diff --git a/ssh-audit.py b/ssh-audit.py index a7c8e40..fa7978b 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -51,8 +51,7 @@ except ImportError: # pragma: nocover pass -def usage(err=None): - # type: (Optional[str]) -> None +def usage(err: Optional[str] = None) -> None: uout = Output() p = os.path.basename(sys.argv[0]) uout.head('# {} {}, https://github.com/jtesta/ssh-audit\n'.format(p, VERSION)) @@ -78,8 +77,7 @@ def usage(err=None): class AuditConf: # pylint: disable=too-many-instance-attributes - def __init__(self, host=None, port=22): - # type: (Optional[str], int) -> None + def __init__(self, host: Optional[str] = None, port: int = 22) -> None: self.host = host self.port = port self.ssh1 = True @@ -96,8 +94,7 @@ class AuditConf: self.timeout = 5.0 self.timeout_set = False # Set to True when the user explicitly sets it. - def __setattr__(self, name, value): - # type: (str, Union[str, int, float, bool, Sequence[int]]) -> None + def __setattr__(self, name: str, value: Union[str, int, float, bool, Sequence[int]]) -> None: valid = False if name in ['ssh1', 'ssh2', 'batch', 'client_audit', 'colors', 'verbose', 'timeout_set', 'json']: valid, value = True, bool(value) @@ -141,8 +138,7 @@ class AuditConf: object.__setattr__(self, name, value) @classmethod - def from_cmdline(cls, args, usage_cb): # pylint: disable=too-many-statements - # type: (List[str], Callable[..., None]) -> AuditConf + def from_cmdline(cls, args: List[str], usage_cb: Callable[..., None]) -> 'AuditConf': # pylint: disable=too-many-statements # pylint: disable=too-many-branches aconf = cls() try: @@ -220,8 +216,7 @@ class Output: LEVELS = ('info', 'warn', 'fail') # type: Sequence[str] COLORS = {'head': 36, 'good': 32, 'warn': 33, 'fail': 31} - def __init__(self): - # type: () -> None + def __init__(self) -> None: self.batch = False self.verbose = False self.use_colors = True @@ -230,41 +225,34 @@ class Output: self.__colsupport = 'colorama' in sys.modules or os.name == 'posix' @property - def level(self): - # type: () -> str + def level(self) -> str: if self.__level < len(self.LEVELS): return self.LEVELS[self.__level] return 'unknown' @level.setter - def level(self, name): - # type: (str) -> None + def level(self, name: str) -> None: self.__level = self.get_level(name) - def get_level(self, name): - # type: (str) -> int + def get_level(self, name: str) -> int: cname = 'info' if name == 'good' else name if cname not in self.LEVELS: return sys.maxsize return self.LEVELS.index(cname) - def sep(self): - # type: () -> None + def sep(self) -> None: if not self.batch: print() @property - def colors_supported(self): - # type: () -> bool + def colors_supported(self) -> bool: return self.__colsupport @staticmethod - def _colorized(color): - # type: (str) -> Callable[[str], None] + def _colorized(color: str) -> Callable[[str], None]: return lambda x: print(u'{}{}\033[0m'.format(color, x)) - def __getattr__(self, name): - # type: (str) -> Callable[[str], None] + def __getattr__(self, name: str) -> Callable[[str], None]: if name == 'head' and self.batch: return lambda x: None if not self.get_level(name) >= self.__level: @@ -277,24 +265,21 @@ class Output: class OutputBuffer(list): - def __enter__(self): - # type: () -> OutputBuffer + def __enter__(self) -> 'OutputBuffer': # pylint: disable=attribute-defined-outside-init self.__buf = io.StringIO() self.__stdout = sys.stdout sys.stdout = self.__buf return self - def flush(self, sort_lines=False): - # type: (bool) -> None + def flush(self, sort_lines: bool = False) -> None: # Lines must be sorted in some cases to ensure consistent testing. if sort_lines: self.sort() for line in self: print(line) - def __exit__(self, *args): - # type: (*Any) -> None + def __exit__(self, *args: Any) -> None: self.extend(self.__buf.getvalue().splitlines()) sys.stdout = self.__stdout @@ -520,36 +505,30 @@ class SSH2: # pylint: disable=too-few-public-methods } # type: Dict[str, Dict[str, List[List[Optional[str]]]]] class KexParty: - def __init__(self, enc, mac, compression, languages): - # type: (List[str], List[str], List[str], List[str]) -> None + def __init__(self, enc: List[str], mac: List[str], compression: List[str], languages: List[str]) -> None: self.__enc = enc self.__mac = mac self.__compression = compression self.__languages = languages @property - def encryption(self): - # type: () -> List[str] + def encryption(self) -> List[str]: return self.__enc @property - def mac(self): - # type: () -> List[str] + def mac(self) -> List[str]: return self.__mac @property - def compression(self): - # type: () -> List[str] + def compression(self) -> List[str]: return self.__compression @property - def languages(self): - # type: () -> List[str] + def languages(self) -> List[str]: return self.__languages class Kex: - def __init__(self, cookie, kex_algs, key_algs, cli, srv, follows, unused=0): - # type: (bytes, List[str], List[str], SSH2.KexParty, SSH2.KexParty, bool, int) -> None + def __init__(self, cookie: bytes, kex_algs: List[str], key_algs: List[str], cli: 'SSH2.KexParty', srv: 'SSH2.KexParty', follows: bool, unused: int = 0) -> None: self.__cookie = cookie self.__kex_algs = kex_algs self.__key_algs = key_algs @@ -563,40 +542,33 @@ class SSH2: # pylint: disable=too-few-public-methods self.__host_keys = {} # type: Dict[str, bytes] @property - def cookie(self): - # type: () -> bytes + def cookie(self) -> bytes: return self.__cookie @property - def kex_algorithms(self): - # type: () -> List[str] + def kex_algorithms(self) -> List[str]: return self.__kex_algs @property - def key_algorithms(self): - # type: () -> List[str] + def key_algorithms(self) -> List[str]: return self.__key_algs # client_to_server @property - def client(self): - # type: () -> SSH2.KexParty + def client(self) -> 'SSH2.KexParty': return self.__client # server_to_client @property - def server(self): - # type: () -> SSH2.KexParty + def server(self) -> 'SSH2.KexParty': return self.__server @property - def follows(self): - # type: () -> bool + def follows(self) -> bool: return self.__follows @property - def unused(self): - # type: () -> int + def unused(self) -> int: return self.__unused def set_rsa_key_size(self, rsa_type, hostkey_size, ca_size=-1): @@ -617,8 +589,7 @@ class SSH2: # pylint: disable=too-few-public-methods def host_keys(self): return self.__host_keys - def write(self, wbuf): - # type: (WriteBuf) -> None + def write(self, wbuf: 'WriteBuf') -> None: wbuf.write(self.cookie) wbuf.write_list(self.kex_algorithms) wbuf.write_list(self.key_algorithms) @@ -634,15 +605,13 @@ class SSH2: # pylint: disable=too-few-public-methods wbuf.write_int(self.__unused) @property - def payload(self): - # type: () -> bytes + def payload(self) -> bytes: wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod - def parse(cls, payload): - # type: (bytes) -> SSH2.Kex + def parse(cls, payload: bytes) -> 'SSH2.Kex': buf = ReadBuf(payload) cookie = buf.read(16) kex_algs = buf.read_list() @@ -915,8 +884,7 @@ class SSH2: # pylint: disable=too-few-public-methods class SSH1: class CRC32: - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._table = [0] * 256 for i in range(256): crc = 0 @@ -927,8 +895,7 @@ class SSH1: n = n >> 1 self._table[i] = crc - def calc(self, v): - # type: (bytes) -> int + def calc(self, v: bytes) -> int: crc, length = 0, len(v) for i in range(length): n = ord(v[i:i + 1]) @@ -941,8 +908,7 @@ class SSH1: AUTHS = ['none', 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos'] @classmethod - def crc32(cls, v): - # type: (bytes) -> int + def crc32(cls, v: bytes) -> int: if cls._crc32 is None: cls._crc32 = cls.CRC32() return cls._crc32.calc(v) @@ -979,8 +945,7 @@ class SSH1: } # type: Dict[str, Dict[str, List[List[Optional[str]]]]] class PublicKeyMessage: - def __init__(self, cookie, skey, hkey, pflags, cmask, amask): - # type: (bytes, Tuple[int, int, int], Tuple[int, int, int], int, int, int) -> None + def __init__(self, cookie: bytes, skey: Tuple[int, int, int], hkey: Tuple[int, int, int], pflags: int, cmask: int, amask: int) -> None: if len(skey) != 3: raise ValueError('invalid server key pair: {}'.format(skey)) if len(hkey) != 3: @@ -993,61 +958,50 @@ class SSH1: self.__supported_authentications_mask = amask @property - def cookie(self): - # type: () -> bytes + def cookie(self) -> bytes: return self.__cookie @property - def server_key_bits(self): - # type: () -> int + def server_key_bits(self) -> int: return self.__server_key[0] @property - def server_key_public_exponent(self): - # type: () -> int + def server_key_public_exponent(self) -> int: return self.__server_key[1] @property - def server_key_public_modulus(self): - # type: () -> int + def server_key_public_modulus(self) -> int: return self.__server_key[2] @property - def host_key_bits(self): - # type: () -> int + def host_key_bits(self) -> int: return self.__host_key[0] @property - def host_key_public_exponent(self): - # type: () -> int + def host_key_public_exponent(self) -> int: return self.__host_key[1] @property - def host_key_public_modulus(self): - # type: () -> int + def host_key_public_modulus(self) -> int: return self.__host_key[2] @property - def host_key_fingerprint_data(self): - # type: () -> bytes + def host_key_fingerprint_data(self) -> bytes: # pylint: disable=protected-access mod = WriteBuf._create_mpint(self.host_key_public_modulus, False) e = WriteBuf._create_mpint(self.host_key_public_exponent, False) return mod + e @property - def protocol_flags(self): - # type: () -> int + def protocol_flags(self) -> int: return self.__protocol_flags @property - def supported_ciphers_mask(self): - # type: () -> int + def supported_ciphers_mask(self) -> int: return self.__supported_ciphers_mask @property - def supported_ciphers(self): - # type: () -> List[str] + def supported_ciphers(self) -> List[str]: ciphers = [] for i in range(len(SSH1.CIPHERS)): if self.__supported_ciphers_mask & (1 << i) != 0: @@ -1055,21 +1009,18 @@ class SSH1: return ciphers @property - def supported_authentications_mask(self): - # type: () -> int + def supported_authentications_mask(self) -> int: return self.__supported_authentications_mask @property - def supported_authentications(self): - # type: () -> List[str] + def supported_authentications(self) -> List[str]: auths = [] for i in range(1, len(SSH1.AUTHS)): if self.__supported_authentications_mask & (1 << i) != 0: auths.append(utils.to_text(SSH1.AUTHS[i])) return auths - def write(self, wbuf): - # type: (WriteBuf) -> None + def write(self, wbuf: 'WriteBuf') -> None: wbuf.write(self.cookie) wbuf.write_int(self.server_key_bits) wbuf.write_mpint1(self.server_key_public_exponent) @@ -1082,15 +1033,13 @@ class SSH1: wbuf.write_int(self.supported_authentications_mask) @property - def payload(self): - # type: () -> bytes + def payload(self) -> bytes: wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod - def parse(cls, payload): - # type: (bytes) -> SSH1.PublicKeyMessage + def parse(cls, payload: bytes) -> 'SSH1.PublicKeyMessage': buf = ReadBuf(payload) cookie = buf.read(8) server_key_bits = buf.read_int() @@ -1109,48 +1058,39 @@ class SSH1: class ReadBuf: - def __init__(self, data=None): - # type: (Optional[bytes]) -> None + def __init__(self, data: Optional[bytes] = None) -> None: super(ReadBuf, self).__init__() self._buf = io.BytesIO(data) if data is not None else io.BytesIO() self._len = len(data) if data is not None else 0 @property - def unread_len(self): - # type: () -> int + def unread_len(self) -> int: return self._len - self._buf.tell() - def read(self, size): - # type: (int) -> bytes + def read(self, size: int) -> bytes: return self._buf.read(size) - def read_byte(self): - # type: () -> int + def read_byte(self) -> int: v = struct.unpack('B', self.read(1))[0] # type: int return v - def read_bool(self): - # type: () -> bool + def read_bool(self) -> bool: return self.read_byte() != 0 - def read_int(self): - # type: () -> int + def read_int(self) -> int: v = struct.unpack('>I', self.read(4))[0] # type: int return v - def read_list(self): - # type: () -> List[str] + def read_list(self) -> List[str]: list_size = self.read_int() return self.read(list_size).decode('utf-8', 'replace').split(',') - def read_string(self): - # type: () -> bytes + def read_string(self) -> bytes: n = self.read_int() return self.read(n) @classmethod - def _parse_mpint(cls, v, pad, f): - # type: (bytes, bytes, str) -> int + def _parse_mpint(cls, v: bytes, pad: bytes, f: str) -> int: r = 0 if len(v) % 4 != 0: v = pad * (4 - (len(v) % 4)) + v @@ -1158,15 +1098,13 @@ class ReadBuf: r = (r << 32) | struct.unpack(f, v[i:i + 4])[0] return r - def read_mpint1(self): - # type: () -> int + def read_mpint1(self) -> int: # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = struct.unpack('>H', self.read(2))[0] n = (bits + 7) // 8 return self._parse_mpint(self.read(n), b'\x00', '>I') - def read_mpint2(self): - # type: () -> int + def read_mpint2(self) -> int: # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt v = self.read_string() if len(v) == 0: @@ -1174,8 +1112,7 @@ class ReadBuf: pad, f = (b'\xff', '>i') if ord(v[0:1]) & 0x80 != 0 else (b'\x00', '>I') return self._parse_mpint(v, pad, f) - def read_line(self): - # type: () -> str + def read_line(self) -> str: return self._buf.readline().rstrip().decode('utf-8', 'replace') def reset(self): @@ -1184,50 +1121,41 @@ class ReadBuf: class WriteBuf: - def __init__(self, data=None): - # type: (Optional[bytes]) -> None + def __init__(self, data: Optional[bytes] = None) -> None: super(WriteBuf, self).__init__() self._wbuf = io.BytesIO(data) if data is not None else io.BytesIO() - def write(self, data): - # type: (bytes) -> WriteBuf + def write(self, data: bytes) -> 'WriteBuf': self._wbuf.write(data) return self - def write_byte(self, v): - # type: (int) -> WriteBuf + def write_byte(self, v: int) -> 'WriteBuf': return self.write(struct.pack('B', v)) - def write_bool(self, v): - # type: (bool) -> WriteBuf + def write_bool(self, v: bool) -> 'WriteBuf': return self.write_byte(1 if v else 0) - def write_int(self, v): - # type: (int) -> WriteBuf + def write_int(self, v: int) -> 'WriteBuf': return self.write(struct.pack('>I', v)) - def write_string(self, v): - # type: (Union[bytes, str]) -> WriteBuf + def write_string(self, v: Union[bytes, str]) -> 'WriteBuf': if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) self.write_int(len(v)) return self.write(v) - def write_list(self, v): - # type: (List[str]) -> WriteBuf + def write_list(self, v: List[str]) -> 'WriteBuf': return self.write_string(u','.join(v)) @classmethod - def _bitlength(cls, n): - # type: (int) -> int + def _bitlength(cls, n: int) -> int: try: return n.bit_length() except AttributeError: return len(bin(n)) - (2 if n > 0 else 3) @classmethod - def _create_mpint(cls, n, signed=True, bits=None): - # type: (int, bool, Optional[int]) -> bytes + def _create_mpint(cls, n: int, signed: bool = True, bits: Optional[int] = None) -> bytes: if bits is None: bits = cls._bitlength(n) length = bits // 8 + (1 if n != 0 else 0) @@ -1243,29 +1171,25 @@ class WriteBuf: data = data[1:] return data - def write_mpint1(self, n): - # type: (int) -> WriteBuf + def write_mpint1(self, n: int) -> 'WriteBuf': # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = self._bitlength(n) data = self._create_mpint(n, False, bits) self.write(struct.pack('>H', bits)) return self.write(data) - def write_mpint2(self, n): - # type: (int) -> WriteBuf + def write_mpint2(self, n: int) -> 'WriteBuf': # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt data = self._create_mpint(n) return self.write_string(data) - def write_line(self, v): - # type: (Union[bytes, str]) -> WriteBuf + def write_line(self, v: Union[bytes, str]) -> 'WriteBuf': if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) v += b'\r\n' return self.write(v) - def write_flush(self): - # type: () -> bytes + def write_flush(self) -> bytes: payload = self._wbuf.getvalue() self._wbuf.truncate(0) self._wbuf.seek(0) @@ -1297,8 +1221,7 @@ class SSH: # pylint: disable=too-few-public-methods PuTTY = 'PuTTY' class Software: - def __init__(self, vendor, product, version, patch, os_version): - # type: (Optional[str], str, str, Optional[str], Optional[str]) -> None + def __init__(self, vendor: Optional[str], product: str, version: str, patch: Optional[str], os_version: Optional[str]) -> None: self.__vendor = vendor self.__product = product self.__version = version @@ -1306,32 +1229,26 @@ class SSH: # pylint: disable=too-few-public-methods self.__os = os_version @property - def vendor(self): - # type: () -> Optional[str] + def vendor(self) -> Optional[str]: return self.__vendor @property - def product(self): - # type: () -> str + def product(self) -> str: return self.__product @property - def version(self): - # type: () -> str + def version(self) -> str: return self.__version @property - def patch(self): - # type: () -> Optional[str] + def patch(self) -> Optional[str]: return self.__patch @property - def os(self): - # type: () -> Optional[str] + def os(self) -> Optional[str]: return self.__os - def compare_version(self, other): - # type: (Union[None, SSH.Software, str]) -> int + def compare_version(self, other: Union[None, 'SSH.Software', str]) -> int: # pylint: disable=too-many-branches if other is None: return 1 @@ -1368,16 +1285,14 @@ class SSH: # pylint: disable=too-few-public-methods return 1 return 0 - def between_versions(self, vfrom, vtill): - # type: (str, str) -> bool + def between_versions(self, vfrom: str, vtill: str) -> bool: if bool(vfrom) and self.compare_version(vfrom) < 0: return False if bool(vtill) and self.compare_version(vtill) > 0: return False return True - def display(self, full=True): - # type: (bool) -> str + def display(self, full: bool = True) -> str: r = '{} '.format(self.vendor) if bool(self.vendor) else '' r += self.product if bool(self.version): @@ -1395,12 +1310,10 @@ class SSH: # pylint: disable=too-few-public-methods r += ' running on {}'.format(self.os) return r - def __str__(self): - # type: () -> str + def __str__(self) -> str: return self.display() - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: r = 'vendor={}, '.format(self.vendor) if bool(self.vendor) else '' r += 'product={}'.format(self.product) if bool(self.version): @@ -1412,21 +1325,18 @@ class SSH: # pylint: disable=too-few-public-methods return '<{}({})>'.format(self.__class__.__name__, r) @staticmethod - def _fix_patch(patch): - # type: (str) -> Optional[str] + def _fix_patch(patch: str) -> Optional[str]: return re.sub(r'^[-_\.]+', '', patch) or None @staticmethod - def _fix_date(d): - # type: (str) -> Optional[str] + def _fix_date(d: str) -> Optional[str]: if d is not None and len(d) == 8: return '{}-{}-{}'.format(d[:4], d[4:6], d[6:8]) else: return None @classmethod - def _extract_os_version(cls, c): - # type: (Optional[str]) -> Optional[str] + def _extract_os_version(cls, c: Optional[str]) -> Optional[str]: if c is None: return None mx = re.match(r'^NetBSD(?:_Secure_Shell)?(?:[\s-]+(\d{8})(.*))?$', c) @@ -1452,8 +1362,7 @@ class SSH: # pylint: disable=too-few-public-methods return None @classmethod - def parse(cls, banner): - # type: (SSH.Banner) -> Optional[SSH.Software] + def parse(cls, banner: 'SSH.Banner') -> Optional['SSH.Software']: # pylint: disable=too-many-return-statements software = str(banner.software) mx = re.match(r'^dropbear_([\d\.]+\d+)(.*)', software) @@ -1508,35 +1417,29 @@ class SSH: # pylint: disable=too-few-public-methods RX_PROTOCOL = re.compile(re.sub(r'\\d(\+?)', r'(\\d\g<1>)', _RXP)) RX_BANNER = re.compile(r'^({0}(?:(?:-{0})*)){1}$'.format(_RXP, _RXR)) - def __init__(self, protocol, software, comments, valid_ascii): - # type: (Tuple[int, int], Optional[str], Optional[str], bool) -> None + def __init__(self, protocol: Tuple[int, int], software: Optional[str], comments: Optional[str], valid_ascii: bool) -> None: self.__protocol = protocol self.__software = software self.__comments = comments self.__valid_ascii = valid_ascii @property - def protocol(self): - # type: () -> Tuple[int, int] + def protocol(self) -> Tuple[int, int]: return self.__protocol @property - def software(self): - # type: () -> Optional[str] + def software(self) -> Optional[str]: return self.__software @property - def comments(self): - # type: () -> Optional[str] + def comments(self) -> Optional[str]: return self.__comments @property - def valid_ascii(self): - # type: () -> bool + def valid_ascii(self) -> bool: return self.__valid_ascii - def __str__(self): - # type: () -> str + def __str__(self) -> str: r = 'SSH-{}.{}'.format(self.protocol[0], self.protocol[1]) if self.software is not None: r += '-{}'.format(self.software) @@ -1544,8 +1447,7 @@ class SSH: # pylint: disable=too-few-public-methods r += ' {}'.format(self.comments) return r - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: p = '{}.{}'.format(self.protocol[0], self.protocol[1]) r = 'protocol={}'.format(p) if self.software is not None: @@ -1555,8 +1457,7 @@ class SSH: # pylint: disable=too-few-public-methods return '<{}({})>'.format(self.__class__.__name__, r) @classmethod - def parse(cls, banner): - # type: (str) -> Optional[SSH.Banner] + def parse(cls, banner: str) -> Optional['SSH.Banner']: valid_ascii = utils.is_print_ascii(banner) ascii_banner = utils.to_print_ascii(banner) mx = cls.RX_BANNER.match(ascii_banner) @@ -1573,56 +1474,45 @@ class SSH: # pylint: disable=too-few-public-methods return cls(protocol, software, comments, valid_ascii) class Fingerprint: - def __init__(self, fpd): - # type: (bytes) -> None + def __init__(self, fpd: bytes) -> None: self.__fpd = fpd @property - def md5(self): - # type: () -> str + def md5(self) -> str: h = hashlib.md5(self.__fpd).hexdigest() r = u':'.join(h[i:i + 2] for i in range(0, len(h), 2)) return u'MD5:{}'.format(r) @property - def sha256(self): - # type: () -> str + def sha256(self) -> str: h = base64.b64encode(hashlib.sha256(self.__fpd).digest()) r = h.decode('ascii').rstrip('=') return u'SHA256:{}'.format(r) class Algorithm: class Timeframe: - def __init__(self): - # type: () -> None + def __init__(self) -> None: self.__storage = {} # type: Dict[str, List[Optional[str]]] - def __contains__(self, product): - # type: (str) -> bool + def __contains__(self, product: str) -> bool: return product in self.__storage - def __getitem__(self, product): - # type: (str) -> Sequence[Optional[str]] + def __getitem__(self, product): # type: (str) -> Sequence[Optional[str]] return tuple(self.__storage.get(product, [None] * 4)) - def __str__(self): - # type: () -> str + def __str__(self) -> str: return self.__storage.__str__() - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return self.__str__() - def get_from(self, product, for_server=True): - # type: (str, bool) -> Optional[str] + def get_from(self, product: str, for_server: bool = True) -> Optional[str]: return self[product][0 if bool(for_server) else 2] - def get_till(self, product, for_server=True): - # type: (str, bool) -> Optional[str] + def get_till(self, product: str, for_server: bool = True) -> Optional[str]: return self[product][1 if bool(for_server) else 3] - def _update(self, versions, pos): - # type: (Optional[str], int) -> None + def _update(self, versions: Optional[str], pos: int) -> None: ssh_versions = {} # type: Dict[str, str] for_srv, for_cli = pos < 2, pos > 1 for v in (versions or '').split(','): @@ -1637,8 +1527,7 @@ class SSH: # pylint: disable=too-few-public-methods if (prev is None or (prev < ssh_version and pos % 2 == 0) or (prev > ssh_version and pos % 2 == 1)): self.__storage[ssh_product][pos] = ssh_version - def update(self, versions, for_server=None): - # type: (List[Optional[str]], Optional[bool]) -> SSH.Algorithm.Timeframe + def update(self, versions: List[Optional[str]], for_server: Optional[bool] = None) -> 'SSH.Algorithm.Timeframe': for_cli = for_server is None or for_server is False for_srv = for_server is None or for_server is True vlen = len(versions) @@ -1650,8 +1539,7 @@ class SSH: # pylint: disable=too-few-public-methods return self @staticmethod - def get_ssh_version(version_desc): - # type: (str) -> Tuple[str, str, bool] + def get_ssh_version(version_desc: str) -> Tuple[str, str, bool]: is_client = version_desc.endswith('C') if is_client: version_desc = version_desc[:-1] @@ -1663,8 +1551,7 @@ class SSH: # pylint: disable=too-few-public-methods return SSH.Product.OpenSSH, version_desc, is_client @classmethod - def get_since_text(cls, versions): - # type: (List[Optional[str]]) -> Optional[str] + def get_since_text(cls, versions: List[Optional[str]]) -> Optional[str]: tv = [] if len(versions) == 0 or versions[0] is None: return None @@ -1682,24 +1569,20 @@ class SSH: # pylint: disable=too-few-public-methods return 'available since ' + ', '.join(tv).rstrip(', ') class Algorithms: - def __init__(self, pkm, kex): - # type: (Optional[SSH1.PublicKeyMessage], Optional[SSH2.Kex]) -> None + def __init__(self, pkm: Optional[SSH1.PublicKeyMessage], kex: Optional[SSH2.Kex]) -> None: self.__ssh1kex = pkm self.__ssh2kex = kex @property - def ssh1kex(self): - # type: () -> Optional[SSH1.PublicKeyMessage] + def ssh1kex(self) -> Optional[SSH1.PublicKeyMessage]: return self.__ssh1kex @property - def ssh2kex(self): - # type: () -> Optional[SSH2.Kex] + def ssh2kex(self) -> Optional[SSH2.Kex]: return self.__ssh2kex @property - def ssh1(self): - # type: () -> Optional[SSH.Algorithms.Item] + def ssh1(self) -> Optional['SSH.Algorithms.Item']: if self.ssh1kex is None: return None item = SSH.Algorithms.Item(1, SSH1.KexDB.ALGORITHMS) @@ -1709,8 +1592,7 @@ class SSH: # pylint: disable=too-few-public-methods return item @property - def ssh2(self): - # type: () -> Optional[SSH.Algorithms.Item] + def ssh2(self) -> Optional['SSH.Algorithms.Item']: if self.ssh2kex is None: return None item = SSH.Algorithms.Item(2, SSH2.KexDB.ALGORITHMS) @@ -1721,17 +1603,14 @@ class SSH: # pylint: disable=too-few-public-methods return item @property - def values(self): - # type: () -> Iterable[SSH.Algorithms.Item] + def values(self) -> Iterable['SSH.Algorithms.Item']: for item in [self.ssh1, self.ssh2]: if item is not None: yield item @property - def maxlen(self): - # type: () -> int - def _ml(items): - # type: (Sequence[str]) -> int + def maxlen(self) -> int: + def _ml(items: Sequence[str]) -> int: return max(len(i) for i in items) maxlen = 0 if self.ssh1kex is not None: @@ -1746,8 +1625,7 @@ class SSH: # pylint: disable=too-few-public-methods maxlen) return maxlen - def get_ssh_timeframe(self, for_server=None): - # type: (Optional[bool]) -> SSH.Algorithm.Timeframe + def get_ssh_timeframe(self, for_server: Optional[bool] = None) -> 'SSH.Algorithm.Timeframe': timeframe = SSH.Algorithm.Timeframe() for alg_pair in self.values: alg_db = alg_pair.db @@ -1761,8 +1639,7 @@ class SSH: # pylint: disable=too-few-public-methods timeframe.update(versions, for_server) return timeframe - def get_recommendations(self, software, for_server=True): - # type: (Optional[SSH.Software], bool) -> Tuple[Optional[SSH.Software], Dict[int, Dict[str, Dict[str, Dict[str, int]]]]] + def get_recommendations(self, software: Optional['SSH.Software'], for_server: bool = True) -> Tuple[Optional['SSH.Software'], Dict[int, Dict[str, Dict[str, Dict[str, int]]]]]: # pylint: disable=too-many-locals,too-many-statements vproducts = [SSH.Product.OpenSSH, SSH.Product.DropbearSSH, @@ -1869,28 +1746,23 @@ class SSH: # pylint: disable=too-few-public-methods return software, rec class Item: - def __init__(self, sshv, db): - # type: (int, Dict[str, Dict[str, List[List[Optional[str]]]]]) -> None + def __init__(self, sshv: int, db: Dict[str, Dict[str, List[List[Optional[str]]]]]) -> None: self.__sshv = sshv self.__db = db self.__storage = {} # type: Dict[str, List[str]] @property - def sshv(self): - # type: () -> int + def sshv(self) -> int: return self.__sshv @property - def db(self): - # type: () -> Dict[str, Dict[str, List[List[Optional[str]]]]] + def db(self) -> Dict[str, Dict[str, List[List[Optional[str]]]]]: return self.__db - def add(self, key, value): - # type: (str, List[str]) -> None + def add(self, key: str, value: List[str]) -> None: self.__storage[key] = value - def items(self): - # type: () -> Iterable[Tuple[str, List[str]]] + def items(self) -> Iterable[Tuple[str, List[str]]]: return self.__storage.items() class Security: # pylint: disable=too-few-public-methods @@ -2021,8 +1893,7 @@ class SSH: # pylint: disable=too-few-public-methods SM_BANNER_SENT = 1 - def __init__(self, host, port, ipvo=None, timeout=5, timeout_set=False): - # type: (Optional[str], int, Optional[Sequence[int]], Union[int,float], bool) -> None + def __init__(self, host: Optional[str], port: int, ipvo: Optional[Sequence[int]] = None, timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: super(SSH.Socket, self).__init__() self.__sock = None # type: Optional[socket.socket] self.__sock_map = {} # type: Dict[int, socket.socket] @@ -2046,8 +1917,7 @@ class SSH: # pylint: disable=too-few-public-methods self.client_host = None self.client_port = None - def _resolve(self, ipvo): - # type: (Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]] + def _resolve(self, ipvo: Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]: ipvo = tuple([x for x in utils.unique_seq(ipvo) if x in (4, 6)]) ipvo_len = len(ipvo) prefer_ipvo = ipvo_len > 0 @@ -2124,8 +1994,7 @@ class SSH: # pylint: disable=too-few-public-methods c.settimeout(self.__timeout) self.__sock = c - def connect(self): - # type: () -> None + def connect(self) -> None: err = None for af, addr in self._resolve(self.__ipvo): s = None @@ -2146,8 +2015,7 @@ class SSH: # pylint: disable=too-few-public-methods out.fail('[exception] {}'.format(errm)) sys.exit(1) - def get_banner(self, sshv=2): - # type: (int) -> Tuple[Optional[SSH.Banner], List[str], Optional[str]] + def get_banner(self, sshv: int = 2) -> Tuple[Optional['SSH.Banner'], List[str], Optional[str]]: if self.__sock is None: return self.__banner, self.__header, 'not connected' banner = SSH_HEADER.format('1.5' if sshv == 1 else '2.0') @@ -2177,8 +2045,7 @@ class SSH: # pylint: disable=too-few-public-methods s = 0 return self.__banner, self.__header, e - def recv(self, size=2048): - # type: (int) -> Tuple[int, Optional[str]] + def recv(self, size: int = 2048) -> Tuple[int, Optional[str]]: if self.__sock is None: return -1, 'not connected' try: @@ -2198,8 +2065,7 @@ class SSH: # pylint: disable=too-few-public-methods self._buf.seek(pos, 0) return len(data), None - def send(self, data): - # type: (bytes) -> Tuple[int, Optional[str]] + def send(self, data: bytes) -> Tuple[int, Optional[str]]: if self.__sock is None: return -1, 'not connected' try: @@ -2209,21 +2075,18 @@ class SSH: # pylint: disable=too-few-public-methods return -1, str(e.args[-1]) self.__sock.send(data) - def send_banner(self, banner): - # type: (str) -> None + def send_banner(self, banner: str) -> None: self.send(banner.encode() + b'\r\n') if self.__state < self.SM_BANNER_SENT: self.__state = self.SM_BANNER_SENT - def ensure_read(self, size): - # type: (int) -> None + def ensure_read(self, size: int) -> None: while self.unread_len < size: s, e = self.recv() if s < 0: raise SSH.Socket.InsufficientReadException(e) - def read_packet(self, sshv=2): - # type: (int) -> Tuple[int, bytes] + def read_packet(self, sshv: int = 2) -> Tuple[int, bytes]: try: header = WriteBuf() self.ensure_read(4) @@ -2274,8 +2137,7 @@ class SSH: # pylint: disable=too-few-public-methods e = ex.args[0].encode('utf-8') return -1, e - def send_packet(self): - # type: () -> Tuple[int, Optional[str]] + def send_packet(self) -> Tuple[int, Optional[str]]: payload = self.write_flush() padding = -(len(payload) + 5) % 8 if padding < 4: @@ -2296,8 +2158,7 @@ class SSH: # pylint: disable=too-few-public-methods self.__header = [] self.__banner = None - def _close_socket(self, s): # pylint: disable=no-self-use - # type: (Optional[socket.socket]) -> None + def _close_socket(self, s: Optional[socket.socket]) -> None: # pylint: disable=no-self-use try: if s is not None: s.shutdown(socket.SHUT_RDWR) @@ -2305,12 +2166,10 @@ class SSH: # pylint: disable=too-few-public-methods except Exception: pass - def __del__(self): - # type: () -> None + def __del__(self) -> None: self.__cleanup() - def __cleanup(self): - # type: () -> None + def __cleanup(self) -> None: self._close_socket(self.__sock) for fd in self.__sock_map: self._close_socket(self.__sock_map[fd]) @@ -2318,8 +2177,7 @@ class SSH: # pylint: disable=too-few-public-methods class KexDH: # pragma: nocover - def __init__(self, kex_name, hash_alg, g, p): - # type: (str, str, int, int) -> None + def __init__(self, kex_name: str, hash_alg: str, g: int, p: int) -> None: self.__kex_name = kex_name self.__hash_alg = hash_alg self.__g = 0 @@ -2345,8 +2203,7 @@ class KexDH: # pragma: nocover self.__x = 0 self.__e = 0 - def send_init(self, s, init_msg=SSH.Protocol.MSG_KEXDH_INIT): - # type: (SSH.Socket, int) -> None + def send_init(self, s: SSH.Socket, init_msg: int = SSH.Protocol.MSG_KEXDH_INIT) -> None: r = random.SystemRandom() self.__x = r.randrange(2, self.__q) self.__e = pow(self.__g, self.__x, self.__p) @@ -2503,16 +2360,14 @@ class KexDH: # pragma: nocover class KexGroup1(KexDH): # pragma: nocover - def __init__(self): - # type: () -> None + def __init__(self) -> None: # rfc2409: second oakley group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece65381ffffffffffffffff', 16) super(KexGroup1, self).__init__('KexGroup1', 'sha1', 2, p) class KexGroup14(KexDH): # pragma: nocover - def __init__(self, hash_alg): - # type: (str) -> None + def __init__(self, hash_alg: str) -> None: # rfc3526: 2048-bit modp group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aacaa68ffffffffffffffff', 16) super(KexGroup14, self).__init__('KexGroup14', hash_alg, 2, p) @@ -2653,8 +2508,7 @@ class KexGroupExchange_SHA256(KexGroupExchange): super(KexGroupExchange_SHA256, self).__init__('KexGroupExchange_SHA256', 'sha256') -def output_algorithms(title, alg_db, alg_type, algorithms, unknown_algs, maxlen=0, alg_sizes=None): - # type: (str, Dict[str, Dict[str, List[List[Optional[str]]]]], str, List[str], List[str], int, Optional[Dict[str, Iterable[int]]]) -> None +def output_algorithms(title: str, alg_db: Dict[str, Dict[str, List[List[Optional[str]]]]], alg_type: str, algorithms: List[str], unknown_algs: List[str], maxlen: int = 0, alg_sizes: Optional[Dict[str, Iterable[int]]] = None) -> None: with OutputBuffer() as obuf: for algorithm in algorithms: output_algorithm(alg_db, alg_type, algorithm, unknown_algs, maxlen, alg_sizes) @@ -2664,8 +2518,7 @@ def output_algorithms(title, alg_db, alg_type, algorithms, unknown_algs, maxlen= out.sep() -def output_algorithm(alg_db, alg_type, alg_name, unknown_algs, alg_max_len=0, alg_sizes=None): - # type: (Dict[str, Dict[str, List[List[Optional[str]]]]], str, str, List[str], int, Optional[Dict[str, Iterable[int]]]) -> None +def output_algorithm(alg_db: Dict[str, Dict[str, List[List[Optional[str]]]]], alg_type: str, alg_name: str, unknown_algs: List[str], alg_max_len: int = 0, alg_sizes: Optional[Dict[str, Iterable[int]]] = None) -> None: prefix = '(' + alg_type + ') ' if alg_max_len == 0: alg_max_len = len(alg_name) @@ -2726,8 +2579,7 @@ def output_algorithm(alg_db, alg_type, alg_name, unknown_algs, alg_max_len=0, al f(' ' * len(prefix + alg_name) + comment) -def output_compatibility(algs, client_audit, for_server=True): - # type: (SSH.Algorithms, bool, bool) -> None +def output_compatibility(algs: SSH.Algorithms, client_audit: bool, for_server: bool = True) -> None: # Don't output any compatibility info if we're doing a client audit. if client_audit: @@ -2757,8 +2609,7 @@ def output_compatibility(algs, client_audit, for_server=True): out.good('(gen) compatibility: ' + ', '.join(comp_text)) -def output_security_sub(sub, software, client_audit, padlen): - # type: (str, Optional[SSH.Software], bool, int) -> None +def output_security_sub(sub: str, software: Optional[SSH.Software], client_audit: bool, padlen: int) -> None: secdb = SSH.Security.CVE if sub == 'cve' else SSH.Security.TXT if software is None or software.product not in secdb: return @@ -2794,8 +2645,7 @@ def output_security_sub(sub, software, client_audit, padlen): out.fail('(sec) {}{} -- {}'.format(name, p, descr)) -def output_security(banner, client_audit, padlen): - # type: (Optional[SSH.Banner], bool, int) -> None +def output_security(banner: Optional[SSH.Banner], client_audit: bool, padlen: int) -> None: with OutputBuffer() as obuf: if banner is not None: software = SSH.Software.parse(banner) @@ -2807,8 +2657,7 @@ def output_security(banner, client_audit, padlen): out.sep() -def output_fingerprints(algs, sha256=True): - # type: (SSH.Algorithms, bool) -> None +def output_fingerprints(algs: SSH.Algorithms, sha256: bool = True) -> None: with OutputBuffer() as obuf: fps = [] if algs.ssh1kex is not None: @@ -2846,8 +2695,7 @@ def output_fingerprints(algs, sha256=True): # Returns True if no warnings or failures encountered in configuration. -def output_recommendations(algs, software, padlen=0): - # type: (SSH.Algorithms, Optional[SSH.Software], int) -> bool +def output_recommendations(algs: SSH.Algorithms, software: Optional[SSH.Software], padlen: int = 0) -> bool: ret = True # PuTTY's algorithms cannot be modified, so there's no point in issuing recommendations. @@ -2935,8 +2783,7 @@ def output_info(software, client_audit, any_problems): out.sep() -def output(banner, header, client_host=None, kex=None, pkm=None): - # type: (Optional[SSH.Banner], List[str], Optional[str], Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage]) -> None +def output(banner: Optional[SSH.Banner], header: List[str], client_host: Optional[str] = None, kex: Optional[SSH2.Kex] = None, pkm: Optional[SSH1.PublicKeyMessage] = None) -> None: client_audit = client_host is not None # If set, this is a client audit. sshv = 1 if pkm is not None else 2 algs = SSH.Algorithms(pkm, kex) @@ -3004,13 +2851,11 @@ def output(banner, header, client_host=None, kex=None, pkm=None): class Utils: @classmethod - def _type_err(cls, v, target): - # type: (Any, str) -> TypeError + def _type_err(cls, v: Any, target: str) -> TypeError: return TypeError('cannot convert {} to {}'.format(type(v), target)) @classmethod - def to_bytes(cls, v, enc='utf-8'): - # type: (Union[bytes, str], str) -> bytes + def to_bytes(cls, v: Union[bytes, str], enc: str = 'utf-8') -> bytes: if isinstance(v, bytes): return v elif isinstance(v, str): @@ -3018,8 +2863,7 @@ class Utils: raise cls._type_err(v, 'bytes') @classmethod - def to_text(cls, v, enc='utf-8'): - # type: (Union[str, bytes], str) -> str + def to_text(cls, v: Union[str, bytes], enc: str = 'utf-8') -> str: if isinstance(v, str): return v elif isinstance(v, bytes): @@ -3027,8 +2871,7 @@ class Utils: raise cls._type_err(v, 'unicode text') @classmethod - def _is_ascii(cls, v, char_filter=lambda x: x <= 127): - # type: (str, Callable[[int], bool]) -> bool + def _is_ascii(cls, v: str, char_filter: Callable[[int], bool] = lambda x: x <= 127) -> bool: r = False if isinstance(v, str): for c in v: @@ -3039,8 +2882,7 @@ class Utils: return r @classmethod - def _to_ascii(cls, v, char_filter=lambda x: x <= 127, errors='replace'): - # type: (str, Callable[[int], bool], str) -> str + def _to_ascii(cls, v: str, char_filter: Callable[[int], bool] = lambda x: x <= 127, errors: str = 'replace') -> str: if isinstance(v, str): r = bytearray() for c in v: @@ -3055,32 +2897,26 @@ class Utils: raise cls._type_err(v, 'ascii') @classmethod - def is_ascii(cls, v): - # type: (str) -> bool + def is_ascii(cls, v: str) -> bool: return cls._is_ascii(v) @classmethod - def to_ascii(cls, v, errors='replace'): - # type: (str, str) -> str + def to_ascii(cls, v: str, errors: str = 'replace') -> str: return cls._to_ascii(v, errors=errors) @classmethod - def is_print_ascii(cls, v): - # type: (str) -> bool + def is_print_ascii(cls, v: str) -> bool: return cls._is_ascii(v, lambda x: 126 >= x >= 32) @classmethod - def to_print_ascii(cls, v, errors='replace'): - # type: (str, str) -> str + def to_print_ascii(cls, v: str, errors: str = 'replace') -> str: return cls._to_ascii(v, lambda x: 126 >= x >= 32, errors) @classmethod - def unique_seq(cls, seq): - # type: (Sequence[Any]) -> Sequence[Any] + def unique_seq(cls, seq: Sequence[Any]) -> Sequence[Any]: seen = set() # type: Set[Any] - def _seen_add(x): - # type: (Any) -> bool + def _seen_add(x: Any) -> bool: seen.add(x) return False @@ -3090,24 +2926,21 @@ class Utils: return [x for x in seq if x not in seen and not _seen_add(x)] @classmethod - def ctoi(cls, c): - # type: (Union[str, int]) -> int + def ctoi(cls, c: Union[str, int]) -> int: if isinstance(c, str): return ord(c[0]) else: return c @staticmethod - def parse_int(v): - # type: (Any) -> int + def parse_int(v: Any) -> int: try: return int(v) except Exception: # pylint: disable=bare-except return 0 @staticmethod - def parse_float(v): - # type: (Any) -> float + def parse_float(v: Any) -> float: try: return float(v) except Exception: # pylint: disable=bare-except @@ -3192,8 +3025,7 @@ def build_struct(banner, kex=None, pkm=None, client_host=None): return res -def audit(aconf, sshv=None): - # type: (AuditConf, Optional[int]) -> None +def audit(aconf: AuditConf, sshv: Optional[int] = None) -> None: out.batch = aconf.batch out.verbose = aconf.verbose out.level = aconf.level