Make HostKeyTest class reusable (#278)

Because the `HostKeyTest` class was mutating its static/global
`HOST_KEY_TYPES` dict, this class could not actually be used more than once
in a single thread!

Rather than mutate this dict after parsing each key type
(`HOST_KEY_TYPES[host_key_type]['parsed'] = True`), the `perform_test`
method should simple add the parsed key types to a local `set()`.
This commit is contained in:
Daniel Lenski 2024-07-05 07:11:18 -07:00 committed by GitHub
parent e42961fa9a
commit d8f8b7c57c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,7 +40,7 @@ class HostKeyTest:
# Tracks the RSA host key types. As of this writing, testing one in this family yields valid results for the rest. # Tracks the RSA host key types. As of this writing, testing one in this family yields valid results for the rest.
RSA_FAMILY = ['ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'] RSA_FAMILY = ['ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512']
# Dict holding the host key types we should extract & parse. 'cert' is True to denote that a host key type handles certificates (thus requires additional parsing). 'variable_key_len' is True for host key types that can have variable sizes (True only for RSA types, as the rest are of fixed-size). After the host key type is fully parsed, the key 'parsed' is added with a value of True. # Dict holding the host key types we should extract & parse. 'cert' is True to denote that a host key type handles certificates (thus requires additional parsing). 'variable_key_len' is True for host key types that can have variable sizes (True only for RSA types, as the rest are of fixed-size).
HOST_KEY_TYPES = { HOST_KEY_TYPES = {
'ssh-rsa': {'cert': False, 'variable_key_len': True}, 'ssh-rsa': {'cert': False, 'variable_key_len': True},
'rsa-sha2-256': {'cert': False, 'variable_key_len': True}, 'rsa-sha2-256': {'cert': False, 'variable_key_len': True},
@ -93,6 +93,7 @@ class HostKeyTest:
def perform_test(out: 'OutputBuffer', s: 'SSH_Socket', server_kex: 'SSH2_Kex', kex_str: str, kex_group: 'KexDH', host_key_types: Dict[str, Dict[str, bool]]) -> None: def perform_test(out: 'OutputBuffer', s: 'SSH_Socket', server_kex: 'SSH2_Kex', kex_str: str, kex_group: 'KexDH', host_key_types: Dict[str, Dict[str, bool]]) -> None:
hostkey_modulus_size = 0 hostkey_modulus_size = 0
ca_modulus_size = 0 ca_modulus_size = 0
parsed_host_key_types = set()
# If the connection still exists, close it so we can test # If the connection still exists, close it so we can test
# using a clean slate (otherwise it may exist in a non-testable # using a clean slate (otherwise it may exist in a non-testable
@ -106,7 +107,7 @@ class HostKeyTest:
key_warn_comments = [] key_warn_comments = []
# Skip those already handled (i.e.: those in the RSA family, as testing one tests them all). # Skip those already handled (i.e.: those in the RSA family, as testing one tests them all).
if 'parsed' in host_key_types[host_key_type] and host_key_types[host_key_type]['parsed']: if host_key_type in parsed_host_key_types:
continue continue
# If this host key type is supported by the server, we test it. # If this host key type is supported by the server, we test it.
@ -216,7 +217,7 @@ class HostKeyTest:
# If this host key type is in the RSA family, then mark them all as parsed (since results in one are valid for them all). # If this host key type is in the RSA family, then mark them all as parsed (since results in one are valid for them all).
if host_key_type in HostKeyTest.RSA_FAMILY: if host_key_type in HostKeyTest.RSA_FAMILY:
for rsa_type in HostKeyTest.RSA_FAMILY: for rsa_type in HostKeyTest.RSA_FAMILY:
host_key_types[rsa_type]['parsed'] = True parsed_host_key_types.add(rsa_type)
# If the current key is a member of the RSA family, then populate all RSA family members with the same # If the current key is a member of the RSA family, then populate all RSA family members with the same
# failure and/or warning comments. # failure and/or warning comments.
@ -228,7 +229,7 @@ class HostKeyTest:
db['key'][rsa_type][2].extend(key_warn_comments) db['key'][rsa_type][2].extend(key_warn_comments)
else: else:
host_key_types[host_key_type]['parsed'] = True parsed_host_key_types.add(host_key_type)
db = SSH2_KexDB.get_db() db = SSH2_KexDB.get_db()
while len(db['key'][host_key_type]) < 3: while len(db['key'][host_key_type]) < 3:
db['key'][host_key_type].append([]) db['key'][host_key_type].append([])