Initial SSH1 support (packet reading, SMSG_PUBLIC_KEY, CRC32, etc) #6.

This commit is contained in:
Andris Raugulis 2016-09-15 18:00:09 +03:00
parent d6980242ba
commit 9030e71892

View File

@ -26,8 +26,7 @@
from __future__ import print_function from __future__ import print_function
import os, io, sys, socket, struct, random, errno, getopt, re import os, io, sys, socket, struct, random, errno, getopt, re
VERSION = 'v1.0.20160908' VERSION = 'v1.0.20160915'
SSH_BANNER = 'SSH-2.0-OpenSSH_7.3'
def usage(err=None): def usage(err=None):
@ -142,6 +141,121 @@ class Kex(object):
return kex return kex
class SSH1(object):
class CRC32(object):
def __init__(self):
self._table = [0] * 256
for i in range(256):
crc = 0
n = i
for j in range(8):
x = (crc ^ n) & 1
crc = (crc >> 1) ^ (x * 0xedb88320)
n = n >> 1
self._table[i] = crc
def calc(self, v):
crc, l = 0, len(v)
for i in range(l):
n = ord(v[i:i + 1])
n = n ^ (crc & 0xff)
crc = (crc >> 8) ^ self._table[n]
return crc
_crc32 = CRC32()
CIPHERS = [None, 'idea', 'des', '3des', 'tss', 'rc4', 'blowfish']
AUTHS = [None, 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos']
@classmethod
def crc32(cls, v):
return cls._crc32.calc(v)
class PublicKeyMessage(object):
def __init__(self, cookie, skey, hkey, pflags, cmask, amask):
assert len(skey) == 3
assert len(hkey) == 3
self.__cookie = cookie
self.__server_key = skey
self.__host_key = hkey
self.__protocol_flags = pflags
self.__supported_ciphers_mask = cmask
self.__supported_authentications_mask = amask
@property
def cookie(self):
return self.__cookie
@property
def server_key_bits(self):
return self.__server_key[0]
@property
def server_key_public_exponent(self):
return self.__server_key[1]
@property
def server_key_public_modulus(self):
return self.__server_key[2]
@property
def host_key_bits(self):
return self.__host_key[0]
@property
def host_key_public_exponent(self):
return self.__host_key[1]
@property
def host_key_public_modulus(self):
return self.__host_key[2]
@property
def protocol_flags(self):
return self.__protocol_flags
@property
def supported_ciphers_mask(self):
return self.__supported_ciphers_mask
@property
def supported_ciphers(self):
ciphers = []
for i in range(len(SSH1.CIPHERS)):
if self.__supported_ciphers_mask & (1 << i) != 0:
ciphers.append(SSH1.CIPHERS[i])
return ciphers
@property
def supported_authentications_mask(self):
return self.__supported_authentications_mask
@property
def supported_authentications(self):
auths = []
for i in range(len(SSH1.AUTHS)):
if self.__supported_authentications_mask & (1 << i) != 0:
auths.append(SSH1.AUTHS[i])
return auths
@classmethod
def parse(cls, payload):
buf = ReadBuf(payload)
cookie = buf.read(8)
server_key_bits = buf.read_int()
server_key_exponent = buf.read_mpint1()
server_key_modulus = buf.read_mpint1()
skey = (server_key_bits, server_key_exponent, server_key_modulus)
host_key_bits = buf.read_int()
host_key_exponent = buf.read_mpint1()
host_key_modulus = buf.read_mpint1()
hkey = (host_key_bits, host_key_exponent, host_key_modulus)
pflags = buf.read_int()
cmask = buf.read_int()
amask = buf.read_int()
pkm = cls(cookie, skey, hkey, pflags, cmask, amask)
return pkm
class ReadBuf(object): class ReadBuf(object):
def __init__(self, data=None): def __init__(self, data=None):
super(ReadBuf, self).__init__() super(ReadBuf, self).__init__()
@ -269,6 +383,7 @@ class WriteBuf(object):
class SSH(object): class SSH(object):
class Protocol(object): class Protocol(object):
SMSG_PUBLIC_KEY = 2
MSG_KEXINIT = 20 MSG_KEXINIT = 20
MSG_NEWKEYS = 21 MSG_NEWKEYS = 21
MSG_KEXDH_INIT = 30 MSG_KEXDH_INIT = 30
@ -516,6 +631,9 @@ class SSH(object):
} }
class Socket(ReadBuf, WriteBuf): class Socket(ReadBuf, WriteBuf):
class InsufficientReadException(Exception):
pass
SM_BANNER_SENT = 1 SM_BANNER_SENT = 1
def __init__(self, host, port, cto=3.0, rto=5.0): def __init__(self, host, port, cto=3.0, rto=5.0):
@ -534,7 +652,8 @@ class SSH(object):
def __enter__(self): def __enter__(self):
return self return self
def get_banner(self): def get_banner(self, sshv=2):
banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0')
rto = self.__sock.gettimeout() rto = self.__sock.gettimeout()
self.__sock.settimeout(0.7) self.__sock.settimeout(0.7)
s, e = self.recv() s, e = self.recv()
@ -542,7 +661,7 @@ class SSH(object):
if s < 0: if s < 0:
return self.__banner, self.__header return self.__banner, self.__header
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.send_banner() self.send_banner(banner)
while self.__banner is None: while self.__banner is None:
if not s > 0: if not s > 0:
s, e = self.recv() s, e = self.recv()
@ -586,41 +705,67 @@ class SSH(object):
return (-1, e) return (-1, e)
self.__sock.send(data) self.__sock.send(data)
def send_banner(self, banner=SSH_BANNER): def send_banner(self, banner):
self.send(banner.encode() + b'\r\n') self.send(banner.encode() + b'\r\n')
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.__state = self.SM_BANNER_SENT self.__state = self.SM_BANNER_SENT
def read_packet(self): def ensure_read(self, size):
while self.unread_len < self.__block_size: while self.unread_len < size:
s, e = self.recv() s, e = self.recv()
if s < 0: if s < 0:
if e is None: raise SSH.Socket.InsufficientReadException(e)
e = self.read(self.unread_len).strip()
return -1, e def read_packet(self, sshv=2):
header = self.read(self.__block_size) try:
if len(header) == 0: header = WriteBuf()
out.fail('[exception] empty ssh packet (no data)') self.ensure_read(4)
sys.exit(1) packet_length = self.read_int()
packet_size = struct.unpack('>I', header[:4])[0] header.write_int(packet_length)
rest = header[4:] # XXX: validate length
lrest = len(rest) if sshv == 1:
padding = ord(rest[0:1]) padding_length = (8 - packet_length % 8)
packet_type = ord(rest[1:2]) self.ensure_read(padding_length)
if (packet_size - lrest) % self.__block_size != 0: padding = self.read(padding_length)
out.fail('[exception] invalid ssh packet (block size)') header.write(padding)
sys.exit(1) payload_length = packet_length
rlen = packet_size - lrest check_size = padding_length + payload_length
while self.unread_len < rlen: else:
s, e = self.recv() self.ensure_read(1)
if s < 0: padding_length = self.read_byte()
if e is None: header.write_byte(padding_length)
e = (header + self.read(self.unread_len)).strip() payload_length = packet_length - padding_length - 1
return -1, e check_size = 4 + 1 + payload_length + padding_length
buf = self.read(rlen) if check_size % self.__block_size != 0:
packet = rest[2:] + buf[0:packet_size - lrest] out.fail('[exception] invalid ssh packet (block size)')
payload = packet[0:packet_size - padding] sys.exit(1)
return packet_type, payload self.ensure_read(payload_length)
if sshv == 1:
payload = self.read(payload_length - 4)
header.write(payload)
crc = self.read_int()
header.write_int(crc)
else:
payload = self.read(payload_length)
header.write(payload)
packet_type = ord(payload[0:1])
if sshv == 1:
rcrc = SSH1.crc32(padding + payload)
if crc != rcrc:
out.fail('[exception] packet checksum CRC32 mismatch.')
sys.exit(1)
else:
self.ensure_read(padding_length)
padding = self.read(padding_length)
payload = payload[1:]
return packet_type, payload
except SSH.Socket.InsufficientReadException as ex:
if ex.args[0] is None:
header.write(self.read(self.unread_len))
e = header.write_flush().strip()
else:
e = ex.args[0]
return (-1, e)
def send_packet(self): def send_packet(self):
payload = self.write_flush() payload = self.write_flush()
@ -961,13 +1106,14 @@ def output_security(banner, padlen):
out.sep() out.sep()
def output(banner, header, kex): def output(banner, header, kex=None, pkm=None):
sshv = 1 if pkm else 2
with OutputBuffer() as obuf: with OutputBuffer() as obuf:
if len(header) > 0: if len(header) > 0:
out.info('(gen) header: ' + '\n'.join(header)) out.info('(gen) header: ' + '\n'.join(header))
if banner is not None: if banner is not None:
out.good('(gen) banner: {0}'.format(banner)) out.good('(gen) banner: {0}'.format(banner))
if banner.protocol[0] == 1: if sshv == 1 or banner.protocol[0] == 1:
out.fail('(gen) protocol SSH1 enabled') out.fail('(gen) protocol SSH1 enabled')
software = SSH.Software.parse(banner) software = SSH.Software.parse(banner)
if software is not None: if software is not None:
@ -1045,26 +1191,44 @@ def parse_args():
return host, port return host, port
def main(): def audit(host, port, sshv=2):
host, port = parse_args()
s = SSH.Socket(host, port) s = SSH.Socket(host, port)
err = None err = None
banner, header = s.get_banner() banner, header = s.get_banner(sshv)
if banner is None: if banner is None:
err = '[exception] did not receive banner.' err = '[exception] did not receive banner.'
if err is None: if err is None:
packet_type, payload = s.read_packet() packet_type, payload = s.read_packet(sshv)
if packet_type < 0: if packet_type < 0:
if payload == b'Protocol major versions differ.':
if sshv == 2:
audit(host, port, 1)
return
err = '[exception] error reading packet ({0})'.format(payload) err = '[exception] error reading packet ({0})'.format(payload)
elif packet_type != SSH.Protocol.MSG_KEXINIT: else:
err = '[exception] did not receive MSG_KEXINIT (20), ' + \ if sshv == 1 and packet_type != SSH.Protocol.SMSG_PUBLIC_KEY:
'instead received unknown message ({0})'.format(packet_type) err = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY)
elif sshv == 2 and packet_type != SSH.Protocol.MSG_KEXINIT:
err = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT)
if err is not None:
fmt = '[exception] did not receive {0} ({1}), ' + \
'instead received unknown message ({2})'
err = fmt.format(err[0], err[1], packet_type)
if err: if err:
output(banner, header, None) output(banner, header)
out.fail(err) out.fail(err)
sys.exit(1) sys.exit(1)
kex = Kex.parse(payload) if sshv == 1:
output(banner, header, kex) pkm = SSH1.PublicKeyMessage.parse(payload)
output(banner, header, pkm=pkm)
elif sshv == 2:
kex = Kex.parse(payload)
output(banner, header, kex=kex)
def main():
host, port = parse_args()
audit(host, port)
if __name__ == '__main__': if __name__ == '__main__':