1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16import os
17import weakref
18import gc
19from weakref import proxy
20import contextlib
21
22from test.support.script_helper import assert_python_ok
23
24import functools
25
26py_functools = support.import_fresh_module('functools', blocked=['_functools'])
27c_functools = support.import_fresh_module('functools', fresh=['_functools'])
28
29decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
30
31@contextlib.contextmanager
32def replaced_module(name, replacement):
33    original_module = sys.modules[name]
34    sys.modules[name] = replacement
35    try:
36        yield
37    finally:
38        sys.modules[name] = original_module
39
40def capture(*args, **kw):
41    """capture all positional and keyword arguments"""
42    return args, kw
43
44
45def signature(part):
46    """ return the signature of a partial object """
47    return (part.func, part.args, part.keywords, part.__dict__)
48
49class MyTuple(tuple):
50    pass
51
52class BadTuple(tuple):
53    def __add__(self, other):
54        return list(self) + list(other)
55
56class MyDict(dict):
57    pass
58
59
60class TestPartial:
61
62    def test_basic_examples(self):
63        p = self.partial(capture, 1, 2, a=10, b=20)
64        self.assertTrue(callable(p))
65        self.assertEqual(p(3, 4, b=30, c=40),
66                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
67        p = self.partial(map, lambda x: x*10)
68        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
69
70    def test_attributes(self):
71        p = self.partial(capture, 1, 2, a=10, b=20)
72        # attributes should be readable
73        self.assertEqual(p.func, capture)
74        self.assertEqual(p.args, (1, 2))
75        self.assertEqual(p.keywords, dict(a=10, b=20))
76
77    def test_argument_checking(self):
78        self.assertRaises(TypeError, self.partial)     # need at least a func arg
79        try:
80            self.partial(2)()
81        except TypeError:
82            pass
83        else:
84            self.fail('First arg not checked for callability')
85
86    def test_protection_of_callers_dict_argument(self):
87        # a caller's dictionary should not be altered by partial
88        def func(a=10, b=20):
89            return a
90        d = {'a':3}
91        p = self.partial(func, a=5)
92        self.assertEqual(p(**d), 3)
93        self.assertEqual(d, {'a':3})
94        p(b=7)
95        self.assertEqual(d, {'a':3})
96
97    def test_kwargs_copy(self):
98        # Issue #29532: Altering a kwarg dictionary passed to a constructor
99        # should not affect a partial object after creation
100        d = {'a': 3}
101        p = self.partial(capture, **d)
102        self.assertEqual(p(), ((), {'a': 3}))
103        d['a'] = 5
104        self.assertEqual(p(), ((), {'a': 3}))
105
106    def test_arg_combinations(self):
107        # exercise special code paths for zero args in either partial
108        # object or the caller
109        p = self.partial(capture)
110        self.assertEqual(p(), ((), {}))
111        self.assertEqual(p(1,2), ((1,2), {}))
112        p = self.partial(capture, 1, 2)
113        self.assertEqual(p(), ((1,2), {}))
114        self.assertEqual(p(3,4), ((1,2,3,4), {}))
115
116    def test_kw_combinations(self):
117        # exercise special code paths for no keyword args in
118        # either the partial object or the caller
119        p = self.partial(capture)
120        self.assertEqual(p.keywords, {})
121        self.assertEqual(p(), ((), {}))
122        self.assertEqual(p(a=1), ((), {'a':1}))
123        p = self.partial(capture, a=1)
124        self.assertEqual(p.keywords, {'a':1})
125        self.assertEqual(p(), ((), {'a':1}))
126        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
127        # keyword args in the call override those in the partial object
128        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
129
130    def test_positional(self):
131        # make sure positional arguments are captured correctly
132        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
133            p = self.partial(capture, *args)
134            expected = args + ('x',)
135            got, empty = p('x')
136            self.assertTrue(expected == got and empty == {})
137
138    def test_keyword(self):
139        # make sure keyword arguments are captured correctly
140        for a in ['a', 0, None, 3.5]:
141            p = self.partial(capture, a=a)
142            expected = {'a':a,'x':None}
143            empty, got = p(x=None)
144            self.assertTrue(expected == got and empty == ())
145
146    def test_no_side_effects(self):
147        # make sure there are no side effects that affect subsequent calls
148        p = self.partial(capture, 0, a=1)
149        args1, kw1 = p(1, b=2)
150        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
151        args2, kw2 = p()
152        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
153
154    def test_error_propagation(self):
155        def f(x, y):
156            x / y
157        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
158        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
159        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
160        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
161
162    def test_weakref(self):
163        f = self.partial(int, base=16)
164        p = proxy(f)
165        self.assertEqual(f.func, p.func)
166        f = None
167        support.gc_collect()  # For PyPy or other GCs.
168        self.assertRaises(ReferenceError, getattr, p, 'func')
169
170    def test_with_bound_and_unbound_methods(self):
171        data = list(map(str, range(10)))
172        join = self.partial(str.join, '')
173        self.assertEqual(join(data), '0123456789')
174        join = self.partial(''.join)
175        self.assertEqual(join(data), '0123456789')
176
177    def test_nested_optimization(self):
178        partial = self.partial
179        inner = partial(signature, 'asdf')
180        nested = partial(inner, bar=True)
181        flat = partial(signature, 'asdf', bar=True)
182        self.assertEqual(signature(nested), signature(flat))
183
184    def test_nested_partial_with_attribute(self):
185        # see issue 25137
186        partial = self.partial
187
188        def foo(bar):
189            return bar
190
191        p = partial(foo, 'first')
192        p2 = partial(p, 'second')
193        p2.new_attr = 'spam'
194        self.assertEqual(p2.new_attr, 'spam')
195
196    def test_repr(self):
197        args = (object(), object())
198        args_repr = ', '.join(repr(a) for a in args)
199        kwargs = {'a': object(), 'b': object()}
200        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
201                        'b={b!r}, a={a!r}'.format_map(kwargs)]
202        if self.partial in (c_functools.partial, py_functools.partial):
203            name = 'functools.partial'
204        else:
205            name = self.partial.__name__
206
207        f = self.partial(capture)
208        self.assertEqual(f'{name}({capture!r})', repr(f))
209
210        f = self.partial(capture, *args)
211        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
212
213        f = self.partial(capture, **kwargs)
214        self.assertIn(repr(f),
215                      [f'{name}({capture!r}, {kwargs_repr})'
216                       for kwargs_repr in kwargs_reprs])
217
218        f = self.partial(capture, *args, **kwargs)
219        self.assertIn(repr(f),
220                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
221                       for kwargs_repr in kwargs_reprs])
222
223    def test_recursive_repr(self):
224        if self.partial in (c_functools.partial, py_functools.partial):
225            name = 'functools.partial'
226        else:
227            name = self.partial.__name__
228
229        f = self.partial(capture)
230        f.__setstate__((f, (), {}, {}))
231        try:
232            self.assertEqual(repr(f), '%s(...)' % (name,))
233        finally:
234            f.__setstate__((capture, (), {}, {}))
235
236        f = self.partial(capture)
237        f.__setstate__((capture, (f,), {}, {}))
238        try:
239            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
240        finally:
241            f.__setstate__((capture, (), {}, {}))
242
243        f = self.partial(capture)
244        f.__setstate__((capture, (), {'a': f}, {}))
245        try:
246            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
247        finally:
248            f.__setstate__((capture, (), {}, {}))
249
250    def test_pickle(self):
251        with self.AllowPickle():
252            f = self.partial(signature, ['asdf'], bar=[True])
253            f.attr = []
254            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
255                f_copy = pickle.loads(pickle.dumps(f, proto))
256                self.assertEqual(signature(f_copy), signature(f))
257
258    def test_copy(self):
259        f = self.partial(signature, ['asdf'], bar=[True])
260        f.attr = []
261        f_copy = copy.copy(f)
262        self.assertEqual(signature(f_copy), signature(f))
263        self.assertIs(f_copy.attr, f.attr)
264        self.assertIs(f_copy.args, f.args)
265        self.assertIs(f_copy.keywords, f.keywords)
266
267    def test_deepcopy(self):
268        f = self.partial(signature, ['asdf'], bar=[True])
269        f.attr = []
270        f_copy = copy.deepcopy(f)
271        self.assertEqual(signature(f_copy), signature(f))
272        self.assertIsNot(f_copy.attr, f.attr)
273        self.assertIsNot(f_copy.args, f.args)
274        self.assertIsNot(f_copy.args[0], f.args[0])
275        self.assertIsNot(f_copy.keywords, f.keywords)
276        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
277
278    def test_setstate(self):
279        f = self.partial(signature)
280        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
281
282        self.assertEqual(signature(f),
283                         (capture, (1,), dict(a=10), dict(attr=[])))
284        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
285
286        f.__setstate__((capture, (1,), dict(a=10), None))
287
288        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
289        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
290
291        f.__setstate__((capture, (1,), None, None))
292        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
293        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
294        self.assertEqual(f(2), ((1, 2), {}))
295        self.assertEqual(f(), ((1,), {}))
296
297        f.__setstate__((capture, (), {}, None))
298        self.assertEqual(signature(f), (capture, (), {}, {}))
299        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
300        self.assertEqual(f(2), ((2,), {}))
301        self.assertEqual(f(), ((), {}))
302
303    def test_setstate_errors(self):
304        f = self.partial(signature)
305        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
306        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
307        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
308        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
309        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
310        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
311        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
312
313    def test_setstate_subclasses(self):
314        f = self.partial(signature)
315        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
316        s = signature(f)
317        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
318        self.assertIs(type(s[1]), tuple)
319        self.assertIs(type(s[2]), dict)
320        r = f()
321        self.assertEqual(r, ((1,), {'a': 10}))
322        self.assertIs(type(r[0]), tuple)
323        self.assertIs(type(r[1]), dict)
324
325        f.__setstate__((capture, BadTuple((1,)), {}, None))
326        s = signature(f)
327        self.assertEqual(s, (capture, (1,), {}, {}))
328        self.assertIs(type(s[1]), tuple)
329        r = f(2)
330        self.assertEqual(r, ((1, 2), {}))
331        self.assertIs(type(r[0]), tuple)
332
333    def test_recursive_pickle(self):
334        with self.AllowPickle():
335            f = self.partial(capture)
336            f.__setstate__((f, (), {}, {}))
337            try:
338                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
339                    with self.assertRaises(RecursionError):
340                        pickle.dumps(f, proto)
341            finally:
342                f.__setstate__((capture, (), {}, {}))
343
344            f = self.partial(capture)
345            f.__setstate__((capture, (f,), {}, {}))
346            try:
347                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
348                    f_copy = pickle.loads(pickle.dumps(f, proto))
349                    try:
350                        self.assertIs(f_copy.args[0], f_copy)
351                    finally:
352                        f_copy.__setstate__((capture, (), {}, {}))
353            finally:
354                f.__setstate__((capture, (), {}, {}))
355
356            f = self.partial(capture)
357            f.__setstate__((capture, (), {'a': f}, {}))
358            try:
359                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
360                    f_copy = pickle.loads(pickle.dumps(f, proto))
361                    try:
362                        self.assertIs(f_copy.keywords['a'], f_copy)
363                    finally:
364                        f_copy.__setstate__((capture, (), {}, {}))
365            finally:
366                f.__setstate__((capture, (), {}, {}))
367
368    # Issue 6083: Reference counting bug
369    def test_setstate_refcount(self):
370        class BadSequence:
371            def __len__(self):
372                return 4
373            def __getitem__(self, key):
374                if key == 0:
375                    return max
376                elif key == 1:
377                    return tuple(range(1000000))
378                elif key in (2, 3):
379                    return {}
380                raise IndexError
381
382        f = self.partial(object)
383        self.assertRaises(TypeError, f.__setstate__, BadSequence())
384
385@unittest.skipUnless(c_functools, 'requires the C _functools module')
386class TestPartialC(TestPartial, unittest.TestCase):
387    if c_functools:
388        partial = c_functools.partial
389
390    class AllowPickle:
391        def __enter__(self):
392            return self
393        def __exit__(self, type, value, tb):
394            return False
395
396    def test_attributes_unwritable(self):
397        # attributes should not be writable
398        p = self.partial(capture, 1, 2, a=10, b=20)
399        self.assertRaises(AttributeError, setattr, p, 'func', map)
400        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
401        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
402
403        p = self.partial(hex)
404        try:
405            del p.__dict__
406        except TypeError:
407            pass
408        else:
409            self.fail('partial object allowed __dict__ to be deleted')
410
411    def test_manually_adding_non_string_keyword(self):
412        p = self.partial(capture)
413        # Adding a non-string/unicode keyword to partial kwargs
414        p.keywords[1234] = 'value'
415        r = repr(p)
416        self.assertIn('1234', r)
417        self.assertIn("'value'", r)
418        with self.assertRaises(TypeError):
419            p()
420
421    def test_keystr_replaces_value(self):
422        p = self.partial(capture)
423
424        class MutatesYourDict(object):
425            def __str__(self):
426                p.keywords[self] = ['sth2']
427                return 'astr'
428
429        # Replacing the value during key formatting should keep the original
430        # value alive (at least long enough).
431        p.keywords[MutatesYourDict()] = ['sth']
432        r = repr(p)
433        self.assertIn('astr', r)
434        self.assertIn("['sth']", r)
435
436
437class TestPartialPy(TestPartial, unittest.TestCase):
438    partial = py_functools.partial
439
440    class AllowPickle:
441        def __init__(self):
442            self._cm = replaced_module("functools", py_functools)
443        def __enter__(self):
444            return self._cm.__enter__()
445        def __exit__(self, type, value, tb):
446            return self._cm.__exit__(type, value, tb)
447
448if c_functools:
449    class CPartialSubclass(c_functools.partial):
450        pass
451
452class PyPartialSubclass(py_functools.partial):
453    pass
454
455@unittest.skipUnless(c_functools, 'requires the C _functools module')
456class TestPartialCSubclass(TestPartialC):
457    if c_functools:
458        partial = CPartialSubclass
459
460    # partial subclasses are not optimized for nested calls
461    test_nested_optimization = None
462
463class TestPartialPySubclass(TestPartialPy):
464    partial = PyPartialSubclass
465
466class TestPartialMethod(unittest.TestCase):
467
468    class A(object):
469        nothing = functools.partialmethod(capture)
470        positional = functools.partialmethod(capture, 1)
471        keywords = functools.partialmethod(capture, a=2)
472        both = functools.partialmethod(capture, 3, b=4)
473        spec_keywords = functools.partialmethod(capture, self=1, func=2)
474
475        nested = functools.partialmethod(positional, 5)
476
477        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
478
479        static = functools.partialmethod(staticmethod(capture), 8)
480        cls = functools.partialmethod(classmethod(capture), d=9)
481
482    a = A()
483
484    def test_arg_combinations(self):
485        self.assertEqual(self.a.nothing(), ((self.a,), {}))
486        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
487        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
488        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
489
490        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
491        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
492        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
493        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
494
495        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
496        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
497        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
498        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
499
500        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
501        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
502        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
503        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
504
505        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
507        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
508
509    def test_nested(self):
510        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
511        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
512        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
513        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
514
515        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517    def test_over_partial(self):
518        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
519        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
520        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
521        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
522
523        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525    def test_bound_method_introspection(self):
526        obj = self.a
527        self.assertIs(obj.both.__self__, obj)
528        self.assertIs(obj.nested.__self__, obj)
529        self.assertIs(obj.over_partial.__self__, obj)
530        self.assertIs(obj.cls.__self__, self.A)
531        self.assertIs(self.A.cls.__self__, self.A)
532
533    def test_unbound_method_retrieval(self):
534        obj = self.A
535        self.assertFalse(hasattr(obj.both, "__self__"))
536        self.assertFalse(hasattr(obj.nested, "__self__"))
537        self.assertFalse(hasattr(obj.over_partial, "__self__"))
538        self.assertFalse(hasattr(obj.static, "__self__"))
539        self.assertFalse(hasattr(self.a.static, "__self__"))
540
541    def test_descriptors(self):
542        for obj in [self.A, self.a]:
543            with self.subTest(obj=obj):
544                self.assertEqual(obj.static(), ((8,), {}))
545                self.assertEqual(obj.static(5), ((8, 5), {}))
546                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
547                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
548
549                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
550                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
551                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
552                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
553
554    def test_overriding_keywords(self):
555        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
556        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
557
558    def test_invalid_args(self):
559        with self.assertRaises(TypeError):
560            class B(object):
561                method = functools.partialmethod(None, 1)
562        with self.assertRaises(TypeError):
563            class B:
564                method = functools.partialmethod()
565        with self.assertRaises(TypeError):
566            class B:
567                method = functools.partialmethod(func=capture, a=1)
568
569    def test_repr(self):
570        self.assertEqual(repr(vars(self.A)['both']),
571                         'functools.partialmethod({}, 3, b=4)'.format(capture))
572
573    def test_abstract(self):
574        class Abstract(abc.ABCMeta):
575
576            @abc.abstractmethod
577            def add(self, x, y):
578                pass
579
580            add5 = functools.partialmethod(add, 5)
581
582        self.assertTrue(Abstract.add.__isabstractmethod__)
583        self.assertTrue(Abstract.add5.__isabstractmethod__)
584
585        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
586            self.assertFalse(getattr(func, '__isabstractmethod__', False))
587
588    def test_positional_only(self):
589        def f(a, b, /):
590            return a + b
591
592        p = functools.partial(f, 1)
593        self.assertEqual(p(2), f(1, 2))
594
595
596class TestUpdateWrapper(unittest.TestCase):
597
598    def check_wrapper(self, wrapper, wrapped,
599                      assigned=functools.WRAPPER_ASSIGNMENTS,
600                      updated=functools.WRAPPER_UPDATES):
601        # Check attributes were assigned
602        for name in assigned:
603            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
604        # Check attributes were updated
605        for name in updated:
606            wrapper_attr = getattr(wrapper, name)
607            wrapped_attr = getattr(wrapped, name)
608            for key in wrapped_attr:
609                if name == "__dict__" and key == "__wrapped__":
610                    # __wrapped__ is overwritten by the update code
611                    continue
612                self.assertIs(wrapped_attr[key], wrapper_attr[key])
613        # Check __wrapped__
614        self.assertIs(wrapper.__wrapped__, wrapped)
615
616
617    def _default_update(self):
618        def f(a:'This is a new annotation'):
619            """This is a test"""
620            pass
621        f.attr = 'This is also a test'
622        f.__wrapped__ = "This is a bald faced lie"
623        def wrapper(b:'This is the prior annotation'):
624            pass
625        functools.update_wrapper(wrapper, f)
626        return wrapper, f
627
628    def test_default_update(self):
629        wrapper, f = self._default_update()
630        self.check_wrapper(wrapper, f)
631        self.assertIs(wrapper.__wrapped__, f)
632        self.assertEqual(wrapper.__name__, 'f')
633        self.assertEqual(wrapper.__qualname__, f.__qualname__)
634        self.assertEqual(wrapper.attr, 'This is also a test')
635        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
636        self.assertNotIn('b', wrapper.__annotations__)
637
638    @unittest.skipIf(sys.flags.optimize >= 2,
639                     "Docstrings are omitted with -O2 and above")
640    def test_default_update_doc(self):
641        wrapper, f = self._default_update()
642        self.assertEqual(wrapper.__doc__, 'This is a test')
643
644    def test_no_update(self):
645        def f():
646            """This is a test"""
647            pass
648        f.attr = 'This is also a test'
649        def wrapper():
650            pass
651        functools.update_wrapper(wrapper, f, (), ())
652        self.check_wrapper(wrapper, f, (), ())
653        self.assertEqual(wrapper.__name__, 'wrapper')
654        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
655        self.assertEqual(wrapper.__doc__, None)
656        self.assertEqual(wrapper.__annotations__, {})
657        self.assertFalse(hasattr(wrapper, 'attr'))
658
659    def test_selective_update(self):
660        def f():
661            pass
662        f.attr = 'This is a different test'
663        f.dict_attr = dict(a=1, b=2, c=3)
664        def wrapper():
665            pass
666        wrapper.dict_attr = {}
667        assign = ('attr',)
668        update = ('dict_attr',)
669        functools.update_wrapper(wrapper, f, assign, update)
670        self.check_wrapper(wrapper, f, assign, update)
671        self.assertEqual(wrapper.__name__, 'wrapper')
672        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
673        self.assertEqual(wrapper.__doc__, None)
674        self.assertEqual(wrapper.attr, 'This is a different test')
675        self.assertEqual(wrapper.dict_attr, f.dict_attr)
676
677    def test_missing_attributes(self):
678        def f():
679            pass
680        def wrapper():
681            pass
682        wrapper.dict_attr = {}
683        assign = ('attr',)
684        update = ('dict_attr',)
685        # Missing attributes on wrapped object are ignored
686        functools.update_wrapper(wrapper, f, assign, update)
687        self.assertNotIn('attr', wrapper.__dict__)
688        self.assertEqual(wrapper.dict_attr, {})
689        # Wrapper must have expected attributes for updating
690        del wrapper.dict_attr
691        with self.assertRaises(AttributeError):
692            functools.update_wrapper(wrapper, f, assign, update)
693        wrapper.dict_attr = 1
694        with self.assertRaises(AttributeError):
695            functools.update_wrapper(wrapper, f, assign, update)
696
697    @support.requires_docstrings
698    @unittest.skipIf(sys.flags.optimize >= 2,
699                     "Docstrings are omitted with -O2 and above")
700    def test_builtin_update(self):
701        # Test for bug #1576241
702        def wrapper():
703            pass
704        functools.update_wrapper(wrapper, max)
705        self.assertEqual(wrapper.__name__, 'max')
706        self.assertTrue(wrapper.__doc__.startswith('max('))
707        self.assertEqual(wrapper.__annotations__, {})
708
709
710class TestWraps(TestUpdateWrapper):
711
712    def _default_update(self):
713        def f():
714            """This is a test"""
715            pass
716        f.attr = 'This is also a test'
717        f.__wrapped__ = "This is still a bald faced lie"
718        @functools.wraps(f)
719        def wrapper():
720            pass
721        return wrapper, f
722
723    def test_default_update(self):
724        wrapper, f = self._default_update()
725        self.check_wrapper(wrapper, f)
726        self.assertEqual(wrapper.__name__, 'f')
727        self.assertEqual(wrapper.__qualname__, f.__qualname__)
728        self.assertEqual(wrapper.attr, 'This is also a test')
729
730    @unittest.skipIf(sys.flags.optimize >= 2,
731                     "Docstrings are omitted with -O2 and above")
732    def test_default_update_doc(self):
733        wrapper, _ = self._default_update()
734        self.assertEqual(wrapper.__doc__, 'This is a test')
735
736    def test_no_update(self):
737        def f():
738            """This is a test"""
739            pass
740        f.attr = 'This is also a test'
741        @functools.wraps(f, (), ())
742        def wrapper():
743            pass
744        self.check_wrapper(wrapper, f, (), ())
745        self.assertEqual(wrapper.__name__, 'wrapper')
746        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
747        self.assertEqual(wrapper.__doc__, None)
748        self.assertFalse(hasattr(wrapper, 'attr'))
749
750    def test_selective_update(self):
751        def f():
752            pass
753        f.attr = 'This is a different test'
754        f.dict_attr = dict(a=1, b=2, c=3)
755        def add_dict_attr(f):
756            f.dict_attr = {}
757            return f
758        assign = ('attr',)
759        update = ('dict_attr',)
760        @functools.wraps(f, assign, update)
761        @add_dict_attr
762        def wrapper():
763            pass
764        self.check_wrapper(wrapper, f, assign, update)
765        self.assertEqual(wrapper.__name__, 'wrapper')
766        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
767        self.assertEqual(wrapper.__doc__, None)
768        self.assertEqual(wrapper.attr, 'This is a different test')
769        self.assertEqual(wrapper.dict_attr, f.dict_attr)
770
771
772class TestReduce:
773    def test_reduce(self):
774        class Squares:
775            def __init__(self, max):
776                self.max = max
777                self.sofar = []
778
779            def __len__(self):
780                return len(self.sofar)
781
782            def __getitem__(self, i):
783                if not 0 <= i < self.max: raise IndexError
784                n = len(self.sofar)
785                while n <= i:
786                    self.sofar.append(n*n)
787                    n += 1
788                return self.sofar[i]
789        def add(x, y):
790            return x + y
791        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
792        self.assertEqual(
793            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
794            ['a','c','d','w']
795        )
796        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
797        self.assertEqual(
798            self.reduce(lambda x, y: x*y, range(2,21), 1),
799            2432902008176640000
800        )
801        self.assertEqual(self.reduce(add, Squares(10)), 285)
802        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
803        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
804        self.assertRaises(TypeError, self.reduce)
805        self.assertRaises(TypeError, self.reduce, 42, 42)
806        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
807        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
808        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
809        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
810        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
811        self.assertRaises(TypeError, self.reduce, add, "")
812        self.assertRaises(TypeError, self.reduce, add, ())
813        self.assertRaises(TypeError, self.reduce, add, object())
814
815        class TestFailingIter:
816            def __iter__(self):
817                raise RuntimeError
818        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
819
820        self.assertEqual(self.reduce(add, [], None), None)
821        self.assertEqual(self.reduce(add, [], 42), 42)
822
823        class BadSeq:
824            def __getitem__(self, index):
825                raise ValueError
826        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
827
828    # Test reduce()'s use of iterators.
829    def test_iterator_usage(self):
830        class SequenceClass:
831            def __init__(self, n):
832                self.n = n
833            def __getitem__(self, i):
834                if 0 <= i < self.n:
835                    return i
836                else:
837                    raise IndexError
838
839        from operator import add
840        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
841        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
842        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
843        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
844        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
845        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
846
847        d = {"one": 1, "two": 2, "three": 3}
848        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
849
850
851@unittest.skipUnless(c_functools, 'requires the C _functools module')
852class TestReduceC(TestReduce, unittest.TestCase):
853    if c_functools:
854        reduce = c_functools.reduce
855
856
857class TestReducePy(TestReduce, unittest.TestCase):
858    reduce = staticmethod(py_functools.reduce)
859
860
861class TestCmpToKey:
862
863    def test_cmp_to_key(self):
864        def cmp1(x, y):
865            return (x > y) - (x < y)
866        key = self.cmp_to_key(cmp1)
867        self.assertEqual(key(3), key(3))
868        self.assertGreater(key(3), key(1))
869        self.assertGreaterEqual(key(3), key(3))
870
871        def cmp2(x, y):
872            return int(x) - int(y)
873        key = self.cmp_to_key(cmp2)
874        self.assertEqual(key(4.0), key('4'))
875        self.assertLess(key(2), key('35'))
876        self.assertLessEqual(key(2), key('35'))
877        self.assertNotEqual(key(2), key('35'))
878
879    def test_cmp_to_key_arguments(self):
880        def cmp1(x, y):
881            return (x > y) - (x < y)
882        key = self.cmp_to_key(mycmp=cmp1)
883        self.assertEqual(key(obj=3), key(obj=3))
884        self.assertGreater(key(obj=3), key(obj=1))
885        with self.assertRaises((TypeError, AttributeError)):
886            key(3) > 1    # rhs is not a K object
887        with self.assertRaises((TypeError, AttributeError)):
888            1 < key(3)    # lhs is not a K object
889        with self.assertRaises(TypeError):
890            key = self.cmp_to_key()             # too few args
891        with self.assertRaises(TypeError):
892            key = self.cmp_to_key(cmp1, None)   # too many args
893        key = self.cmp_to_key(cmp1)
894        with self.assertRaises(TypeError):
895            key()                                    # too few args
896        with self.assertRaises(TypeError):
897            key(None, None)                          # too many args
898
899    def test_bad_cmp(self):
900        def cmp1(x, y):
901            raise ZeroDivisionError
902        key = self.cmp_to_key(cmp1)
903        with self.assertRaises(ZeroDivisionError):
904            key(3) > key(1)
905
906        class BadCmp:
907            def __lt__(self, other):
908                raise ZeroDivisionError
909        def cmp1(x, y):
910            return BadCmp()
911        with self.assertRaises(ZeroDivisionError):
912            key(3) > key(1)
913
914    def test_obj_field(self):
915        def cmp1(x, y):
916            return (x > y) - (x < y)
917        key = self.cmp_to_key(mycmp=cmp1)
918        self.assertEqual(key(50).obj, 50)
919
920    def test_sort_int(self):
921        def mycmp(x, y):
922            return y - x
923        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
924                         [4, 3, 2, 1, 0])
925
926    def test_sort_int_str(self):
927        def mycmp(x, y):
928            x, y = int(x), int(y)
929            return (x > y) - (x < y)
930        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
931        values = sorted(values, key=self.cmp_to_key(mycmp))
932        self.assertEqual([int(value) for value in values],
933                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
934
935    def test_hash(self):
936        def mycmp(x, y):
937            return y - x
938        key = self.cmp_to_key(mycmp)
939        k = key(10)
940        self.assertRaises(TypeError, hash, k)
941        self.assertNotIsInstance(k, collections.abc.Hashable)
942
943
944@unittest.skipUnless(c_functools, 'requires the C _functools module')
945class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
946    if c_functools:
947        cmp_to_key = c_functools.cmp_to_key
948
949
950class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
951    cmp_to_key = staticmethod(py_functools.cmp_to_key)
952
953
954class TestTotalOrdering(unittest.TestCase):
955
956    def test_total_ordering_lt(self):
957        @functools.total_ordering
958        class A:
959            def __init__(self, value):
960                self.value = value
961            def __lt__(self, other):
962                return self.value < other.value
963            def __eq__(self, other):
964                return self.value == other.value
965        self.assertTrue(A(1) < A(2))
966        self.assertTrue(A(2) > A(1))
967        self.assertTrue(A(1) <= A(2))
968        self.assertTrue(A(2) >= A(1))
969        self.assertTrue(A(2) <= A(2))
970        self.assertTrue(A(2) >= A(2))
971        self.assertFalse(A(1) > A(2))
972
973    def test_total_ordering_le(self):
974        @functools.total_ordering
975        class A:
976            def __init__(self, value):
977                self.value = value
978            def __le__(self, other):
979                return self.value <= other.value
980            def __eq__(self, other):
981                return self.value == other.value
982        self.assertTrue(A(1) < A(2))
983        self.assertTrue(A(2) > A(1))
984        self.assertTrue(A(1) <= A(2))
985        self.assertTrue(A(2) >= A(1))
986        self.assertTrue(A(2) <= A(2))
987        self.assertTrue(A(2) >= A(2))
988        self.assertFalse(A(1) >= A(2))
989
990    def test_total_ordering_gt(self):
991        @functools.total_ordering
992        class A:
993            def __init__(self, value):
994                self.value = value
995            def __gt__(self, other):
996                return self.value > other.value
997            def __eq__(self, other):
998                return self.value == other.value
999        self.assertTrue(A(1) < A(2))
1000        self.assertTrue(A(2) > A(1))
1001        self.assertTrue(A(1) <= A(2))
1002        self.assertTrue(A(2) >= A(1))
1003        self.assertTrue(A(2) <= A(2))
1004        self.assertTrue(A(2) >= A(2))
1005        self.assertFalse(A(2) < A(1))
1006
1007    def test_total_ordering_ge(self):
1008        @functools.total_ordering
1009        class A:
1010            def __init__(self, value):
1011                self.value = value
1012            def __ge__(self, other):
1013                return self.value >= other.value
1014            def __eq__(self, other):
1015                return self.value == other.value
1016        self.assertTrue(A(1) < A(2))
1017        self.assertTrue(A(2) > A(1))
1018        self.assertTrue(A(1) <= A(2))
1019        self.assertTrue(A(2) >= A(1))
1020        self.assertTrue(A(2) <= A(2))
1021        self.assertTrue(A(2) >= A(2))
1022        self.assertFalse(A(2) <= A(1))
1023
1024    def test_total_ordering_no_overwrite(self):
1025        # new methods should not overwrite existing
1026        @functools.total_ordering
1027        class A(int):
1028            pass
1029        self.assertTrue(A(1) < A(2))
1030        self.assertTrue(A(2) > A(1))
1031        self.assertTrue(A(1) <= A(2))
1032        self.assertTrue(A(2) >= A(1))
1033        self.assertTrue(A(2) <= A(2))
1034        self.assertTrue(A(2) >= A(2))
1035
1036    def test_no_operations_defined(self):
1037        with self.assertRaises(ValueError):
1038            @functools.total_ordering
1039            class A:
1040                pass
1041
1042    def test_type_error_when_not_implemented(self):
1043        # bug 10042; ensure stack overflow does not occur
1044        # when decorated types return NotImplemented
1045        @functools.total_ordering
1046        class ImplementsLessThan:
1047            def __init__(self, value):
1048                self.value = value
1049            def __eq__(self, other):
1050                if isinstance(other, ImplementsLessThan):
1051                    return self.value == other.value
1052                return False
1053            def __lt__(self, other):
1054                if isinstance(other, ImplementsLessThan):
1055                    return self.value < other.value
1056                return NotImplemented
1057
1058        @functools.total_ordering
1059        class ImplementsGreaterThan:
1060            def __init__(self, value):
1061                self.value = value
1062            def __eq__(self, other):
1063                if isinstance(other, ImplementsGreaterThan):
1064                    return self.value == other.value
1065                return False
1066            def __gt__(self, other):
1067                if isinstance(other, ImplementsGreaterThan):
1068                    return self.value > other.value
1069                return NotImplemented
1070
1071        @functools.total_ordering
1072        class ImplementsLessThanEqualTo:
1073            def __init__(self, value):
1074                self.value = value
1075            def __eq__(self, other):
1076                if isinstance(other, ImplementsLessThanEqualTo):
1077                    return self.value == other.value
1078                return False
1079            def __le__(self, other):
1080                if isinstance(other, ImplementsLessThanEqualTo):
1081                    return self.value <= other.value
1082                return NotImplemented
1083
1084        @functools.total_ordering
1085        class ImplementsGreaterThanEqualTo:
1086            def __init__(self, value):
1087                self.value = value
1088            def __eq__(self, other):
1089                if isinstance(other, ImplementsGreaterThanEqualTo):
1090                    return self.value == other.value
1091                return False
1092            def __ge__(self, other):
1093                if isinstance(other, ImplementsGreaterThanEqualTo):
1094                    return self.value >= other.value
1095                return NotImplemented
1096
1097        @functools.total_ordering
1098        class ComparatorNotImplemented:
1099            def __init__(self, value):
1100                self.value = value
1101            def __eq__(self, other):
1102                if isinstance(other, ComparatorNotImplemented):
1103                    return self.value == other.value
1104                return False
1105            def __lt__(self, other):
1106                return NotImplemented
1107
1108        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1109            ImplementsLessThan(-1) < 1
1110
1111        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1112            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1113
1114        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1115            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1116
1117        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1118            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1119
1120        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1121            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1122
1123        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1124            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1125
1126        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1127            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1128
1129        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1130            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1131
1132        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1133            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1134
1135        with self.subTest("GE when equal"):
1136            a = ComparatorNotImplemented(8)
1137            b = ComparatorNotImplemented(8)
1138            self.assertEqual(a, b)
1139            with self.assertRaises(TypeError):
1140                a >= b
1141
1142        with self.subTest("LE when equal"):
1143            a = ComparatorNotImplemented(9)
1144            b = ComparatorNotImplemented(9)
1145            self.assertEqual(a, b)
1146            with self.assertRaises(TypeError):
1147                a <= b
1148
1149    def test_pickle(self):
1150        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1151            for name in '__lt__', '__gt__', '__le__', '__ge__':
1152                with self.subTest(method=name, proto=proto):
1153                    method = getattr(Orderable_LT, name)
1154                    method_copy = pickle.loads(pickle.dumps(method, proto))
1155                    self.assertIs(method_copy, method)
1156
1157
1158    def test_total_ordering_for_metaclasses_issue_44605(self):
1159
1160        @functools.total_ordering
1161        class SortableMeta(type):
1162            def __new__(cls, name, bases, ns):
1163                return super().__new__(cls, name, bases, ns)
1164
1165            def __lt__(self, other):
1166                if not isinstance(other, SortableMeta):
1167                    pass
1168                return self.__name__ < other.__name__
1169
1170            def __eq__(self, other):
1171                if not isinstance(other, SortableMeta):
1172                    pass
1173                return self.__name__ == other.__name__
1174
1175        class B(metaclass=SortableMeta):
1176            pass
1177
1178        class A(metaclass=SortableMeta):
1179            pass
1180
1181        self.assertTrue(A < B)
1182        self.assertFalse(A > B)
1183
1184
1185@functools.total_ordering
1186class Orderable_LT:
1187    def __init__(self, value):
1188        self.value = value
1189    def __lt__(self, other):
1190        return self.value < other.value
1191    def __eq__(self, other):
1192        return self.value == other.value
1193
1194
1195class TestCache:
1196    # This tests that the pass-through is working as designed.
1197    # The underlying functionality is tested in TestLRU.
1198
1199    def test_cache(self):
1200        @self.module.cache
1201        def fib(n):
1202            if n < 2:
1203                return n
1204            return fib(n-1) + fib(n-2)
1205        self.assertEqual([fib(n) for n in range(16)],
1206            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1207        self.assertEqual(fib.cache_info(),
1208            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1209        fib.cache_clear()
1210        self.assertEqual(fib.cache_info(),
1211            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1212
1213
1214class TestLRU:
1215
1216    def test_lru(self):
1217        def orig(x, y):
1218            return 3 * x + y
1219        f = self.module.lru_cache(maxsize=20)(orig)
1220        hits, misses, maxsize, currsize = f.cache_info()
1221        self.assertEqual(maxsize, 20)
1222        self.assertEqual(currsize, 0)
1223        self.assertEqual(hits, 0)
1224        self.assertEqual(misses, 0)
1225
1226        domain = range(5)
1227        for i in range(1000):
1228            x, y = choice(domain), choice(domain)
1229            actual = f(x, y)
1230            expected = orig(x, y)
1231            self.assertEqual(actual, expected)
1232        hits, misses, maxsize, currsize = f.cache_info()
1233        self.assertTrue(hits > misses)
1234        self.assertEqual(hits + misses, 1000)
1235        self.assertEqual(currsize, 20)
1236
1237        f.cache_clear()   # test clearing
1238        hits, misses, maxsize, currsize = f.cache_info()
1239        self.assertEqual(hits, 0)
1240        self.assertEqual(misses, 0)
1241        self.assertEqual(currsize, 0)
1242        f(x, y)
1243        hits, misses, maxsize, currsize = f.cache_info()
1244        self.assertEqual(hits, 0)
1245        self.assertEqual(misses, 1)
1246        self.assertEqual(currsize, 1)
1247
1248        # Test bypassing the cache
1249        self.assertIs(f.__wrapped__, orig)
1250        f.__wrapped__(x, y)
1251        hits, misses, maxsize, currsize = f.cache_info()
1252        self.assertEqual(hits, 0)
1253        self.assertEqual(misses, 1)
1254        self.assertEqual(currsize, 1)
1255
1256        # test size zero (which means "never-cache")
1257        @self.module.lru_cache(0)
1258        def f():
1259            nonlocal f_cnt
1260            f_cnt += 1
1261            return 20
1262        self.assertEqual(f.cache_info().maxsize, 0)
1263        f_cnt = 0
1264        for i in range(5):
1265            self.assertEqual(f(), 20)
1266        self.assertEqual(f_cnt, 5)
1267        hits, misses, maxsize, currsize = f.cache_info()
1268        self.assertEqual(hits, 0)
1269        self.assertEqual(misses, 5)
1270        self.assertEqual(currsize, 0)
1271
1272        # test size one
1273        @self.module.lru_cache(1)
1274        def f():
1275            nonlocal f_cnt
1276            f_cnt += 1
1277            return 20
1278        self.assertEqual(f.cache_info().maxsize, 1)
1279        f_cnt = 0
1280        for i in range(5):
1281            self.assertEqual(f(), 20)
1282        self.assertEqual(f_cnt, 1)
1283        hits, misses, maxsize, currsize = f.cache_info()
1284        self.assertEqual(hits, 4)
1285        self.assertEqual(misses, 1)
1286        self.assertEqual(currsize, 1)
1287
1288        # test size two
1289        @self.module.lru_cache(2)
1290        def f(x):
1291            nonlocal f_cnt
1292            f_cnt += 1
1293            return x*10
1294        self.assertEqual(f.cache_info().maxsize, 2)
1295        f_cnt = 0
1296        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1297            #    *  *              *                          *
1298            self.assertEqual(f(x), x*10)
1299        self.assertEqual(f_cnt, 4)
1300        hits, misses, maxsize, currsize = f.cache_info()
1301        self.assertEqual(hits, 12)
1302        self.assertEqual(misses, 4)
1303        self.assertEqual(currsize, 2)
1304
1305    def test_lru_no_args(self):
1306        @self.module.lru_cache
1307        def square(x):
1308            return x ** 2
1309
1310        self.assertEqual(list(map(square, [10, 20, 10])),
1311                         [100, 400, 100])
1312        self.assertEqual(square.cache_info().hits, 1)
1313        self.assertEqual(square.cache_info().misses, 2)
1314        self.assertEqual(square.cache_info().maxsize, 128)
1315        self.assertEqual(square.cache_info().currsize, 2)
1316
1317    def test_lru_bug_35780(self):
1318        # C version of the lru_cache was not checking to see if
1319        # the user function call has already modified the cache
1320        # (this arises in recursive calls and in multi-threading).
1321        # This cause the cache to have orphan links not referenced
1322        # by the cache dictionary.
1323
1324        once = True                 # Modified by f(x) below
1325
1326        @self.module.lru_cache(maxsize=10)
1327        def f(x):
1328            nonlocal once
1329            rv = f'.{x}.'
1330            if x == 20 and once:
1331                once = False
1332                rv = f(x)
1333            return rv
1334
1335        # Fill the cache
1336        for x in range(15):
1337            self.assertEqual(f(x), f'.{x}.')
1338        self.assertEqual(f.cache_info().currsize, 10)
1339
1340        # Make a recursive call and make sure the cache remains full
1341        self.assertEqual(f(20), '.20.')
1342        self.assertEqual(f.cache_info().currsize, 10)
1343
1344    def test_lru_bug_36650(self):
1345        # C version of lru_cache was treating a call with an empty **kwargs
1346        # dictionary as being distinct from a call with no keywords at all.
1347        # This did not result in an incorrect answer, but it did trigger
1348        # an unexpected cache miss.
1349
1350        @self.module.lru_cache()
1351        def f(x):
1352            pass
1353
1354        f(0)
1355        f(0, **{})
1356        self.assertEqual(f.cache_info().hits, 1)
1357
1358    def test_lru_hash_only_once(self):
1359        # To protect against weird reentrancy bugs and to improve
1360        # efficiency when faced with slow __hash__ methods, the
1361        # LRU cache guarantees that it will only call __hash__
1362        # only once per use as an argument to the cached function.
1363
1364        @self.module.lru_cache(maxsize=1)
1365        def f(x, y):
1366            return x * 3 + y
1367
1368        # Simulate the integer 5
1369        mock_int = unittest.mock.Mock()
1370        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1371        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1372
1373        # Add to cache:  One use as an argument gives one call
1374        self.assertEqual(f(mock_int, 1), 16)
1375        self.assertEqual(mock_int.__hash__.call_count, 1)
1376        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1377
1378        # Cache hit: One use as an argument gives one additional call
1379        self.assertEqual(f(mock_int, 1), 16)
1380        self.assertEqual(mock_int.__hash__.call_count, 2)
1381        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1382
1383        # Cache eviction: No use as an argument gives no additional call
1384        self.assertEqual(f(6, 2), 20)
1385        self.assertEqual(mock_int.__hash__.call_count, 2)
1386        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1387
1388        # Cache miss: One use as an argument gives one additional call
1389        self.assertEqual(f(mock_int, 1), 16)
1390        self.assertEqual(mock_int.__hash__.call_count, 3)
1391        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1392
1393    def test_lru_reentrancy_with_len(self):
1394        # Test to make sure the LRU cache code isn't thrown-off by
1395        # caching the built-in len() function.  Since len() can be
1396        # cached, we shouldn't use it inside the lru code itself.
1397        old_len = builtins.len
1398        try:
1399            builtins.len = self.module.lru_cache(4)(len)
1400            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1401                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1402        finally:
1403            builtins.len = old_len
1404
1405    def test_lru_star_arg_handling(self):
1406        # Test regression that arose in ea064ff3c10f
1407        @functools.lru_cache()
1408        def f(*args):
1409            return args
1410
1411        self.assertEqual(f(1, 2), (1, 2))
1412        self.assertEqual(f((1, 2)), ((1, 2),))
1413
1414    def test_lru_type_error(self):
1415        # Regression test for issue #28653.
1416        # lru_cache was leaking when one of the arguments
1417        # wasn't cacheable.
1418
1419        @functools.lru_cache(maxsize=None)
1420        def infinite_cache(o):
1421            pass
1422
1423        @functools.lru_cache(maxsize=10)
1424        def limited_cache(o):
1425            pass
1426
1427        with self.assertRaises(TypeError):
1428            infinite_cache([])
1429
1430        with self.assertRaises(TypeError):
1431            limited_cache([])
1432
1433    def test_lru_with_maxsize_none(self):
1434        @self.module.lru_cache(maxsize=None)
1435        def fib(n):
1436            if n < 2:
1437                return n
1438            return fib(n-1) + fib(n-2)
1439        self.assertEqual([fib(n) for n in range(16)],
1440            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1441        self.assertEqual(fib.cache_info(),
1442            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1443        fib.cache_clear()
1444        self.assertEqual(fib.cache_info(),
1445            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1446
1447    def test_lru_with_maxsize_negative(self):
1448        @self.module.lru_cache(maxsize=-10)
1449        def eq(n):
1450            return n
1451        for i in (0, 1):
1452            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1453        self.assertEqual(eq.cache_info(),
1454            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1455
1456    def test_lru_with_exceptions(self):
1457        # Verify that user_function exceptions get passed through without
1458        # creating a hard-to-read chained exception.
1459        # http://bugs.python.org/issue13177
1460        for maxsize in (None, 128):
1461            @self.module.lru_cache(maxsize)
1462            def func(i):
1463                return 'abc'[i]
1464            self.assertEqual(func(0), 'a')
1465            with self.assertRaises(IndexError) as cm:
1466                func(15)
1467            self.assertIsNone(cm.exception.__context__)
1468            # Verify that the previous exception did not result in a cached entry
1469            with self.assertRaises(IndexError):
1470                func(15)
1471
1472    def test_lru_with_types(self):
1473        for maxsize in (None, 128):
1474            @self.module.lru_cache(maxsize=maxsize, typed=True)
1475            def square(x):
1476                return x * x
1477            self.assertEqual(square(3), 9)
1478            self.assertEqual(type(square(3)), type(9))
1479            self.assertEqual(square(3.0), 9.0)
1480            self.assertEqual(type(square(3.0)), type(9.0))
1481            self.assertEqual(square(x=3), 9)
1482            self.assertEqual(type(square(x=3)), type(9))
1483            self.assertEqual(square(x=3.0), 9.0)
1484            self.assertEqual(type(square(x=3.0)), type(9.0))
1485            self.assertEqual(square.cache_info().hits, 4)
1486            self.assertEqual(square.cache_info().misses, 4)
1487
1488    def test_lru_with_keyword_args(self):
1489        @self.module.lru_cache()
1490        def fib(n):
1491            if n < 2:
1492                return n
1493            return fib(n=n-1) + fib(n=n-2)
1494        self.assertEqual(
1495            [fib(n=number) for number in range(16)],
1496            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1497        )
1498        self.assertEqual(fib.cache_info(),
1499            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1500        fib.cache_clear()
1501        self.assertEqual(fib.cache_info(),
1502            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1503
1504    def test_lru_with_keyword_args_maxsize_none(self):
1505        @self.module.lru_cache(maxsize=None)
1506        def fib(n):
1507            if n < 2:
1508                return n
1509            return fib(n=n-1) + fib(n=n-2)
1510        self.assertEqual([fib(n=number) for number in range(16)],
1511            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1512        self.assertEqual(fib.cache_info(),
1513            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1514        fib.cache_clear()
1515        self.assertEqual(fib.cache_info(),
1516            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1517
1518    def test_kwargs_order(self):
1519        # PEP 468: Preserving Keyword Argument Order
1520        @self.module.lru_cache(maxsize=10)
1521        def f(**kwargs):
1522            return list(kwargs.items())
1523        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1524        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1525        self.assertEqual(f.cache_info(),
1526            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1527
1528    def test_lru_cache_decoration(self):
1529        def f(zomg: 'zomg_annotation'):
1530            """f doc string"""
1531            return 42
1532        g = self.module.lru_cache()(f)
1533        for attr in self.module.WRAPPER_ASSIGNMENTS:
1534            self.assertEqual(getattr(g, attr), getattr(f, attr))
1535
1536    def test_lru_cache_threaded(self):
1537        n, m = 5, 11
1538        def orig(x, y):
1539            return 3 * x + y
1540        f = self.module.lru_cache(maxsize=n*m)(orig)
1541        hits, misses, maxsize, currsize = f.cache_info()
1542        self.assertEqual(currsize, 0)
1543
1544        start = threading.Event()
1545        def full(k):
1546            start.wait(10)
1547            for _ in range(m):
1548                self.assertEqual(f(k, 0), orig(k, 0))
1549
1550        def clear():
1551            start.wait(10)
1552            for _ in range(2*m):
1553                f.cache_clear()
1554
1555        orig_si = sys.getswitchinterval()
1556        support.setswitchinterval(1e-6)
1557        try:
1558            # create n threads in order to fill cache
1559            threads = [threading.Thread(target=full, args=[k])
1560                       for k in range(n)]
1561            with support.start_threads(threads):
1562                start.set()
1563
1564            hits, misses, maxsize, currsize = f.cache_info()
1565            if self.module is py_functools:
1566                # XXX: Why can be not equal?
1567                self.assertLessEqual(misses, n)
1568                self.assertLessEqual(hits, m*n - misses)
1569            else:
1570                self.assertEqual(misses, n)
1571                self.assertEqual(hits, m*n - misses)
1572            self.assertEqual(currsize, n)
1573
1574            # create n threads in order to fill cache and 1 to clear it
1575            threads = [threading.Thread(target=clear)]
1576            threads += [threading.Thread(target=full, args=[k])
1577                        for k in range(n)]
1578            start.clear()
1579            with support.start_threads(threads):
1580                start.set()
1581        finally:
1582            sys.setswitchinterval(orig_si)
1583
1584    def test_lru_cache_threaded2(self):
1585        # Simultaneous call with the same arguments
1586        n, m = 5, 7
1587        start = threading.Barrier(n+1)
1588        pause = threading.Barrier(n+1)
1589        stop = threading.Barrier(n+1)
1590        @self.module.lru_cache(maxsize=m*n)
1591        def f(x):
1592            pause.wait(10)
1593            return 3 * x
1594        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1595        def test():
1596            for i in range(m):
1597                start.wait(10)
1598                self.assertEqual(f(i), 3 * i)
1599                stop.wait(10)
1600        threads = [threading.Thread(target=test) for k in range(n)]
1601        with support.start_threads(threads):
1602            for i in range(m):
1603                start.wait(10)
1604                stop.reset()
1605                pause.wait(10)
1606                start.reset()
1607                stop.wait(10)
1608                pause.reset()
1609                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1610
1611    def test_lru_cache_threaded3(self):
1612        @self.module.lru_cache(maxsize=2)
1613        def f(x):
1614            time.sleep(.01)
1615            return 3 * x
1616        def test(i, x):
1617            with self.subTest(thread=i):
1618                self.assertEqual(f(x), 3 * x, i)
1619        threads = [threading.Thread(target=test, args=(i, v))
1620                   for i, v in enumerate([1, 2, 2, 3, 2])]
1621        with support.start_threads(threads):
1622            pass
1623
1624    def test_need_for_rlock(self):
1625        # This will deadlock on an LRU cache that uses a regular lock
1626
1627        @self.module.lru_cache(maxsize=10)
1628        def test_func(x):
1629            'Used to demonstrate a reentrant lru_cache call within a single thread'
1630            return x
1631
1632        class DoubleEq:
1633            'Demonstrate a reentrant lru_cache call within a single thread'
1634            def __init__(self, x):
1635                self.x = x
1636            def __hash__(self):
1637                return self.x
1638            def __eq__(self, other):
1639                if self.x == 2:
1640                    test_func(DoubleEq(1))
1641                return self.x == other.x
1642
1643        test_func(DoubleEq(1))                      # Load the cache
1644        test_func(DoubleEq(2))                      # Load the cache
1645        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1646                         DoubleEq(2))               # Verify the correct return value
1647
1648    def test_lru_method(self):
1649        class X(int):
1650            f_cnt = 0
1651            @self.module.lru_cache(2)
1652            def f(self, x):
1653                self.f_cnt += 1
1654                return x*10+self
1655        a = X(5)
1656        b = X(5)
1657        c = X(7)
1658        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1659
1660        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1661            self.assertEqual(a.f(x), x*10 + 5)
1662        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1663        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1664
1665        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1666            self.assertEqual(b.f(x), x*10 + 5)
1667        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1668        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1669
1670        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1671            self.assertEqual(c.f(x), x*10 + 7)
1672        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1673        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1674
1675        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1676        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1677        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1678
1679    def test_pickle(self):
1680        cls = self.__class__
1681        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1682            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1683                with self.subTest(proto=proto, func=f):
1684                    f_copy = pickle.loads(pickle.dumps(f, proto))
1685                    self.assertIs(f_copy, f)
1686
1687    def test_copy(self):
1688        cls = self.__class__
1689        def orig(x, y):
1690            return 3 * x + y
1691        part = self.module.partial(orig, 2)
1692        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1693                 self.module.lru_cache(2)(part))
1694        for f in funcs:
1695            with self.subTest(func=f):
1696                f_copy = copy.copy(f)
1697                self.assertIs(f_copy, f)
1698
1699    def test_deepcopy(self):
1700        cls = self.__class__
1701        def orig(x, y):
1702            return 3 * x + y
1703        part = self.module.partial(orig, 2)
1704        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1705                 self.module.lru_cache(2)(part))
1706        for f in funcs:
1707            with self.subTest(func=f):
1708                f_copy = copy.deepcopy(f)
1709                self.assertIs(f_copy, f)
1710
1711    def test_lru_cache_parameters(self):
1712        @self.module.lru_cache(maxsize=2)
1713        def f():
1714            return 1
1715        self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1716
1717        @self.module.lru_cache(maxsize=1000, typed=True)
1718        def f():
1719            return 1
1720        self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1721
1722    def test_lru_cache_weakrefable(self):
1723        @self.module.lru_cache
1724        def test_function(x):
1725            return x
1726
1727        class A:
1728            @self.module.lru_cache
1729            def test_method(self, x):
1730                return (self, x)
1731
1732            @staticmethod
1733            @self.module.lru_cache
1734            def test_staticmethod(x):
1735                return (self, x)
1736
1737        refs = [weakref.ref(test_function),
1738                weakref.ref(A.test_method),
1739                weakref.ref(A.test_staticmethod)]
1740
1741        for ref in refs:
1742            self.assertIsNotNone(ref())
1743
1744        del A
1745        del test_function
1746        gc.collect()
1747
1748        for ref in refs:
1749            self.assertIsNone(ref())
1750
1751
1752@py_functools.lru_cache()
1753def py_cached_func(x, y):
1754    return 3 * x + y
1755
1756@c_functools.lru_cache()
1757def c_cached_func(x, y):
1758    return 3 * x + y
1759
1760
1761class TestLRUPy(TestLRU, unittest.TestCase):
1762    module = py_functools
1763    cached_func = py_cached_func,
1764
1765    @module.lru_cache()
1766    def cached_meth(self, x, y):
1767        return 3 * x + y
1768
1769    @staticmethod
1770    @module.lru_cache()
1771    def cached_staticmeth(x, y):
1772        return 3 * x + y
1773
1774
1775class TestLRUC(TestLRU, unittest.TestCase):
1776    module = c_functools
1777    cached_func = c_cached_func,
1778
1779    @module.lru_cache()
1780    def cached_meth(self, x, y):
1781        return 3 * x + y
1782
1783    @staticmethod
1784    @module.lru_cache()
1785    def cached_staticmeth(x, y):
1786        return 3 * x + y
1787
1788
1789class TestSingleDispatch(unittest.TestCase):
1790    def test_simple_overloads(self):
1791        @functools.singledispatch
1792        def g(obj):
1793            return "base"
1794        def g_int(i):
1795            return "integer"
1796        g.register(int, g_int)
1797        self.assertEqual(g("str"), "base")
1798        self.assertEqual(g(1), "integer")
1799        self.assertEqual(g([1,2,3]), "base")
1800
1801    def test_mro(self):
1802        @functools.singledispatch
1803        def g(obj):
1804            return "base"
1805        class A:
1806            pass
1807        class C(A):
1808            pass
1809        class B(A):
1810            pass
1811        class D(C, B):
1812            pass
1813        def g_A(a):
1814            return "A"
1815        def g_B(b):
1816            return "B"
1817        g.register(A, g_A)
1818        g.register(B, g_B)
1819        self.assertEqual(g(A()), "A")
1820        self.assertEqual(g(B()), "B")
1821        self.assertEqual(g(C()), "A")
1822        self.assertEqual(g(D()), "B")
1823
1824    def test_register_decorator(self):
1825        @functools.singledispatch
1826        def g(obj):
1827            return "base"
1828        @g.register(int)
1829        def g_int(i):
1830            return "int %s" % (i,)
1831        self.assertEqual(g(""), "base")
1832        self.assertEqual(g(12), "int 12")
1833        self.assertIs(g.dispatch(int), g_int)
1834        self.assertIs(g.dispatch(object), g.dispatch(str))
1835        # Note: in the assert above this is not g.
1836        # @singledispatch returns the wrapper.
1837
1838    def test_wrapping_attributes(self):
1839        @functools.singledispatch
1840        def g(obj):
1841            "Simple test"
1842            return "Test"
1843        self.assertEqual(g.__name__, "g")
1844        if sys.flags.optimize < 2:
1845            self.assertEqual(g.__doc__, "Simple test")
1846
1847    @unittest.skipUnless(decimal, 'requires _decimal')
1848    @support.cpython_only
1849    def test_c_classes(self):
1850        @functools.singledispatch
1851        def g(obj):
1852            return "base"
1853        @g.register(decimal.DecimalException)
1854        def _(obj):
1855            return obj.args
1856        subn = decimal.Subnormal("Exponent < Emin")
1857        rnd = decimal.Rounded("Number got rounded")
1858        self.assertEqual(g(subn), ("Exponent < Emin",))
1859        self.assertEqual(g(rnd), ("Number got rounded",))
1860        @g.register(decimal.Subnormal)
1861        def _(obj):
1862            return "Too small to care."
1863        self.assertEqual(g(subn), "Too small to care.")
1864        self.assertEqual(g(rnd), ("Number got rounded",))
1865
1866    def test_compose_mro(self):
1867        # None of the examples in this test depend on haystack ordering.
1868        c = collections.abc
1869        mro = functools._compose_mro
1870        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1871        for haystack in permutations(bases):
1872            m = mro(dict, haystack)
1873            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1874                                 c.Collection, c.Sized, c.Iterable,
1875                                 c.Container, object])
1876        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1877        for haystack in permutations(bases):
1878            m = mro(collections.ChainMap, haystack)
1879            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1880                                 c.Collection, c.Sized, c.Iterable,
1881                                 c.Container, object])
1882
1883        # If there's a generic function with implementations registered for
1884        # both Sized and Container, passing a defaultdict to it results in an
1885        # ambiguous dispatch which will cause a RuntimeError (see
1886        # test_mro_conflicts).
1887        bases = [c.Container, c.Sized, str]
1888        for haystack in permutations(bases):
1889            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1890            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1891                                 c.Container, object])
1892
1893        # MutableSequence below is registered directly on D. In other words, it
1894        # precedes MutableMapping which means single dispatch will always
1895        # choose MutableSequence here.
1896        class D(collections.defaultdict):
1897            pass
1898        c.MutableSequence.register(D)
1899        bases = [c.MutableSequence, c.MutableMapping]
1900        for haystack in permutations(bases):
1901            m = mro(D, bases)
1902            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1903                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1904                                 c.Collection, c.Sized, c.Iterable, c.Container,
1905                                 object])
1906
1907        # Container and Callable are registered on different base classes and
1908        # a generic function supporting both should always pick the Callable
1909        # implementation if a C instance is passed.
1910        class C(collections.defaultdict):
1911            def __call__(self):
1912                pass
1913        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1914        for haystack in permutations(bases):
1915            m = mro(C, haystack)
1916            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1917                                 c.Collection, c.Sized, c.Iterable,
1918                                 c.Container, object])
1919
1920    def test_register_abc(self):
1921        c = collections.abc
1922        d = {"a": "b"}
1923        l = [1, 2, 3]
1924        s = {object(), None}
1925        f = frozenset(s)
1926        t = (1, 2, 3)
1927        @functools.singledispatch
1928        def g(obj):
1929            return "base"
1930        self.assertEqual(g(d), "base")
1931        self.assertEqual(g(l), "base")
1932        self.assertEqual(g(s), "base")
1933        self.assertEqual(g(f), "base")
1934        self.assertEqual(g(t), "base")
1935        g.register(c.Sized, lambda obj: "sized")
1936        self.assertEqual(g(d), "sized")
1937        self.assertEqual(g(l), "sized")
1938        self.assertEqual(g(s), "sized")
1939        self.assertEqual(g(f), "sized")
1940        self.assertEqual(g(t), "sized")
1941        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1942        self.assertEqual(g(d), "mutablemapping")
1943        self.assertEqual(g(l), "sized")
1944        self.assertEqual(g(s), "sized")
1945        self.assertEqual(g(f), "sized")
1946        self.assertEqual(g(t), "sized")
1947        g.register(collections.ChainMap, lambda obj: "chainmap")
1948        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1949        self.assertEqual(g(l), "sized")
1950        self.assertEqual(g(s), "sized")
1951        self.assertEqual(g(f), "sized")
1952        self.assertEqual(g(t), "sized")
1953        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1954        self.assertEqual(g(d), "mutablemapping")
1955        self.assertEqual(g(l), "mutablesequence")
1956        self.assertEqual(g(s), "sized")
1957        self.assertEqual(g(f), "sized")
1958        self.assertEqual(g(t), "sized")
1959        g.register(c.MutableSet, lambda obj: "mutableset")
1960        self.assertEqual(g(d), "mutablemapping")
1961        self.assertEqual(g(l), "mutablesequence")
1962        self.assertEqual(g(s), "mutableset")
1963        self.assertEqual(g(f), "sized")
1964        self.assertEqual(g(t), "sized")
1965        g.register(c.Mapping, lambda obj: "mapping")
1966        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1967        self.assertEqual(g(l), "mutablesequence")
1968        self.assertEqual(g(s), "mutableset")
1969        self.assertEqual(g(f), "sized")
1970        self.assertEqual(g(t), "sized")
1971        g.register(c.Sequence, lambda obj: "sequence")
1972        self.assertEqual(g(d), "mutablemapping")
1973        self.assertEqual(g(l), "mutablesequence")
1974        self.assertEqual(g(s), "mutableset")
1975        self.assertEqual(g(f), "sized")
1976        self.assertEqual(g(t), "sequence")
1977        g.register(c.Set, lambda obj: "set")
1978        self.assertEqual(g(d), "mutablemapping")
1979        self.assertEqual(g(l), "mutablesequence")
1980        self.assertEqual(g(s), "mutableset")
1981        self.assertEqual(g(f), "set")
1982        self.assertEqual(g(t), "sequence")
1983        g.register(dict, lambda obj: "dict")
1984        self.assertEqual(g(d), "dict")
1985        self.assertEqual(g(l), "mutablesequence")
1986        self.assertEqual(g(s), "mutableset")
1987        self.assertEqual(g(f), "set")
1988        self.assertEqual(g(t), "sequence")
1989        g.register(list, lambda obj: "list")
1990        self.assertEqual(g(d), "dict")
1991        self.assertEqual(g(l), "list")
1992        self.assertEqual(g(s), "mutableset")
1993        self.assertEqual(g(f), "set")
1994        self.assertEqual(g(t), "sequence")
1995        g.register(set, lambda obj: "concrete-set")
1996        self.assertEqual(g(d), "dict")
1997        self.assertEqual(g(l), "list")
1998        self.assertEqual(g(s), "concrete-set")
1999        self.assertEqual(g(f), "set")
2000        self.assertEqual(g(t), "sequence")
2001        g.register(frozenset, lambda obj: "frozen-set")
2002        self.assertEqual(g(d), "dict")
2003        self.assertEqual(g(l), "list")
2004        self.assertEqual(g(s), "concrete-set")
2005        self.assertEqual(g(f), "frozen-set")
2006        self.assertEqual(g(t), "sequence")
2007        g.register(tuple, lambda obj: "tuple")
2008        self.assertEqual(g(d), "dict")
2009        self.assertEqual(g(l), "list")
2010        self.assertEqual(g(s), "concrete-set")
2011        self.assertEqual(g(f), "frozen-set")
2012        self.assertEqual(g(t), "tuple")
2013
2014    def test_c3_abc(self):
2015        c = collections.abc
2016        mro = functools._c3_mro
2017        class A(object):
2018            pass
2019        class B(A):
2020            def __len__(self):
2021                return 0   # implies Sized
2022        @c.Container.register
2023        class C(object):
2024            pass
2025        class D(object):
2026            pass   # unrelated
2027        class X(D, C, B):
2028            def __call__(self):
2029                pass   # implies Callable
2030        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2031        for abcs in permutations([c.Sized, c.Callable, c.Container]):
2032            self.assertEqual(mro(X, abcs=abcs), expected)
2033        # unrelated ABCs don't appear in the resulting MRO
2034        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2035        self.assertEqual(mro(X, abcs=many_abcs), expected)
2036
2037    def test_false_meta(self):
2038        # see issue23572
2039        class MetaA(type):
2040            def __len__(self):
2041                return 0
2042        class A(metaclass=MetaA):
2043            pass
2044        class AA(A):
2045            pass
2046        @functools.singledispatch
2047        def fun(a):
2048            return 'base A'
2049        @fun.register(A)
2050        def _(a):
2051            return 'fun A'
2052        aa = AA()
2053        self.assertEqual(fun(aa), 'fun A')
2054
2055    def test_mro_conflicts(self):
2056        c = collections.abc
2057        @functools.singledispatch
2058        def g(arg):
2059            return "base"
2060        class O(c.Sized):
2061            def __len__(self):
2062                return 0
2063        o = O()
2064        self.assertEqual(g(o), "base")
2065        g.register(c.Iterable, lambda arg: "iterable")
2066        g.register(c.Container, lambda arg: "container")
2067        g.register(c.Sized, lambda arg: "sized")
2068        g.register(c.Set, lambda arg: "set")
2069        self.assertEqual(g(o), "sized")
2070        c.Iterable.register(O)
2071        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
2072        c.Container.register(O)
2073        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
2074        c.Set.register(O)
2075        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
2076                                          # c.Sized and c.Container
2077        class P:
2078            pass
2079        p = P()
2080        self.assertEqual(g(p), "base")
2081        c.Iterable.register(P)
2082        self.assertEqual(g(p), "iterable")
2083        c.Container.register(P)
2084        with self.assertRaises(RuntimeError) as re_one:
2085            g(p)
2086        self.assertIn(
2087            str(re_one.exception),
2088            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2089              "or <class 'collections.abc.Iterable'>"),
2090             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2091              "or <class 'collections.abc.Container'>")),
2092        )
2093        class Q(c.Sized):
2094            def __len__(self):
2095                return 0
2096        q = Q()
2097        self.assertEqual(g(q), "sized")
2098        c.Iterable.register(Q)
2099        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2100        c.Set.register(Q)
2101        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2102                                          # c.Sized and c.Iterable
2103        @functools.singledispatch
2104        def h(arg):
2105            return "base"
2106        @h.register(c.Sized)
2107        def _(arg):
2108            return "sized"
2109        @h.register(c.Container)
2110        def _(arg):
2111            return "container"
2112        # Even though Sized and Container are explicit bases of MutableMapping,
2113        # this ABC is implicitly registered on defaultdict which makes all of
2114        # MutableMapping's bases implicit as well from defaultdict's
2115        # perspective.
2116        with self.assertRaises(RuntimeError) as re_two:
2117            h(collections.defaultdict(lambda: 0))
2118        self.assertIn(
2119            str(re_two.exception),
2120            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2121              "or <class 'collections.abc.Sized'>"),
2122             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2123              "or <class 'collections.abc.Container'>")),
2124        )
2125        class R(collections.defaultdict):
2126            pass
2127        c.MutableSequence.register(R)
2128        @functools.singledispatch
2129        def i(arg):
2130            return "base"
2131        @i.register(c.MutableMapping)
2132        def _(arg):
2133            return "mapping"
2134        @i.register(c.MutableSequence)
2135        def _(arg):
2136            return "sequence"
2137        r = R()
2138        self.assertEqual(i(r), "sequence")
2139        class S:
2140            pass
2141        class T(S, c.Sized):
2142            def __len__(self):
2143                return 0
2144        t = T()
2145        self.assertEqual(h(t), "sized")
2146        c.Container.register(T)
2147        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2148        class U:
2149            def __len__(self):
2150                return 0
2151        u = U()
2152        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2153                                          # from the existence of __len__()
2154        c.Container.register(U)
2155        # There is no preference for registered versus inferred ABCs.
2156        with self.assertRaises(RuntimeError) as re_three:
2157            h(u)
2158        self.assertIn(
2159            str(re_three.exception),
2160            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2161              "or <class 'collections.abc.Sized'>"),
2162             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2163              "or <class 'collections.abc.Container'>")),
2164        )
2165        class V(c.Sized, S):
2166            def __len__(self):
2167                return 0
2168        @functools.singledispatch
2169        def j(arg):
2170            return "base"
2171        @j.register(S)
2172        def _(arg):
2173            return "s"
2174        @j.register(c.Container)
2175        def _(arg):
2176            return "container"
2177        v = V()
2178        self.assertEqual(j(v), "s")
2179        c.Container.register(V)
2180        self.assertEqual(j(v), "container")   # because it ends up right after
2181                                              # Sized in the MRO
2182
2183    def test_cache_invalidation(self):
2184        from collections import UserDict
2185        import weakref
2186
2187        class TracingDict(UserDict):
2188            def __init__(self, *args, **kwargs):
2189                super(TracingDict, self).__init__(*args, **kwargs)
2190                self.set_ops = []
2191                self.get_ops = []
2192            def __getitem__(self, key):
2193                result = self.data[key]
2194                self.get_ops.append(key)
2195                return result
2196            def __setitem__(self, key, value):
2197                self.set_ops.append(key)
2198                self.data[key] = value
2199            def clear(self):
2200                self.data.clear()
2201
2202        td = TracingDict()
2203        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2204            c = collections.abc
2205            @functools.singledispatch
2206            def g(arg):
2207                return "base"
2208            d = {}
2209            l = []
2210            self.assertEqual(len(td), 0)
2211            self.assertEqual(g(d), "base")
2212            self.assertEqual(len(td), 1)
2213            self.assertEqual(td.get_ops, [])
2214            self.assertEqual(td.set_ops, [dict])
2215            self.assertEqual(td.data[dict], g.registry[object])
2216            self.assertEqual(g(l), "base")
2217            self.assertEqual(len(td), 2)
2218            self.assertEqual(td.get_ops, [])
2219            self.assertEqual(td.set_ops, [dict, list])
2220            self.assertEqual(td.data[dict], g.registry[object])
2221            self.assertEqual(td.data[list], g.registry[object])
2222            self.assertEqual(td.data[dict], td.data[list])
2223            self.assertEqual(g(l), "base")
2224            self.assertEqual(g(d), "base")
2225            self.assertEqual(td.get_ops, [list, dict])
2226            self.assertEqual(td.set_ops, [dict, list])
2227            g.register(list, lambda arg: "list")
2228            self.assertEqual(td.get_ops, [list, dict])
2229            self.assertEqual(len(td), 0)
2230            self.assertEqual(g(d), "base")
2231            self.assertEqual(len(td), 1)
2232            self.assertEqual(td.get_ops, [list, dict])
2233            self.assertEqual(td.set_ops, [dict, list, dict])
2234            self.assertEqual(td.data[dict],
2235                             functools._find_impl(dict, g.registry))
2236            self.assertEqual(g(l), "list")
2237            self.assertEqual(len(td), 2)
2238            self.assertEqual(td.get_ops, [list, dict])
2239            self.assertEqual(td.set_ops, [dict, list, dict, list])
2240            self.assertEqual(td.data[list],
2241                             functools._find_impl(list, g.registry))
2242            class X:
2243                pass
2244            c.MutableMapping.register(X)   # Will not invalidate the cache,
2245                                           # not using ABCs yet.
2246            self.assertEqual(g(d), "base")
2247            self.assertEqual(g(l), "list")
2248            self.assertEqual(td.get_ops, [list, dict, dict, list])
2249            self.assertEqual(td.set_ops, [dict, list, dict, list])
2250            g.register(c.Sized, lambda arg: "sized")
2251            self.assertEqual(len(td), 0)
2252            self.assertEqual(g(d), "sized")
2253            self.assertEqual(len(td), 1)
2254            self.assertEqual(td.get_ops, [list, dict, dict, list])
2255            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2256            self.assertEqual(g(l), "list")
2257            self.assertEqual(len(td), 2)
2258            self.assertEqual(td.get_ops, [list, dict, dict, list])
2259            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2260            self.assertEqual(g(l), "list")
2261            self.assertEqual(g(d), "sized")
2262            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2263            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2264            g.dispatch(list)
2265            g.dispatch(dict)
2266            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2267                                          list, dict])
2268            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2269            c.MutableSet.register(X)       # Will invalidate the cache.
2270            self.assertEqual(len(td), 2)   # Stale cache.
2271            self.assertEqual(g(l), "list")
2272            self.assertEqual(len(td), 1)
2273            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2274            self.assertEqual(len(td), 0)
2275            self.assertEqual(g(d), "mutablemapping")
2276            self.assertEqual(len(td), 1)
2277            self.assertEqual(g(l), "list")
2278            self.assertEqual(len(td), 2)
2279            g.register(dict, lambda arg: "dict")
2280            self.assertEqual(g(d), "dict")
2281            self.assertEqual(g(l), "list")
2282            g._clear_cache()
2283            self.assertEqual(len(td), 0)
2284
2285    def test_annotations(self):
2286        @functools.singledispatch
2287        def i(arg):
2288            return "base"
2289        @i.register
2290        def _(arg: collections.abc.Mapping):
2291            return "mapping"
2292        @i.register
2293        def _(arg: "collections.abc.Sequence"):
2294            return "sequence"
2295        self.assertEqual(i(None), "base")
2296        self.assertEqual(i({"a": 1}), "mapping")
2297        self.assertEqual(i([1, 2, 3]), "sequence")
2298        self.assertEqual(i((1, 2, 3)), "sequence")
2299        self.assertEqual(i("str"), "sequence")
2300
2301        # Registering classes as callables doesn't work with annotations,
2302        # you need to pass the type explicitly.
2303        @i.register(str)
2304        class _:
2305            def __init__(self, arg):
2306                self.arg = arg
2307
2308            def __eq__(self, other):
2309                return self.arg == other
2310        self.assertEqual(i("str"), "str")
2311
2312    def test_method_register(self):
2313        class A:
2314            @functools.singledispatchmethod
2315            def t(self, arg):
2316                self.arg = "base"
2317            @t.register(int)
2318            def _(self, arg):
2319                self.arg = "int"
2320            @t.register(str)
2321            def _(self, arg):
2322                self.arg = "str"
2323        a = A()
2324
2325        a.t(0)
2326        self.assertEqual(a.arg, "int")
2327        aa = A()
2328        self.assertFalse(hasattr(aa, 'arg'))
2329        a.t('')
2330        self.assertEqual(a.arg, "str")
2331        aa = A()
2332        self.assertFalse(hasattr(aa, 'arg'))
2333        a.t(0.0)
2334        self.assertEqual(a.arg, "base")
2335        aa = A()
2336        self.assertFalse(hasattr(aa, 'arg'))
2337
2338    def test_staticmethod_register(self):
2339        class A:
2340            @functools.singledispatchmethod
2341            @staticmethod
2342            def t(arg):
2343                return arg
2344            @t.register(int)
2345            @staticmethod
2346            def _(arg):
2347                return isinstance(arg, int)
2348            @t.register(str)
2349            @staticmethod
2350            def _(arg):
2351                return isinstance(arg, str)
2352        a = A()
2353
2354        self.assertTrue(A.t(0))
2355        self.assertTrue(A.t(''))
2356        self.assertEqual(A.t(0.0), 0.0)
2357
2358    def test_classmethod_register(self):
2359        class A:
2360            def __init__(self, arg):
2361                self.arg = arg
2362
2363            @functools.singledispatchmethod
2364            @classmethod
2365            def t(cls, arg):
2366                return cls("base")
2367            @t.register(int)
2368            @classmethod
2369            def _(cls, arg):
2370                return cls("int")
2371            @t.register(str)
2372            @classmethod
2373            def _(cls, arg):
2374                return cls("str")
2375
2376        self.assertEqual(A.t(0).arg, "int")
2377        self.assertEqual(A.t('').arg, "str")
2378        self.assertEqual(A.t(0.0).arg, "base")
2379
2380    def test_callable_register(self):
2381        class A:
2382            def __init__(self, arg):
2383                self.arg = arg
2384
2385            @functools.singledispatchmethod
2386            @classmethod
2387            def t(cls, arg):
2388                return cls("base")
2389
2390        @A.t.register(int)
2391        @classmethod
2392        def _(cls, arg):
2393            return cls("int")
2394        @A.t.register(str)
2395        @classmethod
2396        def _(cls, arg):
2397            return cls("str")
2398
2399        self.assertEqual(A.t(0).arg, "int")
2400        self.assertEqual(A.t('').arg, "str")
2401        self.assertEqual(A.t(0.0).arg, "base")
2402
2403    def test_abstractmethod_register(self):
2404        class Abstract(metaclass=abc.ABCMeta):
2405
2406            @functools.singledispatchmethod
2407            @abc.abstractmethod
2408            def add(self, x, y):
2409                pass
2410
2411        self.assertTrue(Abstract.add.__isabstractmethod__)
2412        self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
2413
2414        with self.assertRaises(TypeError):
2415            Abstract()
2416
2417    def test_type_ann_register(self):
2418        class A:
2419            @functools.singledispatchmethod
2420            def t(self, arg):
2421                return "base"
2422            @t.register
2423            def _(self, arg: int):
2424                return "int"
2425            @t.register
2426            def _(self, arg: str):
2427                return "str"
2428        a = A()
2429
2430        self.assertEqual(a.t(0), "int")
2431        self.assertEqual(a.t(''), "str")
2432        self.assertEqual(a.t(0.0), "base")
2433
2434    def test_staticmethod_type_ann_register(self):
2435        class A:
2436            @functools.singledispatchmethod
2437            @staticmethod
2438            def t(arg):
2439                return arg
2440            @t.register
2441            @staticmethod
2442            def _(arg: int):
2443                return isinstance(arg, int)
2444            @t.register
2445            @staticmethod
2446            def _(arg: str):
2447                return isinstance(arg, str)
2448        a = A()
2449
2450        self.assertTrue(A.t(0))
2451        self.assertTrue(A.t(''))
2452        self.assertEqual(A.t(0.0), 0.0)
2453
2454    def test_classmethod_type_ann_register(self):
2455        class A:
2456            def __init__(self, arg):
2457                self.arg = arg
2458
2459            @functools.singledispatchmethod
2460            @classmethod
2461            def t(cls, arg):
2462                return cls("base")
2463            @t.register
2464            @classmethod
2465            def _(cls, arg: int):
2466                return cls("int")
2467            @t.register
2468            @classmethod
2469            def _(cls, arg: str):
2470                return cls("str")
2471
2472        self.assertEqual(A.t(0).arg, "int")
2473        self.assertEqual(A.t('').arg, "str")
2474        self.assertEqual(A.t(0.0).arg, "base")
2475
2476    def test_method_wrapping_attributes(self):
2477        class A:
2478            @functools.singledispatchmethod
2479            def func(self, arg: int) -> str:
2480                """My function docstring"""
2481                return str(arg)
2482            @functools.singledispatchmethod
2483            @classmethod
2484            def cls_func(cls, arg: int) -> str:
2485                """My function docstring"""
2486                return str(arg)
2487            @functools.singledispatchmethod
2488            @staticmethod
2489            def static_func(arg: int) -> str:
2490                """My function docstring"""
2491                return str(arg)
2492
2493        for meth in (
2494            A.func,
2495            A().func,
2496            A.cls_func,
2497            A().cls_func,
2498            A.static_func,
2499            A().static_func
2500        ):
2501            with self.subTest(meth=meth):
2502                self.assertEqual(meth.__doc__, 'My function docstring')
2503                self.assertEqual(meth.__annotations__['arg'], int)
2504
2505        self.assertEqual(A.func.__name__, 'func')
2506        self.assertEqual(A().func.__name__, 'func')
2507        self.assertEqual(A.cls_func.__name__, 'cls_func')
2508        self.assertEqual(A().cls_func.__name__, 'cls_func')
2509        self.assertEqual(A.static_func.__name__, 'static_func')
2510        self.assertEqual(A().static_func.__name__, 'static_func')
2511
2512    def test_double_wrapped_methods(self):
2513        def classmethod_friendly_decorator(func):
2514            wrapped = func.__func__
2515            @classmethod
2516            @functools.wraps(wrapped)
2517            def wrapper(*args, **kwargs):
2518                return wrapped(*args, **kwargs)
2519            return wrapper
2520
2521        class WithoutSingleDispatch:
2522            @classmethod
2523            @contextlib.contextmanager
2524            def cls_context_manager(cls, arg: int) -> str:
2525                try:
2526                    yield str(arg)
2527                finally:
2528                    return 'Done'
2529
2530            @classmethod_friendly_decorator
2531            @classmethod
2532            def decorated_classmethod(cls, arg: int) -> str:
2533                return str(arg)
2534
2535        class WithSingleDispatch:
2536            @functools.singledispatchmethod
2537            @classmethod
2538            @contextlib.contextmanager
2539            def cls_context_manager(cls, arg: int) -> str:
2540                """My function docstring"""
2541                try:
2542                    yield str(arg)
2543                finally:
2544                    return 'Done'
2545
2546            @functools.singledispatchmethod
2547            @classmethod_friendly_decorator
2548            @classmethod
2549            def decorated_classmethod(cls, arg: int) -> str:
2550                """My function docstring"""
2551                return str(arg)
2552
2553        # These are sanity checks
2554        # to test the test itself is working as expected
2555        with WithoutSingleDispatch.cls_context_manager(5) as foo:
2556            without_single_dispatch_foo = foo
2557
2558        with WithSingleDispatch.cls_context_manager(5) as foo:
2559            single_dispatch_foo = foo
2560
2561        self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
2562        self.assertEqual(single_dispatch_foo, '5')
2563
2564        self.assertEqual(
2565            WithoutSingleDispatch.decorated_classmethod(5),
2566            WithSingleDispatch.decorated_classmethod(5)
2567        )
2568
2569        self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
2570
2571        # Behavioural checks now follow
2572        for method_name in ('cls_context_manager', 'decorated_classmethod'):
2573            with self.subTest(method=method_name):
2574                self.assertEqual(
2575                    getattr(WithSingleDispatch, method_name).__name__,
2576                    getattr(WithoutSingleDispatch, method_name).__name__
2577                )
2578
2579                self.assertEqual(
2580                    getattr(WithSingleDispatch(), method_name).__name__,
2581                    getattr(WithoutSingleDispatch(), method_name).__name__
2582                )
2583
2584        for meth in (
2585            WithSingleDispatch.cls_context_manager,
2586            WithSingleDispatch().cls_context_manager,
2587            WithSingleDispatch.decorated_classmethod,
2588            WithSingleDispatch().decorated_classmethod
2589        ):
2590            with self.subTest(meth=meth):
2591                self.assertEqual(meth.__doc__, 'My function docstring')
2592                self.assertEqual(meth.__annotations__['arg'], int)
2593
2594        self.assertEqual(
2595            WithSingleDispatch.cls_context_manager.__name__,
2596            'cls_context_manager'
2597        )
2598        self.assertEqual(
2599            WithSingleDispatch().cls_context_manager.__name__,
2600            'cls_context_manager'
2601        )
2602        self.assertEqual(
2603            WithSingleDispatch.decorated_classmethod.__name__,
2604            'decorated_classmethod'
2605        )
2606        self.assertEqual(
2607            WithSingleDispatch().decorated_classmethod.__name__,
2608            'decorated_classmethod'
2609        )
2610
2611    def test_invalid_registrations(self):
2612        msg_prefix = "Invalid first argument to `register()`: "
2613        msg_suffix = (
2614            ". Use either `@register(some_class)` or plain `@register` on an "
2615            "annotated function."
2616        )
2617        @functools.singledispatch
2618        def i(arg):
2619            return "base"
2620        with self.assertRaises(TypeError) as exc:
2621            @i.register(42)
2622            def _(arg):
2623                return "I annotated with a non-type"
2624        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2625        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2626        with self.assertRaises(TypeError) as exc:
2627            @i.register
2628            def _(arg):
2629                return "I forgot to annotate"
2630        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2631            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2632        ))
2633        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2634
2635        with self.assertRaises(TypeError) as exc:
2636            @i.register
2637            def _(arg: typing.Iterable[str]):
2638                # At runtime, dispatching on generics is impossible.
2639                # When registering implementations with singledispatch, avoid
2640                # types from `typing`. Instead, annotate with regular types
2641                # or ABCs.
2642                return "I annotated with a generic collection"
2643        self.assertTrue(str(exc.exception).startswith(
2644            "Invalid annotation for 'arg'."
2645        ))
2646        self.assertTrue(str(exc.exception).endswith(
2647            'typing.Iterable[str] is not a class.'
2648        ))
2649
2650    def test_invalid_positional_argument(self):
2651        @functools.singledispatch
2652        def f(*args):
2653            pass
2654        msg = 'f requires at least 1 positional argument'
2655        with self.assertRaisesRegex(TypeError, msg):
2656            f()
2657
2658
2659class CachedCostItem:
2660    _cost = 1
2661
2662    def __init__(self):
2663        self.lock = py_functools.RLock()
2664
2665    @py_functools.cached_property
2666    def cost(self):
2667        """The cost of the item."""
2668        with self.lock:
2669            self._cost += 1
2670        return self._cost
2671
2672
2673class OptionallyCachedCostItem:
2674    _cost = 1
2675
2676    def get_cost(self):
2677        """The cost of the item."""
2678        self._cost += 1
2679        return self._cost
2680
2681    cached_cost = py_functools.cached_property(get_cost)
2682
2683
2684class CachedCostItemWait:
2685
2686    def __init__(self, event):
2687        self._cost = 1
2688        self.lock = py_functools.RLock()
2689        self.event = event
2690
2691    @py_functools.cached_property
2692    def cost(self):
2693        self.event.wait(1)
2694        with self.lock:
2695            self._cost += 1
2696        return self._cost
2697
2698
2699class CachedCostItemWithSlots:
2700    __slots__ = ('_cost')
2701
2702    def __init__(self):
2703        self._cost = 1
2704
2705    @py_functools.cached_property
2706    def cost(self):
2707        raise RuntimeError('never called, slots not supported')
2708
2709
2710class TestCachedProperty(unittest.TestCase):
2711    def test_cached(self):
2712        item = CachedCostItem()
2713        self.assertEqual(item.cost, 2)
2714        self.assertEqual(item.cost, 2) # not 3
2715
2716    def test_cached_attribute_name_differs_from_func_name(self):
2717        item = OptionallyCachedCostItem()
2718        self.assertEqual(item.get_cost(), 2)
2719        self.assertEqual(item.cached_cost, 3)
2720        self.assertEqual(item.get_cost(), 4)
2721        self.assertEqual(item.cached_cost, 3)
2722
2723    def test_threaded(self):
2724        go = threading.Event()
2725        item = CachedCostItemWait(go)
2726
2727        num_threads = 3
2728
2729        orig_si = sys.getswitchinterval()
2730        sys.setswitchinterval(1e-6)
2731        try:
2732            threads = [
2733                threading.Thread(target=lambda: item.cost)
2734                for k in range(num_threads)
2735            ]
2736            with support.start_threads(threads):
2737                go.set()
2738        finally:
2739            sys.setswitchinterval(orig_si)
2740
2741        self.assertEqual(item.cost, 2)
2742
2743    def test_object_with_slots(self):
2744        item = CachedCostItemWithSlots()
2745        with self.assertRaisesRegex(
2746                TypeError,
2747                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2748        ):
2749            item.cost
2750
2751    def test_immutable_dict(self):
2752        class MyMeta(type):
2753            @py_functools.cached_property
2754            def prop(self):
2755                return True
2756
2757        class MyClass(metaclass=MyMeta):
2758            pass
2759
2760        with self.assertRaisesRegex(
2761            TypeError,
2762            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2763        ):
2764            MyClass.prop
2765
2766    def test_reuse_different_names(self):
2767        """Disallow this case because decorated function a would not be cached."""
2768        with self.assertRaises(RuntimeError) as ctx:
2769            class ReusedCachedProperty:
2770                @py_functools.cached_property
2771                def a(self):
2772                    pass
2773
2774                b = a
2775
2776        self.assertEqual(
2777            str(ctx.exception.__context__),
2778            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2779        )
2780
2781    def test_reuse_same_name(self):
2782        """Reusing a cached_property on different classes under the same name is OK."""
2783        counter = 0
2784
2785        @py_functools.cached_property
2786        def _cp(_self):
2787            nonlocal counter
2788            counter += 1
2789            return counter
2790
2791        class A:
2792            cp = _cp
2793
2794        class B:
2795            cp = _cp
2796
2797        a = A()
2798        b = B()
2799
2800        self.assertEqual(a.cp, 1)
2801        self.assertEqual(b.cp, 2)
2802        self.assertEqual(a.cp, 1)
2803
2804    def test_set_name_not_called(self):
2805        cp = py_functools.cached_property(lambda s: None)
2806        class Foo:
2807            pass
2808
2809        Foo.cp = cp
2810
2811        with self.assertRaisesRegex(
2812                TypeError,
2813                "Cannot use cached_property instance without calling __set_name__ on it.",
2814        ):
2815            Foo().cp
2816
2817    def test_access_from_class(self):
2818        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2819
2820    def test_doc(self):
2821        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2822
2823
2824if __name__ == '__main__':
2825    unittest.main()
2826