diff --git a/ssh-audit.py b/ssh-audit.py index a5fa042..7bb1e1b 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -25,6 +25,10 @@ """ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 +try: + from typing import List, Tuple, Text +except: + pass VERSION = 'v1.6.1.dev' @@ -940,14 +944,15 @@ class SSH(object): return self.__banner, self.__header def recv(self, size=2048): + # type: (int) -> Tuple[int, str] try: data = self.__sock.recv(size) - except socket.timeout as e: - r = 0 if e.strerror == 'timed out' else -1 - return (r, e) + except socket.timeout: + return (-1, 'timeout') except socket.error as e: - r = 0 if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK) else -1 - return (r, e) + if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): + return (0, 'retry') + return (-1, str(e.args[-1])) if len(data) == 0: return (-1, None) pos = self._buf.tell() @@ -977,6 +982,7 @@ class SSH(object): raise SSH.Socket.InsufficientReadException(e) def read_packet(self, sshv=2): + # type: (int) -> Tuple[int, bytes] try: header = WriteBuf() self.ensure_read(4) @@ -1024,7 +1030,7 @@ class SSH(object): header.write(self.read(self.unread_len)) e = header.write_flush().strip() else: - e = ex.args[0] + e = ex.args[0].encode('utf-8') return (-1, e) def send_packet(self): @@ -1651,7 +1657,7 @@ def audit(conf, sshv=None): if err is None: packet_type, payload = s.read_packet(sshv) if packet_type < 0: - payload = str(payload).decode('utf-8') + payload = payload.decode('utf-8') if payload else u'empty' if payload == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) diff --git a/test/test_errors.py b/test/test_errors.py new file mode 100644 index 0000000..17ef23c --- /dev/null +++ b/test/test_errors.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import pytest, socket + + +class TestErrors(object): + @pytest.fixture(autouse=True) + def init(self, ssh_audit): + self.AuditConf = ssh_audit.AuditConf + self.audit = ssh_audit.audit + + def _conf(self): + conf = self.AuditConf('localhost', 22) + conf.batch = True + return conf + + def test_connection_refused(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.errors['connect'] = socket.error(61, 'Connection refused') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'Connection refused' in lines[-1] + + def test_connection_closed_before_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + + def test_connection_closed_after_header(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'header line 1\n') + vsocket.rdata.append(b'header line 2\n') + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'did not receive banner' in lines[-1] + + def test_connection_closed_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'reset by peer' in lines[-1] + + def test_empty_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'empty' in lines[-1] + + def test_wrong_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(b'xxx\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'xxx' in lines[-1] + + def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') + vsocket.rdata.append(b'Protocol major versions differ.\n') + output_spy.begin() + with pytest.raises(SystemExit): + conf = self._conf() + conf.ssh1, conf.ssh2 = True, False + self.audit(conf) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'error reading packet' in lines[-1] + assert 'major versions differ' in lines[-1]