From fabb4b5bb2cf5231b4d2bfcf6dd51a3069e4d7fc Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 19 Oct 2016 20:47:13 +0300 Subject: [PATCH] Add static typing and refactor code to pass all mypy checks. Move Python compatibility types to first lines of code. Add Python (text/byte) compatibility helper functions. Check for SSH banner ASCII validity. --- ssh-audit.py | 372 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 282 insertions(+), 90 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index 13f76b9..7a27514 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -25,15 +25,26 @@ """ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 -try: - from typing import List, Tuple, Text -except: - pass VERSION = 'v1.6.1.dev' +if sys.version_info >= (3,): + StringIO, BytesIO = io.StringIO, io.BytesIO + text_type = str + binary_type = bytes +else: + import StringIO as _StringIO + StringIO = BytesIO = _StringIO.StringIO + text_type = unicode + binary_type = str +try: + from typing import List, Tuple, Optional, Callable, Union, Any +except: + pass + def usage(err=None): + # type: (Optional[str]) -> None out = Output() p = os.path.basename(sys.argv[0]) out.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) @@ -53,6 +64,7 @@ def usage(err=None): class AuditConf(object): def __init__(self, host=None, port=22): + # type: (Optional[str], int) -> None self.host = host self.port = port self.ssh1 = True @@ -63,6 +75,7 @@ class AuditConf(object): self.minlevel = 'info' def __setattr__(self, name, value): + # type: (str, Union[str, int, bool]) -> None valid = False if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: valid, value = True, True if value else False @@ -82,6 +95,7 @@ class AuditConf(object): @classmethod def from_cmdline(cls, args, usage_cb): + # type: (List[str], Callable[..., None]) -> AuditConf conf = cls() try: sopts = 'h12bnvl:' @@ -131,6 +145,7 @@ class Output(object): COLORS = {'head': 36, 'good': 32, 'warn': 33, 'fail': 31} def __init__(self): + # type: () -> None self.batch = False self.colors = True self.verbose = False @@ -138,34 +153,40 @@ class Output(object): @property def minlevel(self): + # type: () -> str if self.__minlevel < len(self.LEVELS): return self.LEVELS[self.__minlevel] return 'unknown' @minlevel.setter def minlevel(self, name): + # type: (str) -> None self.__minlevel = self.getlevel(name) def getlevel(self, name): + # type: (str) -> int cname = 'info' if name == 'good' else name if cname not in self.LEVELS: return sys.maxsize return self.LEVELS.index(cname) def sep(self): + # type: () -> None if not self.batch: print() def _colorized(self, color): + # type: (str) -> Callable[[text_type], None] return lambda x: print(u'{0}{1}\033[0m'.format(color, x)) def __getattr__(self, name): + # type: (str) -> Callable[[text_type], None] if name == 'head' and self.batch: return lambda x: None if not self.getlevel(name) >= self.__minlevel: return lambda x: None if self.colors and os.name == 'posix' and name in self.COLORS: - color = u'\033[0;{0}m'.format(self.COLORS[name]) + color = '\033[0;{0}m'.format(self.COLORS[name]) return self._colorized(color) else: return lambda x: print(u'{0}'.format(x)) @@ -173,16 +194,19 @@ class Output(object): class OutputBuffer(list): def __enter__(self): - self.__buf = utils.StringIO() + # type: () -> OutputBuffer + self.__buf = StringIO() self.__stdout = sys.stdout sys.stdout = self.__buf return self def flush(self): + # type: () -> None for line in self: print(line) def __exit__(self, *args): + # type: (*Any) -> None self.extend(self.__buf.getvalue().splitlines()) sys.stdout = self.__stdout @@ -190,6 +214,7 @@ class OutputBuffer(list): class SSH2(object): class KexParty(object): def __init__(self, enc, mac, compression, languages): + # type: (List[text_type], List[text_type], List[text_type], List[text_type]) -> None self.__enc = enc self.__mac = mac self.__compression = compression @@ -197,22 +222,27 @@ class SSH2(object): @property def encryption(self): + # type: () -> List[text_type] return self.__enc @property def mac(self): + # type: () -> List[text_type] return self.__mac @property def compression(self): + # type: () -> List[text_type] return self.__compression @property def languages(self): + # type: () -> List[text_type] return self.__languages class Kex(object): def __init__(self, cookie, kex_algs, key_algs, cli, srv, follows, unused=0): + # type: (binary_type, List[text_type], List[text_type], SSH2.KexParty, SSH2.KexParty, bool, int) -> None self.__cookie = cookie self.__kex_algs = kex_algs self.__key_algs = key_algs @@ -223,35 +253,43 @@ class SSH2(object): @property def cookie(self): + # type: () -> binary_type return self.__cookie @property def kex_algorithms(self): + # type: () -> List[text_type] return self.__kex_algs @property def key_algorithms(self): + # type: () -> List[text_type] return self.__key_algs # client_to_server @property def client(self): + # type: () -> SSH2.KexParty return self.__client # server_to_client @property def server(self): + # type: () -> SSH2.KexParty return self.__server @property def follows(self): + # type: () -> bool return self.__follows @property def unused(self): + # type: () -> int return self.__unused def write(self, wbuf): + # type: (WriteBuf) -> None wbuf.write(self.cookie) wbuf.write_list(self.kex_algorithms) wbuf.write_list(self.key_algorithms) @@ -268,12 +306,14 @@ class SSH2(object): @property def payload(self): + # type: () -> binary_type wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod def parse(cls, payload): + # type: (binary_type) -> SSH2.Kex buf = ReadBuf(payload) cookie = buf.read(16) kex_algs = buf.read_list() @@ -297,6 +337,7 @@ class SSH2(object): class SSH1(object): class CRC32(object): def __init__(self): + # type: () -> None self._table = [0] * 256 for i in range(256): crc = 0 @@ -308,6 +349,7 @@ class SSH1(object): self._table[i] = crc def calc(self, v): + # type: (binary_type) -> int crc, l = 0, len(v) for i in range(l): n = ord(v[i:i + 1]) @@ -315,12 +357,13 @@ class SSH1(object): crc = (crc >> 8) ^ self._table[n] return crc - _crc32 = None + _crc32 = None # type: Optional[SSH1.CRC32] CIPHERS = ['none', 'idea', 'des', '3des', 'tss', 'rc4', 'blowfish'] AUTHS = [None, 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos'] @classmethod def crc32(cls, v): + # type: (binary_type) -> int if cls._crc32 is None: cls._crc32 = cls.CRC32() return cls._crc32.calc(v) @@ -353,10 +396,11 @@ class SSH1(object): 'tis': [['1.2.2']], 'kerberos': [['1.2.2', '3.6'], [FAIL_OPENSSH37_REMOVE]], } - } + } # type: Dict[str, Dict[str, List[List[str]]]] class PublicKeyMessage(object): def __init__(self, cookie, skey, hkey, pflags, cmask, amask): + # type: (binary_type, Tuple[int, int, int], Tuple[int, int, int], int, int, int) -> None assert len(skey) == 3 assert len(hkey) == 3 self.__cookie = cookie @@ -368,67 +412,81 @@ class SSH1(object): @property def cookie(self): + # type: () -> binary_type return self.__cookie @property def server_key_bits(self): + # type: () -> int return self.__server_key[0] @property def server_key_public_exponent(self): + # type: () -> int return self.__server_key[1] @property def server_key_public_modulus(self): + # type: () -> int return self.__server_key[2] @property def host_key_bits(self): + # type: () -> int return self.__host_key[0] @property def host_key_public_exponent(self): + # type: () -> int return self.__host_key[1] @property def host_key_public_modulus(self): + # type: () -> int return self.__host_key[2] @property def host_key_fingerprint_data(self): + # type: () -> binary_type mod = WriteBuf._create_mpint(self.host_key_public_modulus, False) e = WriteBuf._create_mpint(self.host_key_public_exponent, False) return mod + e @property def protocol_flags(self): + # type: () -> int return self.__protocol_flags @property def supported_ciphers_mask(self): + # type: () -> int return self.__supported_ciphers_mask @property def supported_ciphers(self): + # type: () -> List[text_type] ciphers = [] for i in range(len(SSH1.CIPHERS)): if self.__supported_ciphers_mask & (1 << i) != 0: - ciphers.append(SSH1.CIPHERS[i]) + ciphers.append(utils.to_utext(SSH1.CIPHERS[i])) return ciphers @property def supported_authentications_mask(self): + # type: () -> int return self.__supported_authentications_mask @property def supported_authentications(self): + # type: () -> List[text_type] auths = [] for i in range(1, len(SSH1.AUTHS)): if self.__supported_authentications_mask & (1 << i) != 0: - auths.append(SSH1.AUTHS[i]) + auths.append(utils.to_utext(SSH1.AUTHS[i])) return auths def write(self, wbuf): + # type: (WriteBuf) -> None wbuf.write(self.cookie) wbuf.write_int(self.server_key_bits) wbuf.write_mpint1(self.server_key_public_exponent) @@ -442,12 +500,14 @@ class SSH1(object): @property def payload(self): + # type: () -> binary_type wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod def parse(cls, payload): + # type: (binary_type) -> SSH1.PublicKeyMessage buf = ReadBuf(payload) cookie = buf.read(8) server_key_bits = buf.read_int() @@ -467,36 +527,45 @@ class SSH1(object): class ReadBuf(object): def __init__(self, data=None): + # type: (Optional[binary_type]) -> None super(ReadBuf, self).__init__() - self._buf = utils.BytesIO(data) if data else utils.BytesIO() + self._buf = BytesIO(data) if data else BytesIO() self._len = len(data) if data else 0 @property def unread_len(self): + # type: () -> int return self._len - self._buf.tell() def read(self, size): + # type: (int) -> binary_type return self._buf.read(size) def read_byte(self): + # type: () -> int return struct.unpack('B', self.read(1))[0] def read_bool(self): + # type: () -> bool return self.read_byte() != 0 def read_int(self): + # type: () -> int return struct.unpack('>I', self.read(4))[0] def read_list(self): + # type: () -> List[text_type] list_size = self.read_int() return self.read(list_size).decode().split(',') def read_string(self): + # type: () -> binary_type n = self.read_int() return self.read(n) @classmethod def _parse_mpint(cls, v, pad, sf): + # type: (binary_type, binary_type, str) -> int r = 0 if len(v) % 4: v = pad * (4 - (len(v) % 4)) + v @@ -505,12 +574,14 @@ class ReadBuf(object): return r def read_mpint1(self): + # type: () -> int # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = struct.unpack('>H', self.read(2))[0] n = (bits + 7) // 8 return self._parse_mpint(self.read(n), b'\x00', '>I') def read_mpint2(self): + # type: () -> int # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt v = self.read_string() if len(v) == 0: @@ -519,38 +590,47 @@ class ReadBuf(object): return self._parse_mpint(v, pad, sf) def read_line(self): + # type: () -> text_type return self._buf.readline().rstrip().decode('utf-8') class WriteBuf(object): def __init__(self, data=None): + # type: (Optional[binary_type]) -> None super(WriteBuf, self).__init__() - self._wbuf = io.BytesIO(data) if data else io.BytesIO() + self._wbuf = BytesIO(data) if data else BytesIO() def write(self, data): + # type: (binary_type) -> WriteBuf self._wbuf.write(data) return self def write_byte(self, v): + # type: (int) -> WriteBuf return self.write(struct.pack('B', v)) def write_bool(self, v): + # type: (bool) -> WriteBuf return self.write_byte(1 if v else 0) def write_int(self, v): + # type: (int) -> WriteBuf return self.write(struct.pack('>I', v)) def write_string(self, v): + # type: (Union[binary_type, text_type]) -> WriteBuf if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) self.write_int(len(v)) return self.write(v) def write_list(self, v): + # type: (List[text_type]) -> WriteBuf return self.write_string(u','.join(v)) @classmethod def _bitlength(cls, n): + # type: (int) -> int try: return n.bit_length() except AttributeError: @@ -558,11 +638,12 @@ class WriteBuf(object): @classmethod def _create_mpint(cls, n, signed=True, bits=None): + # type: (int, bool, Optional[int]) -> binary_type if bits is None: bits = cls._bitlength(n) length = bits // 8 + (1 if n != 0 else 0) ql = (length + 7) // 8 - fmt, v2 = '>{0}Q'.format(ql), [b'\x00'] * ql + fmt, v2 = '>{0}Q'.format(ql), [0] * ql for i in range(ql): v2[ql - i - 1] = (n & 0xffffffffffffffff) n >>= 64 @@ -574,6 +655,7 @@ class WriteBuf(object): return data def write_mpint1(self, n): + # type: (int) -> WriteBuf # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = self._bitlength(n) data = self._create_mpint(n, False, bits) @@ -581,17 +663,20 @@ class WriteBuf(object): return self.write(data) def write_mpint2(self, n): + # type: (int) -> WriteBuf # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt data = self._create_mpint(n) return self.write_string(data) def write_line(self, v): + # type: (Union[binary_type, str]) -> WriteBuf if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) v += b'\r\n' return self.write(v) def write_flush(self): + # type: () -> binary_type payload = self._wbuf.getvalue() self._wbuf.truncate(0) self._wbuf.seek(0) @@ -613,6 +698,7 @@ class SSH(object): class Software(object): def __init__(self, vendor, product, version, patch, os): + # type: (Optional[str], str, str, Optional[str], Optional[str]) -> None self.__vendor = vendor self.__product = product self.__version = version @@ -621,28 +707,34 @@ class SSH(object): @property def vendor(self): + # type: () -> Optional[str] return self.__vendor @property def product(self): + # type: () -> str return self.__product @property def version(self): + # type: () -> str return self.__version @property def patch(self): + # type: () -> Optional[str] return self.__patch @property def os(self): + # type: () -> Optional[str] return self.__os def compare_version(self, other): + # type: (Union[None, SSH.Software, text_type]) -> int if other is None: return 1 - if isinstance(other, self.__class__): + if isinstance(other, SSH.Software): other = '{0}{1}'.format(other.version, other.patch or '') else: other = str(other) @@ -676,6 +768,7 @@ class SSH(object): return 0 def between_versions(self, vfrom, vtill): + # type: (str, str) -> bool if vfrom and self.compare_version(vfrom) < 0: return False if vtill and self.compare_version(vtill) > 0: @@ -683,6 +776,7 @@ class SSH(object): return True def display(self, full=True): + # type: (bool) -> str out = '{0} '.format(self.vendor) if self.vendor else '' out += self.product if self.version: @@ -701,9 +795,11 @@ class SSH(object): return out def __str__(self): + # type: () -> str return self.display() def __repr__(self): + # type: () -> str out = 'vendor={0}'.format(self.vendor) if self.vendor else '' if self.product: if self.vendor: @@ -719,10 +815,12 @@ class SSH(object): @staticmethod def _fix_patch(patch): + # type: (str) -> Optional[str] return re.sub(r'^[-_\.]+', '', patch) or None @staticmethod def _fix_date(d): + # type: (str) -> Optional[str] if d is not None and len(d) == 8: return '{0}-{1}-{2}'.format(d[:4], d[4:6], d[6:8]) else: @@ -730,6 +828,7 @@ class SSH(object): @classmethod def _extract_os(cls, c): + # type: (Optional[str]) -> str if c is None: return None mx = re.match(r'^NetBSD(?:_Secure_Shell)?(?:[\s-]+(\d{8})(.*))?$', c) @@ -756,6 +855,7 @@ class SSH(object): @classmethod def parse(cls, banner): + # type: (SSH.Banner) -> SSH.Software software = str(banner.software) mx = re.match(r'^dropbear_([\d\.]+\d+)(.*)', software) if mx: @@ -796,24 +896,35 @@ class SSH(object): RX_PROTOCOL = re.compile(re.sub(r'\\d(\+?)', '(\\d\g<1>)', _RXP)) RX_BANNER = re.compile(r'^({0}(?:(?:-{0})*)){1}$'.format(_RXP, _RXR)) - def __init__(self, protocol, software, comments): + def __init__(self, protocol, software, comments, valid_ascii): + # type: (Tuple[int, int], str, str, bool) -> None self.__protocol = protocol self.__software = software self.__comments = comments + self.__valid_ascii = valid_ascii @property def protocol(self): + # type: () -> Tuple[int, int] return self.__protocol @property def software(self): + # type: () -> str return self.__software @property def comments(self): + # type: () -> str return self.__comments + @property + def valid_ascii(self): + # type: () -> bool + return self.__valid_ascii + def __str__(self): + # type: () -> str out = 'SSH-{0}.{1}'.format(self.protocol[0], self.protocol[1]) if self.software is not None: out += '-{0}'.format(self.software) @@ -822,6 +933,7 @@ class SSH(object): return out def __repr__(self): + # type: () -> str p = '{0}.{1}'.format(self.protocol[0], self.protocol[1]) out = 'protocol={0}'.format(p) if self.software: @@ -832,7 +944,10 @@ class SSH(object): @classmethod def parse(cls, banner): - mx = cls.RX_BANNER.match(banner) + # type: (text_type) -> SSH.Banner + valid_ascii = utils.is_ascii(banner) + ascii_banner = utils.to_ascii(banner) + mx = cls.RX_BANNER.match(ascii_banner) if mx is None: return None protocol = min(re.findall(cls.RX_PROTOCOL, mx.group(1))) @@ -843,23 +958,26 @@ class SSH(object): comments = (mx.group(4) or '').strip() or None if comments is not None: comments = re.sub('\s+', ' ', comments) - return cls(protocol, software, comments) + return cls(protocol, software, comments, valid_ascii) class Fingerprint(object): def __init__(self, fpd): + # type: (binary_type) -> None self.__fpd = fpd @property def md5(self): + # type: () -> text_type h = hashlib.md5(self.__fpd).hexdigest() - h = u':'.join(h[i:i + 2] for i in range(0, len(h), 2)) - return u'MD5:{0}'.format(h) + r = u':'.join(h[i:i + 2] for i in range(0, len(h), 2)) + return u'MD5:{0}'.format(r) @property def sha256(self): + # type: () -> text_type h = base64.b64encode(hashlib.sha256(self.__fpd).digest()) - h = h.decode().rstrip('=') - return u'SHA256:{0}'.format(h) + r = h.decode('ascii').rstrip('=') + return u'SHA256:{0}'.format(r) class Security(object): CVE = { @@ -884,7 +1002,7 @@ class SSH(object): ['0.4.7', '0.5.2', 1, 'CVE-2012-4561', 5.0, 'cause DoS via unspecified vectors (invalid pointer)'], ['0.4.7', '0.5.2', 1, 'CVE-2012-4560', 7.5, 'cause DoS or execute arbitrary code (buffer overflow)'], ['0.4.7', '0.5.2', 1, 'CVE-2012-4559', 6.8, 'cause DoS or execute arbitrary code (double free)']] - } + } # type: Dict[str, List[List[Any]]] TXT = { 'Dropbear SSH': [ ['0.28', '0.34', 1, 'remote root exploit', 'remote format string buffer overflow exploit (exploit-db#387)']], @@ -892,7 +1010,7 @@ class SSH(object): ['0.3.3', '0.3.3', 1, 'null pointer check', 'missing null pointer check in "crypt_set_algorithms_server"'], ['0.3.3', '0.3.3', 1, 'integer overflow', 'integer overflow in "buffer_get_data"'], ['0.3.3', '0.3.3', 3, 'heap overflow', 'heap overflow in "packet_decrypt"']] - } + } # type: Dict[str, List[List[Any]]] class Socket(ReadBuf, WriteBuf): class InsufficientReadException(Exception): @@ -901,10 +1019,11 @@ class SSH(object): SM_BANNER_SENT = 1 def __init__(self, host, port, cto=3.0, rto=5.0): + # type: (str, int, float, float) -> None self.__block_size = 8 self.__state = 0 - self.__header = [] - self.__banner = None + self.__header = [] # type: List[text_type] + self.__banner = None # type: Optional[SSH.Banner] super(SSH.Socket, self).__init__() try: self.__sock = socket.create_connection((host, port), cto) @@ -914,9 +1033,11 @@ class SSH(object): sys.exit(1) def __enter__(self): + # type: () -> SSH.Socket return self def get_banner(self, sshv=2): + # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') rto = self.__sock.gettimeout() self.__sock.settimeout(0.7) @@ -944,7 +1065,7 @@ class SSH(object): return self.__banner, self.__header def recv(self, size=2048): - # type: (int) -> Tuple[int, str] + # type: (int) -> Tuple[int, Optional[str]] try: data = self.__sock.recv(size) except socket.timeout: @@ -963,26 +1084,29 @@ class SSH(object): return (len(data), None) def send(self, data): + # type: (binary_type) -> Tuple[int, Optional[str]] try: self.__sock.send(data) return (0, None) except socket.error as e: - return (-1, e) + return (-1, str(e.args[-1])) self.__sock.send(data) def send_banner(self, banner): + # type: (str) -> None self.send(banner.encode() + b'\r\n') if self.__state < self.SM_BANNER_SENT: self.__state = self.SM_BANNER_SENT def ensure_read(self, size): + # type: (int) -> None while self.unread_len < size: s, e = self.recv() if s < 0: raise SSH.Socket.InsufficientReadException(e) def read_packet(self, sshv=2): - # type: (int) -> Tuple[int, bytes] + # type: (int) -> Tuple[int, binary_type] try: header = WriteBuf() self.ensure_read(4) @@ -1034,6 +1158,7 @@ class SSH(object): return (-1, e) def send_packet(self): + # type: () -> Tuple[int, Optional[str]] payload = self.write_flush() padding = -(len(payload) + 5) % 8 if padding < 4: @@ -1044,12 +1169,15 @@ class SSH(object): return self.send(data) def __del__(self): + # type: () -> None self.__cleanup() - def __exit__(self, ex_type, ex_value, tb): + def __exit__(self, *args): + # type: (*Any) -> None self.__cleanup() def __cleanup(self): + # type: () -> None try: self.__sock.shutdown(socket.SHUT_RDWR) self.__sock.close() @@ -1059,13 +1187,15 @@ class SSH(object): class KexDH(object): def __init__(self, alg, g, p): + # type: (str, int, int) -> None self.__alg = alg self.__g = g self.__p = p self.__q = (self.__p - 1) // 2 - self.__x = None + self.__x = None # type: Optional[int] def send_init(self, s): + # type: (SSH.Socket) -> None r = random.SystemRandom() self.__x = r.randrange(2, self.__q) self.__e = pow(self.__g, self.__x, self.__p) @@ -1076,6 +1206,7 @@ class KexDH(object): class KexGroup1(KexDH): def __init__(self): + # type: () -> None # rfc2409: second oakley group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67' 'cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6d' @@ -1087,6 +1218,7 @@ class KexGroup1(KexDH): class KexGroup14(KexDH): def __init__(self): + # type: () -> None # rfc3526: 2048-bit modp group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67' 'cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6d' @@ -1208,10 +1340,11 @@ class KexDB(object): 'umac-64-etm@openssh.com': [['6.2'], [], [WARN_TAG_SIZE]], 'umac-128-etm@openssh.com': [['6.2']], } - } + } # type: Dict[str, Dict[str, List[List[str]]]] def get_ssh_version(version_desc): + # type: (str) -> Tuple[str, str] if version_desc.startswith('d'): return (SSH.Product.DropbearSSH, version_desc[1:]) elif version_desc.startswith('l1'): @@ -1220,8 +1353,8 @@ def get_ssh_version(version_desc): return (SSH.Product.OpenSSH, version_desc) -def get_alg_timeframe(alg_desc, for_server=True, result={}): - versions = alg_desc[0] +def get_alg_timeframe(versions, for_server=True, result={}): + # type: (List[str], bool, Dict[str, List[Optional[str]]]) -> Dict[str, List[Optional[str]]] vlen = len(versions) for i in range(3): if i > vlen - 1: @@ -1256,23 +1389,25 @@ def get_alg_timeframe(alg_desc, for_server=True, result={}): def get_ssh_timeframe(alg_pairs, for_server=True): - timeframe = {} + # type: (List[Tuple[int, Dict[str, Dict[str, List[List[str]]]], List[Tuple[str, List[text_type]]]]], bool) -> Dict[str, List[Optional[str]]] + timeframe = {} # type: Dict[str, List[Optional[str]]] for alg_pair in alg_pairs: - sshv, alg_db = alg_pair[0] - alg_sets = alg_pair[1:] - for alg_set in alg_sets: + sshv, alg_db = alg_pair[0], alg_pair[1] + for alg_set in alg_pair[2]: alg_type, alg_list = alg_set for alg_name in alg_list: - alg_desc = alg_db[alg_type].get(alg_name) + alg_name_native = utils.to_ntext(alg_name) + alg_desc = alg_db[alg_type].get(alg_name_native) if alg_desc is None: continue - timeframe = get_alg_timeframe(alg_desc, for_server, timeframe) + versions = alg_desc[0] + timeframe = get_alg_timeframe(versions, for_server, timeframe) return timeframe -def get_alg_since_text(alg_desc): +def get_alg_since_text(versions): + # type: (List[str]) -> text_type tv = [] - versions = alg_desc[0] if len(versions) == 0 or versions[0] is None: return None for v in versions[0].split(','): @@ -1290,22 +1425,24 @@ def get_alg_since_text(alg_desc): def get_alg_pairs(kex, pkm): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage]) -> List[Tuple[int, Dict[str, Dict[str, List[List[str]]]], List[Tuple[str, List[text_type]]]]] alg_pairs = [] if pkm is not None: - alg_pairs.append(((1, SSH1.KexDB.ALGORITHMS), - ('key', ['ssh-rsa1']), - ('enc', pkm.supported_ciphers), - ('aut', pkm.supported_authentications))) + alg_pairs.append((1, SSH1.KexDB.ALGORITHMS, + [('key', [u'ssh-rsa1']), + ('enc', pkm.supported_ciphers), + ('aut', pkm.supported_authentications)])) if kex is not None: - alg_pairs.append(((2, KexDB.ALGORITHMS), - ('kex', kex.kex_algorithms), - ('key', kex.key_algorithms), - ('enc', kex.server.encryption), - ('mac', kex.server.mac))) + alg_pairs.append((2, KexDB.ALGORITHMS, + [('kex', kex.kex_algorithms), + ('key', kex.key_algorithms), + ('enc', kex.server.encryption), + ('mac', kex.server.mac)])) return alg_pairs def get_alg_recommendations(software, kex, pkm, for_server=True): + # type: (SSH.Software, SSH2.Kex, SSH1.PublicKeyMessage, bool) -> Tuple[SSH.Software, Dict[int, Dict[str, Dict[str, Dict[str, int]]]]] alg_pairs = get_alg_pairs(kex, pkm) vproducts = [SSH.Product.OpenSSH, SSH.Product.DropbearSSH, @@ -1322,18 +1459,17 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): if version is not None: software = SSH.Software(None, product, version, None, None) break - rec = {'.software': software} + rec = {} # type: Dict[int, Dict[str, Dict[str, Dict[str, int]]]] if software is None: - return rec + return software, rec for alg_pair in alg_pairs: - sshv, alg_db = alg_pair[0] - alg_sets = alg_pair[1:] + sshv, alg_db = alg_pair[0], alg_pair[1] rec[sshv] = {} - for alg_set in alg_sets: + for alg_set in alg_pair[2]: alg_type, alg_list = alg_set if alg_type == 'aut': continue - rec[sshv][alg_type] = {'add': [], 'del': {}} + rec[sshv][alg_type] = {'add': {}, 'del': {}} for n, alg_desc in alg_db[alg_type].items(): if alg_type == 'key' and '-cert-' in n: continue @@ -1367,7 +1503,7 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): if n not in alg_list: if faults > 0: continue - rec[sshv][alg_type]['add'].append(n) + rec[sshv][alg_type]['add'][n] = 0 else: if faults == 0: continue @@ -1379,10 +1515,11 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): del_count = len(rec[sshv][alg_type]['del']) new_alg_count = len(alg_list) + add_count - del_count if new_alg_count < 1 and del_count > 0: - mf, new_del = min(rec[sshv][alg_type]['del'].values()), {} - for k, v in rec[sshv][alg_type]['del'].items(): - if v != mf: - new_del[k] = v + mf = min(rec[sshv][alg_type]['del'].values()) + new_del = {} + for k, cf in rec[sshv][alg_type]['del'].items(): + if cf != mf: + new_del[k] = cf if del_count != len(new_del): rec[sshv][alg_type]['del'] = new_del new_alg_count += del_count - len(new_del) @@ -1397,10 +1534,11 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): del rec[sshv][alg_type] if len(rec[sshv]) == 0: del rec[sshv] - return rec + return software, rec def output_algorithms(title, alg_db, alg_type, algorithms, maxlen=0): + # type: (str, Dict[str, Dict[str, List[List[str]]]], str, List[text_type], int) -> None with OutputBuffer() as obuf: for algorithm in algorithms: output_algorithm(alg_db, alg_type, algorithm, maxlen) @@ -1411,6 +1549,7 @@ def output_algorithms(title, alg_db, alg_type, algorithms, maxlen=0): def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): + # type: (Dict[str, Dict[str, List[List[str]]]], str, text_type, int) -> None prefix = '(' + alg_type + ') ' if alg_max_len == 0: alg_max_len = len(alg_name) @@ -1418,12 +1557,14 @@ def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): texts = [] if len(alg_name.strip()) == 0: return - if alg_name in alg_db[alg_type]: - alg_desc = alg_db[alg_type][alg_name] + alg_name_native = utils.to_ntext(alg_name) + if alg_name_native in alg_db[alg_type]: + alg_desc = alg_db[alg_type][alg_name_native] ldesc = len(alg_desc) for idx, level in enumerate(['fail', 'warn', 'info']): if level == 'info': - since_text = get_alg_since_text(alg_desc) + versions = alg_desc[0] + since_text = get_alg_since_text(versions) if since_text: texts.append((level, since_text)) idx = idx + 1 @@ -1451,6 +1592,7 @@ def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): def output_compatibility(kex, pkm, for_server=True): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage], bool) -> None alg_pairs = get_alg_pairs(kex, pkm) ssh_timeframe = get_ssh_timeframe(alg_pairs, for_server) vp = 1 if for_server else 2 @@ -1474,21 +1616,22 @@ def output_compatibility(kex, pkm, for_server=True): def output_security_sub(sub, software, padlen): + # type: (str, SSH.Software, int) -> None secdb = SSH.Security.CVE if sub == 'cve' else SSH.Security.TXT if software is None or software.product not in secdb: return for line in secdb[software.product]: - vfrom, vtill = line[0:2] + vfrom, vtill = line[0:2] # type: str, str if not software.between_versions(vfrom, vtill): continue - target, name = line[2:4] + target, name = line[2:3] # type: int, str is_server, is_client = target & 1 == 1, target & 2 == 2 is_local = target & 4 == 4 if not is_server: continue p = '' if out.batch else ' ' * (padlen - len(name)) if sub == 'cve': - cvss, descr = line[4:6] + cvss, descr = line[4:6] # type: float, str out.fail('(cve) {0}{1} -- ({2}) {3}'.format(name, p, cvss, descr)) else: descr = line[4] @@ -1496,6 +1639,7 @@ def output_security_sub(sub, software, padlen): def output_security(banner, padlen): + # type: (SSH.Banner, int) -> None with OutputBuffer() as obuf: if banner: software = SSH.Software.parse(banner) @@ -1508,6 +1652,7 @@ def output_security(banner, padlen): def output_fingerprint(kex, pkm, sha256=True, padlen=0): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage], bool, int) -> None with OutputBuffer() as obuf: fps = [] if pkm is not None: @@ -1517,9 +1662,9 @@ def output_fingerprint(kex, pkm, sha256=True, padlen=0): fps.append((name, fp, bits)) for fpp in fps: name, fp, bits = fpp - fp = fp.sha256 if sha256 else fp.md5 + fpo = fp.sha256 if sha256 else fp.md5 p = '' if out.batch else ' ' * (padlen - len(name)) - out.good('(fin) {0}{1} -- {2} {3}'.format(name, p, bits, fp)) + out.good('(fin) {0}{1} -- {2} {3}'.format(name, p, bits, fpo)) if len(obuf) > 0: out.head('# fingerprints') obuf.flush() @@ -1527,10 +1672,10 @@ def output_fingerprint(kex, pkm, sha256=True, padlen=0): def output_recommendations(software, kex, pkm, padlen=0): + # type: (SSH.Software, SSH2.Kex, SSH1.PublicKeyMessage, int) -> None for_server = True with OutputBuffer() as obuf: - alg_rec = get_alg_recommendations(software, kex, pkm, for_server) - software = alg_rec['.software'] + software, alg_rec = get_alg_recommendations(software, kex, pkm, for_server) for sshv in range(2, 0, -1): if sshv not in alg_rec: continue @@ -1559,12 +1704,16 @@ def output_recommendations(software, kex, pkm, padlen=0): def output(banner, header, kex=None, pkm=None): + # type: (Optional[SSH.Banner], List[text_type], Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage]) -> None sshv = 1 if pkm else 2 with OutputBuffer() as obuf: if len(header) > 0: out.info('(gen) header: ' + '\n'.join(header)) if banner is not None: out.good('(gen) banner: {0}'.format(banner)) + if not banner.valid_ascii: + # NOTE: RFC 4253, Section 4.2 + out.warn('(gen) banner contains non-printable ASCII') if sshv == 1 or banner.protocol[0] == 1: out.fail('(gen) protocol SSH1 enabled') software = SSH.Software.parse(banner) @@ -1622,20 +1771,61 @@ def output(banner, header, kex=None, pkm=None): class Utils(object): - PY2 = sys.version_info[0] == 2 + @classmethod + def _type_err(cls, v, target): + # type: (Any, text_type) -> TypeError + return TypeError('cannot convert {0} to {1}'.format(type(v), target)) @classmethod - def wrap(cls): - o = cls() - if cls.PY2: - import StringIO - o.StringIO = o.BytesIO = StringIO.StringIO - else: - o.StringIO, o.BytesIO = io.StringIO, io.BytesIO - return o + def to_bytes(cls, v, enc='utf-8'): + # type: (Union[binary_type, text_type], str) -> binary_type + if isinstance(v, binary_type): + return v + elif isinstance(v, text_type): + return v.encode(enc) + raise cls._type_err(v, 'bytes') + + @classmethod + def to_utext(cls, v, enc='utf-8'): + # type: (Union[text_type, binary_type], str) -> text_type + if isinstance(v, text_type): + return v + elif isinstance(v, binary_type): + return v.decode(enc) + raise cls._type_err(v, 'unicode text') + + @classmethod + def to_ntext(cls, v, enc='utf-8'): + # type: (Union[text_type, binary_type], str) -> str + if isinstance(v, str): + return v + elif isinstance(v, text_type): + return v.encode(enc) + elif isinstance(v, binary_type): + return v.decode(enc) + raise cls._type_err(v, 'native text') + + @classmethod + def is_ascii(cls, v, enc='utf-8'): + # type: (Union[text_type, str], str) -> bool + try: + if isinstance(v, (text_type, str)): + v.encode('ascii') + return True + except UnicodeEncodeError: + pass + return False + + @classmethod + def to_ascii(cls, v, errors='replace'): + # type: (Union[text_type, str], str) -> str + if isinstance(v, (text_type, str)): + return cls.to_ntext(v.encode('ascii', errors)) + raise cls._type_err(v, 'ascii') @staticmethod def parse_int(v): + # type: (Any) -> int try: return int(v) except: @@ -1643,6 +1833,7 @@ class Utils(object): def audit(conf, sshv=None): + # type: (AuditConf, Optional[int]) -> None out.batch = conf.batch out.colors = conf.colors out.verbose = conf.verbose @@ -1658,23 +1849,24 @@ def audit(conf, sshv=None): packet_type, payload = s.read_packet(sshv) if packet_type < 0: try: - payload = payload.decode('utf-8') if payload else u'empty' + payload_txt = payload.decode('utf-8') if payload else u'empty' except UnicodeDecodeError: - payload = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) - if payload == u'Protocol major versions differ.': + payload_txt = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) + if payload_txt == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) return - err = '[exception] error reading packet ({0})'.format(payload) + err = '[exception] error reading packet ({0})'.format(payload_txt) else: + err_pair = None if sshv == 1 and packet_type != SSH.Protocol.SMSG_PUBLIC_KEY: - err = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY) + err_pair = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY) elif sshv == 2 and packet_type != SSH.Protocol.MSG_KEXINIT: - err = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT) - if err is not None: + err_pair = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT) + if err_pair is not None: fmt = '[exception] did not receive {0} ({1}), ' + \ 'instead received unknown message ({2})' - err = fmt.format(err[0], err[1], packet_type) + err = fmt.format(err_pair[0], err_pair[1], packet_type) if err: output(banner, header) out.fail(err) @@ -1687,7 +1879,7 @@ def audit(conf, sshv=None): output(banner, header, kex=kex) -utils = Utils.wrap() +utils = Utils() out = Output() if __name__ == '__main__': conf = AuditConf.from_cmdline(sys.argv[1:], usage)