Ensure reading enough data.

This commit is contained in:
Andris Raugulis 2016-01-05 14:10:02 +02:00
parent 173f3d79de
commit c485ffb01e

View File

@ -78,7 +78,7 @@ class Kex(object):
def parse(cls, payload): def parse(cls, payload):
kex = cls() kex = cls()
buf = ReadBuf(payload) buf = ReadBuf(payload)
kex.cookie = buf.read_raw(16) kex.cookie = buf.read(16)
kex.kex_algorithms = buf.read_list() kex.kex_algorithms = buf.read_list()
kex.key_algorithms = buf.read_list() kex.key_algorithms = buf.read_list()
kex.client.encryption = buf.read_list() kex.client.encryption = buf.read_list()
@ -102,21 +102,21 @@ class ReadBuf(object):
def unread_len(self): def unread_len(self):
return self._len - self._buf.tell() return self._len - self._buf.tell()
def read_raw(self, size): def read(self, size):
return self._buf.read(size) return self._buf.read(size)
def read_line(self): def read_line(self):
return self._buf.readline().rstrip().decode('utf-8') return self._buf.readline().rstrip().decode('utf-8')
def read_int(self): def read_int(self):
return struct.unpack('>I', self._buf.read(4))[0] return struct.unpack('>I', self.read(4))[0]
def read_bool(self): def read_bool(self):
return struct.unpack('b', self._buf.read(1))[0] != 0 return struct.unpack('b', self.read(1))[0] != 0
def read_list(self): def read_list(self):
list_size = self.read_int() list_size = self.read_int()
return self._buf.read(list_size).decode().split(',') return self.read(list_size).decode().split(',')
class SockBuf(ReadBuf): class SockBuf(ReadBuf):
def __init__(self, s): def __init__(self, s):
@ -292,7 +292,9 @@ def process_kex(kex):
def read_ssh_packet(sbuf): def read_ssh_packet(sbuf):
block_size = 8 block_size = 8
header = sbuf.read_raw(block_size) if sbuf.unread_len < block_size:
sbuf.recv()
header = sbuf.read(block_size)
packet_size = struct.unpack('>I', header[:4])[0] packet_size = struct.unpack('>I', header[:4])[0]
rest = header[4:] rest = header[4:]
lrest = len(rest) lrest = len(rest)
@ -301,7 +303,10 @@ def read_ssh_packet(sbuf):
if (packet_size - lrest) % block_size != 0: if (packet_size - lrest) % block_size != 0:
out.fail('[exception] invalid ssh packet (block size)') out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1) sys.exit(1)
buf = sbuf.read_raw(packet_size - lrest) rlen = packet_size - lrest
if sbuf.unread_len < rlen:
sbuf.recv()
buf = sbuf.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest] packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding] payload = packet[0:packet_size - padding]
return packet_type, payload return packet_type, payload
@ -343,8 +348,6 @@ def main():
out.good('[info] banner: ' + banner) out.good('[info] banner: ' + banner)
if banner.startswith('SSH-1.99-'): if banner.startswith('SSH-1.99-'):
out.fail('[fail] protocol SSH1 enabled') out.fail('[fail] protocol SSH1 enabled')
if sbuf.unread_len == 0:
sbuf.recv()
packet_type, payload = read_ssh_packet(sbuf) packet_type, payload = read_ssh_packet(sbuf)
if packet_type != 20: if packet_type != 20:
out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)) out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type))