From c9d58bb82743d46870d1b7e0697ba7a6be814c1d Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Fri, 14 Oct 2016 09:14:07 +0300 Subject: [PATCH 01/28] Switch to new development version. --- ssh-audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index 422abc3..a5fa042 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -26,7 +26,7 @@ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 -VERSION = 'v1.5.1.dev' +VERSION = 'v1.6.1.dev' def usage(err=None): From 63a9c479a7c36628ff04e000d4ed7c67b0638b3a Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Fri, 14 Oct 2016 16:17:38 +0300 Subject: [PATCH 02/28] Test kex payload generation. --- test/test_ssh2.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/test/test_ssh2.py b/test/test_ssh2.py index a9cc425..93093b0 100644 --- a/test/test_ssh2.py +++ b/test/test_ssh2.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest +import pytest, os class TestSSH2(object): @@ -46,3 +46,62 @@ class TestSSH2(object): assert kex.server.languages == [u''] assert kex.follows is False assert kex.unused == 0 + + def _get_empty_kex(self, cookie=None): + kex_algs, key_algs = [], [] + enc, mac, compression, languages = [], [], ['none'], [] + cli = self.ssh2.KexParty(enc, mac, compression, languages) + enc, mac, compression, languages = [], [], ['none'], [] + srv = self.ssh2.KexParty(enc, mac, compression, languages) + if cookie is None: + cookie = os.urandom(16) + kex = self.ssh2.Kex(cookie, kex_algs, key_algs, cli, srv, 0) + return kex + + def _get_kex_variat1(self): + cookie = b'\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff' + kex = self._get_empty_kex(cookie) + kex.kex_algorithms.append('curve25519-sha256@libssh.org') + kex.kex_algorithms.append('ecdh-sha2-nistp256') + kex.kex_algorithms.append('ecdh-sha2-nistp384') + kex.kex_algorithms.append('ecdh-sha2-nistp521') + kex.kex_algorithms.append('diffie-hellman-group-exchange-sha256') + kex.kex_algorithms.append('diffie-hellman-group14-sha1') + kex.key_algorithms.append('ssh-rsa') + kex.key_algorithms.append('rsa-sha2-512') + kex.key_algorithms.append('rsa-sha2-256') + kex.key_algorithms.append('ssh-ed25519') + kex.server.encryption.append('chacha20-poly1305@openssh.com') + kex.server.encryption.append('aes128-ctr') + kex.server.encryption.append('aes192-ctr') + kex.server.encryption.append('aes256-ctr') + kex.server.encryption.append('aes128-gcm@openssh.com') + kex.server.encryption.append('aes256-gcm@openssh.com') + kex.server.encryption.append('aes128-cbc') + kex.server.encryption.append('aes192-cbc') + kex.server.encryption.append('aes256-cbc') + kex.server.mac.append('umac-64-etm@openssh.com') + kex.server.mac.append('umac-128-etm@openssh.com') + kex.server.mac.append('hmac-sha2-256-etm@openssh.com') + kex.server.mac.append('hmac-sha2-512-etm@openssh.com') + kex.server.mac.append('hmac-sha1-etm@openssh.com') + kex.server.mac.append('umac-64@openssh.com') + kex.server.mac.append('umac-128@openssh.com') + kex.server.mac.append('hmac-sha2-256') + kex.server.mac.append('hmac-sha2-512') + kex.server.mac.append('hmac-sha1') + kex.server.compression.append('zlib@openssh.com') + for a in kex.server.encryption: + kex.client.encryption.append(a) + for a in kex.server.mac: + kex.client.mac.append(a) + for a in kex.server.compression: + if a == 'none': + continue + kex.client.compression.append(a) + return kex + + def test_key_payload(self): + kex1 = self._get_kex_variat1() + kex2 = self.ssh2.Kex.parse(self._kex_payload()) + assert kex1.payload == kex2.payload From f0651189596665e6ff0ed46c52e16144a0707a1e Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Mon, 17 Oct 2016 20:27:35 +0300 Subject: [PATCH 03/28] Create virtual socket fixture (socket mocking). --- test/conftest.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 33e9bb4..28ab4ef 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, os, sys, io +import pytest, os, sys, io, socket if sys.version_info[0] == 2: @@ -33,3 +33,80 @@ class _OutputSpy(list): @pytest.fixture(scope='module') def output_spy(): return _OutputSpy() + + +class _VirtualSocket(object): + def __init__(self): + self.sock_address = ('127.0.0.1', 0) + self.peer_address = None + self._connected = False + self.timeout = -1.0 + self.rdata = [] + self.sdata = [] + self.errors = {} + + def _check_err(self, method): + method_error = self.errors.get(method) + if method_error: + raise method_error + + def _connect(self, address): + self.peer_address = address + self._connected = True + self._check_err('connect') + return self + + def settimeout(self, timeout): + self.timeout = timeout + + def gettimeout(self): + return self.timeout + + def getpeername(self): + if self.peer_address is None or not self._connected: + raise socket.error(57, 'Socket is not connected') + return self.peer_address + + def getsockname(self): + return self.sock_address + + def bind(self, address): + self.sock_address = address + + def listen(self, backlog): + pass + + def accept(self): + conn = _VirtualSocket() + conn.sock_address = self.sock_address + conn.peer_address = ('127.0.0.1', 0) + conn._connected = True + return conn, conn.peer_address + + def recv(self, bufsize, flags=0): + if not self._connected: + raise socket.error(54, 'Connection reset by peer') + if not len(self.rdata) > 0: + return b'' + data = self.rdata.pop(0) + if isinstance(data, Exception): + raise data + return data + + def send(self, data): + if self.peer_address is None or not self._connected: + raise socket.error(32, 'Broken pipe') + self._check_err('send') + self.sdata.append(data) + + +@pytest.fixture() +def virtual_socket(monkeypatch): + vsocket = _VirtualSocket() + def _c(address): + return vsocket._connect(address) + def _cc(address, timeout=0, source_address=None): + return vsocket._connect(address) + monkeypatch.setattr(socket, 'create_connection', _cc) + monkeypatch.setattr(socket.socket, 'connect', _c) + return vsocket From 6b76e68d0dab0701a833ee85663161278ef2b5d0 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Mon, 17 Oct 2016 20:31:13 +0300 Subject: [PATCH 04/28] Fix wrongly introduced Python 3 incompatibility. Fixes #14 and #15. Add static type checks via mypy (optional static type checker), Add relevant tests, which could trigger the issue. --- ssh-audit.py | 20 ++++++---- test/test_errors.py | 96 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 test/test_errors.py diff --git a/ssh-audit.py b/ssh-audit.py index a5fa042..7bb1e1b 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -25,6 +25,10 @@ """ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 +try: + from typing import List, Tuple, Text +except: + pass VERSION = 'v1.6.1.dev' @@ -940,14 +944,15 @@ class SSH(object): return self.__banner, self.__header def recv(self, size=2048): + # type: (int) -> Tuple[int, str] try: data = self.__sock.recv(size) - except socket.timeout as e: - r = 0 if e.strerror == 'timed out' else -1 - return (r, e) + except socket.timeout: + return (-1, 'timeout') except socket.error as e: - r = 0 if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK) else -1 - return (r, e) + if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): + return (0, 'retry') + return (-1, str(e.args[-1])) if len(data) == 0: return (-1, None) pos = self._buf.tell() @@ -977,6 +982,7 @@ class SSH(object): raise SSH.Socket.InsufficientReadException(e) def read_packet(self, sshv=2): + # type: (int) -> Tuple[int, bytes] try: header = WriteBuf() self.ensure_read(4) @@ -1024,7 +1030,7 @@ class SSH(object): header.write(self.read(self.unread_len)) e = header.write_flush().strip() else: - e = ex.args[0] + e = ex.args[0].encode('utf-8') return (-1, e) def send_packet(self): @@ -1651,7 +1657,7 @@ def audit(conf, sshv=None): if err is None: packet_type, payload = s.read_packet(sshv) if packet_type < 0: - payload = str(payload).decode('utf-8') + payload = payload.decode('utf-8') if payload else u'empty' if payload == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) diff --git a/test/test_errors.py b/test/test_errors.py new file mode 100644 index 0000000..17ef23c --- /dev/null +++ b/test/test_errors.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import pytest, socket + + +class TestErrors(object): + @pytest.fixture(autouse=True) + def init(self, ssh_audit): + self.AuditConf = ssh_audit.AuditConf + self.audit = ssh_audit.audit + + def _conf(self): + conf = self.AuditConf('localhost', 22) + conf.batch = True + return conf + + def test_connection_refused(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.errors['connect'] = socket.error(61, 'Connection refused') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'Connection refused' in lines[-1] + + def test_connection_closed_before_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'did not receive banner' in lines[-1] + + def test_connection_closed_after_header(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'header line 1\n') + vsocket.rdata.append(b'header line 2\n') + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'did not receive banner' in lines[-1] + + def test_connection_closed_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'reset by peer' in lines[-1] + + def test_empty_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'empty' in lines[-1] + + def test_wrong_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(b'xxx\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert 'xxx' in lines[-1] + + def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') + vsocket.rdata.append(b'Protocol major versions differ.\n') + output_spy.begin() + with pytest.raises(SystemExit): + conf = self._conf() + conf.ssh1, conf.ssh2 = True, False + self.audit(conf) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'error reading packet' in lines[-1] + assert 'major versions differ' in lines[-1] From 8ca6ec591d43e50ff72e08e40ebc05546da567ad Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 18 Oct 2016 09:45:03 +0300 Subject: [PATCH 05/28] Handle the case when received data is in wrong encoding (not utf-8). --- ssh-audit.py | 5 ++++- test/test_errors.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index 7bb1e1b..13f76b9 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1657,7 +1657,10 @@ def audit(conf, sshv=None): if err is None: packet_type, payload = s.read_packet(sshv) if packet_type < 0: - payload = payload.decode('utf-8') if payload else u'empty' + try: + payload = payload.decode('utf-8') if payload else u'empty' + except UnicodeDecodeError: + payload = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) if payload == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) diff --git a/test/test_errors.py b/test/test_errors.py index 17ef23c..13cc9e2 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -81,6 +81,18 @@ class TestErrors(object): assert 'error reading packet' in lines[-1] assert 'xxx' in lines[-1] + def test_nonutf8_data_after_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') + vsocket.rdata.append(b'\x81\xff\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 2 + assert 'error reading packet' in lines[-1] + assert '\\x81\\xff' in lines[-1] + def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') From fabb4b5bb2cf5231b4d2bfcf6dd51a3069e4d7fc Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 19 Oct 2016 20:47:13 +0300 Subject: [PATCH 06/28] Add static typing and refactor code to pass all mypy checks. Move Python compatibility types to first lines of code. Add Python (text/byte) compatibility helper functions. Check for SSH banner ASCII validity. --- ssh-audit.py | 372 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 282 insertions(+), 90 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index 13f76b9..7a27514 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -25,15 +25,26 @@ """ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 -try: - from typing import List, Tuple, Text -except: - pass VERSION = 'v1.6.1.dev' +if sys.version_info >= (3,): + StringIO, BytesIO = io.StringIO, io.BytesIO + text_type = str + binary_type = bytes +else: + import StringIO as _StringIO + StringIO = BytesIO = _StringIO.StringIO + text_type = unicode + binary_type = str +try: + from typing import List, Tuple, Optional, Callable, Union, Any +except: + pass + def usage(err=None): + # type: (Optional[str]) -> None out = Output() p = os.path.basename(sys.argv[0]) out.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) @@ -53,6 +64,7 @@ def usage(err=None): class AuditConf(object): def __init__(self, host=None, port=22): + # type: (Optional[str], int) -> None self.host = host self.port = port self.ssh1 = True @@ -63,6 +75,7 @@ class AuditConf(object): self.minlevel = 'info' def __setattr__(self, name, value): + # type: (str, Union[str, int, bool]) -> None valid = False if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: valid, value = True, True if value else False @@ -82,6 +95,7 @@ class AuditConf(object): @classmethod def from_cmdline(cls, args, usage_cb): + # type: (List[str], Callable[..., None]) -> AuditConf conf = cls() try: sopts = 'h12bnvl:' @@ -131,6 +145,7 @@ class Output(object): COLORS = {'head': 36, 'good': 32, 'warn': 33, 'fail': 31} def __init__(self): + # type: () -> None self.batch = False self.colors = True self.verbose = False @@ -138,34 +153,40 @@ class Output(object): @property def minlevel(self): + # type: () -> str if self.__minlevel < len(self.LEVELS): return self.LEVELS[self.__minlevel] return 'unknown' @minlevel.setter def minlevel(self, name): + # type: (str) -> None self.__minlevel = self.getlevel(name) def getlevel(self, name): + # type: (str) -> int cname = 'info' if name == 'good' else name if cname not in self.LEVELS: return sys.maxsize return self.LEVELS.index(cname) def sep(self): + # type: () -> None if not self.batch: print() def _colorized(self, color): + # type: (str) -> Callable[[text_type], None] return lambda x: print(u'{0}{1}\033[0m'.format(color, x)) def __getattr__(self, name): + # type: (str) -> Callable[[text_type], None] if name == 'head' and self.batch: return lambda x: None if not self.getlevel(name) >= self.__minlevel: return lambda x: None if self.colors and os.name == 'posix' and name in self.COLORS: - color = u'\033[0;{0}m'.format(self.COLORS[name]) + color = '\033[0;{0}m'.format(self.COLORS[name]) return self._colorized(color) else: return lambda x: print(u'{0}'.format(x)) @@ -173,16 +194,19 @@ class Output(object): class OutputBuffer(list): def __enter__(self): - self.__buf = utils.StringIO() + # type: () -> OutputBuffer + self.__buf = StringIO() self.__stdout = sys.stdout sys.stdout = self.__buf return self def flush(self): + # type: () -> None for line in self: print(line) def __exit__(self, *args): + # type: (*Any) -> None self.extend(self.__buf.getvalue().splitlines()) sys.stdout = self.__stdout @@ -190,6 +214,7 @@ class OutputBuffer(list): class SSH2(object): class KexParty(object): def __init__(self, enc, mac, compression, languages): + # type: (List[text_type], List[text_type], List[text_type], List[text_type]) -> None self.__enc = enc self.__mac = mac self.__compression = compression @@ -197,22 +222,27 @@ class SSH2(object): @property def encryption(self): + # type: () -> List[text_type] return self.__enc @property def mac(self): + # type: () -> List[text_type] return self.__mac @property def compression(self): + # type: () -> List[text_type] return self.__compression @property def languages(self): + # type: () -> List[text_type] return self.__languages class Kex(object): def __init__(self, cookie, kex_algs, key_algs, cli, srv, follows, unused=0): + # type: (binary_type, List[text_type], List[text_type], SSH2.KexParty, SSH2.KexParty, bool, int) -> None self.__cookie = cookie self.__kex_algs = kex_algs self.__key_algs = key_algs @@ -223,35 +253,43 @@ class SSH2(object): @property def cookie(self): + # type: () -> binary_type return self.__cookie @property def kex_algorithms(self): + # type: () -> List[text_type] return self.__kex_algs @property def key_algorithms(self): + # type: () -> List[text_type] return self.__key_algs # client_to_server @property def client(self): + # type: () -> SSH2.KexParty return self.__client # server_to_client @property def server(self): + # type: () -> SSH2.KexParty return self.__server @property def follows(self): + # type: () -> bool return self.__follows @property def unused(self): + # type: () -> int return self.__unused def write(self, wbuf): + # type: (WriteBuf) -> None wbuf.write(self.cookie) wbuf.write_list(self.kex_algorithms) wbuf.write_list(self.key_algorithms) @@ -268,12 +306,14 @@ class SSH2(object): @property def payload(self): + # type: () -> binary_type wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod def parse(cls, payload): + # type: (binary_type) -> SSH2.Kex buf = ReadBuf(payload) cookie = buf.read(16) kex_algs = buf.read_list() @@ -297,6 +337,7 @@ class SSH2(object): class SSH1(object): class CRC32(object): def __init__(self): + # type: () -> None self._table = [0] * 256 for i in range(256): crc = 0 @@ -308,6 +349,7 @@ class SSH1(object): self._table[i] = crc def calc(self, v): + # type: (binary_type) -> int crc, l = 0, len(v) for i in range(l): n = ord(v[i:i + 1]) @@ -315,12 +357,13 @@ class SSH1(object): crc = (crc >> 8) ^ self._table[n] return crc - _crc32 = None + _crc32 = None # type: Optional[SSH1.CRC32] CIPHERS = ['none', 'idea', 'des', '3des', 'tss', 'rc4', 'blowfish'] AUTHS = [None, 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos'] @classmethod def crc32(cls, v): + # type: (binary_type) -> int if cls._crc32 is None: cls._crc32 = cls.CRC32() return cls._crc32.calc(v) @@ -353,10 +396,11 @@ class SSH1(object): 'tis': [['1.2.2']], 'kerberos': [['1.2.2', '3.6'], [FAIL_OPENSSH37_REMOVE]], } - } + } # type: Dict[str, Dict[str, List[List[str]]]] class PublicKeyMessage(object): def __init__(self, cookie, skey, hkey, pflags, cmask, amask): + # type: (binary_type, Tuple[int, int, int], Tuple[int, int, int], int, int, int) -> None assert len(skey) == 3 assert len(hkey) == 3 self.__cookie = cookie @@ -368,67 +412,81 @@ class SSH1(object): @property def cookie(self): + # type: () -> binary_type return self.__cookie @property def server_key_bits(self): + # type: () -> int return self.__server_key[0] @property def server_key_public_exponent(self): + # type: () -> int return self.__server_key[1] @property def server_key_public_modulus(self): + # type: () -> int return self.__server_key[2] @property def host_key_bits(self): + # type: () -> int return self.__host_key[0] @property def host_key_public_exponent(self): + # type: () -> int return self.__host_key[1] @property def host_key_public_modulus(self): + # type: () -> int return self.__host_key[2] @property def host_key_fingerprint_data(self): + # type: () -> binary_type mod = WriteBuf._create_mpint(self.host_key_public_modulus, False) e = WriteBuf._create_mpint(self.host_key_public_exponent, False) return mod + e @property def protocol_flags(self): + # type: () -> int return self.__protocol_flags @property def supported_ciphers_mask(self): + # type: () -> int return self.__supported_ciphers_mask @property def supported_ciphers(self): + # type: () -> List[text_type] ciphers = [] for i in range(len(SSH1.CIPHERS)): if self.__supported_ciphers_mask & (1 << i) != 0: - ciphers.append(SSH1.CIPHERS[i]) + ciphers.append(utils.to_utext(SSH1.CIPHERS[i])) return ciphers @property def supported_authentications_mask(self): + # type: () -> int return self.__supported_authentications_mask @property def supported_authentications(self): + # type: () -> List[text_type] auths = [] for i in range(1, len(SSH1.AUTHS)): if self.__supported_authentications_mask & (1 << i) != 0: - auths.append(SSH1.AUTHS[i]) + auths.append(utils.to_utext(SSH1.AUTHS[i])) return auths def write(self, wbuf): + # type: (WriteBuf) -> None wbuf.write(self.cookie) wbuf.write_int(self.server_key_bits) wbuf.write_mpint1(self.server_key_public_exponent) @@ -442,12 +500,14 @@ class SSH1(object): @property def payload(self): + # type: () -> binary_type wbuf = WriteBuf() self.write(wbuf) return wbuf.write_flush() @classmethod def parse(cls, payload): + # type: (binary_type) -> SSH1.PublicKeyMessage buf = ReadBuf(payload) cookie = buf.read(8) server_key_bits = buf.read_int() @@ -467,36 +527,45 @@ class SSH1(object): class ReadBuf(object): def __init__(self, data=None): + # type: (Optional[binary_type]) -> None super(ReadBuf, self).__init__() - self._buf = utils.BytesIO(data) if data else utils.BytesIO() + self._buf = BytesIO(data) if data else BytesIO() self._len = len(data) if data else 0 @property def unread_len(self): + # type: () -> int return self._len - self._buf.tell() def read(self, size): + # type: (int) -> binary_type return self._buf.read(size) def read_byte(self): + # type: () -> int return struct.unpack('B', self.read(1))[0] def read_bool(self): + # type: () -> bool return self.read_byte() != 0 def read_int(self): + # type: () -> int return struct.unpack('>I', self.read(4))[0] def read_list(self): + # type: () -> List[text_type] list_size = self.read_int() return self.read(list_size).decode().split(',') def read_string(self): + # type: () -> binary_type n = self.read_int() return self.read(n) @classmethod def _parse_mpint(cls, v, pad, sf): + # type: (binary_type, binary_type, str) -> int r = 0 if len(v) % 4: v = pad * (4 - (len(v) % 4)) + v @@ -505,12 +574,14 @@ class ReadBuf(object): return r def read_mpint1(self): + # type: () -> int # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = struct.unpack('>H', self.read(2))[0] n = (bits + 7) // 8 return self._parse_mpint(self.read(n), b'\x00', '>I') def read_mpint2(self): + # type: () -> int # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt v = self.read_string() if len(v) == 0: @@ -519,38 +590,47 @@ class ReadBuf(object): return self._parse_mpint(v, pad, sf) def read_line(self): + # type: () -> text_type return self._buf.readline().rstrip().decode('utf-8') class WriteBuf(object): def __init__(self, data=None): + # type: (Optional[binary_type]) -> None super(WriteBuf, self).__init__() - self._wbuf = io.BytesIO(data) if data else io.BytesIO() + self._wbuf = BytesIO(data) if data else BytesIO() def write(self, data): + # type: (binary_type) -> WriteBuf self._wbuf.write(data) return self def write_byte(self, v): + # type: (int) -> WriteBuf return self.write(struct.pack('B', v)) def write_bool(self, v): + # type: (bool) -> WriteBuf return self.write_byte(1 if v else 0) def write_int(self, v): + # type: (int) -> WriteBuf return self.write(struct.pack('>I', v)) def write_string(self, v): + # type: (Union[binary_type, text_type]) -> WriteBuf if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) self.write_int(len(v)) return self.write(v) def write_list(self, v): + # type: (List[text_type]) -> WriteBuf return self.write_string(u','.join(v)) @classmethod def _bitlength(cls, n): + # type: (int) -> int try: return n.bit_length() except AttributeError: @@ -558,11 +638,12 @@ class WriteBuf(object): @classmethod def _create_mpint(cls, n, signed=True, bits=None): + # type: (int, bool, Optional[int]) -> binary_type if bits is None: bits = cls._bitlength(n) length = bits // 8 + (1 if n != 0 else 0) ql = (length + 7) // 8 - fmt, v2 = '>{0}Q'.format(ql), [b'\x00'] * ql + fmt, v2 = '>{0}Q'.format(ql), [0] * ql for i in range(ql): v2[ql - i - 1] = (n & 0xffffffffffffffff) n >>= 64 @@ -574,6 +655,7 @@ class WriteBuf(object): return data def write_mpint1(self, n): + # type: (int) -> WriteBuf # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt bits = self._bitlength(n) data = self._create_mpint(n, False, bits) @@ -581,17 +663,20 @@ class WriteBuf(object): return self.write(data) def write_mpint2(self, n): + # type: (int) -> WriteBuf # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt data = self._create_mpint(n) return self.write_string(data) def write_line(self, v): + # type: (Union[binary_type, str]) -> WriteBuf if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) v += b'\r\n' return self.write(v) def write_flush(self): + # type: () -> binary_type payload = self._wbuf.getvalue() self._wbuf.truncate(0) self._wbuf.seek(0) @@ -613,6 +698,7 @@ class SSH(object): class Software(object): def __init__(self, vendor, product, version, patch, os): + # type: (Optional[str], str, str, Optional[str], Optional[str]) -> None self.__vendor = vendor self.__product = product self.__version = version @@ -621,28 +707,34 @@ class SSH(object): @property def vendor(self): + # type: () -> Optional[str] return self.__vendor @property def product(self): + # type: () -> str return self.__product @property def version(self): + # type: () -> str return self.__version @property def patch(self): + # type: () -> Optional[str] return self.__patch @property def os(self): + # type: () -> Optional[str] return self.__os def compare_version(self, other): + # type: (Union[None, SSH.Software, text_type]) -> int if other is None: return 1 - if isinstance(other, self.__class__): + if isinstance(other, SSH.Software): other = '{0}{1}'.format(other.version, other.patch or '') else: other = str(other) @@ -676,6 +768,7 @@ class SSH(object): return 0 def between_versions(self, vfrom, vtill): + # type: (str, str) -> bool if vfrom and self.compare_version(vfrom) < 0: return False if vtill and self.compare_version(vtill) > 0: @@ -683,6 +776,7 @@ class SSH(object): return True def display(self, full=True): + # type: (bool) -> str out = '{0} '.format(self.vendor) if self.vendor else '' out += self.product if self.version: @@ -701,9 +795,11 @@ class SSH(object): return out def __str__(self): + # type: () -> str return self.display() def __repr__(self): + # type: () -> str out = 'vendor={0}'.format(self.vendor) if self.vendor else '' if self.product: if self.vendor: @@ -719,10 +815,12 @@ class SSH(object): @staticmethod def _fix_patch(patch): + # type: (str) -> Optional[str] return re.sub(r'^[-_\.]+', '', patch) or None @staticmethod def _fix_date(d): + # type: (str) -> Optional[str] if d is not None and len(d) == 8: return '{0}-{1}-{2}'.format(d[:4], d[4:6], d[6:8]) else: @@ -730,6 +828,7 @@ class SSH(object): @classmethod def _extract_os(cls, c): + # type: (Optional[str]) -> str if c is None: return None mx = re.match(r'^NetBSD(?:_Secure_Shell)?(?:[\s-]+(\d{8})(.*))?$', c) @@ -756,6 +855,7 @@ class SSH(object): @classmethod def parse(cls, banner): + # type: (SSH.Banner) -> SSH.Software software = str(banner.software) mx = re.match(r'^dropbear_([\d\.]+\d+)(.*)', software) if mx: @@ -796,24 +896,35 @@ class SSH(object): RX_PROTOCOL = re.compile(re.sub(r'\\d(\+?)', '(\\d\g<1>)', _RXP)) RX_BANNER = re.compile(r'^({0}(?:(?:-{0})*)){1}$'.format(_RXP, _RXR)) - def __init__(self, protocol, software, comments): + def __init__(self, protocol, software, comments, valid_ascii): + # type: (Tuple[int, int], str, str, bool) -> None self.__protocol = protocol self.__software = software self.__comments = comments + self.__valid_ascii = valid_ascii @property def protocol(self): + # type: () -> Tuple[int, int] return self.__protocol @property def software(self): + # type: () -> str return self.__software @property def comments(self): + # type: () -> str return self.__comments + @property + def valid_ascii(self): + # type: () -> bool + return self.__valid_ascii + def __str__(self): + # type: () -> str out = 'SSH-{0}.{1}'.format(self.protocol[0], self.protocol[1]) if self.software is not None: out += '-{0}'.format(self.software) @@ -822,6 +933,7 @@ class SSH(object): return out def __repr__(self): + # type: () -> str p = '{0}.{1}'.format(self.protocol[0], self.protocol[1]) out = 'protocol={0}'.format(p) if self.software: @@ -832,7 +944,10 @@ class SSH(object): @classmethod def parse(cls, banner): - mx = cls.RX_BANNER.match(banner) + # type: (text_type) -> SSH.Banner + valid_ascii = utils.is_ascii(banner) + ascii_banner = utils.to_ascii(banner) + mx = cls.RX_BANNER.match(ascii_banner) if mx is None: return None protocol = min(re.findall(cls.RX_PROTOCOL, mx.group(1))) @@ -843,23 +958,26 @@ class SSH(object): comments = (mx.group(4) or '').strip() or None if comments is not None: comments = re.sub('\s+', ' ', comments) - return cls(protocol, software, comments) + return cls(protocol, software, comments, valid_ascii) class Fingerprint(object): def __init__(self, fpd): + # type: (binary_type) -> None self.__fpd = fpd @property def md5(self): + # type: () -> text_type h = hashlib.md5(self.__fpd).hexdigest() - h = u':'.join(h[i:i + 2] for i in range(0, len(h), 2)) - return u'MD5:{0}'.format(h) + r = u':'.join(h[i:i + 2] for i in range(0, len(h), 2)) + return u'MD5:{0}'.format(r) @property def sha256(self): + # type: () -> text_type h = base64.b64encode(hashlib.sha256(self.__fpd).digest()) - h = h.decode().rstrip('=') - return u'SHA256:{0}'.format(h) + r = h.decode('ascii').rstrip('=') + return u'SHA256:{0}'.format(r) class Security(object): CVE = { @@ -884,7 +1002,7 @@ class SSH(object): ['0.4.7', '0.5.2', 1, 'CVE-2012-4561', 5.0, 'cause DoS via unspecified vectors (invalid pointer)'], ['0.4.7', '0.5.2', 1, 'CVE-2012-4560', 7.5, 'cause DoS or execute arbitrary code (buffer overflow)'], ['0.4.7', '0.5.2', 1, 'CVE-2012-4559', 6.8, 'cause DoS or execute arbitrary code (double free)']] - } + } # type: Dict[str, List[List[Any]]] TXT = { 'Dropbear SSH': [ ['0.28', '0.34', 1, 'remote root exploit', 'remote format string buffer overflow exploit (exploit-db#387)']], @@ -892,7 +1010,7 @@ class SSH(object): ['0.3.3', '0.3.3', 1, 'null pointer check', 'missing null pointer check in "crypt_set_algorithms_server"'], ['0.3.3', '0.3.3', 1, 'integer overflow', 'integer overflow in "buffer_get_data"'], ['0.3.3', '0.3.3', 3, 'heap overflow', 'heap overflow in "packet_decrypt"']] - } + } # type: Dict[str, List[List[Any]]] class Socket(ReadBuf, WriteBuf): class InsufficientReadException(Exception): @@ -901,10 +1019,11 @@ class SSH(object): SM_BANNER_SENT = 1 def __init__(self, host, port, cto=3.0, rto=5.0): + # type: (str, int, float, float) -> None self.__block_size = 8 self.__state = 0 - self.__header = [] - self.__banner = None + self.__header = [] # type: List[text_type] + self.__banner = None # type: Optional[SSH.Banner] super(SSH.Socket, self).__init__() try: self.__sock = socket.create_connection((host, port), cto) @@ -914,9 +1033,11 @@ class SSH(object): sys.exit(1) def __enter__(self): + # type: () -> SSH.Socket return self def get_banner(self, sshv=2): + # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') rto = self.__sock.gettimeout() self.__sock.settimeout(0.7) @@ -944,7 +1065,7 @@ class SSH(object): return self.__banner, self.__header def recv(self, size=2048): - # type: (int) -> Tuple[int, str] + # type: (int) -> Tuple[int, Optional[str]] try: data = self.__sock.recv(size) except socket.timeout: @@ -963,26 +1084,29 @@ class SSH(object): return (len(data), None) def send(self, data): + # type: (binary_type) -> Tuple[int, Optional[str]] try: self.__sock.send(data) return (0, None) except socket.error as e: - return (-1, e) + return (-1, str(e.args[-1])) self.__sock.send(data) def send_banner(self, banner): + # type: (str) -> None self.send(banner.encode() + b'\r\n') if self.__state < self.SM_BANNER_SENT: self.__state = self.SM_BANNER_SENT def ensure_read(self, size): + # type: (int) -> None while self.unread_len < size: s, e = self.recv() if s < 0: raise SSH.Socket.InsufficientReadException(e) def read_packet(self, sshv=2): - # type: (int) -> Tuple[int, bytes] + # type: (int) -> Tuple[int, binary_type] try: header = WriteBuf() self.ensure_read(4) @@ -1034,6 +1158,7 @@ class SSH(object): return (-1, e) def send_packet(self): + # type: () -> Tuple[int, Optional[str]] payload = self.write_flush() padding = -(len(payload) + 5) % 8 if padding < 4: @@ -1044,12 +1169,15 @@ class SSH(object): return self.send(data) def __del__(self): + # type: () -> None self.__cleanup() - def __exit__(self, ex_type, ex_value, tb): + def __exit__(self, *args): + # type: (*Any) -> None self.__cleanup() def __cleanup(self): + # type: () -> None try: self.__sock.shutdown(socket.SHUT_RDWR) self.__sock.close() @@ -1059,13 +1187,15 @@ class SSH(object): class KexDH(object): def __init__(self, alg, g, p): + # type: (str, int, int) -> None self.__alg = alg self.__g = g self.__p = p self.__q = (self.__p - 1) // 2 - self.__x = None + self.__x = None # type: Optional[int] def send_init(self, s): + # type: (SSH.Socket) -> None r = random.SystemRandom() self.__x = r.randrange(2, self.__q) self.__e = pow(self.__g, self.__x, self.__p) @@ -1076,6 +1206,7 @@ class KexDH(object): class KexGroup1(KexDH): def __init__(self): + # type: () -> None # rfc2409: second oakley group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67' 'cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6d' @@ -1087,6 +1218,7 @@ class KexGroup1(KexDH): class KexGroup14(KexDH): def __init__(self): + # type: () -> None # rfc3526: 2048-bit modp group p = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67' 'cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6d' @@ -1208,10 +1340,11 @@ class KexDB(object): 'umac-64-etm@openssh.com': [['6.2'], [], [WARN_TAG_SIZE]], 'umac-128-etm@openssh.com': [['6.2']], } - } + } # type: Dict[str, Dict[str, List[List[str]]]] def get_ssh_version(version_desc): + # type: (str) -> Tuple[str, str] if version_desc.startswith('d'): return (SSH.Product.DropbearSSH, version_desc[1:]) elif version_desc.startswith('l1'): @@ -1220,8 +1353,8 @@ def get_ssh_version(version_desc): return (SSH.Product.OpenSSH, version_desc) -def get_alg_timeframe(alg_desc, for_server=True, result={}): - versions = alg_desc[0] +def get_alg_timeframe(versions, for_server=True, result={}): + # type: (List[str], bool, Dict[str, List[Optional[str]]]) -> Dict[str, List[Optional[str]]] vlen = len(versions) for i in range(3): if i > vlen - 1: @@ -1256,23 +1389,25 @@ def get_alg_timeframe(alg_desc, for_server=True, result={}): def get_ssh_timeframe(alg_pairs, for_server=True): - timeframe = {} + # type: (List[Tuple[int, Dict[str, Dict[str, List[List[str]]]], List[Tuple[str, List[text_type]]]]], bool) -> Dict[str, List[Optional[str]]] + timeframe = {} # type: Dict[str, List[Optional[str]]] for alg_pair in alg_pairs: - sshv, alg_db = alg_pair[0] - alg_sets = alg_pair[1:] - for alg_set in alg_sets: + sshv, alg_db = alg_pair[0], alg_pair[1] + for alg_set in alg_pair[2]: alg_type, alg_list = alg_set for alg_name in alg_list: - alg_desc = alg_db[alg_type].get(alg_name) + alg_name_native = utils.to_ntext(alg_name) + alg_desc = alg_db[alg_type].get(alg_name_native) if alg_desc is None: continue - timeframe = get_alg_timeframe(alg_desc, for_server, timeframe) + versions = alg_desc[0] + timeframe = get_alg_timeframe(versions, for_server, timeframe) return timeframe -def get_alg_since_text(alg_desc): +def get_alg_since_text(versions): + # type: (List[str]) -> text_type tv = [] - versions = alg_desc[0] if len(versions) == 0 or versions[0] is None: return None for v in versions[0].split(','): @@ -1290,22 +1425,24 @@ def get_alg_since_text(alg_desc): def get_alg_pairs(kex, pkm): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage]) -> List[Tuple[int, Dict[str, Dict[str, List[List[str]]]], List[Tuple[str, List[text_type]]]]] alg_pairs = [] if pkm is not None: - alg_pairs.append(((1, SSH1.KexDB.ALGORITHMS), - ('key', ['ssh-rsa1']), - ('enc', pkm.supported_ciphers), - ('aut', pkm.supported_authentications))) + alg_pairs.append((1, SSH1.KexDB.ALGORITHMS, + [('key', [u'ssh-rsa1']), + ('enc', pkm.supported_ciphers), + ('aut', pkm.supported_authentications)])) if kex is not None: - alg_pairs.append(((2, KexDB.ALGORITHMS), - ('kex', kex.kex_algorithms), - ('key', kex.key_algorithms), - ('enc', kex.server.encryption), - ('mac', kex.server.mac))) + alg_pairs.append((2, KexDB.ALGORITHMS, + [('kex', kex.kex_algorithms), + ('key', kex.key_algorithms), + ('enc', kex.server.encryption), + ('mac', kex.server.mac)])) return alg_pairs def get_alg_recommendations(software, kex, pkm, for_server=True): + # type: (SSH.Software, SSH2.Kex, SSH1.PublicKeyMessage, bool) -> Tuple[SSH.Software, Dict[int, Dict[str, Dict[str, Dict[str, int]]]]] alg_pairs = get_alg_pairs(kex, pkm) vproducts = [SSH.Product.OpenSSH, SSH.Product.DropbearSSH, @@ -1322,18 +1459,17 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): if version is not None: software = SSH.Software(None, product, version, None, None) break - rec = {'.software': software} + rec = {} # type: Dict[int, Dict[str, Dict[str, Dict[str, int]]]] if software is None: - return rec + return software, rec for alg_pair in alg_pairs: - sshv, alg_db = alg_pair[0] - alg_sets = alg_pair[1:] + sshv, alg_db = alg_pair[0], alg_pair[1] rec[sshv] = {} - for alg_set in alg_sets: + for alg_set in alg_pair[2]: alg_type, alg_list = alg_set if alg_type == 'aut': continue - rec[sshv][alg_type] = {'add': [], 'del': {}} + rec[sshv][alg_type] = {'add': {}, 'del': {}} for n, alg_desc in alg_db[alg_type].items(): if alg_type == 'key' and '-cert-' in n: continue @@ -1367,7 +1503,7 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): if n not in alg_list: if faults > 0: continue - rec[sshv][alg_type]['add'].append(n) + rec[sshv][alg_type]['add'][n] = 0 else: if faults == 0: continue @@ -1379,10 +1515,11 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): del_count = len(rec[sshv][alg_type]['del']) new_alg_count = len(alg_list) + add_count - del_count if new_alg_count < 1 and del_count > 0: - mf, new_del = min(rec[sshv][alg_type]['del'].values()), {} - for k, v in rec[sshv][alg_type]['del'].items(): - if v != mf: - new_del[k] = v + mf = min(rec[sshv][alg_type]['del'].values()) + new_del = {} + for k, cf in rec[sshv][alg_type]['del'].items(): + if cf != mf: + new_del[k] = cf if del_count != len(new_del): rec[sshv][alg_type]['del'] = new_del new_alg_count += del_count - len(new_del) @@ -1397,10 +1534,11 @@ def get_alg_recommendations(software, kex, pkm, for_server=True): del rec[sshv][alg_type] if len(rec[sshv]) == 0: del rec[sshv] - return rec + return software, rec def output_algorithms(title, alg_db, alg_type, algorithms, maxlen=0): + # type: (str, Dict[str, Dict[str, List[List[str]]]], str, List[text_type], int) -> None with OutputBuffer() as obuf: for algorithm in algorithms: output_algorithm(alg_db, alg_type, algorithm, maxlen) @@ -1411,6 +1549,7 @@ def output_algorithms(title, alg_db, alg_type, algorithms, maxlen=0): def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): + # type: (Dict[str, Dict[str, List[List[str]]]], str, text_type, int) -> None prefix = '(' + alg_type + ') ' if alg_max_len == 0: alg_max_len = len(alg_name) @@ -1418,12 +1557,14 @@ def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): texts = [] if len(alg_name.strip()) == 0: return - if alg_name in alg_db[alg_type]: - alg_desc = alg_db[alg_type][alg_name] + alg_name_native = utils.to_ntext(alg_name) + if alg_name_native in alg_db[alg_type]: + alg_desc = alg_db[alg_type][alg_name_native] ldesc = len(alg_desc) for idx, level in enumerate(['fail', 'warn', 'info']): if level == 'info': - since_text = get_alg_since_text(alg_desc) + versions = alg_desc[0] + since_text = get_alg_since_text(versions) if since_text: texts.append((level, since_text)) idx = idx + 1 @@ -1451,6 +1592,7 @@ def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0): def output_compatibility(kex, pkm, for_server=True): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage], bool) -> None alg_pairs = get_alg_pairs(kex, pkm) ssh_timeframe = get_ssh_timeframe(alg_pairs, for_server) vp = 1 if for_server else 2 @@ -1474,21 +1616,22 @@ def output_compatibility(kex, pkm, for_server=True): def output_security_sub(sub, software, padlen): + # type: (str, SSH.Software, int) -> None secdb = SSH.Security.CVE if sub == 'cve' else SSH.Security.TXT if software is None or software.product not in secdb: return for line in secdb[software.product]: - vfrom, vtill = line[0:2] + vfrom, vtill = line[0:2] # type: str, str if not software.between_versions(vfrom, vtill): continue - target, name = line[2:4] + target, name = line[2:3] # type: int, str is_server, is_client = target & 1 == 1, target & 2 == 2 is_local = target & 4 == 4 if not is_server: continue p = '' if out.batch else ' ' * (padlen - len(name)) if sub == 'cve': - cvss, descr = line[4:6] + cvss, descr = line[4:6] # type: float, str out.fail('(cve) {0}{1} -- ({2}) {3}'.format(name, p, cvss, descr)) else: descr = line[4] @@ -1496,6 +1639,7 @@ def output_security_sub(sub, software, padlen): def output_security(banner, padlen): + # type: (SSH.Banner, int) -> None with OutputBuffer() as obuf: if banner: software = SSH.Software.parse(banner) @@ -1508,6 +1652,7 @@ def output_security(banner, padlen): def output_fingerprint(kex, pkm, sha256=True, padlen=0): + # type: (Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage], bool, int) -> None with OutputBuffer() as obuf: fps = [] if pkm is not None: @@ -1517,9 +1662,9 @@ def output_fingerprint(kex, pkm, sha256=True, padlen=0): fps.append((name, fp, bits)) for fpp in fps: name, fp, bits = fpp - fp = fp.sha256 if sha256 else fp.md5 + fpo = fp.sha256 if sha256 else fp.md5 p = '' if out.batch else ' ' * (padlen - len(name)) - out.good('(fin) {0}{1} -- {2} {3}'.format(name, p, bits, fp)) + out.good('(fin) {0}{1} -- {2} {3}'.format(name, p, bits, fpo)) if len(obuf) > 0: out.head('# fingerprints') obuf.flush() @@ -1527,10 +1672,10 @@ def output_fingerprint(kex, pkm, sha256=True, padlen=0): def output_recommendations(software, kex, pkm, padlen=0): + # type: (SSH.Software, SSH2.Kex, SSH1.PublicKeyMessage, int) -> None for_server = True with OutputBuffer() as obuf: - alg_rec = get_alg_recommendations(software, kex, pkm, for_server) - software = alg_rec['.software'] + software, alg_rec = get_alg_recommendations(software, kex, pkm, for_server) for sshv in range(2, 0, -1): if sshv not in alg_rec: continue @@ -1559,12 +1704,16 @@ def output_recommendations(software, kex, pkm, padlen=0): def output(banner, header, kex=None, pkm=None): + # type: (Optional[SSH.Banner], List[text_type], Optional[SSH2.Kex], Optional[SSH1.PublicKeyMessage]) -> None sshv = 1 if pkm else 2 with OutputBuffer() as obuf: if len(header) > 0: out.info('(gen) header: ' + '\n'.join(header)) if banner is not None: out.good('(gen) banner: {0}'.format(banner)) + if not banner.valid_ascii: + # NOTE: RFC 4253, Section 4.2 + out.warn('(gen) banner contains non-printable ASCII') if sshv == 1 or banner.protocol[0] == 1: out.fail('(gen) protocol SSH1 enabled') software = SSH.Software.parse(banner) @@ -1622,20 +1771,61 @@ def output(banner, header, kex=None, pkm=None): class Utils(object): - PY2 = sys.version_info[0] == 2 + @classmethod + def _type_err(cls, v, target): + # type: (Any, text_type) -> TypeError + return TypeError('cannot convert {0} to {1}'.format(type(v), target)) @classmethod - def wrap(cls): - o = cls() - if cls.PY2: - import StringIO - o.StringIO = o.BytesIO = StringIO.StringIO - else: - o.StringIO, o.BytesIO = io.StringIO, io.BytesIO - return o + def to_bytes(cls, v, enc='utf-8'): + # type: (Union[binary_type, text_type], str) -> binary_type + if isinstance(v, binary_type): + return v + elif isinstance(v, text_type): + return v.encode(enc) + raise cls._type_err(v, 'bytes') + + @classmethod + def to_utext(cls, v, enc='utf-8'): + # type: (Union[text_type, binary_type], str) -> text_type + if isinstance(v, text_type): + return v + elif isinstance(v, binary_type): + return v.decode(enc) + raise cls._type_err(v, 'unicode text') + + @classmethod + def to_ntext(cls, v, enc='utf-8'): + # type: (Union[text_type, binary_type], str) -> str + if isinstance(v, str): + return v + elif isinstance(v, text_type): + return v.encode(enc) + elif isinstance(v, binary_type): + return v.decode(enc) + raise cls._type_err(v, 'native text') + + @classmethod + def is_ascii(cls, v, enc='utf-8'): + # type: (Union[text_type, str], str) -> bool + try: + if isinstance(v, (text_type, str)): + v.encode('ascii') + return True + except UnicodeEncodeError: + pass + return False + + @classmethod + def to_ascii(cls, v, errors='replace'): + # type: (Union[text_type, str], str) -> str + if isinstance(v, (text_type, str)): + return cls.to_ntext(v.encode('ascii', errors)) + raise cls._type_err(v, 'ascii') @staticmethod def parse_int(v): + # type: (Any) -> int try: return int(v) except: @@ -1643,6 +1833,7 @@ class Utils(object): def audit(conf, sshv=None): + # type: (AuditConf, Optional[int]) -> None out.batch = conf.batch out.colors = conf.colors out.verbose = conf.verbose @@ -1658,23 +1849,24 @@ def audit(conf, sshv=None): packet_type, payload = s.read_packet(sshv) if packet_type < 0: try: - payload = payload.decode('utf-8') if payload else u'empty' + payload_txt = payload.decode('utf-8') if payload else u'empty' except UnicodeDecodeError: - payload = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) - if payload == u'Protocol major versions differ.': + payload_txt = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) + if payload_txt == u'Protocol major versions differ.': if sshv == 2 and conf.ssh1: audit(conf, 1) return - err = '[exception] error reading packet ({0})'.format(payload) + err = '[exception] error reading packet ({0})'.format(payload_txt) else: + err_pair = None if sshv == 1 and packet_type != SSH.Protocol.SMSG_PUBLIC_KEY: - err = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY) + err_pair = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY) elif sshv == 2 and packet_type != SSH.Protocol.MSG_KEXINIT: - err = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT) - if err is not None: + err_pair = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT) + if err_pair is not None: fmt = '[exception] did not receive {0} ({1}), ' + \ 'instead received unknown message ({2})' - err = fmt.format(err[0], err[1], packet_type) + err = fmt.format(err_pair[0], err_pair[1], packet_type) if err: output(banner, header) out.fail(err) @@ -1687,7 +1879,7 @@ def audit(conf, sshv=None): output(banner, header, kex=kex) -utils = Utils.wrap() +utils = Utils() out = Output() if __name__ == '__main__': conf = AuditConf.from_cmdline(sys.argv[1:], usage) From ca6cfb81a2936e1eb3a5bb2ca381c5d446322a40 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 19 Oct 2016 20:51:57 +0300 Subject: [PATCH 07/28] Import mypy configuration script and run scripts (for Python 2.7 and 3.5). Import pytest coverage script. --- .gitignore | 1 + test/coverage.sh | 10 ++++++++++ test/mypy-py2.sh | 10 ++++++++++ test/mypy-py3.sh | 10 ++++++++++ test/mypy.ini | 9 +++++++++ 5 files changed, 40 insertions(+) create mode 100755 test/coverage.sh create mode 100755 test/mypy-py2.sh create mode 100755 test/mypy-py3.sh create mode 100644 test/mypy.ini diff --git a/.gitignore b/.gitignore index 2f836aa..84f8554 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *~ *.pyc +./html/ diff --git a/test/coverage.sh b/test/coverage.sh new file mode 100755 index 0000000..28f2010 --- /dev/null +++ b/test/coverage.sh @@ -0,0 +1,10 @@ +#!/bin/sh +_cdir=$(cd -- "$(dirname "$0")" && pwd) +type py.test > /dev/null 2>&1 +if [ $? -ne 0 ]; then + echo "err: py.test (Python testing framework) not found." + exit 1 +fi +cd -- "${_cdir}/.." +mkdir -p html +py.test -v --cov-report=html:html/coverage --cov=ssh-audit test diff --git a/test/mypy-py2.sh b/test/mypy-py2.sh new file mode 100755 index 0000000..f8e9244 --- /dev/null +++ b/test/mypy-py2.sh @@ -0,0 +1,10 @@ +#!/bin/sh +_cdir=$(cd -- "$(dirname "$0")" && pwd) +type mypy > /dev/null 2>&1 +if [ $? -ne 0 ]; then + echo "err: mypy (Optional Static Typing for Python) not found." + exit 1 +fi +_htmldir="${_cdir}/../html/mypy-py2" +mkdir -p "${_htmldir}" +mypy --python-version 2.7 --config-file "${_cdir}/mypy.ini" --html-report "${_htmldir}" "${_cdir}/../ssh-audit.py" diff --git a/test/mypy-py3.sh b/test/mypy-py3.sh new file mode 100755 index 0000000..0d2dfe5 --- /dev/null +++ b/test/mypy-py3.sh @@ -0,0 +1,10 @@ +#!/bin/sh +_cdir=$(cd -- "$(dirname "$0")" && pwd) +type mypy > /dev/null 2>&1 +if [ $? -ne 0 ]; then + echo "err: mypy (Optional Static Typing for Python) not found." + exit 1 +fi +_htmldir="${_cdir}/../html/mypy-py3" +mkdir -p "${_htmldir}" +mypy --python-version 3.5 --config-file "${_cdir}/mypy.ini" --html-report "${_htmldir}" "${_cdir}/../ssh-audit.py" diff --git a/test/mypy.ini b/test/mypy.ini new file mode 100644 index 0000000..9c0a3e0 --- /dev/null +++ b/test/mypy.ini @@ -0,0 +1,9 @@ +[mypy] +silent_imports = True +disallow_untyped_calls = True +disallow_untyped_defs = True +check_untyped_defs = True +disallow-subclassing-any = True +warn-incomplete-stub = True +warn-redundant-casts = True + From 42be99a2c7f7c31d693d414930118bd190c07497 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 19 Oct 2016 20:53:47 +0300 Subject: [PATCH 08/28] Test for non-ASCII banner. --- test/test_errors.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/test_errors.py b/test/test_errors.py index 13cc9e2..abdbebe 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -11,6 +11,7 @@ class TestErrors(object): def _conf(self): conf = self.AuditConf('localhost', 22) + conf.colors = False conf.batch = True return conf @@ -81,6 +82,18 @@ class TestErrors(object): assert 'error reading packet' in lines[-1] assert 'xxx' in lines[-1] + def test_non_ascii_banner(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\xc3\xbc\r\n') + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'error reading packet' in lines[-1] + assert 'ASCII' in lines[-2] + assert lines[-3].endswith('SSH-2.0-ssh-audit-test?') + def test_nonutf8_data_after_banner(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') From 67087fb920de351ab15b598bd7164c4017e5fa07 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 16:27:11 +0300 Subject: [PATCH 09/28] Fix pylint reported anomalous-backslash-in-string. --- ssh-audit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index 7a27514..ec8d2df 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -784,7 +784,7 @@ class SSH(object): if full: patch = self.patch or '' if self.product == SSH.Product.OpenSSH: - mx = re.match('^(p\d)(.*)$', patch) + mx = re.match(r'^(p\d)(.*)$', patch) if mx is not None: out += mx.group(1) patch = mx.group(2).strip() @@ -843,7 +843,7 @@ class SSH(object): return 'FreeBSD' if d is None else 'FreeBSD ({0})'.format(d) w = ['RemotelyAnywhere', 'DesktopAuthority', 'RemoteSupportManager'] for win_soft in w: - mx = re.match(r'^in ' + win_soft + ' ([\d\.]+\d)$', c) + mx = re.match(r'^in ' + win_soft + r' ([\d\.]+\d)$', c) if mx: ver = mx.group(1) return 'Microsoft Windows ({0} {1})'.format(win_soft, ver) @@ -893,7 +893,7 @@ class SSH(object): class Banner(object): _RXP, _RXR = r'SSH-\d\.\s*?\d+', r'(-\s*([^\s]*)(?:\s+(.*))?)?' - RX_PROTOCOL = re.compile(re.sub(r'\\d(\+?)', '(\\d\g<1>)', _RXP)) + RX_PROTOCOL = re.compile(re.sub(r'\\d(\+?)', r'(\\d\g<1>)', _RXP)) RX_BANNER = re.compile(r'^({0}(?:(?:-{0})*)){1}$'.format(_RXP, _RXR)) def __init__(self, protocol, software, comments, valid_ascii): @@ -957,7 +957,7 @@ class SSH(object): software = '' comments = (mx.group(4) or '').strip() or None if comments is not None: - comments = re.sub('\s+', ' ', comments) + comments = re.sub(r'\s+', ' ', comments) return cls(protocol, software, comments, valid_ascii) class Fingerprint(object): From 5be64a8ad2ecbaa681c4d375b8d3b338f6855b5f Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 16:31:48 +0300 Subject: [PATCH 10/28] Fix pylint reported dangerous-default-value. --- ssh-audit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index ec8d2df..1ed19c9 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1353,8 +1353,9 @@ def get_ssh_version(version_desc): return (SSH.Product.OpenSSH, version_desc) -def get_alg_timeframe(versions, for_server=True, result={}): - # type: (List[str], bool, Dict[str, List[Optional[str]]]) -> Dict[str, List[Optional[str]]] +def get_alg_timeframe(versions, for_server=True, result=None): + # type: (List[str], bool, Optional[Dict[str, List[Optional[str]]]]) -> Dict[str, List[Optional[str]]] + result = result or {} vlen = len(versions) for i in range(3): if i > vlen - 1: From 4120377c0be5794eff8fc57037d672d4ffbc9c6d Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 16:41:44 +0300 Subject: [PATCH 11/28] Remove unnecessary argument. --- ssh-audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index 1ed19c9..cb7a1a4 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1807,7 +1807,7 @@ class Utils(object): raise cls._type_err(v, 'native text') @classmethod - def is_ascii(cls, v, enc='utf-8'): + def is_ascii(cls, v): # type: (Union[text_type, str], str) -> bool try: if isinstance(v, (text_type, str)): From dfb8c302bfbb7b58d7d7e9d7f5e1d4be483ffd40 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 16:46:53 +0300 Subject: [PATCH 12/28] Fix pylint reported attribute-defined-outside-init. --- ssh-audit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ssh-audit.py b/ssh-audit.py index cb7a1a4..15183cf 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -193,6 +193,7 @@ class Output(object): class OutputBuffer(list): + # pylint: disable=attribute-defined-outside-init def __enter__(self): # type: () -> OutputBuffer self.__buf = StringIO() @@ -1193,6 +1194,7 @@ class KexDH(object): self.__p = p self.__q = (self.__p - 1) // 2 self.__x = None # type: Optional[int] + self.__e = None # type: Optional[int] def send_init(self, s): # type: (SSH.Socket) -> None From cbe7ad4ac34181b878bd3d2d6c217bbc51b46033 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 17:06:23 +0300 Subject: [PATCH 13/28] Fix pylint reported no-self-use and disable checks in py2/3 compatibility code. --- ssh-audit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index 15183cf..af9c228 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -35,9 +35,10 @@ if sys.version_info >= (3,): else: import StringIO as _StringIO StringIO = BytesIO = _StringIO.StringIO - text_type = unicode + text_type = unicode # pylint: disable=undefined-variable binary_type = str try: + # pylint: disable=unused-import from typing import List, Tuple, Optional, Callable, Union, Any except: pass @@ -175,7 +176,8 @@ class Output(object): if not self.batch: print() - def _colorized(self, color): + @staticmethod + def _colorized(color): # type: (str) -> Callable[[text_type], None] return lambda x: print(u'{0}{1}\033[0m'.format(color, x)) From cdfe06e75db353439d8e124ffcc79db3a10d22b6 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 17:19:37 +0300 Subject: [PATCH 14/28] Fix type after argument removal. --- ssh-audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index af9c228..e1f82df 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1812,7 +1812,7 @@ class Utils(object): @classmethod def is_ascii(cls, v): - # type: (Union[text_type, str], str) -> bool + # type: (Union[text_type, str]) -> bool try: if isinstance(v, (text_type, str)): v.encode('ascii') From a5f1cd91975e8d592df534f75481411b67837e30 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 20:00:29 +0300 Subject: [PATCH 15/28] Tune prospector and pylint settings. --- test/prospector.sh | 7 ++++++- test/prospector.yml | 45 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/test/prospector.sh b/test/prospector.sh index e420aff..4398ec7 100755 --- a/test/prospector.sh +++ b/test/prospector.sh @@ -5,4 +5,9 @@ if [ $? -ne 0 ]; then echo "err: prospector (Python Static Analysis) not found." exit 1 fi -prospector --profile-path "${_cdir}" -P prospector "${_cdir}/../ssh-audit.py" +if [ X"$1" == X"" ]; then + _file="${_cdir}/../ssh-audit.py" +else + _file="$1" +fi +prospector -E --profile-path "${_cdir}" -P prospector "${_file}" diff --git a/test/prospector.yml b/test/prospector.yml index f3373b7..474af15 100644 --- a/test/prospector.yml +++ b/test/prospector.yml @@ -1,9 +1,42 @@ -inherits: - - strictness_veryhigh +strictness: veryhigh +doc-warnings: false + +pylint: + disable: + - multiple-imports + - invalid-name + - trailing-whitespace + + options: + max-args: 8 # default: 5 + max-locals: 20 # default: 15 + max-returns: 6 + max-branches: 15 # default: 12 + max-statements: 60 # default: 50 + max-parents: 7 + max-attributes: 8 # default: 7 + min-public-methods: 1 # default: 2 + max-public-methods: 20 + max-bool-expr: 5 + max-nested-blocks: 6 # default: 5 + max-line-length: 80 # default: 100 + ignore-long-lines: ^\s*(#\s+type:\s+.*|[A-Z0-9_]+\s+=\s+.*|('.*':\s+)?\[.*\],?)$ + max-module-lines: 2500 # default: 10000 pep8: disable: - - W191 - - W293 - - E501 - - E221 + - W191 # indentation contains tabs + - W293 # blank line contains whitespace + - E101 # indentation contains mixed spaces and tabs + - E401 # multiple imports on one line + - E501 # line too long + - E221 # multiple spaces before operator + +pyflakes: + disable: + - F401 # module imported but unused + - F821 # undefined name + +mccabe: + options: + max-complexity: 15 From 5b3b63062370482cfc79e465a65fb32471e3850c Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Thu, 20 Oct 2016 20:00:51 +0300 Subject: [PATCH 16/28] Fix pylint reported issues and disable unnecessary ones. --- ssh-audit.py | 167 +++++++++++++++++++++++++++------------------------ 1 file changed, 88 insertions(+), 79 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index e1f82df..b88e512 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -33,33 +33,33 @@ if sys.version_info >= (3,): text_type = str binary_type = bytes else: - import StringIO as _StringIO + import StringIO as _StringIO # pylint: disable=import-error StringIO = BytesIO = _StringIO.StringIO text_type = unicode # pylint: disable=undefined-variable binary_type = str try: # pylint: disable=unused-import from typing import List, Tuple, Optional, Callable, Union, Any -except: +except ImportError: pass def usage(err=None): # type: (Optional[str]) -> None - out = Output() + uout = Output() p = os.path.basename(sys.argv[0]) - out.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) + uout.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) if err is not None: - out.fail('\n' + err) - out.info('\nusage: {0} [-12bnv] [-l ] \n'.format(p)) - out.info(' -h, --help print this help') - out.info(' -1, --ssh1 force ssh version 1 only') - out.info(' -2, --ssh2 force ssh version 2 only') - out.info(' -b, --batch batch output') - out.info(' -n, --no-colors disable colors') - out.info(' -v, --verbose verbose output') - out.info(' -l, --level= minimum output level (info|warn|fail)') - out.sep() + uout.fail('\n' + err) + uout.info('\nusage: {0} [-12bnv] [-l ] \n'.format(p)) + uout.info(' -h, --help print this help') + uout.info(' -1, --ssh1 force ssh version 1 only') + uout.info(' -2, --ssh2 force ssh version 2 only') + uout.info(' -b, --batch batch output') + uout.info(' -n, --no-colors disable colors') + uout.info(' -v, --verbose verbose output') + uout.info(' -l, --level= minimum output level (info|warn|fail)') + uout.sep() sys.exit(1) @@ -97,7 +97,8 @@ class AuditConf(object): @classmethod def from_cmdline(cls, args, usage_cb): # type: (List[str], Callable[..., None]) -> AuditConf - conf = cls() + # pylint: disable=too-many-branches + aconf = cls() try: sopts = 'h12bnvl:' lopts = ['help', 'ssh1', 'ssh2', 'batch', @@ -105,25 +106,25 @@ class AuditConf(object): opts, args = getopt.getopt(args, sopts, lopts) except getopt.GetoptError as err: usage_cb(str(err)) - conf.ssh1, conf.ssh2 = False, False + aconf.ssh1, aconf.ssh2 = False, False for o, a in opts: if o in ('-h', '--help'): usage_cb() elif o in ('-1', '--ssh1'): - conf.ssh1 = True + aconf.ssh1 = True elif o in ('-2', '--ssh2'): - conf.ssh2 = True + aconf.ssh2 = True elif o in ('-b', '--batch'): - conf.batch = True - conf.verbose = True + aconf.batch = True + aconf.verbose = True elif o in ('-n', '--no-colors'): - conf.colors = False + aconf.colors = False elif o in ('-v', '--verbose'): - conf.verbose = True + aconf.verbose = True elif o in ('-l', '--level'): if a not in ('info', 'warn', 'fail'): usage_cb('level {0} is not valid'.format(a)) - conf.minlevel = a + aconf.minlevel = a if len(args) == 0: usage_cb() s = args[0].split(':') @@ -134,11 +135,11 @@ class AuditConf(object): usage_cb('host is empty') if port <= 0 or port > 65535: usage_cb('port {0} is not valid'.format(s[1])) - conf.host = host - conf.port = port - if not (conf.ssh1 or conf.ssh2): - conf.ssh1, conf.ssh2 = True, True - return conf + aconf.host = host + aconf.port = port + if not (aconf.ssh1 or aconf.ssh2): + aconf.ssh1, aconf.ssh2 = True, True + return aconf class Output(object): @@ -195,9 +196,9 @@ class Output(object): class OutputBuffer(list): - # pylint: disable=attribute-defined-outside-init def __enter__(self): # type: () -> OutputBuffer + # pylint: disable=attribute-defined-outside-init self.__buf = StringIO() self.__stdout = sys.stdout sys.stdout = self.__buf @@ -214,7 +215,7 @@ class OutputBuffer(list): sys.stdout = self.__stdout -class SSH2(object): +class SSH2(object): # pylint: disable=too-few-public-methods class KexParty(object): def __init__(self, enc, mac, compression, languages): # type: (List[text_type], List[text_type], List[text_type], List[text_type]) -> None @@ -345,7 +346,7 @@ class SSH1(object): for i in range(256): crc = 0 n = i - for j in range(8): + for _ in range(8): x = (crc ^ n) & 1 crc = (crc >> 1) ^ (x * 0xedb88320) n = n >> 1 @@ -371,7 +372,8 @@ class SSH1(object): cls._crc32 = cls.CRC32() return cls._crc32.calc(v) - class KexDB(object): + class KexDB(object): # pylint: disable=too-few-public-methods + # pylint: disable=bad-whitespace FAIL_PLAINTEXT = 'no encryption/integrity' FAIL_OPENSSH37_REMOVE = 'removed since OpenSSH 3.7' FAIL_NA_BROKEN = 'not implemented in OpenSSH, broken algorithm' @@ -451,6 +453,7 @@ class SSH1(object): @property def host_key_fingerprint_data(self): # type: () -> binary_type + # pylint: disable=protected-access mod = WriteBuf._create_mpint(self.host_key_public_modulus, False) e = WriteBuf._create_mpint(self.host_key_public_exponent, False) return mod + e @@ -686,27 +689,28 @@ class WriteBuf(object): return payload -class SSH(object): - class Protocol(object): +class SSH(object): # pylint: disable=too-few-public-methods + class Protocol(object): # pylint: disable=too-few-public-methods + # pylint: disable=bad-whitespace SMSG_PUBLIC_KEY = 2 MSG_KEXINIT = 20 MSG_NEWKEYS = 21 MSG_KEXDH_INIT = 30 MSG_KEXDH_REPLY = 32 - class Product(object): + class Product(object): # pylint: disable=too-few-public-methods OpenSSH = 'OpenSSH' DropbearSSH = 'Dropbear SSH' LibSSH = 'libssh' class Software(object): - def __init__(self, vendor, product, version, patch, os): + def __init__(self, vendor, product, version, patch, os_version): # type: (Optional[str], str, str, Optional[str], Optional[str]) -> None self.__vendor = vendor self.__product = product self.__version = version self.__patch = patch - self.__os = os + self.__os = os_version @property def vendor(self): @@ -735,6 +739,7 @@ class SSH(object): def compare_version(self, other): # type: (Union[None, SSH.Software, text_type]) -> int + # pylint: disable=too-many-branches if other is None: return 1 if isinstance(other, SSH.Software): @@ -780,22 +785,22 @@ class SSH(object): def display(self, full=True): # type: (bool) -> str - out = '{0} '.format(self.vendor) if self.vendor else '' - out += self.product + r = '{0} '.format(self.vendor) if self.vendor else '' + r += self.product if self.version: - out += ' {0}'.format(self.version) + r += ' {0}'.format(self.version) if full: patch = self.patch or '' if self.product == SSH.Product.OpenSSH: mx = re.match(r'^(p\d)(.*)$', patch) if mx is not None: - out += mx.group(1) + r += mx.group(1) patch = mx.group(2).strip() if patch: - out += ' ({0})'.format(patch) + r += ' ({0})'.format(patch) if self.os: - out += ' running on {0}'.format(self.os) - return out + r += ' running on {0}'.format(self.os) + return r def __str__(self): # type: () -> str @@ -803,18 +808,18 @@ class SSH(object): def __repr__(self): # type: () -> str - out = 'vendor={0}'.format(self.vendor) if self.vendor else '' + r = 'vendor={0}'.format(self.vendor) if self.vendor else '' if self.product: if self.vendor: - out += ', ' - out += 'product={0}'.format(self.product) + r += ', ' + r += 'product={0}'.format(self.product) if self.version: - out += ', version={0}'.format(self.version) + r += ', version={0}'.format(self.version) if self.patch: - out += ', patch={0}'.format(self.patch) + r += ', patch={0}'.format(self.patch) if self.os: - out += ', os={0}'.format(self.os) - return '<{0}({1})>'.format(self.__class__.__name__, out) + r += ', os={0}'.format(self.os) + return '<{0}({1})>'.format(self.__class__.__name__, r) @staticmethod def _fix_patch(patch): @@ -830,7 +835,7 @@ class SSH(object): return None @classmethod - def _extract_os(cls, c): + def _extract_os_version(cls, c): # type: (Optional[str]) -> str if c is None: return None @@ -859,6 +864,7 @@ class SSH(object): @classmethod def parse(cls, banner): # type: (SSH.Banner) -> SSH.Software + # pylint: disable=too-many-return-statements software = str(banner.software) mx = re.match(r'^dropbear_([\d\.]+\d+)(.*)', software) if mx: @@ -871,14 +877,14 @@ class SSH(object): patch = cls._fix_patch(mx.group(2)) v, p = 'OpenBSD', SSH.Product.OpenSSH v = None - os = cls._extract_os(banner.comments) - return cls(v, p, mx.group(1), patch, os) + os_version = cls._extract_os_version(banner.comments) + return cls(v, p, mx.group(1), patch, os_version) mx = re.match(r'^libssh-([\d\.]+\d+)(.*)', software) if mx: patch = cls._fix_patch(mx.group(2)) v, p = None, SSH.Product.LibSSH - os = cls._extract_os(banner.comments) - return cls(v, p, mx.group(1), patch, os) + os_version = cls._extract_os_version(banner.comments) + return cls(v, p, mx.group(1), patch, os_version) mx = re.match(r'^RomSShell_([\d\.]+\d+)(.*)', software) if mx: patch = cls._fix_patch(mx.group(2)) @@ -928,22 +934,22 @@ class SSH(object): def __str__(self): # type: () -> str - out = 'SSH-{0}.{1}'.format(self.protocol[0], self.protocol[1]) + r = 'SSH-{0}.{1}'.format(self.protocol[0], self.protocol[1]) if self.software is not None: - out += '-{0}'.format(self.software) + r += '-{0}'.format(self.software) if self.comments: - out += ' {0}'.format(self.comments) - return out + r += ' {0}'.format(self.comments) + return r def __repr__(self): # type: () -> str p = '{0}.{1}'.format(self.protocol[0], self.protocol[1]) - out = 'protocol={0}'.format(p) + r = 'protocol={0}'.format(p) if self.software: - out += ', software={0}'.format(self.software) + r += ', software={0}'.format(self.software) if self.comments: - out += ', comments={0}'.format(self.comments) - return '<{0}({1})>'.format(self.__class__.__name__, out) + r += ', comments={0}'.format(self.comments) + return '<{0}({1})>'.format(self.__class__.__name__, r) @classmethod def parse(cls, banner): @@ -982,7 +988,8 @@ class SSH(object): r = h.decode('ascii').rstrip('=') return u'SHA256:{0}'.format(r) - class Security(object): + class Security(object): # pylint: disable=too-few-public-methods + # pylint: disable=bad-whitespace CVE = { 'Dropbear SSH': [ ['0.44', '2015.71', 1, 'CVE-2016-3116', 5.5, 'bypass command restrictions via xauth command injection'], @@ -1031,7 +1038,7 @@ class SSH(object): try: self.__sock = socket.create_connection((host, port), cto) self.__sock.settimeout(rto) - except Exception as e: + except Exception as e: # pylint: disable=broad-except out.fail('[fail] {0}'.format(e)) sys.exit(1) @@ -1184,7 +1191,7 @@ class SSH(object): try: self.__sock.shutdown(socket.SHUT_RDWR) self.__sock.close() - except: + except: # pylint: disable=bare-except pass @@ -1236,7 +1243,8 @@ class KexGroup14(KexDH): super(KexGroup14, self).__init__('sha1', 2, p) -class KexDB(object): +class KexDB(object): # pylint: disable=too-few-public-methods + # pylint: disable=bad-whitespace WARN_OPENSSH72_LEGACY = 'disabled (in client) since OpenSSH 7.2, legacy algorithm' FAIL_OPENSSH70_LEGACY = 'removed since OpenSSH 7.0, legacy algorithm' FAIL_OPENSSH70_WEAK = 'removed (in server) and disabled (in client) since OpenSSH 7.0, weak algorithm' @@ -1397,7 +1405,7 @@ def get_ssh_timeframe(alg_pairs, for_server=True): # type: (List[Tuple[int, Dict[str, Dict[str, List[List[str]]]], List[Tuple[str, List[text_type]]]]], bool) -> Dict[str, List[Optional[str]]] timeframe = {} # type: Dict[str, List[Optional[str]]] for alg_pair in alg_pairs: - sshv, alg_db = alg_pair[0], alg_pair[1] + alg_db = alg_pair[1] for alg_set in alg_pair[2]: alg_type, alg_list = alg_set for alg_name in alg_list: @@ -1448,6 +1456,7 @@ def get_alg_pairs(kex, pkm): def get_alg_recommendations(software, kex, pkm, for_server=True): # type: (SSH.Software, SSH2.Kex, SSH1.PublicKeyMessage, bool) -> Tuple[SSH.Software, Dict[int, Dict[str, Dict[str, Dict[str, int]]]]] + # pylint: disable=too-many-locals,too-many-statements alg_pairs = get_alg_pairs(kex, pkm) vproducts = [SSH.Product.OpenSSH, SSH.Product.DropbearSSH, @@ -1833,19 +1842,19 @@ class Utils(object): # type: (Any) -> int try: return int(v) - except: + except: # pylint: disable=bare-except return 0 -def audit(conf, sshv=None): +def audit(aconf, sshv=None): # type: (AuditConf, Optional[int]) -> None - out.batch = conf.batch - out.colors = conf.colors - out.verbose = conf.verbose - out.minlevel = conf.minlevel - s = SSH.Socket(conf.host, conf.port) + out.batch = aconf.batch + out.colors = aconf.colors + out.verbose = aconf.verbose + out.minlevel = aconf.minlevel + s = SSH.Socket(aconf.host, aconf.port) if sshv is None: - sshv = 2 if conf.ssh2 else 1 + sshv = 2 if aconf.ssh2 else 1 err = None banner, header = s.get_banner(sshv) if banner is None: @@ -1858,8 +1867,8 @@ def audit(conf, sshv=None): except UnicodeDecodeError: payload_txt = u'"{0}"'.format(repr(payload).lstrip('b')[1:-1]) if payload_txt == u'Protocol major versions differ.': - if sshv == 2 and conf.ssh1: - audit(conf, 1) + if sshv == 2 and aconf.ssh1: + audit(aconf, 1) return err = '[exception] error reading packet ({0})'.format(payload_txt) else: From 855d64f5b1f5cbfe06df1cd72726b9850dd2e664 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 02:59:42 +0300 Subject: [PATCH 17/28] Ignore virtualenv and cache. --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 84f8554..481cc4a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *~ *.pyc -./html/ +html/ +venv/ +.cache/ \ No newline at end of file From 385c2303765c0a1ae20c70d1700e5887aea07772 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 11:50:12 +0300 Subject: [PATCH 18/28] Add colors support for Microsoft Windows via optional colorama dependency. --- ssh-audit.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index b88e512..b89e8da 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -42,6 +42,11 @@ try: from typing import List, Tuple, Optional, Callable, Union, Any except ImportError: pass +try: + from colorama import init as colorama_init + colorama_init() +except ImportError: + pass def usage(err=None): @@ -177,6 +182,11 @@ class Output(object): if not self.batch: print() + @property + def colors_supported(self): + # type: () -> bool + return 'colorama' in sys.modules or os.name == 'posix' + @staticmethod def _colorized(color): # type: (str) -> Callable[[text_type], None] @@ -188,7 +198,7 @@ class Output(object): return lambda x: None if not self.getlevel(name) >= self.__minlevel: return lambda x: None - if self.colors and os.name == 'posix' and name in self.COLORS: + if self.colors and self.colors_supported and name in self.COLORS: color = '\033[0;{0}m'.format(self.COLORS[name]) return self._colorized(color) else: From 182467e0e8463fe87e9e9453f5327a607226578a Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 11:52:55 +0300 Subject: [PATCH 19/28] Fix typo, which slipped in while adding type system. --- ssh-audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index b89e8da..b584748 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1648,7 +1648,7 @@ def output_security_sub(sub, software, padlen): vfrom, vtill = line[0:2] # type: str, str if not software.between_versions(vfrom, vtill): continue - target, name = line[2:3] # type: int, str + target, name = line[2:4] # type: int, str is_server, is_client = target & 1 == 1, target & 2 == 2 is_local = target & 4 == 4 if not is_server: From 66bd6c3ef02f250279bd1235e98c1acd2dd7bce9 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 11:57:13 +0300 Subject: [PATCH 20/28] Test colors only if they are supported. --- test/test_output.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_output.py b/test/test_output.py index 8ba3645..a50de38 100644 --- a/test/test_output.py +++ b/test/test_output.py @@ -62,6 +62,8 @@ class TestOutput(object): output_spy.begin() out.fail('fail color') assert output_spy.flush() == [u'fail color'] + if not out.colors_supported: + return # test with colors out.colors = True output_spy.begin() From 4bbb1f4d112faa427bba24885483baac3ff6ce1c Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 13:53:51 +0300 Subject: [PATCH 21/28] Use safer UTF-8 decoding (with replace) and add related tests. --- ssh-audit.py | 8 ++++---- test/test_buffer.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index b584748..424f936 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -40,12 +40,12 @@ else: try: # pylint: disable=unused-import from typing import List, Tuple, Optional, Callable, Union, Any -except ImportError: +except ImportError: # pragma: nocover pass try: from colorama import init as colorama_init colorama_init() -except ImportError: +except ImportError: # pragma: nocover pass @@ -572,7 +572,7 @@ class ReadBuf(object): def read_list(self): # type: () -> List[text_type] list_size = self.read_int() - return self.read(list_size).decode().split(',') + return self.read(list_size).decode('utf-8', 'replace').split(',') def read_string(self): # type: () -> binary_type @@ -607,7 +607,7 @@ class ReadBuf(object): def read_line(self): # type: () -> text_type - return self._buf.readline().rstrip().decode('utf-8') + return self._buf.readline().rstrip().decode('utf-8', 'replace') class WriteBuf(object): diff --git a/test/test_buffer.py b/test/test_buffer.py index 968e3f7..e0be311 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -9,6 +9,7 @@ class TestBuffer(object): def init(self, ssh_audit): self.rbuf = ssh_audit.ReadBuf self.wbuf = ssh_audit.WriteBuf + self.utf8rchar = b'\xef\xbf\xbd' def _b(self, v): v = re.sub(r'\s', '', v) @@ -75,6 +76,12 @@ class TestBuffer(object): assert w(p[0]) == self._b(p[1]) assert r(self._b(p[1])) == p[0] + def test_list_nonutf8(self): + r = lambda x: self.rbuf(x).read_list() + src = self._b('00 00 00 04 de ad be ef') + dst = [(b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')] + assert r(src) == dst + def test_line(self): w = lambda x: self.wbuf().write_line(x).write_flush() r = lambda x: self.rbuf(x).read_line() @@ -83,6 +90,12 @@ class TestBuffer(object): assert w(p[0]) == self._b(p[1]) assert r(self._b(p[1])) == p[0] + def test_line_nonutf8(self): + r = lambda x: self.rbuf(x).read_line() + src = self._b('de ad be af') + dst = (b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8') + assert r(src) == dst + def test_bitlen(self): class Py26Int(int): def bit_length(self): From aa4eabda6618c7ab8a11392515f046847aceff0e Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 14:04:54 +0300 Subject: [PATCH 22/28] Do not count coverage for missing import. --- ssh-audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssh-audit.py b/ssh-audit.py index 424f936..2b52ad7 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -44,7 +44,7 @@ except ImportError: # pragma: nocover pass try: from colorama import init as colorama_init - colorama_init() + colorama_init() # pragma: nocover except ImportError: # pragma: nocover pass From 318aab79bc8c294a9f38f702c84c052901344e01 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 16:57:30 +0300 Subject: [PATCH 23/28] Add simple server tests for SSH1 and SSH2. --- test/test_ssh1.py | 63 ++++++++++++++++++++++++++++++++++++++++++----- test/test_ssh2.py | 50 ++++++++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/test/test_ssh1.py b/test/test_ssh1.py index b9eec53..b70cce4 100644 --- a/test/test_ssh1.py +++ b/test/test_ssh1.py @@ -1,8 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import struct import pytest +# pylint: disable=line-too-long,attribute-defined-outside-init class TestSSH1(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): @@ -10,15 +12,32 @@ class TestSSH1(object): self.ssh1 = ssh_audit.SSH1 self.rbuf = ssh_audit.ReadBuf self.wbuf = ssh_audit.WriteBuf + self.audit = ssh_audit.audit + self.AuditConf = ssh_audit.AuditConf - def test_crc32(self): - assert self.ssh1.crc32(b'') == 0x00 - assert self.ssh1.crc32(b'The quick brown fox jumps over the lazy dog') == 0xb9c60808 + def _conf(self): + conf = self.AuditConf('localhost', 22) + conf.colors = False + conf.batch = True + conf.verbose = True + conf.ssh1 = True + conf.ssh2 = False + return conf - def _server_key(self): + def _create_ssh1_packet(self, payload): + padding = -(len(payload) + 4) % 8 + plen = len(payload) + 4 + pad_bytes = b'\x00' * padding + cksum = self.ssh1.crc32(pad_bytes + payload) + data = struct.pack('>I', plen) + pad_bytes + payload + struct.pack('>I', cksum) + return data + + @classmethod + def _server_key(cls): return (1024, 0x10001, 0xee6552da432e0ac2c422df1a51287507748bfe3b5e3e4fa989a8f49fdc163a17754939ef18ef8a667ea3b71036a151fcd7f5e01ceef1e4439864baf3ac569047582c69d6c128212e0980dcb3168f00d371004039983f6033cd785b8b8f85096c7d9405cbfdc664e27c966356a6b4eb6ee20ad43414b50de18b22829c1880b551) - def _host_key(self): + @classmethod + def _host_key(cls): return (2048, 0x10001, 0xdfa20cd2a530ccc8c870aa60d9feb3b35deeab81c3215a96557abbd683d21f4600f38e475d87100da9a4404220eeb3bb5584e5a2b5b48ffda58530ea19104a32577d7459d91e76aa711b241050f4cc6d5327ccce254f371acad3be56d46eb5919b73f20dbdb1177b700f00891c5bf4ed128bb90ed541b778288285bcfa28432ab5cbcb8321b6e24760e998e0daa519f093a631e44276d7dd252ce0c08c75e2ab28a7349ead779f97d0f20a6d413bf3623cd216dc35375f6366690bcc41e3b2d5465840ec7ee0dc7e3f1c101d674a0c7dbccbc3942788b111396add2f8153b46a0e4b50d66e57ee92958f1c860dd97cc0e40e32febff915343ed53573142bdf4b) def _pkm_payload(self): @@ -33,11 +52,17 @@ class TestSSH1(object): w.write_int(36) return w.write_flush() + def test_crc32(self): + assert self.ssh1.crc32(b'') == 0x00 + assert self.ssh1.crc32(b'The quick brown fox jumps over the lazy dog') == 0xb9c60808 + def test_fingerprint(self): + # pylint: disable=protected-access b, e, m = self._host_key() fpd = self.wbuf._create_mpint(m, False) fpd += self.wbuf._create_mpint(e, False) fp = self.ssh.Fingerprint(fpd) + assert b == 2048 assert fp.md5 == 'MD5:9d:26:f8:39:fc:20:9d:9b:ca:cc:4a:0f:e1:93:f5:96' assert fp.sha256 == 'SHA256:vZdx3mhzbvVJmn08t/ruv8WDhJ9jfKYsCTuSzot+QIs' @@ -63,7 +88,7 @@ class TestSSH1(object): assert fp.sha256 == 'SHA256:vZdx3mhzbvVJmn08t/ruv8WDhJ9jfKYsCTuSzot+QIs' def test_pkm_payload(self): - cookie = b'\x88\x99\xaa\xbb\xcc\xdd\xee\xff' + cookie = b'\x88\x99\xaa\xbb\xcc\xdd\xee\xff' skey = self._server_key() hkey = self._host_key() pflags = 2 @@ -72,3 +97,29 @@ class TestSSH1(object): pkm1 = self.ssh1.PublicKeyMessage(cookie, skey, hkey, pflags, cmask, amask) pkm2 = self.ssh1.PublicKeyMessage.parse(self._pkm_payload()) assert pkm1.payload == pkm2.payload + + def test_ssh1_server_simple(self, output_spy, virtual_socket): + vsocket = virtual_socket + w = self.wbuf() + w.write_byte(self.ssh.Protocol.SMSG_PUBLIC_KEY) + w.write(self._pkm_payload()) + vsocket.rdata.append(b'SSH-1.5-OpenSSH_7.2 ssh-audit-test\r\n') + vsocket.rdata.append(self._create_ssh1_packet(w.write_flush())) + output_spy.begin() + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 10 + + def test_ssh1_server_invalid_first_packet(self, output_spy, virtual_socket): + vsocket = virtual_socket + w = self.wbuf() + w.write_byte(self.ssh.Protocol.SMSG_PUBLIC_KEY + 1) + w.write(self._pkm_payload()) + vsocket.rdata.append(b'SSH-1.5-OpenSSH_7.2 ssh-audit-test\r\n') + vsocket.rdata.append(self._create_ssh1_packet(w.write_flush())) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 4 + assert 'unknown message' in lines[-1] diff --git a/test/test_ssh2.py b/test/test_ssh2.py index 93093b0..cdb348c 100644 --- a/test/test_ssh2.py +++ b/test/test_ssh2.py @@ -1,8 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, os +import struct, os +import pytest +# pylint: disable=line-too-long,attribute-defined-outside-init class TestSSH2(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): @@ -10,6 +12,27 @@ class TestSSH2(object): self.ssh2 = ssh_audit.SSH2 self.rbuf = ssh_audit.ReadBuf self.wbuf = ssh_audit.WriteBuf + self.audit = ssh_audit.audit + self.AuditConf = ssh_audit.AuditConf + + def _conf(self): + conf = self.AuditConf('localhost', 22) + conf.colors = False + conf.batch = True + conf.verbose = True + conf.ssh1 = False + conf.ssh2 = True + return conf + + @classmethod + def _create_ssh2_packet(cls, payload): + padding = -(len(payload) + 5) % 8 + if padding < 4: + padding += 8 + plen = len(payload) + padding + 1 + pad_bytes = b'\x00' * padding + data = struct.pack('>Ib', plen, padding) + payload + pad_bytes + return data def _kex_payload(self): w = self.wbuf() @@ -105,3 +128,28 @@ class TestSSH2(object): kex1 = self._get_kex_variat1() kex2 = self.ssh2.Kex.parse(self._kex_payload()) assert kex1.payload == kex2.payload + + def test_ssh2_server_simple(self, output_spy, virtual_socket): + vsocket = virtual_socket + w = self.wbuf() + w.write_byte(self.ssh.Protocol.MSG_KEXINIT) + w.write(self._kex_payload()) + vsocket.rdata.append(b'SSH-2.0-OpenSSH_7.3 ssh-audit-test\r\n') + vsocket.rdata.append(self._create_ssh2_packet(w.write_flush())) + output_spy.begin() + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 72 + + def test_ssh2_server_invalid_first_packet(self, output_spy, virtual_socket): + vsocket = virtual_socket + w = self.wbuf() + w.write_byte(self.ssh.Protocol.MSG_KEXINIT + 1) + vsocket.rdata.append(b'SSH-2.0-OpenSSH_7.3 ssh-audit-test\r\n') + vsocket.rdata.append(self._create_ssh2_packet(w.write_flush())) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 3 + assert 'unknown message' in lines[-1] From 84dfdcaf5ea97d18dc742d18b4dfc824a0d9cad8 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 16:59:43 +0300 Subject: [PATCH 24/28] Invalid CRC32 checksum test. --- test/test_ssh1.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_ssh1.py b/test/test_ssh1.py index b70cce4..0029845 100644 --- a/test/test_ssh1.py +++ b/test/test_ssh1.py @@ -24,11 +24,11 @@ class TestSSH1(object): conf.ssh2 = False return conf - def _create_ssh1_packet(self, payload): + def _create_ssh1_packet(self, payload, valid_crc=True): padding = -(len(payload) + 4) % 8 plen = len(payload) + 4 pad_bytes = b'\x00' * padding - cksum = self.ssh1.crc32(pad_bytes + payload) + cksum = self.ssh1.crc32(pad_bytes + payload) if valid_crc else 0 data = struct.pack('>I', plen) + pad_bytes + payload + struct.pack('>I', cksum) return data @@ -123,3 +123,17 @@ class TestSSH1(object): lines = output_spy.flush() assert len(lines) == 4 assert 'unknown message' in lines[-1] + + def test_ssh1_server_invalid_checksum(self, output_spy, virtual_socket): + vsocket = virtual_socket + w = self.wbuf() + w.write_byte(self.ssh.Protocol.SMSG_PUBLIC_KEY + 1) + w.write(self._pkm_payload()) + vsocket.rdata.append(b'SSH-1.5-OpenSSH_7.2 ssh-audit-test\r\n') + vsocket.rdata.append(self._create_ssh1_packet(w.write_flush(), False)) + output_spy.begin() + with pytest.raises(SystemExit): + self.audit(self._conf()) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'checksum' in lines[-1] From 4684ff0113893f969d99de4e02109505c19ec829 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Tue, 25 Oct 2016 17:19:08 +0300 Subject: [PATCH 25/28] Add linter fixes for tests. --- ssh-audit.py | 2 +- test/test_auditconf.py | 6 +++-- test/test_banner.py | 13 ++++++----- test/test_buffer.py | 43 +++++++++++++++++++----------------- test/test_errors.py | 4 +++- test/test_output.py | 5 +++-- test/test_software.py | 18 ++++++++------- test/test_version_compare.py | 1 + 8 files changed, 52 insertions(+), 40 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index 2b52ad7..c1dfaf2 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -1905,6 +1905,6 @@ def audit(aconf, sshv=None): utils = Utils() out = Output() -if __name__ == '__main__': +if __name__ == '__main__': # pragma: nocover conf = AuditConf.from_cmdline(sys.argv[1:], usage) audit(conf) diff --git a/test/test_auditconf.py b/test/test_auditconf.py index b4f42f4..6c23c2a 100644 --- a/test/test_auditconf.py +++ b/test/test_auditconf.py @@ -3,13 +3,15 @@ import pytest +# pylint: disable=attribute-defined-outside-init class TestAuditConf(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): self.AuditConf = ssh_audit.AuditConf self.usage = ssh_audit.usage - def _test_conf(self, conf, **kwargs): + @classmethod + def _test_conf(cls, conf, **kwargs): options = { 'host': None, 'port': 22, @@ -66,7 +68,7 @@ class TestAuditConf(object): excinfo.match(r'.*invalid level.*') def test_audit_conf_cmdline(self): - c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage) + c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage) # noqa with pytest.raises(SystemExit): conf = c('') with pytest.raises(SystemExit): diff --git a/test/test_banner.py b/test/test_banner.py index b2d9991..ca93a53 100644 --- a/test/test_banner.py +++ b/test/test_banner.py @@ -3,13 +3,14 @@ import pytest +# pylint: disable=line-too-long,attribute-defined-outside-init class TestBanner(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): self.ssh = ssh_audit.SSH def test_simple_banners(self): - banner = lambda x: self.ssh.Banner.parse(x) + banner = lambda x: self.ssh.Banner.parse(x) # noqa b = banner('SSH-2.0-OpenSSH_7.3') assert b.protocol == (2, 0) assert b.software == 'OpenSSH_7.3' @@ -27,12 +28,12 @@ class TestBanner(object): assert str(b) == 'SSH-1.5-Cisco-1.25' def test_invalid_banners(self): - b = lambda x: self.ssh.Banner.parse(x) + b = lambda x: self.ssh.Banner.parse(x) # noqa assert b('Something') is None assert b('SSH-XXX-OpenSSH_7.3') is None def test_banners_with_spaces(self): - b = lambda x: self.ssh.Banner.parse(x) + b = lambda x: self.ssh.Banner.parse(x) # noqa s = 'SSH-2.0-OpenSSH_4.3p2' assert str(b('SSH-2.0-OpenSSH_4.3p2 ')) == s assert str(b('SSH-2.0- OpenSSH_4.3p2')) == s @@ -43,7 +44,7 @@ class TestBanner(object): assert str(b('SSH-2.0- OpenSSH_4.3p2 Debian-9etch3 on i686-pc-linux-gnu ')) == s def test_banners_without_software(self): - b = lambda x: self.ssh.Banner.parse(x) + b = lambda x: self.ssh.Banner.parse(x) # noqa assert b('SSH-2.0').protocol == (2, 0) assert b('SSH-2.0').software is None assert b('SSH-2.0').comments is None @@ -54,13 +55,13 @@ class TestBanner(object): assert str(b('SSH-2.0-')) == 'SSH-2.0-' def test_banners_with_comments(self): - b = lambda x: self.ssh.Banner.parse(x) + b = lambda x: self.ssh.Banner.parse(x) # noqa assert repr(b('SSH-2.0-OpenSSH_7.2p2 Ubuntu-1')) == '' assert repr(b('SSH-1.99-OpenSSH_3.4p1 Debian 1:3.4p1-1.woody.3')) == '' assert repr(b('SSH-1.5-1.3.7 F-SECURE SSH')) == '' def test_banners_with_multiple_protocols(self): - b = lambda x: self.ssh.Banner.parse(x) + b = lambda x: self.ssh.Banner.parse(x) # noqa assert str(b('SSH-1.99-SSH-1.99-OpenSSH_3.6.1p2')) == 'SSH-1.99-OpenSSH_3.6.1p2' assert str(b('SSH-2.0-SSH-2.0-OpenSSH_4.3p2 Debian-9')) == 'SSH-2.0-OpenSSH_4.3p2 Debian-9' assert str(b('SSH-1.99-SSH-2.0-dropbear_0.5')) == 'SSH-1.99-dropbear_0.5' diff --git a/test/test_buffer.py b/test/test_buffer.py index e0be311..1e457bc 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest import re +import pytest +# pylint: disable=attribute-defined-outside-init,bad-whitespace class TestBuffer(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): @@ -11,7 +12,8 @@ class TestBuffer(object): self.wbuf = ssh_audit.WriteBuf self.utf8rchar = b'\xef\xbf\xbd' - def _b(self, v): + @classmethod + def _b(cls, v): v = re.sub(r'\s', '', v) data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)] return bytes(bytearray(data)) @@ -26,8 +28,8 @@ class TestBuffer(object): assert r.unread_len == 0 def test_byte(self): - w = lambda x: self.wbuf().write_byte(x).write_flush() - r = lambda x: self.rbuf(x).read_byte() + w = lambda x: self.wbuf().write_byte(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_byte() # noqa tc = [(0x00, '00'), (0x01, '01'), (0x10, '10'), @@ -37,8 +39,8 @@ class TestBuffer(object): assert r(self._b(p[1])) == p[0] def test_bool(self): - w = lambda x: self.wbuf().write_bool(x).write_flush() - r = lambda x: self.rbuf(x).read_bool() + w = lambda x: self.wbuf().write_bool(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_bool() # noqa tc = [(True, '01'), (False, '00')] for p in tc: @@ -46,8 +48,8 @@ class TestBuffer(object): assert r(self._b(p[1])) == p[0] def test_int(self): - w = lambda x: self.wbuf().write_int(x).write_flush() - r = lambda x: self.rbuf(x).read_int() + w = lambda x: self.wbuf().write_int(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_int() # noqa tc = [(0x00, '00 00 00 00'), (0x01, '00 00 00 01'), (0xabcd, '00 00 ab cd'), @@ -57,8 +59,8 @@ class TestBuffer(object): assert r(self._b(p[1])) == p[0] def test_string(self): - w = lambda x: self.wbuf().write_string(x).write_flush() - r = lambda x: self.rbuf(x).read_string() + w = lambda x: self.wbuf().write_string(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_string() # noqa tc = [(u'abc1', '00 00 00 04 61 62 63 31'), (b'abc2', '00 00 00 04 61 62 63 32')] for p in tc: @@ -69,34 +71,35 @@ class TestBuffer(object): assert r(self._b(p[1])) == v def test_list(self): - w = lambda x: self.wbuf().write_list(x).write_flush() - r = lambda x: self.rbuf(x).read_list() + w = lambda x: self.wbuf().write_list(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_list() # noqa tc = [(['d', 'ef', 'ault'], '00 00 00 09 64 2c 65 66 2c 61 75 6c 74')] for p in tc: assert w(p[0]) == self._b(p[1]) assert r(self._b(p[1])) == p[0] def test_list_nonutf8(self): - r = lambda x: self.rbuf(x).read_list() + r = lambda x: self.rbuf(x).read_list() # noqa src = self._b('00 00 00 04 de ad be ef') dst = [(b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')] assert r(src) == dst def test_line(self): - w = lambda x: self.wbuf().write_line(x).write_flush() - r = lambda x: self.rbuf(x).read_line() + w = lambda x: self.wbuf().write_line(x).write_flush() # noqa + r = lambda x: self.rbuf(x).read_line() # noqa tc = [(u'example line', '65 78 61 6d 70 6c 65 20 6c 69 6e 65 0d 0a')] for p in tc: assert w(p[0]) == self._b(p[1]) assert r(self._b(p[1])) == p[0] def test_line_nonutf8(self): - r = lambda x: self.rbuf(x).read_line() + r = lambda x: self.rbuf(x).read_line() # noqa src = self._b('de ad be af') dst = (b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8') assert r(src) == dst def test_bitlen(self): + # pylint: disable=protected-access class Py26Int(int): def bit_length(self): raise AttributeError @@ -104,8 +107,8 @@ class TestBuffer(object): assert self.wbuf._bitlength(Py26Int(42)) == 6 def test_mpint1(self): - mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush() - mpint1r = lambda x: self.rbuf(x).read_mpint1() + mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush() # noqa + mpint1r = lambda x: self.rbuf(x).read_mpint1() # noqa tc = [(0x0, '00 00'), (0x1234, '00 0d 12 34'), (0x12345, '00 11 01 23 45'), @@ -115,8 +118,8 @@ class TestBuffer(object): assert mpint1r(self._b(p[1])) == p[0] def test_mpint2(self): - mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush() - mpint2r = lambda x: self.rbuf(x).read_mpint2() + mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush() # noqa + mpint2r = lambda x: self.rbuf(x).read_mpint2() # noqa tc = [(0x0, '00 00 00 00'), (0x80, '00 00 00 02 00 80'), (0x9a378f9b2e332a7, '00 00 00 08 09 a3 78 f9 b2 e3 32 a7'), diff --git a/test/test_errors.py b/test/test_errors.py index abdbebe..ad35a54 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -1,8 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, socket +import socket +import pytest +# pylint: disable=attribute-defined-outside-init class TestErrors(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): diff --git a/test/test_output.py b/test/test_output.py index a50de38..74b2c19 100644 --- a/test/test_output.py +++ b/test/test_output.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import print_function -import pytest, io, sys +import pytest +# pylint: disable=attribute-defined-outside-init class TestOutput(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): @@ -23,7 +24,7 @@ class TestOutput(object): def test_output_buffer_no_flush(self, output_spy): output_spy.begin() - with self.OutputBuffer() as obuf: + with self.OutputBuffer(): print(u'abc') assert output_spy.flush() == [] diff --git a/test/test_software.py b/test/test_software.py index 20eca18..4b7ca75 100644 --- a/test/test_software.py +++ b/test/test_software.py @@ -3,19 +3,21 @@ import pytest +# pylint: disable=line-too-long,attribute-defined-outside-init class TestSoftware(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): self.ssh = ssh_audit.SSH def test_unknown_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa assert ps('SSH-1.5') is None assert ps('SSH-1.99-AlfaMegaServer') is None assert ps('SSH-2.0-BetaMegaServer 0.0.1') is None def test_openssh_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + # pylint: disable=too-many-statements + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-2.0-OpenSSH_7.3') assert s.vendor is None @@ -102,7 +104,7 @@ class TestSoftware(object): assert repr(s) == '' def test_dropbear_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-2.0-dropbear_2016.74') assert s.vendor is None @@ -153,7 +155,7 @@ class TestSoftware(object): assert repr(s) == '' def test_libssh_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-2.0-libssh-0.2') assert s.vendor is None @@ -179,7 +181,7 @@ class TestSoftware(object): assert repr(s) == '' def test_romsshell_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-2.0-RomSShell_5.40') assert s.vendor == 'Allegro Software' @@ -194,7 +196,7 @@ class TestSoftware(object): assert repr(s) == '' def test_hp_ilo_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-2.0-mpSSH_0.2.1') assert s.vendor == 'HP' @@ -209,7 +211,7 @@ class TestSoftware(object): assert repr(s) == '' def test_cisco_software(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # common s = ps('SSH-1.5-Cisco-1.25') assert s.vendor == 'Cisco' @@ -224,7 +226,7 @@ class TestSoftware(object): assert repr(s) == '' def test_sofware_os(self): - ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) + ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # unknown s = ps('SSH-2.0-OpenSSH_3.7.1 MegaOperatingSystem 123') assert s.os is None diff --git a/test/test_version_compare.py b/test/test_version_compare.py index 2f74310..d3f8554 100644 --- a/test/test_version_compare.py +++ b/test/test_version_compare.py @@ -3,6 +3,7 @@ import pytest +# pylint: disable=attribute-defined-outside-init class TestVersionCompare(object): @pytest.fixture(autouse=True) def init(self, ssh_audit): From 8018209dd1a3126be070c18a00a59eb9d27341a9 Mon Sep 17 00:00:00 2001 From: Andrew Murray Date: Wed, 26 Oct 2016 05:52:58 +1100 Subject: [PATCH 26/28] Fixed typos --- README.md | 2 +- test/test_software.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7830cc5..68b9b78 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ usage: ssh-audit.py [-bnv] [-l ] - implement full SSH1 support with fingerprint information - automatically fallback to SSH1 on protocol mismatch - add new options to force SSH1 or SSH2 (both allowed by default) - - parse banner information and convert it to specific sofware and OS version + - parse banner information and convert it to specific software and OS version - do not use padding in batch mode - several fixes (Cisco sshd, rare hangs, error handling, etc) diff --git a/test/test_software.py b/test/test_software.py index 4b7ca75..141ffec 100644 --- a/test/test_software.py +++ b/test/test_software.py @@ -225,7 +225,7 @@ class TestSoftware(object): assert s.display(False) == str(s) assert repr(s) == '' - def test_sofware_os(self): + def test_software_os(self): ps = lambda x: self.ssh.Software.parse(self.ssh.Banner.parse(x)) # noqa # unknown s = ps('SSH-2.0-OpenSSH_3.7.1 MegaOperatingSystem 123') From 66b9e079a8b35f990df6a64f6be4d9924438fe9f Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 26 Oct 2016 18:33:00 +0300 Subject: [PATCH 27/28] Implement new options (-4/--ipv4, -6/--ipv6, -p/--port ). By default both IPv4 and IPv6 is supported and order of precedence depends on OS. By using -46, IPv4 is prefered, but by using -64, IPv6 is preferd. For now the old way how to specify port (host:port) has been kept intact. --- ssh-audit.py | 165 +++++++++++++++++++++++++++++++++-------- test/conftest.py | 34 +++++++-- test/test_auditconf.py | 77 ++++++++++++++++++- 3 files changed, 237 insertions(+), 39 deletions(-) diff --git a/ssh-audit.py b/ssh-audit.py index c1dfaf2..9e11be0 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -28,21 +28,22 @@ import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 VERSION = 'v1.6.1.dev' -if sys.version_info >= (3,): +if sys.version_info >= (3,): # pragma: nocover StringIO, BytesIO = io.StringIO, io.BytesIO text_type = str binary_type = bytes -else: +else: # pragma: nocover import StringIO as _StringIO # pylint: disable=import-error StringIO = BytesIO = _StringIO.StringIO text_type = unicode # pylint: disable=undefined-variable binary_type = str -try: +try: # pragma: nocover # pylint: disable=unused-import - from typing import List, Tuple, Optional, Callable, Union, Any + from typing import List, Set, Sequence, Tuple, Iterable + from typing import Callable, Optional, Union, Any except ImportError: # pragma: nocover pass -try: +try: # pragma: nocover from colorama import init as colorama_init colorama_init() # pragma: nocover except ImportError: # pragma: nocover @@ -53,13 +54,16 @@ def usage(err=None): # type: (Optional[str]) -> None uout = Output() p = os.path.basename(sys.argv[0]) - uout.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) + uout.head('# {0} {1}, moo@arthepsy.eu\n'.format(p, VERSION)) if err is not None: uout.fail('\n' + err) - uout.info('\nusage: {0} [-12bnv] [-l ] \n'.format(p)) + uout.info('usage: {0} [-1246pbnvl] \n'.format(p)) uout.info(' -h, --help print this help') uout.info(' -1, --ssh1 force ssh version 1 only') uout.info(' -2, --ssh2 force ssh version 2 only') + uout.info(' -4, --ipv4 enable IPv4 (order of precedence)') + uout.info(' -6, --ipv6 enable IPv6 (order of precedence)') + uout.info(' -p, --port= port to connect') uout.info(' -b, --batch batch output') uout.info(' -n, --no-colors disable colors') uout.info(' -v, --verbose verbose output') @@ -69,6 +73,7 @@ def usage(err=None): class AuditConf(object): + # pylint: disable=too-many-instance-attributes def __init__(self, host=None, port=22): # type: (Optional[str], int) -> None self.host = host @@ -79,12 +84,35 @@ class AuditConf(object): self.colors = True self.verbose = False self.minlevel = 'info' + self.ipvo = () # type: Sequence[int] + self.ipv4 = False + self.ipv6 = False def __setattr__(self, name, value): - # type: (str, Union[str, int, bool]) -> None + # type: (str, Union[str, int, bool, Sequence[int]]) -> None valid = False if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: valid, value = True, True if value else False + elif name in ['ipv4', 'ipv6']: + valid = False + value = True if value else False + ipv = 4 if name == 'ipv4' else 6 + if value: + value = tuple(list(self.ipvo) + [ipv]) + else: + if len(self.ipvo) == 0: + value = (6,) if ipv == 4 else (4,) + else: + value = tuple(filter(lambda x: x != ipv, self.ipvo)) + self.__setattr__('ipvo', value) + elif name == 'ipvo': + if isinstance(value, (tuple, list)): + uniq_value = utils.unique_seq(value) + value = tuple(filter(lambda x: x in (4, 6), uniq_value)) + valid = True + ipv_both = len(value) == 0 + object.__setattr__(self, 'ipv4', ipv_both or 4 in value) + object.__setattr__(self, 'ipv6', ipv_both or 6 in value) elif name == 'port': valid, port = True, utils.parse_int(value) if port < 1 or port > 65535: @@ -105,13 +133,14 @@ class AuditConf(object): # pylint: disable=too-many-branches aconf = cls() try: - sopts = 'h12bnvl:' - lopts = ['help', 'ssh1', 'ssh2', 'batch', - 'no-colors', 'verbose', 'level='] + sopts = 'h1246p:bnvl:' + lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port', + 'batch', 'no-colors', 'verbose', 'level='] opts, args = getopt.getopt(args, sopts, lopts) except getopt.GetoptError as err: usage_cb(str(err)) aconf.ssh1, aconf.ssh2 = False, False + oport = None for o, a in opts: if o in ('-h', '--help'): usage_cb() @@ -119,6 +148,12 @@ class AuditConf(object): aconf.ssh1 = True elif o in ('-2', '--ssh2'): aconf.ssh2 = True + elif o in ('-4', '--ipv4'): + aconf.ipv4 = True + elif o in ('-6', '--ipv6'): + aconf.ipv6 = True + elif o in ('-p', '--port'): + oport = a elif o in ('-b', '--batch'): aconf.batch = True aconf.verbose = True @@ -132,14 +167,20 @@ class AuditConf(object): aconf.minlevel = a if len(args) == 0: usage_cb() - s = args[0].split(':') - host, port = s[0].strip(), 22 - if len(s) > 1: - port = utils.parse_int(s[1]) + if oport is not None: + host = args[0] + port = utils.parse_int(oport) + else: + s = args[0].split(':') + host = s[0].strip() + if len(s) == 2: + oport, port = s[1], utils.parse_int(s[1]) + else: + oport, port = '22', 22 if not host: usage_cb('host is empty') if port <= 0 or port > 65535: - usage_cb('port {0} is not valid'.format(s[1])) + usage_cb('port {0} is not valid'.format(oport)) aconf.host = host aconf.port = port if not (aconf.ssh1 or aconf.ssh2): @@ -1038,24 +1079,67 @@ class SSH(object): # pylint: disable=too-few-public-methods SM_BANNER_SENT = 1 - def __init__(self, host, port, cto=3.0, rto=5.0): - # type: (str, int, float, float) -> None + def __init__(self, host, port): + # type: (str, int) -> None + super(SSH.Socket, self).__init__() self.__block_size = 8 self.__state = 0 self.__header = [] # type: List[text_type] self.__banner = None # type: Optional[SSH.Banner] - super(SSH.Socket, self).__init__() - try: - self.__sock = socket.create_connection((host, port), cto) - self.__sock.settimeout(rto) - except Exception as e: # pylint: disable=broad-except - out.fail('[fail] {0}'.format(e)) - sys.exit(1) + self.__host = host + self.__port = port + self.__sock = None # type: socket.socket def __enter__(self): # type: () -> SSH.Socket return self + def _resolve(self, ipvo): + # type: (Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]] + ipvo = tuple(filter(lambda x: x in (4, 6), utils.unique_seq(ipvo))) + ipvo_len = len(ipvo) + prefer_ipvo = ipvo_len > 0 + prefer_ipv4 = prefer_ipvo and ipvo[0] == 4 + if len(ipvo) == 1: + family = {4: socket.AF_INET, 6: socket.AF_INET6}.get(ipvo[0]) + else: + family = socket.AF_UNSPEC + try: + stype = socket.SOCK_STREAM + r = socket.getaddrinfo(self.__host, self.__port, family, stype) + if prefer_ipvo: + r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4) + check = any(stype == rline[2] for rline in r) + for (af, socktype, proto, canonname, addr) in r: + if not check or socktype == socket.SOCK_STREAM: + yield (af, addr) + except socket.error as e: + out.fail('[exception] {0}'.format(e)) + sys.exit(1) + + def connect(self, ipvo=(), cto=3.0, rto=5.0): + # type: (Sequence[int], float, float) -> None + err = None + for (af, addr) in self._resolve(ipvo): + s = None + try: + s = socket.socket(af, socket.SOCK_STREAM) + s.settimeout(cto) + s.connect(addr) + s.settimeout(rto) + self.__sock = s + return + except socket.error as e: + err = e + self._close_socket(s) + if err is None: + errm = 'host {0} has no DNS records'.format(self.__host) + else: + errt = (self.__host, self.__port, err) + errm = 'cannot connect to {0} port {1}: {2}'.format(*errt) + out.fail('[exception] {0}'.format(errm)) + sys.exit(1) + def get_banner(self, sshv=2): # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') @@ -1188,6 +1272,15 @@ class SSH(object): # pylint: disable=too-few-public-methods data = struct.pack('>Ib', plen, padding) + payload + pad_bytes return self.send(data) + def _close_socket(self, s): + # type: (Optional[socket.socket]) -> None + try: + if s is not None: + s.shutdown(socket.SHUT_RDWR) + s.close() + except: # pylint: disable=bare-except + pass + def __del__(self): # type: () -> None self.__cleanup() @@ -1198,11 +1291,7 @@ class SSH(object): # pylint: disable=too-few-public-methods def __cleanup(self): # type: () -> None - try: - self.__sock.shutdown(socket.SHUT_RDWR) - self.__sock.close() - except: # pylint: disable=bare-except - pass + self._close_socket(self.__sock) class KexDH(object): @@ -1847,6 +1936,21 @@ class Utils(object): return cls.to_ntext(v.encode('ascii', errors)) raise cls._type_err(v, 'ascii') + @classmethod + def unique_seq(cls, seq): + # type: (Sequence[Any]) -> Sequence[Any] + seen = set() # type: Set[Any] + + def _seen_add(x): + # type: (Any) -> bool + seen.add(x) + return False + + if isinstance(seq, tuple): + return tuple(x for x in seq if x not in seen and not _seen_add(x)) + else: + return [x for x in seq if x not in seen and not _seen_add(x)] + @staticmethod def parse_int(v): # type: (Any) -> int @@ -1863,6 +1967,7 @@ def audit(aconf, sshv=None): out.verbose = aconf.verbose out.minlevel = aconf.minlevel s = SSH.Socket(aconf.host, aconf.port) + s.connect(aconf.ipvo) if sshv is None: sshv = 2 if aconf.ssh2 else 1 err = None diff --git a/test/conftest.py b/test/conftest.py index 28ab4ef..524c0fa 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, os, sys, io, socket +import os +import io +import sys +import socket +import pytest if sys.version_info[0] == 2: - import StringIO + import StringIO # pylint: disable=import-error StringIO = StringIO.StringIO else: StringIO = io.StringIO @@ -17,6 +21,7 @@ def ssh_audit(): return __import__('ssh-audit') +# pylint: disable=attribute-defined-outside-init class _OutputSpy(list): def begin(self): self.__out = StringIO() @@ -50,11 +55,14 @@ class _VirtualSocket(object): if method_error: raise method_error - def _connect(self, address): + def connect(self, address): + return self._connect(address, False) + + def _connect(self, address, ret=True): self.peer_address = address self._connected = True self._check_err('connect') - return self + return self if ret else None def settimeout(self, timeout): self.timeout = timeout @@ -77,6 +85,7 @@ class _VirtualSocket(object): pass def accept(self): + # pylint: disable=protected-access conn = _VirtualSocket() conn.sock_address = self.sock_address conn.peer_address = ('127.0.0.1', 0) @@ -84,6 +93,7 @@ class _VirtualSocket(object): return conn, conn.peer_address def recv(self, bufsize, flags=0): + # pylint: disable=unused-argument if not self._connected: raise socket.error(54, 'Connection reset by peer') if not len(self.rdata) > 0: @@ -103,10 +113,18 @@ class _VirtualSocket(object): @pytest.fixture() def virtual_socket(monkeypatch): vsocket = _VirtualSocket() - def _c(address): - return vsocket._connect(address) + + # pylint: disable=unused-argument + def _socket(family=socket.AF_INET, + socktype=socket.SOCK_STREAM, + proto=0, + fileno=None): + return vsocket + def _cc(address, timeout=0, source_address=None): - return vsocket._connect(address) + # pylint: disable=protected-access + return vsocket._connect(address, True) + monkeypatch.setattr(socket, 'create_connection', _cc) - monkeypatch.setattr(socket.socket, 'connect', _c) + monkeypatch.setattr(socket, 'socket', _socket) return vsocket diff --git a/test/test_auditconf.py b/test/test_auditconf.py index 6c23c2a..3472c42 100644 --- a/test/test_auditconf.py +++ b/test/test_auditconf.py @@ -20,7 +20,10 @@ class TestAuditConf(object): 'batch': False, 'colors': True, 'verbose': False, - 'minlevel': 'info' + 'minlevel': 'info', + 'ipv4': True, + 'ipv6': True, + 'ipvo': () } for k, v in kwargs.items(): options[k] = v @@ -32,6 +35,9 @@ class TestAuditConf(object): assert conf.colors is options['colors'] assert conf.verbose is options['verbose'] assert conf.minlevel == options['minlevel'] + assert conf.ipv4 == options['ipv4'] + assert conf.ipv6 == options['ipv6'] + assert conf.ipvo == options['ipvo'] def test_audit_conf_defaults(self): conf = self.AuditConf() @@ -57,6 +63,58 @@ class TestAuditConf(object): conf.port = port excinfo.match(r'.*invalid port.*') + def test_audit_conf_ipvo(self): + # ipv4-only + conf = self.AuditConf() + conf.ipv4 = True + assert conf.ipv4 is True + assert conf.ipv6 is False + assert conf.ipvo == (4,) + # ipv6-only + conf = self.AuditConf() + conf.ipv6 = True + assert conf.ipv4 is False + assert conf.ipv6 is True + assert conf.ipvo == (6,) + # ipv4-only (by removing ipv6) + conf = self.AuditConf() + conf.ipv6 = False + assert conf.ipv4 is True + assert conf.ipv6 is False + assert conf.ipvo == (4, ) + # ipv6-only (by removing ipv4) + conf = self.AuditConf() + conf.ipv4 = False + assert conf.ipv4 is False + assert conf.ipv6 is True + assert conf.ipvo == (6, ) + # ipv4-preferred + conf = self.AuditConf() + conf.ipv4 = True + conf.ipv6 = True + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == (4, 6) + # ipv6-preferred + conf = self.AuditConf() + conf.ipv6 = True + conf.ipv4 = True + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == (6, 4) + # ipvo empty + conf = self.AuditConf() + conf.ipvo = () + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == () + # ipvo validation + conf = self.AuditConf() + conf.ipvo = (1, 2, 3, 4, 5, 6) + assert conf.ipvo == (4, 6) + conf.ipvo = (4, 4, 4, 6, 6) + assert conf.ipvo == (4, 6) + def test_audit_conf_minlevel(self): conf = self.AuditConf() for level in ['info', 'warn', 'fail']: @@ -68,6 +126,7 @@ class TestAuditConf(object): excinfo.match(r'.*invalid level.*') def test_audit_conf_cmdline(self): + # pylint: disable=too-many-statements c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage) # noqa with pytest.raises(SystemExit): conf = c('') @@ -87,20 +146,36 @@ class TestAuditConf(object): self._test_conf(conf, host='github.com') conf = c('localhost:2222') self._test_conf(conf, host='localhost', port=2222) + conf = c('-p 2222 localhost') + self._test_conf(conf, host='localhost', port=2222) with pytest.raises(SystemExit): conf = c('localhost:') with pytest.raises(SystemExit): conf = c('localhost:abc') + with pytest.raises(SystemExit): + conf = c('-p abc localhost') with pytest.raises(SystemExit): conf = c('localhost:-22') + with pytest.raises(SystemExit): + conf = c('-p -22 localhost') with pytest.raises(SystemExit): conf = c('localhost:99999') + with pytest.raises(SystemExit): + conf = c('-p 99999 localhost') conf = c('-1 localhost') self._test_conf(conf, host='localhost', ssh1=True, ssh2=False) conf = c('-2 localhost') self._test_conf(conf, host='localhost', ssh1=False, ssh2=True) conf = c('-12 localhost') self._test_conf(conf, host='localhost', ssh1=True, ssh2=True) + conf = c('-4 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=False, ipvo=(4,)) + conf = c('-6 localhost') + self._test_conf(conf, host='localhost', ipv4=False, ipv6=True, ipvo=(6,)) + conf = c('-46 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(4, 6)) + conf = c('-64 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(6, 4)) conf = c('-b localhost') self._test_conf(conf, host='localhost', batch=True, verbose=True) conf = c('-n localhost') From 4fbd339c547c3f81a3d4f16e203fb1ba9d3d5644 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Wed, 26 Oct 2016 18:56:38 +0300 Subject: [PATCH 28/28] Document changes and add coverage badge. --- README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 68b9b78..53367f7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # ssh-audit -[![build status](https://api.travis-ci.org/arthepsy/ssh-audit.svg)](https://travis-ci.org/arthepsy/ssh-audit) +[![build status](https://api.travis-ci.org/arthepsy/ssh-audit.svg)](https://travis-ci.org/arthepsy/ssh-audit) +[![coverage status](https://coveralls.io/repos/github/arthepsy/ssh-audit/badge.svg)](https://coveralls.io/github/arthepsy/ssh-audit) **ssh-audit** is a tool for ssh server auditing. ## Features @@ -15,16 +16,20 @@ ## Usage ``` -usage: ssh-audit.py [-bnv] [-l ] +usage: ssh-audit.py [-1246pbnvl] -1, --ssh1 force ssh version 1 only -2, --ssh2 force ssh version 2 only + -4, --ipv4 enable IPv4 (order of precedence) + -6, --ipv6 enable IPv6 (order of precedence) + -p, --port= port to connect -b, --batch batch output -n, --no-colors disable colors -v, --verbose verbose output -l, --level= minimum output level (info|warn|fail) ``` +* if both IPv4 and IPv6 are used, order of precedence can be set by using either `-46` or `-64`. * batch flag `-b` will output sections without header and without empty lines (implies verbose flag). * verbose flag `-v` will prefix each line with section type and algorithm name. @@ -32,6 +37,13 @@ usage: ssh-audit.py [-bnv] [-l ] ![screenshot](https://cloud.githubusercontent.com/assets/7356025/19233757/3e09b168-8ef0-11e6-91b4-e880bacd0b8a.png) ## ChangeLog +### v1.x.x (2016-xx-xx) + - implement options to allow specify IPv4/IPv6 usage and order of precedence + - implement option to specify remote port (old behavior kept for compatibility) + - add colors support for Microsoft Windows via optional colorama dependency + - fix encoding and decoding issues, add tests, do not crash on encoding errors + - use mypy-lang for static type checking and verify all code + ### v1.6.0 (2016-10-14) - implement algorithm recommendations section (based on recognized software) - implement full libssh support (version history, algorithms, security, etc)