1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9import warnings
10import weakref
11
12import doctest
13import unittest
14from test import support
15from test.support import import_helper
16
17from test.pickletester import AbstractHookTests
18from test.pickletester import AbstractUnpickleTests
19from test.pickletester import AbstractPickleTests
20from test.pickletester import AbstractPickleModuleTests
21from test.pickletester import AbstractPersistentPicklerTests
22from test.pickletester import AbstractIdentityPersistentPicklerTests
23from test.pickletester import AbstractPicklerUnpicklerObjectTests
24from test.pickletester import AbstractDispatchTableTests
25from test.pickletester import AbstractCustomPicklerClass
26from test.pickletester import BigmemPickleTests
27
28try:
29    import _pickle
30    has_c_implementation = True
31except ImportError:
32    has_c_implementation = False
33
34
35class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
36    dump = staticmethod(pickle._dump)
37    dumps = staticmethod(pickle._dumps)
38    load = staticmethod(pickle._load)
39    loads = staticmethod(pickle._loads)
40    Pickler = pickle._Pickler
41    Unpickler = pickle._Unpickler
42
43
44class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
45
46    unpickler = pickle._Unpickler
47    bad_stack_errors = (IndexError,)
48    truncated_errors = (pickle.UnpicklingError, EOFError,
49                        AttributeError, ValueError,
50                        struct.error, IndexError, ImportError)
51
52    def loads(self, buf, **kwds):
53        f = io.BytesIO(buf)
54        u = self.unpickler(f, **kwds)
55        return u.load()
56
57
58class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
59
60    pickler = pickle._Pickler
61    unpickler = pickle._Unpickler
62
63    def dumps(self, arg, proto=None, **kwargs):
64        f = io.BytesIO()
65        p = self.pickler(f, proto, **kwargs)
66        p.dump(arg)
67        f.seek(0)
68        return bytes(f.read())
69
70    def loads(self, buf, **kwds):
71        f = io.BytesIO(buf)
72        u = self.unpickler(f, **kwds)
73        return u.load()
74
75
76class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
77                          BigmemPickleTests, unittest.TestCase):
78
79    bad_stack_errors = (pickle.UnpicklingError, IndexError)
80    truncated_errors = (pickle.UnpicklingError, EOFError,
81                        AttributeError, ValueError,
82                        struct.error, IndexError, ImportError)
83
84    def dumps(self, arg, protocol=None, **kwargs):
85        return pickle.dumps(arg, protocol, **kwargs)
86
87    def loads(self, buf, **kwds):
88        return pickle.loads(buf, **kwds)
89
90    test_framed_write_sizes_with_delayed_writer = None
91
92
93class PersistentPicklerUnpicklerMixin(object):
94
95    def dumps(self, arg, proto=None):
96        class PersPickler(self.pickler):
97            def persistent_id(subself, obj):
98                return self.persistent_id(obj)
99        f = io.BytesIO()
100        p = PersPickler(f, proto)
101        p.dump(arg)
102        return f.getvalue()
103
104    def loads(self, buf, **kwds):
105        class PersUnpickler(self.unpickler):
106            def persistent_load(subself, obj):
107                return self.persistent_load(obj)
108        f = io.BytesIO(buf)
109        u = PersUnpickler(f, **kwds)
110        return u.load()
111
112
113class PyPersPicklerTests(AbstractPersistentPicklerTests,
114                         PersistentPicklerUnpicklerMixin, unittest.TestCase):
115
116    pickler = pickle._Pickler
117    unpickler = pickle._Unpickler
118
119
120class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
121                           PersistentPicklerUnpicklerMixin, unittest.TestCase):
122
123    pickler = pickle._Pickler
124    unpickler = pickle._Unpickler
125
126    @support.cpython_only
127    def test_pickler_reference_cycle(self):
128        def check(Pickler):
129            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
130                f = io.BytesIO()
131                pickler = Pickler(f, proto)
132                pickler.dump('abc')
133                self.assertEqual(self.loads(f.getvalue()), 'abc')
134            pickler = Pickler(io.BytesIO())
135            self.assertEqual(pickler.persistent_id('def'), 'def')
136            r = weakref.ref(pickler)
137            del pickler
138            self.assertIsNone(r())
139
140        class PersPickler(self.pickler):
141            def persistent_id(subself, obj):
142                return obj
143        check(PersPickler)
144
145        class PersPickler(self.pickler):
146            @classmethod
147            def persistent_id(cls, obj):
148                return obj
149        check(PersPickler)
150
151        class PersPickler(self.pickler):
152            @staticmethod
153            def persistent_id(obj):
154                return obj
155        check(PersPickler)
156
157    @support.cpython_only
158    def test_unpickler_reference_cycle(self):
159        def check(Unpickler):
160            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
161                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
162                self.assertEqual(unpickler.load(), 'abc')
163            unpickler = Unpickler(io.BytesIO())
164            self.assertEqual(unpickler.persistent_load('def'), 'def')
165            r = weakref.ref(unpickler)
166            del unpickler
167            self.assertIsNone(r())
168
169        class PersUnpickler(self.unpickler):
170            def persistent_load(subself, pid):
171                return pid
172        check(PersUnpickler)
173
174        class PersUnpickler(self.unpickler):
175            @classmethod
176            def persistent_load(cls, pid):
177                return pid
178        check(PersUnpickler)
179
180        class PersUnpickler(self.unpickler):
181            @staticmethod
182            def persistent_load(pid):
183                return pid
184        check(PersUnpickler)
185
186
187class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
188
189    pickler_class = pickle._Pickler
190    unpickler_class = pickle._Unpickler
191
192
193class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
194
195    pickler_class = pickle._Pickler
196
197    def get_dispatch_table(self):
198        return pickle.dispatch_table.copy()
199
200
201class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
202
203    pickler_class = pickle._Pickler
204
205    def get_dispatch_table(self):
206        return collections.ChainMap({}, pickle.dispatch_table)
207
208
209class PyPicklerHookTests(AbstractHookTests, unittest.TestCase):
210    class CustomPyPicklerClass(pickle._Pickler,
211                               AbstractCustomPicklerClass):
212        pass
213    pickler_class = CustomPyPicklerClass
214
215
216if has_c_implementation:
217    class CPickleTests(AbstractPickleModuleTests, unittest.TestCase):
218        from _pickle import dump, dumps, load, loads, Pickler, Unpickler
219
220    class CUnpicklerTests(PyUnpicklerTests):
221        unpickler = _pickle.Unpickler
222        bad_stack_errors = (pickle.UnpicklingError,)
223        truncated_errors = (pickle.UnpicklingError,)
224
225    class CPicklerTests(PyPicklerTests):
226        pickler = _pickle.Pickler
227        unpickler = _pickle.Unpickler
228
229    class CPersPicklerTests(PyPersPicklerTests):
230        pickler = _pickle.Pickler
231        unpickler = _pickle.Unpickler
232
233    class CIdPersPicklerTests(PyIdPersPicklerTests):
234        pickler = _pickle.Pickler
235        unpickler = _pickle.Unpickler
236
237    class CDumpPickle_LoadPickle(PyPicklerTests):
238        pickler = _pickle.Pickler
239        unpickler = pickle._Unpickler
240
241    class DumpPickle_CLoadPickle(PyPicklerTests):
242        pickler = pickle._Pickler
243        unpickler = _pickle.Unpickler
244
245    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
246        pickler_class = _pickle.Pickler
247        unpickler_class = _pickle.Unpickler
248
249        def test_issue18339(self):
250            unpickler = self.unpickler_class(io.BytesIO())
251            with self.assertRaises(TypeError):
252                unpickler.memo = object
253            # used to cause a segfault
254            with self.assertRaises(ValueError):
255                unpickler.memo = {-1: None}
256            unpickler.memo = {1: None}
257
258    class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
259        pickler_class = pickle.Pickler
260        def get_dispatch_table(self):
261            return pickle.dispatch_table.copy()
262
263    class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
264        pickler_class = pickle.Pickler
265        def get_dispatch_table(self):
266            return collections.ChainMap({}, pickle.dispatch_table)
267
268    class CPicklerHookTests(AbstractHookTests, unittest.TestCase):
269        class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
270            pass
271        pickler_class = CustomCPicklerClass
272
273    @support.cpython_only
274    class SizeofTests(unittest.TestCase):
275        check_sizeof = support.check_sizeof
276
277        def test_pickler(self):
278            basesize = support.calcobjsize('7P2n3i2n3i2P')
279            p = _pickle.Pickler(io.BytesIO())
280            self.assertEqual(object.__sizeof__(p), basesize)
281            MT_size = struct.calcsize('3nP0n')
282            ME_size = struct.calcsize('Pn0P')
283            check = self.check_sizeof
284            check(p, basesize +
285                MT_size + 8 * ME_size +  # Minimal memo table size.
286                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
287            for i in range(6):
288                p.dump(chr(i))
289            check(p, basesize +
290                MT_size + 32 * ME_size +  # Size of memo table required to
291                                          # save references to 6 objects.
292                0)  # Write buffer is cleared after every dump().
293
294        def test_unpickler(self):
295            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
296            unpickler = _pickle.Unpickler
297            P = struct.calcsize('P')  # Size of memo table entry.
298            n = struct.calcsize('n')  # Size of mark table entry.
299            check = self.check_sizeof
300            for encoding in 'ASCII', 'UTF-16', 'latin-1':
301                for errors in 'strict', 'replace':
302                    u = unpickler(io.BytesIO(),
303                                  encoding=encoding, errors=errors)
304                    self.assertEqual(object.__sizeof__(u), basesize)
305                    check(u, basesize +
306                             32 * P +  # Minimal memo table size.
307                             len(encoding) + 1 + len(errors) + 1)
308
309            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
310            def check_unpickler(data, memo_size, marks_size):
311                dump = pickle.dumps(data)
312                u = unpickler(io.BytesIO(dump),
313                              encoding='ASCII', errors='strict')
314                u.load()
315                check(u, stdsize + memo_size * P + marks_size * n)
316
317            check_unpickler(0, 32, 0)
318            # 20 is minimal non-empty mark stack size.
319            check_unpickler([0] * 100, 32, 20)
320            # 128 is memo table size required to save references to 100 objects.
321            check_unpickler([chr(i) for i in range(100)], 128, 20)
322            def recurse(deep):
323                data = 0
324                for i in range(deep):
325                    data = [data, data]
326                return data
327            check_unpickler(recurse(0), 32, 0)
328            check_unpickler(recurse(1), 32, 20)
329            check_unpickler(recurse(20), 32, 20)
330            check_unpickler(recurse(50), 64, 60)
331            check_unpickler(recurse(100), 128, 140)
332
333            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
334                          encoding='ASCII', errors='strict')
335            u.load()
336            check(u, stdsize + 32 * P + 2 + 1)
337
338
339ALT_IMPORT_MAPPING = {
340    ('_elementtree', 'xml.etree.ElementTree'),
341    ('cPickle', 'pickle'),
342    ('StringIO', 'io'),
343    ('cStringIO', 'io'),
344}
345
346ALT_NAME_MAPPING = {
347    ('__builtin__', 'basestring', 'builtins', 'str'),
348    ('exceptions', 'StandardError', 'builtins', 'Exception'),
349    ('UserDict', 'UserDict', 'collections', 'UserDict'),
350    ('socket', '_socketobject', 'socket', 'SocketType'),
351}
352
353def mapping(module, name):
354    if (module, name) in NAME_MAPPING:
355        module, name = NAME_MAPPING[(module, name)]
356    elif module in IMPORT_MAPPING:
357        module = IMPORT_MAPPING[module]
358    return module, name
359
360def reverse_mapping(module, name):
361    if (module, name) in REVERSE_NAME_MAPPING:
362        module, name = REVERSE_NAME_MAPPING[(module, name)]
363    elif module in REVERSE_IMPORT_MAPPING:
364        module = REVERSE_IMPORT_MAPPING[module]
365    return module, name
366
367def getmodule(module):
368    try:
369        return sys.modules[module]
370    except KeyError:
371        try:
372            with warnings.catch_warnings():
373                action = 'always' if support.verbose else 'ignore'
374                warnings.simplefilter(action, DeprecationWarning)
375                __import__(module)
376        except AttributeError as exc:
377            if support.verbose:
378                print("Can't import module %r: %s" % (module, exc))
379            raise ImportError
380        except ImportError as exc:
381            if support.verbose:
382                print(exc)
383            raise
384        return sys.modules[module]
385
386def getattribute(module, name):
387    obj = getmodule(module)
388    for n in name.split('.'):
389        obj = getattr(obj, n)
390    return obj
391
392def get_exceptions(mod):
393    for name in dir(mod):
394        attr = getattr(mod, name)
395        if isinstance(attr, type) and issubclass(attr, BaseException):
396            yield name, attr
397
398class CompatPickleTests(unittest.TestCase):
399    def test_import(self):
400        modules = set(IMPORT_MAPPING.values())
401        modules |= set(REVERSE_IMPORT_MAPPING)
402        modules |= {module for module, name in REVERSE_NAME_MAPPING}
403        modules |= {module for module, name in NAME_MAPPING.values()}
404        for module in modules:
405            try:
406                getmodule(module)
407            except ImportError:
408                pass
409
410    def test_import_mapping(self):
411        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
412            with self.subTest((module3, module2)):
413                try:
414                    getmodule(module3)
415                except ImportError:
416                    pass
417                if module3[:1] != '_':
418                    self.assertIn(module2, IMPORT_MAPPING)
419                    self.assertEqual(IMPORT_MAPPING[module2], module3)
420
421    def test_name_mapping(self):
422        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
423            with self.subTest(((module3, name3), (module2, name2))):
424                if (module2, name2) == ('exceptions', 'OSError'):
425                    attr = getattribute(module3, name3)
426                    self.assertTrue(issubclass(attr, OSError))
427                elif (module2, name2) == ('exceptions', 'ImportError'):
428                    attr = getattribute(module3, name3)
429                    self.assertTrue(issubclass(attr, ImportError))
430                else:
431                    module, name = mapping(module2, name2)
432                    if module3[:1] != '_':
433                        self.assertEqual((module, name), (module3, name3))
434                    try:
435                        attr = getattribute(module3, name3)
436                    except ImportError:
437                        pass
438                    else:
439                        self.assertEqual(getattribute(module, name), attr)
440
441    def test_reverse_import_mapping(self):
442        for module2, module3 in IMPORT_MAPPING.items():
443            with self.subTest((module2, module3)):
444                try:
445                    getmodule(module3)
446                except ImportError as exc:
447                    if support.verbose:
448                        print(exc)
449                if ((module2, module3) not in ALT_IMPORT_MAPPING and
450                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
451                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
452                        if (module3, module2) == (m3, m2):
453                            break
454                    else:
455                        self.fail('No reverse mapping from %r to %r' %
456                                  (module3, module2))
457                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
458                module = IMPORT_MAPPING.get(module, module)
459                self.assertEqual(module, module3)
460
461    def test_reverse_name_mapping(self):
462        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
463            with self.subTest(((module2, name2), (module3, name3))):
464                try:
465                    attr = getattribute(module3, name3)
466                except ImportError:
467                    pass
468                module, name = reverse_mapping(module3, name3)
469                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
470                    self.assertEqual((module, name), (module2, name2))
471                module, name = mapping(module, name)
472                self.assertEqual((module, name), (module3, name3))
473
474    def test_exceptions(self):
475        self.assertEqual(mapping('exceptions', 'StandardError'),
476                         ('builtins', 'Exception'))
477        self.assertEqual(mapping('exceptions', 'Exception'),
478                         ('builtins', 'Exception'))
479        self.assertEqual(reverse_mapping('builtins', 'Exception'),
480                         ('exceptions', 'Exception'))
481        self.assertEqual(mapping('exceptions', 'OSError'),
482                         ('builtins', 'OSError'))
483        self.assertEqual(reverse_mapping('builtins', 'OSError'),
484                         ('exceptions', 'OSError'))
485
486        for name, exc in get_exceptions(builtins):
487            with self.subTest(name):
488                if exc in (BlockingIOError,
489                           ResourceWarning,
490                           StopAsyncIteration,
491                           RecursionError,
492                           EncodingWarning,
493                           BaseExceptionGroup,
494                           ExceptionGroup):
495                    continue
496                if exc is not OSError and issubclass(exc, OSError):
497                    self.assertEqual(reverse_mapping('builtins', name),
498                                     ('exceptions', 'OSError'))
499                elif exc is not ImportError and issubclass(exc, ImportError):
500                    self.assertEqual(reverse_mapping('builtins', name),
501                                     ('exceptions', 'ImportError'))
502                    self.assertEqual(mapping('exceptions', name),
503                                     ('exceptions', name))
504                else:
505                    self.assertEqual(reverse_mapping('builtins', name),
506                                     ('exceptions', name))
507                    self.assertEqual(mapping('exceptions', name),
508                                     ('builtins', name))
509
510    def test_multiprocessing_exceptions(self):
511        module = import_helper.import_module('multiprocessing.context')
512        for name, exc in get_exceptions(module):
513            with self.subTest(name):
514                self.assertEqual(reverse_mapping('multiprocessing.context', name),
515                                 ('multiprocessing', name))
516                self.assertEqual(mapping('multiprocessing', name),
517                                 ('multiprocessing.context', name))
518
519
520def load_tests(loader, tests, pattern):
521    tests.addTest(doctest.DocTestSuite())
522    return tests
523
524
525if __name__ == "__main__":
526    unittest.main()
527