Refactor and test SSH.Algorithm.

This commit is contained in:
Andris Raugulis
2017-04-10 13:20:32 +03:00
parent 774d1c1fe4
commit 72a6b9eeaf
2 changed files with 252 additions and 67 deletions

View File

@ -1156,49 +1156,79 @@ class SSH(object): # pylint: disable=too-few-public-methods
return u'SHA256:{0}'.format(r)
class Algorithm(object):
class Timeframe(object):
def __init__(self):
# type: () -> None
self.__storage = {} # type: Dict[str, List[Optional[str]]]
def __contains__(self, product):
# type: (str) -> bool
return product in self.__storage
def __getitem__(self, product):
# type: (str) -> Sequence[Optional[str]]
return tuple(self.__storage.get(product, [None]*4))
def __str__(self):
# type: () -> str
return self.__storage.__str__()
def __repr__(self):
# type: () -> str
return self.__str__()
def get_from(self, product, for_server=True):
# type: (str, bool) -> Optional[str]
return self[product][0 if bool(for_server) else 2]
def get_till(self, product, for_server=True):
# type: (str, bool) -> Optional[str]
return self[product][1 if bool(for_server) else 3]
def _update(self, versions, pos):
# type: (Optional[str], int) -> None
ssh_versions = {} # type: Dict[str, str]
for_srv, for_cli = pos < 2, pos > 1
for v in (versions or '').split(','):
ssh_prod, ssh_ver, is_cli = SSH.Algorithm.get_ssh_version(v)
if (not ssh_ver or
(is_cli and for_srv) or
(not is_cli and for_cli and ssh_prod in ssh_versions)):
continue
ssh_versions[ssh_prod] = ssh_ver
for ssh_product, ssh_version in ssh_versions.items():
if ssh_product not in self.__storage:
self.__storage[ssh_product] = [None]*4
prev = self[ssh_product][pos]
if (prev is None or
(prev < ssh_version and pos % 2 == 0) or
(prev > ssh_version and pos % 2 == 1)):
self.__storage[ssh_product][pos] = ssh_version
def update(self, versions, for_server=None):
# type: (List[Optional[str]], Optional[bool]) -> SSH.Algorithm.Timeframe
for_cli = for_server is None or for_server is False
for_srv = for_server is None or for_server is True
vlen = len(versions)
for i in range(min(3, vlen)):
if for_srv and i < 2:
self._update(versions[i], i)
if for_cli and (i % 2 == 0 or vlen == 2):
self._update(versions[i], 3 - 0**i)
return self
@staticmethod
def get_ssh_version(version_desc):
# type: (str) -> Tuple[str, str]
# type: (str) -> Tuple[str, str, bool]
is_client = version_desc.endswith('C')
if is_client:
version_desc = version_desc[:-1]
if version_desc.startswith('d'):
return SSH.Product.DropbearSSH, version_desc[1:]
return SSH.Product.DropbearSSH, version_desc[1:], is_client
elif version_desc.startswith('l1'):
return SSH.Product.LibSSH, version_desc[2:]
return SSH.Product.LibSSH, version_desc[2:], is_client
else:
return SSH.Product.OpenSSH, version_desc
@classmethod
def get_timeframe(cls, versions, for_server=True, result=None):
# type: (List[Optional[str]], bool, Optional[Dict[str, List[Optional[str]]]]) -> Dict[str, List[Optional[str]]]
result = result or {}
vlen = len(versions)
for i in range(3):
if i > vlen - 1:
if i == 2 and vlen > 1:
cversions = versions[1]
else:
continue
else:
cversions = versions[i]
if cversions is None:
continue
for v in cversions.split(','):
ssh_prefix, ssh_version = cls.get_ssh_version(v)
if not ssh_version:
continue
if ssh_version.endswith('C'):
if for_server:
continue
ssh_version = ssh_version[:-1]
if ssh_prefix not in result:
result[ssh_prefix] = [None, None, None]
prev, push = result[ssh_prefix][i], False
if (prev is None or
(prev < ssh_version and i == 0) or
(prev > ssh_version and i > 0)):
push = True
if push:
result[ssh_prefix][i] = ssh_version
return result
return SSH.Product.OpenSSH, version_desc, is_client
@classmethod
def get_since_text(cls, versions):
@ -1207,14 +1237,14 @@ class SSH(object): # pylint: disable=too-few-public-methods
if len(versions) == 0 or versions[0] is None:
return None
for v in versions[0].split(','):
ssh_prefix, ssh_version = cls.get_ssh_version(v)
if not ssh_version:
ssh_prod, ssh_ver, is_cli = cls.get_ssh_version(v)
if not ssh_ver:
continue
if ssh_prefix in [SSH.Product.LibSSH]:
if ssh_prod in [SSH.Product.LibSSH]:
continue
if ssh_version.endswith('C'):
ssh_version = '{0} (client only)'.format(ssh_version[:-1])
tv.append('{0} {1}'.format(ssh_prefix, ssh_version))
if is_cli:
ssh_ver = '{0} (client only)'.format(ssh_ver)
tv.append('{0} {1}'.format(ssh_prod, ssh_ver))
if len(tv) == 0:
return None
return 'available since ' + ', '.join(tv).rstrip(', ')
@ -1284,9 +1314,9 @@ class SSH(object): # pylint: disable=too-few-public-methods
maxlen)
return maxlen
def get_ssh_timeframe(self, for_server=True):
# type: (bool) -> Dict[str, List[Optional[str]]]
r = {} # type: Dict[str, List[Optional[str]]]
def get_ssh_timeframe(self, for_server=None):
# type: (Optional[bool]) -> SSH.Algorithm.Timeframe
timeframe = SSH.Algorithm.Timeframe()
for alg_pair in self.values:
alg_db = alg_pair.db
for alg_type, alg_list in alg_pair.items():
@ -1296,8 +1326,8 @@ class SSH(object): # pylint: disable=too-few-public-methods
if alg_desc is None:
continue
versions = alg_desc[0]
r = SSH.Algorithm.get_timeframe(versions, for_server, r)
return r
timeframe.update(versions, for_server)
return timeframe
def get_recommendations(self, software, for_server=True):
# type: (Optional[SSH.Software], bool) -> Tuple[Optional[SSH.Software], Dict[int, Dict[str, Dict[str, Dict[str, int]]]]]
@ -1313,7 +1343,7 @@ class SSH(object): # pylint: disable=too-few-public-methods
for product in vproducts:
if product not in ssh_timeframe:
continue
version = ssh_timeframe[product][0]
version = ssh_timeframe.get_from(product, for_server)
if version is not None:
software = SSH.Software(None, product, version, None, None)
break
@ -1335,15 +1365,13 @@ class SSH(object): # pylint: disable=too-few-public-methods
continue
matches = False
for v in versions[0].split(','):
ssh_prefix, ssh_version = SSH.Algorithm.get_ssh_version(v)
ssh_prefix, ssh_version, is_cli = SSH.Algorithm.get_ssh_version(v)
if not ssh_version:
continue
if ssh_prefix != software.product:
continue
if ssh_version.endswith('C'):
if for_server:
continue
ssh_version = ssh_version[:-1]
if is_cli and for_server:
continue
if software.compare_version(ssh_version) < 0:
continue
matches = True
@ -1842,25 +1870,25 @@ def output_algorithm(alg_db, alg_type, alg_name, alg_max_len=0):
def output_compatibility(algs, for_server=True):
# type: (SSH.Algorithms, bool) -> None
ssh_timeframe = algs.get_ssh_timeframe(for_server)
vp = 1 if for_server else 2
comp_text = []
for sshd_name in [SSH.Product.OpenSSH, SSH.Product.DropbearSSH]:
if sshd_name not in ssh_timeframe:
for ssh_prod in [SSH.Product.OpenSSH, SSH.Product.DropbearSSH]:
if ssh_prod not in ssh_timeframe:
continue
v = ssh_timeframe[sshd_name]
if v[0] is None:
v_from = ssh_timeframe.get_from(ssh_prod, for_server)
v_till = ssh_timeframe.get_till(ssh_prod, for_server)
if v_from is None:
continue
if v[vp] is None:
comp_text.append('{0} {1}+'.format(sshd_name, v[0]))
elif v[0] == v[vp]:
comp_text.append('{0} {1}'.format(sshd_name, v[0]))
if v_till is None:
comp_text.append('{0} {1}+'.format(ssh_prod, v_from))
elif v_from == v_till:
comp_text.append('{0} {1}'.format(ssh_prod, v_from))
else:
software = SSH.Software(None, sshd_name, v[0], None, None)
if software.compare_version(v[vp]) > 0:
software = SSH.Software(None, ssh_prod, v_from, None, None)
if software.compare_version(v_till) > 0:
tfmt = '{0} {1}+ (some functionality from {2})'
else:
tfmt = '{0} {1}-{2}'
comp_text.append(tfmt.format(sshd_name, v[0], v[vp]))
comp_text.append(tfmt.format(ssh_prod, v_from, v_till))
if len(comp_text) > 0:
out.good('(gen) compatibility: ' + ', '.join(comp_text))