diff --git a/ssh-audit.py b/ssh-audit.py index 321ac00..3193674 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -117,16 +117,50 @@ class ReadBuf(object): list_size = self.read_int() return self.read(list_size).decode().split(',') +class WriteBuf(object): + def __init__(self, data = None): + super(WriteBuf, self).__init__() + self._wbuf = io.BytesIO(data) if data else io.BytesIO() + + def write(self, data): + self._wbuf.write(data) + + def write_byte(self, v): + self.write(struct.pack('>B', v)) + + def write_bool(self, v): + self.write_byte(1 if v else 0) + + def write_int(self, v): + self.write(struct.pack('>I', v)) + + def write_string(self, v): + if not isinstance(v, bytes): + v = bytes(bytearray(v, 'utf-8')) + n = len(v) + self.write(struct.pack('>I', n)) + self.write(v) + + def write_list(self, v): + self.write_string(','.join(v)) + + def write_mpint(self, v): + length = v.bit_length() // 8 + 1 + data = [(v >> i*8) & 0xff for i in reversed(range(length))] + data = bytes(bytearray(data)) + self.write_string(data) + class SSH(object): MSG_KEXINIT = 20 MSG_NEWKEYS = 21 MSG_KEXDH_INIT = 30 MSG_KEXDH_REPLY = 32 - class Socket(ReadBuf): + class Socket(ReadBuf, WriteBuf): SM_BANNER_SENT = 1 def __init__(self, host, port, cto = 3.0, rto = 5.0): + self.__block_size = 8 self.__state = 0 self.__banner = None super(SSH.Socket, self).__init__() @@ -165,16 +199,15 @@ class SSH(object): self.__state = self.SM_BANNER_SENT def read_packet(self): - block_size = 8 - if self.unread_len < block_size: + if self.unread_len < self.__block_size: self.recv() - header = self.read(block_size) + header = self.read(self.__block_size) packet_size = struct.unpack('>I', header[:4])[0] rest = header[4:] lrest = len(rest) padding = ord(rest[0:1]) packet_type = ord(rest[1:2]) - if (packet_size - lrest) % block_size != 0: + if (packet_size - lrest) % self.__block_size != 0: out.fail('[exception] invalid ssh packet (block size)') sys.exit(1) rlen = packet_size - lrest @@ -185,6 +218,18 @@ class SSH(object): payload = packet[0:packet_size - padding] return packet_type, payload + def send_packet(self): + payload = self._wbuf.getvalue() + self._wbuf.truncate(0) + self._wbuf.seek(0) + padding = -(len(payload) + 5) % 8 + if padding < 4: + padding += 8 + plen = len(payload) + padding + 1 + pad_bytes = '\x00' * padding + data = struct.pack('>Ib', plen, padding) + payload + pad_bytes + self.send(data) + def __del__(self): self.__cleanup()