Use safer UTF-8 decoding (with replace) and add related tests.

This commit is contained in:
Andris Raugulis 2016-10-25 13:53:51 +03:00
parent 66bd6c3ef0
commit 4bbb1f4d11
2 changed files with 17 additions and 4 deletions

View File

@ -40,12 +40,12 @@ else:
try: try:
# pylint: disable=unused-import # pylint: disable=unused-import
from typing import List, Tuple, Optional, Callable, Union, Any from typing import List, Tuple, Optional, Callable, Union, Any
except ImportError: except ImportError: # pragma: nocover
pass pass
try: try:
from colorama import init as colorama_init from colorama import init as colorama_init
colorama_init() colorama_init()
except ImportError: except ImportError: # pragma: nocover
pass pass
@ -572,7 +572,7 @@ class ReadBuf(object):
def read_list(self): def read_list(self):
# type: () -> List[text_type] # type: () -> List[text_type]
list_size = self.read_int() list_size = self.read_int()
return self.read(list_size).decode().split(',') return self.read(list_size).decode('utf-8', 'replace').split(',')
def read_string(self): def read_string(self):
# type: () -> binary_type # type: () -> binary_type
@ -607,7 +607,7 @@ class ReadBuf(object):
def read_line(self): def read_line(self):
# type: () -> text_type # type: () -> text_type
return self._buf.readline().rstrip().decode('utf-8') return self._buf.readline().rstrip().decode('utf-8', 'replace')
class WriteBuf(object): class WriteBuf(object):

View File

@ -9,6 +9,7 @@ class TestBuffer(object):
def init(self, ssh_audit): def init(self, ssh_audit):
self.rbuf = ssh_audit.ReadBuf self.rbuf = ssh_audit.ReadBuf
self.wbuf = ssh_audit.WriteBuf self.wbuf = ssh_audit.WriteBuf
self.utf8rchar = b'\xef\xbf\xbd'
def _b(self, v): def _b(self, v):
v = re.sub(r'\s', '', v) v = re.sub(r'\s', '', v)
@ -75,6 +76,12 @@ class TestBuffer(object):
assert w(p[0]) == self._b(p[1]) assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0] assert r(self._b(p[1])) == p[0]
def test_list_nonutf8(self):
r = lambda x: self.rbuf(x).read_list()
src = self._b('00 00 00 04 de ad be ef')
dst = [(b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')]
assert r(src) == dst
def test_line(self): def test_line(self):
w = lambda x: self.wbuf().write_line(x).write_flush() w = lambda x: self.wbuf().write_line(x).write_flush()
r = lambda x: self.rbuf(x).read_line() r = lambda x: self.rbuf(x).read_line()
@ -83,6 +90,12 @@ class TestBuffer(object):
assert w(p[0]) == self._b(p[1]) assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0] assert r(self._b(p[1])) == p[0]
def test_line_nonutf8(self):
r = lambda x: self.rbuf(x).read_line()
src = self._b('de ad be af')
dst = (b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')
assert r(src) == dst
def test_bitlen(self): def test_bitlen(self):
class Py26Int(int): class Py26Int(int):
def bit_length(self): def bit_length(self):