Implement new options (-4/--ipv4, -6/--ipv6, -p/--port <port>).

By default both IPv4 and IPv6 is supported and order of precedence depends on OS.
By using -46, IPv4 is prefered, but by using -64, IPv6 is preferd.
For now the old way how to specify port (host:port) has been kept intact.
This commit is contained in:
Andris Raugulis
2016-10-26 18:33:00 +03:00
parent 8018209dd1
commit 66b9e079a8
3 changed files with 237 additions and 39 deletions

View File

@ -1,10 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest, os, sys, io, socket
import os
import io
import sys
import socket
import pytest
if sys.version_info[0] == 2:
import StringIO
import StringIO # pylint: disable=import-error
StringIO = StringIO.StringIO
else:
StringIO = io.StringIO
@ -17,6 +21,7 @@ def ssh_audit():
return __import__('ssh-audit')
# pylint: disable=attribute-defined-outside-init
class _OutputSpy(list):
def begin(self):
self.__out = StringIO()
@ -50,11 +55,14 @@ class _VirtualSocket(object):
if method_error:
raise method_error
def _connect(self, address):
def connect(self, address):
return self._connect(address, False)
def _connect(self, address, ret=True):
self.peer_address = address
self._connected = True
self._check_err('connect')
return self
return self if ret else None
def settimeout(self, timeout):
self.timeout = timeout
@ -77,6 +85,7 @@ class _VirtualSocket(object):
pass
def accept(self):
# pylint: disable=protected-access
conn = _VirtualSocket()
conn.sock_address = self.sock_address
conn.peer_address = ('127.0.0.1', 0)
@ -84,6 +93,7 @@ class _VirtualSocket(object):
return conn, conn.peer_address
def recv(self, bufsize, flags=0):
# pylint: disable=unused-argument
if not self._connected:
raise socket.error(54, 'Connection reset by peer')
if not len(self.rdata) > 0:
@ -103,10 +113,18 @@ class _VirtualSocket(object):
@pytest.fixture()
def virtual_socket(monkeypatch):
vsocket = _VirtualSocket()
def _c(address):
return vsocket._connect(address)
# pylint: disable=unused-argument
def _socket(family=socket.AF_INET,
socktype=socket.SOCK_STREAM,
proto=0,
fileno=None):
return vsocket
def _cc(address, timeout=0, source_address=None):
return vsocket._connect(address)
# pylint: disable=protected-access
return vsocket._connect(address, True)
monkeypatch.setattr(socket, 'create_connection', _cc)
monkeypatch.setattr(socket.socket, 'connect', _c)
monkeypatch.setattr(socket, 'socket', _socket)
return vsocket

View File

@ -20,7 +20,10 @@ class TestAuditConf(object):
'batch': False,
'colors': True,
'verbose': False,
'minlevel': 'info'
'minlevel': 'info',
'ipv4': True,
'ipv6': True,
'ipvo': ()
}
for k, v in kwargs.items():
options[k] = v
@ -32,6 +35,9 @@ class TestAuditConf(object):
assert conf.colors is options['colors']
assert conf.verbose is options['verbose']
assert conf.minlevel == options['minlevel']
assert conf.ipv4 == options['ipv4']
assert conf.ipv6 == options['ipv6']
assert conf.ipvo == options['ipvo']
def test_audit_conf_defaults(self):
conf = self.AuditConf()
@ -57,6 +63,58 @@ class TestAuditConf(object):
conf.port = port
excinfo.match(r'.*invalid port.*')
def test_audit_conf_ipvo(self):
# ipv4-only
conf = self.AuditConf()
conf.ipv4 = True
assert conf.ipv4 is True
assert conf.ipv6 is False
assert conf.ipvo == (4,)
# ipv6-only
conf = self.AuditConf()
conf.ipv6 = True
assert conf.ipv4 is False
assert conf.ipv6 is True
assert conf.ipvo == (6,)
# ipv4-only (by removing ipv6)
conf = self.AuditConf()
conf.ipv6 = False
assert conf.ipv4 is True
assert conf.ipv6 is False
assert conf.ipvo == (4, )
# ipv6-only (by removing ipv4)
conf = self.AuditConf()
conf.ipv4 = False
assert conf.ipv4 is False
assert conf.ipv6 is True
assert conf.ipvo == (6, )
# ipv4-preferred
conf = self.AuditConf()
conf.ipv4 = True
conf.ipv6 = True
assert conf.ipv4 is True
assert conf.ipv6 is True
assert conf.ipvo == (4, 6)
# ipv6-preferred
conf = self.AuditConf()
conf.ipv6 = True
conf.ipv4 = True
assert conf.ipv4 is True
assert conf.ipv6 is True
assert conf.ipvo == (6, 4)
# ipvo empty
conf = self.AuditConf()
conf.ipvo = ()
assert conf.ipv4 is True
assert conf.ipv6 is True
assert conf.ipvo == ()
# ipvo validation
conf = self.AuditConf()
conf.ipvo = (1, 2, 3, 4, 5, 6)
assert conf.ipvo == (4, 6)
conf.ipvo = (4, 4, 4, 6, 6)
assert conf.ipvo == (4, 6)
def test_audit_conf_minlevel(self):
conf = self.AuditConf()
for level in ['info', 'warn', 'fail']:
@ -68,6 +126,7 @@ class TestAuditConf(object):
excinfo.match(r'.*invalid level.*')
def test_audit_conf_cmdline(self):
# pylint: disable=too-many-statements
c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage) # noqa
with pytest.raises(SystemExit):
conf = c('')
@ -87,20 +146,36 @@ class TestAuditConf(object):
self._test_conf(conf, host='github.com')
conf = c('localhost:2222')
self._test_conf(conf, host='localhost', port=2222)
conf = c('-p 2222 localhost')
self._test_conf(conf, host='localhost', port=2222)
with pytest.raises(SystemExit):
conf = c('localhost:')
with pytest.raises(SystemExit):
conf = c('localhost:abc')
with pytest.raises(SystemExit):
conf = c('-p abc localhost')
with pytest.raises(SystemExit):
conf = c('localhost:-22')
with pytest.raises(SystemExit):
conf = c('-p -22 localhost')
with pytest.raises(SystemExit):
conf = c('localhost:99999')
with pytest.raises(SystemExit):
conf = c('-p 99999 localhost')
conf = c('-1 localhost')
self._test_conf(conf, host='localhost', ssh1=True, ssh2=False)
conf = c('-2 localhost')
self._test_conf(conf, host='localhost', ssh1=False, ssh2=True)
conf = c('-12 localhost')
self._test_conf(conf, host='localhost', ssh1=True, ssh2=True)
conf = c('-4 localhost')
self._test_conf(conf, host='localhost', ipv4=True, ipv6=False, ipvo=(4,))
conf = c('-6 localhost')
self._test_conf(conf, host='localhost', ipv4=False, ipv6=True, ipvo=(6,))
conf = c('-46 localhost')
self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(4, 6))
conf = c('-64 localhost')
self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(6, 4))
conf = c('-b localhost')
self._test_conf(conf, host='localhost', batch=True, verbose=True)
conf = c('-n localhost')