diff --git a/ssh-audit.py b/ssh-audit.py index cdd9b8c..6dd8002 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -26,8 +26,7 @@ from __future__ import print_function import os, io, sys, socket, struct, random, errno, getopt, re -VERSION = 'v1.0.20160908' -SSH_BANNER = 'SSH-2.0-OpenSSH_7.3' +VERSION = 'v1.0.20160915' def usage(err=None): @@ -142,6 +141,121 @@ class Kex(object): return kex +class SSH1(object): + class CRC32(object): + def __init__(self): + self._table = [0] * 256 + for i in range(256): + crc = 0 + n = i + for j in range(8): + x = (crc ^ n) & 1 + crc = (crc >> 1) ^ (x * 0xedb88320) + n = n >> 1 + self._table[i] = crc + + def calc(self, v): + crc, l = 0, len(v) + for i in range(l): + n = ord(v[i:i + 1]) + n = n ^ (crc & 0xff) + crc = (crc >> 8) ^ self._table[n] + return crc + + _crc32 = CRC32() + CIPHERS = [None, 'idea', 'des', '3des', 'tss', 'rc4', 'blowfish'] + AUTHS = [None, 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos'] + + @classmethod + def crc32(cls, v): + return cls._crc32.calc(v) + + class PublicKeyMessage(object): + def __init__(self, cookie, skey, hkey, pflags, cmask, amask): + assert len(skey) == 3 + assert len(hkey) == 3 + self.__cookie = cookie + self.__server_key = skey + self.__host_key = hkey + self.__protocol_flags = pflags + self.__supported_ciphers_mask = cmask + self.__supported_authentications_mask = amask + + @property + def cookie(self): + return self.__cookie + + @property + def server_key_bits(self): + return self.__server_key[0] + + @property + def server_key_public_exponent(self): + return self.__server_key[1] + + @property + def server_key_public_modulus(self): + return self.__server_key[2] + + @property + def host_key_bits(self): + return self.__host_key[0] + + @property + def host_key_public_exponent(self): + return self.__host_key[1] + + @property + def host_key_public_modulus(self): + return self.__host_key[2] + + @property + def protocol_flags(self): + return self.__protocol_flags + + @property + def supported_ciphers_mask(self): + return self.__supported_ciphers_mask + + @property + def supported_ciphers(self): + ciphers = [] + for i in range(len(SSH1.CIPHERS)): + if self.__supported_ciphers_mask & (1 << i) != 0: + ciphers.append(SSH1.CIPHERS[i]) + return ciphers + + @property + def supported_authentications_mask(self): + return self.__supported_authentications_mask + + @property + def supported_authentications(self): + auths = [] + for i in range(len(SSH1.AUTHS)): + if self.__supported_authentications_mask & (1 << i) != 0: + auths.append(SSH1.AUTHS[i]) + return auths + + @classmethod + def parse(cls, payload): + buf = ReadBuf(payload) + cookie = buf.read(8) + server_key_bits = buf.read_int() + server_key_exponent = buf.read_mpint1() + server_key_modulus = buf.read_mpint1() + skey = (server_key_bits, server_key_exponent, server_key_modulus) + host_key_bits = buf.read_int() + host_key_exponent = buf.read_mpint1() + host_key_modulus = buf.read_mpint1() + hkey = (host_key_bits, host_key_exponent, host_key_modulus) + pflags = buf.read_int() + cmask = buf.read_int() + amask = buf.read_int() + pkm = cls(cookie, skey, hkey, pflags, cmask, amask) + return pkm + + class ReadBuf(object): def __init__(self, data=None): super(ReadBuf, self).__init__() @@ -269,6 +383,7 @@ class WriteBuf(object): class SSH(object): class Protocol(object): + SMSG_PUBLIC_KEY = 2 MSG_KEXINIT = 20 MSG_NEWKEYS = 21 MSG_KEXDH_INIT = 30 @@ -516,6 +631,9 @@ class SSH(object): } class Socket(ReadBuf, WriteBuf): + class InsufficientReadException(Exception): + pass + SM_BANNER_SENT = 1 def __init__(self, host, port, cto=3.0, rto=5.0): @@ -534,7 +652,8 @@ class SSH(object): def __enter__(self): return self - def get_banner(self): + def get_banner(self, sshv=2): + banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') rto = self.__sock.gettimeout() self.__sock.settimeout(0.7) s, e = self.recv() @@ -542,7 +661,7 @@ class SSH(object): if s < 0: return self.__banner, self.__header if self.__state < self.SM_BANNER_SENT: - self.send_banner() + self.send_banner(banner) while self.__banner is None: if not s > 0: s, e = self.recv() @@ -586,41 +705,67 @@ class SSH(object): return (-1, e) self.__sock.send(data) - def send_banner(self, banner=SSH_BANNER): + def send_banner(self, banner): self.send(banner.encode() + b'\r\n') if self.__state < self.SM_BANNER_SENT: self.__state = self.SM_BANNER_SENT - def read_packet(self): - while self.unread_len < self.__block_size: + def ensure_read(self, size): + while self.unread_len < size: s, e = self.recv() if s < 0: - if e is None: - e = self.read(self.unread_len).strip() - return -1, e - header = self.read(self.__block_size) - if len(header) == 0: - out.fail('[exception] empty ssh packet (no data)') - sys.exit(1) - packet_size = struct.unpack('>I', header[:4])[0] - rest = header[4:] - lrest = len(rest) - padding = ord(rest[0:1]) - packet_type = ord(rest[1:2]) - if (packet_size - lrest) % self.__block_size != 0: - out.fail('[exception] invalid ssh packet (block size)') - sys.exit(1) - rlen = packet_size - lrest - while self.unread_len < rlen: - s, e = self.recv() - if s < 0: - if e is None: - e = (header + self.read(self.unread_len)).strip() - return -1, e - buf = self.read(rlen) - packet = rest[2:] + buf[0:packet_size - lrest] - payload = packet[0:packet_size - padding] - return packet_type, payload + raise SSH.Socket.InsufficientReadException(e) + + def read_packet(self, sshv=2): + try: + header = WriteBuf() + self.ensure_read(4) + packet_length = self.read_int() + header.write_int(packet_length) + # XXX: validate length + if sshv == 1: + padding_length = (8 - packet_length % 8) + self.ensure_read(padding_length) + padding = self.read(padding_length) + header.write(padding) + payload_length = packet_length + check_size = padding_length + payload_length + else: + self.ensure_read(1) + padding_length = self.read_byte() + header.write_byte(padding_length) + payload_length = packet_length - padding_length - 1 + check_size = 4 + 1 + payload_length + padding_length + if check_size % self.__block_size != 0: + out.fail('[exception] invalid ssh packet (block size)') + sys.exit(1) + self.ensure_read(payload_length) + if sshv == 1: + payload = self.read(payload_length - 4) + header.write(payload) + crc = self.read_int() + header.write_int(crc) + else: + payload = self.read(payload_length) + header.write(payload) + packet_type = ord(payload[0:1]) + if sshv == 1: + rcrc = SSH1.crc32(padding + payload) + if crc != rcrc: + out.fail('[exception] packet checksum CRC32 mismatch.') + sys.exit(1) + else: + self.ensure_read(padding_length) + padding = self.read(padding_length) + payload = payload[1:] + return packet_type, payload + except SSH.Socket.InsufficientReadException as ex: + if ex.args[0] is None: + header.write(self.read(self.unread_len)) + e = header.write_flush().strip() + else: + e = ex.args[0] + return (-1, e) def send_packet(self): payload = self.write_flush() @@ -961,13 +1106,14 @@ def output_security(banner, padlen): out.sep() -def output(banner, header, kex): +def output(banner, header, kex=None, pkm=None): + sshv = 1 if pkm else 2 with OutputBuffer() as obuf: if len(header) > 0: out.info('(gen) header: ' + '\n'.join(header)) if banner is not None: out.good('(gen) banner: {0}'.format(banner)) - if banner.protocol[0] == 1: + if sshv == 1 or banner.protocol[0] == 1: out.fail('(gen) protocol SSH1 enabled') software = SSH.Software.parse(banner) if software is not None: @@ -1045,26 +1191,44 @@ def parse_args(): return host, port -def main(): - host, port = parse_args() +def audit(host, port, sshv=2): s = SSH.Socket(host, port) err = None - banner, header = s.get_banner() + banner, header = s.get_banner(sshv) if banner is None: err = '[exception] did not receive banner.' if err is None: - packet_type, payload = s.read_packet() + packet_type, payload = s.read_packet(sshv) if packet_type < 0: + if payload == b'Protocol major versions differ.': + if sshv == 2: + audit(host, port, 1) + return err = '[exception] error reading packet ({0})'.format(payload) - elif packet_type != SSH.Protocol.MSG_KEXINIT: - err = '[exception] did not receive MSG_KEXINIT (20), ' + \ - 'instead received unknown message ({0})'.format(packet_type) + else: + if sshv == 1 and packet_type != SSH.Protocol.SMSG_PUBLIC_KEY: + err = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY) + elif sshv == 2 and packet_type != SSH.Protocol.MSG_KEXINIT: + err = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT) + if err is not None: + fmt = '[exception] did not receive {0} ({1}), ' + \ + 'instead received unknown message ({2})' + err = fmt.format(err[0], err[1], packet_type) if err: - output(banner, header, None) + output(banner, header) out.fail(err) sys.exit(1) - kex = Kex.parse(payload) - output(banner, header, kex) + if sshv == 1: + pkm = SSH1.PublicKeyMessage.parse(payload) + output(banner, header, pkm=pkm) + elif sshv == 2: + kex = Kex.parse(payload) + output(banner, header, kex=kex) + + +def main(): + host, port = parse_args() + audit(host, port) if __name__ == '__main__':