diff --git a/ssh-audit.py b/ssh-audit.py index d917d44..321ac00 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -124,7 +124,11 @@ class SSH(object): MSG_KEXDH_REPLY = 32 class Socket(ReadBuf): + SM_BANNER_SENT = 1 + def __init__(self, host, port, cto = 3.0, rto = 5.0): + self.__state = 0 + self.__banner = None super(SSH.Socket, self).__init__() try: self.__sock = socket.create_connection((host, port), cto) @@ -136,6 +140,14 @@ class SSH(object): def __enter__(self): return self + def get_banner(self): + if self.__state < self.SM_BANNER_SENT: + self.send_banner() + if self.__banner is None: + self.recv() + self.__banner = self.read_line() + return self.__banner + def recv(self, size = 2048): data = self.__sock.recv(size) pos = self._buf.tell() @@ -147,6 +159,11 @@ class SSH(object): def send(self, data): self.__sock.send(data) + def send_banner(self, banner = SSH_BANNER): + self.send(banner.encode() + b'\r\n') + if self.__state < self.SM_BANNER_SENT: + self.__state = self.SM_BANNER_SENT + def read_packet(self): block_size = 8 if self.unread_len < block_size: @@ -385,9 +402,7 @@ def parse_args(): def main(): host, port = parse_args() s = SSH.Socket(host, port) - s.send(SSH_BANNER.encode() + b'\r\n') - s.recv() - banner = s.read_line() + banner = s.get_banner() out.head('# general') out.good('[info] banner: ' + banner) if banner.startswith('SSH-1.99-'):