1# tests common to dict and UserDict
2import unittest
3import collections
4import sys
5
6
7class BasicTestMappingProtocol(unittest.TestCase):
8    # This base class can be used to check that an object conforms to the
9    # mapping protocol
10
11    # Functions that can be useful to override to adapt to dictionary
12    # semantics
13    type2test = None # which class is being tested (overwrite in subclasses)
14
15    def _reference(self):
16        """Return a dictionary of values which are invariant by storage
17        in the object under test."""
18        return {"1": "2", "key1":"value1", "key2":(1,2,3)}
19    def _empty_mapping(self):
20        """Return an empty mapping object"""
21        return self.type2test()
22    def _full_mapping(self, data):
23        """Return a mapping object with the value contained in data
24        dictionary"""
25        x = self._empty_mapping()
26        for key, value in data.items():
27            x[key] = value
28        return x
29
30    def __init__(self, *args, **kw):
31        unittest.TestCase.__init__(self, *args, **kw)
32        self.reference = self._reference().copy()
33
34        # A (key, value) pair not in the mapping
35        key, value = self.reference.popitem()
36        self.other = {key:value}
37
38        # A (key, value) pair in the mapping
39        key, value = self.reference.popitem()
40        self.inmapping = {key:value}
41        self.reference[key] = value
42
43    def test_read(self):
44        # Test for read only operations on mapping
45        p = self._empty_mapping()
46        p1 = dict(p) #workaround for singleton objects
47        d = self._full_mapping(self.reference)
48        if d is p:
49            p = p1
50        #Indexing
51        for key, value in self.reference.items():
52            self.assertEqual(d[key], value)
53        knownkey = list(self.other.keys())[0]
54        self.assertRaises(KeyError, lambda:d[knownkey])
55        #len
56        self.assertEqual(len(p), 0)
57        self.assertEqual(len(d), len(self.reference))
58        #__contains__
59        for k in self.reference:
60            self.assertIn(k, d)
61        for k in self.other:
62            self.assertNotIn(k, d)
63        #cmp
64        self.assertEqual(p, p)
65        self.assertEqual(d, d)
66        self.assertNotEqual(p, d)
67        self.assertNotEqual(d, p)
68        #bool
69        if p: self.fail("Empty mapping must compare to False")
70        if not d: self.fail("Full mapping must compare to True")
71        # keys(), items(), iterkeys() ...
72        def check_iterandlist(iter, lst, ref):
73            self.assertTrue(hasattr(iter, '__next__'))
74            self.assertTrue(hasattr(iter, '__iter__'))
75            x = list(iter)
76            self.assertTrue(set(x)==set(lst)==set(ref))
77        check_iterandlist(iter(d.keys()), list(d.keys()),
78                          self.reference.keys())
79        check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
80        check_iterandlist(iter(d.values()), list(d.values()),
81                          self.reference.values())
82        check_iterandlist(iter(d.items()), list(d.items()),
83                          self.reference.items())
84        #get
85        key, value = next(iter(d.items()))
86        knownkey, knownvalue = next(iter(self.other.items()))
87        self.assertEqual(d.get(key, knownvalue), value)
88        self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
89        self.assertNotIn(knownkey, d)
90
91    def test_write(self):
92        # Test for write operations on mapping
93        p = self._empty_mapping()
94        #Indexing
95        for key, value in self.reference.items():
96            p[key] = value
97            self.assertEqual(p[key], value)
98        for key in self.reference.keys():
99            del p[key]
100            self.assertRaises(KeyError, lambda:p[key])
101        p = self._empty_mapping()
102        #update
103        p.update(self.reference)
104        self.assertEqual(dict(p), self.reference)
105        items = list(p.items())
106        p = self._empty_mapping()
107        p.update(items)
108        self.assertEqual(dict(p), self.reference)
109        d = self._full_mapping(self.reference)
110        #setdefault
111        key, value = next(iter(d.items()))
112        knownkey, knownvalue = next(iter(self.other.items()))
113        self.assertEqual(d.setdefault(key, knownvalue), value)
114        self.assertEqual(d[key], value)
115        self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
116        self.assertEqual(d[knownkey], knownvalue)
117        #pop
118        self.assertEqual(d.pop(knownkey), knownvalue)
119        self.assertNotIn(knownkey, d)
120        self.assertRaises(KeyError, d.pop, knownkey)
121        default = 909
122        d[knownkey] = knownvalue
123        self.assertEqual(d.pop(knownkey, default), knownvalue)
124        self.assertNotIn(knownkey, d)
125        self.assertEqual(d.pop(knownkey, default), default)
126        #popitem
127        key, value = d.popitem()
128        self.assertNotIn(key, d)
129        self.assertEqual(value, self.reference[key])
130        p=self._empty_mapping()
131        self.assertRaises(KeyError, p.popitem)
132
133    def test_constructor(self):
134        self.assertEqual(self._empty_mapping(), self._empty_mapping())
135
136    def test_bool(self):
137        self.assertTrue(not self._empty_mapping())
138        self.assertTrue(self.reference)
139        self.assertTrue(bool(self._empty_mapping()) is False)
140        self.assertTrue(bool(self.reference) is True)
141
142    def test_keys(self):
143        d = self._empty_mapping()
144        self.assertEqual(list(d.keys()), [])
145        d = self.reference
146        self.assertIn(list(self.inmapping.keys())[0], d.keys())
147        self.assertNotIn(list(self.other.keys())[0], d.keys())
148        self.assertRaises(TypeError, d.keys, None)
149
150    def test_values(self):
151        d = self._empty_mapping()
152        self.assertEqual(list(d.values()), [])
153
154        self.assertRaises(TypeError, d.values, None)
155
156    def test_items(self):
157        d = self._empty_mapping()
158        self.assertEqual(list(d.items()), [])
159
160        self.assertRaises(TypeError, d.items, None)
161
162    def test_len(self):
163        d = self._empty_mapping()
164        self.assertEqual(len(d), 0)
165
166    def test_getitem(self):
167        d = self.reference
168        self.assertEqual(d[list(self.inmapping.keys())[0]],
169                         list(self.inmapping.values())[0])
170
171        self.assertRaises(TypeError, d.__getitem__)
172
173    def test_update(self):
174        # mapping argument
175        d = self._empty_mapping()
176        d.update(self.other)
177        self.assertEqual(list(d.items()), list(self.other.items()))
178
179        # No argument
180        d = self._empty_mapping()
181        d.update()
182        self.assertEqual(d, self._empty_mapping())
183
184        # item sequence
185        d = self._empty_mapping()
186        d.update(self.other.items())
187        self.assertEqual(list(d.items()), list(self.other.items()))
188
189        # Iterator
190        d = self._empty_mapping()
191        d.update(self.other.items())
192        self.assertEqual(list(d.items()), list(self.other.items()))
193
194        # FIXME: Doesn't work with UserDict
195        # self.assertRaises((TypeError, AttributeError), d.update, None)
196        self.assertRaises((TypeError, AttributeError), d.update, 42)
197
198        outerself = self
199        class SimpleUserDict:
200            def __init__(self):
201                self.d = outerself.reference
202            def keys(self):
203                return self.d.keys()
204            def __getitem__(self, i):
205                return self.d[i]
206        d.clear()
207        d.update(SimpleUserDict())
208        i1 = sorted(d.items())
209        i2 = sorted(self.reference.items())
210        self.assertEqual(i1, i2)
211
212        class Exc(Exception): pass
213
214        d = self._empty_mapping()
215        class FailingUserDict:
216            def keys(self):
217                raise Exc
218        self.assertRaises(Exc, d.update, FailingUserDict())
219
220        d.clear()
221
222        class FailingUserDict:
223            def keys(self):
224                class BogonIter:
225                    def __init__(self):
226                        self.i = 1
227                    def __iter__(self):
228                        return self
229                    def __next__(self):
230                        if self.i:
231                            self.i = 0
232                            return 'a'
233                        raise Exc
234                return BogonIter()
235            def __getitem__(self, key):
236                return key
237        self.assertRaises(Exc, d.update, FailingUserDict())
238
239        class FailingUserDict:
240            def keys(self):
241                class BogonIter:
242                    def __init__(self):
243                        self.i = ord('a')
244                    def __iter__(self):
245                        return self
246                    def __next__(self):
247                        if self.i <= ord('z'):
248                            rtn = chr(self.i)
249                            self.i += 1
250                            return rtn
251                        raise StopIteration
252                return BogonIter()
253            def __getitem__(self, key):
254                raise Exc
255        self.assertRaises(Exc, d.update, FailingUserDict())
256
257        d = self._empty_mapping()
258        class badseq(object):
259            def __iter__(self):
260                return self
261            def __next__(self):
262                raise Exc()
263
264        self.assertRaises(Exc, d.update, badseq())
265
266        self.assertRaises(ValueError, d.update, [(1, 2, 3)])
267
268    # no test_fromkeys or test_copy as both os.environ and selves don't support it
269
270    def test_get(self):
271        d = self._empty_mapping()
272        self.assertTrue(d.get(list(self.other.keys())[0]) is None)
273        self.assertEqual(d.get(list(self.other.keys())[0], 3), 3)
274        d = self.reference
275        self.assertTrue(d.get(list(self.other.keys())[0]) is None)
276        self.assertEqual(d.get(list(self.other.keys())[0], 3), 3)
277        self.assertEqual(d.get(list(self.inmapping.keys())[0]),
278                         list(self.inmapping.values())[0])
279        self.assertEqual(d.get(list(self.inmapping.keys())[0], 3),
280                         list(self.inmapping.values())[0])
281        self.assertRaises(TypeError, d.get)
282        self.assertRaises(TypeError, d.get, None, None, None)
283
284    def test_setdefault(self):
285        d = self._empty_mapping()
286        self.assertRaises(TypeError, d.setdefault)
287
288    def test_popitem(self):
289        d = self._empty_mapping()
290        self.assertRaises(KeyError, d.popitem)
291        self.assertRaises(TypeError, d.popitem, 42)
292
293    def test_pop(self):
294        d = self._empty_mapping()
295        k, v = list(self.inmapping.items())[0]
296        d[k] = v
297        self.assertRaises(KeyError, d.pop, list(self.other.keys())[0])
298
299        self.assertEqual(d.pop(k), v)
300        self.assertEqual(len(d), 0)
301
302        self.assertRaises(KeyError, d.pop, k)
303
304
305class TestMappingProtocol(BasicTestMappingProtocol):
306    def test_constructor(self):
307        BasicTestMappingProtocol.test_constructor(self)
308        self.assertTrue(self._empty_mapping() is not self._empty_mapping())
309        self.assertEqual(self.type2test(x=1, y=2), {"x": 1, "y": 2})
310
311    def test_bool(self):
312        BasicTestMappingProtocol.test_bool(self)
313        self.assertTrue(not self._empty_mapping())
314        self.assertTrue(self._full_mapping({"x": "y"}))
315        self.assertTrue(bool(self._empty_mapping()) is False)
316        self.assertTrue(bool(self._full_mapping({"x": "y"})) is True)
317
318    def test_keys(self):
319        BasicTestMappingProtocol.test_keys(self)
320        d = self._empty_mapping()
321        self.assertEqual(list(d.keys()), [])
322        d = self._full_mapping({'a': 1, 'b': 2})
323        k = d.keys()
324        self.assertIn('a', k)
325        self.assertIn('b', k)
326        self.assertNotIn('c', k)
327
328    def test_values(self):
329        BasicTestMappingProtocol.test_values(self)
330        d = self._full_mapping({1:2})
331        self.assertEqual(list(d.values()), [2])
332
333    def test_items(self):
334        BasicTestMappingProtocol.test_items(self)
335
336        d = self._full_mapping({1:2})
337        self.assertEqual(list(d.items()), [(1, 2)])
338
339    def test_contains(self):
340        d = self._empty_mapping()
341        self.assertNotIn('a', d)
342        self.assertTrue(not ('a' in d))
343        self.assertTrue('a' not in d)
344        d = self._full_mapping({'a': 1, 'b': 2})
345        self.assertIn('a', d)
346        self.assertIn('b', d)
347        self.assertNotIn('c', d)
348
349        self.assertRaises(TypeError, d.__contains__)
350
351    def test_len(self):
352        BasicTestMappingProtocol.test_len(self)
353        d = self._full_mapping({'a': 1, 'b': 2})
354        self.assertEqual(len(d), 2)
355
356    def test_getitem(self):
357        BasicTestMappingProtocol.test_getitem(self)
358        d = self._full_mapping({'a': 1, 'b': 2})
359        self.assertEqual(d['a'], 1)
360        self.assertEqual(d['b'], 2)
361        d['c'] = 3
362        d['a'] = 4
363        self.assertEqual(d['c'], 3)
364        self.assertEqual(d['a'], 4)
365        del d['b']
366        self.assertEqual(d, {'a': 4, 'c': 3})
367
368        self.assertRaises(TypeError, d.__getitem__)
369
370    def test_clear(self):
371        d = self._full_mapping({1:1, 2:2, 3:3})
372        d.clear()
373        self.assertEqual(d, {})
374
375        self.assertRaises(TypeError, d.clear, None)
376
377    def test_update(self):
378        BasicTestMappingProtocol.test_update(self)
379        # mapping argument
380        d = self._empty_mapping()
381        d.update({1:100})
382        d.update({2:20})
383        d.update({1:1, 2:2, 3:3})
384        self.assertEqual(d, {1:1, 2:2, 3:3})
385
386        # no argument
387        d.update()
388        self.assertEqual(d, {1:1, 2:2, 3:3})
389
390        # keyword arguments
391        d = self._empty_mapping()
392        d.update(x=100)
393        d.update(y=20)
394        d.update(x=1, y=2, z=3)
395        self.assertEqual(d, {"x":1, "y":2, "z":3})
396
397        # item sequence
398        d = self._empty_mapping()
399        d.update([("x", 100), ("y", 20)])
400        self.assertEqual(d, {"x":100, "y":20})
401
402        # Both item sequence and keyword arguments
403        d = self._empty_mapping()
404        d.update([("x", 100), ("y", 20)], x=1, y=2)
405        self.assertEqual(d, {"x":1, "y":2})
406
407        # iterator
408        d = self._full_mapping({1:3, 2:4})
409        d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
410        self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
411
412        class SimpleUserDict:
413            def __init__(self):
414                self.d = {1:1, 2:2, 3:3}
415            def keys(self):
416                return self.d.keys()
417            def __getitem__(self, i):
418                return self.d[i]
419        d.clear()
420        d.update(SimpleUserDict())
421        self.assertEqual(d, {1:1, 2:2, 3:3})
422
423    def test_fromkeys(self):
424        self.assertEqual(self.type2test.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
425        d = self._empty_mapping()
426        self.assertTrue(not(d.fromkeys('abc') is d))
427        self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
428        self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0})
429        self.assertEqual(d.fromkeys([]), {})
430        def g():
431            yield 1
432        self.assertEqual(d.fromkeys(g()), {1:None})
433        self.assertRaises(TypeError, {}.fromkeys, 3)
434        class dictlike(self.type2test): pass
435        self.assertEqual(dictlike.fromkeys('a'), {'a':None})
436        self.assertEqual(dictlike().fromkeys('a'), {'a':None})
437        self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike)
438        self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike)
439        self.assertTrue(type(dictlike.fromkeys('a')) is dictlike)
440        class mydict(self.type2test):
441            def __new__(cls):
442                return collections.UserDict()
443        ud = mydict.fromkeys('ab')
444        self.assertEqual(ud, {'a':None, 'b':None})
445        self.assertIsInstance(ud, collections.UserDict)
446        self.assertRaises(TypeError, dict.fromkeys)
447
448        class Exc(Exception): pass
449
450        class baddict1(self.type2test):
451            def __init__(self, *args, **kwargs):
452                raise Exc()
453
454        self.assertRaises(Exc, baddict1.fromkeys, [1])
455
456        class BadSeq(object):
457            def __iter__(self):
458                return self
459            def __next__(self):
460                raise Exc()
461
462        self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
463
464        class baddict2(self.type2test):
465            def __setitem__(self, key, value):
466                raise Exc()
467
468        self.assertRaises(Exc, baddict2.fromkeys, [1])
469
470    def test_copy(self):
471        d = self._full_mapping({1:1, 2:2, 3:3})
472        self.assertEqual(d.copy(), {1:1, 2:2, 3:3})
473        d = self._empty_mapping()
474        self.assertEqual(d.copy(), d)
475        self.assertIsInstance(d.copy(), d.__class__)
476        self.assertRaises(TypeError, d.copy, None)
477
478    def test_get(self):
479        BasicTestMappingProtocol.test_get(self)
480        d = self._empty_mapping()
481        self.assertTrue(d.get('c') is None)
482        self.assertEqual(d.get('c', 3), 3)
483        d = self._full_mapping({'a' : 1, 'b' : 2})
484        self.assertTrue(d.get('c') is None)
485        self.assertEqual(d.get('c', 3), 3)
486        self.assertEqual(d.get('a'), 1)
487        self.assertEqual(d.get('a', 3), 1)
488
489    def test_setdefault(self):
490        BasicTestMappingProtocol.test_setdefault(self)
491        d = self._empty_mapping()
492        self.assertTrue(d.setdefault('key0') is None)
493        d.setdefault('key0', [])
494        self.assertTrue(d.setdefault('key0') is None)
495        d.setdefault('key', []).append(3)
496        self.assertEqual(d['key'][0], 3)
497        d.setdefault('key', []).append(4)
498        self.assertEqual(len(d['key']), 2)
499
500    def test_popitem(self):
501        BasicTestMappingProtocol.test_popitem(self)
502        for copymode in -1, +1:
503            # -1: b has same structure as a
504            # +1: b is a.copy()
505            for log2size in range(12):
506                size = 2**log2size
507                a = self._empty_mapping()
508                b = self._empty_mapping()
509                for i in range(size):
510                    a[repr(i)] = i
511                    if copymode < 0:
512                        b[repr(i)] = i
513                if copymode > 0:
514                    b = a.copy()
515                for i in range(size):
516                    ka, va = ta = a.popitem()
517                    self.assertEqual(va, int(ka))
518                    kb, vb = tb = b.popitem()
519                    self.assertEqual(vb, int(kb))
520                    self.assertTrue(not(copymode < 0 and ta != tb))
521                self.assertTrue(not a)
522                self.assertTrue(not b)
523
524    def test_pop(self):
525        BasicTestMappingProtocol.test_pop(self)
526
527        # Tests for pop with specified key
528        d = self._empty_mapping()
529        k, v = 'abc', 'def'
530
531        self.assertEqual(d.pop(k, v), v)
532        d[k] = v
533        self.assertEqual(d.pop(k, 1), v)
534
535
536class TestHashMappingProtocol(TestMappingProtocol):
537
538    def test_getitem(self):
539        TestMappingProtocol.test_getitem(self)
540        class Exc(Exception): pass
541
542        class BadEq(object):
543            def __eq__(self, other):
544                raise Exc()
545            def __hash__(self):
546                return 24
547
548        d = self._empty_mapping()
549        d[BadEq()] = 42
550        self.assertRaises(KeyError, d.__getitem__, 23)
551
552        class BadHash(object):
553            fail = False
554            def __hash__(self):
555                if self.fail:
556                    raise Exc()
557                else:
558                    return 42
559
560        d = self._empty_mapping()
561        x = BadHash()
562        d[x] = 42
563        x.fail = True
564        self.assertRaises(Exc, d.__getitem__, x)
565
566    def test_fromkeys(self):
567        TestMappingProtocol.test_fromkeys(self)
568        class mydict(self.type2test):
569            def __new__(cls):
570                return collections.UserDict()
571        ud = mydict.fromkeys('ab')
572        self.assertEqual(ud, {'a':None, 'b':None})
573        self.assertIsInstance(ud, collections.UserDict)
574
575    def test_pop(self):
576        TestMappingProtocol.test_pop(self)
577
578        class Exc(Exception): pass
579
580        class BadHash(object):
581            fail = False
582            def __hash__(self):
583                if self.fail:
584                    raise Exc()
585                else:
586                    return 42
587
588        d = self._empty_mapping()
589        x = BadHash()
590        d[x] = 42
591        x.fail = True
592        self.assertRaises(Exc, d.pop, x)
593
594    def test_mutatingiteration(self):
595        d = self._empty_mapping()
596        d[1] = 1
597        try:
598            count = 0
599            for i in d:
600                d[i+1] = 1
601                if count >= 1:
602                    self.fail("changing dict size during iteration doesn't raise Error")
603                count += 1
604        except RuntimeError:
605            pass
606
607    def test_repr(self):
608        d = self._empty_mapping()
609        self.assertEqual(repr(d), '{}')
610        d[1] = 2
611        self.assertEqual(repr(d), '{1: 2}')
612        d = self._empty_mapping()
613        d[1] = d
614        self.assertEqual(repr(d), '{1: {...}}')
615
616        class Exc(Exception): pass
617
618        class BadRepr(object):
619            def __repr__(self):
620                raise Exc()
621
622        d = self._full_mapping({1: BadRepr()})
623        self.assertRaises(Exc, repr, d)
624
625    def test_repr_deep(self):
626        d = self._empty_mapping()
627        for i in range(sys.getrecursionlimit() + 100):
628            d0 = d
629            d = self._empty_mapping()
630            d[1] = d0
631        self.assertRaises(RecursionError, repr, d)
632
633    def test_eq(self):
634        self.assertEqual(self._empty_mapping(), self._empty_mapping())
635        self.assertEqual(self._full_mapping({1: 2}),
636                         self._full_mapping({1: 2}))
637
638        class Exc(Exception): pass
639
640        class BadCmp(object):
641            def __eq__(self, other):
642                raise Exc()
643            def __hash__(self):
644                return 1
645
646        d1 = self._full_mapping({BadCmp(): 1})
647        d2 = self._full_mapping({1: 1})
648        self.assertRaises(Exc, lambda: BadCmp()==1)
649        self.assertRaises(Exc, lambda: d1==d2)
650
651    def test_setdefault(self):
652        TestMappingProtocol.test_setdefault(self)
653
654        class Exc(Exception): pass
655
656        class BadHash(object):
657            fail = False
658            def __hash__(self):
659                if self.fail:
660                    raise Exc()
661                else:
662                    return 42
663
664        d = self._empty_mapping()
665        x = BadHash()
666        d[x] = 42
667        x.fail = True
668        self.assertRaises(Exc, d.setdefault, x, [])
669