1import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11    from _testcapi import hamt
12except ImportError:
13    hamt = None
14
15
16def isolated_context(func):
17    """Needed to make reftracking test mode work."""
18    @functools.wraps(func)
19    def wrapper(*args, **kwargs):
20        ctx = contextvars.Context()
21        return ctx.run(func, *args, **kwargs)
22    return wrapper
23
24
25class ContextTest(unittest.TestCase):
26    def test_context_var_new_1(self):
27        with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28            contextvars.ContextVar()
29
30        with self.assertRaisesRegex(TypeError, 'must be a str'):
31            contextvars.ContextVar(1)
32
33        c = contextvars.ContextVar('aaa')
34        self.assertEqual(c.name, 'aaa')
35
36        with self.assertRaises(AttributeError):
37            c.name = 'bbb'
38
39        self.assertNotEqual(hash(c), hash('aaa'))
40
41    @isolated_context
42    def test_context_var_repr_1(self):
43        c = contextvars.ContextVar('a')
44        self.assertIn('a', repr(c))
45
46        c = contextvars.ContextVar('a', default=123)
47        self.assertIn('123', repr(c))
48
49        lst = []
50        c = contextvars.ContextVar('a', default=lst)
51        lst.append(c)
52        self.assertIn('...', repr(c))
53        self.assertIn('...', repr(lst))
54
55        t = c.set(1)
56        self.assertIn(repr(c), repr(t))
57        self.assertNotIn(' used ', repr(t))
58        c.reset(t)
59        self.assertIn(' used ', repr(t))
60
61    def test_context_subclassing_1(self):
62        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
63            class MyContextVar(contextvars.ContextVar):
64                # Potentially we might want ContextVars to be subclassable.
65                pass
66
67        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
68            class MyContext(contextvars.Context):
69                pass
70
71        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
72            class MyToken(contextvars.Token):
73                pass
74
75    def test_context_new_1(self):
76        with self.assertRaisesRegex(TypeError, 'any arguments'):
77            contextvars.Context(1)
78        with self.assertRaisesRegex(TypeError, 'any arguments'):
79            contextvars.Context(1, a=1)
80        with self.assertRaisesRegex(TypeError, 'any arguments'):
81            contextvars.Context(a=1)
82        contextvars.Context(**{})
83
84    def test_context_typerrors_1(self):
85        ctx = contextvars.Context()
86
87        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
88            ctx[1]
89        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90            1 in ctx
91        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
92            ctx.get(1)
93
94    def test_context_get_context_1(self):
95        ctx = contextvars.copy_context()
96        self.assertIsInstance(ctx, contextvars.Context)
97
98    def test_context_run_1(self):
99        ctx = contextvars.Context()
100
101        with self.assertRaisesRegex(TypeError, 'missing 1 required'):
102            ctx.run()
103
104    def test_context_run_2(self):
105        ctx = contextvars.Context()
106
107        def func(*args, **kwargs):
108            kwargs['spam'] = 'foo'
109            args += ('bar',)
110            return args, kwargs
111
112        for f in (func, functools.partial(func)):
113            # partial doesn't support FASTCALL
114
115            self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
116            self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
117
118            self.assertEqual(
119                ctx.run(f, a=2),
120                (('bar',), {'a': 2, 'spam': 'foo'}))
121
122            self.assertEqual(
123                ctx.run(f, 11, a=2),
124                ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
125
126            a = {}
127            self.assertEqual(
128                ctx.run(f, 11, **a),
129                ((11, 'bar'), {'spam': 'foo'}))
130            self.assertEqual(a, {})
131
132    def test_context_run_3(self):
133        ctx = contextvars.Context()
134
135        def func(*args, **kwargs):
136            1 / 0
137
138        with self.assertRaises(ZeroDivisionError):
139            ctx.run(func)
140        with self.assertRaises(ZeroDivisionError):
141            ctx.run(func, 1, 2)
142        with self.assertRaises(ZeroDivisionError):
143            ctx.run(func, 1, 2, a=123)
144
145    @isolated_context
146    def test_context_run_4(self):
147        ctx1 = contextvars.Context()
148        ctx2 = contextvars.Context()
149        var = contextvars.ContextVar('var')
150
151        def func2():
152            self.assertIsNone(var.get(None))
153
154        def func1():
155            self.assertIsNone(var.get(None))
156            var.set('spam')
157            ctx2.run(func2)
158            self.assertEqual(var.get(None), 'spam')
159
160            cur = contextvars.copy_context()
161            self.assertEqual(len(cur), 1)
162            self.assertEqual(cur[var], 'spam')
163            return cur
164
165        returned_ctx = ctx1.run(func1)
166        self.assertEqual(ctx1, returned_ctx)
167        self.assertEqual(returned_ctx[var], 'spam')
168        self.assertIn(var, returned_ctx)
169
170    def test_context_run_5(self):
171        ctx = contextvars.Context()
172        var = contextvars.ContextVar('var')
173
174        def func():
175            self.assertIsNone(var.get(None))
176            var.set('spam')
177            1 / 0
178
179        with self.assertRaises(ZeroDivisionError):
180            ctx.run(func)
181
182        self.assertIsNone(var.get(None))
183
184    def test_context_run_6(self):
185        ctx = contextvars.Context()
186        c = contextvars.ContextVar('a', default=0)
187
188        def fun():
189            self.assertEqual(c.get(), 0)
190            self.assertIsNone(ctx.get(c))
191
192            c.set(42)
193            self.assertEqual(c.get(), 42)
194            self.assertEqual(ctx.get(c), 42)
195
196        ctx.run(fun)
197
198    def test_context_run_7(self):
199        ctx = contextvars.Context()
200
201        def fun():
202            with self.assertRaisesRegex(RuntimeError, 'is already entered'):
203                ctx.run(fun)
204
205        ctx.run(fun)
206
207    @isolated_context
208    def test_context_getset_1(self):
209        c = contextvars.ContextVar('c')
210        with self.assertRaises(LookupError):
211            c.get()
212
213        self.assertIsNone(c.get(None))
214
215        t0 = c.set(42)
216        self.assertEqual(c.get(), 42)
217        self.assertEqual(c.get(None), 42)
218        self.assertIs(t0.old_value, t0.MISSING)
219        self.assertIs(t0.old_value, contextvars.Token.MISSING)
220        self.assertIs(t0.var, c)
221
222        t = c.set('spam')
223        self.assertEqual(c.get(), 'spam')
224        self.assertEqual(c.get(None), 'spam')
225        self.assertEqual(t.old_value, 42)
226        c.reset(t)
227
228        self.assertEqual(c.get(), 42)
229        self.assertEqual(c.get(None), 42)
230
231        c.set('spam2')
232        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
233            c.reset(t)
234        self.assertEqual(c.get(), 'spam2')
235
236        ctx1 = contextvars.copy_context()
237        self.assertIn(c, ctx1)
238
239        c.reset(t0)
240        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
241            c.reset(t0)
242        self.assertIsNone(c.get(None))
243
244        self.assertIn(c, ctx1)
245        self.assertEqual(ctx1[c], 'spam2')
246        self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
247        self.assertEqual(len(ctx1), 1)
248        self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
249        self.assertEqual(list(ctx1.values()), ['spam2'])
250        self.assertEqual(list(ctx1.keys()), [c])
251        self.assertEqual(list(ctx1), [c])
252
253        ctx2 = contextvars.copy_context()
254        self.assertNotIn(c, ctx2)
255        with self.assertRaises(KeyError):
256            ctx2[c]
257        self.assertEqual(ctx2.get(c, 'aa'), 'aa')
258        self.assertEqual(len(ctx2), 0)
259        self.assertEqual(list(ctx2), [])
260
261    @isolated_context
262    def test_context_getset_2(self):
263        v1 = contextvars.ContextVar('v1')
264        v2 = contextvars.ContextVar('v2')
265
266        t1 = v1.set(42)
267        with self.assertRaisesRegex(ValueError, 'by a different'):
268            v2.reset(t1)
269
270    @isolated_context
271    def test_context_getset_3(self):
272        c = contextvars.ContextVar('c', default=42)
273        ctx = contextvars.Context()
274
275        def fun():
276            self.assertEqual(c.get(), 42)
277            with self.assertRaises(KeyError):
278                ctx[c]
279            self.assertIsNone(ctx.get(c))
280            self.assertEqual(ctx.get(c, 'spam'), 'spam')
281            self.assertNotIn(c, ctx)
282            self.assertEqual(list(ctx.keys()), [])
283
284            t = c.set(1)
285            self.assertEqual(list(ctx.keys()), [c])
286            self.assertEqual(ctx[c], 1)
287
288            c.reset(t)
289            self.assertEqual(list(ctx.keys()), [])
290            with self.assertRaises(KeyError):
291                ctx[c]
292
293        ctx.run(fun)
294
295    @isolated_context
296    def test_context_getset_4(self):
297        c = contextvars.ContextVar('c', default=42)
298        ctx = contextvars.Context()
299
300        tok = ctx.run(c.set, 1)
301
302        with self.assertRaisesRegex(ValueError, 'different Context'):
303            c.reset(tok)
304
305    @isolated_context
306    def test_context_getset_5(self):
307        c = contextvars.ContextVar('c', default=42)
308        c.set([])
309
310        def fun():
311            c.set([])
312            c.get().append(42)
313            self.assertEqual(c.get(), [42])
314
315        contextvars.copy_context().run(fun)
316        self.assertEqual(c.get(), [])
317
318    def test_context_copy_1(self):
319        ctx1 = contextvars.Context()
320        c = contextvars.ContextVar('c', default=42)
321
322        def ctx1_fun():
323            c.set(10)
324
325            ctx2 = ctx1.copy()
326            self.assertEqual(ctx2[c], 10)
327
328            c.set(20)
329            self.assertEqual(ctx1[c], 20)
330            self.assertEqual(ctx2[c], 10)
331
332            ctx2.run(ctx2_fun)
333            self.assertEqual(ctx1[c], 20)
334            self.assertEqual(ctx2[c], 30)
335
336        def ctx2_fun():
337            self.assertEqual(c.get(), 10)
338            c.set(30)
339            self.assertEqual(c.get(), 30)
340
341        ctx1.run(ctx1_fun)
342
343    @isolated_context
344    def test_context_threads_1(self):
345        cvar = contextvars.ContextVar('cvar')
346
347        def sub(num):
348            for i in range(10):
349                cvar.set(num + i)
350                time.sleep(random.uniform(0.001, 0.05))
351                self.assertEqual(cvar.get(), num + i)
352            return num
353
354        tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
355        try:
356            results = list(tp.map(sub, range(10)))
357        finally:
358            tp.shutdown()
359        self.assertEqual(results, list(range(10)))
360
361
362# HAMT Tests
363
364
365class HashKey:
366    _crasher = None
367
368    def __init__(self, hash, name, *, error_on_eq_to=None):
369        assert hash != -1
370        self.name = name
371        self.hash = hash
372        self.error_on_eq_to = error_on_eq_to
373
374    def __repr__(self):
375        return f'<Key name:{self.name} hash:{self.hash}>'
376
377    def __hash__(self):
378        if self._crasher is not None and self._crasher.error_on_hash:
379            raise HashingError
380
381        return self.hash
382
383    def __eq__(self, other):
384        if not isinstance(other, HashKey):
385            return NotImplemented
386
387        if self._crasher is not None and self._crasher.error_on_eq:
388            raise EqError
389
390        if self.error_on_eq_to is not None and self.error_on_eq_to is other:
391            raise ValueError(f'cannot compare {self!r} to {other!r}')
392        if other.error_on_eq_to is not None and other.error_on_eq_to is self:
393            raise ValueError(f'cannot compare {other!r} to {self!r}')
394
395        return (self.name, self.hash) == (other.name, other.hash)
396
397
398class KeyStr(str):
399    def __hash__(self):
400        if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
401            raise HashingError
402        return super().__hash__()
403
404    def __eq__(self, other):
405        if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
406            raise EqError
407        return super().__eq__(other)
408
409
410class HaskKeyCrasher:
411    def __init__(self, *, error_on_hash=False, error_on_eq=False):
412        self.error_on_hash = error_on_hash
413        self.error_on_eq = error_on_eq
414
415    def __enter__(self):
416        if HashKey._crasher is not None:
417            raise RuntimeError('cannot nest crashers')
418        HashKey._crasher = self
419
420    def __exit__(self, *exc):
421        HashKey._crasher = None
422
423
424class HashingError(Exception):
425    pass
426
427
428class EqError(Exception):
429    pass
430
431
432@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
433class HamtTest(unittest.TestCase):
434
435    def test_hashkey_helper_1(self):
436        k1 = HashKey(10, 'aaa')
437        k2 = HashKey(10, 'bbb')
438
439        self.assertNotEqual(k1, k2)
440        self.assertEqual(hash(k1), hash(k2))
441
442        d = dict()
443        d[k1] = 'a'
444        d[k2] = 'b'
445
446        self.assertEqual(d[k1], 'a')
447        self.assertEqual(d[k2], 'b')
448
449    def test_hamt_basics_1(self):
450        h = hamt()
451        h = None  # NoQA
452
453    def test_hamt_basics_2(self):
454        h = hamt()
455        self.assertEqual(len(h), 0)
456
457        h2 = h.set('a', 'b')
458        self.assertIsNot(h, h2)
459        self.assertEqual(len(h), 0)
460        self.assertEqual(len(h2), 1)
461
462        self.assertIsNone(h.get('a'))
463        self.assertEqual(h.get('a', 42), 42)
464
465        self.assertEqual(h2.get('a'), 'b')
466
467        h3 = h2.set('b', 10)
468        self.assertIsNot(h2, h3)
469        self.assertEqual(len(h), 0)
470        self.assertEqual(len(h2), 1)
471        self.assertEqual(len(h3), 2)
472        self.assertEqual(h3.get('a'), 'b')
473        self.assertEqual(h3.get('b'), 10)
474
475        self.assertIsNone(h.get('b'))
476        self.assertIsNone(h2.get('b'))
477
478        self.assertIsNone(h.get('a'))
479        self.assertEqual(h2.get('a'), 'b')
480
481        h = h2 = h3 = None
482
483    def test_hamt_basics_3(self):
484        h = hamt()
485        o = object()
486        h1 = h.set('1', o)
487        h2 = h1.set('1', o)
488        self.assertIs(h1, h2)
489
490    def test_hamt_basics_4(self):
491        h = hamt()
492        h1 = h.set('key', [])
493        h2 = h1.set('key', [])
494        self.assertIsNot(h1, h2)
495        self.assertEqual(len(h1), 1)
496        self.assertEqual(len(h2), 1)
497        self.assertIsNot(h1.get('key'), h2.get('key'))
498
499    def test_hamt_collision_1(self):
500        k1 = HashKey(10, 'aaa')
501        k2 = HashKey(10, 'bbb')
502        k3 = HashKey(10, 'ccc')
503
504        h = hamt()
505        h2 = h.set(k1, 'a')
506        h3 = h2.set(k2, 'b')
507
508        self.assertEqual(h.get(k1), None)
509        self.assertEqual(h.get(k2), None)
510
511        self.assertEqual(h2.get(k1), 'a')
512        self.assertEqual(h2.get(k2), None)
513
514        self.assertEqual(h3.get(k1), 'a')
515        self.assertEqual(h3.get(k2), 'b')
516
517        h4 = h3.set(k2, 'cc')
518        h5 = h4.set(k3, 'aa')
519
520        self.assertEqual(h3.get(k1), 'a')
521        self.assertEqual(h3.get(k2), 'b')
522        self.assertEqual(h4.get(k1), 'a')
523        self.assertEqual(h4.get(k2), 'cc')
524        self.assertEqual(h4.get(k3), None)
525        self.assertEqual(h5.get(k1), 'a')
526        self.assertEqual(h5.get(k2), 'cc')
527        self.assertEqual(h5.get(k2), 'cc')
528        self.assertEqual(h5.get(k3), 'aa')
529
530        self.assertEqual(len(h), 0)
531        self.assertEqual(len(h2), 1)
532        self.assertEqual(len(h3), 2)
533        self.assertEqual(len(h4), 2)
534        self.assertEqual(len(h5), 3)
535
536    def test_hamt_stress(self):
537        COLLECTION_SIZE = 7000
538        TEST_ITERS_EVERY = 647
539        CRASH_HASH_EVERY = 97
540        CRASH_EQ_EVERY = 11
541        RUN_XTIMES = 3
542
543        for _ in range(RUN_XTIMES):
544            h = hamt()
545            d = dict()
546
547            for i in range(COLLECTION_SIZE):
548                key = KeyStr(i)
549
550                if not (i % CRASH_HASH_EVERY):
551                    with HaskKeyCrasher(error_on_hash=True):
552                        with self.assertRaises(HashingError):
553                            h.set(key, i)
554
555                h = h.set(key, i)
556
557                if not (i % CRASH_EQ_EVERY):
558                    with HaskKeyCrasher(error_on_eq=True):
559                        with self.assertRaises(EqError):
560                            h.get(KeyStr(i))  # really trigger __eq__
561
562                d[key] = i
563                self.assertEqual(len(d), len(h))
564
565                if not (i % TEST_ITERS_EVERY):
566                    self.assertEqual(set(h.items()), set(d.items()))
567                    self.assertEqual(len(h.items()), len(d.items()))
568
569            self.assertEqual(len(h), COLLECTION_SIZE)
570
571            for key in range(COLLECTION_SIZE):
572                self.assertEqual(h.get(KeyStr(key), 'not found'), key)
573
574            keys_to_delete = list(range(COLLECTION_SIZE))
575            random.shuffle(keys_to_delete)
576            for iter_i, i in enumerate(keys_to_delete):
577                key = KeyStr(i)
578
579                if not (iter_i % CRASH_HASH_EVERY):
580                    with HaskKeyCrasher(error_on_hash=True):
581                        with self.assertRaises(HashingError):
582                            h.delete(key)
583
584                if not (iter_i % CRASH_EQ_EVERY):
585                    with HaskKeyCrasher(error_on_eq=True):
586                        with self.assertRaises(EqError):
587                            h.delete(KeyStr(i))
588
589                h = h.delete(key)
590                self.assertEqual(h.get(key, 'not found'), 'not found')
591                del d[key]
592                self.assertEqual(len(d), len(h))
593
594                if iter_i == COLLECTION_SIZE // 2:
595                    hm = h
596                    dm = d.copy()
597
598                if not (iter_i % TEST_ITERS_EVERY):
599                    self.assertEqual(set(h.keys()), set(d.keys()))
600                    self.assertEqual(len(h.keys()), len(d.keys()))
601
602            self.assertEqual(len(d), 0)
603            self.assertEqual(len(h), 0)
604
605            # ============
606
607            for key in dm:
608                self.assertEqual(hm.get(str(key)), dm[key])
609            self.assertEqual(len(dm), len(hm))
610
611            for i, key in enumerate(keys_to_delete):
612                hm = hm.delete(str(key))
613                self.assertEqual(hm.get(str(key), 'not found'), 'not found')
614                dm.pop(str(key), None)
615                self.assertEqual(len(d), len(h))
616
617                if not (i % TEST_ITERS_EVERY):
618                    self.assertEqual(set(h.values()), set(d.values()))
619                    self.assertEqual(len(h.values()), len(d.values()))
620
621            self.assertEqual(len(d), 0)
622            self.assertEqual(len(h), 0)
623            self.assertEqual(list(h.items()), [])
624
625    def test_hamt_delete_1(self):
626        A = HashKey(100, 'A')
627        B = HashKey(101, 'B')
628        C = HashKey(102, 'C')
629        D = HashKey(103, 'D')
630        E = HashKey(104, 'E')
631        Z = HashKey(-100, 'Z')
632
633        Er = HashKey(103, 'Er', error_on_eq_to=D)
634
635        h = hamt()
636        h = h.set(A, 'a')
637        h = h.set(B, 'b')
638        h = h.set(C, 'c')
639        h = h.set(D, 'd')
640        h = h.set(E, 'e')
641
642        orig_len = len(h)
643
644        # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
645        #     <Key name:A hash:100>: 'a'
646        #     <Key name:B hash:101>: 'b'
647        #     <Key name:C hash:102>: 'c'
648        #     <Key name:D hash:103>: 'd'
649        #     <Key name:E hash:104>: 'e'
650
651        h = h.delete(C)
652        self.assertEqual(len(h), orig_len - 1)
653
654        with self.assertRaisesRegex(ValueError, 'cannot compare'):
655            h.delete(Er)
656
657        h = h.delete(D)
658        self.assertEqual(len(h), orig_len - 2)
659
660        h2 = h.delete(Z)
661        self.assertIs(h2, h)
662
663        h = h.delete(A)
664        self.assertEqual(len(h), orig_len - 3)
665
666        self.assertEqual(h.get(A, 42), 42)
667        self.assertEqual(h.get(B), 'b')
668        self.assertEqual(h.get(E), 'e')
669
670    def test_hamt_delete_2(self):
671        A = HashKey(100, 'A')
672        B = HashKey(201001, 'B')
673        C = HashKey(101001, 'C')
674        D = HashKey(103, 'D')
675        E = HashKey(104, 'E')
676        Z = HashKey(-100, 'Z')
677
678        Er = HashKey(201001, 'Er', error_on_eq_to=B)
679
680        h = hamt()
681        h = h.set(A, 'a')
682        h = h.set(B, 'b')
683        h = h.set(C, 'c')
684        h = h.set(D, 'd')
685        h = h.set(E, 'e')
686
687        orig_len = len(h)
688
689        # BitmapNode(size=8 bitmap=0b1110010000):
690        #     <Key name:A hash:100>: 'a'
691        #     <Key name:D hash:103>: 'd'
692        #     <Key name:E hash:104>: 'e'
693        #     NULL:
694        #         BitmapNode(size=4 bitmap=0b100000000001000000000):
695        #             <Key name:B hash:201001>: 'b'
696        #             <Key name:C hash:101001>: 'c'
697
698        with self.assertRaisesRegex(ValueError, 'cannot compare'):
699            h.delete(Er)
700
701        h = h.delete(Z)
702        self.assertEqual(len(h), orig_len)
703
704        h = h.delete(C)
705        self.assertEqual(len(h), orig_len - 1)
706
707        h = h.delete(B)
708        self.assertEqual(len(h), orig_len - 2)
709
710        h = h.delete(A)
711        self.assertEqual(len(h), orig_len - 3)
712
713        self.assertEqual(h.get(D), 'd')
714        self.assertEqual(h.get(E), 'e')
715
716        h = h.delete(A)
717        h = h.delete(B)
718        h = h.delete(D)
719        h = h.delete(E)
720        self.assertEqual(len(h), 0)
721
722    def test_hamt_delete_3(self):
723        A = HashKey(100, 'A')
724        B = HashKey(101, 'B')
725        C = HashKey(100100, 'C')
726        D = HashKey(100100, 'D')
727        E = HashKey(104, 'E')
728
729        h = hamt()
730        h = h.set(A, 'a')
731        h = h.set(B, 'b')
732        h = h.set(C, 'c')
733        h = h.set(D, 'd')
734        h = h.set(E, 'e')
735
736        orig_len = len(h)
737
738        # BitmapNode(size=6 bitmap=0b100110000):
739        #     NULL:
740        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
741        #             <Key name:A hash:100>: 'a'
742        #             NULL:
743        #                 CollisionNode(size=4 id=0x108572410):
744        #                     <Key name:C hash:100100>: 'c'
745        #                     <Key name:D hash:100100>: 'd'
746        #     <Key name:B hash:101>: 'b'
747        #     <Key name:E hash:104>: 'e'
748
749        h = h.delete(A)
750        self.assertEqual(len(h), orig_len - 1)
751
752        h = h.delete(E)
753        self.assertEqual(len(h), orig_len - 2)
754
755        self.assertEqual(h.get(C), 'c')
756        self.assertEqual(h.get(B), 'b')
757
758    def test_hamt_delete_4(self):
759        A = HashKey(100, 'A')
760        B = HashKey(101, 'B')
761        C = HashKey(100100, 'C')
762        D = HashKey(100100, 'D')
763        E = HashKey(100100, 'E')
764
765        h = hamt()
766        h = h.set(A, 'a')
767        h = h.set(B, 'b')
768        h = h.set(C, 'c')
769        h = h.set(D, 'd')
770        h = h.set(E, 'e')
771
772        orig_len = len(h)
773
774        # BitmapNode(size=4 bitmap=0b110000):
775        #     NULL:
776        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
777        #             <Key name:A hash:100>: 'a'
778        #             NULL:
779        #                 CollisionNode(size=6 id=0x10515ef30):
780        #                     <Key name:C hash:100100>: 'c'
781        #                     <Key name:D hash:100100>: 'd'
782        #                     <Key name:E hash:100100>: 'e'
783        #     <Key name:B hash:101>: 'b'
784
785        h = h.delete(D)
786        self.assertEqual(len(h), orig_len - 1)
787
788        h = h.delete(E)
789        self.assertEqual(len(h), orig_len - 2)
790
791        h = h.delete(C)
792        self.assertEqual(len(h), orig_len - 3)
793
794        h = h.delete(A)
795        self.assertEqual(len(h), orig_len - 4)
796
797        h = h.delete(B)
798        self.assertEqual(len(h), 0)
799
800    def test_hamt_delete_5(self):
801        h = hamt()
802
803        keys = []
804        for i in range(17):
805            key = HashKey(i, str(i))
806            keys.append(key)
807            h = h.set(key, f'val-{i}')
808
809        collision_key16 = HashKey(16, '18')
810        h = h.set(collision_key16, 'collision')
811
812        # ArrayNode(id=0x10f8b9318):
813        #     0::
814        #     BitmapNode(size=2 count=1 bitmap=0b1):
815        #         <Key name:0 hash:0>: 'val-0'
816        #
817        # ... 14 more BitmapNodes ...
818        #
819        #     15::
820        #     BitmapNode(size=2 count=1 bitmap=0b1):
821        #         <Key name:15 hash:15>: 'val-15'
822        #
823        #     16::
824        #     BitmapNode(size=2 count=1 bitmap=0b1):
825        #         NULL:
826        #             CollisionNode(size=4 id=0x10f2f5af8):
827        #                 <Key name:16 hash:16>: 'val-16'
828        #                 <Key name:18 hash:16>: 'collision'
829
830        self.assertEqual(len(h), 18)
831
832        h = h.delete(keys[2])
833        self.assertEqual(len(h), 17)
834
835        h = h.delete(collision_key16)
836        self.assertEqual(len(h), 16)
837        h = h.delete(keys[16])
838        self.assertEqual(len(h), 15)
839
840        h = h.delete(keys[1])
841        self.assertEqual(len(h), 14)
842        h = h.delete(keys[1])
843        self.assertEqual(len(h), 14)
844
845        for key in keys:
846            h = h.delete(key)
847        self.assertEqual(len(h), 0)
848
849    def test_hamt_items_1(self):
850        A = HashKey(100, 'A')
851        B = HashKey(201001, 'B')
852        C = HashKey(101001, 'C')
853        D = HashKey(103, 'D')
854        E = HashKey(104, 'E')
855        F = HashKey(110, 'F')
856
857        h = hamt()
858        h = h.set(A, 'a')
859        h = h.set(B, 'b')
860        h = h.set(C, 'c')
861        h = h.set(D, 'd')
862        h = h.set(E, 'e')
863        h = h.set(F, 'f')
864
865        it = h.items()
866        self.assertEqual(
867            set(list(it)),
868            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
869
870    def test_hamt_items_2(self):
871        A = HashKey(100, 'A')
872        B = HashKey(101, 'B')
873        C = HashKey(100100, 'C')
874        D = HashKey(100100, 'D')
875        E = HashKey(100100, 'E')
876        F = HashKey(110, 'F')
877
878        h = hamt()
879        h = h.set(A, 'a')
880        h = h.set(B, 'b')
881        h = h.set(C, 'c')
882        h = h.set(D, 'd')
883        h = h.set(E, 'e')
884        h = h.set(F, 'f')
885
886        it = h.items()
887        self.assertEqual(
888            set(list(it)),
889            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
890
891    def test_hamt_keys_1(self):
892        A = HashKey(100, 'A')
893        B = HashKey(101, 'B')
894        C = HashKey(100100, 'C')
895        D = HashKey(100100, 'D')
896        E = HashKey(100100, 'E')
897        F = HashKey(110, 'F')
898
899        h = hamt()
900        h = h.set(A, 'a')
901        h = h.set(B, 'b')
902        h = h.set(C, 'c')
903        h = h.set(D, 'd')
904        h = h.set(E, 'e')
905        h = h.set(F, 'f')
906
907        self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
908        self.assertEqual(set(list(h)), {A, B, C, D, E, F})
909
910    def test_hamt_items_3(self):
911        h = hamt()
912        self.assertEqual(len(h.items()), 0)
913        self.assertEqual(list(h.items()), [])
914
915    def test_hamt_eq_1(self):
916        A = HashKey(100, 'A')
917        B = HashKey(101, 'B')
918        C = HashKey(100100, 'C')
919        D = HashKey(100100, 'D')
920        E = HashKey(120, 'E')
921
922        h1 = hamt()
923        h1 = h1.set(A, 'a')
924        h1 = h1.set(B, 'b')
925        h1 = h1.set(C, 'c')
926        h1 = h1.set(D, 'd')
927
928        h2 = hamt()
929        h2 = h2.set(A, 'a')
930
931        self.assertFalse(h1 == h2)
932        self.assertTrue(h1 != h2)
933
934        h2 = h2.set(B, 'b')
935        self.assertFalse(h1 == h2)
936        self.assertTrue(h1 != h2)
937
938        h2 = h2.set(C, 'c')
939        self.assertFalse(h1 == h2)
940        self.assertTrue(h1 != h2)
941
942        h2 = h2.set(D, 'd2')
943        self.assertFalse(h1 == h2)
944        self.assertTrue(h1 != h2)
945
946        h2 = h2.set(D, 'd')
947        self.assertTrue(h1 == h2)
948        self.assertFalse(h1 != h2)
949
950        h2 = h2.set(E, 'e')
951        self.assertFalse(h1 == h2)
952        self.assertTrue(h1 != h2)
953
954        h2 = h2.delete(D)
955        self.assertFalse(h1 == h2)
956        self.assertTrue(h1 != h2)
957
958        h2 = h2.set(E, 'd')
959        self.assertFalse(h1 == h2)
960        self.assertTrue(h1 != h2)
961
962    def test_hamt_eq_2(self):
963        A = HashKey(100, 'A')
964        Er = HashKey(100, 'Er', error_on_eq_to=A)
965
966        h1 = hamt()
967        h1 = h1.set(A, 'a')
968
969        h2 = hamt()
970        h2 = h2.set(Er, 'a')
971
972        with self.assertRaisesRegex(ValueError, 'cannot compare'):
973            h1 == h2
974
975        with self.assertRaisesRegex(ValueError, 'cannot compare'):
976            h1 != h2
977
978    def test_hamt_gc_1(self):
979        A = HashKey(100, 'A')
980
981        h = hamt()
982        h = h.set(0, 0)  # empty HAMT node is memoized in hamt.c
983        ref = weakref.ref(h)
984
985        a = []
986        a.append(a)
987        a.append(h)
988        b = []
989        a.append(b)
990        b.append(a)
991        h = h.set(A, b)
992
993        del h, a, b
994
995        gc.collect()
996        gc.collect()
997        gc.collect()
998
999        self.assertIsNone(ref())
1000
1001    def test_hamt_gc_2(self):
1002        A = HashKey(100, 'A')
1003        B = HashKey(101, 'B')
1004
1005        h = hamt()
1006        h = h.set(A, 'a')
1007        h = h.set(A, h)
1008
1009        ref = weakref.ref(h)
1010        hi = h.items()
1011        next(hi)
1012
1013        del h, hi
1014
1015        gc.collect()
1016        gc.collect()
1017        gc.collect()
1018
1019        self.assertIsNone(ref())
1020
1021    def test_hamt_in_1(self):
1022        A = HashKey(100, 'A')
1023        AA = HashKey(100, 'A')
1024
1025        B = HashKey(101, 'B')
1026
1027        h = hamt()
1028        h = h.set(A, 1)
1029
1030        self.assertTrue(A in h)
1031        self.assertFalse(B in h)
1032
1033        with self.assertRaises(EqError):
1034            with HaskKeyCrasher(error_on_eq=True):
1035                AA in h
1036
1037        with self.assertRaises(HashingError):
1038            with HaskKeyCrasher(error_on_hash=True):
1039                AA in h
1040
1041    def test_hamt_getitem_1(self):
1042        A = HashKey(100, 'A')
1043        AA = HashKey(100, 'A')
1044
1045        B = HashKey(101, 'B')
1046
1047        h = hamt()
1048        h = h.set(A, 1)
1049
1050        self.assertEqual(h[A], 1)
1051        self.assertEqual(h[AA], 1)
1052
1053        with self.assertRaises(KeyError):
1054            h[B]
1055
1056        with self.assertRaises(EqError):
1057            with HaskKeyCrasher(error_on_eq=True):
1058                h[AA]
1059
1060        with self.assertRaises(HashingError):
1061            with HaskKeyCrasher(error_on_hash=True):
1062                h[AA]
1063
1064
1065if __name__ == "__main__":
1066    unittest.main()
1067