1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9import warnings
10import weakref
11
12import doctest
13import unittest
14from test import support
15
16from test.pickletester import AbstractHookTests
17from test.pickletester import AbstractUnpickleTests
18from test.pickletester import AbstractPickleTests
19from test.pickletester import AbstractPickleModuleTests
20from test.pickletester import AbstractPersistentPicklerTests
21from test.pickletester import AbstractIdentityPersistentPicklerTests
22from test.pickletester import AbstractPicklerUnpicklerObjectTests
23from test.pickletester import AbstractDispatchTableTests
24from test.pickletester import AbstractCustomPicklerClass
25from test.pickletester import BigmemPickleTests
26
27try:
28    import _pickle
29    has_c_implementation = True
30except ImportError:
31    has_c_implementation = False
32
33
34class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
35    dump = staticmethod(pickle._dump)
36    dumps = staticmethod(pickle._dumps)
37    load = staticmethod(pickle._load)
38    loads = staticmethod(pickle._loads)
39    Pickler = pickle._Pickler
40    Unpickler = pickle._Unpickler
41
42
43class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
44
45    unpickler = pickle._Unpickler
46    bad_stack_errors = (IndexError,)
47    truncated_errors = (pickle.UnpicklingError, EOFError,
48                        AttributeError, ValueError,
49                        struct.error, IndexError, ImportError)
50
51    def loads(self, buf, **kwds):
52        f = io.BytesIO(buf)
53        u = self.unpickler(f, **kwds)
54        return u.load()
55
56
57class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
58
59    pickler = pickle._Pickler
60    unpickler = pickle._Unpickler
61
62    def dumps(self, arg, proto=None, **kwargs):
63        f = io.BytesIO()
64        p = self.pickler(f, proto, **kwargs)
65        p.dump(arg)
66        f.seek(0)
67        return bytes(f.read())
68
69    def loads(self, buf, **kwds):
70        f = io.BytesIO(buf)
71        u = self.unpickler(f, **kwds)
72        return u.load()
73
74
75class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
76                          BigmemPickleTests, unittest.TestCase):
77
78    bad_stack_errors = (pickle.UnpicklingError, IndexError)
79    truncated_errors = (pickle.UnpicklingError, EOFError,
80                        AttributeError, ValueError,
81                        struct.error, IndexError, ImportError)
82
83    def dumps(self, arg, protocol=None, **kwargs):
84        return pickle.dumps(arg, protocol, **kwargs)
85
86    def loads(self, buf, **kwds):
87        return pickle.loads(buf, **kwds)
88
89    test_framed_write_sizes_with_delayed_writer = None
90
91
92class PersistentPicklerUnpicklerMixin(object):
93
94    def dumps(self, arg, proto=None):
95        class PersPickler(self.pickler):
96            def persistent_id(subself, obj):
97                return self.persistent_id(obj)
98        f = io.BytesIO()
99        p = PersPickler(f, proto)
100        p.dump(arg)
101        return f.getvalue()
102
103    def loads(self, buf, **kwds):
104        class PersUnpickler(self.unpickler):
105            def persistent_load(subself, obj):
106                return self.persistent_load(obj)
107        f = io.BytesIO(buf)
108        u = PersUnpickler(f, **kwds)
109        return u.load()
110
111
112class PyPersPicklerTests(AbstractPersistentPicklerTests,
113                         PersistentPicklerUnpicklerMixin, unittest.TestCase):
114
115    pickler = pickle._Pickler
116    unpickler = pickle._Unpickler
117
118
119class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
120                           PersistentPicklerUnpicklerMixin, unittest.TestCase):
121
122    pickler = pickle._Pickler
123    unpickler = pickle._Unpickler
124
125    @support.cpython_only
126    def test_pickler_reference_cycle(self):
127        def check(Pickler):
128            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
129                f = io.BytesIO()
130                pickler = Pickler(f, proto)
131                pickler.dump('abc')
132                self.assertEqual(self.loads(f.getvalue()), 'abc')
133            pickler = Pickler(io.BytesIO())
134            self.assertEqual(pickler.persistent_id('def'), 'def')
135            r = weakref.ref(pickler)
136            del pickler
137            self.assertIsNone(r())
138
139        class PersPickler(self.pickler):
140            def persistent_id(subself, obj):
141                return obj
142        check(PersPickler)
143
144        class PersPickler(self.pickler):
145            @classmethod
146            def persistent_id(cls, obj):
147                return obj
148        check(PersPickler)
149
150        class PersPickler(self.pickler):
151            @staticmethod
152            def persistent_id(obj):
153                return obj
154        check(PersPickler)
155
156    @support.cpython_only
157    def test_unpickler_reference_cycle(self):
158        def check(Unpickler):
159            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
160                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
161                self.assertEqual(unpickler.load(), 'abc')
162            unpickler = Unpickler(io.BytesIO())
163            self.assertEqual(unpickler.persistent_load('def'), 'def')
164            r = weakref.ref(unpickler)
165            del unpickler
166            self.assertIsNone(r())
167
168        class PersUnpickler(self.unpickler):
169            def persistent_load(subself, pid):
170                return pid
171        check(PersUnpickler)
172
173        class PersUnpickler(self.unpickler):
174            @classmethod
175            def persistent_load(cls, pid):
176                return pid
177        check(PersUnpickler)
178
179        class PersUnpickler(self.unpickler):
180            @staticmethod
181            def persistent_load(pid):
182                return pid
183        check(PersUnpickler)
184
185
186class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
187
188    pickler_class = pickle._Pickler
189    unpickler_class = pickle._Unpickler
190
191
192class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
193
194    pickler_class = pickle._Pickler
195
196    def get_dispatch_table(self):
197        return pickle.dispatch_table.copy()
198
199
200class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
201
202    pickler_class = pickle._Pickler
203
204    def get_dispatch_table(self):
205        return collections.ChainMap({}, pickle.dispatch_table)
206
207
208class PyPicklerHookTests(AbstractHookTests, unittest.TestCase):
209    class CustomPyPicklerClass(pickle._Pickler,
210                               AbstractCustomPicklerClass):
211        pass
212    pickler_class = CustomPyPicklerClass
213
214
215if has_c_implementation:
216    class CPickleTests(AbstractPickleModuleTests, unittest.TestCase):
217        from _pickle import dump, dumps, load, loads, Pickler, Unpickler
218
219    class CUnpicklerTests(PyUnpicklerTests):
220        unpickler = _pickle.Unpickler
221        bad_stack_errors = (pickle.UnpicklingError,)
222        truncated_errors = (pickle.UnpicklingError,)
223
224    class CPicklerTests(PyPicklerTests):
225        pickler = _pickle.Pickler
226        unpickler = _pickle.Unpickler
227
228    class CPersPicklerTests(PyPersPicklerTests):
229        pickler = _pickle.Pickler
230        unpickler = _pickle.Unpickler
231
232    class CIdPersPicklerTests(PyIdPersPicklerTests):
233        pickler = _pickle.Pickler
234        unpickler = _pickle.Unpickler
235
236    class CDumpPickle_LoadPickle(PyPicklerTests):
237        pickler = _pickle.Pickler
238        unpickler = pickle._Unpickler
239
240    class DumpPickle_CLoadPickle(PyPicklerTests):
241        pickler = pickle._Pickler
242        unpickler = _pickle.Unpickler
243
244    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
245        pickler_class = _pickle.Pickler
246        unpickler_class = _pickle.Unpickler
247
248        def test_issue18339(self):
249            unpickler = self.unpickler_class(io.BytesIO())
250            with self.assertRaises(TypeError):
251                unpickler.memo = object
252            # used to cause a segfault
253            with self.assertRaises(ValueError):
254                unpickler.memo = {-1: None}
255            unpickler.memo = {1: None}
256
257    class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
258        pickler_class = pickle.Pickler
259        def get_dispatch_table(self):
260            return pickle.dispatch_table.copy()
261
262    class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
263        pickler_class = pickle.Pickler
264        def get_dispatch_table(self):
265            return collections.ChainMap({}, pickle.dispatch_table)
266
267    class CPicklerHookTests(AbstractHookTests, unittest.TestCase):
268        class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
269            pass
270        pickler_class = CustomCPicklerClass
271
272    @support.cpython_only
273    class SizeofTests(unittest.TestCase):
274        check_sizeof = support.check_sizeof
275
276        def test_pickler(self):
277            basesize = support.calcobjsize('7P2n3i2n3i2P')
278            p = _pickle.Pickler(io.BytesIO())
279            self.assertEqual(object.__sizeof__(p), basesize)
280            MT_size = struct.calcsize('3nP0n')
281            ME_size = struct.calcsize('Pn0P')
282            check = self.check_sizeof
283            check(p, basesize +
284                MT_size + 8 * ME_size +  # Minimal memo table size.
285                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
286            for i in range(6):
287                p.dump(chr(i))
288            check(p, basesize +
289                MT_size + 32 * ME_size +  # Size of memo table required to
290                                          # save references to 6 objects.
291                0)  # Write buffer is cleared after every dump().
292
293        def test_unpickler(self):
294            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
295            unpickler = _pickle.Unpickler
296            P = struct.calcsize('P')  # Size of memo table entry.
297            n = struct.calcsize('n')  # Size of mark table entry.
298            check = self.check_sizeof
299            for encoding in 'ASCII', 'UTF-16', 'latin-1':
300                for errors in 'strict', 'replace':
301                    u = unpickler(io.BytesIO(),
302                                  encoding=encoding, errors=errors)
303                    self.assertEqual(object.__sizeof__(u), basesize)
304                    check(u, basesize +
305                             32 * P +  # Minimal memo table size.
306                             len(encoding) + 1 + len(errors) + 1)
307
308            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
309            def check_unpickler(data, memo_size, marks_size):
310                dump = pickle.dumps(data)
311                u = unpickler(io.BytesIO(dump),
312                              encoding='ASCII', errors='strict')
313                u.load()
314                check(u, stdsize + memo_size * P + marks_size * n)
315
316            check_unpickler(0, 32, 0)
317            # 20 is minimal non-empty mark stack size.
318            check_unpickler([0] * 100, 32, 20)
319            # 128 is memo table size required to save references to 100 objects.
320            check_unpickler([chr(i) for i in range(100)], 128, 20)
321            def recurse(deep):
322                data = 0
323                for i in range(deep):
324                    data = [data, data]
325                return data
326            check_unpickler(recurse(0), 32, 0)
327            check_unpickler(recurse(1), 32, 20)
328            check_unpickler(recurse(20), 32, 20)
329            check_unpickler(recurse(50), 64, 60)
330            check_unpickler(recurse(100), 128, 140)
331
332            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
333                          encoding='ASCII', errors='strict')
334            u.load()
335            check(u, stdsize + 32 * P + 2 + 1)
336
337
338ALT_IMPORT_MAPPING = {
339    ('_elementtree', 'xml.etree.ElementTree'),
340    ('cPickle', 'pickle'),
341    ('StringIO', 'io'),
342    ('cStringIO', 'io'),
343}
344
345ALT_NAME_MAPPING = {
346    ('__builtin__', 'basestring', 'builtins', 'str'),
347    ('exceptions', 'StandardError', 'builtins', 'Exception'),
348    ('UserDict', 'UserDict', 'collections', 'UserDict'),
349    ('socket', '_socketobject', 'socket', 'SocketType'),
350}
351
352def mapping(module, name):
353    if (module, name) in NAME_MAPPING:
354        module, name = NAME_MAPPING[(module, name)]
355    elif module in IMPORT_MAPPING:
356        module = IMPORT_MAPPING[module]
357    return module, name
358
359def reverse_mapping(module, name):
360    if (module, name) in REVERSE_NAME_MAPPING:
361        module, name = REVERSE_NAME_MAPPING[(module, name)]
362    elif module in REVERSE_IMPORT_MAPPING:
363        module = REVERSE_IMPORT_MAPPING[module]
364    return module, name
365
366def getmodule(module):
367    try:
368        return sys.modules[module]
369    except KeyError:
370        try:
371            with warnings.catch_warnings():
372                action = 'always' if support.verbose else 'ignore'
373                warnings.simplefilter(action, DeprecationWarning)
374                __import__(module)
375        except AttributeError as exc:
376            if support.verbose:
377                print("Can't import module %r: %s" % (module, exc))
378            raise ImportError
379        except ImportError as exc:
380            if support.verbose:
381                print(exc)
382            raise
383        return sys.modules[module]
384
385def getattribute(module, name):
386    obj = getmodule(module)
387    for n in name.split('.'):
388        obj = getattr(obj, n)
389    return obj
390
391def get_exceptions(mod):
392    for name in dir(mod):
393        attr = getattr(mod, name)
394        if isinstance(attr, type) and issubclass(attr, BaseException):
395            yield name, attr
396
397class CompatPickleTests(unittest.TestCase):
398    def test_import(self):
399        modules = set(IMPORT_MAPPING.values())
400        modules |= set(REVERSE_IMPORT_MAPPING)
401        modules |= {module for module, name in REVERSE_NAME_MAPPING}
402        modules |= {module for module, name in NAME_MAPPING.values()}
403        for module in modules:
404            try:
405                getmodule(module)
406            except ImportError:
407                pass
408
409    def test_import_mapping(self):
410        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
411            with self.subTest((module3, module2)):
412                try:
413                    getmodule(module3)
414                except ImportError:
415                    pass
416                if module3[:1] != '_':
417                    self.assertIn(module2, IMPORT_MAPPING)
418                    self.assertEqual(IMPORT_MAPPING[module2], module3)
419
420    def test_name_mapping(self):
421        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
422            with self.subTest(((module3, name3), (module2, name2))):
423                if (module2, name2) == ('exceptions', 'OSError'):
424                    attr = getattribute(module3, name3)
425                    self.assertTrue(issubclass(attr, OSError))
426                elif (module2, name2) == ('exceptions', 'ImportError'):
427                    attr = getattribute(module3, name3)
428                    self.assertTrue(issubclass(attr, ImportError))
429                else:
430                    module, name = mapping(module2, name2)
431                    if module3[:1] != '_':
432                        self.assertEqual((module, name), (module3, name3))
433                    try:
434                        attr = getattribute(module3, name3)
435                    except ImportError:
436                        pass
437                    else:
438                        self.assertEqual(getattribute(module, name), attr)
439
440    def test_reverse_import_mapping(self):
441        for module2, module3 in IMPORT_MAPPING.items():
442            with self.subTest((module2, module3)):
443                try:
444                    getmodule(module3)
445                except ImportError as exc:
446                    if support.verbose:
447                        print(exc)
448                if ((module2, module3) not in ALT_IMPORT_MAPPING and
449                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
450                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
451                        if (module3, module2) == (m3, m2):
452                            break
453                    else:
454                        self.fail('No reverse mapping from %r to %r' %
455                                  (module3, module2))
456                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
457                module = IMPORT_MAPPING.get(module, module)
458                self.assertEqual(module, module3)
459
460    def test_reverse_name_mapping(self):
461        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
462            with self.subTest(((module2, name2), (module3, name3))):
463                try:
464                    attr = getattribute(module3, name3)
465                except ImportError:
466                    pass
467                module, name = reverse_mapping(module3, name3)
468                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
469                    self.assertEqual((module, name), (module2, name2))
470                module, name = mapping(module, name)
471                self.assertEqual((module, name), (module3, name3))
472
473    def test_exceptions(self):
474        self.assertEqual(mapping('exceptions', 'StandardError'),
475                         ('builtins', 'Exception'))
476        self.assertEqual(mapping('exceptions', 'Exception'),
477                         ('builtins', 'Exception'))
478        self.assertEqual(reverse_mapping('builtins', 'Exception'),
479                         ('exceptions', 'Exception'))
480        self.assertEqual(mapping('exceptions', 'OSError'),
481                         ('builtins', 'OSError'))
482        self.assertEqual(reverse_mapping('builtins', 'OSError'),
483                         ('exceptions', 'OSError'))
484
485        for name, exc in get_exceptions(builtins):
486            with self.subTest(name):
487                if exc in (BlockingIOError,
488                           ResourceWarning,
489                           StopAsyncIteration,
490                           RecursionError):
491                    continue
492                if exc is not OSError and issubclass(exc, OSError):
493                    self.assertEqual(reverse_mapping('builtins', name),
494                                     ('exceptions', 'OSError'))
495                elif exc is not ImportError and issubclass(exc, ImportError):
496                    self.assertEqual(reverse_mapping('builtins', name),
497                                     ('exceptions', 'ImportError'))
498                    self.assertEqual(mapping('exceptions', name),
499                                     ('exceptions', name))
500                else:
501                    self.assertEqual(reverse_mapping('builtins', name),
502                                     ('exceptions', name))
503                    self.assertEqual(mapping('exceptions', name),
504                                     ('builtins', name))
505
506    def test_multiprocessing_exceptions(self):
507        module = support.import_module('multiprocessing.context')
508        for name, exc in get_exceptions(module):
509            with self.subTest(name):
510                self.assertEqual(reverse_mapping('multiprocessing.context', name),
511                                 ('multiprocessing', name))
512                self.assertEqual(mapping('multiprocessing', name),
513                                 ('multiprocessing.context', name))
514
515
516def load_tests(loader, tests, pattern):
517    tests.addTest(doctest.DocTestSuite())
518    return tests
519
520
521if __name__ == "__main__":
522    unittest.main()
523