Refactor ssh connection within class for future improvements.

This commit is contained in:
Andris Raugulis 2016-04-01 17:56:06 +03:00
parent 8442dfac0e
commit 06992d7da6

View File

@ -27,8 +27,6 @@ from __future__ import print_function
import os, io, sys, socket, struct
SSH_BANNER = 'SSH-2.0-OpenSSH_7.2'
SOCK_CONN_TIMEOUT = 3.0
SOCK_READ_TIMEOUT = 5.0
def usage():
p = os.path.basename(sys.argv[0])
@ -95,6 +93,7 @@ class Kex(object):
class ReadBuf(object):
def __init__(self, data = None):
super(ReadBuf, self).__init__()
self._buf = io.BytesIO(data) if data else io.BytesIO()
self._len = len(data) if data else 0
@ -118,10 +117,24 @@ class ReadBuf(object):
list_size = self.read_int()
return self.read(list_size).decode().split(',')
class SockBuf(ReadBuf):
def __init__(self, s):
super(SockBuf, self).__init__()
self.__sock = s
class SSH(object):
MSG_KEXINIT = 20
MSG_NEWKEYS = 21
MSG_KEXDH_INIT = 30
MSG_KEXDH_REPLY = 32
class Socket(ReadBuf):
def __init__(self, host, port, cto = 3.0, rto = 5.0):
super(SSH.Socket, self).__init__()
try:
self.__sock = socket.create_connection((host, port), cto)
self.__sock.settimeout(rto)
except Exception as e:
out.fail('[fail] {}'.format(e))
sys.exit(1)
def __enter__(self):
return self
def recv(self, size = 2048):
data = self.__sock.recv(size)
@ -131,6 +144,42 @@ class SockBuf(ReadBuf):
self._len += len(data)
self._buf.seek(pos, 0)
def send(self, data):
self.__sock.send(data)
def read_packet(self):
block_size = 8
if self.unread_len < block_size:
self.recv()
header = self.read(block_size)
packet_size = struct.unpack('>I', header[:4])[0]
rest = header[4:]
lrest = len(rest)
padding = ord(rest[0:1])
packet_type = ord(rest[1:2])
if (packet_size - lrest) % block_size != 0:
out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1)
rlen = packet_size - lrest
if self.unread_len < rlen:
self.recv()
buf = self.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding]
return packet_type, payload
def __del__(self):
self.__cleanup()
def __exit__(self, ex_type, ex_value, tb):
self.__cleanup()
def __cleanup(self):
try:
self.__sock.shutdown(socket.SHUT_RDWR)
self.__sock.close()
except:
pass
def get_ssh_ver(versions):
tv = []
@ -309,26 +358,6 @@ def process_kex(kex):
process_algorithms('mac', kex.server.mac, maxlen)
out.sep()
def read_ssh_packet(sbuf):
block_size = 8
if sbuf.unread_len < block_size:
sbuf.recv()
header = sbuf.read(block_size)
packet_size = struct.unpack('>I', header[:4])[0]
rest = header[4:]
lrest = len(rest)
padding = ord(rest[0:1])
packet_type = ord(rest[1:2])
if (packet_size - lrest) % block_size != 0:
out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1)
rlen = packet_size - lrest
if sbuf.unread_len < rlen:
sbuf.recv()
buf = sbuf.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding]
return packet_type, payload
def parse_int(v):
try:
@ -355,30 +384,20 @@ def parse_args():
def main():
host, port = parse_args()
s = None
try:
s = socket.create_connection((host, port), SOCK_CONN_TIMEOUT)
s.settimeout(SOCK_READ_TIMEOUT)
sbuf = SockBuf(s)
s = SSH.Socket(host, port)
s.send(SSH_BANNER.encode() + b'\r\n')
sbuf.recv()
banner = sbuf.read_line()
s.recv()
banner = s.read_line()
out.head('# general')
out.good('[info] banner: ' + banner)
if banner.startswith('SSH-1.99-'):
out.fail('[fail] protocol SSH1 enabled')
packet_type, payload = read_ssh_packet(sbuf)
if packet_type != 20:
packet_type, payload = s.read_packet()
if packet_type != SSH.MSG_KEXINIT:
out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type))
sys.exit(1)
kex = Kex.parse(payload)
process_kex(kex)
except Exception as e:
out.fail('[fail] {}'.format(e))
sys.exit(1)
finally:
if s:
s.close()
if __name__ == '__main__':
out = Output()