1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9import weakref
10
11import unittest
12from test import support
13
14from test.pickletester import AbstractHookTests
15from test.pickletester import AbstractUnpickleTests
16from test.pickletester import AbstractPickleTests
17from test.pickletester import AbstractPickleModuleTests
18from test.pickletester import AbstractPersistentPicklerTests
19from test.pickletester import AbstractIdentityPersistentPicklerTests
20from test.pickletester import AbstractPicklerUnpicklerObjectTests
21from test.pickletester import AbstractDispatchTableTests
22from test.pickletester import AbstractCustomPicklerClass
23from test.pickletester import BigmemPickleTests
24
25try:
26    import _pickle
27    has_c_implementation = True
28except ImportError:
29    has_c_implementation = False
30
31
32class PyPickleTests(AbstractPickleModuleTests):
33    dump = staticmethod(pickle._dump)
34    dumps = staticmethod(pickle._dumps)
35    load = staticmethod(pickle._load)
36    loads = staticmethod(pickle._loads)
37    Pickler = pickle._Pickler
38    Unpickler = pickle._Unpickler
39
40
41class PyUnpicklerTests(AbstractUnpickleTests):
42
43    unpickler = pickle._Unpickler
44    bad_stack_errors = (IndexError,)
45    truncated_errors = (pickle.UnpicklingError, EOFError,
46                        AttributeError, ValueError,
47                        struct.error, IndexError, ImportError)
48
49    def loads(self, buf, **kwds):
50        f = io.BytesIO(buf)
51        u = self.unpickler(f, **kwds)
52        return u.load()
53
54
55class PyPicklerTests(AbstractPickleTests):
56
57    pickler = pickle._Pickler
58    unpickler = pickle._Unpickler
59
60    def dumps(self, arg, proto=None, **kwargs):
61        f = io.BytesIO()
62        p = self.pickler(f, proto, **kwargs)
63        p.dump(arg)
64        f.seek(0)
65        return bytes(f.read())
66
67    def loads(self, buf, **kwds):
68        f = io.BytesIO(buf)
69        u = self.unpickler(f, **kwds)
70        return u.load()
71
72
73class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
74                          BigmemPickleTests):
75
76    bad_stack_errors = (pickle.UnpicklingError, IndexError)
77    truncated_errors = (pickle.UnpicklingError, EOFError,
78                        AttributeError, ValueError,
79                        struct.error, IndexError, ImportError)
80
81    def dumps(self, arg, protocol=None, **kwargs):
82        return pickle.dumps(arg, protocol, **kwargs)
83
84    def loads(self, buf, **kwds):
85        return pickle.loads(buf, **kwds)
86
87    test_framed_write_sizes_with_delayed_writer = None
88
89
90class PersistentPicklerUnpicklerMixin(object):
91
92    def dumps(self, arg, proto=None):
93        class PersPickler(self.pickler):
94            def persistent_id(subself, obj):
95                return self.persistent_id(obj)
96        f = io.BytesIO()
97        p = PersPickler(f, proto)
98        p.dump(arg)
99        return f.getvalue()
100
101    def loads(self, buf, **kwds):
102        class PersUnpickler(self.unpickler):
103            def persistent_load(subself, obj):
104                return self.persistent_load(obj)
105        f = io.BytesIO(buf)
106        u = PersUnpickler(f, **kwds)
107        return u.load()
108
109
110class PyPersPicklerTests(AbstractPersistentPicklerTests,
111                         PersistentPicklerUnpicklerMixin):
112
113    pickler = pickle._Pickler
114    unpickler = pickle._Unpickler
115
116
117class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
118                           PersistentPicklerUnpicklerMixin):
119
120    pickler = pickle._Pickler
121    unpickler = pickle._Unpickler
122
123    @support.cpython_only
124    def test_pickler_reference_cycle(self):
125        def check(Pickler):
126            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
127                f = io.BytesIO()
128                pickler = Pickler(f, proto)
129                pickler.dump('abc')
130                self.assertEqual(self.loads(f.getvalue()), 'abc')
131            pickler = Pickler(io.BytesIO())
132            self.assertEqual(pickler.persistent_id('def'), 'def')
133            r = weakref.ref(pickler)
134            del pickler
135            self.assertIsNone(r())
136
137        class PersPickler(self.pickler):
138            def persistent_id(subself, obj):
139                return obj
140        check(PersPickler)
141
142        class PersPickler(self.pickler):
143            @classmethod
144            def persistent_id(cls, obj):
145                return obj
146        check(PersPickler)
147
148        class PersPickler(self.pickler):
149            @staticmethod
150            def persistent_id(obj):
151                return obj
152        check(PersPickler)
153
154    @support.cpython_only
155    def test_unpickler_reference_cycle(self):
156        def check(Unpickler):
157            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
158                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
159                self.assertEqual(unpickler.load(), 'abc')
160            unpickler = Unpickler(io.BytesIO())
161            self.assertEqual(unpickler.persistent_load('def'), 'def')
162            r = weakref.ref(unpickler)
163            del unpickler
164            self.assertIsNone(r())
165
166        class PersUnpickler(self.unpickler):
167            def persistent_load(subself, pid):
168                return pid
169        check(PersUnpickler)
170
171        class PersUnpickler(self.unpickler):
172            @classmethod
173            def persistent_load(cls, pid):
174                return pid
175        check(PersUnpickler)
176
177        class PersUnpickler(self.unpickler):
178            @staticmethod
179            def persistent_load(pid):
180                return pid
181        check(PersUnpickler)
182
183
184class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
185
186    pickler_class = pickle._Pickler
187    unpickler_class = pickle._Unpickler
188
189
190class PyDispatchTableTests(AbstractDispatchTableTests):
191
192    pickler_class = pickle._Pickler
193
194    def get_dispatch_table(self):
195        return pickle.dispatch_table.copy()
196
197
198class PyChainDispatchTableTests(AbstractDispatchTableTests):
199
200    pickler_class = pickle._Pickler
201
202    def get_dispatch_table(self):
203        return collections.ChainMap({}, pickle.dispatch_table)
204
205
206class PyPicklerHookTests(AbstractHookTests):
207    class CustomPyPicklerClass(pickle._Pickler,
208                               AbstractCustomPicklerClass):
209        pass
210    pickler_class = CustomPyPicklerClass
211
212
213if has_c_implementation:
214    class CPickleTests(AbstractPickleModuleTests):
215        from _pickle import dump, dumps, load, loads, Pickler, Unpickler
216
217    class CUnpicklerTests(PyUnpicklerTests):
218        unpickler = _pickle.Unpickler
219        bad_stack_errors = (pickle.UnpicklingError,)
220        truncated_errors = (pickle.UnpicklingError,)
221
222    class CPicklerTests(PyPicklerTests):
223        pickler = _pickle.Pickler
224        unpickler = _pickle.Unpickler
225
226    class CPersPicklerTests(PyPersPicklerTests):
227        pickler = _pickle.Pickler
228        unpickler = _pickle.Unpickler
229
230    class CIdPersPicklerTests(PyIdPersPicklerTests):
231        pickler = _pickle.Pickler
232        unpickler = _pickle.Unpickler
233
234    class CDumpPickle_LoadPickle(PyPicklerTests):
235        pickler = _pickle.Pickler
236        unpickler = pickle._Unpickler
237
238    class DumpPickle_CLoadPickle(PyPicklerTests):
239        pickler = pickle._Pickler
240        unpickler = _pickle.Unpickler
241
242    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
243        pickler_class = _pickle.Pickler
244        unpickler_class = _pickle.Unpickler
245
246        def test_issue18339(self):
247            unpickler = self.unpickler_class(io.BytesIO())
248            with self.assertRaises(TypeError):
249                unpickler.memo = object
250            # used to cause a segfault
251            with self.assertRaises(ValueError):
252                unpickler.memo = {-1: None}
253            unpickler.memo = {1: None}
254
255    class CDispatchTableTests(AbstractDispatchTableTests):
256        pickler_class = pickle.Pickler
257        def get_dispatch_table(self):
258            return pickle.dispatch_table.copy()
259
260    class CChainDispatchTableTests(AbstractDispatchTableTests):
261        pickler_class = pickle.Pickler
262        def get_dispatch_table(self):
263            return collections.ChainMap({}, pickle.dispatch_table)
264
265    class CPicklerHookTests(AbstractHookTests):
266        class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
267            pass
268        pickler_class = CustomCPicklerClass
269
270    @support.cpython_only
271    class SizeofTests(unittest.TestCase):
272        check_sizeof = support.check_sizeof
273
274        def test_pickler(self):
275            basesize = support.calcobjsize('7P2n3i2n3i2P')
276            p = _pickle.Pickler(io.BytesIO())
277            self.assertEqual(object.__sizeof__(p), basesize)
278            MT_size = struct.calcsize('3nP0n')
279            ME_size = struct.calcsize('Pn0P')
280            check = self.check_sizeof
281            check(p, basesize +
282                MT_size + 8 * ME_size +  # Minimal memo table size.
283                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
284            for i in range(6):
285                p.dump(chr(i))
286            check(p, basesize +
287                MT_size + 32 * ME_size +  # Size of memo table required to
288                                          # save references to 6 objects.
289                0)  # Write buffer is cleared after every dump().
290
291        def test_unpickler(self):
292            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
293            unpickler = _pickle.Unpickler
294            P = struct.calcsize('P')  # Size of memo table entry.
295            n = struct.calcsize('n')  # Size of mark table entry.
296            check = self.check_sizeof
297            for encoding in 'ASCII', 'UTF-16', 'latin-1':
298                for errors in 'strict', 'replace':
299                    u = unpickler(io.BytesIO(),
300                                  encoding=encoding, errors=errors)
301                    self.assertEqual(object.__sizeof__(u), basesize)
302                    check(u, basesize +
303                             32 * P +  # Minimal memo table size.
304                             len(encoding) + 1 + len(errors) + 1)
305
306            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
307            def check_unpickler(data, memo_size, marks_size):
308                dump = pickle.dumps(data)
309                u = unpickler(io.BytesIO(dump),
310                              encoding='ASCII', errors='strict')
311                u.load()
312                check(u, stdsize + memo_size * P + marks_size * n)
313
314            check_unpickler(0, 32, 0)
315            # 20 is minimal non-empty mark stack size.
316            check_unpickler([0] * 100, 32, 20)
317            # 128 is memo table size required to save references to 100 objects.
318            check_unpickler([chr(i) for i in range(100)], 128, 20)
319            def recurse(deep):
320                data = 0
321                for i in range(deep):
322                    data = [data, data]
323                return data
324            check_unpickler(recurse(0), 32, 0)
325            check_unpickler(recurse(1), 32, 20)
326            check_unpickler(recurse(20), 32, 20)
327            check_unpickler(recurse(50), 64, 60)
328            check_unpickler(recurse(100), 128, 140)
329
330            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
331                          encoding='ASCII', errors='strict')
332            u.load()
333            check(u, stdsize + 32 * P + 2 + 1)
334
335
336ALT_IMPORT_MAPPING = {
337    ('_elementtree', 'xml.etree.ElementTree'),
338    ('cPickle', 'pickle'),
339    ('StringIO', 'io'),
340    ('cStringIO', 'io'),
341}
342
343ALT_NAME_MAPPING = {
344    ('__builtin__', 'basestring', 'builtins', 'str'),
345    ('exceptions', 'StandardError', 'builtins', 'Exception'),
346    ('UserDict', 'UserDict', 'collections', 'UserDict'),
347    ('socket', '_socketobject', 'socket', 'SocketType'),
348}
349
350def mapping(module, name):
351    if (module, name) in NAME_MAPPING:
352        module, name = NAME_MAPPING[(module, name)]
353    elif module in IMPORT_MAPPING:
354        module = IMPORT_MAPPING[module]
355    return module, name
356
357def reverse_mapping(module, name):
358    if (module, name) in REVERSE_NAME_MAPPING:
359        module, name = REVERSE_NAME_MAPPING[(module, name)]
360    elif module in REVERSE_IMPORT_MAPPING:
361        module = REVERSE_IMPORT_MAPPING[module]
362    return module, name
363
364def getmodule(module):
365    try:
366        return sys.modules[module]
367    except KeyError:
368        try:
369            __import__(module)
370        except AttributeError as exc:
371            if support.verbose:
372                print("Can't import module %r: %s" % (module, exc))
373            raise ImportError
374        except ImportError as exc:
375            if support.verbose:
376                print(exc)
377            raise
378        return sys.modules[module]
379
380def getattribute(module, name):
381    obj = getmodule(module)
382    for n in name.split('.'):
383        obj = getattr(obj, n)
384    return obj
385
386def get_exceptions(mod):
387    for name in dir(mod):
388        attr = getattr(mod, name)
389        if isinstance(attr, type) and issubclass(attr, BaseException):
390            yield name, attr
391
392class CompatPickleTests(unittest.TestCase):
393    def test_import(self):
394        modules = set(IMPORT_MAPPING.values())
395        modules |= set(REVERSE_IMPORT_MAPPING)
396        modules |= {module for module, name in REVERSE_NAME_MAPPING}
397        modules |= {module for module, name in NAME_MAPPING.values()}
398        for module in modules:
399            try:
400                getmodule(module)
401            except ImportError:
402                pass
403
404    def test_import_mapping(self):
405        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
406            with self.subTest((module3, module2)):
407                try:
408                    getmodule(module3)
409                except ImportError:
410                    pass
411                if module3[:1] != '_':
412                    self.assertIn(module2, IMPORT_MAPPING)
413                    self.assertEqual(IMPORT_MAPPING[module2], module3)
414
415    def test_name_mapping(self):
416        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
417            with self.subTest(((module3, name3), (module2, name2))):
418                if (module2, name2) == ('exceptions', 'OSError'):
419                    attr = getattribute(module3, name3)
420                    self.assertTrue(issubclass(attr, OSError))
421                elif (module2, name2) == ('exceptions', 'ImportError'):
422                    attr = getattribute(module3, name3)
423                    self.assertTrue(issubclass(attr, ImportError))
424                else:
425                    module, name = mapping(module2, name2)
426                    if module3[:1] != '_':
427                        self.assertEqual((module, name), (module3, name3))
428                    try:
429                        attr = getattribute(module3, name3)
430                    except ImportError:
431                        pass
432                    else:
433                        self.assertEqual(getattribute(module, name), attr)
434
435    def test_reverse_import_mapping(self):
436        for module2, module3 in IMPORT_MAPPING.items():
437            with self.subTest((module2, module3)):
438                try:
439                    getmodule(module3)
440                except ImportError as exc:
441                    if support.verbose:
442                        print(exc)
443                if ((module2, module3) not in ALT_IMPORT_MAPPING and
444                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
445                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
446                        if (module3, module2) == (m3, m2):
447                            break
448                    else:
449                        self.fail('No reverse mapping from %r to %r' %
450                                  (module3, module2))
451                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
452                module = IMPORT_MAPPING.get(module, module)
453                self.assertEqual(module, module3)
454
455    def test_reverse_name_mapping(self):
456        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
457            with self.subTest(((module2, name2), (module3, name3))):
458                try:
459                    attr = getattribute(module3, name3)
460                except ImportError:
461                    pass
462                module, name = reverse_mapping(module3, name3)
463                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
464                    self.assertEqual((module, name), (module2, name2))
465                module, name = mapping(module, name)
466                self.assertEqual((module, name), (module3, name3))
467
468    def test_exceptions(self):
469        self.assertEqual(mapping('exceptions', 'StandardError'),
470                         ('builtins', 'Exception'))
471        self.assertEqual(mapping('exceptions', 'Exception'),
472                         ('builtins', 'Exception'))
473        self.assertEqual(reverse_mapping('builtins', 'Exception'),
474                         ('exceptions', 'Exception'))
475        self.assertEqual(mapping('exceptions', 'OSError'),
476                         ('builtins', 'OSError'))
477        self.assertEqual(reverse_mapping('builtins', 'OSError'),
478                         ('exceptions', 'OSError'))
479
480        for name, exc in get_exceptions(builtins):
481            with self.subTest(name):
482                if exc in (BlockingIOError,
483                           ResourceWarning,
484                           StopAsyncIteration,
485                           RecursionError):
486                    continue
487                if exc is not OSError and issubclass(exc, OSError):
488                    self.assertEqual(reverse_mapping('builtins', name),
489                                     ('exceptions', 'OSError'))
490                elif exc is not ImportError and issubclass(exc, ImportError):
491                    self.assertEqual(reverse_mapping('builtins', name),
492                                     ('exceptions', 'ImportError'))
493                    self.assertEqual(mapping('exceptions', name),
494                                     ('exceptions', name))
495                else:
496                    self.assertEqual(reverse_mapping('builtins', name),
497                                     ('exceptions', name))
498                    self.assertEqual(mapping('exceptions', name),
499                                     ('builtins', name))
500
501    def test_multiprocessing_exceptions(self):
502        module = support.import_module('multiprocessing.context')
503        for name, exc in get_exceptions(module):
504            with self.subTest(name):
505                self.assertEqual(reverse_mapping('multiprocessing.context', name),
506                                 ('multiprocessing', name))
507                self.assertEqual(mapping('multiprocessing', name),
508                                 ('multiprocessing.context', name))
509
510
511def test_main():
512    tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests,
513             PyPersPicklerTests, PyIdPersPicklerTests,
514             PyDispatchTableTests, PyChainDispatchTableTests,
515             CompatPickleTests, PyPicklerHookTests]
516    if has_c_implementation:
517        tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests,
518                      CPersPicklerTests, CIdPersPicklerTests,
519                      CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
520                      PyPicklerUnpicklerObjectTests,
521                      CPicklerUnpicklerObjectTests,
522                      CDispatchTableTests, CChainDispatchTableTests,
523                      CPicklerHookTests,
524                      InMemoryPickleTests, SizeofTests])
525    support.run_unittest(*tests)
526    support.run_doctest(pickle)
527
528if __name__ == "__main__":
529    test_main()
530