diff --git a/ssh-audit.py b/ssh-audit.py index f56f51c..d917d44 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -27,8 +27,6 @@ from __future__ import print_function import os, io, sys, socket, struct SSH_BANNER = 'SSH-2.0-OpenSSH_7.2' -SOCK_CONN_TIMEOUT = 3.0 -SOCK_READ_TIMEOUT = 5.0 def usage(): p = os.path.basename(sys.argv[0]) @@ -95,6 +93,7 @@ class Kex(object): class ReadBuf(object): def __init__(self, data = None): + super(ReadBuf, self).__init__() self._buf = io.BytesIO(data) if data else io.BytesIO() self._len = len(data) if data else 0 @@ -118,19 +117,69 @@ class ReadBuf(object): list_size = self.read_int() return self.read(list_size).decode().split(',') -class SockBuf(ReadBuf): - def __init__(self, s): - super(SockBuf, self).__init__() - self.__sock = s +class SSH(object): + MSG_KEXINIT = 20 + MSG_NEWKEYS = 21 + MSG_KEXDH_INIT = 30 + MSG_KEXDH_REPLY = 32 - def recv(self, size = 2048): - data = self.__sock.recv(size) - pos = self._buf.tell() - self._buf.seek(0, 2) - self._buf.write(data) - self._len += len(data) - self._buf.seek(pos, 0) - + class Socket(ReadBuf): + def __init__(self, host, port, cto = 3.0, rto = 5.0): + super(SSH.Socket, self).__init__() + try: + self.__sock = socket.create_connection((host, port), cto) + self.__sock.settimeout(rto) + except Exception as e: + out.fail('[fail] {}'.format(e)) + sys.exit(1) + + def __enter__(self): + return self + + def recv(self, size = 2048): + data = self.__sock.recv(size) + pos = self._buf.tell() + self._buf.seek(0, 2) + self._buf.write(data) + self._len += len(data) + self._buf.seek(pos, 0) + + def send(self, data): + self.__sock.send(data) + + def read_packet(self): + block_size = 8 + if self.unread_len < block_size: + self.recv() + header = self.read(block_size) + packet_size = struct.unpack('>I', header[:4])[0] + rest = header[4:] + lrest = len(rest) + padding = ord(rest[0:1]) + packet_type = ord(rest[1:2]) + if (packet_size - lrest) % block_size != 0: + out.fail('[exception] invalid ssh packet (block size)') + sys.exit(1) + rlen = packet_size - lrest + if self.unread_len < rlen: + self.recv() + buf = self.read(rlen) + packet = rest[2:] + buf[0:packet_size - lrest] + payload = packet[0:packet_size - padding] + return packet_type, payload + + def __del__(self): + self.__cleanup() + + def __exit__(self, ex_type, ex_value, tb): + self.__cleanup() + + def __cleanup(self): + try: + self.__sock.shutdown(socket.SHUT_RDWR) + self.__sock.close() + except: + pass def get_ssh_ver(versions): tv = [] @@ -309,26 +358,6 @@ def process_kex(kex): process_algorithms('mac', kex.server.mac, maxlen) out.sep() -def read_ssh_packet(sbuf): - block_size = 8 - if sbuf.unread_len < block_size: - sbuf.recv() - header = sbuf.read(block_size) - packet_size = struct.unpack('>I', header[:4])[0] - rest = header[4:] - lrest = len(rest) - padding = ord(rest[0:1]) - packet_type = ord(rest[1:2]) - if (packet_size - lrest) % block_size != 0: - out.fail('[exception] invalid ssh packet (block size)') - sys.exit(1) - rlen = packet_size - lrest - if sbuf.unread_len < rlen: - sbuf.recv() - buf = sbuf.read(rlen) - packet = rest[2:] + buf[0:packet_size - lrest] - payload = packet[0:packet_size - padding] - return packet_type, payload def parse_int(v): try: @@ -355,30 +384,20 @@ def parse_args(): def main(): host, port = parse_args() - s = None - try: - s = socket.create_connection((host, port), SOCK_CONN_TIMEOUT) - s.settimeout(SOCK_READ_TIMEOUT) - sbuf = SockBuf(s) - s.send(SSH_BANNER.encode() + b'\r\n') - sbuf.recv() - banner = sbuf.read_line() - out.head('# general') - out.good('[info] banner: ' + banner) - if banner.startswith('SSH-1.99-'): - out.fail('[fail] protocol SSH1 enabled') - packet_type, payload = read_ssh_packet(sbuf) - if packet_type != 20: - out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)) - sys.exit(1) - kex = Kex.parse(payload) - process_kex(kex) - except Exception as e: - out.fail('[fail] {}'.format(e)) + s = SSH.Socket(host, port) + s.send(SSH_BANNER.encode() + b'\r\n') + s.recv() + banner = s.read_line() + out.head('# general') + out.good('[info] banner: ' + banner) + if banner.startswith('SSH-1.99-'): + out.fail('[fail] protocol SSH1 enabled') + packet_type, payload = s.read_packet() + if packet_type != SSH.MSG_KEXINIT: + out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)) sys.exit(1) - finally: - if s: - s.close() + kex = Kex.parse(payload) + process_kex(kex) if __name__ == '__main__': out = Output()