mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-11-03 18:52:15 +01:00 
			
		
		
		
	SSH_Socket's constructor now takes an OutputBuffer for verbose & debugging output.
This commit is contained in:
		@@ -43,12 +43,12 @@ class GEXTest:
 | 
			
		||||
        if s.is_connected():
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        err = s.connect(out)
 | 
			
		||||
        err = s.connect()
 | 
			
		||||
        if err is not None:
 | 
			
		||||
            out.v(err, write_now=True)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        _, _, err = s.get_banner(out)
 | 
			
		||||
        _, _, err = s.get_banner()
 | 
			
		||||
        if err is not None:
 | 
			
		||||
            out.v(err, write_now=True)
 | 
			
		||||
            s.close()
 | 
			
		||||
@@ -56,7 +56,7 @@ class GEXTest:
 | 
			
		||||
 | 
			
		||||
        # Send our KEX using the specified group-exchange and most of the
 | 
			
		||||
        # server's own values.
 | 
			
		||||
        s.send_kexinit(out, key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages)
 | 
			
		||||
        s.send_kexinit(key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages)
 | 
			
		||||
 | 
			
		||||
        # Parse the server's KEX.
 | 
			
		||||
        _, payload = s.read_packet(2)
 | 
			
		||||
 
 | 
			
		||||
@@ -109,19 +109,19 @@ class HostKeyTest:
 | 
			
		||||
 | 
			
		||||
                # If the connection is closed, re-open it and get the kex again.
 | 
			
		||||
                if not s.is_connected():
 | 
			
		||||
                    err = s.connect(out)
 | 
			
		||||
                    err = s.connect()
 | 
			
		||||
                    if err is not None:
 | 
			
		||||
                        out.v(err, write_now=True)
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
                    _, _, err = s.get_banner(out)
 | 
			
		||||
                    _, _, err = s.get_banner()
 | 
			
		||||
                    if err is not None:
 | 
			
		||||
                        out.v(err, write_now=True)
 | 
			
		||||
                        s.close()
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
                    # Send our KEX using the specified group-exchange and most of the server's own values.
 | 
			
		||||
                    s.send_kexinit(out, key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages)
 | 
			
		||||
                    s.send_kexinit(key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages)
 | 
			
		||||
 | 
			
		||||
                    # Parse the server's KEX.
 | 
			
		||||
                    _, payload = s.read_packet()
 | 
			
		||||
 
 | 
			
		||||
@@ -820,14 +820,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
 | 
			
		||||
    out.debug = aconf.debug
 | 
			
		||||
    out.level = aconf.level
 | 
			
		||||
    out.use_colors = aconf.colors
 | 
			
		||||
    s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
 | 
			
		||||
    s = SSH_Socket(out, aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
 | 
			
		||||
 | 
			
		||||
    if aconf.client_audit:
 | 
			
		||||
        out.v("Listening for client connection on port %d..." % aconf.port, write_now=True)
 | 
			
		||||
        s.listen_and_accept()
 | 
			
		||||
    else:
 | 
			
		||||
        out.v("Starting audit of %s:%d..." % ('[%s]' % aconf.host if Utils.is_ipv6_address(aconf.host) else aconf.host, aconf.port), write_now=True)
 | 
			
		||||
        err = s.connect(out)
 | 
			
		||||
        err = s.connect()
 | 
			
		||||
 | 
			
		||||
        if err is not None:
 | 
			
		||||
            out.fail(err)
 | 
			
		||||
@@ -842,14 +842,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
 | 
			
		||||
    if sshv is None:
 | 
			
		||||
        sshv = 2 if aconf.ssh2 else 1
 | 
			
		||||
    err = None
 | 
			
		||||
    banner, header, err = s.get_banner(out, sshv)
 | 
			
		||||
    banner, header, err = s.get_banner(sshv)
 | 
			
		||||
    if banner is None:
 | 
			
		||||
        if err is None:
 | 
			
		||||
            err = '[exception] did not receive banner.'
 | 
			
		||||
        else:
 | 
			
		||||
            err = '[exception] did not receive banner: {}'.format(err)
 | 
			
		||||
    if err is None:
 | 
			
		||||
        s.send_kexinit(out)  # Send the algorithms we support (except we don't since this isn't a real SSH connection).
 | 
			
		||||
        s.send_kexinit()  # Send the algorithms we support (except we don't since this isn't a real SSH connection).
 | 
			
		||||
 | 
			
		||||
        packet_type, payload = s.read_packet(sshv)
 | 
			
		||||
        if packet_type < 0:
 | 
			
		||||
 
 | 
			
		||||
@@ -52,8 +52,9 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
 | 
			
		||||
    SM_BANNER_SENT = 1
 | 
			
		||||
 | 
			
		||||
    def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:  # pylint: disable=dangerous-default-value
 | 
			
		||||
    def __init__(self, outputbuffer: 'OutputBuffer', host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:  # pylint: disable=dangerous-default-value
 | 
			
		||||
        super(SSH_Socket, self).__init__()
 | 
			
		||||
        self.__outputbuffer = outputbuffer
 | 
			
		||||
        self.__sock: Optional[socket.socket] = None
 | 
			
		||||
        self.__sock_map: Dict[int, socket.socket] = {}
 | 
			
		||||
        self.__block_size = 8
 | 
			
		||||
@@ -90,7 +91,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
                if socktype == socket.SOCK_STREAM:
 | 
			
		||||
                    yield af, addr
 | 
			
		||||
        except socket.error as e:
 | 
			
		||||
            OutputBuffer().fail('[exception] {}'.format(e)).write()
 | 
			
		||||
            self.__outputbuffer.fail('[exception] {}'.format(e)).write()
 | 
			
		||||
            sys.exit(exitcodes.CONNECTION_ERROR)
 | 
			
		||||
 | 
			
		||||
    # Listens on a server socket and accepts one connection (used for
 | 
			
		||||
@@ -148,7 +149,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
        c.settimeout(self.__timeout)
 | 
			
		||||
        self.__sock = c
 | 
			
		||||
 | 
			
		||||
    def connect(self, out: 'OutputBuffer') -> Optional[str]:
 | 
			
		||||
    def connect(self) -> Optional[str]:
 | 
			
		||||
        '''Returns None on success, or an error string.'''
 | 
			
		||||
        err = None
 | 
			
		||||
        for af, addr in self._resolve():
 | 
			
		||||
@@ -156,7 +157,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
            try:
 | 
			
		||||
                s = socket.socket(af, socket.SOCK_STREAM)
 | 
			
		||||
                s.settimeout(self.__timeout)
 | 
			
		||||
                out.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True)
 | 
			
		||||
                self.__outputbuffer.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True)
 | 
			
		||||
                s.connect(addr)
 | 
			
		||||
                self.__sock = s
 | 
			
		||||
                return None
 | 
			
		||||
@@ -170,8 +171,8 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
            errm = 'cannot connect to {} port {}: {}'.format(*errt)
 | 
			
		||||
        return '[exception] {}'.format(errm)
 | 
			
		||||
 | 
			
		||||
    def get_banner(self, out: 'OutputBuffer', sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]:
 | 
			
		||||
        out.d('Getting banner...', write_now=True)
 | 
			
		||||
    def get_banner(self, sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]:
 | 
			
		||||
        self.__outputbuffer.d('Getting banner...', write_now=True)
 | 
			
		||||
 | 
			
		||||
        if self.__sock is None:
 | 
			
		||||
            return self.__banner, self.__header, 'not connected'
 | 
			
		||||
@@ -229,10 +230,10 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
            return -1, str(e.args[-1])
 | 
			
		||||
 | 
			
		||||
    # Send a KEXINIT with the lists of key exchanges, hostkeys, ciphers, MACs, compressions, and languages that we "support".
 | 
			
		||||
    def send_kexinit(self, out: 'OutputBuffer', key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None:  # pylint: disable=dangerous-default-value
 | 
			
		||||
    def send_kexinit(self, key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None:  # pylint: disable=dangerous-default-value
 | 
			
		||||
        '''Sends the list of supported host keys, key exchanges, ciphers, and MACs.  Emulates OpenSSH v8.2.'''
 | 
			
		||||
 | 
			
		||||
        out.d('KEX initialisation...', write_now=True)
 | 
			
		||||
        self.__outputbuffer.d('KEX initialisation...', write_now=True)
 | 
			
		||||
 | 
			
		||||
        kexparty = SSH2_KexParty(ciphers, macs, compressions, languages)
 | 
			
		||||
        kex = SSH2_Kex(os.urandom(16), key_exchanges, hostkeys, kexparty, kexparty, False, 0)
 | 
			
		||||
@@ -273,7 +274,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
                payload_length = packet_length - padding_length - 1
 | 
			
		||||
                check_size = 4 + 1 + payload_length + padding_length
 | 
			
		||||
            if check_size % self.__block_size != 0:
 | 
			
		||||
                OutputBuffer().fail('[exception] invalid ssh packet (block size)').write()
 | 
			
		||||
                self.__outputbuffer.fail('[exception] invalid ssh packet (block size)').write()
 | 
			
		||||
                sys.exit(exitcodes.CONNECTION_ERROR)
 | 
			
		||||
            self.ensure_read(payload_length)
 | 
			
		||||
            if sshv == 1:
 | 
			
		||||
@@ -288,7 +289,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
            if sshv == 1:
 | 
			
		||||
                rcrc = SSH1.crc32(padding + payload)
 | 
			
		||||
                if crc != rcrc:
 | 
			
		||||
                    OutputBuffer().fail('[exception] packet checksum CRC32 mismatch.').write()
 | 
			
		||||
                    self.__outputbuffer.fail('[exception] packet checksum CRC32 mismatch.').write()
 | 
			
		||||
                    sys.exit(exitcodes.CONNECTION_ERROR)
 | 
			
		||||
            else:
 | 
			
		||||
                self.ensure_read(padding_length)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ class TestResolve:
 | 
			
		||||
    def init(self, ssh_audit):
 | 
			
		||||
        self.AuditConf = ssh_audit.AuditConf
 | 
			
		||||
        self.audit = ssh_audit.audit
 | 
			
		||||
        self.OutputBuffer = ssh_audit.OutputBuffer
 | 
			
		||||
        self.ssh_socket = ssh_audit.SSH_Socket
 | 
			
		||||
 | 
			
		||||
    def _conf(self):
 | 
			
		||||
@@ -20,7 +21,7 @@ class TestResolve:
 | 
			
		||||
        vsocket = virtual_socket
 | 
			
		||||
        vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known')
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        output_spy.begin()
 | 
			
		||||
        with pytest.raises(SystemExit):
 | 
			
		||||
            list(s._resolve())
 | 
			
		||||
@@ -32,7 +33,7 @@ class TestResolve:
 | 
			
		||||
        vsocket = virtual_socket
 | 
			
		||||
        vsocket.gsock.addrinfodata['localhost#22'] = []
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        output_spy.begin()
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 0
 | 
			
		||||
@@ -40,7 +41,7 @@ class TestResolve:
 | 
			
		||||
    def test_resolve_ipv4(self, virtual_socket):
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        conf.ipv4 = True
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 1
 | 
			
		||||
        assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
 | 
			
		||||
@@ -48,14 +49,14 @@ class TestResolve:
 | 
			
		||||
    def test_resolve_ipv6(self, virtual_socket):
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        conf.ipv6 = True
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 1
 | 
			
		||||
        assert r[0] == (socket.AF_INET6, ('::1', 22))
 | 
			
		||||
 | 
			
		||||
    def test_resolve_ipv46_both(self, virtual_socket):
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 2
 | 
			
		||||
        assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
 | 
			
		||||
@@ -65,7 +66,7 @@ class TestResolve:
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        conf.ipv4 = True
 | 
			
		||||
        conf.ipv6 = True
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 2
 | 
			
		||||
        assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
 | 
			
		||||
@@ -73,7 +74,7 @@ class TestResolve:
 | 
			
		||||
        conf = self._conf()
 | 
			
		||||
        conf.ipv6 = True
 | 
			
		||||
        conf.ipv4 = True
 | 
			
		||||
        s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
 | 
			
		||||
        r = list(s._resolve())
 | 
			
		||||
        assert len(r) == 2
 | 
			
		||||
        assert r[0] == (socket.AF_INET6, ('::1', 22))
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from ssh_audit.outputbuffer import OutputBuffer
 | 
			
		||||
from ssh_audit.ssh_socket import SSH_Socket
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -7,24 +8,25 @@ from ssh_audit.ssh_socket import SSH_Socket
 | 
			
		||||
class TestSocket:
 | 
			
		||||
    @pytest.fixture(autouse=True)
 | 
			
		||||
    def init(self, ssh_audit):
 | 
			
		||||
        self.OutputBuffer = OutputBuffer
 | 
			
		||||
        self.ssh_socket = SSH_Socket
 | 
			
		||||
 | 
			
		||||
    def test_invalid_host(self, virtual_socket):
 | 
			
		||||
        with pytest.raises(ValueError):
 | 
			
		||||
            self.ssh_socket(None, 22)
 | 
			
		||||
            self.ssh_socket(self.OutputBuffer(), None, 22)
 | 
			
		||||
 | 
			
		||||
    def test_invalid_port(self, virtual_socket):
 | 
			
		||||
        with pytest.raises(ValueError):
 | 
			
		||||
            self.ssh_socket('localhost', 'abc')
 | 
			
		||||
            self.ssh_socket(self.OutputBuffer(), 'localhost', 'abc')
 | 
			
		||||
        with pytest.raises(ValueError):
 | 
			
		||||
            self.ssh_socket('localhost', -1)
 | 
			
		||||
            self.ssh_socket(self.OutputBuffer(), 'localhost', -1)
 | 
			
		||||
        with pytest.raises(ValueError):
 | 
			
		||||
            self.ssh_socket('localhost', 0)
 | 
			
		||||
            self.ssh_socket(self.OutputBuffer(), 'localhost', 0)
 | 
			
		||||
        with pytest.raises(ValueError):
 | 
			
		||||
            self.ssh_socket('localhost', 65536)
 | 
			
		||||
            self.ssh_socket(self.OutputBuffer(), 'localhost', 65536)
 | 
			
		||||
 | 
			
		||||
    def test_not_connected_socket(self, virtual_socket):
 | 
			
		||||
        sock = self.ssh_socket('localhost', 22)
 | 
			
		||||
        sock = self.ssh_socket(self.OutputBuffer(), 'localhost', 22)
 | 
			
		||||
        banner, header, err = sock.get_banner()
 | 
			
		||||
        assert banner is None
 | 
			
		||||
        assert len(header) == 0
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user