1import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11    from _testcapi import hamt
12except ImportError:
13    hamt = None
14
15
16def isolated_context(func):
17    """Needed to make reftracking test mode work."""
18    @functools.wraps(func)
19    def wrapper(*args, **kwargs):
20        ctx = contextvars.Context()
21        return ctx.run(func, *args, **kwargs)
22    return wrapper
23
24
25class ContextTest(unittest.TestCase):
26    def test_context_var_new_1(self):
27        with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28            contextvars.ContextVar()
29
30        with self.assertRaisesRegex(TypeError, 'must be a str'):
31            contextvars.ContextVar(1)
32
33        c = contextvars.ContextVar('aaa')
34        self.assertEqual(c.name, 'aaa')
35
36        with self.assertRaises(AttributeError):
37            c.name = 'bbb'
38
39        self.assertNotEqual(hash(c), hash('aaa'))
40
41    @isolated_context
42    def test_context_var_repr_1(self):
43        c = contextvars.ContextVar('a')
44        self.assertIn('a', repr(c))
45
46        c = contextvars.ContextVar('a', default=123)
47        self.assertIn('123', repr(c))
48
49        lst = []
50        c = contextvars.ContextVar('a', default=lst)
51        lst.append(c)
52        self.assertIn('...', repr(c))
53        self.assertIn('...', repr(lst))
54
55        t = c.set(1)
56        self.assertIn(repr(c), repr(t))
57        self.assertNotIn(' used ', repr(t))
58        c.reset(t)
59        self.assertIn(' used ', repr(t))
60
61    def test_context_subclassing_1(self):
62        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
63            class MyContextVar(contextvars.ContextVar):
64                # Potentially we might want ContextVars to be subclassable.
65                pass
66
67        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
68            class MyContext(contextvars.Context):
69                pass
70
71        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
72            class MyToken(contextvars.Token):
73                pass
74
75    def test_context_new_1(self):
76        with self.assertRaisesRegex(TypeError, 'any arguments'):
77            contextvars.Context(1)
78        with self.assertRaisesRegex(TypeError, 'any arguments'):
79            contextvars.Context(1, a=1)
80        with self.assertRaisesRegex(TypeError, 'any arguments'):
81            contextvars.Context(a=1)
82        contextvars.Context(**{})
83
84    def test_context_typerrors_1(self):
85        ctx = contextvars.Context()
86
87        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
88            ctx[1]
89        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90            1 in ctx
91        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
92            ctx.get(1)
93
94    def test_context_get_context_1(self):
95        ctx = contextvars.copy_context()
96        self.assertIsInstance(ctx, contextvars.Context)
97
98    def test_context_run_1(self):
99        ctx = contextvars.Context()
100
101        with self.assertRaisesRegex(TypeError, 'missing 1 required'):
102            ctx.run()
103
104    def test_context_run_2(self):
105        ctx = contextvars.Context()
106
107        def func(*args, **kwargs):
108            kwargs['spam'] = 'foo'
109            args += ('bar',)
110            return args, kwargs
111
112        for f in (func, functools.partial(func)):
113            # partial doesn't support FASTCALL
114
115            self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
116            self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
117
118            self.assertEqual(
119                ctx.run(f, a=2),
120                (('bar',), {'a': 2, 'spam': 'foo'}))
121
122            self.assertEqual(
123                ctx.run(f, 11, a=2),
124                ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
125
126            a = {}
127            self.assertEqual(
128                ctx.run(f, 11, **a),
129                ((11, 'bar'), {'spam': 'foo'}))
130            self.assertEqual(a, {})
131
132    def test_context_run_3(self):
133        ctx = contextvars.Context()
134
135        def func(*args, **kwargs):
136            1 / 0
137
138        with self.assertRaises(ZeroDivisionError):
139            ctx.run(func)
140        with self.assertRaises(ZeroDivisionError):
141            ctx.run(func, 1, 2)
142        with self.assertRaises(ZeroDivisionError):
143            ctx.run(func, 1, 2, a=123)
144
145    @isolated_context
146    def test_context_run_4(self):
147        ctx1 = contextvars.Context()
148        ctx2 = contextvars.Context()
149        var = contextvars.ContextVar('var')
150
151        def func2():
152            self.assertIsNone(var.get(None))
153
154        def func1():
155            self.assertIsNone(var.get(None))
156            var.set('spam')
157            ctx2.run(func2)
158            self.assertEqual(var.get(None), 'spam')
159
160            cur = contextvars.copy_context()
161            self.assertEqual(len(cur), 1)
162            self.assertEqual(cur[var], 'spam')
163            return cur
164
165        returned_ctx = ctx1.run(func1)
166        self.assertEqual(ctx1, returned_ctx)
167        self.assertEqual(returned_ctx[var], 'spam')
168        self.assertIn(var, returned_ctx)
169
170    def test_context_run_5(self):
171        ctx = contextvars.Context()
172        var = contextvars.ContextVar('var')
173
174        def func():
175            self.assertIsNone(var.get(None))
176            var.set('spam')
177            1 / 0
178
179        with self.assertRaises(ZeroDivisionError):
180            ctx.run(func)
181
182        self.assertIsNone(var.get(None))
183
184    def test_context_run_6(self):
185        ctx = contextvars.Context()
186        c = contextvars.ContextVar('a', default=0)
187
188        def fun():
189            self.assertEqual(c.get(), 0)
190            self.assertIsNone(ctx.get(c))
191
192            c.set(42)
193            self.assertEqual(c.get(), 42)
194            self.assertEqual(ctx.get(c), 42)
195
196        ctx.run(fun)
197
198    def test_context_run_7(self):
199        ctx = contextvars.Context()
200
201        def fun():
202            with self.assertRaisesRegex(RuntimeError, 'is already entered'):
203                ctx.run(fun)
204
205        ctx.run(fun)
206
207    @isolated_context
208    def test_context_getset_1(self):
209        c = contextvars.ContextVar('c')
210        with self.assertRaises(LookupError):
211            c.get()
212
213        self.assertIsNone(c.get(None))
214
215        t0 = c.set(42)
216        self.assertEqual(c.get(), 42)
217        self.assertEqual(c.get(None), 42)
218        self.assertIs(t0.old_value, t0.MISSING)
219        self.assertIs(t0.old_value, contextvars.Token.MISSING)
220        self.assertIs(t0.var, c)
221
222        t = c.set('spam')
223        self.assertEqual(c.get(), 'spam')
224        self.assertEqual(c.get(None), 'spam')
225        self.assertEqual(t.old_value, 42)
226        c.reset(t)
227
228        self.assertEqual(c.get(), 42)
229        self.assertEqual(c.get(None), 42)
230
231        c.set('spam2')
232        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
233            c.reset(t)
234        self.assertEqual(c.get(), 'spam2')
235
236        ctx1 = contextvars.copy_context()
237        self.assertIn(c, ctx1)
238
239        c.reset(t0)
240        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
241            c.reset(t0)
242        self.assertIsNone(c.get(None))
243
244        self.assertIn(c, ctx1)
245        self.assertEqual(ctx1[c], 'spam2')
246        self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
247        self.assertEqual(len(ctx1), 1)
248        self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
249        self.assertEqual(list(ctx1.values()), ['spam2'])
250        self.assertEqual(list(ctx1.keys()), [c])
251        self.assertEqual(list(ctx1), [c])
252
253        ctx2 = contextvars.copy_context()
254        self.assertNotIn(c, ctx2)
255        with self.assertRaises(KeyError):
256            ctx2[c]
257        self.assertEqual(ctx2.get(c, 'aa'), 'aa')
258        self.assertEqual(len(ctx2), 0)
259        self.assertEqual(list(ctx2), [])
260
261    @isolated_context
262    def test_context_getset_2(self):
263        v1 = contextvars.ContextVar('v1')
264        v2 = contextvars.ContextVar('v2')
265
266        t1 = v1.set(42)
267        with self.assertRaisesRegex(ValueError, 'by a different'):
268            v2.reset(t1)
269
270    @isolated_context
271    def test_context_getset_3(self):
272        c = contextvars.ContextVar('c', default=42)
273        ctx = contextvars.Context()
274
275        def fun():
276            self.assertEqual(c.get(), 42)
277            with self.assertRaises(KeyError):
278                ctx[c]
279            self.assertIsNone(ctx.get(c))
280            self.assertEqual(ctx.get(c, 'spam'), 'spam')
281            self.assertNotIn(c, ctx)
282            self.assertEqual(list(ctx.keys()), [])
283
284            t = c.set(1)
285            self.assertEqual(list(ctx.keys()), [c])
286            self.assertEqual(ctx[c], 1)
287
288            c.reset(t)
289            self.assertEqual(list(ctx.keys()), [])
290            with self.assertRaises(KeyError):
291                ctx[c]
292
293        ctx.run(fun)
294
295    @isolated_context
296    def test_context_getset_4(self):
297        c = contextvars.ContextVar('c', default=42)
298        ctx = contextvars.Context()
299
300        tok = ctx.run(c.set, 1)
301
302        with self.assertRaisesRegex(ValueError, 'different Context'):
303            c.reset(tok)
304
305    @isolated_context
306    def test_context_getset_5(self):
307        c = contextvars.ContextVar('c', default=42)
308        c.set([])
309
310        def fun():
311            c.set([])
312            c.get().append(42)
313            self.assertEqual(c.get(), [42])
314
315        contextvars.copy_context().run(fun)
316        self.assertEqual(c.get(), [])
317
318    def test_context_copy_1(self):
319        ctx1 = contextvars.Context()
320        c = contextvars.ContextVar('c', default=42)
321
322        def ctx1_fun():
323            c.set(10)
324
325            ctx2 = ctx1.copy()
326            self.assertEqual(ctx2[c], 10)
327
328            c.set(20)
329            self.assertEqual(ctx1[c], 20)
330            self.assertEqual(ctx2[c], 10)
331
332            ctx2.run(ctx2_fun)
333            self.assertEqual(ctx1[c], 20)
334            self.assertEqual(ctx2[c], 30)
335
336        def ctx2_fun():
337            self.assertEqual(c.get(), 10)
338            c.set(30)
339            self.assertEqual(c.get(), 30)
340
341        ctx1.run(ctx1_fun)
342
343    @isolated_context
344    def test_context_threads_1(self):
345        cvar = contextvars.ContextVar('cvar')
346
347        def sub(num):
348            for i in range(10):
349                cvar.set(num + i)
350                time.sleep(random.uniform(0.001, 0.05))
351                self.assertEqual(cvar.get(), num + i)
352            return num
353
354        tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
355        try:
356            results = list(tp.map(sub, range(10)))
357        finally:
358            tp.shutdown()
359        self.assertEqual(results, list(range(10)))
360
361    def test_contextvar_getitem(self):
362        clss = contextvars.ContextVar
363        self.assertEqual(clss[str], clss)
364
365
366# HAMT Tests
367
368
369class HashKey:
370    _crasher = None
371
372    def __init__(self, hash, name, *, error_on_eq_to=None):
373        assert hash != -1
374        self.name = name
375        self.hash = hash
376        self.error_on_eq_to = error_on_eq_to
377
378    def __repr__(self):
379        return f'<Key name:{self.name} hash:{self.hash}>'
380
381    def __hash__(self):
382        if self._crasher is not None and self._crasher.error_on_hash:
383            raise HashingError
384
385        return self.hash
386
387    def __eq__(self, other):
388        if not isinstance(other, HashKey):
389            return NotImplemented
390
391        if self._crasher is not None and self._crasher.error_on_eq:
392            raise EqError
393
394        if self.error_on_eq_to is not None and self.error_on_eq_to is other:
395            raise ValueError(f'cannot compare {self!r} to {other!r}')
396        if other.error_on_eq_to is not None and other.error_on_eq_to is self:
397            raise ValueError(f'cannot compare {other!r} to {self!r}')
398
399        return (self.name, self.hash) == (other.name, other.hash)
400
401
402class KeyStr(str):
403    def __hash__(self):
404        if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
405            raise HashingError
406        return super().__hash__()
407
408    def __eq__(self, other):
409        if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
410            raise EqError
411        return super().__eq__(other)
412
413
414class HaskKeyCrasher:
415    def __init__(self, *, error_on_hash=False, error_on_eq=False):
416        self.error_on_hash = error_on_hash
417        self.error_on_eq = error_on_eq
418
419    def __enter__(self):
420        if HashKey._crasher is not None:
421            raise RuntimeError('cannot nest crashers')
422        HashKey._crasher = self
423
424    def __exit__(self, *exc):
425        HashKey._crasher = None
426
427
428class HashingError(Exception):
429    pass
430
431
432class EqError(Exception):
433    pass
434
435
436@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
437class HamtTest(unittest.TestCase):
438
439    def test_hashkey_helper_1(self):
440        k1 = HashKey(10, 'aaa')
441        k2 = HashKey(10, 'bbb')
442
443        self.assertNotEqual(k1, k2)
444        self.assertEqual(hash(k1), hash(k2))
445
446        d = dict()
447        d[k1] = 'a'
448        d[k2] = 'b'
449
450        self.assertEqual(d[k1], 'a')
451        self.assertEqual(d[k2], 'b')
452
453    def test_hamt_basics_1(self):
454        h = hamt()
455        h = None  # NoQA
456
457    def test_hamt_basics_2(self):
458        h = hamt()
459        self.assertEqual(len(h), 0)
460
461        h2 = h.set('a', 'b')
462        self.assertIsNot(h, h2)
463        self.assertEqual(len(h), 0)
464        self.assertEqual(len(h2), 1)
465
466        self.assertIsNone(h.get('a'))
467        self.assertEqual(h.get('a', 42), 42)
468
469        self.assertEqual(h2.get('a'), 'b')
470
471        h3 = h2.set('b', 10)
472        self.assertIsNot(h2, h3)
473        self.assertEqual(len(h), 0)
474        self.assertEqual(len(h2), 1)
475        self.assertEqual(len(h3), 2)
476        self.assertEqual(h3.get('a'), 'b')
477        self.assertEqual(h3.get('b'), 10)
478
479        self.assertIsNone(h.get('b'))
480        self.assertIsNone(h2.get('b'))
481
482        self.assertIsNone(h.get('a'))
483        self.assertEqual(h2.get('a'), 'b')
484
485        h = h2 = h3 = None
486
487    def test_hamt_basics_3(self):
488        h = hamt()
489        o = object()
490        h1 = h.set('1', o)
491        h2 = h1.set('1', o)
492        self.assertIs(h1, h2)
493
494    def test_hamt_basics_4(self):
495        h = hamt()
496        h1 = h.set('key', [])
497        h2 = h1.set('key', [])
498        self.assertIsNot(h1, h2)
499        self.assertEqual(len(h1), 1)
500        self.assertEqual(len(h2), 1)
501        self.assertIsNot(h1.get('key'), h2.get('key'))
502
503    def test_hamt_collision_1(self):
504        k1 = HashKey(10, 'aaa')
505        k2 = HashKey(10, 'bbb')
506        k3 = HashKey(10, 'ccc')
507
508        h = hamt()
509        h2 = h.set(k1, 'a')
510        h3 = h2.set(k2, 'b')
511
512        self.assertEqual(h.get(k1), None)
513        self.assertEqual(h.get(k2), None)
514
515        self.assertEqual(h2.get(k1), 'a')
516        self.assertEqual(h2.get(k2), None)
517
518        self.assertEqual(h3.get(k1), 'a')
519        self.assertEqual(h3.get(k2), 'b')
520
521        h4 = h3.set(k2, 'cc')
522        h5 = h4.set(k3, 'aa')
523
524        self.assertEqual(h3.get(k1), 'a')
525        self.assertEqual(h3.get(k2), 'b')
526        self.assertEqual(h4.get(k1), 'a')
527        self.assertEqual(h4.get(k2), 'cc')
528        self.assertEqual(h4.get(k3), None)
529        self.assertEqual(h5.get(k1), 'a')
530        self.assertEqual(h5.get(k2), 'cc')
531        self.assertEqual(h5.get(k2), 'cc')
532        self.assertEqual(h5.get(k3), 'aa')
533
534        self.assertEqual(len(h), 0)
535        self.assertEqual(len(h2), 1)
536        self.assertEqual(len(h3), 2)
537        self.assertEqual(len(h4), 2)
538        self.assertEqual(len(h5), 3)
539
540    def test_hamt_stress(self):
541        COLLECTION_SIZE = 7000
542        TEST_ITERS_EVERY = 647
543        CRASH_HASH_EVERY = 97
544        CRASH_EQ_EVERY = 11
545        RUN_XTIMES = 3
546
547        for _ in range(RUN_XTIMES):
548            h = hamt()
549            d = dict()
550
551            for i in range(COLLECTION_SIZE):
552                key = KeyStr(i)
553
554                if not (i % CRASH_HASH_EVERY):
555                    with HaskKeyCrasher(error_on_hash=True):
556                        with self.assertRaises(HashingError):
557                            h.set(key, i)
558
559                h = h.set(key, i)
560
561                if not (i % CRASH_EQ_EVERY):
562                    with HaskKeyCrasher(error_on_eq=True):
563                        with self.assertRaises(EqError):
564                            h.get(KeyStr(i))  # really trigger __eq__
565
566                d[key] = i
567                self.assertEqual(len(d), len(h))
568
569                if not (i % TEST_ITERS_EVERY):
570                    self.assertEqual(set(h.items()), set(d.items()))
571                    self.assertEqual(len(h.items()), len(d.items()))
572
573            self.assertEqual(len(h), COLLECTION_SIZE)
574
575            for key in range(COLLECTION_SIZE):
576                self.assertEqual(h.get(KeyStr(key), 'not found'), key)
577
578            keys_to_delete = list(range(COLLECTION_SIZE))
579            random.shuffle(keys_to_delete)
580            for iter_i, i in enumerate(keys_to_delete):
581                key = KeyStr(i)
582
583                if not (iter_i % CRASH_HASH_EVERY):
584                    with HaskKeyCrasher(error_on_hash=True):
585                        with self.assertRaises(HashingError):
586                            h.delete(key)
587
588                if not (iter_i % CRASH_EQ_EVERY):
589                    with HaskKeyCrasher(error_on_eq=True):
590                        with self.assertRaises(EqError):
591                            h.delete(KeyStr(i))
592
593                h = h.delete(key)
594                self.assertEqual(h.get(key, 'not found'), 'not found')
595                del d[key]
596                self.assertEqual(len(d), len(h))
597
598                if iter_i == COLLECTION_SIZE // 2:
599                    hm = h
600                    dm = d.copy()
601
602                if not (iter_i % TEST_ITERS_EVERY):
603                    self.assertEqual(set(h.keys()), set(d.keys()))
604                    self.assertEqual(len(h.keys()), len(d.keys()))
605
606            self.assertEqual(len(d), 0)
607            self.assertEqual(len(h), 0)
608
609            # ============
610
611            for key in dm:
612                self.assertEqual(hm.get(str(key)), dm[key])
613            self.assertEqual(len(dm), len(hm))
614
615            for i, key in enumerate(keys_to_delete):
616                hm = hm.delete(str(key))
617                self.assertEqual(hm.get(str(key), 'not found'), 'not found')
618                dm.pop(str(key), None)
619                self.assertEqual(len(d), len(h))
620
621                if not (i % TEST_ITERS_EVERY):
622                    self.assertEqual(set(h.values()), set(d.values()))
623                    self.assertEqual(len(h.values()), len(d.values()))
624
625            self.assertEqual(len(d), 0)
626            self.assertEqual(len(h), 0)
627            self.assertEqual(list(h.items()), [])
628
629    def test_hamt_delete_1(self):
630        A = HashKey(100, 'A')
631        B = HashKey(101, 'B')
632        C = HashKey(102, 'C')
633        D = HashKey(103, 'D')
634        E = HashKey(104, 'E')
635        Z = HashKey(-100, 'Z')
636
637        Er = HashKey(103, 'Er', error_on_eq_to=D)
638
639        h = hamt()
640        h = h.set(A, 'a')
641        h = h.set(B, 'b')
642        h = h.set(C, 'c')
643        h = h.set(D, 'd')
644        h = h.set(E, 'e')
645
646        orig_len = len(h)
647
648        # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
649        #     <Key name:A hash:100>: 'a'
650        #     <Key name:B hash:101>: 'b'
651        #     <Key name:C hash:102>: 'c'
652        #     <Key name:D hash:103>: 'd'
653        #     <Key name:E hash:104>: 'e'
654
655        h = h.delete(C)
656        self.assertEqual(len(h), orig_len - 1)
657
658        with self.assertRaisesRegex(ValueError, 'cannot compare'):
659            h.delete(Er)
660
661        h = h.delete(D)
662        self.assertEqual(len(h), orig_len - 2)
663
664        h2 = h.delete(Z)
665        self.assertIs(h2, h)
666
667        h = h.delete(A)
668        self.assertEqual(len(h), orig_len - 3)
669
670        self.assertEqual(h.get(A, 42), 42)
671        self.assertEqual(h.get(B), 'b')
672        self.assertEqual(h.get(E), 'e')
673
674    def test_hamt_delete_2(self):
675        A = HashKey(100, 'A')
676        B = HashKey(201001, 'B')
677        C = HashKey(101001, 'C')
678        D = HashKey(103, 'D')
679        E = HashKey(104, 'E')
680        Z = HashKey(-100, 'Z')
681
682        Er = HashKey(201001, 'Er', error_on_eq_to=B)
683
684        h = hamt()
685        h = h.set(A, 'a')
686        h = h.set(B, 'b')
687        h = h.set(C, 'c')
688        h = h.set(D, 'd')
689        h = h.set(E, 'e')
690
691        orig_len = len(h)
692
693        # BitmapNode(size=8 bitmap=0b1110010000):
694        #     <Key name:A hash:100>: 'a'
695        #     <Key name:D hash:103>: 'd'
696        #     <Key name:E hash:104>: 'e'
697        #     NULL:
698        #         BitmapNode(size=4 bitmap=0b100000000001000000000):
699        #             <Key name:B hash:201001>: 'b'
700        #             <Key name:C hash:101001>: 'c'
701
702        with self.assertRaisesRegex(ValueError, 'cannot compare'):
703            h.delete(Er)
704
705        h = h.delete(Z)
706        self.assertEqual(len(h), orig_len)
707
708        h = h.delete(C)
709        self.assertEqual(len(h), orig_len - 1)
710
711        h = h.delete(B)
712        self.assertEqual(len(h), orig_len - 2)
713
714        h = h.delete(A)
715        self.assertEqual(len(h), orig_len - 3)
716
717        self.assertEqual(h.get(D), 'd')
718        self.assertEqual(h.get(E), 'e')
719
720        h = h.delete(A)
721        h = h.delete(B)
722        h = h.delete(D)
723        h = h.delete(E)
724        self.assertEqual(len(h), 0)
725
726    def test_hamt_delete_3(self):
727        A = HashKey(100, 'A')
728        B = HashKey(101, 'B')
729        C = HashKey(100100, 'C')
730        D = HashKey(100100, 'D')
731        E = HashKey(104, 'E')
732
733        h = hamt()
734        h = h.set(A, 'a')
735        h = h.set(B, 'b')
736        h = h.set(C, 'c')
737        h = h.set(D, 'd')
738        h = h.set(E, 'e')
739
740        orig_len = len(h)
741
742        # BitmapNode(size=6 bitmap=0b100110000):
743        #     NULL:
744        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
745        #             <Key name:A hash:100>: 'a'
746        #             NULL:
747        #                 CollisionNode(size=4 id=0x108572410):
748        #                     <Key name:C hash:100100>: 'c'
749        #                     <Key name:D hash:100100>: 'd'
750        #     <Key name:B hash:101>: 'b'
751        #     <Key name:E hash:104>: 'e'
752
753        h = h.delete(A)
754        self.assertEqual(len(h), orig_len - 1)
755
756        h = h.delete(E)
757        self.assertEqual(len(h), orig_len - 2)
758
759        self.assertEqual(h.get(C), 'c')
760        self.assertEqual(h.get(B), 'b')
761
762    def test_hamt_delete_4(self):
763        A = HashKey(100, 'A')
764        B = HashKey(101, 'B')
765        C = HashKey(100100, 'C')
766        D = HashKey(100100, 'D')
767        E = HashKey(100100, 'E')
768
769        h = hamt()
770        h = h.set(A, 'a')
771        h = h.set(B, 'b')
772        h = h.set(C, 'c')
773        h = h.set(D, 'd')
774        h = h.set(E, 'e')
775
776        orig_len = len(h)
777
778        # BitmapNode(size=4 bitmap=0b110000):
779        #     NULL:
780        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
781        #             <Key name:A hash:100>: 'a'
782        #             NULL:
783        #                 CollisionNode(size=6 id=0x10515ef30):
784        #                     <Key name:C hash:100100>: 'c'
785        #                     <Key name:D hash:100100>: 'd'
786        #                     <Key name:E hash:100100>: 'e'
787        #     <Key name:B hash:101>: 'b'
788
789        h = h.delete(D)
790        self.assertEqual(len(h), orig_len - 1)
791
792        h = h.delete(E)
793        self.assertEqual(len(h), orig_len - 2)
794
795        h = h.delete(C)
796        self.assertEqual(len(h), orig_len - 3)
797
798        h = h.delete(A)
799        self.assertEqual(len(h), orig_len - 4)
800
801        h = h.delete(B)
802        self.assertEqual(len(h), 0)
803
804    def test_hamt_delete_5(self):
805        h = hamt()
806
807        keys = []
808        for i in range(17):
809            key = HashKey(i, str(i))
810            keys.append(key)
811            h = h.set(key, f'val-{i}')
812
813        collision_key16 = HashKey(16, '18')
814        h = h.set(collision_key16, 'collision')
815
816        # ArrayNode(id=0x10f8b9318):
817        #     0::
818        #     BitmapNode(size=2 count=1 bitmap=0b1):
819        #         <Key name:0 hash:0>: 'val-0'
820        #
821        # ... 14 more BitmapNodes ...
822        #
823        #     15::
824        #     BitmapNode(size=2 count=1 bitmap=0b1):
825        #         <Key name:15 hash:15>: 'val-15'
826        #
827        #     16::
828        #     BitmapNode(size=2 count=1 bitmap=0b1):
829        #         NULL:
830        #             CollisionNode(size=4 id=0x10f2f5af8):
831        #                 <Key name:16 hash:16>: 'val-16'
832        #                 <Key name:18 hash:16>: 'collision'
833
834        self.assertEqual(len(h), 18)
835
836        h = h.delete(keys[2])
837        self.assertEqual(len(h), 17)
838
839        h = h.delete(collision_key16)
840        self.assertEqual(len(h), 16)
841        h = h.delete(keys[16])
842        self.assertEqual(len(h), 15)
843
844        h = h.delete(keys[1])
845        self.assertEqual(len(h), 14)
846        h = h.delete(keys[1])
847        self.assertEqual(len(h), 14)
848
849        for key in keys:
850            h = h.delete(key)
851        self.assertEqual(len(h), 0)
852
853    def test_hamt_items_1(self):
854        A = HashKey(100, 'A')
855        B = HashKey(201001, 'B')
856        C = HashKey(101001, 'C')
857        D = HashKey(103, 'D')
858        E = HashKey(104, 'E')
859        F = HashKey(110, 'F')
860
861        h = hamt()
862        h = h.set(A, 'a')
863        h = h.set(B, 'b')
864        h = h.set(C, 'c')
865        h = h.set(D, 'd')
866        h = h.set(E, 'e')
867        h = h.set(F, 'f')
868
869        it = h.items()
870        self.assertEqual(
871            set(list(it)),
872            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
873
874    def test_hamt_items_2(self):
875        A = HashKey(100, 'A')
876        B = HashKey(101, 'B')
877        C = HashKey(100100, 'C')
878        D = HashKey(100100, 'D')
879        E = HashKey(100100, 'E')
880        F = HashKey(110, 'F')
881
882        h = hamt()
883        h = h.set(A, 'a')
884        h = h.set(B, 'b')
885        h = h.set(C, 'c')
886        h = h.set(D, 'd')
887        h = h.set(E, 'e')
888        h = h.set(F, 'f')
889
890        it = h.items()
891        self.assertEqual(
892            set(list(it)),
893            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
894
895    def test_hamt_keys_1(self):
896        A = HashKey(100, 'A')
897        B = HashKey(101, 'B')
898        C = HashKey(100100, 'C')
899        D = HashKey(100100, 'D')
900        E = HashKey(100100, 'E')
901        F = HashKey(110, 'F')
902
903        h = hamt()
904        h = h.set(A, 'a')
905        h = h.set(B, 'b')
906        h = h.set(C, 'c')
907        h = h.set(D, 'd')
908        h = h.set(E, 'e')
909        h = h.set(F, 'f')
910
911        self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
912        self.assertEqual(set(list(h)), {A, B, C, D, E, F})
913
914    def test_hamt_items_3(self):
915        h = hamt()
916        self.assertEqual(len(h.items()), 0)
917        self.assertEqual(list(h.items()), [])
918
919    def test_hamt_eq_1(self):
920        A = HashKey(100, 'A')
921        B = HashKey(101, 'B')
922        C = HashKey(100100, 'C')
923        D = HashKey(100100, 'D')
924        E = HashKey(120, 'E')
925
926        h1 = hamt()
927        h1 = h1.set(A, 'a')
928        h1 = h1.set(B, 'b')
929        h1 = h1.set(C, 'c')
930        h1 = h1.set(D, 'd')
931
932        h2 = hamt()
933        h2 = h2.set(A, 'a')
934
935        self.assertFalse(h1 == h2)
936        self.assertTrue(h1 != h2)
937
938        h2 = h2.set(B, 'b')
939        self.assertFalse(h1 == h2)
940        self.assertTrue(h1 != h2)
941
942        h2 = h2.set(C, 'c')
943        self.assertFalse(h1 == h2)
944        self.assertTrue(h1 != h2)
945
946        h2 = h2.set(D, 'd2')
947        self.assertFalse(h1 == h2)
948        self.assertTrue(h1 != h2)
949
950        h2 = h2.set(D, 'd')
951        self.assertTrue(h1 == h2)
952        self.assertFalse(h1 != h2)
953
954        h2 = h2.set(E, 'e')
955        self.assertFalse(h1 == h2)
956        self.assertTrue(h1 != h2)
957
958        h2 = h2.delete(D)
959        self.assertFalse(h1 == h2)
960        self.assertTrue(h1 != h2)
961
962        h2 = h2.set(E, 'd')
963        self.assertFalse(h1 == h2)
964        self.assertTrue(h1 != h2)
965
966    def test_hamt_eq_2(self):
967        A = HashKey(100, 'A')
968        Er = HashKey(100, 'Er', error_on_eq_to=A)
969
970        h1 = hamt()
971        h1 = h1.set(A, 'a')
972
973        h2 = hamt()
974        h2 = h2.set(Er, 'a')
975
976        with self.assertRaisesRegex(ValueError, 'cannot compare'):
977            h1 == h2
978
979        with self.assertRaisesRegex(ValueError, 'cannot compare'):
980            h1 != h2
981
982    def test_hamt_gc_1(self):
983        A = HashKey(100, 'A')
984
985        h = hamt()
986        h = h.set(0, 0)  # empty HAMT node is memoized in hamt.c
987        ref = weakref.ref(h)
988
989        a = []
990        a.append(a)
991        a.append(h)
992        b = []
993        a.append(b)
994        b.append(a)
995        h = h.set(A, b)
996
997        del h, a, b
998
999        gc.collect()
1000        gc.collect()
1001        gc.collect()
1002
1003        self.assertIsNone(ref())
1004
1005    def test_hamt_gc_2(self):
1006        A = HashKey(100, 'A')
1007        B = HashKey(101, 'B')
1008
1009        h = hamt()
1010        h = h.set(A, 'a')
1011        h = h.set(A, h)
1012
1013        ref = weakref.ref(h)
1014        hi = h.items()
1015        next(hi)
1016
1017        del h, hi
1018
1019        gc.collect()
1020        gc.collect()
1021        gc.collect()
1022
1023        self.assertIsNone(ref())
1024
1025    def test_hamt_in_1(self):
1026        A = HashKey(100, 'A')
1027        AA = HashKey(100, 'A')
1028
1029        B = HashKey(101, 'B')
1030
1031        h = hamt()
1032        h = h.set(A, 1)
1033
1034        self.assertTrue(A in h)
1035        self.assertFalse(B in h)
1036
1037        with self.assertRaises(EqError):
1038            with HaskKeyCrasher(error_on_eq=True):
1039                AA in h
1040
1041        with self.assertRaises(HashingError):
1042            with HaskKeyCrasher(error_on_hash=True):
1043                AA in h
1044
1045    def test_hamt_getitem_1(self):
1046        A = HashKey(100, 'A')
1047        AA = HashKey(100, 'A')
1048
1049        B = HashKey(101, 'B')
1050
1051        h = hamt()
1052        h = h.set(A, 1)
1053
1054        self.assertEqual(h[A], 1)
1055        self.assertEqual(h[AA], 1)
1056
1057        with self.assertRaises(KeyError):
1058            h[B]
1059
1060        with self.assertRaises(EqError):
1061            with HaskKeyCrasher(error_on_eq=True):
1062                h[AA]
1063
1064        with self.assertRaises(HashingError):
1065            with HaskKeyCrasher(error_on_hash=True):
1066                h[AA]
1067
1068
1069if __name__ == "__main__":
1070    unittest.main()
1071