diff --git a/src/ssh_audit/policy.py b/src/ssh_audit/policy.py index 8e73360..591feec 100644 --- a/src/ssh_audit/policy.py +++ b/src/ssh_audit/policy.py @@ -54,6 +54,7 @@ class Policy: self._hostkey_sizes: Optional[Dict[str, Dict[str, Union[int, str, bytes]]]] = None self._dh_modulus_sizes: Optional[Dict[str, int]] = None self._server_policy = True + self._allow_algorithm_subset_and_reordering = False self._name_and_version: str = '' @@ -112,7 +113,7 @@ class Policy: key = key.strip() val = val.strip() - if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs', 'client policy', 'host_key_sizes', 'dh_modulus_sizes'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'): + if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs', 'client policy', 'host_key_sizes', 'dh_modulus_sizes', 'allow_algorithm_subset_and_reordering'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'): raise ValueError("invalid field found in policy: %s" % line) if key in ['name', 'banner']: @@ -150,7 +151,7 @@ class Policy: elif key == 'host keys': self._host_keys = algs elif key == 'optional host keys': - self._optional_host_keys = algs + self._optional_host_keys = algs elif key == 'key exchanges': self._kex = algs elif key == 'ciphers': @@ -205,7 +206,8 @@ class Policy: elif key.startswith('client policy') and val.lower() == 'true': self._server_policy = False - + elif key == 'allow_algorithm_subset_and_reordering' and val.lower() == 'true': + self._allow_algorithm_subset_and_reordering = True if self._name is None: raise ValueError('The policy does not have a name field.') @@ -332,7 +334,7 @@ macs = %s # All subsequent tests require a valid kex, so end here if we don't have one. if kex is None: - return ret, errors, self._get_error_str(errors) + return ret, errors, self._get_error_str(errors, self._allow_algorithm_subset_and_reordering) if (self._compressions is not None) and (kex.server.compression != self._compressions): ret = False @@ -342,11 +344,20 @@ macs = %s pruned_host_keys = kex.key_algorithms if self._optional_host_keys is not None: pruned_host_keys = [x for x in kex.key_algorithms if x not in self._optional_host_keys] + + # Checking Hostkeys + if self._host_keys is not None: + if self._allow_algorithm_subset_and_reordering: + for hostkey_t in kex.key_algorithms: + if hostkey_t not in self._host_keys: + ret = False + self._append_error(errors, 'Host keys', self._host_keys, None, kex.key_algorithms) + break + elif pruned_host_keys != self._host_keys: + ret = False + self._append_error(errors, 'Host keys', self._host_keys, None, kex.key_algorithms) - if (self._host_keys is not None) and (pruned_host_keys != self._host_keys): - ret = False - self._append_error(errors, 'Host keys', self._host_keys, self._optional_host_keys, kex.key_algorithms) - + # Checking Host Key Sizes if self._hostkey_sizes is not None: hostkey_types = list(self._hostkey_sizes.keys()) hostkey_types.sort() # Sorted to make testing output repeatable. @@ -374,18 +385,42 @@ macs = %s elif actual_ca_key_size != expected_ca_key_size: ret = False self._append_error(errors, 'CA signature size (%s)' % actual_ca_key_type, [str(expected_ca_key_size)], None, [str(actual_ca_key_size)]) - - if kex.kex_algorithms != self._kex: - ret = False - self._append_error(errors, 'Key exchanges', self._kex, None, kex.kex_algorithms) - - if (self._ciphers is not None) and (kex.server.encryption != self._ciphers): - ret = False - self._append_error(errors, 'Ciphers', self._ciphers, None, kex.server.encryption) - - if (self._macs is not None) and (kex.server.mac != self._macs): - ret = False - self._append_error(errors, 'MACs', self._macs, None, kex.server.mac) + + # Checking KEX + if self._kex is not None: + if self._allow_algorithm_subset_and_reordering: + for kex_t in kex.kex_algorithms: + if kex_t not in self._kex: + ret = False + self._append_error(errors, 'Key exchanges', self._kex, None, kex.kex_algorithms) + break + elif kex.kex_algorithms != self._kex: # Requires perfect match + ret = False + self._append_error(errors, 'Key exchanges', self._kex, None, kex.kex_algorithms) + + # Checking Ciphers + if self._ciphers is not None: + if self._allow_algorithm_subset_and_reordering: + for cipher_t in kex.server.encryption: + if cipher_t not in self._ciphers: + ret = False + self._append_error(errors, 'Ciphers', self._ciphers, None, kex.server.encryption) + break + elif kex.server.encryption != self._ciphers: # Requires perfect match + ret = False + self._append_error(errors, 'Ciphers', self._ciphers, None, kex.server.encryption) + + # Checking MACs + if self._macs is not None: + if self._allow_algorithm_subset_and_reordering: + for mac_t in kex.server.mac: + if mac_t not in self._macs: + ret = False + self._append_error(errors, 'MACs', self._macs, None, kex.server.mac) + break + elif kex.server.mac != self._macs: # Requires perfect match + ret = False + self._append_error(errors, 'MACs', self._macs, None, kex.server.mac) if self._dh_modulus_sizes is not None: dh_modulus_types = list(self._dh_modulus_sizes.keys()) @@ -398,22 +433,28 @@ macs = %s ret = False self._append_error(errors, 'Group exchange (%s) modulus sizes' % dh_modulus_type, [str(expected_dh_modulus_size)], None, [str(actual_dh_modulus_size)]) - return ret, errors, self._get_error_str(errors) + return ret, errors, self._get_error_str(errors, self._allow_algorithm_subset_and_reordering) @staticmethod - def _get_error_str(errors: List[Any]) -> str: + def _get_error_str(errors: List[Any], allow_algorithm_subset_and_reordering: bool = False) -> str: '''Transforms an error struct to a flat string of error messages.''' + if allow_algorithm_subset_and_reordering: + expected_str = 'allowed' + else: + expected_str = 'required' + error_list = [] spacer = '' for e in errors: e_str = " * %s did not match.\n" % e['mismatched_field'] + if ('expected_optional' in e) and (e['expected_optional'] != ['']): - e_str += " - Expected (required): %s\n - Expected (optional): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_optional'])) + e_str += " - Expected (" + expected_str + "): %s\n - Expected (optional): %s\n" % (Policy._normalize_error_field(e['expected_required']), Policy._normalize_error_field(e['expected_optional'])) spacer = ' ' else: - e_str += " - Expected: %s\n" % Policy._normalize_error_field(e['expected_required']) + e_str += " - Expected (" + expected_str + "): %s\n" % Policy._normalize_error_field(e['expected_required']) spacer = ' ' e_str += " - Actual:%s%s\n" % (spacer, Policy._normalize_error_field(e['actual'])) error_list.append(e_str)