diff --git a/src/ssh_audit/policy.py b/src/ssh_audit/policy.py index 2463c42..fdb59c0 100644 --- a/src/ssh_audit/policy.py +++ b/src/ssh_audit/policy.py @@ -48,13 +48,9 @@ class Policy: self._compressions: Optional[List[str]] = None self._host_keys: Optional[List[str]] = None self._optional_host_keys: Optional[List[str]] = None - self._allowed_host_keys: Optional[List[str]] = None self._kex: Optional[List[str]] = None - self._allowed_kex: Optional[List[str]] = None self._ciphers: Optional[List[str]] = None - self._allowed_ciphers: Optional[List[str]] = None self._macs: Optional[List[str]] = None - self._allowed_macs: Optional[List[str]] = None self._hostkey_sizes: Optional[Dict[str, Dict[str, Union[int, str, bytes]]]] = None self._dh_modulus_sizes: Optional[Dict[str, int]] = None self._server_policy = True @@ -117,7 +113,7 @@ class Policy: key = key.strip() val = val.strip() - if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'allowed host keys', 'key exchanges', 'allowed key exchanges', 'ciphers', 'allowed ciphers', 'macs', 'allowed macs', 'client policy', 'host_key_sizes', 'dh_modulus_sizes', 'allow_algorithm_subset_and_reordering'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'): + if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs', 'client policy', 'host_key_sizes', 'dh_modulus_sizes', 'allow_algorithm_subset_and_reordering'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'): raise ValueError("invalid field found in policy: %s" % line) if key in ['name', 'banner']: @@ -140,7 +136,7 @@ class Policy: elif key == 'version': self._version = val - elif key in ['compressions', 'host keys', 'optional host keys', 'allowed host keys', 'key exchanges', 'allowed key exchanges', 'ciphers', 'allowed ciphers', 'macs', 'allowed macs']: + elif key in ['compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs']: try: algs = val.split(',') except ValueError: @@ -155,21 +151,13 @@ class Policy: elif key == 'host keys': self._host_keys = algs elif key == 'optional host keys': - self._optional_host_keys = algs - elif key == 'allowed host keys': - self._allowed_host_keys = algs + self._optional_host_keys = algs elif key == 'key exchanges': self._kex = algs - elif key == 'allowed key exchanges': - self._allowed_kex = algs elif key == 'ciphers': self._ciphers = algs - elif key == 'allowed ciphers': - self._allowed_ciphers = algs elif key == 'macs': self._macs = algs - elif key == 'allowed macs': - self._allowed_macs = algs elif key.startswith('hostkey_size_'): # Old host key size format. print(Policy.WARNING_DEPRECATED_DIRECTIVES, file=self._warning_target) # Warn the user that the policy file is using deprecated directives. @@ -230,14 +218,12 @@ class Policy: @staticmethod - def _append_error(errors: List[Any], mismatched_field: str, expected_required: Optional[List[str]], expected_allowed: Optional[List[str]], expected_optional: Optional[List[str]], actual: List[str]) -> None: + def _append_error(errors: List[Any], mismatched_field: str, expected_required: Optional[List[str]], expected_optional: Optional[List[str]], actual: List[str]) -> None: if expected_required is None: expected_required = [''] if expected_optional is None: expected_optional = [''] - if expected_allowed is None: - expected_allowed = [''] - errors.append({'mismatched_field': mismatched_field, 'expected_required': expected_required, 'expected_allowed': expected_allowed, 'expected_optional': expected_optional, 'actual': actual}) + errors.append({'mismatched_field': mismatched_field, 'expected_required': expected_required, 'expected_optional': expected_optional, 'actual': actual}) def _normalize_hostkey_sizes(self) -> None: @@ -344,42 +330,32 @@ macs = %s banner_str = str(banner) if (self._banner is not None) and (banner_str != self._banner): ret = False - self._append_error(errors, 'Banner', [self._banner], None, None, [banner_str]) + self._append_error(errors, 'Banner', [self._banner], None, [banner_str]) # All subsequent tests require a valid kex, so end here if we don't have one. if kex is None: - return ret, errors, self._get_error_str(errors) + return ret, errors, self._get_error_str(errors, self._allow_algorithm_subset_and_reordering) if (self._compressions is not None) and (kex.server.compression != self._compressions): ret = False - self._append_error(errors, 'Compression', self._compressions, None, None, kex.server.compression) + self._append_error(errors, 'Compression', self._compressions, None, kex.server.compression) # If a list of optional host keys was given in the policy, remove any of its entries from the list retrieved from the server. This allows us to do an exact comparison with the expected list below. pruned_host_keys = kex.key_algorithms if self._optional_host_keys is not None: pruned_host_keys = [x for x in kex.key_algorithms if x not in self._optional_host_keys] - - # Checking allowed Hostkeys - hostkey_error = False - if self._allowed_host_keys is not None: - for hostkey_t in kex.key_algorithms: - if hostkey_t not in self._allowed_host_keys: - self._append_error(errors, 'Host keys', self._host_keys, self._allowed_host_keys, self._optional_host_keys, kex.key_algorithms) - ret = False - hostkey_error = True - # Checking required Hostkeys + # Checking Hostkeys if self._host_keys is not None: if self._allow_algorithm_subset_and_reordering: - for hostkey_t in self._host_keys: - if hostkey_t not in kex.key_algorithms: + for hostkey_t in kex.key_algorithms: + if hostkey_t not in self._host_keys: ret = False - if not hostkey_error: - self._append_error(errors, 'Host keys', self._host_keys, None, self._optional_host_keys, kex.key_algorithms) + self._append_error(errors, 'Host keys', self._host_keys, None, kex.key_algorithms) break elif pruned_host_keys != self._host_keys: ret = False - self._append_error(errors, 'Host keys', self._host_keys, None, self._optional_host_keys, kex.key_algorithms) + self._append_error(errors, 'Host keys', self._host_keys, None, kex.key_algorithms) # Checking Host Key Sizes if self._hostkey_sizes is not None: @@ -392,7 +368,7 @@ macs = %s actual_hostkey_size = server_host_keys[hostkey_type]['hostkey_size'] if actual_hostkey_size != expected_hostkey_size: ret = False - self._append_error(errors, 'Host key (%s) sizes' % hostkey_type, [str(expected_hostkey_size)], None, None, [str(actual_hostkey_size)]) + self._append_error(errors, 'Host key (%s) sizes' % hostkey_type, [str(expected_hostkey_size)], None, [str(actual_hostkey_size)]) # If we have expected CA signatures set, check them against what the server returned. if self._hostkey_sizes is not None and len(cast(str, self._hostkey_sizes[hostkey_type]['ca_key_type'])) > 0 and cast(int, self._hostkey_sizes[hostkey_type]['ca_key_size']) > 0: @@ -404,80 +380,47 @@ macs = %s # Ensure that the CA signature type is what's expected (i.e.: the server doesn't have an RSA sig when we're expecting an ED25519 sig). if actual_ca_key_type != expected_ca_key_type: ret = False - self._append_error(errors, 'CA signature type', [expected_ca_key_type], None, None, [actual_ca_key_type]) + self._append_error(errors, 'CA signature type', [expected_ca_key_type], None, [actual_ca_key_type]) # Ensure that the actual and expected signature sizes match. elif actual_ca_key_size != expected_ca_key_size: ret = False - self._append_error(errors, 'CA signature size (%s)' % actual_ca_key_type, [str(expected_ca_key_size)], None, None, [str(actual_ca_key_size)]) - - # Checking allowed KEX - kex_error = False - if self._allowed_kex is not None: - for kex_t in kex.kex_algorithms: - if kex_t not in self._allowed_kex: - self._append_error(errors, 'Kex Exchanges', None, self._allowed_kex, None, kex.kex_algorithms) - ret = False - kex_error = True - break + self._append_error(errors, 'CA signature size (%s)' % actual_ca_key_type, [str(expected_ca_key_size)], None, [str(actual_ca_key_size)]) - # Checking required KEX + # Checking KEX if self._kex is not None: if self._allow_algorithm_subset_and_reordering: - for kex_t in self._kex: - if kex_t not in kex.kex_algorithms: + for kex_t in kex.kex_algorithms: + if kex_t not in self._kex: ret = False - if not kex_error: - self._append_error(errors, 'Key exchanges', self._kex, None, None, kex.kex_algorithms) + self._append_error(errors, 'Key exchanges', self._kex, None, kex.kex_algorithms) break elif kex.kex_algorithms != self._kex: # Requires perfect match ret = False - self._append_error(errors, 'Key exchanges', self._kex, None, None, kex.kex_algorithms) - - # Checking allowed Ciphers - cipher_error = False - if self._allowed_ciphers is not None: - for cipher_t in kex.server.encryption: - if cipher_t not in self._allowed_ciphers: - self._append_error(errors, 'Ciphers', self._ciphers, self._allowed_ciphers, None, kex.server.encryption) - ret = False - cipher_error = True - break + self._append_error(errors, 'Key exchanges', self._kex, None, kex.kex_algorithms) - # Checking required Ciphers + # Checking Ciphers if self._ciphers is not None: if self._allow_algorithm_subset_and_reordering: - for cipher_t in self._ciphers: - if cipher_t not in kex.server.encryption: + for cipher_t in kex.server.encryption: + if cipher_t not in self._ciphers: ret = False - if not cipher_error: - self._append_error(errors, 'Ciphers', self._ciphers, None, None, kex.server.encryption) + self._append_error(errors, 'Ciphers', self._ciphers, None, kex.server.encryption) break elif kex.server.encryption != self._ciphers: # Requires perfect match ret = False - self._append_error(errors, 'Ciphers', self._ciphers, None, None, kex.server.encryption) - - # Checking allowed MACs - mac_error = False - if self._allowed_macs is not None: - for mac_t in kex.server.mac: - if mac_t not in self._allowed_macs: - ret = False - mac_error = True - self._append_error(errors, 'MACs', self._macs, self._allowed_macs, None, kex.server.mac) - break + self._append_error(errors, 'Ciphers', self._ciphers, None, kex.server.encryption) - # Checking required MACs + # Checking MACs if self._macs is not None: if self._allow_algorithm_subset_and_reordering: - for mac_t in self._macs: - if mac_t not in kex.server.mac: + for mac_t in kex.server.mac: + if mac_t not in self._macs: ret = False - if not mac_error: - self._append_error(errors, 'MACs', self._macs, None, None, kex.server.mac) - break + self._append_error(errors, 'MACs', self._macs, None, kex.server.mac) + break elif kex.server.mac != self._macs: # Requires perfect match ret = False - self._append_error(errors, 'MACs', self._macs, None, None, kex.server.mac) + self._append_error(errors, 'MACs', self._macs, None, kex.server.mac) if self._dh_modulus_sizes is not None: dh_modulus_types = list(self._dh_modulus_sizes.keys()) @@ -488,32 +431,30 @@ macs = %s actual_dh_modulus_size = kex.dh_modulus_sizes()[dh_modulus_type] if expected_dh_modulus_size != actual_dh_modulus_size: ret = False - self._append_error(errors, 'Group exchange (%s) modulus sizes' % dh_modulus_type, [str(expected_dh_modulus_size)], None, [str(actual_dh_modulus_size)]) + self._append_error(errors, 'Group exchange (%s) modulus sizes' % dh_modulus_type, [str(expected_dh_modulus_size)], [str(actual_dh_modulus_size)]) - return ret, errors, self._get_error_str(errors) + return ret, errors, self._get_error_str(errors, self._allow_algorithm_subset_and_reordering) @staticmethod - def _get_error_str(errors: List[Any]) -> str: + def _get_error_str(errors: List[Any], allow_algorithm_subset_and_reordering: bool = False) -> str: '''Transforms an error struct to a flat string of error messages.''' - + + if allow_algorithm_subset_and_reordering: + expected_str = 'allowed' + else: + expected_str = 'required' + error_list = [] spacer = '' for e in errors: e_str = " * %s did not match.\n" % e['mismatched_field'] - - if ('expected_optional' in e) and (e['expected_optional'] != ['']) \ - and ('expected_allowed' in e) and (e['expected_allowed'] != ['']): - e_str += " - Expected (required): %s\n - Expected (allowed): %s\n - Expected (optional): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_allowed']), Policy._normalize_error_field(e['expected_optional'])) - spacer = ' ' - elif ('expected_allowed' in e) and (e['expected_allowed'] != ['']): - e_str += " - Expected (required): %s\n - Expected (allowed): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_allowed'])) - spacer = ' ' - elif ('expected_optional' in e) and (e['expected_optional'] != ['']): - e_str += " - Expected (required): %s\n - Expected (optional): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_optional'])) + + if ('expected_optional' in e) and (e['expected_optional'] != ['']): + e_str += " - Expected (" + expected_str + "): %s\n - Expected (optional): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_optional'])) spacer = ' ' else: - e_str += " - Expected: %s\n" % Policy._normalize_error_field(e['expected_required']) + e_str += " - Expected (" + expected_str + "): %s\n" % Policy._normalize_error_field(e['expected_required']) spacer = ' ' e_str += " - Actual:%s%s\n" % (spacer, Policy._normalize_error_field(e['actual'])) error_list.append(e_str)