SSH_Socket's constructor now takes an OutputBuffer for verbose & debugging output.

This commit is contained in:
Joe Testa 2021-03-02 11:25:37 -05:00
parent 83bd049486
commit 8e9fe20fac
6 changed files with 37 additions and 33 deletions

View File

@ -43,12 +43,12 @@ class GEXTest:
if s.is_connected():
return True
err = s.connect(out)
err = s.connect()
if err is not None:
out.v(err, write_now=True)
return False
_, _, err = s.get_banner(out)
_, _, err = s.get_banner()
if err is not None:
out.v(err, write_now=True)
s.close()
@ -56,7 +56,7 @@ class GEXTest:
# Send our KEX using the specified group-exchange and most of the
# server's own values.
s.send_kexinit(out, key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages)
s.send_kexinit(key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages)
# Parse the server's KEX.
_, payload = s.read_packet(2)

View File

@ -109,19 +109,19 @@ class HostKeyTest:
# If the connection is closed, re-open it and get the kex again.
if not s.is_connected():
err = s.connect(out)
err = s.connect()
if err is not None:
out.v(err, write_now=True)
return
_, _, err = s.get_banner(out)
_, _, err = s.get_banner()
if err is not None:
out.v(err, write_now=True)
s.close()
return
# Send our KEX using the specified group-exchange and most of the server's own values.
s.send_kexinit(out, key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages)
s.send_kexinit(key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages)
# Parse the server's KEX.
_, payload = s.read_packet()

View File

@ -820,14 +820,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
out.debug = aconf.debug
out.level = aconf.level
out.use_colors = aconf.colors
s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
s = SSH_Socket(out, aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
if aconf.client_audit:
out.v("Listening for client connection on port %d..." % aconf.port, write_now=True)
s.listen_and_accept()
else:
out.v("Starting audit of %s:%d..." % ('[%s]' % aconf.host if Utils.is_ipv6_address(aconf.host) else aconf.host, aconf.port), write_now=True)
err = s.connect(out)
err = s.connect()
if err is not None:
out.fail(err)
@ -842,14 +842,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
if sshv is None:
sshv = 2 if aconf.ssh2 else 1
err = None
banner, header, err = s.get_banner(out, sshv)
banner, header, err = s.get_banner(sshv)
if banner is None:
if err is None:
err = '[exception] did not receive banner.'
else:
err = '[exception] did not receive banner: {}'.format(err)
if err is None:
s.send_kexinit(out) # Send the algorithms we support (except we don't since this isn't a real SSH connection).
s.send_kexinit() # Send the algorithms we support (except we don't since this isn't a real SSH connection).
packet_type, payload = s.read_packet(sshv)
if packet_type < 0:

View File

@ -52,8 +52,9 @@ class SSH_Socket(ReadBuf, WriteBuf):
SM_BANNER_SENT = 1
def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value
def __init__(self, outputbuffer: 'OutputBuffer', host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value
super(SSH_Socket, self).__init__()
self.__outputbuffer = outputbuffer
self.__sock: Optional[socket.socket] = None
self.__sock_map: Dict[int, socket.socket] = {}
self.__block_size = 8
@ -90,7 +91,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
if socktype == socket.SOCK_STREAM:
yield af, addr
except socket.error as e:
OutputBuffer().fail('[exception] {}'.format(e)).write()
self.__outputbuffer.fail('[exception] {}'.format(e)).write()
sys.exit(exitcodes.CONNECTION_ERROR)
# Listens on a server socket and accepts one connection (used for
@ -148,7 +149,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
c.settimeout(self.__timeout)
self.__sock = c
def connect(self, out: 'OutputBuffer') -> Optional[str]:
def connect(self) -> Optional[str]:
'''Returns None on success, or an error string.'''
err = None
for af, addr in self._resolve():
@ -156,7 +157,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
try:
s = socket.socket(af, socket.SOCK_STREAM)
s.settimeout(self.__timeout)
out.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True)
self.__outputbuffer.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True)
s.connect(addr)
self.__sock = s
return None
@ -170,8 +171,8 @@ class SSH_Socket(ReadBuf, WriteBuf):
errm = 'cannot connect to {} port {}: {}'.format(*errt)
return '[exception] {}'.format(errm)
def get_banner(self, out: 'OutputBuffer', sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]:
out.d('Getting banner...', write_now=True)
def get_banner(self, sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]:
self.__outputbuffer.d('Getting banner...', write_now=True)
if self.__sock is None:
return self.__banner, self.__header, 'not connected'
@ -229,10 +230,10 @@ class SSH_Socket(ReadBuf, WriteBuf):
return -1, str(e.args[-1])
# Send a KEXINIT with the lists of key exchanges, hostkeys, ciphers, MACs, compressions, and languages that we "support".
def send_kexinit(self, out: 'OutputBuffer', key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value
def send_kexinit(self, key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value
'''Sends the list of supported host keys, key exchanges, ciphers, and MACs. Emulates OpenSSH v8.2.'''
out.d('KEX initialisation...', write_now=True)
self.__outputbuffer.d('KEX initialisation...', write_now=True)
kexparty = SSH2_KexParty(ciphers, macs, compressions, languages)
kex = SSH2_Kex(os.urandom(16), key_exchanges, hostkeys, kexparty, kexparty, False, 0)
@ -273,7 +274,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
payload_length = packet_length - padding_length - 1
check_size = 4 + 1 + payload_length + padding_length
if check_size % self.__block_size != 0:
OutputBuffer().fail('[exception] invalid ssh packet (block size)').write()
self.__outputbuffer.fail('[exception] invalid ssh packet (block size)').write()
sys.exit(exitcodes.CONNECTION_ERROR)
self.ensure_read(payload_length)
if sshv == 1:
@ -288,7 +289,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
if sshv == 1:
rcrc = SSH1.crc32(padding + payload)
if crc != rcrc:
OutputBuffer().fail('[exception] packet checksum CRC32 mismatch.').write()
self.__outputbuffer.fail('[exception] packet checksum CRC32 mismatch.').write()
sys.exit(exitcodes.CONNECTION_ERROR)
else:
self.ensure_read(padding_length)

View File

@ -8,6 +8,7 @@ class TestResolve:
def init(self, ssh_audit):
self.AuditConf = ssh_audit.AuditConf
self.audit = ssh_audit.audit
self.OutputBuffer = ssh_audit.OutputBuffer
self.ssh_socket = ssh_audit.SSH_Socket
def _conf(self):
@ -20,7 +21,7 @@ class TestResolve:
vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known')
conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
output_spy.begin()
with pytest.raises(SystemExit):
list(s._resolve())
@ -32,7 +33,7 @@ class TestResolve:
vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = []
conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
output_spy.begin()
r = list(s._resolve())
assert len(r) == 0
@ -40,7 +41,7 @@ class TestResolve:
def test_resolve_ipv4(self, virtual_socket):
conf = self._conf()
conf.ipv4 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve())
assert len(r) == 1
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -48,14 +49,14 @@ class TestResolve:
def test_resolve_ipv6(self, virtual_socket):
conf = self._conf()
conf.ipv6 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve())
assert len(r) == 1
assert r[0] == (socket.AF_INET6, ('::1', 22))
def test_resolve_ipv46_both(self, virtual_socket):
conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve())
assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -65,7 +66,7 @@ class TestResolve:
conf = self._conf()
conf.ipv4 = True
conf.ipv6 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve())
assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -73,7 +74,7 @@ class TestResolve:
conf = self._conf()
conf.ipv6 = True
conf.ipv4 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference)
s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve())
assert len(r) == 2
assert r[0] == (socket.AF_INET6, ('::1', 22))

View File

@ -1,5 +1,6 @@
import pytest
from ssh_audit.outputbuffer import OutputBuffer
from ssh_audit.ssh_socket import SSH_Socket
@ -7,24 +8,25 @@ from ssh_audit.ssh_socket import SSH_Socket
class TestSocket:
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
self.OutputBuffer = OutputBuffer
self.ssh_socket = SSH_Socket
def test_invalid_host(self, virtual_socket):
with pytest.raises(ValueError):
self.ssh_socket(None, 22)
self.ssh_socket(self.OutputBuffer(), None, 22)
def test_invalid_port(self, virtual_socket):
with pytest.raises(ValueError):
self.ssh_socket('localhost', 'abc')
self.ssh_socket(self.OutputBuffer(), 'localhost', 'abc')
with pytest.raises(ValueError):
self.ssh_socket('localhost', -1)
self.ssh_socket(self.OutputBuffer(), 'localhost', -1)
with pytest.raises(ValueError):
self.ssh_socket('localhost', 0)
self.ssh_socket(self.OutputBuffer(), 'localhost', 0)
with pytest.raises(ValueError):
self.ssh_socket('localhost', 65536)
self.ssh_socket(self.OutputBuffer(), 'localhost', 65536)
def test_not_connected_socket(self, virtual_socket):
sock = self.ssh_socket('localhost', 22)
sock = self.ssh_socket(self.OutputBuffer(), 'localhost', 22)
banner, header, err = sock.get_banner()
assert banner is None
assert len(header) == 0