1import copyreg
2import unittest
3
4from test.pickletester import ExtensionSaver
5
6class C:
7    pass
8
9
10class WithoutSlots(object):
11    pass
12
13class WithWeakref(object):
14    __slots__ = ('__weakref__',)
15
16class WithPrivate(object):
17    __slots__ = ('__spam',)
18
19class _WithLeadingUnderscoreAndPrivate(object):
20    __slots__ = ('__spam',)
21
22class ___(object):
23    __slots__ = ('__spam',)
24
25class WithSingleString(object):
26    __slots__ = 'spam'
27
28class WithInherited(WithSingleString):
29    __slots__ = ('eggs',)
30
31
32class CopyRegTestCase(unittest.TestCase):
33
34    def test_class(self):
35        self.assertRaises(TypeError, copyreg.pickle,
36                          C, None, None)
37
38    def test_noncallable_reduce(self):
39        self.assertRaises(TypeError, copyreg.pickle,
40                          type(1), "not a callable")
41
42    def test_noncallable_constructor(self):
43        self.assertRaises(TypeError, copyreg.pickle,
44                          type(1), int, "not a callable")
45
46    def test_bool(self):
47        import copy
48        self.assertEqual(True, copy.copy(True))
49
50    def test_extension_registry(self):
51        mod, func, code = 'junk1 ', ' junk2', 0xabcd
52        e = ExtensionSaver(code)
53        try:
54            # Shouldn't be in registry now.
55            self.assertRaises(ValueError, copyreg.remove_extension,
56                              mod, func, code)
57            copyreg.add_extension(mod, func, code)
58            # Should be in the registry.
59            self.assertTrue(copyreg._extension_registry[mod, func] == code)
60            self.assertTrue(copyreg._inverted_registry[code] == (mod, func))
61            # Shouldn't be in the cache.
62            self.assertNotIn(code, copyreg._extension_cache)
63            # Redundant registration should be OK.
64            copyreg.add_extension(mod, func, code)  # shouldn't blow up
65            # Conflicting code.
66            self.assertRaises(ValueError, copyreg.add_extension,
67                              mod, func, code + 1)
68            self.assertRaises(ValueError, copyreg.remove_extension,
69                              mod, func, code + 1)
70            # Conflicting module name.
71            self.assertRaises(ValueError, copyreg.add_extension,
72                              mod[1:], func, code )
73            self.assertRaises(ValueError, copyreg.remove_extension,
74                              mod[1:], func, code )
75            # Conflicting function name.
76            self.assertRaises(ValueError, copyreg.add_extension,
77                              mod, func[1:], code)
78            self.assertRaises(ValueError, copyreg.remove_extension,
79                              mod, func[1:], code)
80            # Can't remove one that isn't registered at all.
81            if code + 1 not in copyreg._inverted_registry:
82                self.assertRaises(ValueError, copyreg.remove_extension,
83                                  mod[1:], func[1:], code + 1)
84
85        finally:
86            e.restore()
87
88        # Shouldn't be there anymore.
89        self.assertNotIn((mod, func), copyreg._extension_registry)
90        # The code *may* be in copyreg._extension_registry, though, if
91        # we happened to pick on a registered code.  So don't check for
92        # that.
93
94        # Check valid codes at the limits.
95        for code in 1, 0x7fffffff:
96            e = ExtensionSaver(code)
97            try:
98                copyreg.add_extension(mod, func, code)
99                copyreg.remove_extension(mod, func, code)
100            finally:
101                e.restore()
102
103        # Ensure invalid codes blow up.
104        for code in -1, 0, 0x80000000:
105            self.assertRaises(ValueError, copyreg.add_extension,
106                              mod, func, code)
107
108    def test_slotnames(self):
109        self.assertEqual(copyreg._slotnames(WithoutSlots), [])
110        self.assertEqual(copyreg._slotnames(WithWeakref), [])
111        expected = ['_WithPrivate__spam']
112        self.assertEqual(copyreg._slotnames(WithPrivate), expected)
113        expected = ['_WithLeadingUnderscoreAndPrivate__spam']
114        self.assertEqual(copyreg._slotnames(_WithLeadingUnderscoreAndPrivate),
115                         expected)
116        self.assertEqual(copyreg._slotnames(___), ['__spam'])
117        self.assertEqual(copyreg._slotnames(WithSingleString), ['spam'])
118        expected = ['eggs', 'spam']
119        expected.sort()
120        result = copyreg._slotnames(WithInherited)
121        result.sort()
122        self.assertEqual(result, expected)
123
124
125if __name__ == "__main__":
126    unittest.main()
127