mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-11-04 03:02:15 +01:00 
			
		
		
		
	Initial SSH1 support (packet reading, SMSG_PUBLIC_KEY, CRC32, etc) #6.
This commit is contained in:
		
							
								
								
									
										246
									
								
								ssh-audit.py
									
									
									
									
									
								
							
							
						
						
									
										246
									
								
								ssh-audit.py
									
									
									
									
									
								
							@@ -26,8 +26,7 @@
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
import os, io, sys, socket, struct, random, errno, getopt, re
 | 
			
		||||
 | 
			
		||||
VERSION = 'v1.0.20160908'
 | 
			
		||||
SSH_BANNER = 'SSH-2.0-OpenSSH_7.3'
 | 
			
		||||
VERSION = 'v1.0.20160915'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def usage(err=None):
 | 
			
		||||
@@ -142,6 +141,121 @@ class Kex(object):
 | 
			
		||||
		return kex
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SSH1(object):
 | 
			
		||||
	class CRC32(object):
 | 
			
		||||
		def __init__(self):
 | 
			
		||||
			self._table = [0] * 256
 | 
			
		||||
			for i in range(256):
 | 
			
		||||
				crc = 0
 | 
			
		||||
				n = i
 | 
			
		||||
				for j in range(8):
 | 
			
		||||
					x = (crc ^ n) & 1
 | 
			
		||||
					crc = (crc >> 1) ^ (x * 0xedb88320)
 | 
			
		||||
					n = n >> 1
 | 
			
		||||
				self._table[i] = crc
 | 
			
		||||
		
 | 
			
		||||
		def calc(self, v):
 | 
			
		||||
			crc, l = 0, len(v)
 | 
			
		||||
			for i in range(l):
 | 
			
		||||
				n = ord(v[i:i + 1])
 | 
			
		||||
				n = n ^ (crc & 0xff)
 | 
			
		||||
				crc = (crc >> 8) ^ self._table[n]
 | 
			
		||||
			return crc
 | 
			
		||||
	
 | 
			
		||||
	_crc32 = CRC32()
 | 
			
		||||
	CIPHERS = [None, 'idea', 'des', '3des', 'tss', 'rc4', 'blowfish']
 | 
			
		||||
	AUTHS = [None, 'rhosts', 'rsa', 'password', 'rhosts_rsa', 'tis', 'kerberos']
 | 
			
		||||
	
 | 
			
		||||
	@classmethod
 | 
			
		||||
	def crc32(cls, v):
 | 
			
		||||
		return cls._crc32.calc(v)
 | 
			
		||||
	
 | 
			
		||||
	class PublicKeyMessage(object):
 | 
			
		||||
		def __init__(self, cookie, skey, hkey, pflags, cmask, amask):
 | 
			
		||||
			assert len(skey) == 3
 | 
			
		||||
			assert len(hkey) == 3
 | 
			
		||||
			self.__cookie = cookie
 | 
			
		||||
			self.__server_key = skey
 | 
			
		||||
			self.__host_key = hkey
 | 
			
		||||
			self.__protocol_flags = pflags
 | 
			
		||||
			self.__supported_ciphers_mask = cmask
 | 
			
		||||
			self.__supported_authentications_mask = amask
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def cookie(self):
 | 
			
		||||
			return self.__cookie
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def server_key_bits(self):
 | 
			
		||||
			return self.__server_key[0]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def server_key_public_exponent(self):
 | 
			
		||||
			return self.__server_key[1]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def server_key_public_modulus(self):
 | 
			
		||||
			return self.__server_key[2]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def host_key_bits(self):
 | 
			
		||||
			return self.__host_key[0]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def host_key_public_exponent(self):
 | 
			
		||||
			return self.__host_key[1]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def host_key_public_modulus(self):
 | 
			
		||||
			return self.__host_key[2]
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def protocol_flags(self):
 | 
			
		||||
			return self.__protocol_flags
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def supported_ciphers_mask(self):
 | 
			
		||||
			return self.__supported_ciphers_mask
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def supported_ciphers(self):
 | 
			
		||||
			ciphers = []
 | 
			
		||||
			for i in range(len(SSH1.CIPHERS)):
 | 
			
		||||
				if self.__supported_ciphers_mask & (1 << i) != 0:
 | 
			
		||||
					ciphers.append(SSH1.CIPHERS[i])
 | 
			
		||||
			return ciphers
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def supported_authentications_mask(self):
 | 
			
		||||
			return self.__supported_authentications_mask
 | 
			
		||||
		
 | 
			
		||||
		@property
 | 
			
		||||
		def supported_authentications(self):
 | 
			
		||||
			auths = []
 | 
			
		||||
			for i in range(len(SSH1.AUTHS)):
 | 
			
		||||
				if self.__supported_authentications_mask & (1 << i) != 0:
 | 
			
		||||
					auths.append(SSH1.AUTHS[i])
 | 
			
		||||
			return auths
 | 
			
		||||
		
 | 
			
		||||
		@classmethod
 | 
			
		||||
		def parse(cls, payload):
 | 
			
		||||
			buf = ReadBuf(payload)
 | 
			
		||||
			cookie = buf.read(8)
 | 
			
		||||
			server_key_bits = buf.read_int()
 | 
			
		||||
			server_key_exponent = buf.read_mpint1()
 | 
			
		||||
			server_key_modulus = buf.read_mpint1()
 | 
			
		||||
			skey = (server_key_bits, server_key_exponent, server_key_modulus)
 | 
			
		||||
			host_key_bits = buf.read_int()
 | 
			
		||||
			host_key_exponent = buf.read_mpint1()
 | 
			
		||||
			host_key_modulus = buf.read_mpint1()
 | 
			
		||||
			hkey = (host_key_bits, host_key_exponent, host_key_modulus)
 | 
			
		||||
			pflags = buf.read_int()
 | 
			
		||||
			cmask = buf.read_int()
 | 
			
		||||
			amask = buf.read_int()
 | 
			
		||||
			pkm = cls(cookie, skey, hkey, pflags, cmask, amask)
 | 
			
		||||
			return pkm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReadBuf(object):
 | 
			
		||||
	def __init__(self, data=None):
 | 
			
		||||
		super(ReadBuf, self).__init__()
 | 
			
		||||
@@ -269,6 +383,7 @@ class WriteBuf(object):
 | 
			
		||||
 | 
			
		||||
class SSH(object):
 | 
			
		||||
	class Protocol(object):
 | 
			
		||||
		SMSG_PUBLIC_KEY = 2
 | 
			
		||||
		MSG_KEXINIT     = 20
 | 
			
		||||
		MSG_NEWKEYS     = 21
 | 
			
		||||
		MSG_KEXDH_INIT  = 30
 | 
			
		||||
@@ -516,6 +631,9 @@ class SSH(object):
 | 
			
		||||
		}
 | 
			
		||||
	
 | 
			
		||||
	class Socket(ReadBuf, WriteBuf):
 | 
			
		||||
		class InsufficientReadException(Exception):
 | 
			
		||||
			pass
 | 
			
		||||
		
 | 
			
		||||
		SM_BANNER_SENT = 1
 | 
			
		||||
		
 | 
			
		||||
		def __init__(self, host, port, cto=3.0, rto=5.0):
 | 
			
		||||
@@ -534,7 +652,8 @@ class SSH(object):
 | 
			
		||||
		def __enter__(self):
 | 
			
		||||
			return self
 | 
			
		||||
		
 | 
			
		||||
		def get_banner(self):
 | 
			
		||||
		def get_banner(self, sshv=2):
 | 
			
		||||
			banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0')
 | 
			
		||||
			rto = self.__sock.gettimeout()
 | 
			
		||||
			self.__sock.settimeout(0.7)
 | 
			
		||||
			s, e = self.recv()
 | 
			
		||||
@@ -542,7 +661,7 @@ class SSH(object):
 | 
			
		||||
			if s < 0:
 | 
			
		||||
				return self.__banner, self.__header
 | 
			
		||||
			if self.__state < self.SM_BANNER_SENT:
 | 
			
		||||
				self.send_banner()
 | 
			
		||||
				self.send_banner(banner)
 | 
			
		||||
			while self.__banner is None:
 | 
			
		||||
				if not s > 0:
 | 
			
		||||
					s, e = self.recv()
 | 
			
		||||
@@ -586,41 +705,67 @@ class SSH(object):
 | 
			
		||||
				return (-1, e)
 | 
			
		||||
			self.__sock.send(data)
 | 
			
		||||
		
 | 
			
		||||
		def send_banner(self, banner=SSH_BANNER):
 | 
			
		||||
		def send_banner(self, banner):
 | 
			
		||||
			self.send(banner.encode() + b'\r\n')
 | 
			
		||||
			if self.__state < self.SM_BANNER_SENT:
 | 
			
		||||
				self.__state = self.SM_BANNER_SENT
 | 
			
		||||
		
 | 
			
		||||
		def read_packet(self):
 | 
			
		||||
			while self.unread_len < self.__block_size:
 | 
			
		||||
		def ensure_read(self, size):
 | 
			
		||||
			while self.unread_len < size:
 | 
			
		||||
				s, e = self.recv()
 | 
			
		||||
				if s < 0:
 | 
			
		||||
					if e is None:
 | 
			
		||||
						e = self.read(self.unread_len).strip()
 | 
			
		||||
					return -1, e
 | 
			
		||||
			header = self.read(self.__block_size)
 | 
			
		||||
			if len(header) == 0:
 | 
			
		||||
				out.fail('[exception] empty ssh packet (no data)')
 | 
			
		||||
				sys.exit(1)
 | 
			
		||||
			packet_size = struct.unpack('>I', header[:4])[0]
 | 
			
		||||
			rest = header[4:]
 | 
			
		||||
			lrest = len(rest)
 | 
			
		||||
			padding = ord(rest[0:1])
 | 
			
		||||
			packet_type = ord(rest[1:2])
 | 
			
		||||
			if (packet_size - lrest) % self.__block_size != 0:
 | 
			
		||||
					raise SSH.Socket.InsufficientReadException(e)
 | 
			
		||||
		
 | 
			
		||||
		def read_packet(self, sshv=2):
 | 
			
		||||
			try:
 | 
			
		||||
				header = WriteBuf()
 | 
			
		||||
				self.ensure_read(4)
 | 
			
		||||
				packet_length = self.read_int()
 | 
			
		||||
				header.write_int(packet_length)
 | 
			
		||||
				# XXX: validate length
 | 
			
		||||
				if sshv == 1:
 | 
			
		||||
					padding_length = (8 - packet_length % 8)
 | 
			
		||||
					self.ensure_read(padding_length)
 | 
			
		||||
					padding = self.read(padding_length)
 | 
			
		||||
					header.write(padding)
 | 
			
		||||
					payload_length = packet_length
 | 
			
		||||
					check_size = padding_length + payload_length
 | 
			
		||||
				else:
 | 
			
		||||
					self.ensure_read(1)
 | 
			
		||||
					padding_length = self.read_byte()
 | 
			
		||||
					header.write_byte(padding_length)
 | 
			
		||||
					payload_length = packet_length - padding_length - 1
 | 
			
		||||
					check_size = 4 + 1 + payload_length + padding_length
 | 
			
		||||
				if check_size % self.__block_size != 0:
 | 
			
		||||
					out.fail('[exception] invalid ssh packet (block size)')
 | 
			
		||||
					sys.exit(1)
 | 
			
		||||
			rlen = packet_size - lrest
 | 
			
		||||
			while self.unread_len < rlen:
 | 
			
		||||
				s, e = self.recv()
 | 
			
		||||
				if s < 0:
 | 
			
		||||
					if e is None:
 | 
			
		||||
						e = (header + self.read(self.unread_len)).strip()
 | 
			
		||||
					return -1, e
 | 
			
		||||
			buf = self.read(rlen)
 | 
			
		||||
			packet = rest[2:] + buf[0:packet_size - lrest]
 | 
			
		||||
			payload = packet[0:packet_size - padding]
 | 
			
		||||
				self.ensure_read(payload_length)
 | 
			
		||||
				if sshv == 1:
 | 
			
		||||
					payload = self.read(payload_length - 4)
 | 
			
		||||
					header.write(payload)
 | 
			
		||||
					crc = self.read_int()
 | 
			
		||||
					header.write_int(crc)
 | 
			
		||||
				else:
 | 
			
		||||
					payload = self.read(payload_length)
 | 
			
		||||
					header.write(payload)
 | 
			
		||||
				packet_type = ord(payload[0:1])
 | 
			
		||||
				if sshv == 1:
 | 
			
		||||
					rcrc = SSH1.crc32(padding + payload)
 | 
			
		||||
					if crc != rcrc:
 | 
			
		||||
						out.fail('[exception] packet checksum CRC32 mismatch.')
 | 
			
		||||
						sys.exit(1)
 | 
			
		||||
				else:
 | 
			
		||||
					self.ensure_read(padding_length)
 | 
			
		||||
					padding = self.read(padding_length)
 | 
			
		||||
				payload = payload[1:]
 | 
			
		||||
				return packet_type, payload
 | 
			
		||||
			except SSH.Socket.InsufficientReadException as ex:
 | 
			
		||||
				if ex.args[0] is None:
 | 
			
		||||
					header.write(self.read(self.unread_len))
 | 
			
		||||
					e = header.write_flush().strip()
 | 
			
		||||
				else:
 | 
			
		||||
					e = ex.args[0]
 | 
			
		||||
				return (-1, e)
 | 
			
		||||
		
 | 
			
		||||
		def send_packet(self):
 | 
			
		||||
			payload = self.write_flush()
 | 
			
		||||
@@ -961,13 +1106,14 @@ def output_security(banner, padlen):
 | 
			
		||||
		out.sep()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def output(banner, header, kex):
 | 
			
		||||
def output(banner, header, kex=None, pkm=None):
 | 
			
		||||
	sshv = 1 if pkm else 2
 | 
			
		||||
	with OutputBuffer() as obuf:
 | 
			
		||||
		if len(header) > 0:
 | 
			
		||||
			out.info('(gen) header: ' + '\n'.join(header))
 | 
			
		||||
		if banner is not None:
 | 
			
		||||
			out.good('(gen) banner: {0}'.format(banner))
 | 
			
		||||
			if banner.protocol[0] == 1:
 | 
			
		||||
			if sshv == 1 or banner.protocol[0] == 1:
 | 
			
		||||
				out.fail('(gen) protocol SSH1 enabled')
 | 
			
		||||
			software = SSH.Software.parse(banner)
 | 
			
		||||
			if software is not None:
 | 
			
		||||
@@ -1045,26 +1191,44 @@ def parse_args():
 | 
			
		||||
	return host, port
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
	host, port = parse_args()
 | 
			
		||||
def audit(host, port, sshv=2):
 | 
			
		||||
	s = SSH.Socket(host, port)
 | 
			
		||||
	err = None
 | 
			
		||||
	banner, header = s.get_banner()
 | 
			
		||||
	banner, header = s.get_banner(sshv)
 | 
			
		||||
	if banner is None:
 | 
			
		||||
		err = '[exception] did not receive banner.'
 | 
			
		||||
	if err is None:
 | 
			
		||||
		packet_type, payload = s.read_packet()
 | 
			
		||||
		packet_type, payload = s.read_packet(sshv)
 | 
			
		||||
		if packet_type < 0:
 | 
			
		||||
			if payload == b'Protocol major versions differ.':
 | 
			
		||||
				if sshv == 2:
 | 
			
		||||
					audit(host, port, 1)
 | 
			
		||||
					return
 | 
			
		||||
			err = '[exception] error reading packet ({0})'.format(payload)
 | 
			
		||||
		elif packet_type != SSH.Protocol.MSG_KEXINIT:
 | 
			
		||||
			err = '[exception] did not receive MSG_KEXINIT (20), ' + \
 | 
			
		||||
			      'instead received unknown message ({0})'.format(packet_type)
 | 
			
		||||
		else:
 | 
			
		||||
			if sshv == 1 and packet_type != SSH.Protocol.SMSG_PUBLIC_KEY:
 | 
			
		||||
				err = ('SMSG_PUBLIC_KEY', SSH.Protocol.SMSG_PUBLIC_KEY)
 | 
			
		||||
			elif sshv == 2 and packet_type != SSH.Protocol.MSG_KEXINIT:
 | 
			
		||||
				err = ('MSG_KEXINIT', SSH.Protocol.MSG_KEXINIT)
 | 
			
		||||
			if err is not None:
 | 
			
		||||
				fmt = '[exception] did not receive {0} ({1}), ' + \
 | 
			
		||||
				      'instead received unknown message ({2})'
 | 
			
		||||
				err = fmt.format(err[0], err[1], packet_type)
 | 
			
		||||
	if err:
 | 
			
		||||
		output(banner, header, None)
 | 
			
		||||
		output(banner, header)
 | 
			
		||||
		out.fail(err)
 | 
			
		||||
		sys.exit(1)
 | 
			
		||||
	if sshv == 1:
 | 
			
		||||
		pkm = SSH1.PublicKeyMessage.parse(payload)
 | 
			
		||||
		output(banner, header, pkm=pkm)
 | 
			
		||||
	elif sshv == 2:
 | 
			
		||||
		kex = Kex.parse(payload)
 | 
			
		||||
	output(banner, header, kex)
 | 
			
		||||
		output(banner, header, kex=kex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
	host, port = parse_args()
 | 
			
		||||
	audit(host, port)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user