1import copy_reg
2import unittest
3
4from test import test_support
5from test.pickletester import ExtensionSaver
6
7class C:
8    pass
9
10
11class WithoutSlots(object):
12    pass
13
14class WithWeakref(object):
15    __slots__ = ('__weakref__',)
16
17class WithPrivate(object):
18    __slots__ = ('__spam',)
19
20class _WithLeadingUnderscoreAndPrivate(object):
21    __slots__ = ('__spam',)
22
23class ___(object):
24    __slots__ = ('__spam',)
25
26class WithSingleString(object):
27    __slots__ = 'spam'
28
29class WithInherited(WithSingleString):
30    __slots__ = ('eggs',)
31
32
33class CopyRegTestCase(unittest.TestCase):
34
35    def test_class(self):
36        self.assertRaises(TypeError, copy_reg.pickle,
37                          C, None, None)
38
39    def test_noncallable_reduce(self):
40        self.assertRaises(TypeError, copy_reg.pickle,
41                          type(1), "not a callable")
42
43    def test_noncallable_constructor(self):
44        self.assertRaises(TypeError, copy_reg.pickle,
45                          type(1), int, "not a callable")
46
47    def test_bool(self):
48        import copy
49        self.assertEqual(True, copy.copy(True))
50
51    def test_extension_registry(self):
52        mod, func, code = 'junk1 ', ' junk2', 0xabcd
53        e = ExtensionSaver(code)
54        try:
55            # Shouldn't be in registry now.
56            self.assertRaises(ValueError, copy_reg.remove_extension,
57                              mod, func, code)
58            copy_reg.add_extension(mod, func, code)
59            # Should be in the registry.
60            self.assertTrue(copy_reg._extension_registry[mod, func] == code)
61            self.assertTrue(copy_reg._inverted_registry[code] == (mod, func))
62            # Shouldn't be in the cache.
63            self.assertNotIn(code, copy_reg._extension_cache)
64            # Redundant registration should be OK.
65            copy_reg.add_extension(mod, func, code)  # shouldn't blow up
66            # Conflicting code.
67            self.assertRaises(ValueError, copy_reg.add_extension,
68                              mod, func, code + 1)
69            self.assertRaises(ValueError, copy_reg.remove_extension,
70                              mod, func, code + 1)
71            # Conflicting module name.
72            self.assertRaises(ValueError, copy_reg.add_extension,
73                              mod[1:], func, code )
74            self.assertRaises(ValueError, copy_reg.remove_extension,
75                              mod[1:], func, code )
76            # Conflicting function name.
77            self.assertRaises(ValueError, copy_reg.add_extension,
78                              mod, func[1:], code)
79            self.assertRaises(ValueError, copy_reg.remove_extension,
80                              mod, func[1:], code)
81            # Can't remove one that isn't registered at all.
82            if code + 1 not in copy_reg._inverted_registry:
83                self.assertRaises(ValueError, copy_reg.remove_extension,
84                                  mod[1:], func[1:], code + 1)
85
86        finally:
87            e.restore()
88
89        # Shouldn't be there anymore.
90        self.assertNotIn((mod, func), copy_reg._extension_registry)
91        # The code *may* be in copy_reg._extension_registry, though, if
92        # we happened to pick on a registered code.  So don't check for
93        # that.
94
95        # Check valid codes at the limits.
96        for code in 1, 0x7fffffff:
97            e = ExtensionSaver(code)
98            try:
99                copy_reg.add_extension(mod, func, code)
100                copy_reg.remove_extension(mod, func, code)
101            finally:
102                e.restore()
103
104        # Ensure invalid codes blow up.
105        for code in -1, 0, 0x80000000L:
106            self.assertRaises(ValueError, copy_reg.add_extension,
107                              mod, func, code)
108
109    def test_slotnames(self):
110        self.assertEqual(copy_reg._slotnames(WithoutSlots), [])
111        self.assertEqual(copy_reg._slotnames(WithWeakref), [])
112        expected = ['_WithPrivate__spam']
113        self.assertEqual(copy_reg._slotnames(WithPrivate), expected)
114        expected = ['_WithLeadingUnderscoreAndPrivate__spam']
115        self.assertEqual(copy_reg._slotnames(_WithLeadingUnderscoreAndPrivate),
116                         expected)
117        self.assertEqual(copy_reg._slotnames(___), ['__spam'])
118        self.assertEqual(copy_reg._slotnames(WithSingleString), ['spam'])
119        expected = ['eggs', 'spam']
120        expected.sort()
121        result = copy_reg._slotnames(WithInherited)
122        result.sort()
123        self.assertEqual(result, expected)
124
125
126def test_main():
127    test_support.run_unittest(CopyRegTestCase)
128
129
130if __name__ == "__main__":
131    test_main()
132