diff --git a/src/ssh_audit/ssh_audit.py b/src/ssh_audit/ssh_audit.py index ede5337..9c320d4 100755 --- a/src/ssh_audit/ssh_audit.py +++ b/src/ssh_audit/ssh_audit.py @@ -841,10 +841,9 @@ def process_commandline(out: OutputBuffer, args: List[str], usage_cb: Callable[. aconf.ssh1, aconf.ssh2 = False, False host: str = '' - oport: Optional[str] = None - parser = argparse.ArgumentParser(description='SSH Audit Tool', add_help=False) - + parser = argparse.ArgumentParser(prog='SSH Audit Tool', description='SSH Audit Tool', add_help=False, allow_abbrev=False) + # Add short options to the parser parser.add_argument('-1', '--ssh1', action="store_true", dest='ssh1', default=None) parser.add_argument('-2', '--ssh2', action="store_true", dest='ssh2', default=None) @@ -862,6 +861,7 @@ def process_commandline(out: OutputBuffer, args: List[str], usage_cb: Callable[. parser.add_argument('-L', '--list-policies', action="store_true", dest='list_policies', default=None) parser.add_argument('-M', '--make-policy', action="store", dest='make_policy', default=None) parser.add_argument('-m', '--manual', action="store_true", dest='manual', default=None) + parser.add_argument('-n', '--no-colors', action="store_true", dest='no_colors', default=None) parser.add_argument('-P', '--policy', action="store", dest='policy', default=None) parser.add_argument('-p', '--port', action="store", dest='port', default='22', type=int) parser.add_argument('-T', '--targets', action="store", dest='targets', default=None) @@ -873,88 +873,81 @@ def process_commandline(out: OutputBuffer, args: List[str], usage_cb: Callable[. parser.add_argument('--conn-rate-test', action="store", dest='conn_rate_test', default='0', type=int) parser.add_argument('--dheat', action="store", dest='dheat', default='0', type=int) parser.add_argument('--lookup', action="store", dest='lookup', default=None) - parser.add_argument('--no-colors', action="store_true", dest='no_colors', default=None) parser.add_argument('--skip-rate-test', action="store_true", dest='skip_rate_test', default=None) - parser.add_argument('--threads', action="store", dest='threads', default='1', type=int) + parser.add_argument('--threads', action="store", dest='threads', default='32', type=int) try: - namespace = parser.parse_args() + argument = parser.parse_args() - if namespace.help is True: + if argument.help is True: usage_cb(out) - aconf.host = namespace.host - host = namespace.host - port = namespace.port - aconf.ssh1 = namespace.ssh1 - aconf.ssh2 = namespace.ssh2 - aconf.ipv4 = namespace.ipv4 - aconf.ipv6 = namespace.ipv6 + aconf.host = argument.host + host = argument.host + port = argument.port + aconf.ssh1 = argument.ssh1 + aconf.ssh2 = argument.ssh2 + aconf.ipv4 = argument.ipv4 + aconf.ipv6 = argument.ipv6 - aconf.json = namespace.json - if namespace.json_indent is True: - setattr(namespace, 'json', True) - aconf.json = namespace.json - aconf.json_print_indent = namespace.json_indent + aconf.json = argument.json + if argument.json_indent is True: + setattr(argument, 'json', True) + aconf.json = argument.json + aconf.json_print_indent = argument.json_indent - if namespace.batch is True: + if argument.batch is True: aconf.batch = True aconf.verbose = True - aconf.client_audit = namespace.client_audit + aconf.client_audit = argument.client_audit - ttime = namespace.timeout + ttime = argument.timeout if ttime != 5: - aconf.timeout = float(namespace.timeout) + aconf.timeout = float(argument.timeout) aconf.timeout_set = True - if namespace.verbose is True: + if argument.verbose is True: aconf.verbose = True out.verbose = True # Get error level regex - err_level = (namespace.level) - pattern = re.compile(r'info|warn|fail') - if pattern.match(err_level): - aconf.level = str(namespace.level) + err_level = argument.level + if err_level in ["info", "warn", "fail"]: + aconf.level = str(argument.level) else: usage_cb(out, 'Error level : {} is not valid'.format(err_level)) - if getattr(namespace, 'make_policy') is True: + if getattr(argument, 'make_policy') is True: aconf.make_policy = True - aconf.policy_file = namespace.make_policy + aconf.policy_file = argument.make_policy - if getattr(namespace, 'policy') is True: - aconf.policy_file = namespace.policy + if getattr(argument, 'policy') is True: + aconf.policy_file = argument.policy - if getattr(namespace, 'targets') is True: - aconf.target_file = namespace.targets + if getattr(argument, 'targets') is True: + aconf.target_file = argument.targets - if os.name == 'nt': - aconf.threads = '1' - - if os.name == 'posix': - aconf.threads = int(namespace.threads) - else: - aconf.threads = '1' - - if getattr(namespace, 'list_policies') is True: + if argument.threads != 32: + aconf.threads = argument.threads + + if getattr(argument, 'list_policies') is True: aconf.list_policies = True - - if getattr(namespace, 'lookup') is True: - aconf.lookup = namespace.lookup - if getattr(namespace, 'manual') is True: + if getattr(argument, 'lookup') is True: + aconf.lookup = argument.lookup + + if getattr(argument, 'manual') is True: aconf.manual = True else: aconf.manual = False - - if namespace.debug == True: + + if argument.debug is True: aconf.debug = True out.debug = True - if getattr(namespace, 'gex_test') == True: - dh_gex = namespace.gex_test + if getattr(argument, 'gex_test') is True: + dh_gex = argument.gex_test permitted_syntax = get_permitted_syntax_for_gex_test() if not any(re.search(regex_str, dh_gex) for regex_str in permitted_syntax.values()): @@ -975,21 +968,21 @@ def process_commandline(out: OutputBuffer, args: List[str], usage_cb: Callable[. if all(x < 0 for x in (bits_left_bound, bits_right_bound)): usage_cb(out, '{} {} {} is not valid'.format(dh_gex, bits_left_bound, bits_right_bound)) - aconf.gex_test = namespace.gex_test + aconf.gex_test = argument.gex_test - if int(namespace.dheat) > 0: - aconf.dheat = int(namespace.dheat) + if int(argument.dheat) > 0: + aconf.dheat = argument.dheat - aconf.skip_rate_test = namespace.skip_rate_test + aconf.skip_rate_test = argument.skip_rate_test + + if int(argument.conn_rate_test) > 0: + aconf.conn_rate_test = argument.conn_rate_test - if int(namespace.conn_rate_test) > 0: - aconf.conn_rate_test = int(namespace.conn_rate_test) - except argparse.ArgumentError as err: usage_cb(out, str(err)) - if namespace.host is None and namespace.client_audit is None and namespace.targets is None and namespace.list_policies is None and namespace.lookup is None and namespace.manual is None: + if argument.host is None and argument.client_audit is None and argument.targets is None and argument.list_policies is None and argument.lookup is None and argument.manual is None: usage_cb(out) if aconf.manual: @@ -1003,25 +996,19 @@ def process_commandline(out: OutputBuffer, args: List[str], usage_cb: Callable[. sys.exit(exitcodes.GOOD) if aconf.client_audit is None and aconf.target_file is None: - if oport is not None: - host = host - else: - host = host - port = port - #if not host and aconf.target_file is None: - if host is None and aconf.target_file is None: + host = argument.host + port = argument.port + + if argument.host is None and aconf.target_file is None: usage_cb(out, 'host is empty') - if port == 0 and oport is None: - if aconf.client_audit: # The default port to listen on during a client audit is 2222. - port = 2222 - else: - port = port + if aconf.client_audit is True: # The default port to listen on during a client audit is 2222. + port = 2222 - if oport is not None: - port = Utils.parse_int(oport) + if argument.port != 22: + port = Utils.parse_int(argument.port) if port <= 0 or port > 65535: - usage_cb(out, 'port {} is not valid'.format(oport)) + usage_cb(out, 'port {} is not valid'.format(argument.port)) aconf.host = host aconf.port = port