diff --git a/ssh-audit.py b/ssh-audit.py index 4e64ec5..7d6d641 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from __future__ import print_function -import binascii, os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 +import base64, binascii, errno, hashlib, getopt, io, os, random, re, select, socket, struct, sys VERSION = 'v2.1.0-dev' SSH_HEADER = 'SSH-{0}-OpenSSH_8.0' # SSH software to impersonate @@ -1957,7 +1957,7 @@ class SSH(object): # pylint: disable=too-few-public-methods # type: (Optional[str], int) -> None super(SSH.Socket, self).__init__() self.__sock = None # type: Optional[socket.socket] - self.__sock_server = None + self.__sock_map = {} self.__block_size = 8 self.__state = 0 self.__header = [] # type: List[text_type] @@ -2003,15 +2003,27 @@ class SSH(object): # pylint: disable=too-few-public-methods # Listens on a server socket and accepts one connection (used for # auditing client connections). def listen_and_accept(self): + + # Socket to listen on all IPv4 addresses. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.__sock_server = s - - # TODO: listen on IPv6 address if necessary. s.bind(('0.0.0.0', self.__port)) s.listen() + self.__sock_map[s.fileno()] = s - c, addr = s.accept() + # Socket to listen on all IPv6 addresses. + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + s.bind(('::', self.__port)) + s.listen() + self.__sock_map[s.fileno()] = s + + # Wait for a connection on either socket. + fds = select.select(self.__sock_map.keys(), [], []) + + # Accept the connection. + c, addr = self.__sock_map[fds[0][0]].accept() self.client_host = addr[0] self.client_port = addr[1] c.settimeout(self.__timeout) @@ -2209,7 +2221,8 @@ class SSH(object): # pylint: disable=too-few-public-methods def __cleanup(self): # type: () -> None self._close_socket(self.__sock) - self._close_socket(self.__sock_server) + for fd in self.__sock_map: + self._close_socket(self.__sock_map[fd]) self.__sock = None