1"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
5import pickle
6import tempfile
7import unittest
8
9from collections import defaultdict
10
11def foobar():
12    return list
13
14class TestDefaultDict(unittest.TestCase):
15
16    def test_basic(self):
17        d1 = defaultdict()
18        self.assertEqual(d1.default_factory, None)
19        d1.default_factory = list
20        d1[12].append(42)
21        self.assertEqual(d1, {12: [42]})
22        d1[12].append(24)
23        self.assertEqual(d1, {12: [42, 24]})
24        d1[13]
25        d1[14]
26        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
27        self.assertTrue(d1[12] is not d1[13] is not d1[14])
28        d2 = defaultdict(list, foo=1, bar=2)
29        self.assertEqual(d2.default_factory, list)
30        self.assertEqual(d2, {"foo": 1, "bar": 2})
31        self.assertEqual(d2["foo"], 1)
32        self.assertEqual(d2["bar"], 2)
33        self.assertEqual(d2[42], [])
34        self.assertIn("foo", d2)
35        self.assertIn("foo", d2.keys())
36        self.assertIn("bar", d2)
37        self.assertIn("bar", d2.keys())
38        self.assertIn(42, d2)
39        self.assertIn(42, d2.keys())
40        self.assertNotIn(12, d2)
41        self.assertNotIn(12, d2.keys())
42        d2.default_factory = None
43        self.assertEqual(d2.default_factory, None)
44        try:
45            d2[15]
46        except KeyError as err:
47            self.assertEqual(err.args, (15,))
48        else:
49            self.fail("d2[15] didn't raise KeyError")
50        self.assertRaises(TypeError, defaultdict, 1)
51
52    def test_missing(self):
53        d1 = defaultdict()
54        self.assertRaises(KeyError, d1.__missing__, 42)
55        d1.default_factory = list
56        self.assertEqual(d1.__missing__(42), [])
57
58    def test_repr(self):
59        d1 = defaultdict()
60        self.assertEqual(d1.default_factory, None)
61        self.assertEqual(repr(d1), "defaultdict(None, {})")
62        self.assertEqual(eval(repr(d1)), d1)
63        d1[11] = 41
64        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
65        d2 = defaultdict(int)
66        self.assertEqual(d2.default_factory, int)
67        d2[12] = 42
68        self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
69        def foo(): return 43
70        d3 = defaultdict(foo)
71        self.assertTrue(d3.default_factory is foo)
72        d3[13]
73        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
74
75    def test_copy(self):
76        d1 = defaultdict()
77        d2 = d1.copy()
78        self.assertEqual(type(d2), defaultdict)
79        self.assertEqual(d2.default_factory, None)
80        self.assertEqual(d2, {})
81        d1.default_factory = list
82        d3 = d1.copy()
83        self.assertEqual(type(d3), defaultdict)
84        self.assertEqual(d3.default_factory, list)
85        self.assertEqual(d3, {})
86        d1[42]
87        d4 = d1.copy()
88        self.assertEqual(type(d4), defaultdict)
89        self.assertEqual(d4.default_factory, list)
90        self.assertEqual(d4, {42: []})
91        d4[12]
92        self.assertEqual(d4, {42: [], 12: []})
93
94        # Issue 6637: Copy fails for empty default dict
95        d = defaultdict()
96        d['a'] = 42
97        e = d.copy()
98        self.assertEqual(e['a'], 42)
99
100    def test_shallow_copy(self):
101        d1 = defaultdict(foobar, {1: 1})
102        d2 = copy.copy(d1)
103        self.assertEqual(d2.default_factory, foobar)
104        self.assertEqual(d2, d1)
105        d1.default_factory = list
106        d2 = copy.copy(d1)
107        self.assertEqual(d2.default_factory, list)
108        self.assertEqual(d2, d1)
109
110    def test_deep_copy(self):
111        d1 = defaultdict(foobar, {1: [1]})
112        d2 = copy.deepcopy(d1)
113        self.assertEqual(d2.default_factory, foobar)
114        self.assertEqual(d2, d1)
115        self.assertTrue(d1[1] is not d2[1])
116        d1.default_factory = list
117        d2 = copy.deepcopy(d1)
118        self.assertEqual(d2.default_factory, list)
119        self.assertEqual(d2, d1)
120
121    def test_keyerror_without_factory(self):
122        d1 = defaultdict()
123        try:
124            d1[(1,)]
125        except KeyError as err:
126            self.assertEqual(err.args[0], (1,))
127        else:
128            self.fail("expected KeyError")
129
130    def test_recursive_repr(self):
131        # Issue2045: stack overflow when default_factory is a bound method
132        class sub(defaultdict):
133            def __init__(self):
134                self.default_factory = self._factory
135            def _factory(self):
136                return []
137        d = sub()
138        self.assertRegex(repr(d),
139            r"sub\(<bound method .*sub\._factory "
140            r"of sub\(\.\.\., \{\}\)>, \{\}\)")
141
142    def test_callable_arg(self):
143        self.assertRaises(TypeError, defaultdict, {})
144
145    def test_pickling(self):
146        d = defaultdict(int)
147        d[1]
148        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
149            s = pickle.dumps(d, proto)
150            o = pickle.loads(s)
151            self.assertEqual(d, o)
152
153    def test_union(self):
154        i = defaultdict(int, {1: 1, 2: 2})
155        s = defaultdict(str, {0: "zero", 1: "one"})
156
157        i_s = i | s
158        self.assertIs(i_s.default_factory, int)
159        self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
160        self.assertEqual(list(i_s), [1, 2, 0])
161
162        s_i = s | i
163        self.assertIs(s_i.default_factory, str)
164        self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
165        self.assertEqual(list(s_i), [0, 1, 2])
166
167        i_ds = i | dict(s)
168        self.assertIs(i_ds.default_factory, int)
169        self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
170        self.assertEqual(list(i_ds), [1, 2, 0])
171
172        ds_i = dict(s) | i
173        self.assertIs(ds_i.default_factory, int)
174        self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
175        self.assertEqual(list(ds_i), [0, 1, 2])
176
177        with self.assertRaises(TypeError):
178            i | list(s.items())
179        with self.assertRaises(TypeError):
180            list(s.items()) | i
181
182        # We inherit a fine |= from dict, so just a few sanity checks here:
183        i |= list(s.items())
184        self.assertIs(i.default_factory, int)
185        self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
186        self.assertEqual(list(i), [1, 2, 0])
187
188        with self.assertRaises(TypeError):
189            i |= None
190
191if __name__ == "__main__":
192    unittest.main()
193