1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16from weakref import proxy
17import contextlib
18
19import functools
20
21py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
24decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
26@contextlib.contextmanager
27def replaced_module(name, replacement):
28    original_module = sys.modules[name]
29    sys.modules[name] = replacement
30    try:
31        yield
32    finally:
33        sys.modules[name] = original_module
34
35def capture(*args, **kw):
36    """capture all positional and keyword arguments"""
37    return args, kw
38
39
40def signature(part):
41    """ return the signature of a partial object """
42    return (part.func, part.args, part.keywords, part.__dict__)
43
44class MyTuple(tuple):
45    pass
46
47class BadTuple(tuple):
48    def __add__(self, other):
49        return list(self) + list(other)
50
51class MyDict(dict):
52    pass
53
54
55class TestPartial:
56
57    def test_basic_examples(self):
58        p = self.partial(capture, 1, 2, a=10, b=20)
59        self.assertTrue(callable(p))
60        self.assertEqual(p(3, 4, b=30, c=40),
61                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
62        p = self.partial(map, lambda x: x*10)
63        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
64
65    def test_attributes(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        # attributes should be readable
68        self.assertEqual(p.func, capture)
69        self.assertEqual(p.args, (1, 2))
70        self.assertEqual(p.keywords, dict(a=10, b=20))
71
72    def test_argument_checking(self):
73        self.assertRaises(TypeError, self.partial)     # need at least a func arg
74        try:
75            self.partial(2)()
76        except TypeError:
77            pass
78        else:
79            self.fail('First arg not checked for callability')
80
81    def test_protection_of_callers_dict_argument(self):
82        # a caller's dictionary should not be altered by partial
83        def func(a=10, b=20):
84            return a
85        d = {'a':3}
86        p = self.partial(func, a=5)
87        self.assertEqual(p(**d), 3)
88        self.assertEqual(d, {'a':3})
89        p(b=7)
90        self.assertEqual(d, {'a':3})
91
92    def test_kwargs_copy(self):
93        # Issue #29532: Altering a kwarg dictionary passed to a constructor
94        # should not affect a partial object after creation
95        d = {'a': 3}
96        p = self.partial(capture, **d)
97        self.assertEqual(p(), ((), {'a': 3}))
98        d['a'] = 5
99        self.assertEqual(p(), ((), {'a': 3}))
100
101    def test_arg_combinations(self):
102        # exercise special code paths for zero args in either partial
103        # object or the caller
104        p = self.partial(capture)
105        self.assertEqual(p(), ((), {}))
106        self.assertEqual(p(1,2), ((1,2), {}))
107        p = self.partial(capture, 1, 2)
108        self.assertEqual(p(), ((1,2), {}))
109        self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111    def test_kw_combinations(self):
112        # exercise special code paths for no keyword args in
113        # either the partial object or the caller
114        p = self.partial(capture)
115        self.assertEqual(p.keywords, {})
116        self.assertEqual(p(), ((), {}))
117        self.assertEqual(p(a=1), ((), {'a':1}))
118        p = self.partial(capture, a=1)
119        self.assertEqual(p.keywords, {'a':1})
120        self.assertEqual(p(), ((), {'a':1}))
121        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122        # keyword args in the call override those in the partial object
123        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125    def test_positional(self):
126        # make sure positional arguments are captured correctly
127        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
128            p = self.partial(capture, *args)
129            expected = args + ('x',)
130            got, empty = p('x')
131            self.assertTrue(expected == got and empty == {})
132
133    def test_keyword(self):
134        # make sure keyword arguments are captured correctly
135        for a in ['a', 0, None, 3.5]:
136            p = self.partial(capture, a=a)
137            expected = {'a':a,'x':None}
138            empty, got = p(x=None)
139            self.assertTrue(expected == got and empty == ())
140
141    def test_no_side_effects(self):
142        # make sure there are no side effects that affect subsequent calls
143        p = self.partial(capture, 0, a=1)
144        args1, kw1 = p(1, b=2)
145        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
146        args2, kw2 = p()
147        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
148
149    def test_error_propagation(self):
150        def f(x, y):
151            x / y
152        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
156
157    def test_weakref(self):
158        f = self.partial(int, base=16)
159        p = proxy(f)
160        self.assertEqual(f.func, p.func)
161        f = None
162        self.assertRaises(ReferenceError, getattr, p, 'func')
163
164    def test_with_bound_and_unbound_methods(self):
165        data = list(map(str, range(10)))
166        join = self.partial(str.join, '')
167        self.assertEqual(join(data), '0123456789')
168        join = self.partial(''.join)
169        self.assertEqual(join(data), '0123456789')
170
171    def test_nested_optimization(self):
172        partial = self.partial
173        inner = partial(signature, 'asdf')
174        nested = partial(inner, bar=True)
175        flat = partial(signature, 'asdf', bar=True)
176        self.assertEqual(signature(nested), signature(flat))
177
178    def test_nested_partial_with_attribute(self):
179        # see issue 25137
180        partial = self.partial
181
182        def foo(bar):
183            return bar
184
185        p = partial(foo, 'first')
186        p2 = partial(p, 'second')
187        p2.new_attr = 'spam'
188        self.assertEqual(p2.new_attr, 'spam')
189
190    def test_repr(self):
191        args = (object(), object())
192        args_repr = ', '.join(repr(a) for a in args)
193        kwargs = {'a': object(), 'b': object()}
194        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195                        'b={b!r}, a={a!r}'.format_map(kwargs)]
196        if self.partial in (c_functools.partial, py_functools.partial):
197            name = 'functools.partial'
198        else:
199            name = self.partial.__name__
200
201        f = self.partial(capture)
202        self.assertEqual(f'{name}({capture!r})', repr(f))
203
204        f = self.partial(capture, *args)
205        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
206
207        f = self.partial(capture, **kwargs)
208        self.assertIn(repr(f),
209                      [f'{name}({capture!r}, {kwargs_repr})'
210                       for kwargs_repr in kwargs_reprs])
211
212        f = self.partial(capture, *args, **kwargs)
213        self.assertIn(repr(f),
214                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
215                       for kwargs_repr in kwargs_reprs])
216
217    def test_recursive_repr(self):
218        if self.partial in (c_functools.partial, py_functools.partial):
219            name = 'functools.partial'
220        else:
221            name = self.partial.__name__
222
223        f = self.partial(capture)
224        f.__setstate__((f, (), {}, {}))
225        try:
226            self.assertEqual(repr(f), '%s(...)' % (name,))
227        finally:
228            f.__setstate__((capture, (), {}, {}))
229
230        f = self.partial(capture)
231        f.__setstate__((capture, (f,), {}, {}))
232        try:
233            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
234        finally:
235            f.__setstate__((capture, (), {}, {}))
236
237        f = self.partial(capture)
238        f.__setstate__((capture, (), {'a': f}, {}))
239        try:
240            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
241        finally:
242            f.__setstate__((capture, (), {}, {}))
243
244    def test_pickle(self):
245        with self.AllowPickle():
246            f = self.partial(signature, ['asdf'], bar=[True])
247            f.attr = []
248            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249                f_copy = pickle.loads(pickle.dumps(f, proto))
250                self.assertEqual(signature(f_copy), signature(f))
251
252    def test_copy(self):
253        f = self.partial(signature, ['asdf'], bar=[True])
254        f.attr = []
255        f_copy = copy.copy(f)
256        self.assertEqual(signature(f_copy), signature(f))
257        self.assertIs(f_copy.attr, f.attr)
258        self.assertIs(f_copy.args, f.args)
259        self.assertIs(f_copy.keywords, f.keywords)
260
261    def test_deepcopy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.deepcopy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIsNot(f_copy.attr, f.attr)
267        self.assertIsNot(f_copy.args, f.args)
268        self.assertIsNot(f_copy.args[0], f.args[0])
269        self.assertIsNot(f_copy.keywords, f.keywords)
270        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272    def test_setstate(self):
273        f = self.partial(signature)
274        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
275
276        self.assertEqual(signature(f),
277                         (capture, (1,), dict(a=10), dict(attr=[])))
278        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280        f.__setstate__((capture, (1,), dict(a=10), None))
281
282        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285        f.__setstate__((capture, (1,), None, None))
286        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288        self.assertEqual(f(2), ((1, 2), {}))
289        self.assertEqual(f(), ((1,), {}))
290
291        f.__setstate__((capture, (), {}, None))
292        self.assertEqual(signature(f), (capture, (), {}, {}))
293        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294        self.assertEqual(f(2), ((2,), {}))
295        self.assertEqual(f(), ((), {}))
296
297    def test_setstate_errors(self):
298        f = self.partial(signature)
299        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307    def test_setstate_subclasses(self):
308        f = self.partial(signature)
309        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310        s = signature(f)
311        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312        self.assertIs(type(s[1]), tuple)
313        self.assertIs(type(s[2]), dict)
314        r = f()
315        self.assertEqual(r, ((1,), {'a': 10}))
316        self.assertIs(type(r[0]), tuple)
317        self.assertIs(type(r[1]), dict)
318
319        f.__setstate__((capture, BadTuple((1,)), {}, None))
320        s = signature(f)
321        self.assertEqual(s, (capture, (1,), {}, {}))
322        self.assertIs(type(s[1]), tuple)
323        r = f(2)
324        self.assertEqual(r, ((1, 2), {}))
325        self.assertIs(type(r[0]), tuple)
326
327    def test_recursive_pickle(self):
328        with self.AllowPickle():
329            f = self.partial(capture)
330            f.__setstate__((f, (), {}, {}))
331            try:
332                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333                    with self.assertRaises(RecursionError):
334                        pickle.dumps(f, proto)
335            finally:
336                f.__setstate__((capture, (), {}, {}))
337
338            f = self.partial(capture)
339            f.__setstate__((capture, (f,), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    f_copy = pickle.loads(pickle.dumps(f, proto))
343                    try:
344                        self.assertIs(f_copy.args[0], f_copy)
345                    finally:
346                        f_copy.__setstate__((capture, (), {}, {}))
347            finally:
348                f.__setstate__((capture, (), {}, {}))
349
350            f = self.partial(capture)
351            f.__setstate__((capture, (), {'a': f}, {}))
352            try:
353                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354                    f_copy = pickle.loads(pickle.dumps(f, proto))
355                    try:
356                        self.assertIs(f_copy.keywords['a'], f_copy)
357                    finally:
358                        f_copy.__setstate__((capture, (), {}, {}))
359            finally:
360                f.__setstate__((capture, (), {}, {}))
361
362    # Issue 6083: Reference counting bug
363    def test_setstate_refcount(self):
364        class BadSequence:
365            def __len__(self):
366                return 4
367            def __getitem__(self, key):
368                if key == 0:
369                    return max
370                elif key == 1:
371                    return tuple(range(1000000))
372                elif key in (2, 3):
373                    return {}
374                raise IndexError
375
376        f = self.partial(object)
377        self.assertRaises(TypeError, f.__setstate__, BadSequence())
378
379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381    if c_functools:
382        partial = c_functools.partial
383
384    class AllowPickle:
385        def __enter__(self):
386            return self
387        def __exit__(self, type, value, tb):
388            return False
389
390    def test_attributes_unwritable(self):
391        # attributes should not be writable
392        p = self.partial(capture, 1, 2, a=10, b=20)
393        self.assertRaises(AttributeError, setattr, p, 'func', map)
394        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397        p = self.partial(hex)
398        try:
399            del p.__dict__
400        except TypeError:
401            pass
402        else:
403            self.fail('partial object allowed __dict__ to be deleted')
404
405    def test_manually_adding_non_string_keyword(self):
406        p = self.partial(capture)
407        # Adding a non-string/unicode keyword to partial kwargs
408        p.keywords[1234] = 'value'
409        r = repr(p)
410        self.assertIn('1234', r)
411        self.assertIn("'value'", r)
412        with self.assertRaises(TypeError):
413            p()
414
415    def test_keystr_replaces_value(self):
416        p = self.partial(capture)
417
418        class MutatesYourDict(object):
419            def __str__(self):
420                p.keywords[self] = ['sth2']
421                return 'astr'
422
423        # Replacing the value during key formatting should keep the original
424        # value alive (at least long enough).
425        p.keywords[MutatesYourDict()] = ['sth']
426        r = repr(p)
427        self.assertIn('astr', r)
428        self.assertIn("['sth']", r)
429
430
431class TestPartialPy(TestPartial, unittest.TestCase):
432    partial = py_functools.partial
433
434    class AllowPickle:
435        def __init__(self):
436            self._cm = replaced_module("functools", py_functools)
437        def __enter__(self):
438            return self._cm.__enter__()
439        def __exit__(self, type, value, tb):
440            return self._cm.__exit__(type, value, tb)
441
442if c_functools:
443    class CPartialSubclass(c_functools.partial):
444        pass
445
446class PyPartialSubclass(py_functools.partial):
447    pass
448
449@unittest.skipUnless(c_functools, 'requires the C _functools module')
450class TestPartialCSubclass(TestPartialC):
451    if c_functools:
452        partial = CPartialSubclass
453
454    # partial subclasses are not optimized for nested calls
455    test_nested_optimization = None
456
457class TestPartialPySubclass(TestPartialPy):
458    partial = PyPartialSubclass
459
460class TestPartialMethod(unittest.TestCase):
461
462    class A(object):
463        nothing = functools.partialmethod(capture)
464        positional = functools.partialmethod(capture, 1)
465        keywords = functools.partialmethod(capture, a=2)
466        both = functools.partialmethod(capture, 3, b=4)
467        spec_keywords = functools.partialmethod(capture, self=1, func=2)
468
469        nested = functools.partialmethod(positional, 5)
470
471        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473        static = functools.partialmethod(staticmethod(capture), 8)
474        cls = functools.partialmethod(classmethod(capture), d=9)
475
476    a = A()
477
478    def test_arg_combinations(self):
479        self.assertEqual(self.a.nothing(), ((self.a,), {}))
480        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
501        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
503    def test_nested(self):
504        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511    def test_over_partial(self):
512        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519    def test_bound_method_introspection(self):
520        obj = self.a
521        self.assertIs(obj.both.__self__, obj)
522        self.assertIs(obj.nested.__self__, obj)
523        self.assertIs(obj.over_partial.__self__, obj)
524        self.assertIs(obj.cls.__self__, self.A)
525        self.assertIs(self.A.cls.__self__, self.A)
526
527    def test_unbound_method_retrieval(self):
528        obj = self.A
529        self.assertFalse(hasattr(obj.both, "__self__"))
530        self.assertFalse(hasattr(obj.nested, "__self__"))
531        self.assertFalse(hasattr(obj.over_partial, "__self__"))
532        self.assertFalse(hasattr(obj.static, "__self__"))
533        self.assertFalse(hasattr(self.a.static, "__self__"))
534
535    def test_descriptors(self):
536        for obj in [self.A, self.a]:
537            with self.subTest(obj=obj):
538                self.assertEqual(obj.static(), ((8,), {}))
539                self.assertEqual(obj.static(5), ((8, 5), {}))
540                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548    def test_overriding_keywords(self):
549        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552    def test_invalid_args(self):
553        with self.assertRaises(TypeError):
554            class B(object):
555                method = functools.partialmethod(None, 1)
556        with self.assertRaises(TypeError):
557            class B:
558                method = functools.partialmethod()
559        with self.assertWarns(DeprecationWarning):
560            class B:
561                method = functools.partialmethod(func=capture, a=1)
562        b = B()
563        self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
564
565    def test_repr(self):
566        self.assertEqual(repr(vars(self.A)['both']),
567                         'functools.partialmethod({}, 3, b=4)'.format(capture))
568
569    def test_abstract(self):
570        class Abstract(abc.ABCMeta):
571
572            @abc.abstractmethod
573            def add(self, x, y):
574                pass
575
576            add5 = functools.partialmethod(add, 5)
577
578        self.assertTrue(Abstract.add.__isabstractmethod__)
579        self.assertTrue(Abstract.add5.__isabstractmethod__)
580
581        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
582            self.assertFalse(getattr(func, '__isabstractmethod__', False))
583
584    def test_positional_only(self):
585        def f(a, b, /):
586            return a + b
587
588        p = functools.partial(f, 1)
589        self.assertEqual(p(2), f(1, 2))
590
591
592class TestUpdateWrapper(unittest.TestCase):
593
594    def check_wrapper(self, wrapper, wrapped,
595                      assigned=functools.WRAPPER_ASSIGNMENTS,
596                      updated=functools.WRAPPER_UPDATES):
597        # Check attributes were assigned
598        for name in assigned:
599            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
600        # Check attributes were updated
601        for name in updated:
602            wrapper_attr = getattr(wrapper, name)
603            wrapped_attr = getattr(wrapped, name)
604            for key in wrapped_attr:
605                if name == "__dict__" and key == "__wrapped__":
606                    # __wrapped__ is overwritten by the update code
607                    continue
608                self.assertIs(wrapped_attr[key], wrapper_attr[key])
609        # Check __wrapped__
610        self.assertIs(wrapper.__wrapped__, wrapped)
611
612
613    def _default_update(self):
614        def f(a:'This is a new annotation'):
615            """This is a test"""
616            pass
617        f.attr = 'This is also a test'
618        f.__wrapped__ = "This is a bald faced lie"
619        def wrapper(b:'This is the prior annotation'):
620            pass
621        functools.update_wrapper(wrapper, f)
622        return wrapper, f
623
624    def test_default_update(self):
625        wrapper, f = self._default_update()
626        self.check_wrapper(wrapper, f)
627        self.assertIs(wrapper.__wrapped__, f)
628        self.assertEqual(wrapper.__name__, 'f')
629        self.assertEqual(wrapper.__qualname__, f.__qualname__)
630        self.assertEqual(wrapper.attr, 'This is also a test')
631        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
632        self.assertNotIn('b', wrapper.__annotations__)
633
634    @unittest.skipIf(sys.flags.optimize >= 2,
635                     "Docstrings are omitted with -O2 and above")
636    def test_default_update_doc(self):
637        wrapper, f = self._default_update()
638        self.assertEqual(wrapper.__doc__, 'This is a test')
639
640    def test_no_update(self):
641        def f():
642            """This is a test"""
643            pass
644        f.attr = 'This is also a test'
645        def wrapper():
646            pass
647        functools.update_wrapper(wrapper, f, (), ())
648        self.check_wrapper(wrapper, f, (), ())
649        self.assertEqual(wrapper.__name__, 'wrapper')
650        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
651        self.assertEqual(wrapper.__doc__, None)
652        self.assertEqual(wrapper.__annotations__, {})
653        self.assertFalse(hasattr(wrapper, 'attr'))
654
655    def test_selective_update(self):
656        def f():
657            pass
658        f.attr = 'This is a different test'
659        f.dict_attr = dict(a=1, b=2, c=3)
660        def wrapper():
661            pass
662        wrapper.dict_attr = {}
663        assign = ('attr',)
664        update = ('dict_attr',)
665        functools.update_wrapper(wrapper, f, assign, update)
666        self.check_wrapper(wrapper, f, assign, update)
667        self.assertEqual(wrapper.__name__, 'wrapper')
668        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
669        self.assertEqual(wrapper.__doc__, None)
670        self.assertEqual(wrapper.attr, 'This is a different test')
671        self.assertEqual(wrapper.dict_attr, f.dict_attr)
672
673    def test_missing_attributes(self):
674        def f():
675            pass
676        def wrapper():
677            pass
678        wrapper.dict_attr = {}
679        assign = ('attr',)
680        update = ('dict_attr',)
681        # Missing attributes on wrapped object are ignored
682        functools.update_wrapper(wrapper, f, assign, update)
683        self.assertNotIn('attr', wrapper.__dict__)
684        self.assertEqual(wrapper.dict_attr, {})
685        # Wrapper must have expected attributes for updating
686        del wrapper.dict_attr
687        with self.assertRaises(AttributeError):
688            functools.update_wrapper(wrapper, f, assign, update)
689        wrapper.dict_attr = 1
690        with self.assertRaises(AttributeError):
691            functools.update_wrapper(wrapper, f, assign, update)
692
693    @support.requires_docstrings
694    @unittest.skipIf(sys.flags.optimize >= 2,
695                     "Docstrings are omitted with -O2 and above")
696    def test_builtin_update(self):
697        # Test for bug #1576241
698        def wrapper():
699            pass
700        functools.update_wrapper(wrapper, max)
701        self.assertEqual(wrapper.__name__, 'max')
702        self.assertTrue(wrapper.__doc__.startswith('max('))
703        self.assertEqual(wrapper.__annotations__, {})
704
705
706class TestWraps(TestUpdateWrapper):
707
708    def _default_update(self):
709        def f():
710            """This is a test"""
711            pass
712        f.attr = 'This is also a test'
713        f.__wrapped__ = "This is still a bald faced lie"
714        @functools.wraps(f)
715        def wrapper():
716            pass
717        return wrapper, f
718
719    def test_default_update(self):
720        wrapper, f = self._default_update()
721        self.check_wrapper(wrapper, f)
722        self.assertEqual(wrapper.__name__, 'f')
723        self.assertEqual(wrapper.__qualname__, f.__qualname__)
724        self.assertEqual(wrapper.attr, 'This is also a test')
725
726    @unittest.skipIf(sys.flags.optimize >= 2,
727                     "Docstrings are omitted with -O2 and above")
728    def test_default_update_doc(self):
729        wrapper, _ = self._default_update()
730        self.assertEqual(wrapper.__doc__, 'This is a test')
731
732    def test_no_update(self):
733        def f():
734            """This is a test"""
735            pass
736        f.attr = 'This is also a test'
737        @functools.wraps(f, (), ())
738        def wrapper():
739            pass
740        self.check_wrapper(wrapper, f, (), ())
741        self.assertEqual(wrapper.__name__, 'wrapper')
742        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
743        self.assertEqual(wrapper.__doc__, None)
744        self.assertFalse(hasattr(wrapper, 'attr'))
745
746    def test_selective_update(self):
747        def f():
748            pass
749        f.attr = 'This is a different test'
750        f.dict_attr = dict(a=1, b=2, c=3)
751        def add_dict_attr(f):
752            f.dict_attr = {}
753            return f
754        assign = ('attr',)
755        update = ('dict_attr',)
756        @functools.wraps(f, assign, update)
757        @add_dict_attr
758        def wrapper():
759            pass
760        self.check_wrapper(wrapper, f, assign, update)
761        self.assertEqual(wrapper.__name__, 'wrapper')
762        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
763        self.assertEqual(wrapper.__doc__, None)
764        self.assertEqual(wrapper.attr, 'This is a different test')
765        self.assertEqual(wrapper.dict_attr, f.dict_attr)
766
767
768class TestReduce:
769    def test_reduce(self):
770        class Squares:
771            def __init__(self, max):
772                self.max = max
773                self.sofar = []
774
775            def __len__(self):
776                return len(self.sofar)
777
778            def __getitem__(self, i):
779                if not 0 <= i < self.max: raise IndexError
780                n = len(self.sofar)
781                while n <= i:
782                    self.sofar.append(n*n)
783                    n += 1
784                return self.sofar[i]
785        def add(x, y):
786            return x + y
787        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
788        self.assertEqual(
789            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
790            ['a','c','d','w']
791        )
792        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
793        self.assertEqual(
794            self.reduce(lambda x, y: x*y, range(2,21), 1),
795            2432902008176640000
796        )
797        self.assertEqual(self.reduce(add, Squares(10)), 285)
798        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
799        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
800        self.assertRaises(TypeError, self.reduce)
801        self.assertRaises(TypeError, self.reduce, 42, 42)
802        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
803        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
804        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
805        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
806        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
807        self.assertRaises(TypeError, self.reduce, add, "")
808        self.assertRaises(TypeError, self.reduce, add, ())
809        self.assertRaises(TypeError, self.reduce, add, object())
810
811        class TestFailingIter:
812            def __iter__(self):
813                raise RuntimeError
814        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
815
816        self.assertEqual(self.reduce(add, [], None), None)
817        self.assertEqual(self.reduce(add, [], 42), 42)
818
819        class BadSeq:
820            def __getitem__(self, index):
821                raise ValueError
822        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
823
824    # Test reduce()'s use of iterators.
825    def test_iterator_usage(self):
826        class SequenceClass:
827            def __init__(self, n):
828                self.n = n
829            def __getitem__(self, i):
830                if 0 <= i < self.n:
831                    return i
832                else:
833                    raise IndexError
834
835        from operator import add
836        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
837        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
838        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
839        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
840        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
841        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
842
843        d = {"one": 1, "two": 2, "three": 3}
844        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
845
846
847@unittest.skipUnless(c_functools, 'requires the C _functools module')
848class TestReduceC(TestReduce, unittest.TestCase):
849    if c_functools:
850        reduce = c_functools.reduce
851
852
853class TestReducePy(TestReduce, unittest.TestCase):
854    reduce = staticmethod(py_functools.reduce)
855
856
857class TestCmpToKey:
858
859    def test_cmp_to_key(self):
860        def cmp1(x, y):
861            return (x > y) - (x < y)
862        key = self.cmp_to_key(cmp1)
863        self.assertEqual(key(3), key(3))
864        self.assertGreater(key(3), key(1))
865        self.assertGreaterEqual(key(3), key(3))
866
867        def cmp2(x, y):
868            return int(x) - int(y)
869        key = self.cmp_to_key(cmp2)
870        self.assertEqual(key(4.0), key('4'))
871        self.assertLess(key(2), key('35'))
872        self.assertLessEqual(key(2), key('35'))
873        self.assertNotEqual(key(2), key('35'))
874
875    def test_cmp_to_key_arguments(self):
876        def cmp1(x, y):
877            return (x > y) - (x < y)
878        key = self.cmp_to_key(mycmp=cmp1)
879        self.assertEqual(key(obj=3), key(obj=3))
880        self.assertGreater(key(obj=3), key(obj=1))
881        with self.assertRaises((TypeError, AttributeError)):
882            key(3) > 1    # rhs is not a K object
883        with self.assertRaises((TypeError, AttributeError)):
884            1 < key(3)    # lhs is not a K object
885        with self.assertRaises(TypeError):
886            key = self.cmp_to_key()             # too few args
887        with self.assertRaises(TypeError):
888            key = self.cmp_to_key(cmp1, None)   # too many args
889        key = self.cmp_to_key(cmp1)
890        with self.assertRaises(TypeError):
891            key()                                    # too few args
892        with self.assertRaises(TypeError):
893            key(None, None)                          # too many args
894
895    def test_bad_cmp(self):
896        def cmp1(x, y):
897            raise ZeroDivisionError
898        key = self.cmp_to_key(cmp1)
899        with self.assertRaises(ZeroDivisionError):
900            key(3) > key(1)
901
902        class BadCmp:
903            def __lt__(self, other):
904                raise ZeroDivisionError
905        def cmp1(x, y):
906            return BadCmp()
907        with self.assertRaises(ZeroDivisionError):
908            key(3) > key(1)
909
910    def test_obj_field(self):
911        def cmp1(x, y):
912            return (x > y) - (x < y)
913        key = self.cmp_to_key(mycmp=cmp1)
914        self.assertEqual(key(50).obj, 50)
915
916    def test_sort_int(self):
917        def mycmp(x, y):
918            return y - x
919        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
920                         [4, 3, 2, 1, 0])
921
922    def test_sort_int_str(self):
923        def mycmp(x, y):
924            x, y = int(x), int(y)
925            return (x > y) - (x < y)
926        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
927        values = sorted(values, key=self.cmp_to_key(mycmp))
928        self.assertEqual([int(value) for value in values],
929                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
930
931    def test_hash(self):
932        def mycmp(x, y):
933            return y - x
934        key = self.cmp_to_key(mycmp)
935        k = key(10)
936        self.assertRaises(TypeError, hash, k)
937        self.assertNotIsInstance(k, collections.abc.Hashable)
938
939
940@unittest.skipUnless(c_functools, 'requires the C _functools module')
941class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
942    if c_functools:
943        cmp_to_key = c_functools.cmp_to_key
944
945
946class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
947    cmp_to_key = staticmethod(py_functools.cmp_to_key)
948
949
950class TestTotalOrdering(unittest.TestCase):
951
952    def test_total_ordering_lt(self):
953        @functools.total_ordering
954        class A:
955            def __init__(self, value):
956                self.value = value
957            def __lt__(self, other):
958                return self.value < other.value
959            def __eq__(self, other):
960                return self.value == other.value
961        self.assertTrue(A(1) < A(2))
962        self.assertTrue(A(2) > A(1))
963        self.assertTrue(A(1) <= A(2))
964        self.assertTrue(A(2) >= A(1))
965        self.assertTrue(A(2) <= A(2))
966        self.assertTrue(A(2) >= A(2))
967        self.assertFalse(A(1) > A(2))
968
969    def test_total_ordering_le(self):
970        @functools.total_ordering
971        class A:
972            def __init__(self, value):
973                self.value = value
974            def __le__(self, other):
975                return self.value <= other.value
976            def __eq__(self, other):
977                return self.value == other.value
978        self.assertTrue(A(1) < A(2))
979        self.assertTrue(A(2) > A(1))
980        self.assertTrue(A(1) <= A(2))
981        self.assertTrue(A(2) >= A(1))
982        self.assertTrue(A(2) <= A(2))
983        self.assertTrue(A(2) >= A(2))
984        self.assertFalse(A(1) >= A(2))
985
986    def test_total_ordering_gt(self):
987        @functools.total_ordering
988        class A:
989            def __init__(self, value):
990                self.value = value
991            def __gt__(self, other):
992                return self.value > other.value
993            def __eq__(self, other):
994                return self.value == other.value
995        self.assertTrue(A(1) < A(2))
996        self.assertTrue(A(2) > A(1))
997        self.assertTrue(A(1) <= A(2))
998        self.assertTrue(A(2) >= A(1))
999        self.assertTrue(A(2) <= A(2))
1000        self.assertTrue(A(2) >= A(2))
1001        self.assertFalse(A(2) < A(1))
1002
1003    def test_total_ordering_ge(self):
1004        @functools.total_ordering
1005        class A:
1006            def __init__(self, value):
1007                self.value = value
1008            def __ge__(self, other):
1009                return self.value >= other.value
1010            def __eq__(self, other):
1011                return self.value == other.value
1012        self.assertTrue(A(1) < A(2))
1013        self.assertTrue(A(2) > A(1))
1014        self.assertTrue(A(1) <= A(2))
1015        self.assertTrue(A(2) >= A(1))
1016        self.assertTrue(A(2) <= A(2))
1017        self.assertTrue(A(2) >= A(2))
1018        self.assertFalse(A(2) <= A(1))
1019
1020    def test_total_ordering_no_overwrite(self):
1021        # new methods should not overwrite existing
1022        @functools.total_ordering
1023        class A(int):
1024            pass
1025        self.assertTrue(A(1) < A(2))
1026        self.assertTrue(A(2) > A(1))
1027        self.assertTrue(A(1) <= A(2))
1028        self.assertTrue(A(2) >= A(1))
1029        self.assertTrue(A(2) <= A(2))
1030        self.assertTrue(A(2) >= A(2))
1031
1032    def test_no_operations_defined(self):
1033        with self.assertRaises(ValueError):
1034            @functools.total_ordering
1035            class A:
1036                pass
1037
1038    def test_type_error_when_not_implemented(self):
1039        # bug 10042; ensure stack overflow does not occur
1040        # when decorated types return NotImplemented
1041        @functools.total_ordering
1042        class ImplementsLessThan:
1043            def __init__(self, value):
1044                self.value = value
1045            def __eq__(self, other):
1046                if isinstance(other, ImplementsLessThan):
1047                    return self.value == other.value
1048                return False
1049            def __lt__(self, other):
1050                if isinstance(other, ImplementsLessThan):
1051                    return self.value < other.value
1052                return NotImplemented
1053
1054        @functools.total_ordering
1055        class ImplementsGreaterThan:
1056            def __init__(self, value):
1057                self.value = value
1058            def __eq__(self, other):
1059                if isinstance(other, ImplementsGreaterThan):
1060                    return self.value == other.value
1061                return False
1062            def __gt__(self, other):
1063                if isinstance(other, ImplementsGreaterThan):
1064                    return self.value > other.value
1065                return NotImplemented
1066
1067        @functools.total_ordering
1068        class ImplementsLessThanEqualTo:
1069            def __init__(self, value):
1070                self.value = value
1071            def __eq__(self, other):
1072                if isinstance(other, ImplementsLessThanEqualTo):
1073                    return self.value == other.value
1074                return False
1075            def __le__(self, other):
1076                if isinstance(other, ImplementsLessThanEqualTo):
1077                    return self.value <= other.value
1078                return NotImplemented
1079
1080        @functools.total_ordering
1081        class ImplementsGreaterThanEqualTo:
1082            def __init__(self, value):
1083                self.value = value
1084            def __eq__(self, other):
1085                if isinstance(other, ImplementsGreaterThanEqualTo):
1086                    return self.value == other.value
1087                return False
1088            def __ge__(self, other):
1089                if isinstance(other, ImplementsGreaterThanEqualTo):
1090                    return self.value >= other.value
1091                return NotImplemented
1092
1093        @functools.total_ordering
1094        class ComparatorNotImplemented:
1095            def __init__(self, value):
1096                self.value = value
1097            def __eq__(self, other):
1098                if isinstance(other, ComparatorNotImplemented):
1099                    return self.value == other.value
1100                return False
1101            def __lt__(self, other):
1102                return NotImplemented
1103
1104        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1105            ImplementsLessThan(-1) < 1
1106
1107        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1108            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1109
1110        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1111            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1112
1113        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1114            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1115
1116        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1117            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1118
1119        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1120            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1121
1122        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1123            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1124
1125        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1126            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1127
1128        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1129            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1130
1131        with self.subTest("GE when equal"):
1132            a = ComparatorNotImplemented(8)
1133            b = ComparatorNotImplemented(8)
1134            self.assertEqual(a, b)
1135            with self.assertRaises(TypeError):
1136                a >= b
1137
1138        with self.subTest("LE when equal"):
1139            a = ComparatorNotImplemented(9)
1140            b = ComparatorNotImplemented(9)
1141            self.assertEqual(a, b)
1142            with self.assertRaises(TypeError):
1143                a <= b
1144
1145    def test_pickle(self):
1146        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1147            for name in '__lt__', '__gt__', '__le__', '__ge__':
1148                with self.subTest(method=name, proto=proto):
1149                    method = getattr(Orderable_LT, name)
1150                    method_copy = pickle.loads(pickle.dumps(method, proto))
1151                    self.assertIs(method_copy, method)
1152
1153@functools.total_ordering
1154class Orderable_LT:
1155    def __init__(self, value):
1156        self.value = value
1157    def __lt__(self, other):
1158        return self.value < other.value
1159    def __eq__(self, other):
1160        return self.value == other.value
1161
1162
1163class TestLRU:
1164
1165    def test_lru(self):
1166        def orig(x, y):
1167            return 3 * x + y
1168        f = self.module.lru_cache(maxsize=20)(orig)
1169        hits, misses, maxsize, currsize = f.cache_info()
1170        self.assertEqual(maxsize, 20)
1171        self.assertEqual(currsize, 0)
1172        self.assertEqual(hits, 0)
1173        self.assertEqual(misses, 0)
1174
1175        domain = range(5)
1176        for i in range(1000):
1177            x, y = choice(domain), choice(domain)
1178            actual = f(x, y)
1179            expected = orig(x, y)
1180            self.assertEqual(actual, expected)
1181        hits, misses, maxsize, currsize = f.cache_info()
1182        self.assertTrue(hits > misses)
1183        self.assertEqual(hits + misses, 1000)
1184        self.assertEqual(currsize, 20)
1185
1186        f.cache_clear()   # test clearing
1187        hits, misses, maxsize, currsize = f.cache_info()
1188        self.assertEqual(hits, 0)
1189        self.assertEqual(misses, 0)
1190        self.assertEqual(currsize, 0)
1191        f(x, y)
1192        hits, misses, maxsize, currsize = f.cache_info()
1193        self.assertEqual(hits, 0)
1194        self.assertEqual(misses, 1)
1195        self.assertEqual(currsize, 1)
1196
1197        # Test bypassing the cache
1198        self.assertIs(f.__wrapped__, orig)
1199        f.__wrapped__(x, y)
1200        hits, misses, maxsize, currsize = f.cache_info()
1201        self.assertEqual(hits, 0)
1202        self.assertEqual(misses, 1)
1203        self.assertEqual(currsize, 1)
1204
1205        # test size zero (which means "never-cache")
1206        @self.module.lru_cache(0)
1207        def f():
1208            nonlocal f_cnt
1209            f_cnt += 1
1210            return 20
1211        self.assertEqual(f.cache_info().maxsize, 0)
1212        f_cnt = 0
1213        for i in range(5):
1214            self.assertEqual(f(), 20)
1215        self.assertEqual(f_cnt, 5)
1216        hits, misses, maxsize, currsize = f.cache_info()
1217        self.assertEqual(hits, 0)
1218        self.assertEqual(misses, 5)
1219        self.assertEqual(currsize, 0)
1220
1221        # test size one
1222        @self.module.lru_cache(1)
1223        def f():
1224            nonlocal f_cnt
1225            f_cnt += 1
1226            return 20
1227        self.assertEqual(f.cache_info().maxsize, 1)
1228        f_cnt = 0
1229        for i in range(5):
1230            self.assertEqual(f(), 20)
1231        self.assertEqual(f_cnt, 1)
1232        hits, misses, maxsize, currsize = f.cache_info()
1233        self.assertEqual(hits, 4)
1234        self.assertEqual(misses, 1)
1235        self.assertEqual(currsize, 1)
1236
1237        # test size two
1238        @self.module.lru_cache(2)
1239        def f(x):
1240            nonlocal f_cnt
1241            f_cnt += 1
1242            return x*10
1243        self.assertEqual(f.cache_info().maxsize, 2)
1244        f_cnt = 0
1245        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1246            #    *  *              *                          *
1247            self.assertEqual(f(x), x*10)
1248        self.assertEqual(f_cnt, 4)
1249        hits, misses, maxsize, currsize = f.cache_info()
1250        self.assertEqual(hits, 12)
1251        self.assertEqual(misses, 4)
1252        self.assertEqual(currsize, 2)
1253
1254    def test_lru_no_args(self):
1255        @self.module.lru_cache
1256        def square(x):
1257            return x ** 2
1258
1259        self.assertEqual(list(map(square, [10, 20, 10])),
1260                         [100, 400, 100])
1261        self.assertEqual(square.cache_info().hits, 1)
1262        self.assertEqual(square.cache_info().misses, 2)
1263        self.assertEqual(square.cache_info().maxsize, 128)
1264        self.assertEqual(square.cache_info().currsize, 2)
1265
1266    def test_lru_bug_35780(self):
1267        # C version of the lru_cache was not checking to see if
1268        # the user function call has already modified the cache
1269        # (this arises in recursive calls and in multi-threading).
1270        # This cause the cache to have orphan links not referenced
1271        # by the cache dictionary.
1272
1273        once = True                 # Modified by f(x) below
1274
1275        @self.module.lru_cache(maxsize=10)
1276        def f(x):
1277            nonlocal once
1278            rv = f'.{x}.'
1279            if x == 20 and once:
1280                once = False
1281                rv = f(x)
1282            return rv
1283
1284        # Fill the cache
1285        for x in range(15):
1286            self.assertEqual(f(x), f'.{x}.')
1287        self.assertEqual(f.cache_info().currsize, 10)
1288
1289        # Make a recursive call and make sure the cache remains full
1290        self.assertEqual(f(20), '.20.')
1291        self.assertEqual(f.cache_info().currsize, 10)
1292
1293    def test_lru_bug_36650(self):
1294        # C version of lru_cache was treating a call with an empty **kwargs
1295        # dictionary as being distinct from a call with no keywords at all.
1296        # This did not result in an incorrect answer, but it did trigger
1297        # an unexpected cache miss.
1298
1299        @self.module.lru_cache()
1300        def f(x):
1301            pass
1302
1303        f(0)
1304        f(0, **{})
1305        self.assertEqual(f.cache_info().hits, 1)
1306
1307    def test_lru_hash_only_once(self):
1308        # To protect against weird reentrancy bugs and to improve
1309        # efficiency when faced with slow __hash__ methods, the
1310        # LRU cache guarantees that it will only call __hash__
1311        # only once per use as an argument to the cached function.
1312
1313        @self.module.lru_cache(maxsize=1)
1314        def f(x, y):
1315            return x * 3 + y
1316
1317        # Simulate the integer 5
1318        mock_int = unittest.mock.Mock()
1319        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1320        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1321
1322        # Add to cache:  One use as an argument gives one call
1323        self.assertEqual(f(mock_int, 1), 16)
1324        self.assertEqual(mock_int.__hash__.call_count, 1)
1325        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1326
1327        # Cache hit: One use as an argument gives one additional call
1328        self.assertEqual(f(mock_int, 1), 16)
1329        self.assertEqual(mock_int.__hash__.call_count, 2)
1330        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1331
1332        # Cache eviction: No use as an argument gives no additional call
1333        self.assertEqual(f(6, 2), 20)
1334        self.assertEqual(mock_int.__hash__.call_count, 2)
1335        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1336
1337        # Cache miss: One use as an argument gives one additional call
1338        self.assertEqual(f(mock_int, 1), 16)
1339        self.assertEqual(mock_int.__hash__.call_count, 3)
1340        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1341
1342    def test_lru_reentrancy_with_len(self):
1343        # Test to make sure the LRU cache code isn't thrown-off by
1344        # caching the built-in len() function.  Since len() can be
1345        # cached, we shouldn't use it inside the lru code itself.
1346        old_len = builtins.len
1347        try:
1348            builtins.len = self.module.lru_cache(4)(len)
1349            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1350                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1351        finally:
1352            builtins.len = old_len
1353
1354    def test_lru_star_arg_handling(self):
1355        # Test regression that arose in ea064ff3c10f
1356        @functools.lru_cache()
1357        def f(*args):
1358            return args
1359
1360        self.assertEqual(f(1, 2), (1, 2))
1361        self.assertEqual(f((1, 2)), ((1, 2),))
1362
1363    def test_lru_type_error(self):
1364        # Regression test for issue #28653.
1365        # lru_cache was leaking when one of the arguments
1366        # wasn't cacheable.
1367
1368        @functools.lru_cache(maxsize=None)
1369        def infinite_cache(o):
1370            pass
1371
1372        @functools.lru_cache(maxsize=10)
1373        def limited_cache(o):
1374            pass
1375
1376        with self.assertRaises(TypeError):
1377            infinite_cache([])
1378
1379        with self.assertRaises(TypeError):
1380            limited_cache([])
1381
1382    def test_lru_with_maxsize_none(self):
1383        @self.module.lru_cache(maxsize=None)
1384        def fib(n):
1385            if n < 2:
1386                return n
1387            return fib(n-1) + fib(n-2)
1388        self.assertEqual([fib(n) for n in range(16)],
1389            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1390        self.assertEqual(fib.cache_info(),
1391            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1392        fib.cache_clear()
1393        self.assertEqual(fib.cache_info(),
1394            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1395
1396    def test_lru_with_maxsize_negative(self):
1397        @self.module.lru_cache(maxsize=-10)
1398        def eq(n):
1399            return n
1400        for i in (0, 1):
1401            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1402        self.assertEqual(eq.cache_info(),
1403            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1404
1405    def test_lru_with_exceptions(self):
1406        # Verify that user_function exceptions get passed through without
1407        # creating a hard-to-read chained exception.
1408        # http://bugs.python.org/issue13177
1409        for maxsize in (None, 128):
1410            @self.module.lru_cache(maxsize)
1411            def func(i):
1412                return 'abc'[i]
1413            self.assertEqual(func(0), 'a')
1414            with self.assertRaises(IndexError) as cm:
1415                func(15)
1416            self.assertIsNone(cm.exception.__context__)
1417            # Verify that the previous exception did not result in a cached entry
1418            with self.assertRaises(IndexError):
1419                func(15)
1420
1421    def test_lru_with_types(self):
1422        for maxsize in (None, 128):
1423            @self.module.lru_cache(maxsize=maxsize, typed=True)
1424            def square(x):
1425                return x * x
1426            self.assertEqual(square(3), 9)
1427            self.assertEqual(type(square(3)), type(9))
1428            self.assertEqual(square(3.0), 9.0)
1429            self.assertEqual(type(square(3.0)), type(9.0))
1430            self.assertEqual(square(x=3), 9)
1431            self.assertEqual(type(square(x=3)), type(9))
1432            self.assertEqual(square(x=3.0), 9.0)
1433            self.assertEqual(type(square(x=3.0)), type(9.0))
1434            self.assertEqual(square.cache_info().hits, 4)
1435            self.assertEqual(square.cache_info().misses, 4)
1436
1437    def test_lru_with_keyword_args(self):
1438        @self.module.lru_cache()
1439        def fib(n):
1440            if n < 2:
1441                return n
1442            return fib(n=n-1) + fib(n=n-2)
1443        self.assertEqual(
1444            [fib(n=number) for number in range(16)],
1445            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1446        )
1447        self.assertEqual(fib.cache_info(),
1448            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1449        fib.cache_clear()
1450        self.assertEqual(fib.cache_info(),
1451            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1452
1453    def test_lru_with_keyword_args_maxsize_none(self):
1454        @self.module.lru_cache(maxsize=None)
1455        def fib(n):
1456            if n < 2:
1457                return n
1458            return fib(n=n-1) + fib(n=n-2)
1459        self.assertEqual([fib(n=number) for number in range(16)],
1460            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1461        self.assertEqual(fib.cache_info(),
1462            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1463        fib.cache_clear()
1464        self.assertEqual(fib.cache_info(),
1465            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1466
1467    def test_kwargs_order(self):
1468        # PEP 468: Preserving Keyword Argument Order
1469        @self.module.lru_cache(maxsize=10)
1470        def f(**kwargs):
1471            return list(kwargs.items())
1472        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1473        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1474        self.assertEqual(f.cache_info(),
1475            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1476
1477    def test_lru_cache_decoration(self):
1478        def f(zomg: 'zomg_annotation'):
1479            """f doc string"""
1480            return 42
1481        g = self.module.lru_cache()(f)
1482        for attr in self.module.WRAPPER_ASSIGNMENTS:
1483            self.assertEqual(getattr(g, attr), getattr(f, attr))
1484
1485    def test_lru_cache_threaded(self):
1486        n, m = 5, 11
1487        def orig(x, y):
1488            return 3 * x + y
1489        f = self.module.lru_cache(maxsize=n*m)(orig)
1490        hits, misses, maxsize, currsize = f.cache_info()
1491        self.assertEqual(currsize, 0)
1492
1493        start = threading.Event()
1494        def full(k):
1495            start.wait(10)
1496            for _ in range(m):
1497                self.assertEqual(f(k, 0), orig(k, 0))
1498
1499        def clear():
1500            start.wait(10)
1501            for _ in range(2*m):
1502                f.cache_clear()
1503
1504        orig_si = sys.getswitchinterval()
1505        support.setswitchinterval(1e-6)
1506        try:
1507            # create n threads in order to fill cache
1508            threads = [threading.Thread(target=full, args=[k])
1509                       for k in range(n)]
1510            with support.start_threads(threads):
1511                start.set()
1512
1513            hits, misses, maxsize, currsize = f.cache_info()
1514            if self.module is py_functools:
1515                # XXX: Why can be not equal?
1516                self.assertLessEqual(misses, n)
1517                self.assertLessEqual(hits, m*n - misses)
1518            else:
1519                self.assertEqual(misses, n)
1520                self.assertEqual(hits, m*n - misses)
1521            self.assertEqual(currsize, n)
1522
1523            # create n threads in order to fill cache and 1 to clear it
1524            threads = [threading.Thread(target=clear)]
1525            threads += [threading.Thread(target=full, args=[k])
1526                        for k in range(n)]
1527            start.clear()
1528            with support.start_threads(threads):
1529                start.set()
1530        finally:
1531            sys.setswitchinterval(orig_si)
1532
1533    def test_lru_cache_threaded2(self):
1534        # Simultaneous call with the same arguments
1535        n, m = 5, 7
1536        start = threading.Barrier(n+1)
1537        pause = threading.Barrier(n+1)
1538        stop = threading.Barrier(n+1)
1539        @self.module.lru_cache(maxsize=m*n)
1540        def f(x):
1541            pause.wait(10)
1542            return 3 * x
1543        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1544        def test():
1545            for i in range(m):
1546                start.wait(10)
1547                self.assertEqual(f(i), 3 * i)
1548                stop.wait(10)
1549        threads = [threading.Thread(target=test) for k in range(n)]
1550        with support.start_threads(threads):
1551            for i in range(m):
1552                start.wait(10)
1553                stop.reset()
1554                pause.wait(10)
1555                start.reset()
1556                stop.wait(10)
1557                pause.reset()
1558                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1559
1560    def test_lru_cache_threaded3(self):
1561        @self.module.lru_cache(maxsize=2)
1562        def f(x):
1563            time.sleep(.01)
1564            return 3 * x
1565        def test(i, x):
1566            with self.subTest(thread=i):
1567                self.assertEqual(f(x), 3 * x, i)
1568        threads = [threading.Thread(target=test, args=(i, v))
1569                   for i, v in enumerate([1, 2, 2, 3, 2])]
1570        with support.start_threads(threads):
1571            pass
1572
1573    def test_need_for_rlock(self):
1574        # This will deadlock on an LRU cache that uses a regular lock
1575
1576        @self.module.lru_cache(maxsize=10)
1577        def test_func(x):
1578            'Used to demonstrate a reentrant lru_cache call within a single thread'
1579            return x
1580
1581        class DoubleEq:
1582            'Demonstrate a reentrant lru_cache call within a single thread'
1583            def __init__(self, x):
1584                self.x = x
1585            def __hash__(self):
1586                return self.x
1587            def __eq__(self, other):
1588                if self.x == 2:
1589                    test_func(DoubleEq(1))
1590                return self.x == other.x
1591
1592        test_func(DoubleEq(1))                      # Load the cache
1593        test_func(DoubleEq(2))                      # Load the cache
1594        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1595                         DoubleEq(2))               # Verify the correct return value
1596
1597    def test_lru_method(self):
1598        class X(int):
1599            f_cnt = 0
1600            @self.module.lru_cache(2)
1601            def f(self, x):
1602                self.f_cnt += 1
1603                return x*10+self
1604        a = X(5)
1605        b = X(5)
1606        c = X(7)
1607        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1608
1609        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1610            self.assertEqual(a.f(x), x*10 + 5)
1611        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1612        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1613
1614        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1615            self.assertEqual(b.f(x), x*10 + 5)
1616        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1617        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1618
1619        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1620            self.assertEqual(c.f(x), x*10 + 7)
1621        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1622        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1623
1624        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1625        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1626        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1627
1628    def test_pickle(self):
1629        cls = self.__class__
1630        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1631            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1632                with self.subTest(proto=proto, func=f):
1633                    f_copy = pickle.loads(pickle.dumps(f, proto))
1634                    self.assertIs(f_copy, f)
1635
1636    def test_copy(self):
1637        cls = self.__class__
1638        def orig(x, y):
1639            return 3 * x + y
1640        part = self.module.partial(orig, 2)
1641        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1642                 self.module.lru_cache(2)(part))
1643        for f in funcs:
1644            with self.subTest(func=f):
1645                f_copy = copy.copy(f)
1646                self.assertIs(f_copy, f)
1647
1648    def test_deepcopy(self):
1649        cls = self.__class__
1650        def orig(x, y):
1651            return 3 * x + y
1652        part = self.module.partial(orig, 2)
1653        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1654                 self.module.lru_cache(2)(part))
1655        for f in funcs:
1656            with self.subTest(func=f):
1657                f_copy = copy.deepcopy(f)
1658                self.assertIs(f_copy, f)
1659
1660
1661@py_functools.lru_cache()
1662def py_cached_func(x, y):
1663    return 3 * x + y
1664
1665@c_functools.lru_cache()
1666def c_cached_func(x, y):
1667    return 3 * x + y
1668
1669
1670class TestLRUPy(TestLRU, unittest.TestCase):
1671    module = py_functools
1672    cached_func = py_cached_func,
1673
1674    @module.lru_cache()
1675    def cached_meth(self, x, y):
1676        return 3 * x + y
1677
1678    @staticmethod
1679    @module.lru_cache()
1680    def cached_staticmeth(x, y):
1681        return 3 * x + y
1682
1683
1684class TestLRUC(TestLRU, unittest.TestCase):
1685    module = c_functools
1686    cached_func = c_cached_func,
1687
1688    @module.lru_cache()
1689    def cached_meth(self, x, y):
1690        return 3 * x + y
1691
1692    @staticmethod
1693    @module.lru_cache()
1694    def cached_staticmeth(x, y):
1695        return 3 * x + y
1696
1697
1698class TestSingleDispatch(unittest.TestCase):
1699    def test_simple_overloads(self):
1700        @functools.singledispatch
1701        def g(obj):
1702            return "base"
1703        def g_int(i):
1704            return "integer"
1705        g.register(int, g_int)
1706        self.assertEqual(g("str"), "base")
1707        self.assertEqual(g(1), "integer")
1708        self.assertEqual(g([1,2,3]), "base")
1709
1710    def test_mro(self):
1711        @functools.singledispatch
1712        def g(obj):
1713            return "base"
1714        class A:
1715            pass
1716        class C(A):
1717            pass
1718        class B(A):
1719            pass
1720        class D(C, B):
1721            pass
1722        def g_A(a):
1723            return "A"
1724        def g_B(b):
1725            return "B"
1726        g.register(A, g_A)
1727        g.register(B, g_B)
1728        self.assertEqual(g(A()), "A")
1729        self.assertEqual(g(B()), "B")
1730        self.assertEqual(g(C()), "A")
1731        self.assertEqual(g(D()), "B")
1732
1733    def test_register_decorator(self):
1734        @functools.singledispatch
1735        def g(obj):
1736            return "base"
1737        @g.register(int)
1738        def g_int(i):
1739            return "int %s" % (i,)
1740        self.assertEqual(g(""), "base")
1741        self.assertEqual(g(12), "int 12")
1742        self.assertIs(g.dispatch(int), g_int)
1743        self.assertIs(g.dispatch(object), g.dispatch(str))
1744        # Note: in the assert above this is not g.
1745        # @singledispatch returns the wrapper.
1746
1747    def test_wrapping_attributes(self):
1748        @functools.singledispatch
1749        def g(obj):
1750            "Simple test"
1751            return "Test"
1752        self.assertEqual(g.__name__, "g")
1753        if sys.flags.optimize < 2:
1754            self.assertEqual(g.__doc__, "Simple test")
1755
1756    @unittest.skipUnless(decimal, 'requires _decimal')
1757    @support.cpython_only
1758    def test_c_classes(self):
1759        @functools.singledispatch
1760        def g(obj):
1761            return "base"
1762        @g.register(decimal.DecimalException)
1763        def _(obj):
1764            return obj.args
1765        subn = decimal.Subnormal("Exponent < Emin")
1766        rnd = decimal.Rounded("Number got rounded")
1767        self.assertEqual(g(subn), ("Exponent < Emin",))
1768        self.assertEqual(g(rnd), ("Number got rounded",))
1769        @g.register(decimal.Subnormal)
1770        def _(obj):
1771            return "Too small to care."
1772        self.assertEqual(g(subn), "Too small to care.")
1773        self.assertEqual(g(rnd), ("Number got rounded",))
1774
1775    def test_compose_mro(self):
1776        # None of the examples in this test depend on haystack ordering.
1777        c = collections.abc
1778        mro = functools._compose_mro
1779        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1780        for haystack in permutations(bases):
1781            m = mro(dict, haystack)
1782            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1783                                 c.Collection, c.Sized, c.Iterable,
1784                                 c.Container, object])
1785        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1786        for haystack in permutations(bases):
1787            m = mro(collections.ChainMap, haystack)
1788            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1789                                 c.Collection, c.Sized, c.Iterable,
1790                                 c.Container, object])
1791
1792        # If there's a generic function with implementations registered for
1793        # both Sized and Container, passing a defaultdict to it results in an
1794        # ambiguous dispatch which will cause a RuntimeError (see
1795        # test_mro_conflicts).
1796        bases = [c.Container, c.Sized, str]
1797        for haystack in permutations(bases):
1798            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1799            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1800                                 c.Container, object])
1801
1802        # MutableSequence below is registered directly on D. In other words, it
1803        # precedes MutableMapping which means single dispatch will always
1804        # choose MutableSequence here.
1805        class D(collections.defaultdict):
1806            pass
1807        c.MutableSequence.register(D)
1808        bases = [c.MutableSequence, c.MutableMapping]
1809        for haystack in permutations(bases):
1810            m = mro(D, bases)
1811            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1812                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1813                                 c.Collection, c.Sized, c.Iterable, c.Container,
1814                                 object])
1815
1816        # Container and Callable are registered on different base classes and
1817        # a generic function supporting both should always pick the Callable
1818        # implementation if a C instance is passed.
1819        class C(collections.defaultdict):
1820            def __call__(self):
1821                pass
1822        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1823        for haystack in permutations(bases):
1824            m = mro(C, haystack)
1825            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1826                                 c.Collection, c.Sized, c.Iterable,
1827                                 c.Container, object])
1828
1829    def test_register_abc(self):
1830        c = collections.abc
1831        d = {"a": "b"}
1832        l = [1, 2, 3]
1833        s = {object(), None}
1834        f = frozenset(s)
1835        t = (1, 2, 3)
1836        @functools.singledispatch
1837        def g(obj):
1838            return "base"
1839        self.assertEqual(g(d), "base")
1840        self.assertEqual(g(l), "base")
1841        self.assertEqual(g(s), "base")
1842        self.assertEqual(g(f), "base")
1843        self.assertEqual(g(t), "base")
1844        g.register(c.Sized, lambda obj: "sized")
1845        self.assertEqual(g(d), "sized")
1846        self.assertEqual(g(l), "sized")
1847        self.assertEqual(g(s), "sized")
1848        self.assertEqual(g(f), "sized")
1849        self.assertEqual(g(t), "sized")
1850        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1851        self.assertEqual(g(d), "mutablemapping")
1852        self.assertEqual(g(l), "sized")
1853        self.assertEqual(g(s), "sized")
1854        self.assertEqual(g(f), "sized")
1855        self.assertEqual(g(t), "sized")
1856        g.register(collections.ChainMap, lambda obj: "chainmap")
1857        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1858        self.assertEqual(g(l), "sized")
1859        self.assertEqual(g(s), "sized")
1860        self.assertEqual(g(f), "sized")
1861        self.assertEqual(g(t), "sized")
1862        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1863        self.assertEqual(g(d), "mutablemapping")
1864        self.assertEqual(g(l), "mutablesequence")
1865        self.assertEqual(g(s), "sized")
1866        self.assertEqual(g(f), "sized")
1867        self.assertEqual(g(t), "sized")
1868        g.register(c.MutableSet, lambda obj: "mutableset")
1869        self.assertEqual(g(d), "mutablemapping")
1870        self.assertEqual(g(l), "mutablesequence")
1871        self.assertEqual(g(s), "mutableset")
1872        self.assertEqual(g(f), "sized")
1873        self.assertEqual(g(t), "sized")
1874        g.register(c.Mapping, lambda obj: "mapping")
1875        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1876        self.assertEqual(g(l), "mutablesequence")
1877        self.assertEqual(g(s), "mutableset")
1878        self.assertEqual(g(f), "sized")
1879        self.assertEqual(g(t), "sized")
1880        g.register(c.Sequence, lambda obj: "sequence")
1881        self.assertEqual(g(d), "mutablemapping")
1882        self.assertEqual(g(l), "mutablesequence")
1883        self.assertEqual(g(s), "mutableset")
1884        self.assertEqual(g(f), "sized")
1885        self.assertEqual(g(t), "sequence")
1886        g.register(c.Set, lambda obj: "set")
1887        self.assertEqual(g(d), "mutablemapping")
1888        self.assertEqual(g(l), "mutablesequence")
1889        self.assertEqual(g(s), "mutableset")
1890        self.assertEqual(g(f), "set")
1891        self.assertEqual(g(t), "sequence")
1892        g.register(dict, lambda obj: "dict")
1893        self.assertEqual(g(d), "dict")
1894        self.assertEqual(g(l), "mutablesequence")
1895        self.assertEqual(g(s), "mutableset")
1896        self.assertEqual(g(f), "set")
1897        self.assertEqual(g(t), "sequence")
1898        g.register(list, lambda obj: "list")
1899        self.assertEqual(g(d), "dict")
1900        self.assertEqual(g(l), "list")
1901        self.assertEqual(g(s), "mutableset")
1902        self.assertEqual(g(f), "set")
1903        self.assertEqual(g(t), "sequence")
1904        g.register(set, lambda obj: "concrete-set")
1905        self.assertEqual(g(d), "dict")
1906        self.assertEqual(g(l), "list")
1907        self.assertEqual(g(s), "concrete-set")
1908        self.assertEqual(g(f), "set")
1909        self.assertEqual(g(t), "sequence")
1910        g.register(frozenset, lambda obj: "frozen-set")
1911        self.assertEqual(g(d), "dict")
1912        self.assertEqual(g(l), "list")
1913        self.assertEqual(g(s), "concrete-set")
1914        self.assertEqual(g(f), "frozen-set")
1915        self.assertEqual(g(t), "sequence")
1916        g.register(tuple, lambda obj: "tuple")
1917        self.assertEqual(g(d), "dict")
1918        self.assertEqual(g(l), "list")
1919        self.assertEqual(g(s), "concrete-set")
1920        self.assertEqual(g(f), "frozen-set")
1921        self.assertEqual(g(t), "tuple")
1922
1923    def test_c3_abc(self):
1924        c = collections.abc
1925        mro = functools._c3_mro
1926        class A(object):
1927            pass
1928        class B(A):
1929            def __len__(self):
1930                return 0   # implies Sized
1931        @c.Container.register
1932        class C(object):
1933            pass
1934        class D(object):
1935            pass   # unrelated
1936        class X(D, C, B):
1937            def __call__(self):
1938                pass   # implies Callable
1939        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1940        for abcs in permutations([c.Sized, c.Callable, c.Container]):
1941            self.assertEqual(mro(X, abcs=abcs), expected)
1942        # unrelated ABCs don't appear in the resulting MRO
1943        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1944        self.assertEqual(mro(X, abcs=many_abcs), expected)
1945
1946    def test_false_meta(self):
1947        # see issue23572
1948        class MetaA(type):
1949            def __len__(self):
1950                return 0
1951        class A(metaclass=MetaA):
1952            pass
1953        class AA(A):
1954            pass
1955        @functools.singledispatch
1956        def fun(a):
1957            return 'base A'
1958        @fun.register(A)
1959        def _(a):
1960            return 'fun A'
1961        aa = AA()
1962        self.assertEqual(fun(aa), 'fun A')
1963
1964    def test_mro_conflicts(self):
1965        c = collections.abc
1966        @functools.singledispatch
1967        def g(arg):
1968            return "base"
1969        class O(c.Sized):
1970            def __len__(self):
1971                return 0
1972        o = O()
1973        self.assertEqual(g(o), "base")
1974        g.register(c.Iterable, lambda arg: "iterable")
1975        g.register(c.Container, lambda arg: "container")
1976        g.register(c.Sized, lambda arg: "sized")
1977        g.register(c.Set, lambda arg: "set")
1978        self.assertEqual(g(o), "sized")
1979        c.Iterable.register(O)
1980        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
1981        c.Container.register(O)
1982        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
1983        c.Set.register(O)
1984        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
1985                                          # c.Sized and c.Container
1986        class P:
1987            pass
1988        p = P()
1989        self.assertEqual(g(p), "base")
1990        c.Iterable.register(P)
1991        self.assertEqual(g(p), "iterable")
1992        c.Container.register(P)
1993        with self.assertRaises(RuntimeError) as re_one:
1994            g(p)
1995        self.assertIn(
1996            str(re_one.exception),
1997            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1998              "or <class 'collections.abc.Iterable'>"),
1999             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2000              "or <class 'collections.abc.Container'>")),
2001        )
2002        class Q(c.Sized):
2003            def __len__(self):
2004                return 0
2005        q = Q()
2006        self.assertEqual(g(q), "sized")
2007        c.Iterable.register(Q)
2008        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2009        c.Set.register(Q)
2010        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2011                                          # c.Sized and c.Iterable
2012        @functools.singledispatch
2013        def h(arg):
2014            return "base"
2015        @h.register(c.Sized)
2016        def _(arg):
2017            return "sized"
2018        @h.register(c.Container)
2019        def _(arg):
2020            return "container"
2021        # Even though Sized and Container are explicit bases of MutableMapping,
2022        # this ABC is implicitly registered on defaultdict which makes all of
2023        # MutableMapping's bases implicit as well from defaultdict's
2024        # perspective.
2025        with self.assertRaises(RuntimeError) as re_two:
2026            h(collections.defaultdict(lambda: 0))
2027        self.assertIn(
2028            str(re_two.exception),
2029            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2030              "or <class 'collections.abc.Sized'>"),
2031             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2032              "or <class 'collections.abc.Container'>")),
2033        )
2034        class R(collections.defaultdict):
2035            pass
2036        c.MutableSequence.register(R)
2037        @functools.singledispatch
2038        def i(arg):
2039            return "base"
2040        @i.register(c.MutableMapping)
2041        def _(arg):
2042            return "mapping"
2043        @i.register(c.MutableSequence)
2044        def _(arg):
2045            return "sequence"
2046        r = R()
2047        self.assertEqual(i(r), "sequence")
2048        class S:
2049            pass
2050        class T(S, c.Sized):
2051            def __len__(self):
2052                return 0
2053        t = T()
2054        self.assertEqual(h(t), "sized")
2055        c.Container.register(T)
2056        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2057        class U:
2058            def __len__(self):
2059                return 0
2060        u = U()
2061        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2062                                          # from the existence of __len__()
2063        c.Container.register(U)
2064        # There is no preference for registered versus inferred ABCs.
2065        with self.assertRaises(RuntimeError) as re_three:
2066            h(u)
2067        self.assertIn(
2068            str(re_three.exception),
2069            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2070              "or <class 'collections.abc.Sized'>"),
2071             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2072              "or <class 'collections.abc.Container'>")),
2073        )
2074        class V(c.Sized, S):
2075            def __len__(self):
2076                return 0
2077        @functools.singledispatch
2078        def j(arg):
2079            return "base"
2080        @j.register(S)
2081        def _(arg):
2082            return "s"
2083        @j.register(c.Container)
2084        def _(arg):
2085            return "container"
2086        v = V()
2087        self.assertEqual(j(v), "s")
2088        c.Container.register(V)
2089        self.assertEqual(j(v), "container")   # because it ends up right after
2090                                              # Sized in the MRO
2091
2092    def test_cache_invalidation(self):
2093        from collections import UserDict
2094        import weakref
2095
2096        class TracingDict(UserDict):
2097            def __init__(self, *args, **kwargs):
2098                super(TracingDict, self).__init__(*args, **kwargs)
2099                self.set_ops = []
2100                self.get_ops = []
2101            def __getitem__(self, key):
2102                result = self.data[key]
2103                self.get_ops.append(key)
2104                return result
2105            def __setitem__(self, key, value):
2106                self.set_ops.append(key)
2107                self.data[key] = value
2108            def clear(self):
2109                self.data.clear()
2110
2111        td = TracingDict()
2112        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2113            c = collections.abc
2114            @functools.singledispatch
2115            def g(arg):
2116                return "base"
2117            d = {}
2118            l = []
2119            self.assertEqual(len(td), 0)
2120            self.assertEqual(g(d), "base")
2121            self.assertEqual(len(td), 1)
2122            self.assertEqual(td.get_ops, [])
2123            self.assertEqual(td.set_ops, [dict])
2124            self.assertEqual(td.data[dict], g.registry[object])
2125            self.assertEqual(g(l), "base")
2126            self.assertEqual(len(td), 2)
2127            self.assertEqual(td.get_ops, [])
2128            self.assertEqual(td.set_ops, [dict, list])
2129            self.assertEqual(td.data[dict], g.registry[object])
2130            self.assertEqual(td.data[list], g.registry[object])
2131            self.assertEqual(td.data[dict], td.data[list])
2132            self.assertEqual(g(l), "base")
2133            self.assertEqual(g(d), "base")
2134            self.assertEqual(td.get_ops, [list, dict])
2135            self.assertEqual(td.set_ops, [dict, list])
2136            g.register(list, lambda arg: "list")
2137            self.assertEqual(td.get_ops, [list, dict])
2138            self.assertEqual(len(td), 0)
2139            self.assertEqual(g(d), "base")
2140            self.assertEqual(len(td), 1)
2141            self.assertEqual(td.get_ops, [list, dict])
2142            self.assertEqual(td.set_ops, [dict, list, dict])
2143            self.assertEqual(td.data[dict],
2144                             functools._find_impl(dict, g.registry))
2145            self.assertEqual(g(l), "list")
2146            self.assertEqual(len(td), 2)
2147            self.assertEqual(td.get_ops, [list, dict])
2148            self.assertEqual(td.set_ops, [dict, list, dict, list])
2149            self.assertEqual(td.data[list],
2150                             functools._find_impl(list, g.registry))
2151            class X:
2152                pass
2153            c.MutableMapping.register(X)   # Will not invalidate the cache,
2154                                           # not using ABCs yet.
2155            self.assertEqual(g(d), "base")
2156            self.assertEqual(g(l), "list")
2157            self.assertEqual(td.get_ops, [list, dict, dict, list])
2158            self.assertEqual(td.set_ops, [dict, list, dict, list])
2159            g.register(c.Sized, lambda arg: "sized")
2160            self.assertEqual(len(td), 0)
2161            self.assertEqual(g(d), "sized")
2162            self.assertEqual(len(td), 1)
2163            self.assertEqual(td.get_ops, [list, dict, dict, list])
2164            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2165            self.assertEqual(g(l), "list")
2166            self.assertEqual(len(td), 2)
2167            self.assertEqual(td.get_ops, [list, dict, dict, list])
2168            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2169            self.assertEqual(g(l), "list")
2170            self.assertEqual(g(d), "sized")
2171            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2172            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2173            g.dispatch(list)
2174            g.dispatch(dict)
2175            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2176                                          list, dict])
2177            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2178            c.MutableSet.register(X)       # Will invalidate the cache.
2179            self.assertEqual(len(td), 2)   # Stale cache.
2180            self.assertEqual(g(l), "list")
2181            self.assertEqual(len(td), 1)
2182            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2183            self.assertEqual(len(td), 0)
2184            self.assertEqual(g(d), "mutablemapping")
2185            self.assertEqual(len(td), 1)
2186            self.assertEqual(g(l), "list")
2187            self.assertEqual(len(td), 2)
2188            g.register(dict, lambda arg: "dict")
2189            self.assertEqual(g(d), "dict")
2190            self.assertEqual(g(l), "list")
2191            g._clear_cache()
2192            self.assertEqual(len(td), 0)
2193
2194    def test_annotations(self):
2195        @functools.singledispatch
2196        def i(arg):
2197            return "base"
2198        @i.register
2199        def _(arg: collections.abc.Mapping):
2200            return "mapping"
2201        @i.register
2202        def _(arg: "collections.abc.Sequence"):
2203            return "sequence"
2204        self.assertEqual(i(None), "base")
2205        self.assertEqual(i({"a": 1}), "mapping")
2206        self.assertEqual(i([1, 2, 3]), "sequence")
2207        self.assertEqual(i((1, 2, 3)), "sequence")
2208        self.assertEqual(i("str"), "sequence")
2209
2210        # Registering classes as callables doesn't work with annotations,
2211        # you need to pass the type explicitly.
2212        @i.register(str)
2213        class _:
2214            def __init__(self, arg):
2215                self.arg = arg
2216
2217            def __eq__(self, other):
2218                return self.arg == other
2219        self.assertEqual(i("str"), "str")
2220
2221    def test_method_register(self):
2222        class A:
2223            @functools.singledispatchmethod
2224            def t(self, arg):
2225                self.arg = "base"
2226            @t.register(int)
2227            def _(self, arg):
2228                self.arg = "int"
2229            @t.register(str)
2230            def _(self, arg):
2231                self.arg = "str"
2232        a = A()
2233
2234        a.t(0)
2235        self.assertEqual(a.arg, "int")
2236        aa = A()
2237        self.assertFalse(hasattr(aa, 'arg'))
2238        a.t('')
2239        self.assertEqual(a.arg, "str")
2240        aa = A()
2241        self.assertFalse(hasattr(aa, 'arg'))
2242        a.t(0.0)
2243        self.assertEqual(a.arg, "base")
2244        aa = A()
2245        self.assertFalse(hasattr(aa, 'arg'))
2246
2247    def test_staticmethod_register(self):
2248        class A:
2249            @functools.singledispatchmethod
2250            @staticmethod
2251            def t(arg):
2252                return arg
2253            @t.register(int)
2254            @staticmethod
2255            def _(arg):
2256                return isinstance(arg, int)
2257            @t.register(str)
2258            @staticmethod
2259            def _(arg):
2260                return isinstance(arg, str)
2261        a = A()
2262
2263        self.assertTrue(A.t(0))
2264        self.assertTrue(A.t(''))
2265        self.assertEqual(A.t(0.0), 0.0)
2266
2267    def test_classmethod_register(self):
2268        class A:
2269            def __init__(self, arg):
2270                self.arg = arg
2271
2272            @functools.singledispatchmethod
2273            @classmethod
2274            def t(cls, arg):
2275                return cls("base")
2276            @t.register(int)
2277            @classmethod
2278            def _(cls, arg):
2279                return cls("int")
2280            @t.register(str)
2281            @classmethod
2282            def _(cls, arg):
2283                return cls("str")
2284
2285        self.assertEqual(A.t(0).arg, "int")
2286        self.assertEqual(A.t('').arg, "str")
2287        self.assertEqual(A.t(0.0).arg, "base")
2288
2289    def test_callable_register(self):
2290        class A:
2291            def __init__(self, arg):
2292                self.arg = arg
2293
2294            @functools.singledispatchmethod
2295            @classmethod
2296            def t(cls, arg):
2297                return cls("base")
2298
2299        @A.t.register(int)
2300        @classmethod
2301        def _(cls, arg):
2302            return cls("int")
2303        @A.t.register(str)
2304        @classmethod
2305        def _(cls, arg):
2306            return cls("str")
2307
2308        self.assertEqual(A.t(0).arg, "int")
2309        self.assertEqual(A.t('').arg, "str")
2310        self.assertEqual(A.t(0.0).arg, "base")
2311
2312    def test_abstractmethod_register(self):
2313        class Abstract(abc.ABCMeta):
2314
2315            @functools.singledispatchmethod
2316            @abc.abstractmethod
2317            def add(self, x, y):
2318                pass
2319
2320        self.assertTrue(Abstract.add.__isabstractmethod__)
2321
2322    def test_type_ann_register(self):
2323        class A:
2324            @functools.singledispatchmethod
2325            def t(self, arg):
2326                return "base"
2327            @t.register
2328            def _(self, arg: int):
2329                return "int"
2330            @t.register
2331            def _(self, arg: str):
2332                return "str"
2333        a = A()
2334
2335        self.assertEqual(a.t(0), "int")
2336        self.assertEqual(a.t(''), "str")
2337        self.assertEqual(a.t(0.0), "base")
2338
2339    def test_invalid_registrations(self):
2340        msg_prefix = "Invalid first argument to `register()`: "
2341        msg_suffix = (
2342            ". Use either `@register(some_class)` or plain `@register` on an "
2343            "annotated function."
2344        )
2345        @functools.singledispatch
2346        def i(arg):
2347            return "base"
2348        with self.assertRaises(TypeError) as exc:
2349            @i.register(42)
2350            def _(arg):
2351                return "I annotated with a non-type"
2352        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2353        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2354        with self.assertRaises(TypeError) as exc:
2355            @i.register
2356            def _(arg):
2357                return "I forgot to annotate"
2358        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2359            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2360        ))
2361        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2362
2363        with self.assertRaises(TypeError) as exc:
2364            @i.register
2365            def _(arg: typing.Iterable[str]):
2366                # At runtime, dispatching on generics is impossible.
2367                # When registering implementations with singledispatch, avoid
2368                # types from `typing`. Instead, annotate with regular types
2369                # or ABCs.
2370                return "I annotated with a generic collection"
2371        self.assertTrue(str(exc.exception).startswith(
2372            "Invalid annotation for 'arg'."
2373        ))
2374        self.assertTrue(str(exc.exception).endswith(
2375            'typing.Iterable[str] is not a class.'
2376        ))
2377
2378    def test_invalid_positional_argument(self):
2379        @functools.singledispatch
2380        def f(*args):
2381            pass
2382        msg = 'f requires at least 1 positional argument'
2383        with self.assertRaisesRegex(TypeError, msg):
2384            f()
2385
2386
2387class CachedCostItem:
2388    _cost = 1
2389
2390    def __init__(self):
2391        self.lock = py_functools.RLock()
2392
2393    @py_functools.cached_property
2394    def cost(self):
2395        """The cost of the item."""
2396        with self.lock:
2397            self._cost += 1
2398        return self._cost
2399
2400
2401class OptionallyCachedCostItem:
2402    _cost = 1
2403
2404    def get_cost(self):
2405        """The cost of the item."""
2406        self._cost += 1
2407        return self._cost
2408
2409    cached_cost = py_functools.cached_property(get_cost)
2410
2411
2412class CachedCostItemWait:
2413
2414    def __init__(self, event):
2415        self._cost = 1
2416        self.lock = py_functools.RLock()
2417        self.event = event
2418
2419    @py_functools.cached_property
2420    def cost(self):
2421        self.event.wait(1)
2422        with self.lock:
2423            self._cost += 1
2424        return self._cost
2425
2426
2427class CachedCostItemWithSlots:
2428    __slots__ = ('_cost')
2429
2430    def __init__(self):
2431        self._cost = 1
2432
2433    @py_functools.cached_property
2434    def cost(self):
2435        raise RuntimeError('never called, slots not supported')
2436
2437
2438class TestCachedProperty(unittest.TestCase):
2439    def test_cached(self):
2440        item = CachedCostItem()
2441        self.assertEqual(item.cost, 2)
2442        self.assertEqual(item.cost, 2) # not 3
2443
2444    def test_cached_attribute_name_differs_from_func_name(self):
2445        item = OptionallyCachedCostItem()
2446        self.assertEqual(item.get_cost(), 2)
2447        self.assertEqual(item.cached_cost, 3)
2448        self.assertEqual(item.get_cost(), 4)
2449        self.assertEqual(item.cached_cost, 3)
2450
2451    def test_threaded(self):
2452        go = threading.Event()
2453        item = CachedCostItemWait(go)
2454
2455        num_threads = 3
2456
2457        orig_si = sys.getswitchinterval()
2458        sys.setswitchinterval(1e-6)
2459        try:
2460            threads = [
2461                threading.Thread(target=lambda: item.cost)
2462                for k in range(num_threads)
2463            ]
2464            with support.start_threads(threads):
2465                go.set()
2466        finally:
2467            sys.setswitchinterval(orig_si)
2468
2469        self.assertEqual(item.cost, 2)
2470
2471    def test_object_with_slots(self):
2472        item = CachedCostItemWithSlots()
2473        with self.assertRaisesRegex(
2474                TypeError,
2475                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2476        ):
2477            item.cost
2478
2479    def test_immutable_dict(self):
2480        class MyMeta(type):
2481            @py_functools.cached_property
2482            def prop(self):
2483                return True
2484
2485        class MyClass(metaclass=MyMeta):
2486            pass
2487
2488        with self.assertRaisesRegex(
2489            TypeError,
2490            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2491        ):
2492            MyClass.prop
2493
2494    def test_reuse_different_names(self):
2495        """Disallow this case because decorated function a would not be cached."""
2496        with self.assertRaises(RuntimeError) as ctx:
2497            class ReusedCachedProperty:
2498                @py_functools.cached_property
2499                def a(self):
2500                    pass
2501
2502                b = a
2503
2504        self.assertEqual(
2505            str(ctx.exception.__context__),
2506            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2507        )
2508
2509    def test_reuse_same_name(self):
2510        """Reusing a cached_property on different classes under the same name is OK."""
2511        counter = 0
2512
2513        @py_functools.cached_property
2514        def _cp(_self):
2515            nonlocal counter
2516            counter += 1
2517            return counter
2518
2519        class A:
2520            cp = _cp
2521
2522        class B:
2523            cp = _cp
2524
2525        a = A()
2526        b = B()
2527
2528        self.assertEqual(a.cp, 1)
2529        self.assertEqual(b.cp, 2)
2530        self.assertEqual(a.cp, 1)
2531
2532    def test_set_name_not_called(self):
2533        cp = py_functools.cached_property(lambda s: None)
2534        class Foo:
2535            pass
2536
2537        Foo.cp = cp
2538
2539        with self.assertRaisesRegex(
2540                TypeError,
2541                "Cannot use cached_property instance without calling __set_name__ on it.",
2542        ):
2543            Foo().cp
2544
2545    def test_access_from_class(self):
2546        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2547
2548    def test_doc(self):
2549        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2550
2551
2552if __name__ == '__main__':
2553    unittest.main()
2554