diff --git a/ssh-audit.py b/ssh-audit.py index c1dfaf2..9e11be0 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -28,21 +28,22 @@ import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 VERSION = 'v1.6.1.dev' -if sys.version_info >= (3,): +if sys.version_info >= (3,): # pragma: nocover StringIO, BytesIO = io.StringIO, io.BytesIO text_type = str binary_type = bytes -else: +else: # pragma: nocover import StringIO as _StringIO # pylint: disable=import-error StringIO = BytesIO = _StringIO.StringIO text_type = unicode # pylint: disable=undefined-variable binary_type = str -try: +try: # pragma: nocover # pylint: disable=unused-import - from typing import List, Tuple, Optional, Callable, Union, Any + from typing import List, Set, Sequence, Tuple, Iterable + from typing import Callable, Optional, Union, Any except ImportError: # pragma: nocover pass -try: +try: # pragma: nocover from colorama import init as colorama_init colorama_init() # pragma: nocover except ImportError: # pragma: nocover @@ -53,13 +54,16 @@ def usage(err=None): # type: (Optional[str]) -> None uout = Output() p = os.path.basename(sys.argv[0]) - uout.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) + uout.head('# {0} {1}, moo@arthepsy.eu\n'.format(p, VERSION)) if err is not None: uout.fail('\n' + err) - uout.info('\nusage: {0} [-12bnv] [-l ] \n'.format(p)) + uout.info('usage: {0} [-1246pbnvl] \n'.format(p)) uout.info(' -h, --help print this help') uout.info(' -1, --ssh1 force ssh version 1 only') uout.info(' -2, --ssh2 force ssh version 2 only') + uout.info(' -4, --ipv4 enable IPv4 (order of precedence)') + uout.info(' -6, --ipv6 enable IPv6 (order of precedence)') + uout.info(' -p, --port= port to connect') uout.info(' -b, --batch batch output') uout.info(' -n, --no-colors disable colors') uout.info(' -v, --verbose verbose output') @@ -69,6 +73,7 @@ def usage(err=None): class AuditConf(object): + # pylint: disable=too-many-instance-attributes def __init__(self, host=None, port=22): # type: (Optional[str], int) -> None self.host = host @@ -79,12 +84,35 @@ class AuditConf(object): self.colors = True self.verbose = False self.minlevel = 'info' + self.ipvo = () # type: Sequence[int] + self.ipv4 = False + self.ipv6 = False def __setattr__(self, name, value): - # type: (str, Union[str, int, bool]) -> None + # type: (str, Union[str, int, bool, Sequence[int]]) -> None valid = False if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: valid, value = True, True if value else False + elif name in ['ipv4', 'ipv6']: + valid = False + value = True if value else False + ipv = 4 if name == 'ipv4' else 6 + if value: + value = tuple(list(self.ipvo) + [ipv]) + else: + if len(self.ipvo) == 0: + value = (6,) if ipv == 4 else (4,) + else: + value = tuple(filter(lambda x: x != ipv, self.ipvo)) + self.__setattr__('ipvo', value) + elif name == 'ipvo': + if isinstance(value, (tuple, list)): + uniq_value = utils.unique_seq(value) + value = tuple(filter(lambda x: x in (4, 6), uniq_value)) + valid = True + ipv_both = len(value) == 0 + object.__setattr__(self, 'ipv4', ipv_both or 4 in value) + object.__setattr__(self, 'ipv6', ipv_both or 6 in value) elif name == 'port': valid, port = True, utils.parse_int(value) if port < 1 or port > 65535: @@ -105,13 +133,14 @@ class AuditConf(object): # pylint: disable=too-many-branches aconf = cls() try: - sopts = 'h12bnvl:' - lopts = ['help', 'ssh1', 'ssh2', 'batch', - 'no-colors', 'verbose', 'level='] + sopts = 'h1246p:bnvl:' + lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port', + 'batch', 'no-colors', 'verbose', 'level='] opts, args = getopt.getopt(args, sopts, lopts) except getopt.GetoptError as err: usage_cb(str(err)) aconf.ssh1, aconf.ssh2 = False, False + oport = None for o, a in opts: if o in ('-h', '--help'): usage_cb() @@ -119,6 +148,12 @@ class AuditConf(object): aconf.ssh1 = True elif o in ('-2', '--ssh2'): aconf.ssh2 = True + elif o in ('-4', '--ipv4'): + aconf.ipv4 = True + elif o in ('-6', '--ipv6'): + aconf.ipv6 = True + elif o in ('-p', '--port'): + oport = a elif o in ('-b', '--batch'): aconf.batch = True aconf.verbose = True @@ -132,14 +167,20 @@ class AuditConf(object): aconf.minlevel = a if len(args) == 0: usage_cb() - s = args[0].split(':') - host, port = s[0].strip(), 22 - if len(s) > 1: - port = utils.parse_int(s[1]) + if oport is not None: + host = args[0] + port = utils.parse_int(oport) + else: + s = args[0].split(':') + host = s[0].strip() + if len(s) == 2: + oport, port = s[1], utils.parse_int(s[1]) + else: + oport, port = '22', 22 if not host: usage_cb('host is empty') if port <= 0 or port > 65535: - usage_cb('port {0} is not valid'.format(s[1])) + usage_cb('port {0} is not valid'.format(oport)) aconf.host = host aconf.port = port if not (aconf.ssh1 or aconf.ssh2): @@ -1038,24 +1079,67 @@ class SSH(object): # pylint: disable=too-few-public-methods SM_BANNER_SENT = 1 - def __init__(self, host, port, cto=3.0, rto=5.0): - # type: (str, int, float, float) -> None + def __init__(self, host, port): + # type: (str, int) -> None + super(SSH.Socket, self).__init__() self.__block_size = 8 self.__state = 0 self.__header = [] # type: List[text_type] self.__banner = None # type: Optional[SSH.Banner] - super(SSH.Socket, self).__init__() - try: - self.__sock = socket.create_connection((host, port), cto) - self.__sock.settimeout(rto) - except Exception as e: # pylint: disable=broad-except - out.fail('[fail] {0}'.format(e)) - sys.exit(1) + self.__host = host + self.__port = port + self.__sock = None # type: socket.socket def __enter__(self): # type: () -> SSH.Socket return self + def _resolve(self, ipvo): + # type: (Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]] + ipvo = tuple(filter(lambda x: x in (4, 6), utils.unique_seq(ipvo))) + ipvo_len = len(ipvo) + prefer_ipvo = ipvo_len > 0 + prefer_ipv4 = prefer_ipvo and ipvo[0] == 4 + if len(ipvo) == 1: + family = {4: socket.AF_INET, 6: socket.AF_INET6}.get(ipvo[0]) + else: + family = socket.AF_UNSPEC + try: + stype = socket.SOCK_STREAM + r = socket.getaddrinfo(self.__host, self.__port, family, stype) + if prefer_ipvo: + r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4) + check = any(stype == rline[2] for rline in r) + for (af, socktype, proto, canonname, addr) in r: + if not check or socktype == socket.SOCK_STREAM: + yield (af, addr) + except socket.error as e: + out.fail('[exception] {0}'.format(e)) + sys.exit(1) + + def connect(self, ipvo=(), cto=3.0, rto=5.0): + # type: (Sequence[int], float, float) -> None + err = None + for (af, addr) in self._resolve(ipvo): + s = None + try: + s = socket.socket(af, socket.SOCK_STREAM) + s.settimeout(cto) + s.connect(addr) + s.settimeout(rto) + self.__sock = s + return + except socket.error as e: + err = e + self._close_socket(s) + if err is None: + errm = 'host {0} has no DNS records'.format(self.__host) + else: + errt = (self.__host, self.__port, err) + errm = 'cannot connect to {0} port {1}: {2}'.format(*errt) + out.fail('[exception] {0}'.format(errm)) + sys.exit(1) + def get_banner(self, sshv=2): # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') @@ -1188,6 +1272,15 @@ class SSH(object): # pylint: disable=too-few-public-methods data = struct.pack('>Ib', plen, padding) + payload + pad_bytes return self.send(data) + def _close_socket(self, s): + # type: (Optional[socket.socket]) -> None + try: + if s is not None: + s.shutdown(socket.SHUT_RDWR) + s.close() + except: # pylint: disable=bare-except + pass + def __del__(self): # type: () -> None self.__cleanup() @@ -1198,11 +1291,7 @@ class SSH(object): # pylint: disable=too-few-public-methods def __cleanup(self): # type: () -> None - try: - self.__sock.shutdown(socket.SHUT_RDWR) - self.__sock.close() - except: # pylint: disable=bare-except - pass + self._close_socket(self.__sock) class KexDH(object): @@ -1847,6 +1936,21 @@ class Utils(object): return cls.to_ntext(v.encode('ascii', errors)) raise cls._type_err(v, 'ascii') + @classmethod + def unique_seq(cls, seq): + # type: (Sequence[Any]) -> Sequence[Any] + seen = set() # type: Set[Any] + + def _seen_add(x): + # type: (Any) -> bool + seen.add(x) + return False + + if isinstance(seq, tuple): + return tuple(x for x in seq if x not in seen and not _seen_add(x)) + else: + return [x for x in seq if x not in seen and not _seen_add(x)] + @staticmethod def parse_int(v): # type: (Any) -> int @@ -1863,6 +1967,7 @@ def audit(aconf, sshv=None): out.verbose = aconf.verbose out.minlevel = aconf.minlevel s = SSH.Socket(aconf.host, aconf.port) + s.connect(aconf.ipvo) if sshv is None: sshv = 2 if aconf.ssh2 else 1 err = None diff --git a/test/conftest.py b/test/conftest.py index 28ab4ef..524c0fa 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, os, sys, io, socket +import os +import io +import sys +import socket +import pytest if sys.version_info[0] == 2: - import StringIO + import StringIO # pylint: disable=import-error StringIO = StringIO.StringIO else: StringIO = io.StringIO @@ -17,6 +21,7 @@ def ssh_audit(): return __import__('ssh-audit') +# pylint: disable=attribute-defined-outside-init class _OutputSpy(list): def begin(self): self.__out = StringIO() @@ -50,11 +55,14 @@ class _VirtualSocket(object): if method_error: raise method_error - def _connect(self, address): + def connect(self, address): + return self._connect(address, False) + + def _connect(self, address, ret=True): self.peer_address = address self._connected = True self._check_err('connect') - return self + return self if ret else None def settimeout(self, timeout): self.timeout = timeout @@ -77,6 +85,7 @@ class _VirtualSocket(object): pass def accept(self): + # pylint: disable=protected-access conn = _VirtualSocket() conn.sock_address = self.sock_address conn.peer_address = ('127.0.0.1', 0) @@ -84,6 +93,7 @@ class _VirtualSocket(object): return conn, conn.peer_address def recv(self, bufsize, flags=0): + # pylint: disable=unused-argument if not self._connected: raise socket.error(54, 'Connection reset by peer') if not len(self.rdata) > 0: @@ -103,10 +113,18 @@ class _VirtualSocket(object): @pytest.fixture() def virtual_socket(monkeypatch): vsocket = _VirtualSocket() - def _c(address): - return vsocket._connect(address) + + # pylint: disable=unused-argument + def _socket(family=socket.AF_INET, + socktype=socket.SOCK_STREAM, + proto=0, + fileno=None): + return vsocket + def _cc(address, timeout=0, source_address=None): - return vsocket._connect(address) + # pylint: disable=protected-access + return vsocket._connect(address, True) + monkeypatch.setattr(socket, 'create_connection', _cc) - monkeypatch.setattr(socket.socket, 'connect', _c) + monkeypatch.setattr(socket, 'socket', _socket) return vsocket diff --git a/test/test_auditconf.py b/test/test_auditconf.py index 6c23c2a..3472c42 100644 --- a/test/test_auditconf.py +++ b/test/test_auditconf.py @@ -20,7 +20,10 @@ class TestAuditConf(object): 'batch': False, 'colors': True, 'verbose': False, - 'minlevel': 'info' + 'minlevel': 'info', + 'ipv4': True, + 'ipv6': True, + 'ipvo': () } for k, v in kwargs.items(): options[k] = v @@ -32,6 +35,9 @@ class TestAuditConf(object): assert conf.colors is options['colors'] assert conf.verbose is options['verbose'] assert conf.minlevel == options['minlevel'] + assert conf.ipv4 == options['ipv4'] + assert conf.ipv6 == options['ipv6'] + assert conf.ipvo == options['ipvo'] def test_audit_conf_defaults(self): conf = self.AuditConf() @@ -57,6 +63,58 @@ class TestAuditConf(object): conf.port = port excinfo.match(r'.*invalid port.*') + def test_audit_conf_ipvo(self): + # ipv4-only + conf = self.AuditConf() + conf.ipv4 = True + assert conf.ipv4 is True + assert conf.ipv6 is False + assert conf.ipvo == (4,) + # ipv6-only + conf = self.AuditConf() + conf.ipv6 = True + assert conf.ipv4 is False + assert conf.ipv6 is True + assert conf.ipvo == (6,) + # ipv4-only (by removing ipv6) + conf = self.AuditConf() + conf.ipv6 = False + assert conf.ipv4 is True + assert conf.ipv6 is False + assert conf.ipvo == (4, ) + # ipv6-only (by removing ipv4) + conf = self.AuditConf() + conf.ipv4 = False + assert conf.ipv4 is False + assert conf.ipv6 is True + assert conf.ipvo == (6, ) + # ipv4-preferred + conf = self.AuditConf() + conf.ipv4 = True + conf.ipv6 = True + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == (4, 6) + # ipv6-preferred + conf = self.AuditConf() + conf.ipv6 = True + conf.ipv4 = True + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == (6, 4) + # ipvo empty + conf = self.AuditConf() + conf.ipvo = () + assert conf.ipv4 is True + assert conf.ipv6 is True + assert conf.ipvo == () + # ipvo validation + conf = self.AuditConf() + conf.ipvo = (1, 2, 3, 4, 5, 6) + assert conf.ipvo == (4, 6) + conf.ipvo = (4, 4, 4, 6, 6) + assert conf.ipvo == (4, 6) + def test_audit_conf_minlevel(self): conf = self.AuditConf() for level in ['info', 'warn', 'fail']: @@ -68,6 +126,7 @@ class TestAuditConf(object): excinfo.match(r'.*invalid level.*') def test_audit_conf_cmdline(self): + # pylint: disable=too-many-statements c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage) # noqa with pytest.raises(SystemExit): conf = c('') @@ -87,20 +146,36 @@ class TestAuditConf(object): self._test_conf(conf, host='github.com') conf = c('localhost:2222') self._test_conf(conf, host='localhost', port=2222) + conf = c('-p 2222 localhost') + self._test_conf(conf, host='localhost', port=2222) with pytest.raises(SystemExit): conf = c('localhost:') with pytest.raises(SystemExit): conf = c('localhost:abc') + with pytest.raises(SystemExit): + conf = c('-p abc localhost') with pytest.raises(SystemExit): conf = c('localhost:-22') + with pytest.raises(SystemExit): + conf = c('-p -22 localhost') with pytest.raises(SystemExit): conf = c('localhost:99999') + with pytest.raises(SystemExit): + conf = c('-p 99999 localhost') conf = c('-1 localhost') self._test_conf(conf, host='localhost', ssh1=True, ssh2=False) conf = c('-2 localhost') self._test_conf(conf, host='localhost', ssh1=False, ssh2=True) conf = c('-12 localhost') self._test_conf(conf, host='localhost', ssh1=True, ssh2=True) + conf = c('-4 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=False, ipvo=(4,)) + conf = c('-6 localhost') + self._test_conf(conf, host='localhost', ipv4=False, ipv6=True, ipvo=(6,)) + conf = c('-46 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(4, 6)) + conf = c('-64 localhost') + self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(6, 4)) conf = c('-b localhost') self._test_conf(conf, host='localhost', batch=True, verbose=True) conf = c('-n localhost')