Add simple server tests for SSH1 and SSH2.

This commit is contained in:
Andris Raugulis
2016-10-25 16:57:30 +03:00
parent aa4eabda66
commit 318aab79bc
2 changed files with 106 additions and 7 deletions

View File

@ -1,8 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest, os
import struct, os
import pytest
# pylint: disable=line-too-long,attribute-defined-outside-init
class TestSSH2(object):
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
@ -10,6 +12,27 @@ class TestSSH2(object):
self.ssh2 = ssh_audit.SSH2
self.rbuf = ssh_audit.ReadBuf
self.wbuf = ssh_audit.WriteBuf
self.audit = ssh_audit.audit
self.AuditConf = ssh_audit.AuditConf
def _conf(self):
conf = self.AuditConf('localhost', 22)
conf.colors = False
conf.batch = True
conf.verbose = True
conf.ssh1 = False
conf.ssh2 = True
return conf
@classmethod
def _create_ssh2_packet(cls, payload):
padding = -(len(payload) + 5) % 8
if padding < 4:
padding += 8
plen = len(payload) + padding + 1
pad_bytes = b'\x00' * padding
data = struct.pack('>Ib', plen, padding) + payload + pad_bytes
return data
def _kex_payload(self):
w = self.wbuf()
@ -105,3 +128,28 @@ class TestSSH2(object):
kex1 = self._get_kex_variat1()
kex2 = self.ssh2.Kex.parse(self._kex_payload())
assert kex1.payload == kex2.payload
def test_ssh2_server_simple(self, output_spy, virtual_socket):
vsocket = virtual_socket
w = self.wbuf()
w.write_byte(self.ssh.Protocol.MSG_KEXINIT)
w.write(self._kex_payload())
vsocket.rdata.append(b'SSH-2.0-OpenSSH_7.3 ssh-audit-test\r\n')
vsocket.rdata.append(self._create_ssh2_packet(w.write_flush()))
output_spy.begin()
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 72
def test_ssh2_server_invalid_first_packet(self, output_spy, virtual_socket):
vsocket = virtual_socket
w = self.wbuf()
w.write_byte(self.ssh.Protocol.MSG_KEXINIT + 1)
vsocket.rdata.append(b'SSH-2.0-OpenSSH_7.3 ssh-audit-test\r\n')
vsocket.rdata.append(self._create_ssh2_packet(w.write_flush()))
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 3
assert 'unknown message' in lines[-1]