1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16import os
17import weakref
18import gc
19from weakref import proxy
20import contextlib
21
22from test.support import import_helper
23from test.support import threading_helper
24from test.support.script_helper import assert_python_ok
25
26import functools
27
28py_functools = import_helper.import_fresh_module('functools',
29                                                 blocked=['_functools'])
30c_functools = import_helper.import_fresh_module('functools')
31
32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
33
34@contextlib.contextmanager
35def replaced_module(name, replacement):
36    original_module = sys.modules[name]
37    sys.modules[name] = replacement
38    try:
39        yield
40    finally:
41        sys.modules[name] = original_module
42
43def capture(*args, **kw):
44    """capture all positional and keyword arguments"""
45    return args, kw
46
47
48def signature(part):
49    """ return the signature of a partial object """
50    return (part.func, part.args, part.keywords, part.__dict__)
51
52class MyTuple(tuple):
53    pass
54
55class BadTuple(tuple):
56    def __add__(self, other):
57        return list(self) + list(other)
58
59class MyDict(dict):
60    pass
61
62
63class TestPartial:
64
65    def test_basic_examples(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        self.assertTrue(callable(p))
68        self.assertEqual(p(3, 4, b=30, c=40),
69                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
70        p = self.partial(map, lambda x: x*10)
71        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
72
73    def test_attributes(self):
74        p = self.partial(capture, 1, 2, a=10, b=20)
75        # attributes should be readable
76        self.assertEqual(p.func, capture)
77        self.assertEqual(p.args, (1, 2))
78        self.assertEqual(p.keywords, dict(a=10, b=20))
79
80    def test_argument_checking(self):
81        self.assertRaises(TypeError, self.partial)     # need at least a func arg
82        try:
83            self.partial(2)()
84        except TypeError:
85            pass
86        else:
87            self.fail('First arg not checked for callability')
88
89    def test_protection_of_callers_dict_argument(self):
90        # a caller's dictionary should not be altered by partial
91        def func(a=10, b=20):
92            return a
93        d = {'a':3}
94        p = self.partial(func, a=5)
95        self.assertEqual(p(**d), 3)
96        self.assertEqual(d, {'a':3})
97        p(b=7)
98        self.assertEqual(d, {'a':3})
99
100    def test_kwargs_copy(self):
101        # Issue #29532: Altering a kwarg dictionary passed to a constructor
102        # should not affect a partial object after creation
103        d = {'a': 3}
104        p = self.partial(capture, **d)
105        self.assertEqual(p(), ((), {'a': 3}))
106        d['a'] = 5
107        self.assertEqual(p(), ((), {'a': 3}))
108
109    def test_arg_combinations(self):
110        # exercise special code paths for zero args in either partial
111        # object or the caller
112        p = self.partial(capture)
113        self.assertEqual(p(), ((), {}))
114        self.assertEqual(p(1,2), ((1,2), {}))
115        p = self.partial(capture, 1, 2)
116        self.assertEqual(p(), ((1,2), {}))
117        self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119    def test_kw_combinations(self):
120        # exercise special code paths for no keyword args in
121        # either the partial object or the caller
122        p = self.partial(capture)
123        self.assertEqual(p.keywords, {})
124        self.assertEqual(p(), ((), {}))
125        self.assertEqual(p(a=1), ((), {'a':1}))
126        p = self.partial(capture, a=1)
127        self.assertEqual(p.keywords, {'a':1})
128        self.assertEqual(p(), ((), {'a':1}))
129        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130        # keyword args in the call override those in the partial object
131        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133    def test_positional(self):
134        # make sure positional arguments are captured correctly
135        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
136            p = self.partial(capture, *args)
137            expected = args + ('x',)
138            got, empty = p('x')
139            self.assertTrue(expected == got and empty == {})
140
141    def test_keyword(self):
142        # make sure keyword arguments are captured correctly
143        for a in ['a', 0, None, 3.5]:
144            p = self.partial(capture, a=a)
145            expected = {'a':a,'x':None}
146            empty, got = p(x=None)
147            self.assertTrue(expected == got and empty == ())
148
149    def test_no_side_effects(self):
150        # make sure there are no side effects that affect subsequent calls
151        p = self.partial(capture, 0, a=1)
152        args1, kw1 = p(1, b=2)
153        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
154        args2, kw2 = p()
155        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
156
157    def test_error_propagation(self):
158        def f(x, y):
159            x / y
160        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
164
165    def test_weakref(self):
166        f = self.partial(int, base=16)
167        p = proxy(f)
168        self.assertEqual(f.func, p.func)
169        f = None
170        support.gc_collect()  # For PyPy or other GCs.
171        self.assertRaises(ReferenceError, getattr, p, 'func')
172
173    def test_with_bound_and_unbound_methods(self):
174        data = list(map(str, range(10)))
175        join = self.partial(str.join, '')
176        self.assertEqual(join(data), '0123456789')
177        join = self.partial(''.join)
178        self.assertEqual(join(data), '0123456789')
179
180    def test_nested_optimization(self):
181        partial = self.partial
182        inner = partial(signature, 'asdf')
183        nested = partial(inner, bar=True)
184        flat = partial(signature, 'asdf', bar=True)
185        self.assertEqual(signature(nested), signature(flat))
186
187    def test_nested_partial_with_attribute(self):
188        # see issue 25137
189        partial = self.partial
190
191        def foo(bar):
192            return bar
193
194        p = partial(foo, 'first')
195        p2 = partial(p, 'second')
196        p2.new_attr = 'spam'
197        self.assertEqual(p2.new_attr, 'spam')
198
199    def test_repr(self):
200        args = (object(), object())
201        args_repr = ', '.join(repr(a) for a in args)
202        kwargs = {'a': object(), 'b': object()}
203        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
204                        'b={b!r}, a={a!r}'.format_map(kwargs)]
205        if self.partial in (c_functools.partial, py_functools.partial):
206            name = 'functools.partial'
207        else:
208            name = self.partial.__name__
209
210        f = self.partial(capture)
211        self.assertEqual(f'{name}({capture!r})', repr(f))
212
213        f = self.partial(capture, *args)
214        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
215
216        f = self.partial(capture, **kwargs)
217        self.assertIn(repr(f),
218                      [f'{name}({capture!r}, {kwargs_repr})'
219                       for kwargs_repr in kwargs_reprs])
220
221        f = self.partial(capture, *args, **kwargs)
222        self.assertIn(repr(f),
223                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
224                       for kwargs_repr in kwargs_reprs])
225
226    def test_recursive_repr(self):
227        if self.partial in (c_functools.partial, py_functools.partial):
228            name = 'functools.partial'
229        else:
230            name = self.partial.__name__
231
232        f = self.partial(capture)
233        f.__setstate__((f, (), {}, {}))
234        try:
235            self.assertEqual(repr(f), '%s(...)' % (name,))
236        finally:
237            f.__setstate__((capture, (), {}, {}))
238
239        f = self.partial(capture)
240        f.__setstate__((capture, (f,), {}, {}))
241        try:
242            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
243        finally:
244            f.__setstate__((capture, (), {}, {}))
245
246        f = self.partial(capture)
247        f.__setstate__((capture, (), {'a': f}, {}))
248        try:
249            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
250        finally:
251            f.__setstate__((capture, (), {}, {}))
252
253    def test_pickle(self):
254        with self.AllowPickle():
255            f = self.partial(signature, ['asdf'], bar=[True])
256            f.attr = []
257            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
258                f_copy = pickle.loads(pickle.dumps(f, proto))
259                self.assertEqual(signature(f_copy), signature(f))
260
261    def test_copy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.copy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIs(f_copy.attr, f.attr)
267        self.assertIs(f_copy.args, f.args)
268        self.assertIs(f_copy.keywords, f.keywords)
269
270    def test_deepcopy(self):
271        f = self.partial(signature, ['asdf'], bar=[True])
272        f.attr = []
273        f_copy = copy.deepcopy(f)
274        self.assertEqual(signature(f_copy), signature(f))
275        self.assertIsNot(f_copy.attr, f.attr)
276        self.assertIsNot(f_copy.args, f.args)
277        self.assertIsNot(f_copy.args[0], f.args[0])
278        self.assertIsNot(f_copy.keywords, f.keywords)
279        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
280
281    def test_setstate(self):
282        f = self.partial(signature)
283        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
284
285        self.assertEqual(signature(f),
286                         (capture, (1,), dict(a=10), dict(attr=[])))
287        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
288
289        f.__setstate__((capture, (1,), dict(a=10), None))
290
291        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
292        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
293
294        f.__setstate__((capture, (1,), None, None))
295        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
296        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
297        self.assertEqual(f(2), ((1, 2), {}))
298        self.assertEqual(f(), ((1,), {}))
299
300        f.__setstate__((capture, (), {}, None))
301        self.assertEqual(signature(f), (capture, (), {}, {}))
302        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
303        self.assertEqual(f(2), ((2,), {}))
304        self.assertEqual(f(), ((), {}))
305
306    def test_setstate_errors(self):
307        f = self.partial(signature)
308        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
309        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
310        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
311        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
312        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
313        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
314        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
315
316    def test_setstate_subclasses(self):
317        f = self.partial(signature)
318        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
319        s = signature(f)
320        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
321        self.assertIs(type(s[1]), tuple)
322        self.assertIs(type(s[2]), dict)
323        r = f()
324        self.assertEqual(r, ((1,), {'a': 10}))
325        self.assertIs(type(r[0]), tuple)
326        self.assertIs(type(r[1]), dict)
327
328        f.__setstate__((capture, BadTuple((1,)), {}, None))
329        s = signature(f)
330        self.assertEqual(s, (capture, (1,), {}, {}))
331        self.assertIs(type(s[1]), tuple)
332        r = f(2)
333        self.assertEqual(r, ((1, 2), {}))
334        self.assertIs(type(r[0]), tuple)
335
336    def test_recursive_pickle(self):
337        with self.AllowPickle():
338            f = self.partial(capture)
339            f.__setstate__((f, (), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    with self.assertRaises(RecursionError):
343                        pickle.dumps(f, proto)
344            finally:
345                f.__setstate__((capture, (), {}, {}))
346
347            f = self.partial(capture)
348            f.__setstate__((capture, (f,), {}, {}))
349            try:
350                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
351                    f_copy = pickle.loads(pickle.dumps(f, proto))
352                    try:
353                        self.assertIs(f_copy.args[0], f_copy)
354                    finally:
355                        f_copy.__setstate__((capture, (), {}, {}))
356            finally:
357                f.__setstate__((capture, (), {}, {}))
358
359            f = self.partial(capture)
360            f.__setstate__((capture, (), {'a': f}, {}))
361            try:
362                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
363                    f_copy = pickle.loads(pickle.dumps(f, proto))
364                    try:
365                        self.assertIs(f_copy.keywords['a'], f_copy)
366                    finally:
367                        f_copy.__setstate__((capture, (), {}, {}))
368            finally:
369                f.__setstate__((capture, (), {}, {}))
370
371    # Issue 6083: Reference counting bug
372    def test_setstate_refcount(self):
373        class BadSequence:
374            def __len__(self):
375                return 4
376            def __getitem__(self, key):
377                if key == 0:
378                    return max
379                elif key == 1:
380                    return tuple(range(1000000))
381                elif key in (2, 3):
382                    return {}
383                raise IndexError
384
385        f = self.partial(object)
386        self.assertRaises(TypeError, f.__setstate__, BadSequence())
387
388@unittest.skipUnless(c_functools, 'requires the C _functools module')
389class TestPartialC(TestPartial, unittest.TestCase):
390    if c_functools:
391        partial = c_functools.partial
392
393    class AllowPickle:
394        def __enter__(self):
395            return self
396        def __exit__(self, type, value, tb):
397            return False
398
399    def test_attributes_unwritable(self):
400        # attributes should not be writable
401        p = self.partial(capture, 1, 2, a=10, b=20)
402        self.assertRaises(AttributeError, setattr, p, 'func', map)
403        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
404        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
405
406        p = self.partial(hex)
407        try:
408            del p.__dict__
409        except TypeError:
410            pass
411        else:
412            self.fail('partial object allowed __dict__ to be deleted')
413
414    def test_manually_adding_non_string_keyword(self):
415        p = self.partial(capture)
416        # Adding a non-string/unicode keyword to partial kwargs
417        p.keywords[1234] = 'value'
418        r = repr(p)
419        self.assertIn('1234', r)
420        self.assertIn("'value'", r)
421        with self.assertRaises(TypeError):
422            p()
423
424    def test_keystr_replaces_value(self):
425        p = self.partial(capture)
426
427        class MutatesYourDict(object):
428            def __str__(self):
429                p.keywords[self] = ['sth2']
430                return 'astr'
431
432        # Replacing the value during key formatting should keep the original
433        # value alive (at least long enough).
434        p.keywords[MutatesYourDict()] = ['sth']
435        r = repr(p)
436        self.assertIn('astr', r)
437        self.assertIn("['sth']", r)
438
439
440class TestPartialPy(TestPartial, unittest.TestCase):
441    partial = py_functools.partial
442
443    class AllowPickle:
444        def __init__(self):
445            self._cm = replaced_module("functools", py_functools)
446        def __enter__(self):
447            return self._cm.__enter__()
448        def __exit__(self, type, value, tb):
449            return self._cm.__exit__(type, value, tb)
450
451if c_functools:
452    class CPartialSubclass(c_functools.partial):
453        pass
454
455class PyPartialSubclass(py_functools.partial):
456    pass
457
458@unittest.skipUnless(c_functools, 'requires the C _functools module')
459class TestPartialCSubclass(TestPartialC):
460    if c_functools:
461        partial = CPartialSubclass
462
463    # partial subclasses are not optimized for nested calls
464    test_nested_optimization = None
465
466class TestPartialPySubclass(TestPartialPy):
467    partial = PyPartialSubclass
468
469class TestPartialMethod(unittest.TestCase):
470
471    class A(object):
472        nothing = functools.partialmethod(capture)
473        positional = functools.partialmethod(capture, 1)
474        keywords = functools.partialmethod(capture, a=2)
475        both = functools.partialmethod(capture, 3, b=4)
476        spec_keywords = functools.partialmethod(capture, self=1, func=2)
477
478        nested = functools.partialmethod(positional, 5)
479
480        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
481
482        static = functools.partialmethod(staticmethod(capture), 8)
483        cls = functools.partialmethod(classmethod(capture), d=9)
484
485    a = A()
486
487    def test_arg_combinations(self):
488        self.assertEqual(self.a.nothing(), ((self.a,), {}))
489        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
490        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
491        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
492
493        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
494        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
495        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
496        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
497
498        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
499        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
500        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
501        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
502
503        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
504        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
505        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
506        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
507
508        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
509
510        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
511
512    def test_nested(self):
513        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
514        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
515        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
516        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
517
518        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
519
520    def test_over_partial(self):
521        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
522        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
523        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
524        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
525
526        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
527
528    def test_bound_method_introspection(self):
529        obj = self.a
530        self.assertIs(obj.both.__self__, obj)
531        self.assertIs(obj.nested.__self__, obj)
532        self.assertIs(obj.over_partial.__self__, obj)
533        self.assertIs(obj.cls.__self__, self.A)
534        self.assertIs(self.A.cls.__self__, self.A)
535
536    def test_unbound_method_retrieval(self):
537        obj = self.A
538        self.assertFalse(hasattr(obj.both, "__self__"))
539        self.assertFalse(hasattr(obj.nested, "__self__"))
540        self.assertFalse(hasattr(obj.over_partial, "__self__"))
541        self.assertFalse(hasattr(obj.static, "__self__"))
542        self.assertFalse(hasattr(self.a.static, "__self__"))
543
544    def test_descriptors(self):
545        for obj in [self.A, self.a]:
546            with self.subTest(obj=obj):
547                self.assertEqual(obj.static(), ((8,), {}))
548                self.assertEqual(obj.static(5), ((8, 5), {}))
549                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
550                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
551
552                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
553                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
554                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
555                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
556
557    def test_overriding_keywords(self):
558        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
559        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
560
561    def test_invalid_args(self):
562        with self.assertRaises(TypeError):
563            class B(object):
564                method = functools.partialmethod(None, 1)
565        with self.assertRaises(TypeError):
566            class B:
567                method = functools.partialmethod()
568        with self.assertRaises(TypeError):
569            class B:
570                method = functools.partialmethod(func=capture, a=1)
571
572    def test_repr(self):
573        self.assertEqual(repr(vars(self.A)['both']),
574                         'functools.partialmethod({}, 3, b=4)'.format(capture))
575
576    def test_abstract(self):
577        class Abstract(abc.ABCMeta):
578
579            @abc.abstractmethod
580            def add(self, x, y):
581                pass
582
583            add5 = functools.partialmethod(add, 5)
584
585        self.assertTrue(Abstract.add.__isabstractmethod__)
586        self.assertTrue(Abstract.add5.__isabstractmethod__)
587
588        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
589            self.assertFalse(getattr(func, '__isabstractmethod__', False))
590
591    def test_positional_only(self):
592        def f(a, b, /):
593            return a + b
594
595        p = functools.partial(f, 1)
596        self.assertEqual(p(2), f(1, 2))
597
598
599class TestUpdateWrapper(unittest.TestCase):
600
601    def check_wrapper(self, wrapper, wrapped,
602                      assigned=functools.WRAPPER_ASSIGNMENTS,
603                      updated=functools.WRAPPER_UPDATES):
604        # Check attributes were assigned
605        for name in assigned:
606            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
607        # Check attributes were updated
608        for name in updated:
609            wrapper_attr = getattr(wrapper, name)
610            wrapped_attr = getattr(wrapped, name)
611            for key in wrapped_attr:
612                if name == "__dict__" and key == "__wrapped__":
613                    # __wrapped__ is overwritten by the update code
614                    continue
615                self.assertIs(wrapped_attr[key], wrapper_attr[key])
616        # Check __wrapped__
617        self.assertIs(wrapper.__wrapped__, wrapped)
618
619
620    def _default_update(self):
621        def f(a:'This is a new annotation'):
622            """This is a test"""
623            pass
624        f.attr = 'This is also a test'
625        f.__wrapped__ = "This is a bald faced lie"
626        def wrapper(b:'This is the prior annotation'):
627            pass
628        functools.update_wrapper(wrapper, f)
629        return wrapper, f
630
631    def test_default_update(self):
632        wrapper, f = self._default_update()
633        self.check_wrapper(wrapper, f)
634        self.assertIs(wrapper.__wrapped__, f)
635        self.assertEqual(wrapper.__name__, 'f')
636        self.assertEqual(wrapper.__qualname__, f.__qualname__)
637        self.assertEqual(wrapper.attr, 'This is also a test')
638        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
639        self.assertNotIn('b', wrapper.__annotations__)
640
641    @unittest.skipIf(sys.flags.optimize >= 2,
642                     "Docstrings are omitted with -O2 and above")
643    def test_default_update_doc(self):
644        wrapper, f = self._default_update()
645        self.assertEqual(wrapper.__doc__, 'This is a test')
646
647    def test_no_update(self):
648        def f():
649            """This is a test"""
650            pass
651        f.attr = 'This is also a test'
652        def wrapper():
653            pass
654        functools.update_wrapper(wrapper, f, (), ())
655        self.check_wrapper(wrapper, f, (), ())
656        self.assertEqual(wrapper.__name__, 'wrapper')
657        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
658        self.assertEqual(wrapper.__doc__, None)
659        self.assertEqual(wrapper.__annotations__, {})
660        self.assertFalse(hasattr(wrapper, 'attr'))
661
662    def test_selective_update(self):
663        def f():
664            pass
665        f.attr = 'This is a different test'
666        f.dict_attr = dict(a=1, b=2, c=3)
667        def wrapper():
668            pass
669        wrapper.dict_attr = {}
670        assign = ('attr',)
671        update = ('dict_attr',)
672        functools.update_wrapper(wrapper, f, assign, update)
673        self.check_wrapper(wrapper, f, assign, update)
674        self.assertEqual(wrapper.__name__, 'wrapper')
675        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
676        self.assertEqual(wrapper.__doc__, None)
677        self.assertEqual(wrapper.attr, 'This is a different test')
678        self.assertEqual(wrapper.dict_attr, f.dict_attr)
679
680    def test_missing_attributes(self):
681        def f():
682            pass
683        def wrapper():
684            pass
685        wrapper.dict_attr = {}
686        assign = ('attr',)
687        update = ('dict_attr',)
688        # Missing attributes on wrapped object are ignored
689        functools.update_wrapper(wrapper, f, assign, update)
690        self.assertNotIn('attr', wrapper.__dict__)
691        self.assertEqual(wrapper.dict_attr, {})
692        # Wrapper must have expected attributes for updating
693        del wrapper.dict_attr
694        with self.assertRaises(AttributeError):
695            functools.update_wrapper(wrapper, f, assign, update)
696        wrapper.dict_attr = 1
697        with self.assertRaises(AttributeError):
698            functools.update_wrapper(wrapper, f, assign, update)
699
700    @support.requires_docstrings
701    @unittest.skipIf(sys.flags.optimize >= 2,
702                     "Docstrings are omitted with -O2 and above")
703    def test_builtin_update(self):
704        # Test for bug #1576241
705        def wrapper():
706            pass
707        functools.update_wrapper(wrapper, max)
708        self.assertEqual(wrapper.__name__, 'max')
709        self.assertTrue(wrapper.__doc__.startswith('max('))
710        self.assertEqual(wrapper.__annotations__, {})
711
712
713class TestWraps(TestUpdateWrapper):
714
715    def _default_update(self):
716        def f():
717            """This is a test"""
718            pass
719        f.attr = 'This is also a test'
720        f.__wrapped__ = "This is still a bald faced lie"
721        @functools.wraps(f)
722        def wrapper():
723            pass
724        return wrapper, f
725
726    def test_default_update(self):
727        wrapper, f = self._default_update()
728        self.check_wrapper(wrapper, f)
729        self.assertEqual(wrapper.__name__, 'f')
730        self.assertEqual(wrapper.__qualname__, f.__qualname__)
731        self.assertEqual(wrapper.attr, 'This is also a test')
732
733    @unittest.skipIf(sys.flags.optimize >= 2,
734                     "Docstrings are omitted with -O2 and above")
735    def test_default_update_doc(self):
736        wrapper, _ = self._default_update()
737        self.assertEqual(wrapper.__doc__, 'This is a test')
738
739    def test_no_update(self):
740        def f():
741            """This is a test"""
742            pass
743        f.attr = 'This is also a test'
744        @functools.wraps(f, (), ())
745        def wrapper():
746            pass
747        self.check_wrapper(wrapper, f, (), ())
748        self.assertEqual(wrapper.__name__, 'wrapper')
749        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
750        self.assertEqual(wrapper.__doc__, None)
751        self.assertFalse(hasattr(wrapper, 'attr'))
752
753    def test_selective_update(self):
754        def f():
755            pass
756        f.attr = 'This is a different test'
757        f.dict_attr = dict(a=1, b=2, c=3)
758        def add_dict_attr(f):
759            f.dict_attr = {}
760            return f
761        assign = ('attr',)
762        update = ('dict_attr',)
763        @functools.wraps(f, assign, update)
764        @add_dict_attr
765        def wrapper():
766            pass
767        self.check_wrapper(wrapper, f, assign, update)
768        self.assertEqual(wrapper.__name__, 'wrapper')
769        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
770        self.assertEqual(wrapper.__doc__, None)
771        self.assertEqual(wrapper.attr, 'This is a different test')
772        self.assertEqual(wrapper.dict_attr, f.dict_attr)
773
774
775class TestReduce:
776    def test_reduce(self):
777        class Squares:
778            def __init__(self, max):
779                self.max = max
780                self.sofar = []
781
782            def __len__(self):
783                return len(self.sofar)
784
785            def __getitem__(self, i):
786                if not 0 <= i < self.max: raise IndexError
787                n = len(self.sofar)
788                while n <= i:
789                    self.sofar.append(n*n)
790                    n += 1
791                return self.sofar[i]
792        def add(x, y):
793            return x + y
794        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
795        self.assertEqual(
796            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
797            ['a','c','d','w']
798        )
799        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
800        self.assertEqual(
801            self.reduce(lambda x, y: x*y, range(2,21), 1),
802            2432902008176640000
803        )
804        self.assertEqual(self.reduce(add, Squares(10)), 285)
805        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
806        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
807        self.assertRaises(TypeError, self.reduce)
808        self.assertRaises(TypeError, self.reduce, 42, 42)
809        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
810        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
811        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
812        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
813        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
814        self.assertRaises(TypeError, self.reduce, add, "")
815        self.assertRaises(TypeError, self.reduce, add, ())
816        self.assertRaises(TypeError, self.reduce, add, object())
817
818        class TestFailingIter:
819            def __iter__(self):
820                raise RuntimeError
821        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
822
823        self.assertEqual(self.reduce(add, [], None), None)
824        self.assertEqual(self.reduce(add, [], 42), 42)
825
826        class BadSeq:
827            def __getitem__(self, index):
828                raise ValueError
829        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
830
831    # Test reduce()'s use of iterators.
832    def test_iterator_usage(self):
833        class SequenceClass:
834            def __init__(self, n):
835                self.n = n
836            def __getitem__(self, i):
837                if 0 <= i < self.n:
838                    return i
839                else:
840                    raise IndexError
841
842        from operator import add
843        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
844        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
845        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
846        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
847        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
848        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
849
850        d = {"one": 1, "two": 2, "three": 3}
851        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
852
853
854@unittest.skipUnless(c_functools, 'requires the C _functools module')
855class TestReduceC(TestReduce, unittest.TestCase):
856    if c_functools:
857        reduce = c_functools.reduce
858
859
860class TestReducePy(TestReduce, unittest.TestCase):
861    reduce = staticmethod(py_functools.reduce)
862
863
864class TestCmpToKey:
865
866    def test_cmp_to_key(self):
867        def cmp1(x, y):
868            return (x > y) - (x < y)
869        key = self.cmp_to_key(cmp1)
870        self.assertEqual(key(3), key(3))
871        self.assertGreater(key(3), key(1))
872        self.assertGreaterEqual(key(3), key(3))
873
874        def cmp2(x, y):
875            return int(x) - int(y)
876        key = self.cmp_to_key(cmp2)
877        self.assertEqual(key(4.0), key('4'))
878        self.assertLess(key(2), key('35'))
879        self.assertLessEqual(key(2), key('35'))
880        self.assertNotEqual(key(2), key('35'))
881
882    def test_cmp_to_key_arguments(self):
883        def cmp1(x, y):
884            return (x > y) - (x < y)
885        key = self.cmp_to_key(mycmp=cmp1)
886        self.assertEqual(key(obj=3), key(obj=3))
887        self.assertGreater(key(obj=3), key(obj=1))
888        with self.assertRaises((TypeError, AttributeError)):
889            key(3) > 1    # rhs is not a K object
890        with self.assertRaises((TypeError, AttributeError)):
891            1 < key(3)    # lhs is not a K object
892        with self.assertRaises(TypeError):
893            key = self.cmp_to_key()             # too few args
894        with self.assertRaises(TypeError):
895            key = self.cmp_to_key(cmp1, None)   # too many args
896        key = self.cmp_to_key(cmp1)
897        with self.assertRaises(TypeError):
898            key()                                    # too few args
899        with self.assertRaises(TypeError):
900            key(None, None)                          # too many args
901
902    def test_bad_cmp(self):
903        def cmp1(x, y):
904            raise ZeroDivisionError
905        key = self.cmp_to_key(cmp1)
906        with self.assertRaises(ZeroDivisionError):
907            key(3) > key(1)
908
909        class BadCmp:
910            def __lt__(self, other):
911                raise ZeroDivisionError
912        def cmp1(x, y):
913            return BadCmp()
914        with self.assertRaises(ZeroDivisionError):
915            key(3) > key(1)
916
917    def test_obj_field(self):
918        def cmp1(x, y):
919            return (x > y) - (x < y)
920        key = self.cmp_to_key(mycmp=cmp1)
921        self.assertEqual(key(50).obj, 50)
922
923    def test_sort_int(self):
924        def mycmp(x, y):
925            return y - x
926        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
927                         [4, 3, 2, 1, 0])
928
929    def test_sort_int_str(self):
930        def mycmp(x, y):
931            x, y = int(x), int(y)
932            return (x > y) - (x < y)
933        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
934        values = sorted(values, key=self.cmp_to_key(mycmp))
935        self.assertEqual([int(value) for value in values],
936                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
937
938    def test_hash(self):
939        def mycmp(x, y):
940            return y - x
941        key = self.cmp_to_key(mycmp)
942        k = key(10)
943        self.assertRaises(TypeError, hash, k)
944        self.assertNotIsInstance(k, collections.abc.Hashable)
945
946
947@unittest.skipUnless(c_functools, 'requires the C _functools module')
948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
949    if c_functools:
950        cmp_to_key = c_functools.cmp_to_key
951
952    @support.cpython_only
953    def test_disallow_instantiation(self):
954        # Ensure that the type disallows instantiation (bpo-43916)
955        support.check_disallow_instantiation(
956            self, type(c_functools.cmp_to_key(None))
957        )
958
959
960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
961    cmp_to_key = staticmethod(py_functools.cmp_to_key)
962
963
964class TestTotalOrdering(unittest.TestCase):
965
966    def test_total_ordering_lt(self):
967        @functools.total_ordering
968        class A:
969            def __init__(self, value):
970                self.value = value
971            def __lt__(self, other):
972                return self.value < other.value
973            def __eq__(self, other):
974                return self.value == other.value
975        self.assertTrue(A(1) < A(2))
976        self.assertTrue(A(2) > A(1))
977        self.assertTrue(A(1) <= A(2))
978        self.assertTrue(A(2) >= A(1))
979        self.assertTrue(A(2) <= A(2))
980        self.assertTrue(A(2) >= A(2))
981        self.assertFalse(A(1) > A(2))
982
983    def test_total_ordering_le(self):
984        @functools.total_ordering
985        class A:
986            def __init__(self, value):
987                self.value = value
988            def __le__(self, other):
989                return self.value <= other.value
990            def __eq__(self, other):
991                return self.value == other.value
992        self.assertTrue(A(1) < A(2))
993        self.assertTrue(A(2) > A(1))
994        self.assertTrue(A(1) <= A(2))
995        self.assertTrue(A(2) >= A(1))
996        self.assertTrue(A(2) <= A(2))
997        self.assertTrue(A(2) >= A(2))
998        self.assertFalse(A(1) >= A(2))
999
1000    def test_total_ordering_gt(self):
1001        @functools.total_ordering
1002        class A:
1003            def __init__(self, value):
1004                self.value = value
1005            def __gt__(self, other):
1006                return self.value > other.value
1007            def __eq__(self, other):
1008                return self.value == other.value
1009        self.assertTrue(A(1) < A(2))
1010        self.assertTrue(A(2) > A(1))
1011        self.assertTrue(A(1) <= A(2))
1012        self.assertTrue(A(2) >= A(1))
1013        self.assertTrue(A(2) <= A(2))
1014        self.assertTrue(A(2) >= A(2))
1015        self.assertFalse(A(2) < A(1))
1016
1017    def test_total_ordering_ge(self):
1018        @functools.total_ordering
1019        class A:
1020            def __init__(self, value):
1021                self.value = value
1022            def __ge__(self, other):
1023                return self.value >= other.value
1024            def __eq__(self, other):
1025                return self.value == other.value
1026        self.assertTrue(A(1) < A(2))
1027        self.assertTrue(A(2) > A(1))
1028        self.assertTrue(A(1) <= A(2))
1029        self.assertTrue(A(2) >= A(1))
1030        self.assertTrue(A(2) <= A(2))
1031        self.assertTrue(A(2) >= A(2))
1032        self.assertFalse(A(2) <= A(1))
1033
1034    def test_total_ordering_no_overwrite(self):
1035        # new methods should not overwrite existing
1036        @functools.total_ordering
1037        class A(int):
1038            pass
1039        self.assertTrue(A(1) < A(2))
1040        self.assertTrue(A(2) > A(1))
1041        self.assertTrue(A(1) <= A(2))
1042        self.assertTrue(A(2) >= A(1))
1043        self.assertTrue(A(2) <= A(2))
1044        self.assertTrue(A(2) >= A(2))
1045
1046    def test_no_operations_defined(self):
1047        with self.assertRaises(ValueError):
1048            @functools.total_ordering
1049            class A:
1050                pass
1051
1052    def test_type_error_when_not_implemented(self):
1053        # bug 10042; ensure stack overflow does not occur
1054        # when decorated types return NotImplemented
1055        @functools.total_ordering
1056        class ImplementsLessThan:
1057            def __init__(self, value):
1058                self.value = value
1059            def __eq__(self, other):
1060                if isinstance(other, ImplementsLessThan):
1061                    return self.value == other.value
1062                return False
1063            def __lt__(self, other):
1064                if isinstance(other, ImplementsLessThan):
1065                    return self.value < other.value
1066                return NotImplemented
1067
1068        @functools.total_ordering
1069        class ImplementsGreaterThan:
1070            def __init__(self, value):
1071                self.value = value
1072            def __eq__(self, other):
1073                if isinstance(other, ImplementsGreaterThan):
1074                    return self.value == other.value
1075                return False
1076            def __gt__(self, other):
1077                if isinstance(other, ImplementsGreaterThan):
1078                    return self.value > other.value
1079                return NotImplemented
1080
1081        @functools.total_ordering
1082        class ImplementsLessThanEqualTo:
1083            def __init__(self, value):
1084                self.value = value
1085            def __eq__(self, other):
1086                if isinstance(other, ImplementsLessThanEqualTo):
1087                    return self.value == other.value
1088                return False
1089            def __le__(self, other):
1090                if isinstance(other, ImplementsLessThanEqualTo):
1091                    return self.value <= other.value
1092                return NotImplemented
1093
1094        @functools.total_ordering
1095        class ImplementsGreaterThanEqualTo:
1096            def __init__(self, value):
1097                self.value = value
1098            def __eq__(self, other):
1099                if isinstance(other, ImplementsGreaterThanEqualTo):
1100                    return self.value == other.value
1101                return False
1102            def __ge__(self, other):
1103                if isinstance(other, ImplementsGreaterThanEqualTo):
1104                    return self.value >= other.value
1105                return NotImplemented
1106
1107        @functools.total_ordering
1108        class ComparatorNotImplemented:
1109            def __init__(self, value):
1110                self.value = value
1111            def __eq__(self, other):
1112                if isinstance(other, ComparatorNotImplemented):
1113                    return self.value == other.value
1114                return False
1115            def __lt__(self, other):
1116                return NotImplemented
1117
1118        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1119            ImplementsLessThan(-1) < 1
1120
1121        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1122            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1123
1124        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1125            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1126
1127        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1128            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1129
1130        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1131            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1132
1133        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1134            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1135
1136        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1137            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1138
1139        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1140            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1141
1142        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1143            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1144
1145        with self.subTest("GE when equal"):
1146            a = ComparatorNotImplemented(8)
1147            b = ComparatorNotImplemented(8)
1148            self.assertEqual(a, b)
1149            with self.assertRaises(TypeError):
1150                a >= b
1151
1152        with self.subTest("LE when equal"):
1153            a = ComparatorNotImplemented(9)
1154            b = ComparatorNotImplemented(9)
1155            self.assertEqual(a, b)
1156            with self.assertRaises(TypeError):
1157                a <= b
1158
1159    def test_pickle(self):
1160        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1161            for name in '__lt__', '__gt__', '__le__', '__ge__':
1162                with self.subTest(method=name, proto=proto):
1163                    method = getattr(Orderable_LT, name)
1164                    method_copy = pickle.loads(pickle.dumps(method, proto))
1165                    self.assertIs(method_copy, method)
1166
1167
1168    def test_total_ordering_for_metaclasses_issue_44605(self):
1169
1170        @functools.total_ordering
1171        class SortableMeta(type):
1172            def __new__(cls, name, bases, ns):
1173                return super().__new__(cls, name, bases, ns)
1174
1175            def __lt__(self, other):
1176                if not isinstance(other, SortableMeta):
1177                    pass
1178                return self.__name__ < other.__name__
1179
1180            def __eq__(self, other):
1181                if not isinstance(other, SortableMeta):
1182                    pass
1183                return self.__name__ == other.__name__
1184
1185        class B(metaclass=SortableMeta):
1186            pass
1187
1188        class A(metaclass=SortableMeta):
1189            pass
1190
1191        self.assertTrue(A < B)
1192        self.assertFalse(A > B)
1193
1194
1195@functools.total_ordering
1196class Orderable_LT:
1197    def __init__(self, value):
1198        self.value = value
1199    def __lt__(self, other):
1200        return self.value < other.value
1201    def __eq__(self, other):
1202        return self.value == other.value
1203
1204
1205class TestCache:
1206    # This tests that the pass-through is working as designed.
1207    # The underlying functionality is tested in TestLRU.
1208
1209    def test_cache(self):
1210        @self.module.cache
1211        def fib(n):
1212            if n < 2:
1213                return n
1214            return fib(n-1) + fib(n-2)
1215        self.assertEqual([fib(n) for n in range(16)],
1216            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1217        self.assertEqual(fib.cache_info(),
1218            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1219        fib.cache_clear()
1220        self.assertEqual(fib.cache_info(),
1221            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1222
1223
1224class TestLRU:
1225
1226    def test_lru(self):
1227        def orig(x, y):
1228            return 3 * x + y
1229        f = self.module.lru_cache(maxsize=20)(orig)
1230        hits, misses, maxsize, currsize = f.cache_info()
1231        self.assertEqual(maxsize, 20)
1232        self.assertEqual(currsize, 0)
1233        self.assertEqual(hits, 0)
1234        self.assertEqual(misses, 0)
1235
1236        domain = range(5)
1237        for i in range(1000):
1238            x, y = choice(domain), choice(domain)
1239            actual = f(x, y)
1240            expected = orig(x, y)
1241            self.assertEqual(actual, expected)
1242        hits, misses, maxsize, currsize = f.cache_info()
1243        self.assertTrue(hits > misses)
1244        self.assertEqual(hits + misses, 1000)
1245        self.assertEqual(currsize, 20)
1246
1247        f.cache_clear()   # test clearing
1248        hits, misses, maxsize, currsize = f.cache_info()
1249        self.assertEqual(hits, 0)
1250        self.assertEqual(misses, 0)
1251        self.assertEqual(currsize, 0)
1252        f(x, y)
1253        hits, misses, maxsize, currsize = f.cache_info()
1254        self.assertEqual(hits, 0)
1255        self.assertEqual(misses, 1)
1256        self.assertEqual(currsize, 1)
1257
1258        # Test bypassing the cache
1259        self.assertIs(f.__wrapped__, orig)
1260        f.__wrapped__(x, y)
1261        hits, misses, maxsize, currsize = f.cache_info()
1262        self.assertEqual(hits, 0)
1263        self.assertEqual(misses, 1)
1264        self.assertEqual(currsize, 1)
1265
1266        # test size zero (which means "never-cache")
1267        @self.module.lru_cache(0)
1268        def f():
1269            nonlocal f_cnt
1270            f_cnt += 1
1271            return 20
1272        self.assertEqual(f.cache_info().maxsize, 0)
1273        f_cnt = 0
1274        for i in range(5):
1275            self.assertEqual(f(), 20)
1276        self.assertEqual(f_cnt, 5)
1277        hits, misses, maxsize, currsize = f.cache_info()
1278        self.assertEqual(hits, 0)
1279        self.assertEqual(misses, 5)
1280        self.assertEqual(currsize, 0)
1281
1282        # test size one
1283        @self.module.lru_cache(1)
1284        def f():
1285            nonlocal f_cnt
1286            f_cnt += 1
1287            return 20
1288        self.assertEqual(f.cache_info().maxsize, 1)
1289        f_cnt = 0
1290        for i in range(5):
1291            self.assertEqual(f(), 20)
1292        self.assertEqual(f_cnt, 1)
1293        hits, misses, maxsize, currsize = f.cache_info()
1294        self.assertEqual(hits, 4)
1295        self.assertEqual(misses, 1)
1296        self.assertEqual(currsize, 1)
1297
1298        # test size two
1299        @self.module.lru_cache(2)
1300        def f(x):
1301            nonlocal f_cnt
1302            f_cnt += 1
1303            return x*10
1304        self.assertEqual(f.cache_info().maxsize, 2)
1305        f_cnt = 0
1306        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1307            #    *  *              *                          *
1308            self.assertEqual(f(x), x*10)
1309        self.assertEqual(f_cnt, 4)
1310        hits, misses, maxsize, currsize = f.cache_info()
1311        self.assertEqual(hits, 12)
1312        self.assertEqual(misses, 4)
1313        self.assertEqual(currsize, 2)
1314
1315    def test_lru_no_args(self):
1316        @self.module.lru_cache
1317        def square(x):
1318            return x ** 2
1319
1320        self.assertEqual(list(map(square, [10, 20, 10])),
1321                         [100, 400, 100])
1322        self.assertEqual(square.cache_info().hits, 1)
1323        self.assertEqual(square.cache_info().misses, 2)
1324        self.assertEqual(square.cache_info().maxsize, 128)
1325        self.assertEqual(square.cache_info().currsize, 2)
1326
1327    def test_lru_bug_35780(self):
1328        # C version of the lru_cache was not checking to see if
1329        # the user function call has already modified the cache
1330        # (this arises in recursive calls and in multi-threading).
1331        # This cause the cache to have orphan links not referenced
1332        # by the cache dictionary.
1333
1334        once = True                 # Modified by f(x) below
1335
1336        @self.module.lru_cache(maxsize=10)
1337        def f(x):
1338            nonlocal once
1339            rv = f'.{x}.'
1340            if x == 20 and once:
1341                once = False
1342                rv = f(x)
1343            return rv
1344
1345        # Fill the cache
1346        for x in range(15):
1347            self.assertEqual(f(x), f'.{x}.')
1348        self.assertEqual(f.cache_info().currsize, 10)
1349
1350        # Make a recursive call and make sure the cache remains full
1351        self.assertEqual(f(20), '.20.')
1352        self.assertEqual(f.cache_info().currsize, 10)
1353
1354    def test_lru_bug_36650(self):
1355        # C version of lru_cache was treating a call with an empty **kwargs
1356        # dictionary as being distinct from a call with no keywords at all.
1357        # This did not result in an incorrect answer, but it did trigger
1358        # an unexpected cache miss.
1359
1360        @self.module.lru_cache()
1361        def f(x):
1362            pass
1363
1364        f(0)
1365        f(0, **{})
1366        self.assertEqual(f.cache_info().hits, 1)
1367
1368    def test_lru_hash_only_once(self):
1369        # To protect against weird reentrancy bugs and to improve
1370        # efficiency when faced with slow __hash__ methods, the
1371        # LRU cache guarantees that it will only call __hash__
1372        # only once per use as an argument to the cached function.
1373
1374        @self.module.lru_cache(maxsize=1)
1375        def f(x, y):
1376            return x * 3 + y
1377
1378        # Simulate the integer 5
1379        mock_int = unittest.mock.Mock()
1380        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1381        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1382
1383        # Add to cache:  One use as an argument gives one call
1384        self.assertEqual(f(mock_int, 1), 16)
1385        self.assertEqual(mock_int.__hash__.call_count, 1)
1386        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1387
1388        # Cache hit: One use as an argument gives one additional call
1389        self.assertEqual(f(mock_int, 1), 16)
1390        self.assertEqual(mock_int.__hash__.call_count, 2)
1391        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1392
1393        # Cache eviction: No use as an argument gives no additional call
1394        self.assertEqual(f(6, 2), 20)
1395        self.assertEqual(mock_int.__hash__.call_count, 2)
1396        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1397
1398        # Cache miss: One use as an argument gives one additional call
1399        self.assertEqual(f(mock_int, 1), 16)
1400        self.assertEqual(mock_int.__hash__.call_count, 3)
1401        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1402
1403    def test_lru_reentrancy_with_len(self):
1404        # Test to make sure the LRU cache code isn't thrown-off by
1405        # caching the built-in len() function.  Since len() can be
1406        # cached, we shouldn't use it inside the lru code itself.
1407        old_len = builtins.len
1408        try:
1409            builtins.len = self.module.lru_cache(4)(len)
1410            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1411                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1412        finally:
1413            builtins.len = old_len
1414
1415    def test_lru_star_arg_handling(self):
1416        # Test regression that arose in ea064ff3c10f
1417        @functools.lru_cache()
1418        def f(*args):
1419            return args
1420
1421        self.assertEqual(f(1, 2), (1, 2))
1422        self.assertEqual(f((1, 2)), ((1, 2),))
1423
1424    def test_lru_type_error(self):
1425        # Regression test for issue #28653.
1426        # lru_cache was leaking when one of the arguments
1427        # wasn't cacheable.
1428
1429        @functools.lru_cache(maxsize=None)
1430        def infinite_cache(o):
1431            pass
1432
1433        @functools.lru_cache(maxsize=10)
1434        def limited_cache(o):
1435            pass
1436
1437        with self.assertRaises(TypeError):
1438            infinite_cache([])
1439
1440        with self.assertRaises(TypeError):
1441            limited_cache([])
1442
1443    def test_lru_with_maxsize_none(self):
1444        @self.module.lru_cache(maxsize=None)
1445        def fib(n):
1446            if n < 2:
1447                return n
1448            return fib(n-1) + fib(n-2)
1449        self.assertEqual([fib(n) for n in range(16)],
1450            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1451        self.assertEqual(fib.cache_info(),
1452            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1453        fib.cache_clear()
1454        self.assertEqual(fib.cache_info(),
1455            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1456
1457    def test_lru_with_maxsize_negative(self):
1458        @self.module.lru_cache(maxsize=-10)
1459        def eq(n):
1460            return n
1461        for i in (0, 1):
1462            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1463        self.assertEqual(eq.cache_info(),
1464            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1465
1466    def test_lru_with_exceptions(self):
1467        # Verify that user_function exceptions get passed through without
1468        # creating a hard-to-read chained exception.
1469        # http://bugs.python.org/issue13177
1470        for maxsize in (None, 128):
1471            @self.module.lru_cache(maxsize)
1472            def func(i):
1473                return 'abc'[i]
1474            self.assertEqual(func(0), 'a')
1475            with self.assertRaises(IndexError) as cm:
1476                func(15)
1477            self.assertIsNone(cm.exception.__context__)
1478            # Verify that the previous exception did not result in a cached entry
1479            with self.assertRaises(IndexError):
1480                func(15)
1481
1482    def test_lru_with_types(self):
1483        for maxsize in (None, 128):
1484            @self.module.lru_cache(maxsize=maxsize, typed=True)
1485            def square(x):
1486                return x * x
1487            self.assertEqual(square(3), 9)
1488            self.assertEqual(type(square(3)), type(9))
1489            self.assertEqual(square(3.0), 9.0)
1490            self.assertEqual(type(square(3.0)), type(9.0))
1491            self.assertEqual(square(x=3), 9)
1492            self.assertEqual(type(square(x=3)), type(9))
1493            self.assertEqual(square(x=3.0), 9.0)
1494            self.assertEqual(type(square(x=3.0)), type(9.0))
1495            self.assertEqual(square.cache_info().hits, 4)
1496            self.assertEqual(square.cache_info().misses, 4)
1497
1498    def test_lru_cache_typed_is_not_recursive(self):
1499        cached = self.module.lru_cache(typed=True)(repr)
1500
1501        self.assertEqual(cached(1), '1')
1502        self.assertEqual(cached(True), 'True')
1503        self.assertEqual(cached(1.0), '1.0')
1504        self.assertEqual(cached(0), '0')
1505        self.assertEqual(cached(False), 'False')
1506        self.assertEqual(cached(0.0), '0.0')
1507
1508        self.assertEqual(cached((1,)), '(1,)')
1509        self.assertEqual(cached((True,)), '(1,)')
1510        self.assertEqual(cached((1.0,)), '(1,)')
1511        self.assertEqual(cached((0,)), '(0,)')
1512        self.assertEqual(cached((False,)), '(0,)')
1513        self.assertEqual(cached((0.0,)), '(0,)')
1514
1515        class T(tuple):
1516            pass
1517
1518        self.assertEqual(cached(T((1,))), '(1,)')
1519        self.assertEqual(cached(T((True,))), '(1,)')
1520        self.assertEqual(cached(T((1.0,))), '(1,)')
1521        self.assertEqual(cached(T((0,))), '(0,)')
1522        self.assertEqual(cached(T((False,))), '(0,)')
1523        self.assertEqual(cached(T((0.0,))), '(0,)')
1524
1525    def test_lru_with_keyword_args(self):
1526        @self.module.lru_cache()
1527        def fib(n):
1528            if n < 2:
1529                return n
1530            return fib(n=n-1) + fib(n=n-2)
1531        self.assertEqual(
1532            [fib(n=number) for number in range(16)],
1533            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1534        )
1535        self.assertEqual(fib.cache_info(),
1536            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1537        fib.cache_clear()
1538        self.assertEqual(fib.cache_info(),
1539            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1540
1541    def test_lru_with_keyword_args_maxsize_none(self):
1542        @self.module.lru_cache(maxsize=None)
1543        def fib(n):
1544            if n < 2:
1545                return n
1546            return fib(n=n-1) + fib(n=n-2)
1547        self.assertEqual([fib(n=number) for number in range(16)],
1548            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1549        self.assertEqual(fib.cache_info(),
1550            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1551        fib.cache_clear()
1552        self.assertEqual(fib.cache_info(),
1553            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1554
1555    def test_kwargs_order(self):
1556        # PEP 468: Preserving Keyword Argument Order
1557        @self.module.lru_cache(maxsize=10)
1558        def f(**kwargs):
1559            return list(kwargs.items())
1560        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1561        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1562        self.assertEqual(f.cache_info(),
1563            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1564
1565    def test_lru_cache_decoration(self):
1566        def f(zomg: 'zomg_annotation'):
1567            """f doc string"""
1568            return 42
1569        g = self.module.lru_cache()(f)
1570        for attr in self.module.WRAPPER_ASSIGNMENTS:
1571            self.assertEqual(getattr(g, attr), getattr(f, attr))
1572
1573    def test_lru_cache_threaded(self):
1574        n, m = 5, 11
1575        def orig(x, y):
1576            return 3 * x + y
1577        f = self.module.lru_cache(maxsize=n*m)(orig)
1578        hits, misses, maxsize, currsize = f.cache_info()
1579        self.assertEqual(currsize, 0)
1580
1581        start = threading.Event()
1582        def full(k):
1583            start.wait(10)
1584            for _ in range(m):
1585                self.assertEqual(f(k, 0), orig(k, 0))
1586
1587        def clear():
1588            start.wait(10)
1589            for _ in range(2*m):
1590                f.cache_clear()
1591
1592        orig_si = sys.getswitchinterval()
1593        support.setswitchinterval(1e-6)
1594        try:
1595            # create n threads in order to fill cache
1596            threads = [threading.Thread(target=full, args=[k])
1597                       for k in range(n)]
1598            with threading_helper.start_threads(threads):
1599                start.set()
1600
1601            hits, misses, maxsize, currsize = f.cache_info()
1602            if self.module is py_functools:
1603                # XXX: Why can be not equal?
1604                self.assertLessEqual(misses, n)
1605                self.assertLessEqual(hits, m*n - misses)
1606            else:
1607                self.assertEqual(misses, n)
1608                self.assertEqual(hits, m*n - misses)
1609            self.assertEqual(currsize, n)
1610
1611            # create n threads in order to fill cache and 1 to clear it
1612            threads = [threading.Thread(target=clear)]
1613            threads += [threading.Thread(target=full, args=[k])
1614                        for k in range(n)]
1615            start.clear()
1616            with threading_helper.start_threads(threads):
1617                start.set()
1618        finally:
1619            sys.setswitchinterval(orig_si)
1620
1621    def test_lru_cache_threaded2(self):
1622        # Simultaneous call with the same arguments
1623        n, m = 5, 7
1624        start = threading.Barrier(n+1)
1625        pause = threading.Barrier(n+1)
1626        stop = threading.Barrier(n+1)
1627        @self.module.lru_cache(maxsize=m*n)
1628        def f(x):
1629            pause.wait(10)
1630            return 3 * x
1631        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1632        def test():
1633            for i in range(m):
1634                start.wait(10)
1635                self.assertEqual(f(i), 3 * i)
1636                stop.wait(10)
1637        threads = [threading.Thread(target=test) for k in range(n)]
1638        with threading_helper.start_threads(threads):
1639            for i in range(m):
1640                start.wait(10)
1641                stop.reset()
1642                pause.wait(10)
1643                start.reset()
1644                stop.wait(10)
1645                pause.reset()
1646                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1647
1648    def test_lru_cache_threaded3(self):
1649        @self.module.lru_cache(maxsize=2)
1650        def f(x):
1651            time.sleep(.01)
1652            return 3 * x
1653        def test(i, x):
1654            with self.subTest(thread=i):
1655                self.assertEqual(f(x), 3 * x, i)
1656        threads = [threading.Thread(target=test, args=(i, v))
1657                   for i, v in enumerate([1, 2, 2, 3, 2])]
1658        with threading_helper.start_threads(threads):
1659            pass
1660
1661    def test_need_for_rlock(self):
1662        # This will deadlock on an LRU cache that uses a regular lock
1663
1664        @self.module.lru_cache(maxsize=10)
1665        def test_func(x):
1666            'Used to demonstrate a reentrant lru_cache call within a single thread'
1667            return x
1668
1669        class DoubleEq:
1670            'Demonstrate a reentrant lru_cache call within a single thread'
1671            def __init__(self, x):
1672                self.x = x
1673            def __hash__(self):
1674                return self.x
1675            def __eq__(self, other):
1676                if self.x == 2:
1677                    test_func(DoubleEq(1))
1678                return self.x == other.x
1679
1680        test_func(DoubleEq(1))                      # Load the cache
1681        test_func(DoubleEq(2))                      # Load the cache
1682        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1683                         DoubleEq(2))               # Verify the correct return value
1684
1685    def test_lru_method(self):
1686        class X(int):
1687            f_cnt = 0
1688            @self.module.lru_cache(2)
1689            def f(self, x):
1690                self.f_cnt += 1
1691                return x*10+self
1692        a = X(5)
1693        b = X(5)
1694        c = X(7)
1695        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1696
1697        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1698            self.assertEqual(a.f(x), x*10 + 5)
1699        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1700        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1701
1702        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1703            self.assertEqual(b.f(x), x*10 + 5)
1704        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1705        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1706
1707        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1708            self.assertEqual(c.f(x), x*10 + 7)
1709        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1710        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1711
1712        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1713        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1714        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1715
1716    def test_pickle(self):
1717        cls = self.__class__
1718        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1719            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1720                with self.subTest(proto=proto, func=f):
1721                    f_copy = pickle.loads(pickle.dumps(f, proto))
1722                    self.assertIs(f_copy, f)
1723
1724    def test_copy(self):
1725        cls = self.__class__
1726        def orig(x, y):
1727            return 3 * x + y
1728        part = self.module.partial(orig, 2)
1729        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1730                 self.module.lru_cache(2)(part))
1731        for f in funcs:
1732            with self.subTest(func=f):
1733                f_copy = copy.copy(f)
1734                self.assertIs(f_copy, f)
1735
1736    def test_deepcopy(self):
1737        cls = self.__class__
1738        def orig(x, y):
1739            return 3 * x + y
1740        part = self.module.partial(orig, 2)
1741        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1742                 self.module.lru_cache(2)(part))
1743        for f in funcs:
1744            with self.subTest(func=f):
1745                f_copy = copy.deepcopy(f)
1746                self.assertIs(f_copy, f)
1747
1748    def test_lru_cache_parameters(self):
1749        @self.module.lru_cache(maxsize=2)
1750        def f():
1751            return 1
1752        self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1753
1754        @self.module.lru_cache(maxsize=1000, typed=True)
1755        def f():
1756            return 1
1757        self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1758
1759    def test_lru_cache_weakrefable(self):
1760        @self.module.lru_cache
1761        def test_function(x):
1762            return x
1763
1764        class A:
1765            @self.module.lru_cache
1766            def test_method(self, x):
1767                return (self, x)
1768
1769            @staticmethod
1770            @self.module.lru_cache
1771            def test_staticmethod(x):
1772                return (self, x)
1773
1774        refs = [weakref.ref(test_function),
1775                weakref.ref(A.test_method),
1776                weakref.ref(A.test_staticmethod)]
1777
1778        for ref in refs:
1779            self.assertIsNotNone(ref())
1780
1781        del A
1782        del test_function
1783        gc.collect()
1784
1785        for ref in refs:
1786            self.assertIsNone(ref())
1787
1788
1789@py_functools.lru_cache()
1790def py_cached_func(x, y):
1791    return 3 * x + y
1792
1793@c_functools.lru_cache()
1794def c_cached_func(x, y):
1795    return 3 * x + y
1796
1797
1798class TestLRUPy(TestLRU, unittest.TestCase):
1799    module = py_functools
1800    cached_func = py_cached_func,
1801
1802    @module.lru_cache()
1803    def cached_meth(self, x, y):
1804        return 3 * x + y
1805
1806    @staticmethod
1807    @module.lru_cache()
1808    def cached_staticmeth(x, y):
1809        return 3 * x + y
1810
1811
1812class TestLRUC(TestLRU, unittest.TestCase):
1813    module = c_functools
1814    cached_func = c_cached_func,
1815
1816    @module.lru_cache()
1817    def cached_meth(self, x, y):
1818        return 3 * x + y
1819
1820    @staticmethod
1821    @module.lru_cache()
1822    def cached_staticmeth(x, y):
1823        return 3 * x + y
1824
1825
1826class TestSingleDispatch(unittest.TestCase):
1827    def test_simple_overloads(self):
1828        @functools.singledispatch
1829        def g(obj):
1830            return "base"
1831        def g_int(i):
1832            return "integer"
1833        g.register(int, g_int)
1834        self.assertEqual(g("str"), "base")
1835        self.assertEqual(g(1), "integer")
1836        self.assertEqual(g([1,2,3]), "base")
1837
1838    def test_mro(self):
1839        @functools.singledispatch
1840        def g(obj):
1841            return "base"
1842        class A:
1843            pass
1844        class C(A):
1845            pass
1846        class B(A):
1847            pass
1848        class D(C, B):
1849            pass
1850        def g_A(a):
1851            return "A"
1852        def g_B(b):
1853            return "B"
1854        g.register(A, g_A)
1855        g.register(B, g_B)
1856        self.assertEqual(g(A()), "A")
1857        self.assertEqual(g(B()), "B")
1858        self.assertEqual(g(C()), "A")
1859        self.assertEqual(g(D()), "B")
1860
1861    def test_register_decorator(self):
1862        @functools.singledispatch
1863        def g(obj):
1864            return "base"
1865        @g.register(int)
1866        def g_int(i):
1867            return "int %s" % (i,)
1868        self.assertEqual(g(""), "base")
1869        self.assertEqual(g(12), "int 12")
1870        self.assertIs(g.dispatch(int), g_int)
1871        self.assertIs(g.dispatch(object), g.dispatch(str))
1872        # Note: in the assert above this is not g.
1873        # @singledispatch returns the wrapper.
1874
1875    def test_wrapping_attributes(self):
1876        @functools.singledispatch
1877        def g(obj):
1878            "Simple test"
1879            return "Test"
1880        self.assertEqual(g.__name__, "g")
1881        if sys.flags.optimize < 2:
1882            self.assertEqual(g.__doc__, "Simple test")
1883
1884    @unittest.skipUnless(decimal, 'requires _decimal')
1885    @support.cpython_only
1886    def test_c_classes(self):
1887        @functools.singledispatch
1888        def g(obj):
1889            return "base"
1890        @g.register(decimal.DecimalException)
1891        def _(obj):
1892            return obj.args
1893        subn = decimal.Subnormal("Exponent < Emin")
1894        rnd = decimal.Rounded("Number got rounded")
1895        self.assertEqual(g(subn), ("Exponent < Emin",))
1896        self.assertEqual(g(rnd), ("Number got rounded",))
1897        @g.register(decimal.Subnormal)
1898        def _(obj):
1899            return "Too small to care."
1900        self.assertEqual(g(subn), "Too small to care.")
1901        self.assertEqual(g(rnd), ("Number got rounded",))
1902
1903    def test_compose_mro(self):
1904        # None of the examples in this test depend on haystack ordering.
1905        c = collections.abc
1906        mro = functools._compose_mro
1907        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1908        for haystack in permutations(bases):
1909            m = mro(dict, haystack)
1910            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1911                                 c.Collection, c.Sized, c.Iterable,
1912                                 c.Container, object])
1913        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1914        for haystack in permutations(bases):
1915            m = mro(collections.ChainMap, haystack)
1916            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1917                                 c.Collection, c.Sized, c.Iterable,
1918                                 c.Container, object])
1919
1920        # If there's a generic function with implementations registered for
1921        # both Sized and Container, passing a defaultdict to it results in an
1922        # ambiguous dispatch which will cause a RuntimeError (see
1923        # test_mro_conflicts).
1924        bases = [c.Container, c.Sized, str]
1925        for haystack in permutations(bases):
1926            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1927            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1928                                 c.Container, object])
1929
1930        # MutableSequence below is registered directly on D. In other words, it
1931        # precedes MutableMapping which means single dispatch will always
1932        # choose MutableSequence here.
1933        class D(collections.defaultdict):
1934            pass
1935        c.MutableSequence.register(D)
1936        bases = [c.MutableSequence, c.MutableMapping]
1937        for haystack in permutations(bases):
1938            m = mro(D, bases)
1939            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1940                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1941                                 c.Collection, c.Sized, c.Iterable, c.Container,
1942                                 object])
1943
1944        # Container and Callable are registered on different base classes and
1945        # a generic function supporting both should always pick the Callable
1946        # implementation if a C instance is passed.
1947        class C(collections.defaultdict):
1948            def __call__(self):
1949                pass
1950        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1951        for haystack in permutations(bases):
1952            m = mro(C, haystack)
1953            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1954                                 c.Collection, c.Sized, c.Iterable,
1955                                 c.Container, object])
1956
1957    def test_register_abc(self):
1958        c = collections.abc
1959        d = {"a": "b"}
1960        l = [1, 2, 3]
1961        s = {object(), None}
1962        f = frozenset(s)
1963        t = (1, 2, 3)
1964        @functools.singledispatch
1965        def g(obj):
1966            return "base"
1967        self.assertEqual(g(d), "base")
1968        self.assertEqual(g(l), "base")
1969        self.assertEqual(g(s), "base")
1970        self.assertEqual(g(f), "base")
1971        self.assertEqual(g(t), "base")
1972        g.register(c.Sized, lambda obj: "sized")
1973        self.assertEqual(g(d), "sized")
1974        self.assertEqual(g(l), "sized")
1975        self.assertEqual(g(s), "sized")
1976        self.assertEqual(g(f), "sized")
1977        self.assertEqual(g(t), "sized")
1978        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1979        self.assertEqual(g(d), "mutablemapping")
1980        self.assertEqual(g(l), "sized")
1981        self.assertEqual(g(s), "sized")
1982        self.assertEqual(g(f), "sized")
1983        self.assertEqual(g(t), "sized")
1984        g.register(collections.ChainMap, lambda obj: "chainmap")
1985        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1986        self.assertEqual(g(l), "sized")
1987        self.assertEqual(g(s), "sized")
1988        self.assertEqual(g(f), "sized")
1989        self.assertEqual(g(t), "sized")
1990        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1991        self.assertEqual(g(d), "mutablemapping")
1992        self.assertEqual(g(l), "mutablesequence")
1993        self.assertEqual(g(s), "sized")
1994        self.assertEqual(g(f), "sized")
1995        self.assertEqual(g(t), "sized")
1996        g.register(c.MutableSet, lambda obj: "mutableset")
1997        self.assertEqual(g(d), "mutablemapping")
1998        self.assertEqual(g(l), "mutablesequence")
1999        self.assertEqual(g(s), "mutableset")
2000        self.assertEqual(g(f), "sized")
2001        self.assertEqual(g(t), "sized")
2002        g.register(c.Mapping, lambda obj: "mapping")
2003        self.assertEqual(g(d), "mutablemapping")  # not specific enough
2004        self.assertEqual(g(l), "mutablesequence")
2005        self.assertEqual(g(s), "mutableset")
2006        self.assertEqual(g(f), "sized")
2007        self.assertEqual(g(t), "sized")
2008        g.register(c.Sequence, lambda obj: "sequence")
2009        self.assertEqual(g(d), "mutablemapping")
2010        self.assertEqual(g(l), "mutablesequence")
2011        self.assertEqual(g(s), "mutableset")
2012        self.assertEqual(g(f), "sized")
2013        self.assertEqual(g(t), "sequence")
2014        g.register(c.Set, lambda obj: "set")
2015        self.assertEqual(g(d), "mutablemapping")
2016        self.assertEqual(g(l), "mutablesequence")
2017        self.assertEqual(g(s), "mutableset")
2018        self.assertEqual(g(f), "set")
2019        self.assertEqual(g(t), "sequence")
2020        g.register(dict, lambda obj: "dict")
2021        self.assertEqual(g(d), "dict")
2022        self.assertEqual(g(l), "mutablesequence")
2023        self.assertEqual(g(s), "mutableset")
2024        self.assertEqual(g(f), "set")
2025        self.assertEqual(g(t), "sequence")
2026        g.register(list, lambda obj: "list")
2027        self.assertEqual(g(d), "dict")
2028        self.assertEqual(g(l), "list")
2029        self.assertEqual(g(s), "mutableset")
2030        self.assertEqual(g(f), "set")
2031        self.assertEqual(g(t), "sequence")
2032        g.register(set, lambda obj: "concrete-set")
2033        self.assertEqual(g(d), "dict")
2034        self.assertEqual(g(l), "list")
2035        self.assertEqual(g(s), "concrete-set")
2036        self.assertEqual(g(f), "set")
2037        self.assertEqual(g(t), "sequence")
2038        g.register(frozenset, lambda obj: "frozen-set")
2039        self.assertEqual(g(d), "dict")
2040        self.assertEqual(g(l), "list")
2041        self.assertEqual(g(s), "concrete-set")
2042        self.assertEqual(g(f), "frozen-set")
2043        self.assertEqual(g(t), "sequence")
2044        g.register(tuple, lambda obj: "tuple")
2045        self.assertEqual(g(d), "dict")
2046        self.assertEqual(g(l), "list")
2047        self.assertEqual(g(s), "concrete-set")
2048        self.assertEqual(g(f), "frozen-set")
2049        self.assertEqual(g(t), "tuple")
2050
2051    def test_c3_abc(self):
2052        c = collections.abc
2053        mro = functools._c3_mro
2054        class A(object):
2055            pass
2056        class B(A):
2057            def __len__(self):
2058                return 0   # implies Sized
2059        @c.Container.register
2060        class C(object):
2061            pass
2062        class D(object):
2063            pass   # unrelated
2064        class X(D, C, B):
2065            def __call__(self):
2066                pass   # implies Callable
2067        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2068        for abcs in permutations([c.Sized, c.Callable, c.Container]):
2069            self.assertEqual(mro(X, abcs=abcs), expected)
2070        # unrelated ABCs don't appear in the resulting MRO
2071        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2072        self.assertEqual(mro(X, abcs=many_abcs), expected)
2073
2074    def test_false_meta(self):
2075        # see issue23572
2076        class MetaA(type):
2077            def __len__(self):
2078                return 0
2079        class A(metaclass=MetaA):
2080            pass
2081        class AA(A):
2082            pass
2083        @functools.singledispatch
2084        def fun(a):
2085            return 'base A'
2086        @fun.register(A)
2087        def _(a):
2088            return 'fun A'
2089        aa = AA()
2090        self.assertEqual(fun(aa), 'fun A')
2091
2092    def test_mro_conflicts(self):
2093        c = collections.abc
2094        @functools.singledispatch
2095        def g(arg):
2096            return "base"
2097        class O(c.Sized):
2098            def __len__(self):
2099                return 0
2100        o = O()
2101        self.assertEqual(g(o), "base")
2102        g.register(c.Iterable, lambda arg: "iterable")
2103        g.register(c.Container, lambda arg: "container")
2104        g.register(c.Sized, lambda arg: "sized")
2105        g.register(c.Set, lambda arg: "set")
2106        self.assertEqual(g(o), "sized")
2107        c.Iterable.register(O)
2108        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
2109        c.Container.register(O)
2110        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
2111        c.Set.register(O)
2112        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
2113                                          # c.Sized and c.Container
2114        class P:
2115            pass
2116        p = P()
2117        self.assertEqual(g(p), "base")
2118        c.Iterable.register(P)
2119        self.assertEqual(g(p), "iterable")
2120        c.Container.register(P)
2121        with self.assertRaises(RuntimeError) as re_one:
2122            g(p)
2123        self.assertIn(
2124            str(re_one.exception),
2125            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2126              "or <class 'collections.abc.Iterable'>"),
2127             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2128              "or <class 'collections.abc.Container'>")),
2129        )
2130        class Q(c.Sized):
2131            def __len__(self):
2132                return 0
2133        q = Q()
2134        self.assertEqual(g(q), "sized")
2135        c.Iterable.register(Q)
2136        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2137        c.Set.register(Q)
2138        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2139                                          # c.Sized and c.Iterable
2140        @functools.singledispatch
2141        def h(arg):
2142            return "base"
2143        @h.register(c.Sized)
2144        def _(arg):
2145            return "sized"
2146        @h.register(c.Container)
2147        def _(arg):
2148            return "container"
2149        # Even though Sized and Container are explicit bases of MutableMapping,
2150        # this ABC is implicitly registered on defaultdict which makes all of
2151        # MutableMapping's bases implicit as well from defaultdict's
2152        # perspective.
2153        with self.assertRaises(RuntimeError) as re_two:
2154            h(collections.defaultdict(lambda: 0))
2155        self.assertIn(
2156            str(re_two.exception),
2157            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2158              "or <class 'collections.abc.Sized'>"),
2159             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2160              "or <class 'collections.abc.Container'>")),
2161        )
2162        class R(collections.defaultdict):
2163            pass
2164        c.MutableSequence.register(R)
2165        @functools.singledispatch
2166        def i(arg):
2167            return "base"
2168        @i.register(c.MutableMapping)
2169        def _(arg):
2170            return "mapping"
2171        @i.register(c.MutableSequence)
2172        def _(arg):
2173            return "sequence"
2174        r = R()
2175        self.assertEqual(i(r), "sequence")
2176        class S:
2177            pass
2178        class T(S, c.Sized):
2179            def __len__(self):
2180                return 0
2181        t = T()
2182        self.assertEqual(h(t), "sized")
2183        c.Container.register(T)
2184        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2185        class U:
2186            def __len__(self):
2187                return 0
2188        u = U()
2189        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2190                                          # from the existence of __len__()
2191        c.Container.register(U)
2192        # There is no preference for registered versus inferred ABCs.
2193        with self.assertRaises(RuntimeError) as re_three:
2194            h(u)
2195        self.assertIn(
2196            str(re_three.exception),
2197            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2198              "or <class 'collections.abc.Sized'>"),
2199             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2200              "or <class 'collections.abc.Container'>")),
2201        )
2202        class V(c.Sized, S):
2203            def __len__(self):
2204                return 0
2205        @functools.singledispatch
2206        def j(arg):
2207            return "base"
2208        @j.register(S)
2209        def _(arg):
2210            return "s"
2211        @j.register(c.Container)
2212        def _(arg):
2213            return "container"
2214        v = V()
2215        self.assertEqual(j(v), "s")
2216        c.Container.register(V)
2217        self.assertEqual(j(v), "container")   # because it ends up right after
2218                                              # Sized in the MRO
2219
2220    def test_cache_invalidation(self):
2221        from collections import UserDict
2222        import weakref
2223
2224        class TracingDict(UserDict):
2225            def __init__(self, *args, **kwargs):
2226                super(TracingDict, self).__init__(*args, **kwargs)
2227                self.set_ops = []
2228                self.get_ops = []
2229            def __getitem__(self, key):
2230                result = self.data[key]
2231                self.get_ops.append(key)
2232                return result
2233            def __setitem__(self, key, value):
2234                self.set_ops.append(key)
2235                self.data[key] = value
2236            def clear(self):
2237                self.data.clear()
2238
2239        td = TracingDict()
2240        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2241            c = collections.abc
2242            @functools.singledispatch
2243            def g(arg):
2244                return "base"
2245            d = {}
2246            l = []
2247            self.assertEqual(len(td), 0)
2248            self.assertEqual(g(d), "base")
2249            self.assertEqual(len(td), 1)
2250            self.assertEqual(td.get_ops, [])
2251            self.assertEqual(td.set_ops, [dict])
2252            self.assertEqual(td.data[dict], g.registry[object])
2253            self.assertEqual(g(l), "base")
2254            self.assertEqual(len(td), 2)
2255            self.assertEqual(td.get_ops, [])
2256            self.assertEqual(td.set_ops, [dict, list])
2257            self.assertEqual(td.data[dict], g.registry[object])
2258            self.assertEqual(td.data[list], g.registry[object])
2259            self.assertEqual(td.data[dict], td.data[list])
2260            self.assertEqual(g(l), "base")
2261            self.assertEqual(g(d), "base")
2262            self.assertEqual(td.get_ops, [list, dict])
2263            self.assertEqual(td.set_ops, [dict, list])
2264            g.register(list, lambda arg: "list")
2265            self.assertEqual(td.get_ops, [list, dict])
2266            self.assertEqual(len(td), 0)
2267            self.assertEqual(g(d), "base")
2268            self.assertEqual(len(td), 1)
2269            self.assertEqual(td.get_ops, [list, dict])
2270            self.assertEqual(td.set_ops, [dict, list, dict])
2271            self.assertEqual(td.data[dict],
2272                             functools._find_impl(dict, g.registry))
2273            self.assertEqual(g(l), "list")
2274            self.assertEqual(len(td), 2)
2275            self.assertEqual(td.get_ops, [list, dict])
2276            self.assertEqual(td.set_ops, [dict, list, dict, list])
2277            self.assertEqual(td.data[list],
2278                             functools._find_impl(list, g.registry))
2279            class X:
2280                pass
2281            c.MutableMapping.register(X)   # Will not invalidate the cache,
2282                                           # not using ABCs yet.
2283            self.assertEqual(g(d), "base")
2284            self.assertEqual(g(l), "list")
2285            self.assertEqual(td.get_ops, [list, dict, dict, list])
2286            self.assertEqual(td.set_ops, [dict, list, dict, list])
2287            g.register(c.Sized, lambda arg: "sized")
2288            self.assertEqual(len(td), 0)
2289            self.assertEqual(g(d), "sized")
2290            self.assertEqual(len(td), 1)
2291            self.assertEqual(td.get_ops, [list, dict, dict, list])
2292            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2293            self.assertEqual(g(l), "list")
2294            self.assertEqual(len(td), 2)
2295            self.assertEqual(td.get_ops, [list, dict, dict, list])
2296            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2297            self.assertEqual(g(l), "list")
2298            self.assertEqual(g(d), "sized")
2299            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2300            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2301            g.dispatch(list)
2302            g.dispatch(dict)
2303            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2304                                          list, dict])
2305            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2306            c.MutableSet.register(X)       # Will invalidate the cache.
2307            self.assertEqual(len(td), 2)   # Stale cache.
2308            self.assertEqual(g(l), "list")
2309            self.assertEqual(len(td), 1)
2310            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2311            self.assertEqual(len(td), 0)
2312            self.assertEqual(g(d), "mutablemapping")
2313            self.assertEqual(len(td), 1)
2314            self.assertEqual(g(l), "list")
2315            self.assertEqual(len(td), 2)
2316            g.register(dict, lambda arg: "dict")
2317            self.assertEqual(g(d), "dict")
2318            self.assertEqual(g(l), "list")
2319            g._clear_cache()
2320            self.assertEqual(len(td), 0)
2321
2322    def test_annotations(self):
2323        @functools.singledispatch
2324        def i(arg):
2325            return "base"
2326        @i.register
2327        def _(arg: collections.abc.Mapping):
2328            return "mapping"
2329        @i.register
2330        def _(arg: "collections.abc.Sequence"):
2331            return "sequence"
2332        self.assertEqual(i(None), "base")
2333        self.assertEqual(i({"a": 1}), "mapping")
2334        self.assertEqual(i([1, 2, 3]), "sequence")
2335        self.assertEqual(i((1, 2, 3)), "sequence")
2336        self.assertEqual(i("str"), "sequence")
2337
2338        # Registering classes as callables doesn't work with annotations,
2339        # you need to pass the type explicitly.
2340        @i.register(str)
2341        class _:
2342            def __init__(self, arg):
2343                self.arg = arg
2344
2345            def __eq__(self, other):
2346                return self.arg == other
2347        self.assertEqual(i("str"), "str")
2348
2349    def test_method_register(self):
2350        class A:
2351            @functools.singledispatchmethod
2352            def t(self, arg):
2353                self.arg = "base"
2354            @t.register(int)
2355            def _(self, arg):
2356                self.arg = "int"
2357            @t.register(str)
2358            def _(self, arg):
2359                self.arg = "str"
2360        a = A()
2361
2362        a.t(0)
2363        self.assertEqual(a.arg, "int")
2364        aa = A()
2365        self.assertFalse(hasattr(aa, 'arg'))
2366        a.t('')
2367        self.assertEqual(a.arg, "str")
2368        aa = A()
2369        self.assertFalse(hasattr(aa, 'arg'))
2370        a.t(0.0)
2371        self.assertEqual(a.arg, "base")
2372        aa = A()
2373        self.assertFalse(hasattr(aa, 'arg'))
2374
2375    def test_staticmethod_register(self):
2376        class A:
2377            @functools.singledispatchmethod
2378            @staticmethod
2379            def t(arg):
2380                return arg
2381            @t.register(int)
2382            @staticmethod
2383            def _(arg):
2384                return isinstance(arg, int)
2385            @t.register(str)
2386            @staticmethod
2387            def _(arg):
2388                return isinstance(arg, str)
2389        a = A()
2390
2391        self.assertTrue(A.t(0))
2392        self.assertTrue(A.t(''))
2393        self.assertEqual(A.t(0.0), 0.0)
2394
2395    def test_classmethod_register(self):
2396        class A:
2397            def __init__(self, arg):
2398                self.arg = arg
2399
2400            @functools.singledispatchmethod
2401            @classmethod
2402            def t(cls, arg):
2403                return cls("base")
2404            @t.register(int)
2405            @classmethod
2406            def _(cls, arg):
2407                return cls("int")
2408            @t.register(str)
2409            @classmethod
2410            def _(cls, arg):
2411                return cls("str")
2412
2413        self.assertEqual(A.t(0).arg, "int")
2414        self.assertEqual(A.t('').arg, "str")
2415        self.assertEqual(A.t(0.0).arg, "base")
2416
2417    def test_callable_register(self):
2418        class A:
2419            def __init__(self, arg):
2420                self.arg = arg
2421
2422            @functools.singledispatchmethod
2423            @classmethod
2424            def t(cls, arg):
2425                return cls("base")
2426
2427        @A.t.register(int)
2428        @classmethod
2429        def _(cls, arg):
2430            return cls("int")
2431        @A.t.register(str)
2432        @classmethod
2433        def _(cls, arg):
2434            return cls("str")
2435
2436        self.assertEqual(A.t(0).arg, "int")
2437        self.assertEqual(A.t('').arg, "str")
2438        self.assertEqual(A.t(0.0).arg, "base")
2439
2440    def test_abstractmethod_register(self):
2441        class Abstract(metaclass=abc.ABCMeta):
2442
2443            @functools.singledispatchmethod
2444            @abc.abstractmethod
2445            def add(self, x, y):
2446                pass
2447
2448        self.assertTrue(Abstract.add.__isabstractmethod__)
2449        self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
2450
2451        with self.assertRaises(TypeError):
2452            Abstract()
2453
2454    def test_type_ann_register(self):
2455        class A:
2456            @functools.singledispatchmethod
2457            def t(self, arg):
2458                return "base"
2459            @t.register
2460            def _(self, arg: int):
2461                return "int"
2462            @t.register
2463            def _(self, arg: str):
2464                return "str"
2465        a = A()
2466
2467        self.assertEqual(a.t(0), "int")
2468        self.assertEqual(a.t(''), "str")
2469        self.assertEqual(a.t(0.0), "base")
2470
2471    def test_staticmethod_type_ann_register(self):
2472        class A:
2473            @functools.singledispatchmethod
2474            @staticmethod
2475            def t(arg):
2476                return arg
2477            @t.register
2478            @staticmethod
2479            def _(arg: int):
2480                return isinstance(arg, int)
2481            @t.register
2482            @staticmethod
2483            def _(arg: str):
2484                return isinstance(arg, str)
2485        a = A()
2486
2487        self.assertTrue(A.t(0))
2488        self.assertTrue(A.t(''))
2489        self.assertEqual(A.t(0.0), 0.0)
2490
2491    def test_classmethod_type_ann_register(self):
2492        class A:
2493            def __init__(self, arg):
2494                self.arg = arg
2495
2496            @functools.singledispatchmethod
2497            @classmethod
2498            def t(cls, arg):
2499                return cls("base")
2500            @t.register
2501            @classmethod
2502            def _(cls, arg: int):
2503                return cls("int")
2504            @t.register
2505            @classmethod
2506            def _(cls, arg: str):
2507                return cls("str")
2508
2509        self.assertEqual(A.t(0).arg, "int")
2510        self.assertEqual(A.t('').arg, "str")
2511        self.assertEqual(A.t(0.0).arg, "base")
2512
2513    def test_method_wrapping_attributes(self):
2514        class A:
2515            @functools.singledispatchmethod
2516            def func(self, arg: int) -> str:
2517                """My function docstring"""
2518                return str(arg)
2519            @functools.singledispatchmethod
2520            @classmethod
2521            def cls_func(cls, arg: int) -> str:
2522                """My function docstring"""
2523                return str(arg)
2524            @functools.singledispatchmethod
2525            @staticmethod
2526            def static_func(arg: int) -> str:
2527                """My function docstring"""
2528                return str(arg)
2529
2530        for meth in (
2531            A.func,
2532            A().func,
2533            A.cls_func,
2534            A().cls_func,
2535            A.static_func,
2536            A().static_func
2537        ):
2538            with self.subTest(meth=meth):
2539                self.assertEqual(meth.__doc__, 'My function docstring')
2540                self.assertEqual(meth.__annotations__['arg'], int)
2541
2542        self.assertEqual(A.func.__name__, 'func')
2543        self.assertEqual(A().func.__name__, 'func')
2544        self.assertEqual(A.cls_func.__name__, 'cls_func')
2545        self.assertEqual(A().cls_func.__name__, 'cls_func')
2546        self.assertEqual(A.static_func.__name__, 'static_func')
2547        self.assertEqual(A().static_func.__name__, 'static_func')
2548
2549    def test_double_wrapped_methods(self):
2550        def classmethod_friendly_decorator(func):
2551            wrapped = func.__func__
2552            @classmethod
2553            @functools.wraps(wrapped)
2554            def wrapper(*args, **kwargs):
2555                return wrapped(*args, **kwargs)
2556            return wrapper
2557
2558        class WithoutSingleDispatch:
2559            @classmethod
2560            @contextlib.contextmanager
2561            def cls_context_manager(cls, arg: int) -> str:
2562                try:
2563                    yield str(arg)
2564                finally:
2565                    return 'Done'
2566
2567            @classmethod_friendly_decorator
2568            @classmethod
2569            def decorated_classmethod(cls, arg: int) -> str:
2570                return str(arg)
2571
2572        class WithSingleDispatch:
2573            @functools.singledispatchmethod
2574            @classmethod
2575            @contextlib.contextmanager
2576            def cls_context_manager(cls, arg: int) -> str:
2577                """My function docstring"""
2578                try:
2579                    yield str(arg)
2580                finally:
2581                    return 'Done'
2582
2583            @functools.singledispatchmethod
2584            @classmethod_friendly_decorator
2585            @classmethod
2586            def decorated_classmethod(cls, arg: int) -> str:
2587                """My function docstring"""
2588                return str(arg)
2589
2590        # These are sanity checks
2591        # to test the test itself is working as expected
2592        with WithoutSingleDispatch.cls_context_manager(5) as foo:
2593            without_single_dispatch_foo = foo
2594
2595        with WithSingleDispatch.cls_context_manager(5) as foo:
2596            single_dispatch_foo = foo
2597
2598        self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
2599        self.assertEqual(single_dispatch_foo, '5')
2600
2601        self.assertEqual(
2602            WithoutSingleDispatch.decorated_classmethod(5),
2603            WithSingleDispatch.decorated_classmethod(5)
2604        )
2605
2606        self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
2607
2608        # Behavioural checks now follow
2609        for method_name in ('cls_context_manager', 'decorated_classmethod'):
2610            with self.subTest(method=method_name):
2611                self.assertEqual(
2612                    getattr(WithSingleDispatch, method_name).__name__,
2613                    getattr(WithoutSingleDispatch, method_name).__name__
2614                )
2615
2616                self.assertEqual(
2617                    getattr(WithSingleDispatch(), method_name).__name__,
2618                    getattr(WithoutSingleDispatch(), method_name).__name__
2619                )
2620
2621        for meth in (
2622            WithSingleDispatch.cls_context_manager,
2623            WithSingleDispatch().cls_context_manager,
2624            WithSingleDispatch.decorated_classmethod,
2625            WithSingleDispatch().decorated_classmethod
2626        ):
2627            with self.subTest(meth=meth):
2628                self.assertEqual(meth.__doc__, 'My function docstring')
2629                self.assertEqual(meth.__annotations__['arg'], int)
2630
2631        self.assertEqual(
2632            WithSingleDispatch.cls_context_manager.__name__,
2633            'cls_context_manager'
2634        )
2635        self.assertEqual(
2636            WithSingleDispatch().cls_context_manager.__name__,
2637            'cls_context_manager'
2638        )
2639        self.assertEqual(
2640            WithSingleDispatch.decorated_classmethod.__name__,
2641            'decorated_classmethod'
2642        )
2643        self.assertEqual(
2644            WithSingleDispatch().decorated_classmethod.__name__,
2645            'decorated_classmethod'
2646        )
2647
2648    def test_invalid_registrations(self):
2649        msg_prefix = "Invalid first argument to `register()`: "
2650        msg_suffix = (
2651            ". Use either `@register(some_class)` or plain `@register` on an "
2652            "annotated function."
2653        )
2654        @functools.singledispatch
2655        def i(arg):
2656            return "base"
2657        with self.assertRaises(TypeError) as exc:
2658            @i.register(42)
2659            def _(arg):
2660                return "I annotated with a non-type"
2661        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2662        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2663        with self.assertRaises(TypeError) as exc:
2664            @i.register
2665            def _(arg):
2666                return "I forgot to annotate"
2667        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2668            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2669        ))
2670        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2671
2672        with self.assertRaises(TypeError) as exc:
2673            @i.register
2674            def _(arg: typing.Iterable[str]):
2675                # At runtime, dispatching on generics is impossible.
2676                # When registering implementations with singledispatch, avoid
2677                # types from `typing`. Instead, annotate with regular types
2678                # or ABCs.
2679                return "I annotated with a generic collection"
2680        self.assertTrue(str(exc.exception).startswith(
2681            "Invalid annotation for 'arg'."
2682        ))
2683        self.assertTrue(str(exc.exception).endswith(
2684            'typing.Iterable[str] is not a class.'
2685        ))
2686
2687    def test_invalid_positional_argument(self):
2688        @functools.singledispatch
2689        def f(*args):
2690            pass
2691        msg = 'f requires at least 1 positional argument'
2692        with self.assertRaisesRegex(TypeError, msg):
2693            f()
2694
2695
2696class CachedCostItem:
2697    _cost = 1
2698
2699    def __init__(self):
2700        self.lock = py_functools.RLock()
2701
2702    @py_functools.cached_property
2703    def cost(self):
2704        """The cost of the item."""
2705        with self.lock:
2706            self._cost += 1
2707        return self._cost
2708
2709
2710class OptionallyCachedCostItem:
2711    _cost = 1
2712
2713    def get_cost(self):
2714        """The cost of the item."""
2715        self._cost += 1
2716        return self._cost
2717
2718    cached_cost = py_functools.cached_property(get_cost)
2719
2720
2721class CachedCostItemWait:
2722
2723    def __init__(self, event):
2724        self._cost = 1
2725        self.lock = py_functools.RLock()
2726        self.event = event
2727
2728    @py_functools.cached_property
2729    def cost(self):
2730        self.event.wait(1)
2731        with self.lock:
2732            self._cost += 1
2733        return self._cost
2734
2735
2736class CachedCostItemWithSlots:
2737    __slots__ = ('_cost')
2738
2739    def __init__(self):
2740        self._cost = 1
2741
2742    @py_functools.cached_property
2743    def cost(self):
2744        raise RuntimeError('never called, slots not supported')
2745
2746
2747class TestCachedProperty(unittest.TestCase):
2748    def test_cached(self):
2749        item = CachedCostItem()
2750        self.assertEqual(item.cost, 2)
2751        self.assertEqual(item.cost, 2) # not 3
2752
2753    def test_cached_attribute_name_differs_from_func_name(self):
2754        item = OptionallyCachedCostItem()
2755        self.assertEqual(item.get_cost(), 2)
2756        self.assertEqual(item.cached_cost, 3)
2757        self.assertEqual(item.get_cost(), 4)
2758        self.assertEqual(item.cached_cost, 3)
2759
2760    def test_threaded(self):
2761        go = threading.Event()
2762        item = CachedCostItemWait(go)
2763
2764        num_threads = 3
2765
2766        orig_si = sys.getswitchinterval()
2767        sys.setswitchinterval(1e-6)
2768        try:
2769            threads = [
2770                threading.Thread(target=lambda: item.cost)
2771                for k in range(num_threads)
2772            ]
2773            with threading_helper.start_threads(threads):
2774                go.set()
2775        finally:
2776            sys.setswitchinterval(orig_si)
2777
2778        self.assertEqual(item.cost, 2)
2779
2780    def test_object_with_slots(self):
2781        item = CachedCostItemWithSlots()
2782        with self.assertRaisesRegex(
2783                TypeError,
2784                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2785        ):
2786            item.cost
2787
2788    def test_immutable_dict(self):
2789        class MyMeta(type):
2790            @py_functools.cached_property
2791            def prop(self):
2792                return True
2793
2794        class MyClass(metaclass=MyMeta):
2795            pass
2796
2797        with self.assertRaisesRegex(
2798            TypeError,
2799            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2800        ):
2801            MyClass.prop
2802
2803    def test_reuse_different_names(self):
2804        """Disallow this case because decorated function a would not be cached."""
2805        with self.assertRaises(RuntimeError) as ctx:
2806            class ReusedCachedProperty:
2807                @py_functools.cached_property
2808                def a(self):
2809                    pass
2810
2811                b = a
2812
2813        self.assertEqual(
2814            str(ctx.exception.__context__),
2815            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2816        )
2817
2818    def test_reuse_same_name(self):
2819        """Reusing a cached_property on different classes under the same name is OK."""
2820        counter = 0
2821
2822        @py_functools.cached_property
2823        def _cp(_self):
2824            nonlocal counter
2825            counter += 1
2826            return counter
2827
2828        class A:
2829            cp = _cp
2830
2831        class B:
2832            cp = _cp
2833
2834        a = A()
2835        b = B()
2836
2837        self.assertEqual(a.cp, 1)
2838        self.assertEqual(b.cp, 2)
2839        self.assertEqual(a.cp, 1)
2840
2841    def test_set_name_not_called(self):
2842        cp = py_functools.cached_property(lambda s: None)
2843        class Foo:
2844            pass
2845
2846        Foo.cp = cp
2847
2848        with self.assertRaisesRegex(
2849                TypeError,
2850                "Cannot use cached_property instance without calling __set_name__ on it.",
2851        ):
2852            Foo().cp
2853
2854    def test_access_from_class(self):
2855        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2856
2857    def test_doc(self):
2858        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2859
2860
2861if __name__ == '__main__':
2862    unittest.main()
2863