diff --git a/ssh-audit.py b/ssh-audit.py index bf43e95..0dc8bb1 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1141,16 +1141,17 @@ class SSH(object): # pylint: disable=too-few-public-methods sys.exit(1) def get_banner(self, sshv=2): - # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] + # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type], Optional[str]] banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') rto = self.__sock.gettimeout() self.__sock.settimeout(0.7) s, e = self.recv() self.__sock.settimeout(rto) if s < 0: - return self.__banner, self.__header + return self.__banner, self.__header, e if self.__state < self.SM_BANNER_SENT: self.send_banner(banner) + e = None while self.__banner is None: if not s > 0: s, e = self.recv() @@ -1166,14 +1167,14 @@ class SSH(object): # pylint: disable=too-few-public-methods continue self.__header.append(line) s = 0 - return self.__banner, self.__header + return self.__banner, self.__header, e def recv(self, size=2048): # type: (int) -> Tuple[int, Optional[str]] try: data = self.__sock.recv(size) except socket.timeout: - return (-1, 'timeout') + return (-1, 'timed out') except socket.error as e: if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): return (0, 'retry') @@ -1971,9 +1972,12 @@ def audit(aconf, sshv=None): if sshv is None: sshv = 2 if aconf.ssh2 else 1 err = None - banner, header = s.get_banner(sshv) + banner, header, err = s.get_banner(sshv) if banner is None: - err = '[exception] did not receive banner.' + if err is None: + err = '[exception] did not receive banner.' + else: + err = '[exception] did not receive banner: {0}'.format(err) if err is None: packet_type, payload = s.read_packet(sshv) if packet_type < 0: diff --git a/test/test_errors.py b/test/test_errors.py index ad35a54..e37f60e 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -17,46 +17,91 @@ class TestErrors(object): conf.batch = True return conf + def _audit(self, spy, conf=None, sysexit=True): + if conf is None: + conf = self._conf() + spy.begin() + if sysexit: + with pytest.raises(SystemExit): + self.audit(conf) + else: + self.audit(conf) + lines = spy.flush() + return lines + def test_connection_refused(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.errors['connect'] = socket.error(61, 'Connection refused') - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 1 assert 'Connection refused' in lines[-1] + def test_connection_timeout(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.errors['connect'] = socket.timeout('timed out') + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'timed out' in lines[-1] + + def test_recv_empty(self, output_spy, virtual_socket): + vsocket = virtual_socket + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + + def test_recv_timeout(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(socket.timeout('timed out')) + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + assert 'timed out' in lines[-1] + + def test_recv_retry_till_timeout(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.timeout('timed out')) + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + assert 'timed out' in lines[-1] + + def test_recv_retry_till_reset(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable')) + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + assert 'reset by peer' in lines[-1] + def test_connection_closed_before_banner(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 1 assert 'did not receive banner' in lines[-1] + assert 'reset by peer' in lines[-1] def test_connection_closed_after_header(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'header line 1\n') vsocket.rdata.append(b'header line 2\n') vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 3 assert 'did not receive banner' in lines[-1] + assert 'reset by peer' in lines[-1] def test_connection_closed_after_banner(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 2 assert 'error reading packet' in lines[-1] assert 'reset by peer' in lines[-1] @@ -64,10 +109,7 @@ class TestErrors(object): def test_empty_data_after_banner(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 2 assert 'error reading packet' in lines[-1] assert 'empty' in lines[-1] @@ -76,10 +118,7 @@ class TestErrors(object): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'xxx\n') - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 2 assert 'error reading packet' in lines[-1] assert 'xxx' in lines[-1] @@ -87,10 +126,7 @@ class TestErrors(object): def test_non_ascii_banner(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\xc3\xbc\r\n') - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 3 assert 'error reading packet' in lines[-1] assert 'ASCII' in lines[-2] @@ -100,10 +136,7 @@ class TestErrors(object): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'\x81\xff\n') - output_spy.begin() - with pytest.raises(SystemExit): - self.audit(self._conf()) - lines = output_spy.flush() + lines = self._audit(output_spy) assert len(lines) == 2 assert 'error reading packet' in lines[-1] assert '\\x81\\xff' in lines[-1] @@ -112,12 +145,9 @@ class TestErrors(object): vsocket = virtual_socket vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') vsocket.rdata.append(b'Protocol major versions differ.\n') - output_spy.begin() - with pytest.raises(SystemExit): - conf = self._conf() - conf.ssh1, conf.ssh2 = True, False - self.audit(conf) - lines = output_spy.flush() + conf = self._conf() + conf.ssh1, conf.ssh2 = True, False + lines = self._audit(output_spy, conf) assert len(lines) == 3 assert 'error reading packet' in lines[-1] assert 'major versions differ' in lines[-1]