diff --git a/ssh-audit.py b/ssh-audit.py index e9243ec..0ae8111 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -133,7 +133,7 @@ class Output(object): class OutputBuffer(list): def __enter__(self): - self.__buf = io.StringIO() + self.__buf = utils.StringIO() self.__stdout = sys.stdout sys.stdout = self.__buf return self @@ -389,7 +389,7 @@ class SSH1(object): class ReadBuf(object): def __init__(self, data=None): super(ReadBuf, self).__init__() - self._buf = io.BytesIO(data) if data else io.BytesIO() + self._buf = utils.BytesIO(data) if data else utils.BytesIO() self._len = len(data) if data else 0 @property @@ -1538,11 +1538,25 @@ def output(banner, header, kex=None, pkm=None): output_fingerprint(kex, pkm, True, maxlen) -def parse_int(v): - try: - return int(v) - except: - return 0 +class Utils(object): + PY2 = sys.version_info[0] == 2 + + @classmethod + def wrap(cls): + o = cls() + if cls.PY2: + import StringIO + o.StringIO = o.BytesIO = StringIO.StringIO + else: + o.StringIO, o.BytesIO = io.StringIO, io.BytesIO + return o + + @staticmethod + def parse_int(v): + try: + return int(v) + except: + return 0 def parse_args(): @@ -1576,7 +1590,7 @@ def parse_args(): s = args[0].split(':') host, port = s[0].strip(), 22 if len(s) > 1: - port = parse_int(s[1]) + port = utils.parse_int(s[1]) if not host or port <= 0: usage('port {0} is not valid'.format(port)) conf.host = host @@ -1625,6 +1639,7 @@ def audit(conf, sshv=None): output(banner, header, kex=kex) +utils = Utils.wrap() if __name__ == '__main__': out = Output() conf = parse_args()