Do not repeat strings, use constants. Also, encapsulate MSG constants.

This commit is contained in:
Andris Raugulis 2016-09-08 14:55:58 +03:00
parent 71a18e153c
commit dbcc0f2c4f

View File

@ -204,11 +204,16 @@ class WriteBuf(object):
class SSH(object): class SSH(object):
class Protocol(object):
MSG_KEXINIT = 20 MSG_KEXINIT = 20
MSG_NEWKEYS = 21 MSG_NEWKEYS = 21
MSG_KEXDH_INIT = 30 MSG_KEXDH_INIT = 30
MSG_KEXDH_REPLY = 32 MSG_KEXDH_REPLY = 32
class Product(object):
OpenSSH = 'OpenSSH'
DropbearSSH = 'Dropbear SSH'
class Software(object): class Software(object):
def __init__(self, vendor, product, version, patch, os): def __init__(self, vendor, product, version, patch, os):
self.__vendor = vendor self.__vendor = vendor
@ -243,7 +248,7 @@ class SSH(object):
if self.version: if self.version:
out += ' {0}'.format(self.version) out += ' {0}'.format(self.version)
patch = self.patch patch = self.patch
if self.product == 'OpenSSH': if self.product == SSH.Product.OpenSSH:
mx = re.match('^(p\d)(.*)$', self.patch) mx = re.match('^(p\d)(.*)$', self.patch)
if mx is not None: if mx is not None:
out += mx.group(1) out += mx.group(1)
@ -309,13 +314,13 @@ class SSH(object):
mx = re.match(r'^dropbear_(\d+.\d+)(.*)', software) mx = re.match(r'^dropbear_(\d+.\d+)(.*)', software)
if mx: if mx:
patch = cls._fix_patch(mx.group(2)) patch = cls._fix_patch(mx.group(2))
v, p = 'Matt Johnston', 'Dropbear SSH' v, p = 'Matt Johnston', SSH.Product.DropbearSSH
v = None v = None
return cls(v, p, mx.group(1), patch, None) return cls(v, p, mx.group(1), patch, None)
mx = re.match(r'^OpenSSH[_\.-]+([\d\.]+\d+)(.*)', software) mx = re.match(r'^OpenSSH[_\.-]+([\d\.]+\d+)(.*)', software)
if mx: if mx:
patch = cls._fix_patch(mx.group(2)) patch = cls._fix_patch(mx.group(2))
v, p = 'OpenBSD', 'OpenSSH' v, p = 'OpenBSD', SSH.Product.OpenSSH
v = None v = None
os = cls._extract_os(banner.comments) os = cls._extract_os(banner.comments)
return cls(v, p, mx.group(1), patch, os) return cls(v, p, mx.group(1), patch, os)
@ -524,7 +529,7 @@ class KexDH(object):
r = random.SystemRandom() r = random.SystemRandom()
self.__x = r.randrange(2, self.__q) self.__x = r.randrange(2, self.__q)
self.__e = pow(self.__g, self.__x, self.__p) self.__e = pow(self.__g, self.__x, self.__p)
s.write_byte(SSH.MSG_KEXDH_INIT) s.write_byte(SSH.Protocol.MSG_KEXDH_INIT)
s.write_mpint(self.__e) s.write_mpint(self.__e)
s.send_packet() s.send_packet()
@ -668,9 +673,9 @@ class KexDB(object):
def get_ssh_version(version_desc): def get_ssh_version(version_desc):
if version_desc.startswith('d'): if version_desc.startswith('d'):
return ('Dropbear SSH', version_desc[1:]) return (SSH.Product.DropbearSSH, version_desc[1:])
else: else:
return ('OpenSSH', version_desc) return (SSH.Product.OpenSSH, version_desc)
def get_alg_timeframe(alg_desc, result={}): def get_alg_timeframe(alg_desc, result={}):
@ -776,7 +781,7 @@ def output_compatibility(kex, client=False):
ssh_timeframe = get_ssh_timeframe(kex) ssh_timeframe = get_ssh_timeframe(kex)
cp = 2 if client else 1 cp = 2 if client else 1
comp_text = [] comp_text = []
for sshd_name in ['OpenSSH', 'Dropbear SSH']: for sshd_name in [SSH.Product.OpenSSH, SSH.Product.DropbearSSH]:
if sshd_name not in ssh_timeframe: if sshd_name not in ssh_timeframe:
continue continue
v = ssh_timeframe[sshd_name] v = ssh_timeframe[sshd_name]
@ -885,7 +890,7 @@ def main():
packet_type, payload = s.read_packet() packet_type, payload = s.read_packet()
if packet_type < 0: if packet_type < 0:
err = '[exception] error reading packet ({0})'.format(payload) err = '[exception] error reading packet ({0})'.format(payload)
elif packet_type != SSH.MSG_KEXINIT: elif packet_type != SSH.Protocol.MSG_KEXINIT:
err = '[exception] did not receive MSG_KEXINIT (20), ' + \ err = '[exception] did not receive MSG_KEXINIT (20), ' + \
'instead received unknown message ({0})'.format(packet_type) 'instead received unknown message ({0})'.format(packet_type)
if err: if err: