From b300ad1252ca4c595a3a24264e08b028e4e69ea5 Mon Sep 17 00:00:00 2001 From: Joe Testa Date: Tue, 23 Feb 2021 16:05:01 -0500 Subject: [PATCH] Refactored IPv4/6 preference logic to fix pylint warnings. --- src/ssh_audit/auditconf.py | 30 +++++++------------------- src/ssh_audit/ssh_audit.py | 2 +- src/ssh_audit/ssh_socket.py | 31 +++++++++++--------------- test/test_auditconf.py | 43 ++++++++++--------------------------- test/test_resolve.py | 27 ++++++++++++----------- 5 files changed, 47 insertions(+), 86 deletions(-) diff --git a/src/ssh_audit/auditconf.py b/src/ssh_audit/auditconf.py index 25a2c06..71b71ac 100644 --- a/src/ssh_audit/auditconf.py +++ b/src/ssh_audit/auditconf.py @@ -43,7 +43,7 @@ class AuditConf: self.json = False self.verbose = False self.level = 'info' - self.ipvo: Sequence[int] = () + self.ip_version_preference: List[int] = [] # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6). self.ipv4 = False self.ipv6 = False self.make_policy = False # When True, creates a policy file from an audit scan. @@ -60,28 +60,14 @@ class AuditConf: def __setattr__(self, name: str, value: Union[str, int, float, bool, Sequence[int]]) -> None: valid = False - if name in ['ssh1', 'ssh2', 'batch', 'client_audit', 'colors', 'verbose', 'timeout_set', 'json', 'make_policy', 'list_policies', 'manual']: + if name in ['batch', 'client_audit', 'colors', 'json', 'list_policies', 'manual', 'make_policy', 'ssh1', 'ssh2', 'timeout_set', 'verbose']: valid, value = True, bool(value) elif name in ['ipv4', 'ipv6']: - valid = False - value = bool(value) - ipv = 4 if name == 'ipv4' else 6 - if value: - value = tuple(list(self.ipvo) + [ipv]) - else: # pylint: disable=else-if-used - if len(self.ipvo) == 0: - value = (6,) if ipv == 4 else (4,) - else: - value = tuple([x for x in self.ipvo if x != ipv]) - self.__setattr__('ipvo', value) - elif name == 'ipvo': - if isinstance(value, (tuple, list)): - uniq_value = Utils.unique_seq(value) - value = tuple([x for x in uniq_value if x in (4, 6)]) - valid = True - ipv_both = len(value) == 0 - object.__setattr__(self, 'ipv4', ipv_both or 4 in value) - object.__setattr__(self, 'ipv6', ipv_both or 6 in value) + valid, value = True, bool(value) + if len(self.ip_version_preference) == 2: # Being called more than twice is not valid. + valid = False + elif value: + self.ip_version_preference.append(4 if name == 'ipv4' else 6) elif name == 'port': valid, port = True, Utils.parse_int(value) if port < 1 or port > 65535: @@ -98,7 +84,7 @@ class AuditConf: if value == -1.0: raise ValueError('invalid timeout: {}'.format(value)) valid = True - elif name in ['policy_file', 'policy', 'target_file', 'target_list', 'lookup']: + elif name in ['ip_version_preference', 'lookup', 'policy_file', 'policy', 'target_file', 'target_list']: valid = True elif name == "threads": valid, num_threads = True, Utils.parse_int(value) diff --git a/src/ssh_audit/ssh_audit.py b/src/ssh_audit/ssh_audit.py index b4d5a9e..eb08707 100755 --- a/src/ssh_audit/ssh_audit.py +++ b/src/ssh_audit/ssh_audit.py @@ -815,7 +815,7 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print out.verbose = aconf.verbose out.level = aconf.level out.use_colors = aconf.colors - s = SSH_Socket(aconf.host, aconf.port, aconf.ipvo, aconf.timeout, aconf.timeout_set) + s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set) if aconf.client_audit: out.v("Listening for client connection on port %d..." % aconf.port, write_now=True) s.listen_and_accept() diff --git a/src/ssh_audit/ssh_socket.py b/src/ssh_audit/ssh_socket.py index 0960537..28dfc50 100644 --- a/src/ssh_audit/ssh_socket.py +++ b/src/ssh_audit/ssh_socket.py @@ -1,7 +1,7 @@ """ The MIT License (MIT) - Copyright (C) 2017-2020 Joe Testa (jtesta@positronsecurity.com) + Copyright (C) 2017-2021 Joe Testa (jtesta@positronsecurity.com) Copyright (C) 2017 Andris Raugulis (moo@arthepsy.eu) Permission is hereby granted, free of charge, to any person obtaining a copy @@ -52,7 +52,7 @@ class SSH_Socket(ReadBuf, WriteBuf): SM_BANNER_SENT = 1 - def __init__(self, host: Optional[str], port: int, ipvo: Optional[Sequence[int]] = None, timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: + def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value super(SSH_Socket, self).__init__() self.__sock: Optional[socket.socket] = None self.__sock_map: Dict[int, socket.socket] = {} @@ -67,32 +67,27 @@ class SSH_Socket(ReadBuf, WriteBuf): raise ValueError('invalid port: {}'.format(port)) self.__host = host self.__port = nport - if ipvo is not None: - self.__ipvo = ipvo - else: - self.__ipvo = () + self.__ip_version_preference = ip_version_preference # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6). self.__timeout = timeout self.__timeout_set = timeout_set self.client_host: Optional[str] = None self.client_port = None - def _resolve(self, ipvo: Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]: - ipvo = tuple([x for x in Utils.unique_seq(ipvo) if x in (4, 6)]) - ipvo_len = len(ipvo) - prefer_ipvo = ipvo_len > 0 - prefer_ipv4 = prefer_ipvo and ipvo[0] == 4 - if ipvo_len == 1: - family = socket.AF_INET if ipvo[0] == 4 else socket.AF_INET6 + def _resolve(self) -> Iterable[Tuple[int, Tuple[Any, ...]]]: + # If __ip_version_preference has only one entry, then it means that ONLY that IP version should be used. + if len(self.__ip_version_preference) == 1: + family = socket.AF_INET if self.__ip_version_preference[0] == 4 else socket.AF_INET6 else: family = socket.AF_UNSPEC try: stype = socket.SOCK_STREAM r = socket.getaddrinfo(self.__host, self.__port, family, stype) - if prefer_ipvo: - r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4) - check = any(stype == rline[2] for rline in r) + + # If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first. + if len(self.__ip_version_preference) == 2: + r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6)) for af, socktype, _proto, _canonname, addr in r: - if not check or socktype == socket.SOCK_STREAM: + if socktype == socket.SOCK_STREAM: yield af, addr except socket.error as e: OutputBuffer().fail('[exception] {}'.format(e)).write() @@ -156,7 +151,7 @@ class SSH_Socket(ReadBuf, WriteBuf): def connect(self) -> Optional[str]: '''Returns None on success, or an error string.''' err = None - for af, addr in self._resolve(self.__ipvo): + for af, addr in self._resolve(): s = None try: s = socket.socket(af, socket.SOCK_STREAM) diff --git a/test/test_auditconf.py b/test/test_auditconf.py index 064af54..05e7089 100644 --- a/test/test_auditconf.py +++ b/test/test_auditconf.py @@ -22,9 +22,8 @@ class TestAuditConf: 'colors': True, 'verbose': False, 'level': 'info', - 'ipv4': True, - 'ipv6': True, - 'ipvo': () + 'ipv4': False, + 'ipv6': False } for k, v in kwargs.items(): options[k] = v @@ -38,7 +37,6 @@ class TestAuditConf: assert conf.level == options['level'] assert conf.ipv4 == options['ipv4'] assert conf.ipv6 == options['ipv6'] - assert conf.ipvo == options['ipvo'] def test_audit_conf_defaults(self): conf = self.AuditConf() @@ -64,57 +62,38 @@ class TestAuditConf: conf.port = port excinfo.match(r'.*invalid port.*') - def test_audit_conf_ipvo(self): + def test_audit_conf_ip_version_preference(self): # ipv4-only conf = self.AuditConf() conf.ipv4 = True assert conf.ipv4 is True assert conf.ipv6 is False - assert conf.ipvo == (4,) + assert conf.ip_version_preference == [4] # ipv6-only conf = self.AuditConf() conf.ipv6 = True assert conf.ipv4 is False assert conf.ipv6 is True - assert conf.ipvo == (6,) - # ipv4-only (by removing ipv6) - conf = self.AuditConf() - conf.ipv6 = False - assert conf.ipv4 is True - assert conf.ipv6 is False - assert conf.ipvo == (4, ) - # ipv6-only (by removing ipv4) - conf = self.AuditConf() - conf.ipv4 = False - assert conf.ipv4 is False - assert conf.ipv6 is True - assert conf.ipvo == (6, ) + assert conf.ip_version_preference == [6] # ipv4-preferred conf = self.AuditConf() conf.ipv4 = True conf.ipv6 = True assert conf.ipv4 is True assert conf.ipv6 is True - assert conf.ipvo == (4, 6) + assert conf.ip_version_preference == [4, 6] # ipv6-preferred conf = self.AuditConf() conf.ipv6 = True conf.ipv4 = True assert conf.ipv4 is True assert conf.ipv6 is True - assert conf.ipvo == (6, 4) - # ipvo empty + assert conf.ip_version_preference == [6, 4] + # defaults conf = self.AuditConf() - conf.ipvo = () - assert conf.ipv4 is True - assert conf.ipv6 is True - assert conf.ipvo == () - # ipvo validation - conf = self.AuditConf() - conf.ipvo = (1, 2, 3, 4, 5, 6) - assert conf.ipvo == (4, 6) - conf.ipvo = (4, 4, 4, 6, 6) - assert conf.ipvo == (4, 6) + assert conf.ipv4 is False + assert conf.ipv6 is False + assert conf.ip_version_preference == [] def test_audit_conf_level(self): conf = self.AuditConf() diff --git a/test/test_resolve.py b/test/test_resolve.py index 11cba1e..c5b612c 100644 --- a/test/test_resolve.py +++ b/test/test_resolve.py @@ -19,11 +19,11 @@ class TestResolve: def test_resolve_error(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') - s = self.ssh_socket('localhost', 22) conf = self._conf() + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) output_spy.begin() with pytest.raises(SystemExit): - list(s._resolve(conf.ipvo)) + list(s._resolve()) lines = output_spy.flush() assert len(lines) == 1 assert 'hostname nor servname provided' in lines[-1] @@ -31,49 +31,50 @@ class TestResolve: def test_resolve_hostname_without_records(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.gsock.addrinfodata['localhost#22'] = [] - s = self.ssh_socket('localhost', 22) conf = self._conf() + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) output_spy.begin() - r = list(s._resolve(conf.ipvo)) + r = list(s._resolve()) assert len(r) == 0 def test_resolve_ipv4(self, virtual_socket): conf = self._conf() conf.ipv4 = True - s = self.ssh_socket('localhost', 22) - r = list(s._resolve(conf.ipvo)) + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + r = list(s._resolve()) assert len(r) == 1 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) def test_resolve_ipv6(self, virtual_socket): - s = self.ssh_socket('localhost', 22) conf = self._conf() conf.ipv6 = True - r = list(s._resolve(conf.ipvo)) + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + r = list(s._resolve()) assert len(r) == 1 assert r[0] == (socket.AF_INET6, ('::1', 22)) def test_resolve_ipv46_both(self, virtual_socket): - s = self.ssh_socket('localhost', 22) conf = self._conf() - r = list(s._resolve(conf.ipvo)) + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) assert r[1] == (socket.AF_INET6, ('::1', 22)) def test_resolve_ipv46_order(self, virtual_socket): - s = self.ssh_socket('localhost', 22) conf = self._conf() conf.ipv4 = True conf.ipv6 = True - r = list(s._resolve(conf.ipvo)) + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) assert r[1] == (socket.AF_INET6, ('::1', 22)) conf = self._conf() conf.ipv6 = True conf.ipv4 = True - r = list(s._resolve(conf.ipvo)) + s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET6, ('::1', 22)) assert r[1] == (socket.AF_INET, ('127.0.0.1', 22))