diff --git a/src/ssh_audit/gextest.py b/src/ssh_audit/gextest.py index 7f84fc6..312e6aa 100644 --- a/src/ssh_audit/gextest.py +++ b/src/ssh_audit/gextest.py @@ -43,12 +43,12 @@ class GEXTest: if s.is_connected(): return True - err = s.connect(out) + err = s.connect() if err is not None: out.v(err, write_now=True) return False - _, _, err = s.get_banner(out) + _, _, err = s.get_banner() if err is not None: out.v(err, write_now=True) s.close() @@ -56,7 +56,7 @@ class GEXTest: # Send our KEX using the specified group-exchange and most of the # server's own values. - s.send_kexinit(out, key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages) + s.send_kexinit(key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages) # Parse the server's KEX. _, payload = s.read_packet(2) diff --git a/src/ssh_audit/hostkeytest.py b/src/ssh_audit/hostkeytest.py index 7a37055..8b0a6eb 100644 --- a/src/ssh_audit/hostkeytest.py +++ b/src/ssh_audit/hostkeytest.py @@ -109,19 +109,19 @@ class HostKeyTest: # If the connection is closed, re-open it and get the kex again. if not s.is_connected(): - err = s.connect(out) + err = s.connect() if err is not None: out.v(err, write_now=True) return - _, _, err = s.get_banner(out) + _, _, err = s.get_banner() if err is not None: out.v(err, write_now=True) s.close() return # Send our KEX using the specified group-exchange and most of the server's own values. - s.send_kexinit(out, key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages) + s.send_kexinit(key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages) # Parse the server's KEX. _, payload = s.read_packet() diff --git a/src/ssh_audit/ssh_audit.py b/src/ssh_audit/ssh_audit.py index 1382e61..50424e5 100755 --- a/src/ssh_audit/ssh_audit.py +++ b/src/ssh_audit/ssh_audit.py @@ -820,14 +820,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print out.debug = aconf.debug out.level = aconf.level out.use_colors = aconf.colors - s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set) + s = SSH_Socket(out, aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set) if aconf.client_audit: out.v("Listening for client connection on port %d..." % aconf.port, write_now=True) s.listen_and_accept() else: out.v("Starting audit of %s:%d..." % ('[%s]' % aconf.host if Utils.is_ipv6_address(aconf.host) else aconf.host, aconf.port), write_now=True) - err = s.connect(out) + err = s.connect() if err is not None: out.fail(err) @@ -842,14 +842,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print if sshv is None: sshv = 2 if aconf.ssh2 else 1 err = None - banner, header, err = s.get_banner(out, sshv) + banner, header, err = s.get_banner(sshv) if banner is None: if err is None: err = '[exception] did not receive banner.' else: err = '[exception] did not receive banner: {}'.format(err) if err is None: - s.send_kexinit(out) # Send the algorithms we support (except we don't since this isn't a real SSH connection). + s.send_kexinit() # Send the algorithms we support (except we don't since this isn't a real SSH connection). packet_type, payload = s.read_packet(sshv) if packet_type < 0: diff --git a/src/ssh_audit/ssh_socket.py b/src/ssh_audit/ssh_socket.py index 35c2a9c..3891588 100644 --- a/src/ssh_audit/ssh_socket.py +++ b/src/ssh_audit/ssh_socket.py @@ -52,8 +52,9 @@ class SSH_Socket(ReadBuf, WriteBuf): SM_BANNER_SENT = 1 - def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value + def __init__(self, outputbuffer: 'OutputBuffer', host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value super(SSH_Socket, self).__init__() + self.__outputbuffer = outputbuffer self.__sock: Optional[socket.socket] = None self.__sock_map: Dict[int, socket.socket] = {} self.__block_size = 8 @@ -90,7 +91,7 @@ class SSH_Socket(ReadBuf, WriteBuf): if socktype == socket.SOCK_STREAM: yield af, addr except socket.error as e: - OutputBuffer().fail('[exception] {}'.format(e)).write() + self.__outputbuffer.fail('[exception] {}'.format(e)).write() sys.exit(exitcodes.CONNECTION_ERROR) # Listens on a server socket and accepts one connection (used for @@ -148,7 +149,7 @@ class SSH_Socket(ReadBuf, WriteBuf): c.settimeout(self.__timeout) self.__sock = c - def connect(self, out: 'OutputBuffer') -> Optional[str]: + def connect(self) -> Optional[str]: '''Returns None on success, or an error string.''' err = None for af, addr in self._resolve(): @@ -156,7 +157,7 @@ class SSH_Socket(ReadBuf, WriteBuf): try: s = socket.socket(af, socket.SOCK_STREAM) s.settimeout(self.__timeout) - out.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True) + self.__outputbuffer.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True) s.connect(addr) self.__sock = s return None @@ -170,8 +171,8 @@ class SSH_Socket(ReadBuf, WriteBuf): errm = 'cannot connect to {} port {}: {}'.format(*errt) return '[exception] {}'.format(errm) - def get_banner(self, out: 'OutputBuffer', sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]: - out.d('Getting banner...', write_now=True) + def get_banner(self, sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]: + self.__outputbuffer.d('Getting banner...', write_now=True) if self.__sock is None: return self.__banner, self.__header, 'not connected' @@ -229,10 +230,10 @@ class SSH_Socket(ReadBuf, WriteBuf): return -1, str(e.args[-1]) # Send a KEXINIT with the lists of key exchanges, hostkeys, ciphers, MACs, compressions, and languages that we "support". - def send_kexinit(self, out: 'OutputBuffer', key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value + def send_kexinit(self, key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value '''Sends the list of supported host keys, key exchanges, ciphers, and MACs. Emulates OpenSSH v8.2.''' - out.d('KEX initialisation...', write_now=True) + self.__outputbuffer.d('KEX initialisation...', write_now=True) kexparty = SSH2_KexParty(ciphers, macs, compressions, languages) kex = SSH2_Kex(os.urandom(16), key_exchanges, hostkeys, kexparty, kexparty, False, 0) @@ -273,7 +274,7 @@ class SSH_Socket(ReadBuf, WriteBuf): payload_length = packet_length - padding_length - 1 check_size = 4 + 1 + payload_length + padding_length if check_size % self.__block_size != 0: - OutputBuffer().fail('[exception] invalid ssh packet (block size)').write() + self.__outputbuffer.fail('[exception] invalid ssh packet (block size)').write() sys.exit(exitcodes.CONNECTION_ERROR) self.ensure_read(payload_length) if sshv == 1: @@ -288,7 +289,7 @@ class SSH_Socket(ReadBuf, WriteBuf): if sshv == 1: rcrc = SSH1.crc32(padding + payload) if crc != rcrc: - OutputBuffer().fail('[exception] packet checksum CRC32 mismatch.').write() + self.__outputbuffer.fail('[exception] packet checksum CRC32 mismatch.').write() sys.exit(exitcodes.CONNECTION_ERROR) else: self.ensure_read(padding_length) diff --git a/test/test_resolve.py b/test/test_resolve.py index c5b612c..fbaa033 100644 --- a/test/test_resolve.py +++ b/test/test_resolve.py @@ -8,6 +8,7 @@ class TestResolve: def init(self, ssh_audit): self.AuditConf = ssh_audit.AuditConf self.audit = ssh_audit.audit + self.OutputBuffer = ssh_audit.OutputBuffer self.ssh_socket = ssh_audit.SSH_Socket def _conf(self): @@ -20,7 +21,7 @@ class TestResolve: vsocket = virtual_socket vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') conf = self._conf() - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) output_spy.begin() with pytest.raises(SystemExit): list(s._resolve()) @@ -32,7 +33,7 @@ class TestResolve: vsocket = virtual_socket vsocket.gsock.addrinfodata['localhost#22'] = [] conf = self._conf() - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) output_spy.begin() r = list(s._resolve()) assert len(r) == 0 @@ -40,7 +41,7 @@ class TestResolve: def test_resolve_ipv4(self, virtual_socket): conf = self._conf() conf.ipv4 = True - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) r = list(s._resolve()) assert len(r) == 1 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) @@ -48,14 +49,14 @@ class TestResolve: def test_resolve_ipv6(self, virtual_socket): conf = self._conf() conf.ipv6 = True - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) r = list(s._resolve()) assert len(r) == 1 assert r[0] == (socket.AF_INET6, ('::1', 22)) def test_resolve_ipv46_both(self, virtual_socket): conf = self._conf() - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) @@ -65,7 +66,7 @@ class TestResolve: conf = self._conf() conf.ipv4 = True conf.ipv6 = True - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) @@ -73,7 +74,7 @@ class TestResolve: conf = self._conf() conf.ipv6 = True conf.ipv4 = True - s = self.ssh_socket('localhost', 22, conf.ip_version_preference) + s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference) r = list(s._resolve()) assert len(r) == 2 assert r[0] == (socket.AF_INET6, ('::1', 22)) diff --git a/test/test_socket.py b/test/test_socket.py index 6cf7490..9569768 100644 --- a/test/test_socket.py +++ b/test/test_socket.py @@ -1,5 +1,6 @@ import pytest +from ssh_audit.outputbuffer import OutputBuffer from ssh_audit.ssh_socket import SSH_Socket @@ -7,24 +8,25 @@ from ssh_audit.ssh_socket import SSH_Socket class TestSocket: @pytest.fixture(autouse=True) def init(self, ssh_audit): + self.OutputBuffer = OutputBuffer self.ssh_socket = SSH_Socket def test_invalid_host(self, virtual_socket): with pytest.raises(ValueError): - self.ssh_socket(None, 22) + self.ssh_socket(self.OutputBuffer(), None, 22) def test_invalid_port(self, virtual_socket): with pytest.raises(ValueError): - self.ssh_socket('localhost', 'abc') + self.ssh_socket(self.OutputBuffer(), 'localhost', 'abc') with pytest.raises(ValueError): - self.ssh_socket('localhost', -1) + self.ssh_socket(self.OutputBuffer(), 'localhost', -1) with pytest.raises(ValueError): - self.ssh_socket('localhost', 0) + self.ssh_socket(self.OutputBuffer(), 'localhost', 0) with pytest.raises(ValueError): - self.ssh_socket('localhost', 65536) + self.ssh_socket(self.OutputBuffer(), 'localhost', 65536) def test_not_connected_socket(self, virtual_socket): - sock = self.ssh_socket('localhost', 22) + sock = self.ssh_socket(self.OutputBuffer(), 'localhost', 22) banner, header, err = sock.get_banner() assert banner is None assert len(header) == 0