1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9import weakref
10
11import unittest
12from test import support
13
14from test.pickletester import AbstractUnpickleTests
15from test.pickletester import AbstractPickleTests
16from test.pickletester import AbstractPickleModuleTests
17from test.pickletester import AbstractPersistentPicklerTests
18from test.pickletester import AbstractIdentityPersistentPicklerTests
19from test.pickletester import AbstractPicklerUnpicklerObjectTests
20from test.pickletester import AbstractDispatchTableTests
21from test.pickletester import BigmemPickleTests
22
23try:
24    import _pickle
25    has_c_implementation = True
26except ImportError:
27    has_c_implementation = False
28
29
30class PyPickleTests(AbstractPickleModuleTests):
31    dump = staticmethod(pickle._dump)
32    dumps = staticmethod(pickle._dumps)
33    load = staticmethod(pickle._load)
34    loads = staticmethod(pickle._loads)
35    Pickler = pickle._Pickler
36    Unpickler = pickle._Unpickler
37
38
39class PyUnpicklerTests(AbstractUnpickleTests):
40
41    unpickler = pickle._Unpickler
42    bad_stack_errors = (IndexError,)
43    truncated_errors = (pickle.UnpicklingError, EOFError,
44                        AttributeError, ValueError,
45                        struct.error, IndexError, ImportError)
46
47    def loads(self, buf, **kwds):
48        f = io.BytesIO(buf)
49        u = self.unpickler(f, **kwds)
50        return u.load()
51
52
53class PyPicklerTests(AbstractPickleTests):
54
55    pickler = pickle._Pickler
56    unpickler = pickle._Unpickler
57
58    def dumps(self, arg, proto=None):
59        f = io.BytesIO()
60        p = self.pickler(f, proto)
61        p.dump(arg)
62        f.seek(0)
63        return bytes(f.read())
64
65    def loads(self, buf, **kwds):
66        f = io.BytesIO(buf)
67        u = self.unpickler(f, **kwds)
68        return u.load()
69
70
71class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
72                          BigmemPickleTests):
73
74    bad_stack_errors = (pickle.UnpicklingError, IndexError)
75    truncated_errors = (pickle.UnpicklingError, EOFError,
76                        AttributeError, ValueError,
77                        struct.error, IndexError, ImportError)
78
79    def dumps(self, arg, protocol=None):
80        return pickle.dumps(arg, protocol)
81
82    def loads(self, buf, **kwds):
83        return pickle.loads(buf, **kwds)
84
85    test_framed_write_sizes_with_delayed_writer = None
86
87
88class PersistentPicklerUnpicklerMixin(object):
89
90    def dumps(self, arg, proto=None):
91        class PersPickler(self.pickler):
92            def persistent_id(subself, obj):
93                return self.persistent_id(obj)
94        f = io.BytesIO()
95        p = PersPickler(f, proto)
96        p.dump(arg)
97        return f.getvalue()
98
99    def loads(self, buf, **kwds):
100        class PersUnpickler(self.unpickler):
101            def persistent_load(subself, obj):
102                return self.persistent_load(obj)
103        f = io.BytesIO(buf)
104        u = PersUnpickler(f, **kwds)
105        return u.load()
106
107
108class PyPersPicklerTests(AbstractPersistentPicklerTests,
109                         PersistentPicklerUnpicklerMixin):
110
111    pickler = pickle._Pickler
112    unpickler = pickle._Unpickler
113
114
115class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
116                           PersistentPicklerUnpicklerMixin):
117
118    pickler = pickle._Pickler
119    unpickler = pickle._Unpickler
120
121    @support.cpython_only
122    def test_pickler_reference_cycle(self):
123        def check(Pickler):
124            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
125                f = io.BytesIO()
126                pickler = Pickler(f, proto)
127                pickler.dump('abc')
128                self.assertEqual(self.loads(f.getvalue()), 'abc')
129            pickler = Pickler(io.BytesIO())
130            self.assertEqual(pickler.persistent_id('def'), 'def')
131            r = weakref.ref(pickler)
132            del pickler
133            self.assertIsNone(r())
134
135        class PersPickler(self.pickler):
136            def persistent_id(subself, obj):
137                return obj
138        check(PersPickler)
139
140        class PersPickler(self.pickler):
141            @classmethod
142            def persistent_id(cls, obj):
143                return obj
144        check(PersPickler)
145
146        class PersPickler(self.pickler):
147            @staticmethod
148            def persistent_id(obj):
149                return obj
150        check(PersPickler)
151
152    @support.cpython_only
153    def test_unpickler_reference_cycle(self):
154        def check(Unpickler):
155            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
156                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
157                self.assertEqual(unpickler.load(), 'abc')
158            unpickler = Unpickler(io.BytesIO())
159            self.assertEqual(unpickler.persistent_load('def'), 'def')
160            r = weakref.ref(unpickler)
161            del unpickler
162            self.assertIsNone(r())
163
164        class PersUnpickler(self.unpickler):
165            def persistent_load(subself, pid):
166                return pid
167        check(PersUnpickler)
168
169        class PersUnpickler(self.unpickler):
170            @classmethod
171            def persistent_load(cls, pid):
172                return pid
173        check(PersUnpickler)
174
175        class PersUnpickler(self.unpickler):
176            @staticmethod
177            def persistent_load(pid):
178                return pid
179        check(PersUnpickler)
180
181
182class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
183
184    pickler_class = pickle._Pickler
185    unpickler_class = pickle._Unpickler
186
187
188class PyDispatchTableTests(AbstractDispatchTableTests):
189
190    pickler_class = pickle._Pickler
191
192    def get_dispatch_table(self):
193        return pickle.dispatch_table.copy()
194
195
196class PyChainDispatchTableTests(AbstractDispatchTableTests):
197
198    pickler_class = pickle._Pickler
199
200    def get_dispatch_table(self):
201        return collections.ChainMap({}, pickle.dispatch_table)
202
203
204if has_c_implementation:
205    class CPickleTests(AbstractPickleModuleTests):
206        from _pickle import dump, dumps, load, loads, Pickler, Unpickler
207
208    class CUnpicklerTests(PyUnpicklerTests):
209        unpickler = _pickle.Unpickler
210        bad_stack_errors = (pickle.UnpicklingError,)
211        truncated_errors = (pickle.UnpicklingError,)
212
213    class CPicklerTests(PyPicklerTests):
214        pickler = _pickle.Pickler
215        unpickler = _pickle.Unpickler
216
217    class CPersPicklerTests(PyPersPicklerTests):
218        pickler = _pickle.Pickler
219        unpickler = _pickle.Unpickler
220
221    class CIdPersPicklerTests(PyIdPersPicklerTests):
222        pickler = _pickle.Pickler
223        unpickler = _pickle.Unpickler
224
225    class CDumpPickle_LoadPickle(PyPicklerTests):
226        pickler = _pickle.Pickler
227        unpickler = pickle._Unpickler
228
229    class DumpPickle_CLoadPickle(PyPicklerTests):
230        pickler = pickle._Pickler
231        unpickler = _pickle.Unpickler
232
233    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
234        pickler_class = _pickle.Pickler
235        unpickler_class = _pickle.Unpickler
236
237        def test_issue18339(self):
238            unpickler = self.unpickler_class(io.BytesIO())
239            with self.assertRaises(TypeError):
240                unpickler.memo = object
241            # used to cause a segfault
242            with self.assertRaises(ValueError):
243                unpickler.memo = {-1: None}
244            unpickler.memo = {1: None}
245
246    class CDispatchTableTests(AbstractDispatchTableTests):
247        pickler_class = pickle.Pickler
248        def get_dispatch_table(self):
249            return pickle.dispatch_table.copy()
250
251    class CChainDispatchTableTests(AbstractDispatchTableTests):
252        pickler_class = pickle.Pickler
253        def get_dispatch_table(self):
254            return collections.ChainMap({}, pickle.dispatch_table)
255
256    @support.cpython_only
257    class SizeofTests(unittest.TestCase):
258        check_sizeof = support.check_sizeof
259
260        def test_pickler(self):
261            basesize = support.calcobjsize('6P2n3i2n3iP')
262            p = _pickle.Pickler(io.BytesIO())
263            self.assertEqual(object.__sizeof__(p), basesize)
264            MT_size = struct.calcsize('3nP0n')
265            ME_size = struct.calcsize('Pn0P')
266            check = self.check_sizeof
267            check(p, basesize +
268                MT_size + 8 * ME_size +  # Minimal memo table size.
269                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
270            for i in range(6):
271                p.dump(chr(i))
272            check(p, basesize +
273                MT_size + 32 * ME_size +  # Size of memo table required to
274                                          # save references to 6 objects.
275                0)  # Write buffer is cleared after every dump().
276
277        def test_unpickler(self):
278            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
279            unpickler = _pickle.Unpickler
280            P = struct.calcsize('P')  # Size of memo table entry.
281            n = struct.calcsize('n')  # Size of mark table entry.
282            check = self.check_sizeof
283            for encoding in 'ASCII', 'UTF-16', 'latin-1':
284                for errors in 'strict', 'replace':
285                    u = unpickler(io.BytesIO(),
286                                  encoding=encoding, errors=errors)
287                    self.assertEqual(object.__sizeof__(u), basesize)
288                    check(u, basesize +
289                             32 * P +  # Minimal memo table size.
290                             len(encoding) + 1 + len(errors) + 1)
291
292            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
293            def check_unpickler(data, memo_size, marks_size):
294                dump = pickle.dumps(data)
295                u = unpickler(io.BytesIO(dump),
296                              encoding='ASCII', errors='strict')
297                u.load()
298                check(u, stdsize + memo_size * P + marks_size * n)
299
300            check_unpickler(0, 32, 0)
301            # 20 is minimal non-empty mark stack size.
302            check_unpickler([0] * 100, 32, 20)
303            # 128 is memo table size required to save references to 100 objects.
304            check_unpickler([chr(i) for i in range(100)], 128, 20)
305            def recurse(deep):
306                data = 0
307                for i in range(deep):
308                    data = [data, data]
309                return data
310            check_unpickler(recurse(0), 32, 0)
311            check_unpickler(recurse(1), 32, 20)
312            check_unpickler(recurse(20), 32, 58)
313            check_unpickler(recurse(50), 64, 58)
314            check_unpickler(recurse(100), 128, 134)
315
316            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
317                          encoding='ASCII', errors='strict')
318            u.load()
319            check(u, stdsize + 32 * P + 2 + 1)
320
321
322ALT_IMPORT_MAPPING = {
323    ('_elementtree', 'xml.etree.ElementTree'),
324    ('cPickle', 'pickle'),
325    ('StringIO', 'io'),
326    ('cStringIO', 'io'),
327}
328
329ALT_NAME_MAPPING = {
330    ('__builtin__', 'basestring', 'builtins', 'str'),
331    ('exceptions', 'StandardError', 'builtins', 'Exception'),
332    ('UserDict', 'UserDict', 'collections', 'UserDict'),
333    ('socket', '_socketobject', 'socket', 'SocketType'),
334}
335
336def mapping(module, name):
337    if (module, name) in NAME_MAPPING:
338        module, name = NAME_MAPPING[(module, name)]
339    elif module in IMPORT_MAPPING:
340        module = IMPORT_MAPPING[module]
341    return module, name
342
343def reverse_mapping(module, name):
344    if (module, name) in REVERSE_NAME_MAPPING:
345        module, name = REVERSE_NAME_MAPPING[(module, name)]
346    elif module in REVERSE_IMPORT_MAPPING:
347        module = REVERSE_IMPORT_MAPPING[module]
348    return module, name
349
350def getmodule(module):
351    try:
352        return sys.modules[module]
353    except KeyError:
354        try:
355            __import__(module)
356        except AttributeError as exc:
357            if support.verbose:
358                print("Can't import module %r: %s" % (module, exc))
359            raise ImportError
360        except ImportError as exc:
361            if support.verbose:
362                print(exc)
363            raise
364        return sys.modules[module]
365
366def getattribute(module, name):
367    obj = getmodule(module)
368    for n in name.split('.'):
369        obj = getattr(obj, n)
370    return obj
371
372def get_exceptions(mod):
373    for name in dir(mod):
374        attr = getattr(mod, name)
375        if isinstance(attr, type) and issubclass(attr, BaseException):
376            yield name, attr
377
378class CompatPickleTests(unittest.TestCase):
379    def test_import(self):
380        modules = set(IMPORT_MAPPING.values())
381        modules |= set(REVERSE_IMPORT_MAPPING)
382        modules |= {module for module, name in REVERSE_NAME_MAPPING}
383        modules |= {module for module, name in NAME_MAPPING.values()}
384        for module in modules:
385            try:
386                getmodule(module)
387            except ImportError:
388                pass
389
390    def test_import_mapping(self):
391        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
392            with self.subTest((module3, module2)):
393                try:
394                    getmodule(module3)
395                except ImportError:
396                    pass
397                if module3[:1] != '_':
398                    self.assertIn(module2, IMPORT_MAPPING)
399                    self.assertEqual(IMPORT_MAPPING[module2], module3)
400
401    def test_name_mapping(self):
402        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
403            with self.subTest(((module3, name3), (module2, name2))):
404                if (module2, name2) == ('exceptions', 'OSError'):
405                    attr = getattribute(module3, name3)
406                    self.assertTrue(issubclass(attr, OSError))
407                elif (module2, name2) == ('exceptions', 'ImportError'):
408                    attr = getattribute(module3, name3)
409                    self.assertTrue(issubclass(attr, ImportError))
410                else:
411                    module, name = mapping(module2, name2)
412                    if module3[:1] != '_':
413                        self.assertEqual((module, name), (module3, name3))
414                    try:
415                        attr = getattribute(module3, name3)
416                    except ImportError:
417                        pass
418                    else:
419                        self.assertEqual(getattribute(module, name), attr)
420
421    def test_reverse_import_mapping(self):
422        for module2, module3 in IMPORT_MAPPING.items():
423            with self.subTest((module2, module3)):
424                try:
425                    getmodule(module3)
426                except ImportError as exc:
427                    if support.verbose:
428                        print(exc)
429                if ((module2, module3) not in ALT_IMPORT_MAPPING and
430                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
431                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
432                        if (module3, module2) == (m3, m2):
433                            break
434                    else:
435                        self.fail('No reverse mapping from %r to %r' %
436                                  (module3, module2))
437                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
438                module = IMPORT_MAPPING.get(module, module)
439                self.assertEqual(module, module3)
440
441    def test_reverse_name_mapping(self):
442        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
443            with self.subTest(((module2, name2), (module3, name3))):
444                try:
445                    attr = getattribute(module3, name3)
446                except ImportError:
447                    pass
448                module, name = reverse_mapping(module3, name3)
449                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
450                    self.assertEqual((module, name), (module2, name2))
451                module, name = mapping(module, name)
452                self.assertEqual((module, name), (module3, name3))
453
454    def test_exceptions(self):
455        self.assertEqual(mapping('exceptions', 'StandardError'),
456                         ('builtins', 'Exception'))
457        self.assertEqual(mapping('exceptions', 'Exception'),
458                         ('builtins', 'Exception'))
459        self.assertEqual(reverse_mapping('builtins', 'Exception'),
460                         ('exceptions', 'Exception'))
461        self.assertEqual(mapping('exceptions', 'OSError'),
462                         ('builtins', 'OSError'))
463        self.assertEqual(reverse_mapping('builtins', 'OSError'),
464                         ('exceptions', 'OSError'))
465
466        for name, exc in get_exceptions(builtins):
467            with self.subTest(name):
468                if exc in (BlockingIOError,
469                           ResourceWarning,
470                           StopAsyncIteration,
471                           RecursionError):
472                    continue
473                if exc is not OSError and issubclass(exc, OSError):
474                    self.assertEqual(reverse_mapping('builtins', name),
475                                     ('exceptions', 'OSError'))
476                elif exc is not ImportError and issubclass(exc, ImportError):
477                    self.assertEqual(reverse_mapping('builtins', name),
478                                     ('exceptions', 'ImportError'))
479                    self.assertEqual(mapping('exceptions', name),
480                                     ('exceptions', name))
481                else:
482                    self.assertEqual(reverse_mapping('builtins', name),
483                                     ('exceptions', name))
484                    self.assertEqual(mapping('exceptions', name),
485                                     ('builtins', name))
486
487    def test_multiprocessing_exceptions(self):
488        module = support.import_module('multiprocessing.context')
489        for name, exc in get_exceptions(module):
490            with self.subTest(name):
491                self.assertEqual(reverse_mapping('multiprocessing.context', name),
492                                 ('multiprocessing', name))
493                self.assertEqual(mapping('multiprocessing', name),
494                                 ('multiprocessing.context', name))
495
496
497def test_main():
498    tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests,
499             PyPersPicklerTests, PyIdPersPicklerTests,
500             PyDispatchTableTests, PyChainDispatchTableTests,
501             CompatPickleTests]
502    if has_c_implementation:
503        tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests,
504                      CPersPicklerTests, CIdPersPicklerTests,
505                      CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
506                      PyPicklerUnpicklerObjectTests,
507                      CPicklerUnpicklerObjectTests,
508                      CDispatchTableTests, CChainDispatchTableTests,
509                      InMemoryPickleTests, SizeofTests])
510    support.run_unittest(*tests)
511    support.run_doctest(pickle)
512
513if __name__ == "__main__":
514    test_main()
515