Specify error when couldn't get banner. Test for timeout and retry cases.

This commit is contained in:
Andris Raugulis 2016-11-02 13:00:24 +02:00
parent dd3ca9688e
commit 44c1d4827c
2 changed files with 78 additions and 44 deletions

View File

@ -1141,16 +1141,17 @@ class SSH(object): # pylint: disable=too-few-public-methods
sys.exit(1) sys.exit(1)
def get_banner(self, sshv=2): def get_banner(self, sshv=2):
# type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] # type: (int) -> Tuple[Optional[SSH.Banner], List[text_type], Optional[str]]
banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0')
rto = self.__sock.gettimeout() rto = self.__sock.gettimeout()
self.__sock.settimeout(0.7) self.__sock.settimeout(0.7)
s, e = self.recv() s, e = self.recv()
self.__sock.settimeout(rto) self.__sock.settimeout(rto)
if s < 0: if s < 0:
return self.__banner, self.__header return self.__banner, self.__header, e
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.send_banner(banner) self.send_banner(banner)
e = None
while self.__banner is None: while self.__banner is None:
if not s > 0: if not s > 0:
s, e = self.recv() s, e = self.recv()
@ -1166,14 +1167,14 @@ class SSH(object): # pylint: disable=too-few-public-methods
continue continue
self.__header.append(line) self.__header.append(line)
s = 0 s = 0
return self.__banner, self.__header return self.__banner, self.__header, e
def recv(self, size=2048): def recv(self, size=2048):
# type: (int) -> Tuple[int, Optional[str]] # type: (int) -> Tuple[int, Optional[str]]
try: try:
data = self.__sock.recv(size) data = self.__sock.recv(size)
except socket.timeout: except socket.timeout:
return (-1, 'timeout') return (-1, 'timed out')
except socket.error as e: except socket.error as e:
if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
return (0, 'retry') return (0, 'retry')
@ -1971,9 +1972,12 @@ def audit(aconf, sshv=None):
if sshv is None: if sshv is None:
sshv = 2 if aconf.ssh2 else 1 sshv = 2 if aconf.ssh2 else 1
err = None err = None
banner, header = s.get_banner(sshv) banner, header, err = s.get_banner(sshv)
if banner is None: if banner is None:
if err is None:
err = '[exception] did not receive banner.' err = '[exception] did not receive banner.'
else:
err = '[exception] did not receive banner: {0}'.format(err)
if err is None: if err is None:
packet_type, payload = s.read_packet(sshv) packet_type, payload = s.read_packet(sshv)
if packet_type < 0: if packet_type < 0:

View File

@ -17,46 +17,91 @@ class TestErrors(object):
conf.batch = True conf.batch = True
return conf return conf
def _audit(self, spy, conf=None, sysexit=True):
if conf is None:
conf = self._conf()
spy.begin()
if sysexit:
with pytest.raises(SystemExit):
self.audit(conf)
else:
self.audit(conf)
lines = spy.flush()
return lines
def test_connection_refused(self, output_spy, virtual_socket): def test_connection_refused(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.errors['connect'] = socket.error(61, 'Connection refused') vsocket.errors['connect'] = socket.error(61, 'Connection refused')
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 1 assert len(lines) == 1
assert 'Connection refused' in lines[-1] assert 'Connection refused' in lines[-1]
def test_connection_timeout(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.errors['connect'] = socket.timeout('timed out')
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'timed out' in lines[-1]
def test_recv_empty(self, output_spy, virtual_socket):
vsocket = virtual_socket
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'did not receive banner' in lines[-1]
def test_recv_timeout(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(socket.timeout('timed out'))
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'did not receive banner' in lines[-1]
assert 'timed out' in lines[-1]
def test_recv_retry_till_timeout(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.timeout('timed out'))
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'did not receive banner' in lines[-1]
assert 'timed out' in lines[-1]
def test_recv_retry_till_reset(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.error(35, 'Resource temporarily unavailable'))
vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'did not receive banner' in lines[-1]
assert 'reset by peer' in lines[-1]
def test_connection_closed_before_banner(self, output_spy, virtual_socket): def test_connection_closed_before_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 1 assert len(lines) == 1
assert 'did not receive banner' in lines[-1] assert 'did not receive banner' in lines[-1]
assert 'reset by peer' in lines[-1]
def test_connection_closed_after_header(self, output_spy, virtual_socket): def test_connection_closed_after_header(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'header line 1\n') vsocket.rdata.append(b'header line 1\n')
vsocket.rdata.append(b'header line 2\n') vsocket.rdata.append(b'header line 2\n')
vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 3 assert len(lines) == 3
assert 'did not receive banner' in lines[-1] assert 'did not receive banner' in lines[-1]
assert 'reset by peer' in lines[-1]
def test_connection_closed_after_banner(self, output_spy, virtual_socket): def test_connection_closed_after_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2 assert len(lines) == 2
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert 'reset by peer' in lines[-1] assert 'reset by peer' in lines[-1]
@ -64,10 +109,7 @@ class TestErrors(object):
def test_empty_data_after_banner(self, output_spy, virtual_socket): def test_empty_data_after_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2 assert len(lines) == 2
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert 'empty' in lines[-1] assert 'empty' in lines[-1]
@ -76,10 +118,7 @@ class TestErrors(object):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
vsocket.rdata.append(b'xxx\n') vsocket.rdata.append(b'xxx\n')
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2 assert len(lines) == 2
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert 'xxx' in lines[-1] assert 'xxx' in lines[-1]
@ -87,10 +126,7 @@ class TestErrors(object):
def test_non_ascii_banner(self, output_spy, virtual_socket): def test_non_ascii_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\xc3\xbc\r\n') vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\xc3\xbc\r\n')
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 3 assert len(lines) == 3
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert 'ASCII' in lines[-2] assert 'ASCII' in lines[-2]
@ -100,10 +136,7 @@ class TestErrors(object):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
vsocket.rdata.append(b'\x81\xff\n') vsocket.rdata.append(b'\x81\xff\n')
output_spy.begin() lines = self._audit(output_spy)
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2 assert len(lines) == 2
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert '\\x81\\xff' in lines[-1] assert '\\x81\\xff' in lines[-1]
@ -112,12 +145,9 @@ class TestErrors(object):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n')
vsocket.rdata.append(b'Protocol major versions differ.\n') vsocket.rdata.append(b'Protocol major versions differ.\n')
output_spy.begin()
with pytest.raises(SystemExit):
conf = self._conf() conf = self._conf()
conf.ssh1, conf.ssh2 = True, False conf.ssh1, conf.ssh2 = True, False
self.audit(conf) lines = self._audit(output_spy, conf)
lines = output_spy.flush()
assert len(lines) == 3 assert len(lines) == 3
assert 'error reading packet' in lines[-1] assert 'error reading packet' in lines[-1]
assert 'major versions differ' in lines[-1] assert 'major versions differ' in lines[-1]