mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-11-03 18:52:15 +01:00 
			
		
		
		
	Refactored IPv4/6 preference logic to fix pylint warnings.
This commit is contained in:
		@@ -43,7 +43,7 @@ class AuditConf:
 | 
			
		||||
        self.json = False
 | 
			
		||||
        self.verbose = False
 | 
			
		||||
        self.level = 'info'
 | 
			
		||||
        self.ipvo: Sequence[int] = ()
 | 
			
		||||
        self.ip_version_preference: List[int] = []  # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6).
 | 
			
		||||
        self.ipv4 = False
 | 
			
		||||
        self.ipv6 = False
 | 
			
		||||
        self.make_policy = False  # When True, creates a policy file from an audit scan.
 | 
			
		||||
@@ -60,28 +60,14 @@ class AuditConf:
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name: str, value: Union[str, int, float, bool, Sequence[int]]) -> None:
 | 
			
		||||
        valid = False
 | 
			
		||||
        if name in ['ssh1', 'ssh2', 'batch', 'client_audit', 'colors', 'verbose', 'timeout_set', 'json', 'make_policy', 'list_policies', 'manual']:
 | 
			
		||||
        if name in ['batch', 'client_audit', 'colors', 'json', 'list_policies', 'manual', 'make_policy', 'ssh1', 'ssh2', 'timeout_set', 'verbose']:
 | 
			
		||||
            valid, value = True, bool(value)
 | 
			
		||||
        elif name in ['ipv4', 'ipv6']:
 | 
			
		||||
            valid = False
 | 
			
		||||
            value = bool(value)
 | 
			
		||||
            ipv = 4 if name == 'ipv4' else 6
 | 
			
		||||
            if value:
 | 
			
		||||
                value = tuple(list(self.ipvo) + [ipv])
 | 
			
		||||
            else:  # pylint: disable=else-if-used
 | 
			
		||||
                if len(self.ipvo) == 0:
 | 
			
		||||
                    value = (6,) if ipv == 4 else (4,)
 | 
			
		||||
                else:
 | 
			
		||||
                    value = tuple([x for x in self.ipvo if x != ipv])
 | 
			
		||||
            self.__setattr__('ipvo', value)
 | 
			
		||||
        elif name == 'ipvo':
 | 
			
		||||
            if isinstance(value, (tuple, list)):
 | 
			
		||||
                uniq_value = Utils.unique_seq(value)
 | 
			
		||||
                value = tuple([x for x in uniq_value if x in (4, 6)])
 | 
			
		||||
                valid = True
 | 
			
		||||
                ipv_both = len(value) == 0
 | 
			
		||||
                object.__setattr__(self, 'ipv4', ipv_both or 4 in value)
 | 
			
		||||
                object.__setattr__(self, 'ipv6', ipv_both or 6 in value)
 | 
			
		||||
            valid, value = True, bool(value)
 | 
			
		||||
            if len(self.ip_version_preference) == 2:  # Being called more than twice is not valid.
 | 
			
		||||
                valid = False
 | 
			
		||||
            elif value:
 | 
			
		||||
                self.ip_version_preference.append(4 if name == 'ipv4' else 6)
 | 
			
		||||
        elif name == 'port':
 | 
			
		||||
            valid, port = True, Utils.parse_int(value)
 | 
			
		||||
            if port < 1 or port > 65535:
 | 
			
		||||
@@ -98,7 +84,7 @@ class AuditConf:
 | 
			
		||||
            if value == -1.0:
 | 
			
		||||
                raise ValueError('invalid timeout: {}'.format(value))
 | 
			
		||||
            valid = True
 | 
			
		||||
        elif name in ['policy_file', 'policy', 'target_file', 'target_list', 'lookup']:
 | 
			
		||||
        elif name in ['ip_version_preference', 'lookup', 'policy_file', 'policy', 'target_file', 'target_list']:
 | 
			
		||||
            valid = True
 | 
			
		||||
        elif name == "threads":
 | 
			
		||||
            valid, num_threads = True, Utils.parse_int(value)
 | 
			
		||||
 
 | 
			
		||||
@@ -815,7 +815,7 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
 | 
			
		||||
    out.verbose = aconf.verbose
 | 
			
		||||
    out.level = aconf.level
 | 
			
		||||
    out.use_colors = aconf.colors
 | 
			
		||||
    s = SSH_Socket(aconf.host, aconf.port, aconf.ipvo, aconf.timeout, aconf.timeout_set)
 | 
			
		||||
    s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
 | 
			
		||||
    if aconf.client_audit:
 | 
			
		||||
        out.v("Listening for client connection on port %d..." % aconf.port, write_now=True)
 | 
			
		||||
        s.listen_and_accept()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
"""
 | 
			
		||||
   The MIT License (MIT)
 | 
			
		||||
 | 
			
		||||
   Copyright (C) 2017-2020 Joe Testa (jtesta@positronsecurity.com)
 | 
			
		||||
   Copyright (C) 2017-2021 Joe Testa (jtesta@positronsecurity.com)
 | 
			
		||||
   Copyright (C) 2017 Andris Raugulis (moo@arthepsy.eu)
 | 
			
		||||
 | 
			
		||||
   Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
@@ -52,7 +52,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
 | 
			
		||||
    SM_BANNER_SENT = 1
 | 
			
		||||
 | 
			
		||||
    def __init__(self, host: Optional[str], port: int, ipvo: Optional[Sequence[int]] = None, timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:
 | 
			
		||||
    def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:  # pylint: disable=dangerous-default-value
 | 
			
		||||
        super(SSH_Socket, self).__init__()
 | 
			
		||||
        self.__sock: Optional[socket.socket] = None
 | 
			
		||||
        self.__sock_map: Dict[int, socket.socket] = {}
 | 
			
		||||
@@ -67,32 +67,27 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
            raise ValueError('invalid port: {}'.format(port))
 | 
			
		||||
        self.__host = host
 | 
			
		||||
        self.__port = nport
 | 
			
		||||
        if ipvo is not None:
 | 
			
		||||
            self.__ipvo = ipvo
 | 
			
		||||
        else:
 | 
			
		||||
            self.__ipvo = ()
 | 
			
		||||
        self.__ip_version_preference = ip_version_preference  # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6).
 | 
			
		||||
        self.__timeout = timeout
 | 
			
		||||
        self.__timeout_set = timeout_set
 | 
			
		||||
        self.client_host: Optional[str] = None
 | 
			
		||||
        self.client_port = None
 | 
			
		||||
 | 
			
		||||
    def _resolve(self, ipvo: Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]:
 | 
			
		||||
        ipvo = tuple([x for x in Utils.unique_seq(ipvo) if x in (4, 6)])
 | 
			
		||||
        ipvo_len = len(ipvo)
 | 
			
		||||
        prefer_ipvo = ipvo_len > 0
 | 
			
		||||
        prefer_ipv4 = prefer_ipvo and ipvo[0] == 4
 | 
			
		||||
        if ipvo_len == 1:
 | 
			
		||||
            family = socket.AF_INET if ipvo[0] == 4 else socket.AF_INET6
 | 
			
		||||
    def _resolve(self) -> Iterable[Tuple[int, Tuple[Any, ...]]]:
 | 
			
		||||
        # If __ip_version_preference has only one entry, then it means that ONLY that IP version should be used.
 | 
			
		||||
        if len(self.__ip_version_preference) == 1:
 | 
			
		||||
            family = socket.AF_INET if self.__ip_version_preference[0] == 4 else socket.AF_INET6
 | 
			
		||||
        else:
 | 
			
		||||
            family = socket.AF_UNSPEC
 | 
			
		||||
        try:
 | 
			
		||||
            stype = socket.SOCK_STREAM
 | 
			
		||||
            r = socket.getaddrinfo(self.__host, self.__port, family, stype)
 | 
			
		||||
            if prefer_ipvo:
 | 
			
		||||
                r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4)
 | 
			
		||||
            check = any(stype == rline[2] for rline in r)
 | 
			
		||||
 | 
			
		||||
            # If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first.
 | 
			
		||||
            if len(self.__ip_version_preference) == 2:
 | 
			
		||||
                r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6))
 | 
			
		||||
            for af, socktype, _proto, _canonname, addr in r:
 | 
			
		||||
                if not check or socktype == socket.SOCK_STREAM:
 | 
			
		||||
                if socktype == socket.SOCK_STREAM:
 | 
			
		||||
                    yield af, addr
 | 
			
		||||
        except socket.error as e:
 | 
			
		||||
            OutputBuffer().fail('[exception] {}'.format(e)).write()
 | 
			
		||||
@@ -156,7 +151,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
 | 
			
		||||
    def connect(self) -> Optional[str]:
 | 
			
		||||
        '''Returns None on success, or an error string.'''
 | 
			
		||||
        err = None
 | 
			
		||||
        for af, addr in self._resolve(self.__ipvo):
 | 
			
		||||
        for af, addr in self._resolve():
 | 
			
		||||
            s = None
 | 
			
		||||
            try:
 | 
			
		||||
                s = socket.socket(af, socket.SOCK_STREAM)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user