diff --git a/test/conftest.py b/test/conftest.py index 524c0fa..0bc4124 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -40,6 +40,41 @@ def output_spy(): return _OutputSpy() +class _VirtualGlobalSocket(object): + def __init__(self, vsocket): + self.vsocket = vsocket + self.addrinfodata = {} + + # pylint: disable=unused-argument + def create_connection(self, address, timeout=0, source_address=None): + # pylint: disable=protected-access + return self.vsocket._connect(address, True) + + # pylint: disable=unused-argument + def socket(self, + family=socket.AF_INET, + socktype=socket.SOCK_STREAM, + proto=0, + fileno=None): + return self.vsocket + + def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0): + key = '{0}#{1}'.format(host, port) + if key in self.addrinfodata: + data = self.addrinfodata[key] + if isinstance(data, Exception): + raise data + return data + if host == 'localhost': + r = [] + if family in (0, socket.AF_INET): + r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port))) + if family in (0, socket.AF_INET6): + r.append((socket.AF_INET6, 1, 6, '', ('::1', port))) + return r + return [] + + class _VirtualSocket(object): def __init__(self): self.sock_address = ('127.0.0.1', 0) @@ -49,6 +84,7 @@ class _VirtualSocket(object): self.rdata = [] self.sdata = [] self.errors = {} + self.gsock = _VirtualGlobalSocket(self) def _check_err(self, method): method_error = self.errors.get(method) @@ -113,18 +149,8 @@ class _VirtualSocket(object): @pytest.fixture() def virtual_socket(monkeypatch): vsocket = _VirtualSocket() - - # pylint: disable=unused-argument - def _socket(family=socket.AF_INET, - socktype=socket.SOCK_STREAM, - proto=0, - fileno=None): - return vsocket - - def _cc(address, timeout=0, source_address=None): - # pylint: disable=protected-access - return vsocket._connect(address, True) - - monkeypatch.setattr(socket, 'create_connection', _cc) - monkeypatch.setattr(socket, 'socket', _socket) + gsock = vsocket.gsock + monkeypatch.setattr(socket, 'create_connection', gsock.create_connection) + monkeypatch.setattr(socket, 'socket', gsock.socket) + monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo) return vsocket diff --git a/test/test_errors.py b/test/test_errors.py index 4f3d6cc..abf720e 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -30,6 +30,13 @@ class TestErrors(object): lines = spy.flush() return lines + def test_connection_unresolved(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.gsock.addrinfodata['localhost#22'] = [] + lines = self._audit(output_spy) + assert len(lines) == 1 + assert 'has no DNS records' in lines[-1] + def test_connection_refused(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.errors['connect'] = socket.error(errno.ECONNREFUSED, 'Connection refused') @@ -91,6 +98,7 @@ class TestErrors(object): def test_connection_closed_after_header(self, output_spy, virtual_socket): vsocket = virtual_socket vsocket.rdata.append(b'header line 1\n') + vsocket.rdata.append(b'\n') vsocket.rdata.append(b'header line 2\n') vsocket.rdata.append(socket.error(errno.ECONNRESET, 'Connection reset by peer')) lines = self._audit(output_spy) diff --git a/test/test_resolve.py b/test/test_resolve.py new file mode 100644 index 0000000..8fcddf6 --- /dev/null +++ b/test/test_resolve.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import socket +import pytest + + +# pylint: disable=attribute-defined-outside-init,protected-access +class TestResolve(object): + @pytest.fixture(autouse=True) + def init(self, ssh_audit): + self.AuditConf = ssh_audit.AuditConf + self.audit = ssh_audit.audit + self.ssh = ssh_audit.SSH + + def _conf(self): + conf = self.AuditConf('localhost', 22) + conf.colors = False + conf.batch = True + return conf + + def test_resolve_error(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') + s = self.ssh.Socket('localhost', 22) + conf = self._conf() + output_spy.begin() + with pytest.raises(SystemExit): + r = list(s._resolve(conf.ipvo)) + lines = output_spy.flush() + assert len(lines) == 1 + assert 'hostname nor servname provided' in lines[-1] + + def test_resolve_hostname_without_records(self, output_spy, virtual_socket): + vsocket = virtual_socket + vsocket.gsock.addrinfodata['localhost#22'] = [] + s = self.ssh.Socket('localhost', 22) + conf = self._conf() + output_spy.begin() + r = list(s._resolve(conf.ipvo)) + assert len(r) == 0 + + def test_resolve_ipv4(self, virtual_socket): + vsocket = virtual_socket + conf = self._conf() + conf.ipv4 = True + s = self.ssh.Socket('localhost', 22) + r = list(s._resolve(conf.ipvo)) + assert len(r) == 1 + assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) + + def test_resolve_ipv6(self, virtual_socket): + vsocket = virtual_socket + s = self.ssh.Socket('localhost', 22) + conf = self._conf() + conf.ipv6 = True + r = list(s._resolve(conf.ipvo)) + assert len(r) == 1 + assert r[0] == (socket.AF_INET6, ('::1', 22)) + + def test_resolve_ipv46_both(self, virtual_socket): + vsocket = virtual_socket + s = self.ssh.Socket('localhost', 22) + conf = self._conf() + r = list(s._resolve(conf.ipvo)) + assert len(r) == 2 + assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) + assert r[1] == (socket.AF_INET6, ('::1', 22)) + + def test_resolve_ipv46_order(self, virtual_socket): + vsocket = virtual_socket + s = self.ssh.Socket('localhost', 22) + conf = self._conf() + conf.ipv4 = True + conf.ipv6 = True + r = list(s._resolve(conf.ipvo)) + assert len(r) == 2 + assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) + assert r[1] == (socket.AF_INET6, ('::1', 22)) + conf = self._conf() + conf.ipv6 = True + conf.ipv4 = True + r = list(s._resolve(conf.ipvo)) + assert len(r) == 2 + assert r[0] == (socket.AF_INET6, ('::1', 22)) + assert r[1] == (socket.AF_INET, ('127.0.0.1', 22))