1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16import os
17import weakref
18import gc
19from weakref import proxy
20import contextlib
21
22from test.support import import_helper
23from test.support import threading_helper
24from test.support.script_helper import assert_python_ok
25
26import functools
27
28py_functools = import_helper.import_fresh_module('functools',
29                                                 blocked=['_functools'])
30c_functools = import_helper.import_fresh_module('functools')
31
32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
33
34@contextlib.contextmanager
35def replaced_module(name, replacement):
36    original_module = sys.modules[name]
37    sys.modules[name] = replacement
38    try:
39        yield
40    finally:
41        sys.modules[name] = original_module
42
43def capture(*args, **kw):
44    """capture all positional and keyword arguments"""
45    return args, kw
46
47
48def signature(part):
49    """ return the signature of a partial object """
50    return (part.func, part.args, part.keywords, part.__dict__)
51
52class MyTuple(tuple):
53    pass
54
55class BadTuple(tuple):
56    def __add__(self, other):
57        return list(self) + list(other)
58
59class MyDict(dict):
60    pass
61
62
63class TestPartial:
64
65    def test_basic_examples(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        self.assertTrue(callable(p))
68        self.assertEqual(p(3, 4, b=30, c=40),
69                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
70        p = self.partial(map, lambda x: x*10)
71        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
72
73    def test_attributes(self):
74        p = self.partial(capture, 1, 2, a=10, b=20)
75        # attributes should be readable
76        self.assertEqual(p.func, capture)
77        self.assertEqual(p.args, (1, 2))
78        self.assertEqual(p.keywords, dict(a=10, b=20))
79
80    def test_argument_checking(self):
81        self.assertRaises(TypeError, self.partial)     # need at least a func arg
82        try:
83            self.partial(2)()
84        except TypeError:
85            pass
86        else:
87            self.fail('First arg not checked for callability')
88
89    def test_protection_of_callers_dict_argument(self):
90        # a caller's dictionary should not be altered by partial
91        def func(a=10, b=20):
92            return a
93        d = {'a':3}
94        p = self.partial(func, a=5)
95        self.assertEqual(p(**d), 3)
96        self.assertEqual(d, {'a':3})
97        p(b=7)
98        self.assertEqual(d, {'a':3})
99
100    def test_kwargs_copy(self):
101        # Issue #29532: Altering a kwarg dictionary passed to a constructor
102        # should not affect a partial object after creation
103        d = {'a': 3}
104        p = self.partial(capture, **d)
105        self.assertEqual(p(), ((), {'a': 3}))
106        d['a'] = 5
107        self.assertEqual(p(), ((), {'a': 3}))
108
109    def test_arg_combinations(self):
110        # exercise special code paths for zero args in either partial
111        # object or the caller
112        p = self.partial(capture)
113        self.assertEqual(p(), ((), {}))
114        self.assertEqual(p(1,2), ((1,2), {}))
115        p = self.partial(capture, 1, 2)
116        self.assertEqual(p(), ((1,2), {}))
117        self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119    def test_kw_combinations(self):
120        # exercise special code paths for no keyword args in
121        # either the partial object or the caller
122        p = self.partial(capture)
123        self.assertEqual(p.keywords, {})
124        self.assertEqual(p(), ((), {}))
125        self.assertEqual(p(a=1), ((), {'a':1}))
126        p = self.partial(capture, a=1)
127        self.assertEqual(p.keywords, {'a':1})
128        self.assertEqual(p(), ((), {'a':1}))
129        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130        # keyword args in the call override those in the partial object
131        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133    def test_positional(self):
134        # make sure positional arguments are captured correctly
135        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
136            p = self.partial(capture, *args)
137            expected = args + ('x',)
138            got, empty = p('x')
139            self.assertTrue(expected == got and empty == {})
140
141    def test_keyword(self):
142        # make sure keyword arguments are captured correctly
143        for a in ['a', 0, None, 3.5]:
144            p = self.partial(capture, a=a)
145            expected = {'a':a,'x':None}
146            empty, got = p(x=None)
147            self.assertTrue(expected == got and empty == ())
148
149    def test_no_side_effects(self):
150        # make sure there are no side effects that affect subsequent calls
151        p = self.partial(capture, 0, a=1)
152        args1, kw1 = p(1, b=2)
153        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
154        args2, kw2 = p()
155        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
156
157    def test_error_propagation(self):
158        def f(x, y):
159            x / y
160        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
164
165    def test_weakref(self):
166        f = self.partial(int, base=16)
167        p = proxy(f)
168        self.assertEqual(f.func, p.func)
169        f = None
170        support.gc_collect()  # For PyPy or other GCs.
171        self.assertRaises(ReferenceError, getattr, p, 'func')
172
173    def test_with_bound_and_unbound_methods(self):
174        data = list(map(str, range(10)))
175        join = self.partial(str.join, '')
176        self.assertEqual(join(data), '0123456789')
177        join = self.partial(''.join)
178        self.assertEqual(join(data), '0123456789')
179
180    def test_nested_optimization(self):
181        partial = self.partial
182        inner = partial(signature, 'asdf')
183        nested = partial(inner, bar=True)
184        flat = partial(signature, 'asdf', bar=True)
185        self.assertEqual(signature(nested), signature(flat))
186
187    def test_nested_partial_with_attribute(self):
188        # see issue 25137
189        partial = self.partial
190
191        def foo(bar):
192            return bar
193
194        p = partial(foo, 'first')
195        p2 = partial(p, 'second')
196        p2.new_attr = 'spam'
197        self.assertEqual(p2.new_attr, 'spam')
198
199    def test_repr(self):
200        args = (object(), object())
201        args_repr = ', '.join(repr(a) for a in args)
202        kwargs = {'a': object(), 'b': object()}
203        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
204                        'b={b!r}, a={a!r}'.format_map(kwargs)]
205        if self.partial in (c_functools.partial, py_functools.partial):
206            name = 'functools.partial'
207        else:
208            name = self.partial.__name__
209
210        f = self.partial(capture)
211        self.assertEqual(f'{name}({capture!r})', repr(f))
212
213        f = self.partial(capture, *args)
214        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
215
216        f = self.partial(capture, **kwargs)
217        self.assertIn(repr(f),
218                      [f'{name}({capture!r}, {kwargs_repr})'
219                       for kwargs_repr in kwargs_reprs])
220
221        f = self.partial(capture, *args, **kwargs)
222        self.assertIn(repr(f),
223                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
224                       for kwargs_repr in kwargs_reprs])
225
226    def test_recursive_repr(self):
227        if self.partial in (c_functools.partial, py_functools.partial):
228            name = 'functools.partial'
229        else:
230            name = self.partial.__name__
231
232        f = self.partial(capture)
233        f.__setstate__((f, (), {}, {}))
234        try:
235            self.assertEqual(repr(f), '%s(...)' % (name,))
236        finally:
237            f.__setstate__((capture, (), {}, {}))
238
239        f = self.partial(capture)
240        f.__setstate__((capture, (f,), {}, {}))
241        try:
242            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
243        finally:
244            f.__setstate__((capture, (), {}, {}))
245
246        f = self.partial(capture)
247        f.__setstate__((capture, (), {'a': f}, {}))
248        try:
249            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
250        finally:
251            f.__setstate__((capture, (), {}, {}))
252
253    def test_pickle(self):
254        with self.AllowPickle():
255            f = self.partial(signature, ['asdf'], bar=[True])
256            f.attr = []
257            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
258                f_copy = pickle.loads(pickle.dumps(f, proto))
259                self.assertEqual(signature(f_copy), signature(f))
260
261    def test_copy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.copy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIs(f_copy.attr, f.attr)
267        self.assertIs(f_copy.args, f.args)
268        self.assertIs(f_copy.keywords, f.keywords)
269
270    def test_deepcopy(self):
271        f = self.partial(signature, ['asdf'], bar=[True])
272        f.attr = []
273        f_copy = copy.deepcopy(f)
274        self.assertEqual(signature(f_copy), signature(f))
275        self.assertIsNot(f_copy.attr, f.attr)
276        self.assertIsNot(f_copy.args, f.args)
277        self.assertIsNot(f_copy.args[0], f.args[0])
278        self.assertIsNot(f_copy.keywords, f.keywords)
279        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
280
281    def test_setstate(self):
282        f = self.partial(signature)
283        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
284
285        self.assertEqual(signature(f),
286                         (capture, (1,), dict(a=10), dict(attr=[])))
287        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
288
289        f.__setstate__((capture, (1,), dict(a=10), None))
290
291        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
292        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
293
294        f.__setstate__((capture, (1,), None, None))
295        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
296        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
297        self.assertEqual(f(2), ((1, 2), {}))
298        self.assertEqual(f(), ((1,), {}))
299
300        f.__setstate__((capture, (), {}, None))
301        self.assertEqual(signature(f), (capture, (), {}, {}))
302        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
303        self.assertEqual(f(2), ((2,), {}))
304        self.assertEqual(f(), ((), {}))
305
306    def test_setstate_errors(self):
307        f = self.partial(signature)
308        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
309        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
310        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
311        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
312        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
313        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
314        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
315
316    def test_setstate_subclasses(self):
317        f = self.partial(signature)
318        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
319        s = signature(f)
320        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
321        self.assertIs(type(s[1]), tuple)
322        self.assertIs(type(s[2]), dict)
323        r = f()
324        self.assertEqual(r, ((1,), {'a': 10}))
325        self.assertIs(type(r[0]), tuple)
326        self.assertIs(type(r[1]), dict)
327
328        f.__setstate__((capture, BadTuple((1,)), {}, None))
329        s = signature(f)
330        self.assertEqual(s, (capture, (1,), {}, {}))
331        self.assertIs(type(s[1]), tuple)
332        r = f(2)
333        self.assertEqual(r, ((1, 2), {}))
334        self.assertIs(type(r[0]), tuple)
335
336    def test_recursive_pickle(self):
337        with self.AllowPickle():
338            f = self.partial(capture)
339            f.__setstate__((f, (), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    with self.assertRaises(RecursionError):
343                        pickle.dumps(f, proto)
344            finally:
345                f.__setstate__((capture, (), {}, {}))
346
347            f = self.partial(capture)
348            f.__setstate__((capture, (f,), {}, {}))
349            try:
350                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
351                    f_copy = pickle.loads(pickle.dumps(f, proto))
352                    try:
353                        self.assertIs(f_copy.args[0], f_copy)
354                    finally:
355                        f_copy.__setstate__((capture, (), {}, {}))
356            finally:
357                f.__setstate__((capture, (), {}, {}))
358
359            f = self.partial(capture)
360            f.__setstate__((capture, (), {'a': f}, {}))
361            try:
362                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
363                    f_copy = pickle.loads(pickle.dumps(f, proto))
364                    try:
365                        self.assertIs(f_copy.keywords['a'], f_copy)
366                    finally:
367                        f_copy.__setstate__((capture, (), {}, {}))
368            finally:
369                f.__setstate__((capture, (), {}, {}))
370
371    # Issue 6083: Reference counting bug
372    def test_setstate_refcount(self):
373        class BadSequence:
374            def __len__(self):
375                return 4
376            def __getitem__(self, key):
377                if key == 0:
378                    return max
379                elif key == 1:
380                    return tuple(range(1000000))
381                elif key in (2, 3):
382                    return {}
383                raise IndexError
384
385        f = self.partial(object)
386        self.assertRaises(TypeError, f.__setstate__, BadSequence())
387
388@unittest.skipUnless(c_functools, 'requires the C _functools module')
389class TestPartialC(TestPartial, unittest.TestCase):
390    if c_functools:
391        partial = c_functools.partial
392
393    class AllowPickle:
394        def __enter__(self):
395            return self
396        def __exit__(self, type, value, tb):
397            return False
398
399    def test_attributes_unwritable(self):
400        # attributes should not be writable
401        p = self.partial(capture, 1, 2, a=10, b=20)
402        self.assertRaises(AttributeError, setattr, p, 'func', map)
403        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
404        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
405
406        p = self.partial(hex)
407        try:
408            del p.__dict__
409        except TypeError:
410            pass
411        else:
412            self.fail('partial object allowed __dict__ to be deleted')
413
414    def test_manually_adding_non_string_keyword(self):
415        p = self.partial(capture)
416        # Adding a non-string/unicode keyword to partial kwargs
417        p.keywords[1234] = 'value'
418        r = repr(p)
419        self.assertIn('1234', r)
420        self.assertIn("'value'", r)
421        with self.assertRaises(TypeError):
422            p()
423
424    def test_keystr_replaces_value(self):
425        p = self.partial(capture)
426
427        class MutatesYourDict(object):
428            def __str__(self):
429                p.keywords[self] = ['sth2']
430                return 'astr'
431
432        # Replacing the value during key formatting should keep the original
433        # value alive (at least long enough).
434        p.keywords[MutatesYourDict()] = ['sth']
435        r = repr(p)
436        self.assertIn('astr', r)
437        self.assertIn("['sth']", r)
438
439
440class TestPartialPy(TestPartial, unittest.TestCase):
441    partial = py_functools.partial
442
443    class AllowPickle:
444        def __init__(self):
445            self._cm = replaced_module("functools", py_functools)
446        def __enter__(self):
447            return self._cm.__enter__()
448        def __exit__(self, type, value, tb):
449            return self._cm.__exit__(type, value, tb)
450
451if c_functools:
452    class CPartialSubclass(c_functools.partial):
453        pass
454
455class PyPartialSubclass(py_functools.partial):
456    pass
457
458@unittest.skipUnless(c_functools, 'requires the C _functools module')
459class TestPartialCSubclass(TestPartialC):
460    if c_functools:
461        partial = CPartialSubclass
462
463    # partial subclasses are not optimized for nested calls
464    test_nested_optimization = None
465
466class TestPartialPySubclass(TestPartialPy):
467    partial = PyPartialSubclass
468
469class TestPartialMethod(unittest.TestCase):
470
471    class A(object):
472        nothing = functools.partialmethod(capture)
473        positional = functools.partialmethod(capture, 1)
474        keywords = functools.partialmethod(capture, a=2)
475        both = functools.partialmethod(capture, 3, b=4)
476        spec_keywords = functools.partialmethod(capture, self=1, func=2)
477
478        nested = functools.partialmethod(positional, 5)
479
480        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
481
482        static = functools.partialmethod(staticmethod(capture), 8)
483        cls = functools.partialmethod(classmethod(capture), d=9)
484
485    a = A()
486
487    def test_arg_combinations(self):
488        self.assertEqual(self.a.nothing(), ((self.a,), {}))
489        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
490        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
491        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
492
493        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
494        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
495        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
496        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
497
498        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
499        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
500        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
501        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
502
503        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
504        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
505        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
506        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
507
508        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
509
510        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
511
512    def test_nested(self):
513        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
514        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
515        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
516        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
517
518        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
519
520    def test_over_partial(self):
521        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
522        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
523        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
524        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
525
526        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
527
528    def test_bound_method_introspection(self):
529        obj = self.a
530        self.assertIs(obj.both.__self__, obj)
531        self.assertIs(obj.nested.__self__, obj)
532        self.assertIs(obj.over_partial.__self__, obj)
533        self.assertIs(obj.cls.__self__, self.A)
534        self.assertIs(self.A.cls.__self__, self.A)
535
536    def test_unbound_method_retrieval(self):
537        obj = self.A
538        self.assertFalse(hasattr(obj.both, "__self__"))
539        self.assertFalse(hasattr(obj.nested, "__self__"))
540        self.assertFalse(hasattr(obj.over_partial, "__self__"))
541        self.assertFalse(hasattr(obj.static, "__self__"))
542        self.assertFalse(hasattr(self.a.static, "__self__"))
543
544    def test_descriptors(self):
545        for obj in [self.A, self.a]:
546            with self.subTest(obj=obj):
547                self.assertEqual(obj.static(), ((8,), {}))
548                self.assertEqual(obj.static(5), ((8, 5), {}))
549                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
550                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
551
552                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
553                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
554                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
555                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
556
557    def test_overriding_keywords(self):
558        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
559        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
560
561    def test_invalid_args(self):
562        with self.assertRaises(TypeError):
563            class B(object):
564                method = functools.partialmethod(None, 1)
565        with self.assertRaises(TypeError):
566            class B:
567                method = functools.partialmethod()
568        with self.assertRaises(TypeError):
569            class B:
570                method = functools.partialmethod(func=capture, a=1)
571
572    def test_repr(self):
573        self.assertEqual(repr(vars(self.A)['both']),
574                         'functools.partialmethod({}, 3, b=4)'.format(capture))
575
576    def test_abstract(self):
577        class Abstract(abc.ABCMeta):
578
579            @abc.abstractmethod
580            def add(self, x, y):
581                pass
582
583            add5 = functools.partialmethod(add, 5)
584
585        self.assertTrue(Abstract.add.__isabstractmethod__)
586        self.assertTrue(Abstract.add5.__isabstractmethod__)
587
588        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
589            self.assertFalse(getattr(func, '__isabstractmethod__', False))
590
591    def test_positional_only(self):
592        def f(a, b, /):
593            return a + b
594
595        p = functools.partial(f, 1)
596        self.assertEqual(p(2), f(1, 2))
597
598
599class TestUpdateWrapper(unittest.TestCase):
600
601    def check_wrapper(self, wrapper, wrapped,
602                      assigned=functools.WRAPPER_ASSIGNMENTS,
603                      updated=functools.WRAPPER_UPDATES):
604        # Check attributes were assigned
605        for name in assigned:
606            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
607        # Check attributes were updated
608        for name in updated:
609            wrapper_attr = getattr(wrapper, name)
610            wrapped_attr = getattr(wrapped, name)
611            for key in wrapped_attr:
612                if name == "__dict__" and key == "__wrapped__":
613                    # __wrapped__ is overwritten by the update code
614                    continue
615                self.assertIs(wrapped_attr[key], wrapper_attr[key])
616        # Check __wrapped__
617        self.assertIs(wrapper.__wrapped__, wrapped)
618
619
620    def _default_update(self):
621        def f(a:'This is a new annotation'):
622            """This is a test"""
623            pass
624        f.attr = 'This is also a test'
625        f.__wrapped__ = "This is a bald faced lie"
626        def wrapper(b:'This is the prior annotation'):
627            pass
628        functools.update_wrapper(wrapper, f)
629        return wrapper, f
630
631    def test_default_update(self):
632        wrapper, f = self._default_update()
633        self.check_wrapper(wrapper, f)
634        self.assertIs(wrapper.__wrapped__, f)
635        self.assertEqual(wrapper.__name__, 'f')
636        self.assertEqual(wrapper.__qualname__, f.__qualname__)
637        self.assertEqual(wrapper.attr, 'This is also a test')
638        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
639        self.assertNotIn('b', wrapper.__annotations__)
640
641    @unittest.skipIf(sys.flags.optimize >= 2,
642                     "Docstrings are omitted with -O2 and above")
643    def test_default_update_doc(self):
644        wrapper, f = self._default_update()
645        self.assertEqual(wrapper.__doc__, 'This is a test')
646
647    def test_no_update(self):
648        def f():
649            """This is a test"""
650            pass
651        f.attr = 'This is also a test'
652        def wrapper():
653            pass
654        functools.update_wrapper(wrapper, f, (), ())
655        self.check_wrapper(wrapper, f, (), ())
656        self.assertEqual(wrapper.__name__, 'wrapper')
657        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
658        self.assertEqual(wrapper.__doc__, None)
659        self.assertEqual(wrapper.__annotations__, {})
660        self.assertFalse(hasattr(wrapper, 'attr'))
661
662    def test_selective_update(self):
663        def f():
664            pass
665        f.attr = 'This is a different test'
666        f.dict_attr = dict(a=1, b=2, c=3)
667        def wrapper():
668            pass
669        wrapper.dict_attr = {}
670        assign = ('attr',)
671        update = ('dict_attr',)
672        functools.update_wrapper(wrapper, f, assign, update)
673        self.check_wrapper(wrapper, f, assign, update)
674        self.assertEqual(wrapper.__name__, 'wrapper')
675        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
676        self.assertEqual(wrapper.__doc__, None)
677        self.assertEqual(wrapper.attr, 'This is a different test')
678        self.assertEqual(wrapper.dict_attr, f.dict_attr)
679
680    def test_missing_attributes(self):
681        def f():
682            pass
683        def wrapper():
684            pass
685        wrapper.dict_attr = {}
686        assign = ('attr',)
687        update = ('dict_attr',)
688        # Missing attributes on wrapped object are ignored
689        functools.update_wrapper(wrapper, f, assign, update)
690        self.assertNotIn('attr', wrapper.__dict__)
691        self.assertEqual(wrapper.dict_attr, {})
692        # Wrapper must have expected attributes for updating
693        del wrapper.dict_attr
694        with self.assertRaises(AttributeError):
695            functools.update_wrapper(wrapper, f, assign, update)
696        wrapper.dict_attr = 1
697        with self.assertRaises(AttributeError):
698            functools.update_wrapper(wrapper, f, assign, update)
699
700    @support.requires_docstrings
701    @unittest.skipIf(sys.flags.optimize >= 2,
702                     "Docstrings are omitted with -O2 and above")
703    def test_builtin_update(self):
704        # Test for bug #1576241
705        def wrapper():
706            pass
707        functools.update_wrapper(wrapper, max)
708        self.assertEqual(wrapper.__name__, 'max')
709        self.assertTrue(wrapper.__doc__.startswith('max('))
710        self.assertEqual(wrapper.__annotations__, {})
711
712
713class TestWraps(TestUpdateWrapper):
714
715    def _default_update(self):
716        def f():
717            """This is a test"""
718            pass
719        f.attr = 'This is also a test'
720        f.__wrapped__ = "This is still a bald faced lie"
721        @functools.wraps(f)
722        def wrapper():
723            pass
724        return wrapper, f
725
726    def test_default_update(self):
727        wrapper, f = self._default_update()
728        self.check_wrapper(wrapper, f)
729        self.assertEqual(wrapper.__name__, 'f')
730        self.assertEqual(wrapper.__qualname__, f.__qualname__)
731        self.assertEqual(wrapper.attr, 'This is also a test')
732
733    @unittest.skipIf(sys.flags.optimize >= 2,
734                     "Docstrings are omitted with -O2 and above")
735    def test_default_update_doc(self):
736        wrapper, _ = self._default_update()
737        self.assertEqual(wrapper.__doc__, 'This is a test')
738
739    def test_no_update(self):
740        def f():
741            """This is a test"""
742            pass
743        f.attr = 'This is also a test'
744        @functools.wraps(f, (), ())
745        def wrapper():
746            pass
747        self.check_wrapper(wrapper, f, (), ())
748        self.assertEqual(wrapper.__name__, 'wrapper')
749        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
750        self.assertEqual(wrapper.__doc__, None)
751        self.assertFalse(hasattr(wrapper, 'attr'))
752
753    def test_selective_update(self):
754        def f():
755            pass
756        f.attr = 'This is a different test'
757        f.dict_attr = dict(a=1, b=2, c=3)
758        def add_dict_attr(f):
759            f.dict_attr = {}
760            return f
761        assign = ('attr',)
762        update = ('dict_attr',)
763        @functools.wraps(f, assign, update)
764        @add_dict_attr
765        def wrapper():
766            pass
767        self.check_wrapper(wrapper, f, assign, update)
768        self.assertEqual(wrapper.__name__, 'wrapper')
769        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
770        self.assertEqual(wrapper.__doc__, None)
771        self.assertEqual(wrapper.attr, 'This is a different test')
772        self.assertEqual(wrapper.dict_attr, f.dict_attr)
773
774
775class TestReduce:
776    def test_reduce(self):
777        class Squares:
778            def __init__(self, max):
779                self.max = max
780                self.sofar = []
781
782            def __len__(self):
783                return len(self.sofar)
784
785            def __getitem__(self, i):
786                if not 0 <= i < self.max: raise IndexError
787                n = len(self.sofar)
788                while n <= i:
789                    self.sofar.append(n*n)
790                    n += 1
791                return self.sofar[i]
792        def add(x, y):
793            return x + y
794        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
795        self.assertEqual(
796            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
797            ['a','c','d','w']
798        )
799        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
800        self.assertEqual(
801            self.reduce(lambda x, y: x*y, range(2,21), 1),
802            2432902008176640000
803        )
804        self.assertEqual(self.reduce(add, Squares(10)), 285)
805        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
806        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
807        self.assertRaises(TypeError, self.reduce)
808        self.assertRaises(TypeError, self.reduce, 42, 42)
809        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
810        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
811        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
812        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
813        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
814        self.assertRaises(TypeError, self.reduce, add, "")
815        self.assertRaises(TypeError, self.reduce, add, ())
816        self.assertRaises(TypeError, self.reduce, add, object())
817
818        class TestFailingIter:
819            def __iter__(self):
820                raise RuntimeError
821        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
822
823        self.assertEqual(self.reduce(add, [], None), None)
824        self.assertEqual(self.reduce(add, [], 42), 42)
825
826        class BadSeq:
827            def __getitem__(self, index):
828                raise ValueError
829        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
830
831    # Test reduce()'s use of iterators.
832    def test_iterator_usage(self):
833        class SequenceClass:
834            def __init__(self, n):
835                self.n = n
836            def __getitem__(self, i):
837                if 0 <= i < self.n:
838                    return i
839                else:
840                    raise IndexError
841
842        from operator import add
843        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
844        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
845        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
846        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
847        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
848        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
849
850        d = {"one": 1, "two": 2, "three": 3}
851        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
852
853
854@unittest.skipUnless(c_functools, 'requires the C _functools module')
855class TestReduceC(TestReduce, unittest.TestCase):
856    if c_functools:
857        reduce = c_functools.reduce
858
859
860class TestReducePy(TestReduce, unittest.TestCase):
861    reduce = staticmethod(py_functools.reduce)
862
863
864class TestCmpToKey:
865
866    def test_cmp_to_key(self):
867        def cmp1(x, y):
868            return (x > y) - (x < y)
869        key = self.cmp_to_key(cmp1)
870        self.assertEqual(key(3), key(3))
871        self.assertGreater(key(3), key(1))
872        self.assertGreaterEqual(key(3), key(3))
873
874        def cmp2(x, y):
875            return int(x) - int(y)
876        key = self.cmp_to_key(cmp2)
877        self.assertEqual(key(4.0), key('4'))
878        self.assertLess(key(2), key('35'))
879        self.assertLessEqual(key(2), key('35'))
880        self.assertNotEqual(key(2), key('35'))
881
882    def test_cmp_to_key_arguments(self):
883        def cmp1(x, y):
884            return (x > y) - (x < y)
885        key = self.cmp_to_key(mycmp=cmp1)
886        self.assertEqual(key(obj=3), key(obj=3))
887        self.assertGreater(key(obj=3), key(obj=1))
888        with self.assertRaises((TypeError, AttributeError)):
889            key(3) > 1    # rhs is not a K object
890        with self.assertRaises((TypeError, AttributeError)):
891            1 < key(3)    # lhs is not a K object
892        with self.assertRaises(TypeError):
893            key = self.cmp_to_key()             # too few args
894        with self.assertRaises(TypeError):
895            key = self.cmp_to_key(cmp1, None)   # too many args
896        key = self.cmp_to_key(cmp1)
897        with self.assertRaises(TypeError):
898            key()                                    # too few args
899        with self.assertRaises(TypeError):
900            key(None, None)                          # too many args
901
902    def test_bad_cmp(self):
903        def cmp1(x, y):
904            raise ZeroDivisionError
905        key = self.cmp_to_key(cmp1)
906        with self.assertRaises(ZeroDivisionError):
907            key(3) > key(1)
908
909        class BadCmp:
910            def __lt__(self, other):
911                raise ZeroDivisionError
912        def cmp1(x, y):
913            return BadCmp()
914        with self.assertRaises(ZeroDivisionError):
915            key(3) > key(1)
916
917    def test_obj_field(self):
918        def cmp1(x, y):
919            return (x > y) - (x < y)
920        key = self.cmp_to_key(mycmp=cmp1)
921        self.assertEqual(key(50).obj, 50)
922
923    def test_sort_int(self):
924        def mycmp(x, y):
925            return y - x
926        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
927                         [4, 3, 2, 1, 0])
928
929    def test_sort_int_str(self):
930        def mycmp(x, y):
931            x, y = int(x), int(y)
932            return (x > y) - (x < y)
933        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
934        values = sorted(values, key=self.cmp_to_key(mycmp))
935        self.assertEqual([int(value) for value in values],
936                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
937
938    def test_hash(self):
939        def mycmp(x, y):
940            return y - x
941        key = self.cmp_to_key(mycmp)
942        k = key(10)
943        self.assertRaises(TypeError, hash, k)
944        self.assertNotIsInstance(k, collections.abc.Hashable)
945
946
947@unittest.skipUnless(c_functools, 'requires the C _functools module')
948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
949    if c_functools:
950        cmp_to_key = c_functools.cmp_to_key
951
952    @support.cpython_only
953    def test_disallow_instantiation(self):
954        # Ensure that the type disallows instantiation (bpo-43916)
955        support.check_disallow_instantiation(
956            self, type(c_functools.cmp_to_key(None))
957        )
958
959
960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
961    cmp_to_key = staticmethod(py_functools.cmp_to_key)
962
963
964class TestTotalOrdering(unittest.TestCase):
965
966    def test_total_ordering_lt(self):
967        @functools.total_ordering
968        class A:
969            def __init__(self, value):
970                self.value = value
971            def __lt__(self, other):
972                return self.value < other.value
973            def __eq__(self, other):
974                return self.value == other.value
975        self.assertTrue(A(1) < A(2))
976        self.assertTrue(A(2) > A(1))
977        self.assertTrue(A(1) <= A(2))
978        self.assertTrue(A(2) >= A(1))
979        self.assertTrue(A(2) <= A(2))
980        self.assertTrue(A(2) >= A(2))
981        self.assertFalse(A(1) > A(2))
982
983    def test_total_ordering_le(self):
984        @functools.total_ordering
985        class A:
986            def __init__(self, value):
987                self.value = value
988            def __le__(self, other):
989                return self.value <= other.value
990            def __eq__(self, other):
991                return self.value == other.value
992        self.assertTrue(A(1) < A(2))
993        self.assertTrue(A(2) > A(1))
994        self.assertTrue(A(1) <= A(2))
995        self.assertTrue(A(2) >= A(1))
996        self.assertTrue(A(2) <= A(2))
997        self.assertTrue(A(2) >= A(2))
998        self.assertFalse(A(1) >= A(2))
999
1000    def test_total_ordering_gt(self):
1001        @functools.total_ordering
1002        class A:
1003            def __init__(self, value):
1004                self.value = value
1005            def __gt__(self, other):
1006                return self.value > other.value
1007            def __eq__(self, other):
1008                return self.value == other.value
1009        self.assertTrue(A(1) < A(2))
1010        self.assertTrue(A(2) > A(1))
1011        self.assertTrue(A(1) <= A(2))
1012        self.assertTrue(A(2) >= A(1))
1013        self.assertTrue(A(2) <= A(2))
1014        self.assertTrue(A(2) >= A(2))
1015        self.assertFalse(A(2) < A(1))
1016
1017    def test_total_ordering_ge(self):
1018        @functools.total_ordering
1019        class A:
1020            def __init__(self, value):
1021                self.value = value
1022            def __ge__(self, other):
1023                return self.value >= other.value
1024            def __eq__(self, other):
1025                return self.value == other.value
1026        self.assertTrue(A(1) < A(2))
1027        self.assertTrue(A(2) > A(1))
1028        self.assertTrue(A(1) <= A(2))
1029        self.assertTrue(A(2) >= A(1))
1030        self.assertTrue(A(2) <= A(2))
1031        self.assertTrue(A(2) >= A(2))
1032        self.assertFalse(A(2) <= A(1))
1033
1034    def test_total_ordering_no_overwrite(self):
1035        # new methods should not overwrite existing
1036        @functools.total_ordering
1037        class A(int):
1038            pass
1039        self.assertTrue(A(1) < A(2))
1040        self.assertTrue(A(2) > A(1))
1041        self.assertTrue(A(1) <= A(2))
1042        self.assertTrue(A(2) >= A(1))
1043        self.assertTrue(A(2) <= A(2))
1044        self.assertTrue(A(2) >= A(2))
1045
1046    def test_no_operations_defined(self):
1047        with self.assertRaises(ValueError):
1048            @functools.total_ordering
1049            class A:
1050                pass
1051
1052    def test_type_error_when_not_implemented(self):
1053        # bug 10042; ensure stack overflow does not occur
1054        # when decorated types return NotImplemented
1055        @functools.total_ordering
1056        class ImplementsLessThan:
1057            def __init__(self, value):
1058                self.value = value
1059            def __eq__(self, other):
1060                if isinstance(other, ImplementsLessThan):
1061                    return self.value == other.value
1062                return False
1063            def __lt__(self, other):
1064                if isinstance(other, ImplementsLessThan):
1065                    return self.value < other.value
1066                return NotImplemented
1067
1068        @functools.total_ordering
1069        class ImplementsGreaterThan:
1070            def __init__(self, value):
1071                self.value = value
1072            def __eq__(self, other):
1073                if isinstance(other, ImplementsGreaterThan):
1074                    return self.value == other.value
1075                return False
1076            def __gt__(self, other):
1077                if isinstance(other, ImplementsGreaterThan):
1078                    return self.value > other.value
1079                return NotImplemented
1080
1081        @functools.total_ordering
1082        class ImplementsLessThanEqualTo:
1083            def __init__(self, value):
1084                self.value = value
1085            def __eq__(self, other):
1086                if isinstance(other, ImplementsLessThanEqualTo):
1087                    return self.value == other.value
1088                return False
1089            def __le__(self, other):
1090                if isinstance(other, ImplementsLessThanEqualTo):
1091                    return self.value <= other.value
1092                return NotImplemented
1093
1094        @functools.total_ordering
1095        class ImplementsGreaterThanEqualTo:
1096            def __init__(self, value):
1097                self.value = value
1098            def __eq__(self, other):
1099                if isinstance(other, ImplementsGreaterThanEqualTo):
1100                    return self.value == other.value
1101                return False
1102            def __ge__(self, other):
1103                if isinstance(other, ImplementsGreaterThanEqualTo):
1104                    return self.value >= other.value
1105                return NotImplemented
1106
1107        @functools.total_ordering
1108        class ComparatorNotImplemented:
1109            def __init__(self, value):
1110                self.value = value
1111            def __eq__(self, other):
1112                if isinstance(other, ComparatorNotImplemented):
1113                    return self.value == other.value
1114                return False
1115            def __lt__(self, other):
1116                return NotImplemented
1117
1118        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1119            ImplementsLessThan(-1) < 1
1120
1121        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1122            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1123
1124        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1125            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1126
1127        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1128            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1129
1130        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1131            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1132
1133        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1134            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1135
1136        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1137            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1138
1139        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1140            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1141
1142        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1143            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1144
1145        with self.subTest("GE when equal"):
1146            a = ComparatorNotImplemented(8)
1147            b = ComparatorNotImplemented(8)
1148            self.assertEqual(a, b)
1149            with self.assertRaises(TypeError):
1150                a >= b
1151
1152        with self.subTest("LE when equal"):
1153            a = ComparatorNotImplemented(9)
1154            b = ComparatorNotImplemented(9)
1155            self.assertEqual(a, b)
1156            with self.assertRaises(TypeError):
1157                a <= b
1158
1159    def test_pickle(self):
1160        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1161            for name in '__lt__', '__gt__', '__le__', '__ge__':
1162                with self.subTest(method=name, proto=proto):
1163                    method = getattr(Orderable_LT, name)
1164                    method_copy = pickle.loads(pickle.dumps(method, proto))
1165                    self.assertIs(method_copy, method)
1166
1167
1168    def test_total_ordering_for_metaclasses_issue_44605(self):
1169
1170        @functools.total_ordering
1171        class SortableMeta(type):
1172            def __new__(cls, name, bases, ns):
1173                return super().__new__(cls, name, bases, ns)
1174
1175            def __lt__(self, other):
1176                if not isinstance(other, SortableMeta):
1177                    pass
1178                return self.__name__ < other.__name__
1179
1180            def __eq__(self, other):
1181                if not isinstance(other, SortableMeta):
1182                    pass
1183                return self.__name__ == other.__name__
1184
1185        class B(metaclass=SortableMeta):
1186            pass
1187
1188        class A(metaclass=SortableMeta):
1189            pass
1190
1191        self.assertTrue(A < B)
1192        self.assertFalse(A > B)
1193
1194
1195@functools.total_ordering
1196class Orderable_LT:
1197    def __init__(self, value):
1198        self.value = value
1199    def __lt__(self, other):
1200        return self.value < other.value
1201    def __eq__(self, other):
1202        return self.value == other.value
1203
1204
1205class TestCache:
1206    # This tests that the pass-through is working as designed.
1207    # The underlying functionality is tested in TestLRU.
1208
1209    def test_cache(self):
1210        @self.module.cache
1211        def fib(n):
1212            if n < 2:
1213                return n
1214            return fib(n-1) + fib(n-2)
1215        self.assertEqual([fib(n) for n in range(16)],
1216            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1217        self.assertEqual(fib.cache_info(),
1218            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1219        fib.cache_clear()
1220        self.assertEqual(fib.cache_info(),
1221            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1222
1223
1224class TestLRU:
1225
1226    def test_lru(self):
1227        def orig(x, y):
1228            return 3 * x + y
1229        f = self.module.lru_cache(maxsize=20)(orig)
1230        hits, misses, maxsize, currsize = f.cache_info()
1231        self.assertEqual(maxsize, 20)
1232        self.assertEqual(currsize, 0)
1233        self.assertEqual(hits, 0)
1234        self.assertEqual(misses, 0)
1235
1236        domain = range(5)
1237        for i in range(1000):
1238            x, y = choice(domain), choice(domain)
1239            actual = f(x, y)
1240            expected = orig(x, y)
1241            self.assertEqual(actual, expected)
1242        hits, misses, maxsize, currsize = f.cache_info()
1243        self.assertTrue(hits > misses)
1244        self.assertEqual(hits + misses, 1000)
1245        self.assertEqual(currsize, 20)
1246
1247        f.cache_clear()   # test clearing
1248        hits, misses, maxsize, currsize = f.cache_info()
1249        self.assertEqual(hits, 0)
1250        self.assertEqual(misses, 0)
1251        self.assertEqual(currsize, 0)
1252        f(x, y)
1253        hits, misses, maxsize, currsize = f.cache_info()
1254        self.assertEqual(hits, 0)
1255        self.assertEqual(misses, 1)
1256        self.assertEqual(currsize, 1)
1257
1258        # Test bypassing the cache
1259        self.assertIs(f.__wrapped__, orig)
1260        f.__wrapped__(x, y)
1261        hits, misses, maxsize, currsize = f.cache_info()
1262        self.assertEqual(hits, 0)
1263        self.assertEqual(misses, 1)
1264        self.assertEqual(currsize, 1)
1265
1266        # test size zero (which means "never-cache")
1267        @self.module.lru_cache(0)
1268        def f():
1269            nonlocal f_cnt
1270            f_cnt += 1
1271            return 20
1272        self.assertEqual(f.cache_info().maxsize, 0)
1273        f_cnt = 0
1274        for i in range(5):
1275            self.assertEqual(f(), 20)
1276        self.assertEqual(f_cnt, 5)
1277        hits, misses, maxsize, currsize = f.cache_info()
1278        self.assertEqual(hits, 0)
1279        self.assertEqual(misses, 5)
1280        self.assertEqual(currsize, 0)
1281
1282        # test size one
1283        @self.module.lru_cache(1)
1284        def f():
1285            nonlocal f_cnt
1286            f_cnt += 1
1287            return 20
1288        self.assertEqual(f.cache_info().maxsize, 1)
1289        f_cnt = 0
1290        for i in range(5):
1291            self.assertEqual(f(), 20)
1292        self.assertEqual(f_cnt, 1)
1293        hits, misses, maxsize, currsize = f.cache_info()
1294        self.assertEqual(hits, 4)
1295        self.assertEqual(misses, 1)
1296        self.assertEqual(currsize, 1)
1297
1298        # test size two
1299        @self.module.lru_cache(2)
1300        def f(x):
1301            nonlocal f_cnt
1302            f_cnt += 1
1303            return x*10
1304        self.assertEqual(f.cache_info().maxsize, 2)
1305        f_cnt = 0
1306        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1307            #    *  *              *                          *
1308            self.assertEqual(f(x), x*10)
1309        self.assertEqual(f_cnt, 4)
1310        hits, misses, maxsize, currsize = f.cache_info()
1311        self.assertEqual(hits, 12)
1312        self.assertEqual(misses, 4)
1313        self.assertEqual(currsize, 2)
1314
1315    def test_lru_no_args(self):
1316        @self.module.lru_cache
1317        def square(x):
1318            return x ** 2
1319
1320        self.assertEqual(list(map(square, [10, 20, 10])),
1321                         [100, 400, 100])
1322        self.assertEqual(square.cache_info().hits, 1)
1323        self.assertEqual(square.cache_info().misses, 2)
1324        self.assertEqual(square.cache_info().maxsize, 128)
1325        self.assertEqual(square.cache_info().currsize, 2)
1326
1327    def test_lru_bug_35780(self):
1328        # C version of the lru_cache was not checking to see if
1329        # the user function call has already modified the cache
1330        # (this arises in recursive calls and in multi-threading).
1331        # This cause the cache to have orphan links not referenced
1332        # by the cache dictionary.
1333
1334        once = True                 # Modified by f(x) below
1335
1336        @self.module.lru_cache(maxsize=10)
1337        def f(x):
1338            nonlocal once
1339            rv = f'.{x}.'
1340            if x == 20 and once:
1341                once = False
1342                rv = f(x)
1343            return rv
1344
1345        # Fill the cache
1346        for x in range(15):
1347            self.assertEqual(f(x), f'.{x}.')
1348        self.assertEqual(f.cache_info().currsize, 10)
1349
1350        # Make a recursive call and make sure the cache remains full
1351        self.assertEqual(f(20), '.20.')
1352        self.assertEqual(f.cache_info().currsize, 10)
1353
1354    def test_lru_bug_36650(self):
1355        # C version of lru_cache was treating a call with an empty **kwargs
1356        # dictionary as being distinct from a call with no keywords at all.
1357        # This did not result in an incorrect answer, but it did trigger
1358        # an unexpected cache miss.
1359
1360        @self.module.lru_cache()
1361        def f(x):
1362            pass
1363
1364        f(0)
1365        f(0, **{})
1366        self.assertEqual(f.cache_info().hits, 1)
1367
1368    def test_lru_hash_only_once(self):
1369        # To protect against weird reentrancy bugs and to improve
1370        # efficiency when faced with slow __hash__ methods, the
1371        # LRU cache guarantees that it will only call __hash__
1372        # only once per use as an argument to the cached function.
1373
1374        @self.module.lru_cache(maxsize=1)
1375        def f(x, y):
1376            return x * 3 + y
1377
1378        # Simulate the integer 5
1379        mock_int = unittest.mock.Mock()
1380        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1381        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1382
1383        # Add to cache:  One use as an argument gives one call
1384        self.assertEqual(f(mock_int, 1), 16)
1385        self.assertEqual(mock_int.__hash__.call_count, 1)
1386        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1387
1388        # Cache hit: One use as an argument gives one additional call
1389        self.assertEqual(f(mock_int, 1), 16)
1390        self.assertEqual(mock_int.__hash__.call_count, 2)
1391        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1392
1393        # Cache eviction: No use as an argument gives no additional call
1394        self.assertEqual(f(6, 2), 20)
1395        self.assertEqual(mock_int.__hash__.call_count, 2)
1396        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1397
1398        # Cache miss: One use as an argument gives one additional call
1399        self.assertEqual(f(mock_int, 1), 16)
1400        self.assertEqual(mock_int.__hash__.call_count, 3)
1401        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1402
1403    def test_lru_reentrancy_with_len(self):
1404        # Test to make sure the LRU cache code isn't thrown-off by
1405        # caching the built-in len() function.  Since len() can be
1406        # cached, we shouldn't use it inside the lru code itself.
1407        old_len = builtins.len
1408        try:
1409            builtins.len = self.module.lru_cache(4)(len)
1410            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1411                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1412        finally:
1413            builtins.len = old_len
1414
1415    def test_lru_star_arg_handling(self):
1416        # Test regression that arose in ea064ff3c10f
1417        @functools.lru_cache()
1418        def f(*args):
1419            return args
1420
1421        self.assertEqual(f(1, 2), (1, 2))
1422        self.assertEqual(f((1, 2)), ((1, 2),))
1423
1424    def test_lru_type_error(self):
1425        # Regression test for issue #28653.
1426        # lru_cache was leaking when one of the arguments
1427        # wasn't cacheable.
1428
1429        @functools.lru_cache(maxsize=None)
1430        def infinite_cache(o):
1431            pass
1432
1433        @functools.lru_cache(maxsize=10)
1434        def limited_cache(o):
1435            pass
1436
1437        with self.assertRaises(TypeError):
1438            infinite_cache([])
1439
1440        with self.assertRaises(TypeError):
1441            limited_cache([])
1442
1443    def test_lru_with_maxsize_none(self):
1444        @self.module.lru_cache(maxsize=None)
1445        def fib(n):
1446            if n < 2:
1447                return n
1448            return fib(n-1) + fib(n-2)
1449        self.assertEqual([fib(n) for n in range(16)],
1450            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1451        self.assertEqual(fib.cache_info(),
1452            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1453        fib.cache_clear()
1454        self.assertEqual(fib.cache_info(),
1455            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1456
1457    def test_lru_with_maxsize_negative(self):
1458        @self.module.lru_cache(maxsize=-10)
1459        def eq(n):
1460            return n
1461        for i in (0, 1):
1462            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1463        self.assertEqual(eq.cache_info(),
1464            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1465
1466    def test_lru_with_exceptions(self):
1467        # Verify that user_function exceptions get passed through without
1468        # creating a hard-to-read chained exception.
1469        # http://bugs.python.org/issue13177
1470        for maxsize in (None, 128):
1471            @self.module.lru_cache(maxsize)
1472            def func(i):
1473                return 'abc'[i]
1474            self.assertEqual(func(0), 'a')
1475            with self.assertRaises(IndexError) as cm:
1476                func(15)
1477            self.assertIsNone(cm.exception.__context__)
1478            # Verify that the previous exception did not result in a cached entry
1479            with self.assertRaises(IndexError):
1480                func(15)
1481
1482    def test_lru_with_types(self):
1483        for maxsize in (None, 128):
1484            @self.module.lru_cache(maxsize=maxsize, typed=True)
1485            def square(x):
1486                return x * x
1487            self.assertEqual(square(3), 9)
1488            self.assertEqual(type(square(3)), type(9))
1489            self.assertEqual(square(3.0), 9.0)
1490            self.assertEqual(type(square(3.0)), type(9.0))
1491            self.assertEqual(square(x=3), 9)
1492            self.assertEqual(type(square(x=3)), type(9))
1493            self.assertEqual(square(x=3.0), 9.0)
1494            self.assertEqual(type(square(x=3.0)), type(9.0))
1495            self.assertEqual(square.cache_info().hits, 4)
1496            self.assertEqual(square.cache_info().misses, 4)
1497
1498    def test_lru_with_keyword_args(self):
1499        @self.module.lru_cache()
1500        def fib(n):
1501            if n < 2:
1502                return n
1503            return fib(n=n-1) + fib(n=n-2)
1504        self.assertEqual(
1505            [fib(n=number) for number in range(16)],
1506            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1507        )
1508        self.assertEqual(fib.cache_info(),
1509            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1510        fib.cache_clear()
1511        self.assertEqual(fib.cache_info(),
1512            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1513
1514    def test_lru_with_keyword_args_maxsize_none(self):
1515        @self.module.lru_cache(maxsize=None)
1516        def fib(n):
1517            if n < 2:
1518                return n
1519            return fib(n=n-1) + fib(n=n-2)
1520        self.assertEqual([fib(n=number) for number in range(16)],
1521            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1522        self.assertEqual(fib.cache_info(),
1523            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1524        fib.cache_clear()
1525        self.assertEqual(fib.cache_info(),
1526            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1527
1528    def test_kwargs_order(self):
1529        # PEP 468: Preserving Keyword Argument Order
1530        @self.module.lru_cache(maxsize=10)
1531        def f(**kwargs):
1532            return list(kwargs.items())
1533        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1534        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1535        self.assertEqual(f.cache_info(),
1536            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1537
1538    def test_lru_cache_decoration(self):
1539        def f(zomg: 'zomg_annotation'):
1540            """f doc string"""
1541            return 42
1542        g = self.module.lru_cache()(f)
1543        for attr in self.module.WRAPPER_ASSIGNMENTS:
1544            self.assertEqual(getattr(g, attr), getattr(f, attr))
1545
1546    def test_lru_cache_threaded(self):
1547        n, m = 5, 11
1548        def orig(x, y):
1549            return 3 * x + y
1550        f = self.module.lru_cache(maxsize=n*m)(orig)
1551        hits, misses, maxsize, currsize = f.cache_info()
1552        self.assertEqual(currsize, 0)
1553
1554        start = threading.Event()
1555        def full(k):
1556            start.wait(10)
1557            for _ in range(m):
1558                self.assertEqual(f(k, 0), orig(k, 0))
1559
1560        def clear():
1561            start.wait(10)
1562            for _ in range(2*m):
1563                f.cache_clear()
1564
1565        orig_si = sys.getswitchinterval()
1566        support.setswitchinterval(1e-6)
1567        try:
1568            # create n threads in order to fill cache
1569            threads = [threading.Thread(target=full, args=[k])
1570                       for k in range(n)]
1571            with threading_helper.start_threads(threads):
1572                start.set()
1573
1574            hits, misses, maxsize, currsize = f.cache_info()
1575            if self.module is py_functools:
1576                # XXX: Why can be not equal?
1577                self.assertLessEqual(misses, n)
1578                self.assertLessEqual(hits, m*n - misses)
1579            else:
1580                self.assertEqual(misses, n)
1581                self.assertEqual(hits, m*n - misses)
1582            self.assertEqual(currsize, n)
1583
1584            # create n threads in order to fill cache and 1 to clear it
1585            threads = [threading.Thread(target=clear)]
1586            threads += [threading.Thread(target=full, args=[k])
1587                        for k in range(n)]
1588            start.clear()
1589            with threading_helper.start_threads(threads):
1590                start.set()
1591        finally:
1592            sys.setswitchinterval(orig_si)
1593
1594    def test_lru_cache_threaded2(self):
1595        # Simultaneous call with the same arguments
1596        n, m = 5, 7
1597        start = threading.Barrier(n+1)
1598        pause = threading.Barrier(n+1)
1599        stop = threading.Barrier(n+1)
1600        @self.module.lru_cache(maxsize=m*n)
1601        def f(x):
1602            pause.wait(10)
1603            return 3 * x
1604        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1605        def test():
1606            for i in range(m):
1607                start.wait(10)
1608                self.assertEqual(f(i), 3 * i)
1609                stop.wait(10)
1610        threads = [threading.Thread(target=test) for k in range(n)]
1611        with threading_helper.start_threads(threads):
1612            for i in range(m):
1613                start.wait(10)
1614                stop.reset()
1615                pause.wait(10)
1616                start.reset()
1617                stop.wait(10)
1618                pause.reset()
1619                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1620
1621    def test_lru_cache_threaded3(self):
1622        @self.module.lru_cache(maxsize=2)
1623        def f(x):
1624            time.sleep(.01)
1625            return 3 * x
1626        def test(i, x):
1627            with self.subTest(thread=i):
1628                self.assertEqual(f(x), 3 * x, i)
1629        threads = [threading.Thread(target=test, args=(i, v))
1630                   for i, v in enumerate([1, 2, 2, 3, 2])]
1631        with threading_helper.start_threads(threads):
1632            pass
1633
1634    def test_need_for_rlock(self):
1635        # This will deadlock on an LRU cache that uses a regular lock
1636
1637        @self.module.lru_cache(maxsize=10)
1638        def test_func(x):
1639            'Used to demonstrate a reentrant lru_cache call within a single thread'
1640            return x
1641
1642        class DoubleEq:
1643            'Demonstrate a reentrant lru_cache call within a single thread'
1644            def __init__(self, x):
1645                self.x = x
1646            def __hash__(self):
1647                return self.x
1648            def __eq__(self, other):
1649                if self.x == 2:
1650                    test_func(DoubleEq(1))
1651                return self.x == other.x
1652
1653        test_func(DoubleEq(1))                      # Load the cache
1654        test_func(DoubleEq(2))                      # Load the cache
1655        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1656                         DoubleEq(2))               # Verify the correct return value
1657
1658    def test_lru_method(self):
1659        class X(int):
1660            f_cnt = 0
1661            @self.module.lru_cache(2)
1662            def f(self, x):
1663                self.f_cnt += 1
1664                return x*10+self
1665        a = X(5)
1666        b = X(5)
1667        c = X(7)
1668        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1669
1670        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1671            self.assertEqual(a.f(x), x*10 + 5)
1672        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1673        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1674
1675        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1676            self.assertEqual(b.f(x), x*10 + 5)
1677        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1678        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1679
1680        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1681            self.assertEqual(c.f(x), x*10 + 7)
1682        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1683        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1684
1685        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1686        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1687        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1688
1689    def test_pickle(self):
1690        cls = self.__class__
1691        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1692            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1693                with self.subTest(proto=proto, func=f):
1694                    f_copy = pickle.loads(pickle.dumps(f, proto))
1695                    self.assertIs(f_copy, f)
1696
1697    def test_copy(self):
1698        cls = self.__class__
1699        def orig(x, y):
1700            return 3 * x + y
1701        part = self.module.partial(orig, 2)
1702        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1703                 self.module.lru_cache(2)(part))
1704        for f in funcs:
1705            with self.subTest(func=f):
1706                f_copy = copy.copy(f)
1707                self.assertIs(f_copy, f)
1708
1709    def test_deepcopy(self):
1710        cls = self.__class__
1711        def orig(x, y):
1712            return 3 * x + y
1713        part = self.module.partial(orig, 2)
1714        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1715                 self.module.lru_cache(2)(part))
1716        for f in funcs:
1717            with self.subTest(func=f):
1718                f_copy = copy.deepcopy(f)
1719                self.assertIs(f_copy, f)
1720
1721    def test_lru_cache_parameters(self):
1722        @self.module.lru_cache(maxsize=2)
1723        def f():
1724            return 1
1725        self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1726
1727        @self.module.lru_cache(maxsize=1000, typed=True)
1728        def f():
1729            return 1
1730        self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1731
1732    def test_lru_cache_weakrefable(self):
1733        @self.module.lru_cache
1734        def test_function(x):
1735            return x
1736
1737        class A:
1738            @self.module.lru_cache
1739            def test_method(self, x):
1740                return (self, x)
1741
1742            @staticmethod
1743            @self.module.lru_cache
1744            def test_staticmethod(x):
1745                return (self, x)
1746
1747        refs = [weakref.ref(test_function),
1748                weakref.ref(A.test_method),
1749                weakref.ref(A.test_staticmethod)]
1750
1751        for ref in refs:
1752            self.assertIsNotNone(ref())
1753
1754        del A
1755        del test_function
1756        gc.collect()
1757
1758        for ref in refs:
1759            self.assertIsNone(ref())
1760
1761
1762@py_functools.lru_cache()
1763def py_cached_func(x, y):
1764    return 3 * x + y
1765
1766@c_functools.lru_cache()
1767def c_cached_func(x, y):
1768    return 3 * x + y
1769
1770
1771class TestLRUPy(TestLRU, unittest.TestCase):
1772    module = py_functools
1773    cached_func = py_cached_func,
1774
1775    @module.lru_cache()
1776    def cached_meth(self, x, y):
1777        return 3 * x + y
1778
1779    @staticmethod
1780    @module.lru_cache()
1781    def cached_staticmeth(x, y):
1782        return 3 * x + y
1783
1784
1785class TestLRUC(TestLRU, unittest.TestCase):
1786    module = c_functools
1787    cached_func = c_cached_func,
1788
1789    @module.lru_cache()
1790    def cached_meth(self, x, y):
1791        return 3 * x + y
1792
1793    @staticmethod
1794    @module.lru_cache()
1795    def cached_staticmeth(x, y):
1796        return 3 * x + y
1797
1798
1799class TestSingleDispatch(unittest.TestCase):
1800    def test_simple_overloads(self):
1801        @functools.singledispatch
1802        def g(obj):
1803            return "base"
1804        def g_int(i):
1805            return "integer"
1806        g.register(int, g_int)
1807        self.assertEqual(g("str"), "base")
1808        self.assertEqual(g(1), "integer")
1809        self.assertEqual(g([1,2,3]), "base")
1810
1811    def test_mro(self):
1812        @functools.singledispatch
1813        def g(obj):
1814            return "base"
1815        class A:
1816            pass
1817        class C(A):
1818            pass
1819        class B(A):
1820            pass
1821        class D(C, B):
1822            pass
1823        def g_A(a):
1824            return "A"
1825        def g_B(b):
1826            return "B"
1827        g.register(A, g_A)
1828        g.register(B, g_B)
1829        self.assertEqual(g(A()), "A")
1830        self.assertEqual(g(B()), "B")
1831        self.assertEqual(g(C()), "A")
1832        self.assertEqual(g(D()), "B")
1833
1834    def test_register_decorator(self):
1835        @functools.singledispatch
1836        def g(obj):
1837            return "base"
1838        @g.register(int)
1839        def g_int(i):
1840            return "int %s" % (i,)
1841        self.assertEqual(g(""), "base")
1842        self.assertEqual(g(12), "int 12")
1843        self.assertIs(g.dispatch(int), g_int)
1844        self.assertIs(g.dispatch(object), g.dispatch(str))
1845        # Note: in the assert above this is not g.
1846        # @singledispatch returns the wrapper.
1847
1848    def test_wrapping_attributes(self):
1849        @functools.singledispatch
1850        def g(obj):
1851            "Simple test"
1852            return "Test"
1853        self.assertEqual(g.__name__, "g")
1854        if sys.flags.optimize < 2:
1855            self.assertEqual(g.__doc__, "Simple test")
1856
1857    @unittest.skipUnless(decimal, 'requires _decimal')
1858    @support.cpython_only
1859    def test_c_classes(self):
1860        @functools.singledispatch
1861        def g(obj):
1862            return "base"
1863        @g.register(decimal.DecimalException)
1864        def _(obj):
1865            return obj.args
1866        subn = decimal.Subnormal("Exponent < Emin")
1867        rnd = decimal.Rounded("Number got rounded")
1868        self.assertEqual(g(subn), ("Exponent < Emin",))
1869        self.assertEqual(g(rnd), ("Number got rounded",))
1870        @g.register(decimal.Subnormal)
1871        def _(obj):
1872            return "Too small to care."
1873        self.assertEqual(g(subn), "Too small to care.")
1874        self.assertEqual(g(rnd), ("Number got rounded",))
1875
1876    def test_compose_mro(self):
1877        # None of the examples in this test depend on haystack ordering.
1878        c = collections.abc
1879        mro = functools._compose_mro
1880        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1881        for haystack in permutations(bases):
1882            m = mro(dict, haystack)
1883            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1884                                 c.Collection, c.Sized, c.Iterable,
1885                                 c.Container, object])
1886        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1887        for haystack in permutations(bases):
1888            m = mro(collections.ChainMap, haystack)
1889            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1890                                 c.Collection, c.Sized, c.Iterable,
1891                                 c.Container, object])
1892
1893        # If there's a generic function with implementations registered for
1894        # both Sized and Container, passing a defaultdict to it results in an
1895        # ambiguous dispatch which will cause a RuntimeError (see
1896        # test_mro_conflicts).
1897        bases = [c.Container, c.Sized, str]
1898        for haystack in permutations(bases):
1899            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1900            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1901                                 c.Container, object])
1902
1903        # MutableSequence below is registered directly on D. In other words, it
1904        # precedes MutableMapping which means single dispatch will always
1905        # choose MutableSequence here.
1906        class D(collections.defaultdict):
1907            pass
1908        c.MutableSequence.register(D)
1909        bases = [c.MutableSequence, c.MutableMapping]
1910        for haystack in permutations(bases):
1911            m = mro(D, bases)
1912            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1913                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1914                                 c.Collection, c.Sized, c.Iterable, c.Container,
1915                                 object])
1916
1917        # Container and Callable are registered on different base classes and
1918        # a generic function supporting both should always pick the Callable
1919        # implementation if a C instance is passed.
1920        class C(collections.defaultdict):
1921            def __call__(self):
1922                pass
1923        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1924        for haystack in permutations(bases):
1925            m = mro(C, haystack)
1926            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1927                                 c.Collection, c.Sized, c.Iterable,
1928                                 c.Container, object])
1929
1930    def test_register_abc(self):
1931        c = collections.abc
1932        d = {"a": "b"}
1933        l = [1, 2, 3]
1934        s = {object(), None}
1935        f = frozenset(s)
1936        t = (1, 2, 3)
1937        @functools.singledispatch
1938        def g(obj):
1939            return "base"
1940        self.assertEqual(g(d), "base")
1941        self.assertEqual(g(l), "base")
1942        self.assertEqual(g(s), "base")
1943        self.assertEqual(g(f), "base")
1944        self.assertEqual(g(t), "base")
1945        g.register(c.Sized, lambda obj: "sized")
1946        self.assertEqual(g(d), "sized")
1947        self.assertEqual(g(l), "sized")
1948        self.assertEqual(g(s), "sized")
1949        self.assertEqual(g(f), "sized")
1950        self.assertEqual(g(t), "sized")
1951        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1952        self.assertEqual(g(d), "mutablemapping")
1953        self.assertEqual(g(l), "sized")
1954        self.assertEqual(g(s), "sized")
1955        self.assertEqual(g(f), "sized")
1956        self.assertEqual(g(t), "sized")
1957        g.register(collections.ChainMap, lambda obj: "chainmap")
1958        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1959        self.assertEqual(g(l), "sized")
1960        self.assertEqual(g(s), "sized")
1961        self.assertEqual(g(f), "sized")
1962        self.assertEqual(g(t), "sized")
1963        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1964        self.assertEqual(g(d), "mutablemapping")
1965        self.assertEqual(g(l), "mutablesequence")
1966        self.assertEqual(g(s), "sized")
1967        self.assertEqual(g(f), "sized")
1968        self.assertEqual(g(t), "sized")
1969        g.register(c.MutableSet, lambda obj: "mutableset")
1970        self.assertEqual(g(d), "mutablemapping")
1971        self.assertEqual(g(l), "mutablesequence")
1972        self.assertEqual(g(s), "mutableset")
1973        self.assertEqual(g(f), "sized")
1974        self.assertEqual(g(t), "sized")
1975        g.register(c.Mapping, lambda obj: "mapping")
1976        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1977        self.assertEqual(g(l), "mutablesequence")
1978        self.assertEqual(g(s), "mutableset")
1979        self.assertEqual(g(f), "sized")
1980        self.assertEqual(g(t), "sized")
1981        g.register(c.Sequence, lambda obj: "sequence")
1982        self.assertEqual(g(d), "mutablemapping")
1983        self.assertEqual(g(l), "mutablesequence")
1984        self.assertEqual(g(s), "mutableset")
1985        self.assertEqual(g(f), "sized")
1986        self.assertEqual(g(t), "sequence")
1987        g.register(c.Set, lambda obj: "set")
1988        self.assertEqual(g(d), "mutablemapping")
1989        self.assertEqual(g(l), "mutablesequence")
1990        self.assertEqual(g(s), "mutableset")
1991        self.assertEqual(g(f), "set")
1992        self.assertEqual(g(t), "sequence")
1993        g.register(dict, lambda obj: "dict")
1994        self.assertEqual(g(d), "dict")
1995        self.assertEqual(g(l), "mutablesequence")
1996        self.assertEqual(g(s), "mutableset")
1997        self.assertEqual(g(f), "set")
1998        self.assertEqual(g(t), "sequence")
1999        g.register(list, lambda obj: "list")
2000        self.assertEqual(g(d), "dict")
2001        self.assertEqual(g(l), "list")
2002        self.assertEqual(g(s), "mutableset")
2003        self.assertEqual(g(f), "set")
2004        self.assertEqual(g(t), "sequence")
2005        g.register(set, lambda obj: "concrete-set")
2006        self.assertEqual(g(d), "dict")
2007        self.assertEqual(g(l), "list")
2008        self.assertEqual(g(s), "concrete-set")
2009        self.assertEqual(g(f), "set")
2010        self.assertEqual(g(t), "sequence")
2011        g.register(frozenset, lambda obj: "frozen-set")
2012        self.assertEqual(g(d), "dict")
2013        self.assertEqual(g(l), "list")
2014        self.assertEqual(g(s), "concrete-set")
2015        self.assertEqual(g(f), "frozen-set")
2016        self.assertEqual(g(t), "sequence")
2017        g.register(tuple, lambda obj: "tuple")
2018        self.assertEqual(g(d), "dict")
2019        self.assertEqual(g(l), "list")
2020        self.assertEqual(g(s), "concrete-set")
2021        self.assertEqual(g(f), "frozen-set")
2022        self.assertEqual(g(t), "tuple")
2023
2024    def test_c3_abc(self):
2025        c = collections.abc
2026        mro = functools._c3_mro
2027        class A(object):
2028            pass
2029        class B(A):
2030            def __len__(self):
2031                return 0   # implies Sized
2032        @c.Container.register
2033        class C(object):
2034            pass
2035        class D(object):
2036            pass   # unrelated
2037        class X(D, C, B):
2038            def __call__(self):
2039                pass   # implies Callable
2040        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2041        for abcs in permutations([c.Sized, c.Callable, c.Container]):
2042            self.assertEqual(mro(X, abcs=abcs), expected)
2043        # unrelated ABCs don't appear in the resulting MRO
2044        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2045        self.assertEqual(mro(X, abcs=many_abcs), expected)
2046
2047    def test_false_meta(self):
2048        # see issue23572
2049        class MetaA(type):
2050            def __len__(self):
2051                return 0
2052        class A(metaclass=MetaA):
2053            pass
2054        class AA(A):
2055            pass
2056        @functools.singledispatch
2057        def fun(a):
2058            return 'base A'
2059        @fun.register(A)
2060        def _(a):
2061            return 'fun A'
2062        aa = AA()
2063        self.assertEqual(fun(aa), 'fun A')
2064
2065    def test_mro_conflicts(self):
2066        c = collections.abc
2067        @functools.singledispatch
2068        def g(arg):
2069            return "base"
2070        class O(c.Sized):
2071            def __len__(self):
2072                return 0
2073        o = O()
2074        self.assertEqual(g(o), "base")
2075        g.register(c.Iterable, lambda arg: "iterable")
2076        g.register(c.Container, lambda arg: "container")
2077        g.register(c.Sized, lambda arg: "sized")
2078        g.register(c.Set, lambda arg: "set")
2079        self.assertEqual(g(o), "sized")
2080        c.Iterable.register(O)
2081        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
2082        c.Container.register(O)
2083        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
2084        c.Set.register(O)
2085        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
2086                                          # c.Sized and c.Container
2087        class P:
2088            pass
2089        p = P()
2090        self.assertEqual(g(p), "base")
2091        c.Iterable.register(P)
2092        self.assertEqual(g(p), "iterable")
2093        c.Container.register(P)
2094        with self.assertRaises(RuntimeError) as re_one:
2095            g(p)
2096        self.assertIn(
2097            str(re_one.exception),
2098            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2099              "or <class 'collections.abc.Iterable'>"),
2100             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2101              "or <class 'collections.abc.Container'>")),
2102        )
2103        class Q(c.Sized):
2104            def __len__(self):
2105                return 0
2106        q = Q()
2107        self.assertEqual(g(q), "sized")
2108        c.Iterable.register(Q)
2109        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2110        c.Set.register(Q)
2111        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2112                                          # c.Sized and c.Iterable
2113        @functools.singledispatch
2114        def h(arg):
2115            return "base"
2116        @h.register(c.Sized)
2117        def _(arg):
2118            return "sized"
2119        @h.register(c.Container)
2120        def _(arg):
2121            return "container"
2122        # Even though Sized and Container are explicit bases of MutableMapping,
2123        # this ABC is implicitly registered on defaultdict which makes all of
2124        # MutableMapping's bases implicit as well from defaultdict's
2125        # perspective.
2126        with self.assertRaises(RuntimeError) as re_two:
2127            h(collections.defaultdict(lambda: 0))
2128        self.assertIn(
2129            str(re_two.exception),
2130            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2131              "or <class 'collections.abc.Sized'>"),
2132             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2133              "or <class 'collections.abc.Container'>")),
2134        )
2135        class R(collections.defaultdict):
2136            pass
2137        c.MutableSequence.register(R)
2138        @functools.singledispatch
2139        def i(arg):
2140            return "base"
2141        @i.register(c.MutableMapping)
2142        def _(arg):
2143            return "mapping"
2144        @i.register(c.MutableSequence)
2145        def _(arg):
2146            return "sequence"
2147        r = R()
2148        self.assertEqual(i(r), "sequence")
2149        class S:
2150            pass
2151        class T(S, c.Sized):
2152            def __len__(self):
2153                return 0
2154        t = T()
2155        self.assertEqual(h(t), "sized")
2156        c.Container.register(T)
2157        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2158        class U:
2159            def __len__(self):
2160                return 0
2161        u = U()
2162        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2163                                          # from the existence of __len__()
2164        c.Container.register(U)
2165        # There is no preference for registered versus inferred ABCs.
2166        with self.assertRaises(RuntimeError) as re_three:
2167            h(u)
2168        self.assertIn(
2169            str(re_three.exception),
2170            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2171              "or <class 'collections.abc.Sized'>"),
2172             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2173              "or <class 'collections.abc.Container'>")),
2174        )
2175        class V(c.Sized, S):
2176            def __len__(self):
2177                return 0
2178        @functools.singledispatch
2179        def j(arg):
2180            return "base"
2181        @j.register(S)
2182        def _(arg):
2183            return "s"
2184        @j.register(c.Container)
2185        def _(arg):
2186            return "container"
2187        v = V()
2188        self.assertEqual(j(v), "s")
2189        c.Container.register(V)
2190        self.assertEqual(j(v), "container")   # because it ends up right after
2191                                              # Sized in the MRO
2192
2193    def test_cache_invalidation(self):
2194        from collections import UserDict
2195        import weakref
2196
2197        class TracingDict(UserDict):
2198            def __init__(self, *args, **kwargs):
2199                super(TracingDict, self).__init__(*args, **kwargs)
2200                self.set_ops = []
2201                self.get_ops = []
2202            def __getitem__(self, key):
2203                result = self.data[key]
2204                self.get_ops.append(key)
2205                return result
2206            def __setitem__(self, key, value):
2207                self.set_ops.append(key)
2208                self.data[key] = value
2209            def clear(self):
2210                self.data.clear()
2211
2212        td = TracingDict()
2213        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2214            c = collections.abc
2215            @functools.singledispatch
2216            def g(arg):
2217                return "base"
2218            d = {}
2219            l = []
2220            self.assertEqual(len(td), 0)
2221            self.assertEqual(g(d), "base")
2222            self.assertEqual(len(td), 1)
2223            self.assertEqual(td.get_ops, [])
2224            self.assertEqual(td.set_ops, [dict])
2225            self.assertEqual(td.data[dict], g.registry[object])
2226            self.assertEqual(g(l), "base")
2227            self.assertEqual(len(td), 2)
2228            self.assertEqual(td.get_ops, [])
2229            self.assertEqual(td.set_ops, [dict, list])
2230            self.assertEqual(td.data[dict], g.registry[object])
2231            self.assertEqual(td.data[list], g.registry[object])
2232            self.assertEqual(td.data[dict], td.data[list])
2233            self.assertEqual(g(l), "base")
2234            self.assertEqual(g(d), "base")
2235            self.assertEqual(td.get_ops, [list, dict])
2236            self.assertEqual(td.set_ops, [dict, list])
2237            g.register(list, lambda arg: "list")
2238            self.assertEqual(td.get_ops, [list, dict])
2239            self.assertEqual(len(td), 0)
2240            self.assertEqual(g(d), "base")
2241            self.assertEqual(len(td), 1)
2242            self.assertEqual(td.get_ops, [list, dict])
2243            self.assertEqual(td.set_ops, [dict, list, dict])
2244            self.assertEqual(td.data[dict],
2245                             functools._find_impl(dict, g.registry))
2246            self.assertEqual(g(l), "list")
2247            self.assertEqual(len(td), 2)
2248            self.assertEqual(td.get_ops, [list, dict])
2249            self.assertEqual(td.set_ops, [dict, list, dict, list])
2250            self.assertEqual(td.data[list],
2251                             functools._find_impl(list, g.registry))
2252            class X:
2253                pass
2254            c.MutableMapping.register(X)   # Will not invalidate the cache,
2255                                           # not using ABCs yet.
2256            self.assertEqual(g(d), "base")
2257            self.assertEqual(g(l), "list")
2258            self.assertEqual(td.get_ops, [list, dict, dict, list])
2259            self.assertEqual(td.set_ops, [dict, list, dict, list])
2260            g.register(c.Sized, lambda arg: "sized")
2261            self.assertEqual(len(td), 0)
2262            self.assertEqual(g(d), "sized")
2263            self.assertEqual(len(td), 1)
2264            self.assertEqual(td.get_ops, [list, dict, dict, list])
2265            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2266            self.assertEqual(g(l), "list")
2267            self.assertEqual(len(td), 2)
2268            self.assertEqual(td.get_ops, [list, dict, dict, list])
2269            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2270            self.assertEqual(g(l), "list")
2271            self.assertEqual(g(d), "sized")
2272            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2273            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2274            g.dispatch(list)
2275            g.dispatch(dict)
2276            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2277                                          list, dict])
2278            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2279            c.MutableSet.register(X)       # Will invalidate the cache.
2280            self.assertEqual(len(td), 2)   # Stale cache.
2281            self.assertEqual(g(l), "list")
2282            self.assertEqual(len(td), 1)
2283            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2284            self.assertEqual(len(td), 0)
2285            self.assertEqual(g(d), "mutablemapping")
2286            self.assertEqual(len(td), 1)
2287            self.assertEqual(g(l), "list")
2288            self.assertEqual(len(td), 2)
2289            g.register(dict, lambda arg: "dict")
2290            self.assertEqual(g(d), "dict")
2291            self.assertEqual(g(l), "list")
2292            g._clear_cache()
2293            self.assertEqual(len(td), 0)
2294
2295    def test_annotations(self):
2296        @functools.singledispatch
2297        def i(arg):
2298            return "base"
2299        @i.register
2300        def _(arg: collections.abc.Mapping):
2301            return "mapping"
2302        @i.register
2303        def _(arg: "collections.abc.Sequence"):
2304            return "sequence"
2305        self.assertEqual(i(None), "base")
2306        self.assertEqual(i({"a": 1}), "mapping")
2307        self.assertEqual(i([1, 2, 3]), "sequence")
2308        self.assertEqual(i((1, 2, 3)), "sequence")
2309        self.assertEqual(i("str"), "sequence")
2310
2311        # Registering classes as callables doesn't work with annotations,
2312        # you need to pass the type explicitly.
2313        @i.register(str)
2314        class _:
2315            def __init__(self, arg):
2316                self.arg = arg
2317
2318            def __eq__(self, other):
2319                return self.arg == other
2320        self.assertEqual(i("str"), "str")
2321
2322    def test_method_register(self):
2323        class A:
2324            @functools.singledispatchmethod
2325            def t(self, arg):
2326                self.arg = "base"
2327            @t.register(int)
2328            def _(self, arg):
2329                self.arg = "int"
2330            @t.register(str)
2331            def _(self, arg):
2332                self.arg = "str"
2333        a = A()
2334
2335        a.t(0)
2336        self.assertEqual(a.arg, "int")
2337        aa = A()
2338        self.assertFalse(hasattr(aa, 'arg'))
2339        a.t('')
2340        self.assertEqual(a.arg, "str")
2341        aa = A()
2342        self.assertFalse(hasattr(aa, 'arg'))
2343        a.t(0.0)
2344        self.assertEqual(a.arg, "base")
2345        aa = A()
2346        self.assertFalse(hasattr(aa, 'arg'))
2347
2348    def test_staticmethod_register(self):
2349        class A:
2350            @functools.singledispatchmethod
2351            @staticmethod
2352            def t(arg):
2353                return arg
2354            @t.register(int)
2355            @staticmethod
2356            def _(arg):
2357                return isinstance(arg, int)
2358            @t.register(str)
2359            @staticmethod
2360            def _(arg):
2361                return isinstance(arg, str)
2362        a = A()
2363
2364        self.assertTrue(A.t(0))
2365        self.assertTrue(A.t(''))
2366        self.assertEqual(A.t(0.0), 0.0)
2367
2368    def test_classmethod_register(self):
2369        class A:
2370            def __init__(self, arg):
2371                self.arg = arg
2372
2373            @functools.singledispatchmethod
2374            @classmethod
2375            def t(cls, arg):
2376                return cls("base")
2377            @t.register(int)
2378            @classmethod
2379            def _(cls, arg):
2380                return cls("int")
2381            @t.register(str)
2382            @classmethod
2383            def _(cls, arg):
2384                return cls("str")
2385
2386        self.assertEqual(A.t(0).arg, "int")
2387        self.assertEqual(A.t('').arg, "str")
2388        self.assertEqual(A.t(0.0).arg, "base")
2389
2390    def test_callable_register(self):
2391        class A:
2392            def __init__(self, arg):
2393                self.arg = arg
2394
2395            @functools.singledispatchmethod
2396            @classmethod
2397            def t(cls, arg):
2398                return cls("base")
2399
2400        @A.t.register(int)
2401        @classmethod
2402        def _(cls, arg):
2403            return cls("int")
2404        @A.t.register(str)
2405        @classmethod
2406        def _(cls, arg):
2407            return cls("str")
2408
2409        self.assertEqual(A.t(0).arg, "int")
2410        self.assertEqual(A.t('').arg, "str")
2411        self.assertEqual(A.t(0.0).arg, "base")
2412
2413    def test_abstractmethod_register(self):
2414        class Abstract(metaclass=abc.ABCMeta):
2415
2416            @functools.singledispatchmethod
2417            @abc.abstractmethod
2418            def add(self, x, y):
2419                pass
2420
2421        self.assertTrue(Abstract.add.__isabstractmethod__)
2422        self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
2423
2424        with self.assertRaises(TypeError):
2425            Abstract()
2426
2427    def test_type_ann_register(self):
2428        class A:
2429            @functools.singledispatchmethod
2430            def t(self, arg):
2431                return "base"
2432            @t.register
2433            def _(self, arg: int):
2434                return "int"
2435            @t.register
2436            def _(self, arg: str):
2437                return "str"
2438        a = A()
2439
2440        self.assertEqual(a.t(0), "int")
2441        self.assertEqual(a.t(''), "str")
2442        self.assertEqual(a.t(0.0), "base")
2443
2444    def test_staticmethod_type_ann_register(self):
2445        class A:
2446            @functools.singledispatchmethod
2447            @staticmethod
2448            def t(arg):
2449                return arg
2450            @t.register
2451            @staticmethod
2452            def _(arg: int):
2453                return isinstance(arg, int)
2454            @t.register
2455            @staticmethod
2456            def _(arg: str):
2457                return isinstance(arg, str)
2458        a = A()
2459
2460        self.assertTrue(A.t(0))
2461        self.assertTrue(A.t(''))
2462        self.assertEqual(A.t(0.0), 0.0)
2463
2464    def test_classmethod_type_ann_register(self):
2465        class A:
2466            def __init__(self, arg):
2467                self.arg = arg
2468
2469            @functools.singledispatchmethod
2470            @classmethod
2471            def t(cls, arg):
2472                return cls("base")
2473            @t.register
2474            @classmethod
2475            def _(cls, arg: int):
2476                return cls("int")
2477            @t.register
2478            @classmethod
2479            def _(cls, arg: str):
2480                return cls("str")
2481
2482        self.assertEqual(A.t(0).arg, "int")
2483        self.assertEqual(A.t('').arg, "str")
2484        self.assertEqual(A.t(0.0).arg, "base")
2485
2486    def test_method_wrapping_attributes(self):
2487        class A:
2488            @functools.singledispatchmethod
2489            def func(self, arg: int) -> str:
2490                """My function docstring"""
2491                return str(arg)
2492            @functools.singledispatchmethod
2493            @classmethod
2494            def cls_func(cls, arg: int) -> str:
2495                """My function docstring"""
2496                return str(arg)
2497            @functools.singledispatchmethod
2498            @staticmethod
2499            def static_func(arg: int) -> str:
2500                """My function docstring"""
2501                return str(arg)
2502
2503        for meth in (
2504            A.func,
2505            A().func,
2506            A.cls_func,
2507            A().cls_func,
2508            A.static_func,
2509            A().static_func
2510        ):
2511            with self.subTest(meth=meth):
2512                self.assertEqual(meth.__doc__, 'My function docstring')
2513                self.assertEqual(meth.__annotations__['arg'], int)
2514
2515        self.assertEqual(A.func.__name__, 'func')
2516        self.assertEqual(A().func.__name__, 'func')
2517        self.assertEqual(A.cls_func.__name__, 'cls_func')
2518        self.assertEqual(A().cls_func.__name__, 'cls_func')
2519        self.assertEqual(A.static_func.__name__, 'static_func')
2520        self.assertEqual(A().static_func.__name__, 'static_func')
2521
2522    def test_double_wrapped_methods(self):
2523        def classmethod_friendly_decorator(func):
2524            wrapped = func.__func__
2525            @classmethod
2526            @functools.wraps(wrapped)
2527            def wrapper(*args, **kwargs):
2528                return wrapped(*args, **kwargs)
2529            return wrapper
2530
2531        class WithoutSingleDispatch:
2532            @classmethod
2533            @contextlib.contextmanager
2534            def cls_context_manager(cls, arg: int) -> str:
2535                try:
2536                    yield str(arg)
2537                finally:
2538                    return 'Done'
2539
2540            @classmethod_friendly_decorator
2541            @classmethod
2542            def decorated_classmethod(cls, arg: int) -> str:
2543                return str(arg)
2544
2545        class WithSingleDispatch:
2546            @functools.singledispatchmethod
2547            @classmethod
2548            @contextlib.contextmanager
2549            def cls_context_manager(cls, arg: int) -> str:
2550                """My function docstring"""
2551                try:
2552                    yield str(arg)
2553                finally:
2554                    return 'Done'
2555
2556            @functools.singledispatchmethod
2557            @classmethod_friendly_decorator
2558            @classmethod
2559            def decorated_classmethod(cls, arg: int) -> str:
2560                """My function docstring"""
2561                return str(arg)
2562
2563        # These are sanity checks
2564        # to test the test itself is working as expected
2565        with WithoutSingleDispatch.cls_context_manager(5) as foo:
2566            without_single_dispatch_foo = foo
2567
2568        with WithSingleDispatch.cls_context_manager(5) as foo:
2569            single_dispatch_foo = foo
2570
2571        self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
2572        self.assertEqual(single_dispatch_foo, '5')
2573
2574        self.assertEqual(
2575            WithoutSingleDispatch.decorated_classmethod(5),
2576            WithSingleDispatch.decorated_classmethod(5)
2577        )
2578
2579        self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
2580
2581        # Behavioural checks now follow
2582        for method_name in ('cls_context_manager', 'decorated_classmethod'):
2583            with self.subTest(method=method_name):
2584                self.assertEqual(
2585                    getattr(WithSingleDispatch, method_name).__name__,
2586                    getattr(WithoutSingleDispatch, method_name).__name__
2587                )
2588
2589                self.assertEqual(
2590                    getattr(WithSingleDispatch(), method_name).__name__,
2591                    getattr(WithoutSingleDispatch(), method_name).__name__
2592                )
2593
2594        for meth in (
2595            WithSingleDispatch.cls_context_manager,
2596            WithSingleDispatch().cls_context_manager,
2597            WithSingleDispatch.decorated_classmethod,
2598            WithSingleDispatch().decorated_classmethod
2599        ):
2600            with self.subTest(meth=meth):
2601                self.assertEqual(meth.__doc__, 'My function docstring')
2602                self.assertEqual(meth.__annotations__['arg'], int)
2603
2604        self.assertEqual(
2605            WithSingleDispatch.cls_context_manager.__name__,
2606            'cls_context_manager'
2607        )
2608        self.assertEqual(
2609            WithSingleDispatch().cls_context_manager.__name__,
2610            'cls_context_manager'
2611        )
2612        self.assertEqual(
2613            WithSingleDispatch.decorated_classmethod.__name__,
2614            'decorated_classmethod'
2615        )
2616        self.assertEqual(
2617            WithSingleDispatch().decorated_classmethod.__name__,
2618            'decorated_classmethod'
2619        )
2620
2621    def test_invalid_registrations(self):
2622        msg_prefix = "Invalid first argument to `register()`: "
2623        msg_suffix = (
2624            ". Use either `@register(some_class)` or plain `@register` on an "
2625            "annotated function."
2626        )
2627        @functools.singledispatch
2628        def i(arg):
2629            return "base"
2630        with self.assertRaises(TypeError) as exc:
2631            @i.register(42)
2632            def _(arg):
2633                return "I annotated with a non-type"
2634        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2635        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2636        with self.assertRaises(TypeError) as exc:
2637            @i.register
2638            def _(arg):
2639                return "I forgot to annotate"
2640        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2641            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2642        ))
2643        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2644
2645        with self.assertRaises(TypeError) as exc:
2646            @i.register
2647            def _(arg: typing.Iterable[str]):
2648                # At runtime, dispatching on generics is impossible.
2649                # When registering implementations with singledispatch, avoid
2650                # types from `typing`. Instead, annotate with regular types
2651                # or ABCs.
2652                return "I annotated with a generic collection"
2653        self.assertTrue(str(exc.exception).startswith(
2654            "Invalid annotation for 'arg'."
2655        ))
2656        self.assertTrue(str(exc.exception).endswith(
2657            'typing.Iterable[str] is not a class.'
2658        ))
2659
2660    def test_invalid_positional_argument(self):
2661        @functools.singledispatch
2662        def f(*args):
2663            pass
2664        msg = 'f requires at least 1 positional argument'
2665        with self.assertRaisesRegex(TypeError, msg):
2666            f()
2667
2668
2669class CachedCostItem:
2670    _cost = 1
2671
2672    def __init__(self):
2673        self.lock = py_functools.RLock()
2674
2675    @py_functools.cached_property
2676    def cost(self):
2677        """The cost of the item."""
2678        with self.lock:
2679            self._cost += 1
2680        return self._cost
2681
2682
2683class OptionallyCachedCostItem:
2684    _cost = 1
2685
2686    def get_cost(self):
2687        """The cost of the item."""
2688        self._cost += 1
2689        return self._cost
2690
2691    cached_cost = py_functools.cached_property(get_cost)
2692
2693
2694class CachedCostItemWait:
2695
2696    def __init__(self, event):
2697        self._cost = 1
2698        self.lock = py_functools.RLock()
2699        self.event = event
2700
2701    @py_functools.cached_property
2702    def cost(self):
2703        self.event.wait(1)
2704        with self.lock:
2705            self._cost += 1
2706        return self._cost
2707
2708
2709class CachedCostItemWithSlots:
2710    __slots__ = ('_cost')
2711
2712    def __init__(self):
2713        self._cost = 1
2714
2715    @py_functools.cached_property
2716    def cost(self):
2717        raise RuntimeError('never called, slots not supported')
2718
2719
2720class TestCachedProperty(unittest.TestCase):
2721    def test_cached(self):
2722        item = CachedCostItem()
2723        self.assertEqual(item.cost, 2)
2724        self.assertEqual(item.cost, 2) # not 3
2725
2726    def test_cached_attribute_name_differs_from_func_name(self):
2727        item = OptionallyCachedCostItem()
2728        self.assertEqual(item.get_cost(), 2)
2729        self.assertEqual(item.cached_cost, 3)
2730        self.assertEqual(item.get_cost(), 4)
2731        self.assertEqual(item.cached_cost, 3)
2732
2733    def test_threaded(self):
2734        go = threading.Event()
2735        item = CachedCostItemWait(go)
2736
2737        num_threads = 3
2738
2739        orig_si = sys.getswitchinterval()
2740        sys.setswitchinterval(1e-6)
2741        try:
2742            threads = [
2743                threading.Thread(target=lambda: item.cost)
2744                for k in range(num_threads)
2745            ]
2746            with threading_helper.start_threads(threads):
2747                go.set()
2748        finally:
2749            sys.setswitchinterval(orig_si)
2750
2751        self.assertEqual(item.cost, 2)
2752
2753    def test_object_with_slots(self):
2754        item = CachedCostItemWithSlots()
2755        with self.assertRaisesRegex(
2756                TypeError,
2757                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2758        ):
2759            item.cost
2760
2761    def test_immutable_dict(self):
2762        class MyMeta(type):
2763            @py_functools.cached_property
2764            def prop(self):
2765                return True
2766
2767        class MyClass(metaclass=MyMeta):
2768            pass
2769
2770        with self.assertRaisesRegex(
2771            TypeError,
2772            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2773        ):
2774            MyClass.prop
2775
2776    def test_reuse_different_names(self):
2777        """Disallow this case because decorated function a would not be cached."""
2778        with self.assertRaises(RuntimeError) as ctx:
2779            class ReusedCachedProperty:
2780                @py_functools.cached_property
2781                def a(self):
2782                    pass
2783
2784                b = a
2785
2786        self.assertEqual(
2787            str(ctx.exception.__context__),
2788            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2789        )
2790
2791    def test_reuse_same_name(self):
2792        """Reusing a cached_property on different classes under the same name is OK."""
2793        counter = 0
2794
2795        @py_functools.cached_property
2796        def _cp(_self):
2797            nonlocal counter
2798            counter += 1
2799            return counter
2800
2801        class A:
2802            cp = _cp
2803
2804        class B:
2805            cp = _cp
2806
2807        a = A()
2808        b = B()
2809
2810        self.assertEqual(a.cp, 1)
2811        self.assertEqual(b.cp, 2)
2812        self.assertEqual(a.cp, 1)
2813
2814    def test_set_name_not_called(self):
2815        cp = py_functools.cached_property(lambda s: None)
2816        class Foo:
2817            pass
2818
2819        Foo.cp = cp
2820
2821        with self.assertRaisesRegex(
2822                TypeError,
2823                "Cannot use cached_property instance without calling __set_name__ on it.",
2824        ):
2825            Foo().cp
2826
2827    def test_access_from_class(self):
2828        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2829
2830    def test_doc(self):
2831        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2832
2833
2834if __name__ == '__main__':
2835    unittest.main()
2836