Added support for mixed host key/CA key types (i.e.: RSA host keys signed by ED25519 CAs) (#120).

This commit is contained in:
Joe Testa
2023-04-25 09:17:32 -04:00
parent 4f31304b66
commit 263267c5ad
34 changed files with 556 additions and 308 deletions

View File

@ -21,6 +21,8 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import copy
import json
import sys
from typing import Dict, List, Tuple
@ -28,9 +30,9 @@ from typing import Optional, Any, Union, cast
from datetime import date
from ssh_audit import exitcodes
from ssh_audit.ssh2_kex import SSH2_Kex # pylint: disable=unused-import
from ssh_audit.banner import Banner # pylint: disable=unused-import
from ssh_audit.banner import Banner
from ssh_audit.globals import SNAP_PACKAGE, SNAP_PERMISSIONS_ERROR
from ssh_audit.ssh2_kex import SSH2_Kex
# Validates policy files and performs policy testing
@ -87,8 +89,9 @@ class Policy:
}
WARNING_DEPRECATED_DIRECTIVES = "\nWARNING: this policy is using deprecated features. Future versions of ssh-audit may remove support for them. Re-generating the policy file is perhaps the most straight-forward way of resolving this issue. Manually converting the 'hostkey_size_*', 'cakey_size_*', and 'dh_modulus_size_*' directives into the new format is another option.\n"
def __init__(self, policy_file: Optional[str] = None, policy_data: Optional[str] = None, manual_load: bool = False) -> None:
def __init__(self, policy_file: Optional[str] = None, policy_data: Optional[str] = None, manual_load: bool = False, json_output: bool = False) -> None:
self._name: Optional[str] = None
self._version: Optional[str] = None
self._banner: Optional[str] = None
@ -98,13 +101,19 @@ class Policy:
self._kex: Optional[List[str]] = None
self._ciphers: Optional[List[str]] = None
self._macs: Optional[List[str]] = None
self._hostkey_sizes: Optional[Dict[str, int]] = None
self._hostkey_sizes: Optional[Dict[str, Dict[str, Union[int, str, bytes]]]] = None
self._cakey_sizes: Optional[Dict[str, int]] = None
self._dh_modulus_sizes: Optional[Dict[str, int]] = None
self._server_policy = True
self._name_and_version: str = ''
# If invoked while JSON output is expected, send warnings to stderr instead of stdout (which would corrupt the JSON output).
if json_output:
self._warning_target = sys.stderr
else:
self._warning_target = sys.stdout
# Ensure that only one mode was specified.
num_modes = 0
if policy_file is not None:
@ -154,7 +163,7 @@ class Policy:
key = key.strip()
val = val.strip()
if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs', 'client policy'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'):
if key not in ['name', 'version', 'banner', 'compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs', 'client policy', 'host_key_sizes', 'dh_modulus_sizes'] and not key.startswith('hostkey_size_') and not key.startswith('cakey_size_') and not key.startswith('dh_modulus_size_'):
raise ValueError("invalid field found in policy: %s" % line)
if key in ['name', 'banner']:
@ -173,8 +182,10 @@ class Policy:
self._name = val
elif key == 'banner':
self._banner = val
elif key == 'version':
self._version = val
elif key in ['compressions', 'host keys', 'optional host keys', 'key exchanges', 'ciphers', 'macs']:
try:
algs = val.split(',')
@ -197,21 +208,59 @@ class Policy:
self._ciphers = algs
elif key == 'macs':
self._macs = algs
elif key.startswith('hostkey_size_'):
elif key.startswith('hostkey_size_'): # Old host key size format.
print(Policy.WARNING_DEPRECATED_DIRECTIVES, file=self._warning_target) # Warn the user that the policy file is using deprecated directives.
hostkey_type = key[13:]
hostkey_size = int(val)
if self._hostkey_sizes is None:
self._hostkey_sizes = {}
self._hostkey_sizes[hostkey_type] = int(val)
elif key.startswith('cakey_size_'):
cakey_type = key[11:]
if self._cakey_sizes is None:
self._cakey_sizes = {}
self._cakey_sizes[cakey_type] = int(val)
elif key.startswith('dh_modulus_size_'):
dh_modulus_type = key[16:]
self._hostkey_sizes[hostkey_type] = {'hostkey_size': hostkey_size, 'ca_key_type': '', 'ca_key_size': 0}
elif key.startswith('cakey_size_'): # Old host key size format.
print(Policy.WARNING_DEPRECATED_DIRECTIVES, file=self._warning_target) # Warn the user that the policy file is using deprecated directives.
hostkey_type = key[11:]
ca_key_size = int(val)
ca_key_type = 'ssh-ed25519'
if hostkey_type in ['ssh-rsa-cert-v01@openssh.com', 'rsa-sha2-256-cert-v01@openssh.com', 'rsa-sha2-512-cert-v01@openssh.com']:
ca_key_type = 'ssh-rsa'
if self._hostkey_sizes is None:
self._hostkey_sizes = {}
self._hostkey_sizes[hostkey_type] = {'hostkey_size': hostkey_size, 'ca_key_type': ca_key_type, 'ca_key_size': ca_key_size}
elif key == 'host_key_sizes': # New host key size format.
self._hostkey_sizes = json.loads(val)
# Fill in the trimmed fields that were omitted from the policy.
if self._hostkey_sizes is not None:
for host_key_type in self._hostkey_sizes:
if 'ca_key_type' not in self._hostkey_sizes[host_key_type]:
self._hostkey_sizes[host_key_type]['ca_key_type'] = ''
if 'ca_key_size' not in self._hostkey_sizes[host_key_type]:
self._hostkey_sizes[host_key_type]['ca_key_size'] = 0
if 'raw_hostkey_bytes' not in self._hostkey_sizes[host_key_type]:
self._hostkey_sizes[host_key_type]['raw_hostkey_bytes'] = b''
elif key.startswith('dh_modulus_size_'): # Old DH modulus format.
print(Policy.WARNING_DEPRECATED_DIRECTIVES, file=self._warning_target) # Warn the user that the policy file is using deprecated directives.
dh_type = key[16:]
dh_size = int(val)
if self._dh_modulus_sizes is None:
self._dh_modulus_sizes = {}
self._dh_modulus_sizes[dh_modulus_type] = int(val)
self._dh_modulus_sizes[dh_type] = dh_size
elif key == 'dh_modulus_sizes': # New DH modulus format.
self._dh_modulus_sizes = json.loads(val)
elif key.startswith('client policy') and val.lower() == 'true':
self._server_policy = False
@ -243,10 +292,9 @@ class Policy:
kex_algs = None
ciphers = None
macs = None
rsa_hostkey_sizes_str = ''
rsa_cakey_sizes_str = ''
dh_modulus_sizes_str = ''
client_policy_str = ''
host_keys_json = ''
if client_audit:
client_policy_str = "\n# Set to true to signify this is a policy for clients, not servers.\nclient policy = true\n"
@ -262,26 +310,23 @@ class Policy:
ciphers = ', '.join(kex.server.encryption)
if kex.server.mac is not None:
macs = ', '.join(kex.server.mac)
if kex.rsa_key_sizes():
rsa_key_sizes_dict = kex.rsa_key_sizes()
for host_key_type in sorted(rsa_key_sizes_dict):
hostkey_size, cakey_size = rsa_key_sizes_dict[host_key_type]
rsa_hostkey_sizes_str = "%shostkey_size_%s = %d\n" % (rsa_hostkey_sizes_str, host_key_type, hostkey_size)
if cakey_size != -1:
rsa_cakey_sizes_str = "%scakey_size_%s = %d\n" % (rsa_cakey_sizes_str, host_key_type, cakey_size)
if kex.host_keys():
# Make a deep copy of the host keys dict, then delete all the raw hostkey bytes from the copy.
host_keys_trimmed = copy.deepcopy(kex.host_keys())
for hostkey_alg in host_keys_trimmed:
del host_keys_trimmed[hostkey_alg]['raw_hostkey_bytes']
# Delete the CA signature if any of its fields are empty.
if host_keys_trimmed[hostkey_alg]['ca_key_type'] == '' or host_keys_trimmed[hostkey_alg]['ca_key_size'] == 0:
del host_keys_trimmed[hostkey_alg]['ca_key_type']
del host_keys_trimmed[hostkey_alg]['ca_key_size']
host_keys_json = "\n# Dictionary containing all host key and size information. Optionally contains the certificate authority's signature algorithm ('ca_key_type') and signature length ('ca_key_size'), if any.\nhost_key_sizes = %s\n" % json.dumps(host_keys_trimmed)
if len(rsa_hostkey_sizes_str) > 0:
rsa_hostkey_sizes_str = "\n# RSA host key sizes.\n%s" % rsa_hostkey_sizes_str
if len(rsa_cakey_sizes_str) > 0:
rsa_cakey_sizes_str = "\n# RSA CA key sizes.\n%s" % rsa_cakey_sizes_str
if kex.dh_modulus_sizes():
dh_modulus_sizes_dict = kex.dh_modulus_sizes()
for gex_type in sorted(dh_modulus_sizes_dict):
modulus_size, _ = dh_modulus_sizes_dict[gex_type]
dh_modulus_sizes_str = "%sdh_modulus_size_%s = %d\n" % (dh_modulus_sizes_str, gex_type, modulus_size)
if len(dh_modulus_sizes_str) > 0:
dh_modulus_sizes_str = "\n# Group exchange DH modulus sizes.\n%s" % dh_modulus_sizes_str
dh_modulus_sizes_str = "\n# Group exchange DH modulus sizes.\ndh_modulus_sizes = %s\n" % json.dumps(kex.dh_modulus_sizes())
policy_data = '''#
@ -299,7 +344,7 @@ version = 1
# The compression options that must match exactly (order matters). Commented out to ignore by default.
# compressions = %s
%s%s%s
%s%s
# The host key types that must match exactly (order matters).
host keys = %s
@ -314,7 +359,7 @@ ciphers = %s
# The MACs that must match exactly (order matters).
macs = %s
''' % (source, today, client_policy_str, source, today, banner, compressions, rsa_hostkey_sizes_str, rsa_cakey_sizes_str, dh_modulus_sizes_str, host_keys, kex_algs, ciphers, macs)
''' % (source, today, client_policy_str, source, today, banner, compressions, host_keys_json, dh_modulus_sizes_str, host_keys, kex_algs, ciphers, macs)
return policy_data
@ -351,23 +396,29 @@ macs = %s
hostkey_types = list(self._hostkey_sizes.keys())
hostkey_types.sort() # Sorted to make testing output repeatable.
for hostkey_type in hostkey_types:
expected_hostkey_size = self._hostkey_sizes[hostkey_type]
if hostkey_type in kex.rsa_key_sizes():
actual_hostkey_size, actual_cakey_size = kex.rsa_key_sizes()[hostkey_type]
expected_hostkey_size = self._hostkey_sizes[hostkey_type]['hostkey_size']
server_host_keys = kex.host_keys()
if hostkey_type in server_host_keys:
actual_hostkey_size = server_host_keys[hostkey_type]['hostkey_size']
if actual_hostkey_size != expected_hostkey_size:
ret = False
self._append_error(errors, 'RSA host key (%s) sizes' % hostkey_type, [str(expected_hostkey_size)], None, [str(actual_hostkey_size)])
self._append_error(errors, 'Host key (%s) sizes' % hostkey_type, [str(expected_hostkey_size)], None, [str(actual_hostkey_size)])
if self._cakey_sizes is not None:
hostkey_types = list(self._cakey_sizes.keys())
hostkey_types.sort() # Sorted to make testing output repeatable.
for hostkey_type in hostkey_types:
expected_cakey_size = self._cakey_sizes[hostkey_type]
if hostkey_type in kex.rsa_key_sizes():
actual_hostkey_size, actual_cakey_size = kex.rsa_key_sizes()[hostkey_type]
if actual_cakey_size != expected_cakey_size:
ret = False
self._append_error(errors, 'RSA CA key (%s) sizes' % hostkey_type, [str(expected_cakey_size)], None, [str(actual_cakey_size)])
# If we have expected CA signatures set, check them against what the server returned.
if self._hostkey_sizes is not None and len(cast(str, self._hostkey_sizes[hostkey_type]['ca_key_type'])) > 0 and cast(int, self._hostkey_sizes[hostkey_type]['ca_key_size']) > 0:
expected_ca_key_type = cast(str, self._hostkey_sizes[hostkey_type]['ca_key_type'])
expected_ca_key_size = cast(int, self._hostkey_sizes[hostkey_type]['ca_key_size'])
actual_ca_key_type = cast(str, server_host_keys[hostkey_type]['ca_key_type'])
actual_ca_key_size = cast(int, server_host_keys[hostkey_type]['ca_key_size'])
# Ensure that the CA signature type is what's expected (i.e.: the server doesn't have an RSA sig when we're expecting an ED25519 sig).
if actual_ca_key_type != expected_ca_key_type:
ret = False
self._append_error(errors, 'CA signature type', [expected_ca_key_type], None, [actual_ca_key_type])
# Ensure that the actual and expected signature sizes match.
elif actual_ca_key_size != expected_ca_key_size:
ret = False
self._append_error(errors, 'CA signature size (%s)' % actual_ca_key_type, [str(expected_ca_key_size)], None, [str(actual_ca_key_size)])
if kex.kex_algorithms != self._kex:
ret = False
@ -387,7 +438,7 @@ macs = %s
for dh_modulus_type in dh_modulus_types:
expected_dh_modulus_size = self._dh_modulus_sizes[dh_modulus_type]
if dh_modulus_type in kex.dh_modulus_sizes():
actual_dh_modulus_size, _ = kex.dh_modulus_sizes()[dh_modulus_type]
actual_dh_modulus_size = kex.dh_modulus_sizes()[dh_modulus_type]
if expected_dh_modulus_size != actual_dh_modulus_size:
ret = False
self._append_error(errors, 'Group exchange (%s) modulus sizes' % dh_modulus_type, [str(expected_dh_modulus_size)], None, [str(actual_dh_modulus_size)])
@ -449,12 +500,12 @@ macs = %s
@staticmethod
def load_builtin_policy(policy_name: str) -> Optional['Policy']:
def load_builtin_policy(policy_name: str, json_output: bool = False) -> Optional['Policy']:
'''Returns a Policy with the specified built-in policy name loaded, or None if no policy of that name exists.'''
p = None
if policy_name in Policy.BUILTIN_POLICIES:
policy_struct = Policy.BUILTIN_POLICIES[policy_name]
p = Policy(manual_load=True)
p = Policy(manual_load=True, json_output=json_output)
policy_name_without_version = policy_name[0:policy_name.rfind(' (')]
p._name = policy_name_without_version # pylint: disable=protected-access
p._version = cast(str, policy_struct['version']) # pylint: disable=protected-access
@ -465,7 +516,7 @@ macs = %s
p._kex = cast(Optional[List[str]], policy_struct['kex']) # pylint: disable=protected-access
p._ciphers = cast(Optional[List[str]], policy_struct['ciphers']) # pylint: disable=protected-access
p._macs = cast(Optional[List[str]], policy_struct['macs']) # pylint: disable=protected-access
p._hostkey_sizes = cast(Optional[Dict[str, int]], policy_struct['hostkey_sizes']) # pylint: disable=protected-access
p._hostkey_sizes = cast(Optional[Dict[str, Dict[str, Union[int, str, bytes]]]], policy_struct['hostkey_sizes']) # pylint: disable=protected-access
p._cakey_sizes = cast(Optional[Dict[str, int]], policy_struct['cakey_sizes']) # pylint: disable=protected-access
p._dh_modulus_sizes = cast(Optional[Dict[str, int]], policy_struct['dh_modulus_sizes']) # pylint: disable=protected-access
p._server_policy = cast(bool, policy_struct['server_policy']) # pylint: disable=protected-access