diff --git a/test/conftest.py b/test/conftest.py index 33e9bb4..28ab4ef 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import pytest, os, sys, io +import pytest, os, sys, io, socket if sys.version_info[0] == 2: @@ -33,3 +33,80 @@ class _OutputSpy(list): @pytest.fixture(scope='module') def output_spy(): return _OutputSpy() + + +class _VirtualSocket(object): + def __init__(self): + self.sock_address = ('127.0.0.1', 0) + self.peer_address = None + self._connected = False + self.timeout = -1.0 + self.rdata = [] + self.sdata = [] + self.errors = {} + + def _check_err(self, method): + method_error = self.errors.get(method) + if method_error: + raise method_error + + def _connect(self, address): + self.peer_address = address + self._connected = True + self._check_err('connect') + return self + + def settimeout(self, timeout): + self.timeout = timeout + + def gettimeout(self): + return self.timeout + + def getpeername(self): + if self.peer_address is None or not self._connected: + raise socket.error(57, 'Socket is not connected') + return self.peer_address + + def getsockname(self): + return self.sock_address + + def bind(self, address): + self.sock_address = address + + def listen(self, backlog): + pass + + def accept(self): + conn = _VirtualSocket() + conn.sock_address = self.sock_address + conn.peer_address = ('127.0.0.1', 0) + conn._connected = True + return conn, conn.peer_address + + def recv(self, bufsize, flags=0): + if not self._connected: + raise socket.error(54, 'Connection reset by peer') + if not len(self.rdata) > 0: + return b'' + data = self.rdata.pop(0) + if isinstance(data, Exception): + raise data + return data + + def send(self, data): + if self.peer_address is None or not self._connected: + raise socket.error(32, 'Broken pipe') + self._check_err('send') + self.sdata.append(data) + + +@pytest.fixture() +def virtual_socket(monkeypatch): + vsocket = _VirtualSocket() + def _c(address): + return vsocket._connect(address) + def _cc(address, timeout=0, source_address=None): + return vsocket._connect(address) + monkeypatch.setattr(socket, 'create_connection', _cc) + monkeypatch.setattr(socket.socket, 'connect', _c) + return vsocket