1from collections import OrderedDict
2from datetime import datetime, timezone
3from io import StringIO
4from unittest.mock import Mock
5
6import pytest
7
8from . import BaseTestCase
9from ..crypto.key import PlaintextKey
10from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS, Statistics
11from ..archive import BackupOSError, backup_io, backup_io_iter
12from ..helpers import Manifest
13from ..helpers import msgpack
14from ..item import Item, ArchiveItem
15
16
17@pytest.fixture()
18def stats():
19    stats = Statistics()
20    stats.update(20, 10, unique=True)
21    return stats
22
23
24def test_stats_basic(stats):
25    assert stats.osize == 20
26    assert stats.csize == stats.usize == 10
27    stats.update(20, 10, unique=False)
28    assert stats.osize == 40
29    assert stats.csize == 20
30    assert stats.usize == 10
31
32
33def tests_stats_progress(stats, monkeypatch, columns=80):
34    monkeypatch.setenv('COLUMNS', str(columns))
35    out = StringIO()
36    stats.show_progress(stream=out)
37    s = '20 B O 10 B C 10 B D 0 N '
38    buf = ' ' * (columns - len(s))
39    assert out.getvalue() == s + buf + "\r"
40
41    out = StringIO()
42    stats.update(10**3, 0, unique=False)
43    stats.show_progress(item=Item(path='foo'), final=False, stream=out)
44    s = '1.02 kB O 10 B C 10 B D 0 N foo'
45    buf = ' ' * (columns - len(s))
46    assert out.getvalue() == s + buf + "\r"
47    out = StringIO()
48    stats.show_progress(item=Item(path='foo'*40), final=False, stream=out)
49    s = '1.02 kB O 10 B C 10 B D 0 N foofoofoofoofoofoofoofo...oofoofoofoofoofoofoofoofoo'
50    buf = ' ' * (columns - len(s))
51    assert out.getvalue() == s + buf + "\r"
52
53
54def test_stats_format(stats):
55    assert str(stats) == """\
56This archive:                   20 B                 10 B                 10 B"""
57    s = "{0.osize_fmt}".format(stats)
58    assert s == "20 B"
59    # kind of redundant, but id is variable so we can't match reliably
60    assert repr(stats) == '<Statistics object at {:#x} (20, 10, 10)>'.format(id(stats))
61
62
63class MockCache:
64
65    class MockRepo:
66        def async_response(self, wait=True):
67            pass
68
69    def __init__(self):
70        self.objects = {}
71        self.repository = self.MockRepo()
72
73    def add_chunk(self, id, chunk, stats=None, wait=True):
74        self.objects[id] = chunk
75        return id, len(chunk), len(chunk)
76
77
78class ArchiveTimestampTestCase(BaseTestCase):
79
80    def _test_timestamp_parsing(self, isoformat, expected):
81        repository = Mock()
82        key = PlaintextKey(repository)
83        manifest = Manifest(repository, key)
84        a = Archive(repository, key, manifest, 'test', create=True)
85        a.metadata = ArchiveItem(time=isoformat)
86        self.assert_equal(a.ts, expected)
87
88    def test_with_microseconds(self):
89        self._test_timestamp_parsing(
90            '1970-01-01T00:00:01.000001',
91            datetime(1970, 1, 1, 0, 0, 1, 1, timezone.utc))
92
93    def test_without_microseconds(self):
94        self._test_timestamp_parsing(
95            '1970-01-01T00:00:01',
96            datetime(1970, 1, 1, 0, 0, 1, 0, timezone.utc))
97
98
99class ChunkBufferTestCase(BaseTestCase):
100
101    def test(self):
102        data = [Item(path='p1'), Item(path='p2')]
103        cache = MockCache()
104        key = PlaintextKey(None)
105        chunks = CacheChunkBuffer(cache, key, None)
106        for d in data:
107            chunks.add(d)
108            chunks.flush()
109        chunks.flush(flush=True)
110        self.assert_equal(len(chunks.chunks), 2)
111        unpacker = msgpack.Unpacker()
112        for id in chunks.chunks:
113            unpacker.feed(cache.objects[id])
114        self.assert_equal(data, [Item(internal_dict=d) for d in unpacker])
115
116    def test_partial(self):
117        big = "0123456789abcdefghijklmnopqrstuvwxyz" * 25000
118        data = [Item(path='full', source=big), Item(path='partial', source=big)]
119        cache = MockCache()
120        key = PlaintextKey(None)
121        chunks = CacheChunkBuffer(cache, key, None)
122        for d in data:
123            chunks.add(d)
124        chunks.flush(flush=False)
125        # the code is expected to leave the last partial chunk in the buffer
126        self.assert_equal(len(chunks.chunks), 3)
127        self.assert_true(chunks.buffer.tell() > 0)
128        # now really flush
129        chunks.flush(flush=True)
130        self.assert_equal(len(chunks.chunks), 4)
131        self.assert_true(chunks.buffer.tell() == 0)
132        unpacker = msgpack.Unpacker()
133        for id in chunks.chunks:
134            unpacker.feed(cache.objects[id])
135        self.assert_equal(data, [Item(internal_dict=d) for d in unpacker])
136
137
138class RobustUnpackerTestCase(BaseTestCase):
139
140    def make_chunks(self, items):
141        return b''.join(msgpack.packb({'path': item}) for item in items)
142
143    def _validator(self, value):
144        return isinstance(value, dict) and value.get(b'path') in (b'foo', b'bar', b'boo', b'baz')
145
146    def process(self, input):
147        unpacker = RobustUnpacker(validator=self._validator, item_keys=ITEM_KEYS)
148        result = []
149        for should_sync, chunks in input:
150            if should_sync:
151                unpacker.resync()
152            for data in chunks:
153                unpacker.feed(data)
154                for item in unpacker:
155                    result.append(item)
156        return result
157
158    def test_extra_garbage_no_sync(self):
159        chunks = [(False, [self.make_chunks([b'foo', b'bar'])]),
160                  (False, [b'garbage'] + [self.make_chunks([b'boo', b'baz'])])]
161        result = self.process(chunks)
162        self.assert_equal(result, [
163            {b'path': b'foo'}, {b'path': b'bar'},
164            103, 97, 114, 98, 97, 103, 101,
165            {b'path': b'boo'},
166            {b'path': b'baz'}])
167
168    def split(self, left, length):
169        parts = []
170        while left:
171            parts.append(left[:length])
172            left = left[length:]
173        return parts
174
175    def test_correct_stream(self):
176        chunks = self.split(self.make_chunks([b'foo', b'bar', b'boo', b'baz']), 2)
177        input = [(False, chunks)]
178        result = self.process(input)
179        self.assert_equal(result, [{b'path': b'foo'}, {b'path': b'bar'}, {b'path': b'boo'}, {b'path': b'baz'}])
180
181    def test_missing_chunk(self):
182        chunks = self.split(self.make_chunks([b'foo', b'bar', b'boo', b'baz']), 4)
183        input = [(False, chunks[:3]), (True, chunks[4:])]
184        result = self.process(input)
185        self.assert_equal(result, [{b'path': b'foo'}, {b'path': b'boo'}, {b'path': b'baz'}])
186
187    def test_corrupt_chunk(self):
188        chunks = self.split(self.make_chunks([b'foo', b'bar', b'boo', b'baz']), 4)
189        input = [(False, chunks[:3]), (True, [b'gar', b'bage'] + chunks[3:])]
190        result = self.process(input)
191        self.assert_equal(result, [{b'path': b'foo'}, {b'path': b'boo'}, {b'path': b'baz'}])
192
193
194@pytest.fixture
195def item_keys_serialized():
196    return [msgpack.packb(name) for name in ITEM_KEYS]
197
198
199@pytest.mark.parametrize('packed',
200    [b'', b'x', b'foobar', ] +
201    [msgpack.packb(o) for o in (
202        [None, 0, 0.0, False, '', {}, [], ()] +
203        [42, 23.42, True, b'foobar', {b'foo': b'bar'}, [b'foo', b'bar'], (b'foo', b'bar')]
204    )])
205def test_invalid_msgpacked_item(packed, item_keys_serialized):
206    assert not valid_msgpacked_dict(packed, item_keys_serialized)
207
208
209# pytest-xdist requires always same order for the keys and dicts:
210IK = sorted(list(ITEM_KEYS))
211
212
213@pytest.mark.parametrize('packed',
214    [msgpack.packb(o) for o in [
215        {b'path': b'/a/b/c'},  # small (different msgpack mapping type!)
216        OrderedDict((k, b'') for k in IK),  # as big (key count) as it gets
217        OrderedDict((k, b'x' * 1000) for k in IK),  # as big (key count and volume) as it gets
218    ]])
219def test_valid_msgpacked_items(packed, item_keys_serialized):
220    assert valid_msgpacked_dict(packed, item_keys_serialized)
221
222
223def test_key_length_msgpacked_items():
224    key = b'x' * 32  # 31 bytes is the limit for fixstr msgpack type
225    data = {key: b''}
226    item_keys_serialized = [msgpack.packb(key), ]
227    assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized)
228
229
230def test_backup_io():
231    with pytest.raises(BackupOSError):
232        with backup_io:
233            raise OSError(123)
234
235
236def test_backup_io_iter():
237    class Iterator:
238        def __init__(self, exc):
239            self.exc = exc
240
241        def __next__(self):
242            raise self.exc()
243
244    oserror_iterator = Iterator(OSError)
245    with pytest.raises(BackupOSError):
246        for _ in backup_io_iter(oserror_iterator):
247            pass
248
249    normal_iterator = Iterator(StopIteration)
250    for _ in backup_io_iter(normal_iterator):
251        assert False, 'StopIteration handled incorrectly'
252