diff --git a/ssh-audit.py b/ssh-audit.py index 7bb1e1b..13f76b9 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1657,7 +1657,10 @@ def audit(conf, sshv=None): if err is None: packet_type, payload = s.read_packet(sshv) if packet_type < 0: - payload = payload.decode('utf-8') if payload else u'empty' + try: + payload = payload.decode('utf-8') if payload else u'empty' + except UnicodeDecodeError: + payload = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) if payload == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) diff --git a/test/test_errors.py b/test/test_errors.py index 17ef23c..13cc9e2 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -81,6 +81,18 @@ class TestErrors(object): assert 'error reading packet' in lines[-1] assert 'xxx' in lines[-1] + def test_nonutf8_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(b'\x81\xff\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert '\\x81\\xff' in lines[-1] + def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n')