diff --git a/ssh-audit.py b/ssh-audit.py index da4bc7b..f81f2d2 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -172,17 +172,27 @@ class ReadBuf(object): n = self.read_int() return self.read(n) - def read_mpint2(self): - # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt - r, v = 0, self.read_string() - if len(v) == 0: - return r - pad, sf = (b'\xff', '>i') if ord(v[0:1]) & 0x80 else (b'\x00', '>I') + def _parse_mpint(self, v, pad, sf): + r = 0 if len(v) % 4: v = pad * (4 - (len(v) % 4)) + v for i in range(0, len(v), 4): r = (r << 32) | struct.unpack(sf, v[i:i + 4])[0] return r + + def read_mpint1(self): + # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt + bits = struct.unpack('>H', self.read(2))[0] + n = (bits + 7) // 8 + return self._parse_mpint(self.read(n), b'\x00', '>I') + + def read_mpint2(self): + # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt + v = self.read_string() + if len(v) == 0: + return 0 + pad, sf = (b'\xff', '>i') if ord(v[0:1]) & 0x80 else (b'\x00', '>I') + return self._parse_mpint(v, pad, sf) def read_line(self): return self._buf.readline().rstrip().decode('utf-8') @@ -215,13 +225,27 @@ class WriteBuf(object): def write_list(self, v): self.write_string(u','.join(v)) - def write_mpint2(self, v): - # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt + def _create_mpint(self, v): length = v.bit_length() // 8 + (1 if v != 0 else 0) - data = [(v >> i * 8) & 0xff for i in reversed(range(length))] - if length > 1 and data[0] == 0xff and data[1] & 0x80: - data.pop(0) - data = bytes(bytearray(data)) + ql = (length + 7) // 8 + fmt, v2 = '>{0}Q'.format(ql), [b'\x00'] * ql + for i in range(ql): + v2[ql - i - 1] = (v & 0xffffffffffffffff) + v >>= 64 + data = bytes(struct.pack(fmt, *v2)[-length:]) + if data.startswith(b'\xff\x80'): + data = data[1:] + return data + + def write_mpint1(self, v): + # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt + data = self._create_mpint(v) + self.write(struct.pack('>H', v.bit_length())) + return self.write(data) + + def write_mpint2(self, v): + # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt + data = self._create_mpint(v) return self.write_string(data) def write_flush(self): diff --git a/test/test_protocol.py b/test/test_protocol.py index bcff2cc..8519930 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -15,23 +15,27 @@ class TestProtocol(object): data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)] return bytes(bytearray(data)) - def test_mpint2_write(self): - wbuf, _b = self.wbuf(), self._b - mpint = lambda x: wbuf.write_mpint2(x).write_flush() - assert mpint(0x0) == _b('00 00 00 00') - assert mpint(0x80) == _b('00 00 00 02 00 80') - assert mpint(0x9a378f9b2e332a7) == _b('00 00 00 08 09 a3 78 f9 b2 e3 32 a7') - assert mpint(-0x1234) == _b('00 00 00 02 ed cc') - assert mpint(-0xdeadbeef) == _b('00 00 00 05 ff 21 52 41 11') - assert mpint(-0x80) == _b('00 00 00 01 80') + def test_mpint1(self): + mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush() + mpint1r = lambda x: self.rbuf(x).read_mpint1() + tc = [(0x0, '00 00'), + (0x1234, '00 0d 12 34'), + (0x12345, '00 11 01 23 45')] + for p in tc: + assert mpint1w(p[0]) == self._b(p[1]) + assert mpint1r(self._b(p[1])) == p[0] - def test_mpint2_read(self): - rbuf, _b = self.rbuf, self._b - mpint = lambda x: rbuf(x).read_mpint2() - assert mpint(_b('00 00 00 00')) == 0x00 - assert mpint(_b('00 00 00 02 00 80')) == 0x80 - assert mpint(_b('00 00 00 08 09 a3 78 f9 b2 e3 32 a7')) == 0x9a378f9b2e332a7 - assert mpint(_b('00 00 00 02 ed cc')) == -0x1234 - assert mpint(_b('00 00 00 05 ff 21 52 41 11')) == -0xdeadbeef - assert mpint(_b('00 00 00 01 80')) == -0x80 - assert mpint(_b('00 00 00 02 ff 80')) == -0x80 + def test_mpint2(self): + mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush() + mpint2r = lambda x: self.rbuf(x).read_mpint2() + tc = [(0x0, '00 00 00 00'), + (0x80, '00 00 00 02 00 80'), + (0x9a378f9b2e332a7, '00 00 00 08 09 a3 78 f9 b2 e3 32 a7'), + (-0x1234, '00 00 00 02 ed cc'), + (-0xdeadbeef, '00 00 00 05 ff 21 52 41 11'), + (-0x8000, '00 00 00 02 80 00'), + (-0x80, '00 00 00 01 80')] + for p in tc: + assert mpint2w(p[0]) == self._b(p[1]) + assert mpint2r(self._b(p[1])) == p[0] + assert mpint2r(self._b('00 00 00 02 ff 80')) == -0x80