1import re
2import pytest
3
4from ssh_audit.readbuf import ReadBuf
5from ssh_audit.writebuf import WriteBuf
6
7
8# pylint: disable=attribute-defined-outside-init,bad-whitespace
9class TestBuffer:
10    @pytest.fixture(autouse=True)
11    def init(self, ssh_audit):
12        self.rbuf = ReadBuf
13        self.wbuf = WriteBuf
14        self.utf8rchar = b'\xef\xbf\xbd'
15
16    @classmethod
17    def _b(cls, v):
18        v = re.sub(r'\s', '', v)
19        data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)]
20        return bytes(bytearray(data))
21
22    def test_unread(self):
23        w = self.wbuf().write_byte(1).write_int(2).write_flush()
24        r = self.rbuf(w)
25        assert r.unread_len == 5
26        r.read_byte()
27        assert r.unread_len == 4
28        r.read_int()
29        assert r.unread_len == 0
30
31    def test_byte(self):
32        w = lambda x: self.wbuf().write_byte(x).write_flush()  # noqa
33        r = lambda x: self.rbuf(x).read_byte()  # noqa
34        tc = [(0x00, '00'),
35              (0x01, '01'),
36              (0x10, '10'),
37              (0xff, 'ff')]
38        for p in tc:
39            assert w(p[0]) == self._b(p[1])
40            assert r(self._b(p[1])) == p[0]
41
42    def test_bool(self):
43        w = lambda x: self.wbuf().write_bool(x).write_flush()  # noqa
44        r = lambda x: self.rbuf(x).read_bool()  # noqa
45        tc = [(True,  '01'),
46              (False, '00')]
47        for p in tc:
48            assert w(p[0]) == self._b(p[1])
49            assert r(self._b(p[1])) == p[0]
50
51    def test_int(self):
52        w = lambda x: self.wbuf().write_int(x).write_flush()  # noqa
53        r = lambda x: self.rbuf(x).read_int()  # noqa
54        tc = [(0x00,       '00 00 00 00'),
55              (0x01,       '00 00 00 01'),
56              (0xabcd,     '00 00 ab cd'),
57              (0xffffffff, 'ff ff ff ff')]
58        for p in tc:
59            assert w(p[0]) == self._b(p[1])
60            assert r(self._b(p[1])) == p[0]
61
62    def test_string(self):
63        w = lambda x: self.wbuf().write_string(x).write_flush()  # noqa
64        r = lambda x: self.rbuf(x).read_string()  # noqa
65        tc = [('abc1',  '00 00 00 04 61 62 63 31'),
66              (b'abc2',  '00 00 00 04 61 62 63 32')]
67        for p in tc:
68            v = p[0]
69            assert w(v) == self._b(p[1])
70            if not isinstance(v, bytes):
71                v = bytes(bytearray(v, 'utf-8'))
72            assert r(self._b(p[1])) == v
73
74    def test_list(self):
75        w = lambda x: self.wbuf().write_list(x).write_flush()  # noqa
76        r = lambda x: self.rbuf(x).read_list()  # noqa
77        tc = [(['d', 'ef', 'ault'], '00 00 00 09 64 2c 65 66 2c 61 75 6c 74')]
78        for p in tc:
79            assert w(p[0]) == self._b(p[1])
80            assert r(self._b(p[1])) == p[0]
81
82    def test_list_nonutf8(self):
83        r = lambda x: self.rbuf(x).read_list()  # noqa
84        src = self._b('00 00 00 04 de ad be ef')
85        dst = [(b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')]
86        assert r(src) == dst
87
88    def test_line(self):
89        w = lambda x: self.wbuf().write_line(x).write_flush()  # noqa
90        r = lambda x: self.rbuf(x).read_line()  # noqa
91        tc = [('example line', '65 78 61 6d 70 6c 65 20 6c 69 6e 65 0d 0a')]
92        for p in tc:
93            assert w(p[0]) == self._b(p[1])
94            assert r(self._b(p[1])) == p[0]
95
96    def test_line_nonutf8(self):
97        r = lambda x: self.rbuf(x).read_line()  # noqa
98        src = self._b('de ad be af')
99        dst = (b'\xde\xad' + self.utf8rchar + self.utf8rchar).decode('utf-8')
100        assert r(src) == dst
101
102    def test_bitlen(self):
103        # pylint: disable=protected-access
104        class Py26Int(int):
105            def bit_length(self):
106                raise AttributeError
107        assert self.wbuf._bitlength(42) == 6
108        assert self.wbuf._bitlength(Py26Int(42)) == 6
109
110    def test_mpint1(self):
111        mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush()  # noqa
112        mpint1r = lambda x: self.rbuf(x).read_mpint1()  # noqa
113        tc = [(0x0,     '00 00'),
114              (0x1234,  '00 0d 12 34'),
115              (0x12345, '00 11 01 23 45'),
116              (0xdeadbeef, '00 20 de ad be ef')]
117        for p in tc:
118            assert mpint1w(p[0]) == self._b(p[1])
119            assert mpint1r(self._b(p[1])) == p[0]
120
121    def test_mpint2(self):
122        mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush()  # noqa
123        mpint2r = lambda x: self.rbuf(x).read_mpint2()  # noqa
124        tc = [(0x0,               '00 00 00 00'),
125              (0x80,              '00 00 00 02 00 80'),
126              (0x9a378f9b2e332a7, '00 00 00 08 09 a3 78 f9 b2 e3 32 a7'),
127              (-0x1234,           '00 00 00 02 ed cc'),
128              (-0xdeadbeef,       '00 00 00 05 ff 21 52 41 11'),
129              (-0x8000,           '00 00 00 02 80 00'),
130              (-0x80,             '00 00 00 01 80')]
131        for p in tc:
132            assert mpint2w(p[0]) == self._b(p[1])
133            assert mpint2r(self._b(p[1])) == p[0]
134        assert mpint2r(self._b('00 00 00 02 ff 80')) == -0x80
135
136    def test_reset(self):
137        w = self.wbuf()
138        w.write_int(7)
139        w.write_int(13)
140        assert len(w.write_flush()) == 8
141
142        w.write_int(7)
143        w.write_int(13)
144        w.reset()
145        assert len(w.write_flush()) == 0
146