1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16from weakref import proxy
17import contextlib
18
19import functools
20
21py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
24decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
26@contextlib.contextmanager
27def replaced_module(name, replacement):
28    original_module = sys.modules[name]
29    sys.modules[name] = replacement
30    try:
31        yield
32    finally:
33        sys.modules[name] = original_module
34
35def capture(*args, **kw):
36    """capture all positional and keyword arguments"""
37    return args, kw
38
39
40def signature(part):
41    """ return the signature of a partial object """
42    return (part.func, part.args, part.keywords, part.__dict__)
43
44class MyTuple(tuple):
45    pass
46
47class BadTuple(tuple):
48    def __add__(self, other):
49        return list(self) + list(other)
50
51class MyDict(dict):
52    pass
53
54
55class TestPartial:
56
57    def test_basic_examples(self):
58        p = self.partial(capture, 1, 2, a=10, b=20)
59        self.assertTrue(callable(p))
60        self.assertEqual(p(3, 4, b=30, c=40),
61                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
62        p = self.partial(map, lambda x: x*10)
63        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
64
65    def test_attributes(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        # attributes should be readable
68        self.assertEqual(p.func, capture)
69        self.assertEqual(p.args, (1, 2))
70        self.assertEqual(p.keywords, dict(a=10, b=20))
71
72    def test_argument_checking(self):
73        self.assertRaises(TypeError, self.partial)     # need at least a func arg
74        try:
75            self.partial(2)()
76        except TypeError:
77            pass
78        else:
79            self.fail('First arg not checked for callability')
80
81    def test_protection_of_callers_dict_argument(self):
82        # a caller's dictionary should not be altered by partial
83        def func(a=10, b=20):
84            return a
85        d = {'a':3}
86        p = self.partial(func, a=5)
87        self.assertEqual(p(**d), 3)
88        self.assertEqual(d, {'a':3})
89        p(b=7)
90        self.assertEqual(d, {'a':3})
91
92    def test_kwargs_copy(self):
93        # Issue #29532: Altering a kwarg dictionary passed to a constructor
94        # should not affect a partial object after creation
95        d = {'a': 3}
96        p = self.partial(capture, **d)
97        self.assertEqual(p(), ((), {'a': 3}))
98        d['a'] = 5
99        self.assertEqual(p(), ((), {'a': 3}))
100
101    def test_arg_combinations(self):
102        # exercise special code paths for zero args in either partial
103        # object or the caller
104        p = self.partial(capture)
105        self.assertEqual(p(), ((), {}))
106        self.assertEqual(p(1,2), ((1,2), {}))
107        p = self.partial(capture, 1, 2)
108        self.assertEqual(p(), ((1,2), {}))
109        self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111    def test_kw_combinations(self):
112        # exercise special code paths for no keyword args in
113        # either the partial object or the caller
114        p = self.partial(capture)
115        self.assertEqual(p.keywords, {})
116        self.assertEqual(p(), ((), {}))
117        self.assertEqual(p(a=1), ((), {'a':1}))
118        p = self.partial(capture, a=1)
119        self.assertEqual(p.keywords, {'a':1})
120        self.assertEqual(p(), ((), {'a':1}))
121        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122        # keyword args in the call override those in the partial object
123        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125    def test_positional(self):
126        # make sure positional arguments are captured correctly
127        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
128            p = self.partial(capture, *args)
129            expected = args + ('x',)
130            got, empty = p('x')
131            self.assertTrue(expected == got and empty == {})
132
133    def test_keyword(self):
134        # make sure keyword arguments are captured correctly
135        for a in ['a', 0, None, 3.5]:
136            p = self.partial(capture, a=a)
137            expected = {'a':a,'x':None}
138            empty, got = p(x=None)
139            self.assertTrue(expected == got and empty == ())
140
141    def test_no_side_effects(self):
142        # make sure there are no side effects that affect subsequent calls
143        p = self.partial(capture, 0, a=1)
144        args1, kw1 = p(1, b=2)
145        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
146        args2, kw2 = p()
147        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
148
149    def test_error_propagation(self):
150        def f(x, y):
151            x / y
152        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
156
157    def test_weakref(self):
158        f = self.partial(int, base=16)
159        p = proxy(f)
160        self.assertEqual(f.func, p.func)
161        f = None
162        self.assertRaises(ReferenceError, getattr, p, 'func')
163
164    def test_with_bound_and_unbound_methods(self):
165        data = list(map(str, range(10)))
166        join = self.partial(str.join, '')
167        self.assertEqual(join(data), '0123456789')
168        join = self.partial(''.join)
169        self.assertEqual(join(data), '0123456789')
170
171    def test_nested_optimization(self):
172        partial = self.partial
173        inner = partial(signature, 'asdf')
174        nested = partial(inner, bar=True)
175        flat = partial(signature, 'asdf', bar=True)
176        self.assertEqual(signature(nested), signature(flat))
177
178    def test_nested_partial_with_attribute(self):
179        # see issue 25137
180        partial = self.partial
181
182        def foo(bar):
183            return bar
184
185        p = partial(foo, 'first')
186        p2 = partial(p, 'second')
187        p2.new_attr = 'spam'
188        self.assertEqual(p2.new_attr, 'spam')
189
190    def test_repr(self):
191        args = (object(), object())
192        args_repr = ', '.join(repr(a) for a in args)
193        kwargs = {'a': object(), 'b': object()}
194        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195                        'b={b!r}, a={a!r}'.format_map(kwargs)]
196        if self.partial in (c_functools.partial, py_functools.partial):
197            name = 'functools.partial'
198        else:
199            name = self.partial.__name__
200
201        f = self.partial(capture)
202        self.assertEqual(f'{name}({capture!r})', repr(f))
203
204        f = self.partial(capture, *args)
205        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
206
207        f = self.partial(capture, **kwargs)
208        self.assertIn(repr(f),
209                      [f'{name}({capture!r}, {kwargs_repr})'
210                       for kwargs_repr in kwargs_reprs])
211
212        f = self.partial(capture, *args, **kwargs)
213        self.assertIn(repr(f),
214                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
215                       for kwargs_repr in kwargs_reprs])
216
217    def test_recursive_repr(self):
218        if self.partial in (c_functools.partial, py_functools.partial):
219            name = 'functools.partial'
220        else:
221            name = self.partial.__name__
222
223        f = self.partial(capture)
224        f.__setstate__((f, (), {}, {}))
225        try:
226            self.assertEqual(repr(f), '%s(...)' % (name,))
227        finally:
228            f.__setstate__((capture, (), {}, {}))
229
230        f = self.partial(capture)
231        f.__setstate__((capture, (f,), {}, {}))
232        try:
233            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
234        finally:
235            f.__setstate__((capture, (), {}, {}))
236
237        f = self.partial(capture)
238        f.__setstate__((capture, (), {'a': f}, {}))
239        try:
240            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
241        finally:
242            f.__setstate__((capture, (), {}, {}))
243
244    def test_pickle(self):
245        with self.AllowPickle():
246            f = self.partial(signature, ['asdf'], bar=[True])
247            f.attr = []
248            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249                f_copy = pickle.loads(pickle.dumps(f, proto))
250                self.assertEqual(signature(f_copy), signature(f))
251
252    def test_copy(self):
253        f = self.partial(signature, ['asdf'], bar=[True])
254        f.attr = []
255        f_copy = copy.copy(f)
256        self.assertEqual(signature(f_copy), signature(f))
257        self.assertIs(f_copy.attr, f.attr)
258        self.assertIs(f_copy.args, f.args)
259        self.assertIs(f_copy.keywords, f.keywords)
260
261    def test_deepcopy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.deepcopy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIsNot(f_copy.attr, f.attr)
267        self.assertIsNot(f_copy.args, f.args)
268        self.assertIsNot(f_copy.args[0], f.args[0])
269        self.assertIsNot(f_copy.keywords, f.keywords)
270        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272    def test_setstate(self):
273        f = self.partial(signature)
274        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
275
276        self.assertEqual(signature(f),
277                         (capture, (1,), dict(a=10), dict(attr=[])))
278        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280        f.__setstate__((capture, (1,), dict(a=10), None))
281
282        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285        f.__setstate__((capture, (1,), None, None))
286        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288        self.assertEqual(f(2), ((1, 2), {}))
289        self.assertEqual(f(), ((1,), {}))
290
291        f.__setstate__((capture, (), {}, None))
292        self.assertEqual(signature(f), (capture, (), {}, {}))
293        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294        self.assertEqual(f(2), ((2,), {}))
295        self.assertEqual(f(), ((), {}))
296
297    def test_setstate_errors(self):
298        f = self.partial(signature)
299        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307    def test_setstate_subclasses(self):
308        f = self.partial(signature)
309        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310        s = signature(f)
311        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312        self.assertIs(type(s[1]), tuple)
313        self.assertIs(type(s[2]), dict)
314        r = f()
315        self.assertEqual(r, ((1,), {'a': 10}))
316        self.assertIs(type(r[0]), tuple)
317        self.assertIs(type(r[1]), dict)
318
319        f.__setstate__((capture, BadTuple((1,)), {}, None))
320        s = signature(f)
321        self.assertEqual(s, (capture, (1,), {}, {}))
322        self.assertIs(type(s[1]), tuple)
323        r = f(2)
324        self.assertEqual(r, ((1, 2), {}))
325        self.assertIs(type(r[0]), tuple)
326
327    def test_recursive_pickle(self):
328        with self.AllowPickle():
329            f = self.partial(capture)
330            f.__setstate__((f, (), {}, {}))
331            try:
332                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333                    with self.assertRaises(RecursionError):
334                        pickle.dumps(f, proto)
335            finally:
336                f.__setstate__((capture, (), {}, {}))
337
338            f = self.partial(capture)
339            f.__setstate__((capture, (f,), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    f_copy = pickle.loads(pickle.dumps(f, proto))
343                    try:
344                        self.assertIs(f_copy.args[0], f_copy)
345                    finally:
346                        f_copy.__setstate__((capture, (), {}, {}))
347            finally:
348                f.__setstate__((capture, (), {}, {}))
349
350            f = self.partial(capture)
351            f.__setstate__((capture, (), {'a': f}, {}))
352            try:
353                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354                    f_copy = pickle.loads(pickle.dumps(f, proto))
355                    try:
356                        self.assertIs(f_copy.keywords['a'], f_copy)
357                    finally:
358                        f_copy.__setstate__((capture, (), {}, {}))
359            finally:
360                f.__setstate__((capture, (), {}, {}))
361
362    # Issue 6083: Reference counting bug
363    def test_setstate_refcount(self):
364        class BadSequence:
365            def __len__(self):
366                return 4
367            def __getitem__(self, key):
368                if key == 0:
369                    return max
370                elif key == 1:
371                    return tuple(range(1000000))
372                elif key in (2, 3):
373                    return {}
374                raise IndexError
375
376        f = self.partial(object)
377        self.assertRaises(TypeError, f.__setstate__, BadSequence())
378
379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381    if c_functools:
382        partial = c_functools.partial
383
384    class AllowPickle:
385        def __enter__(self):
386            return self
387        def __exit__(self, type, value, tb):
388            return False
389
390    def test_attributes_unwritable(self):
391        # attributes should not be writable
392        p = self.partial(capture, 1, 2, a=10, b=20)
393        self.assertRaises(AttributeError, setattr, p, 'func', map)
394        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397        p = self.partial(hex)
398        try:
399            del p.__dict__
400        except TypeError:
401            pass
402        else:
403            self.fail('partial object allowed __dict__ to be deleted')
404
405    def test_manually_adding_non_string_keyword(self):
406        p = self.partial(capture)
407        # Adding a non-string/unicode keyword to partial kwargs
408        p.keywords[1234] = 'value'
409        r = repr(p)
410        self.assertIn('1234', r)
411        self.assertIn("'value'", r)
412        with self.assertRaises(TypeError):
413            p()
414
415    def test_keystr_replaces_value(self):
416        p = self.partial(capture)
417
418        class MutatesYourDict(object):
419            def __str__(self):
420                p.keywords[self] = ['sth2']
421                return 'astr'
422
423        # Replacing the value during key formatting should keep the original
424        # value alive (at least long enough).
425        p.keywords[MutatesYourDict()] = ['sth']
426        r = repr(p)
427        self.assertIn('astr', r)
428        self.assertIn("['sth']", r)
429
430
431class TestPartialPy(TestPartial, unittest.TestCase):
432    partial = py_functools.partial
433
434    class AllowPickle:
435        def __init__(self):
436            self._cm = replaced_module("functools", py_functools)
437        def __enter__(self):
438            return self._cm.__enter__()
439        def __exit__(self, type, value, tb):
440            return self._cm.__exit__(type, value, tb)
441
442if c_functools:
443    class CPartialSubclass(c_functools.partial):
444        pass
445
446class PyPartialSubclass(py_functools.partial):
447    pass
448
449@unittest.skipUnless(c_functools, 'requires the C _functools module')
450class TestPartialCSubclass(TestPartialC):
451    if c_functools:
452        partial = CPartialSubclass
453
454    # partial subclasses are not optimized for nested calls
455    test_nested_optimization = None
456
457class TestPartialPySubclass(TestPartialPy):
458    partial = PyPartialSubclass
459
460class TestPartialMethod(unittest.TestCase):
461
462    class A(object):
463        nothing = functools.partialmethod(capture)
464        positional = functools.partialmethod(capture, 1)
465        keywords = functools.partialmethod(capture, a=2)
466        both = functools.partialmethod(capture, 3, b=4)
467        spec_keywords = functools.partialmethod(capture, self=1, func=2)
468
469        nested = functools.partialmethod(positional, 5)
470
471        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473        static = functools.partialmethod(staticmethod(capture), 8)
474        cls = functools.partialmethod(classmethod(capture), d=9)
475
476    a = A()
477
478    def test_arg_combinations(self):
479        self.assertEqual(self.a.nothing(), ((self.a,), {}))
480        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
501        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
503    def test_nested(self):
504        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511    def test_over_partial(self):
512        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519    def test_bound_method_introspection(self):
520        obj = self.a
521        self.assertIs(obj.both.__self__, obj)
522        self.assertIs(obj.nested.__self__, obj)
523        self.assertIs(obj.over_partial.__self__, obj)
524        self.assertIs(obj.cls.__self__, self.A)
525        self.assertIs(self.A.cls.__self__, self.A)
526
527    def test_unbound_method_retrieval(self):
528        obj = self.A
529        self.assertFalse(hasattr(obj.both, "__self__"))
530        self.assertFalse(hasattr(obj.nested, "__self__"))
531        self.assertFalse(hasattr(obj.over_partial, "__self__"))
532        self.assertFalse(hasattr(obj.static, "__self__"))
533        self.assertFalse(hasattr(self.a.static, "__self__"))
534
535    def test_descriptors(self):
536        for obj in [self.A, self.a]:
537            with self.subTest(obj=obj):
538                self.assertEqual(obj.static(), ((8,), {}))
539                self.assertEqual(obj.static(5), ((8, 5), {}))
540                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548    def test_overriding_keywords(self):
549        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552    def test_invalid_args(self):
553        with self.assertRaises(TypeError):
554            class B(object):
555                method = functools.partialmethod(None, 1)
556        with self.assertRaises(TypeError):
557            class B:
558                method = functools.partialmethod()
559        class B:
560            method = functools.partialmethod(func=capture, a=1)
561        b = B()
562        self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
563
564    def test_repr(self):
565        self.assertEqual(repr(vars(self.A)['both']),
566                         'functools.partialmethod({}, 3, b=4)'.format(capture))
567
568    def test_abstract(self):
569        class Abstract(abc.ABCMeta):
570
571            @abc.abstractmethod
572            def add(self, x, y):
573                pass
574
575            add5 = functools.partialmethod(add, 5)
576
577        self.assertTrue(Abstract.add.__isabstractmethod__)
578        self.assertTrue(Abstract.add5.__isabstractmethod__)
579
580        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
581            self.assertFalse(getattr(func, '__isabstractmethod__', False))
582
583
584class TestUpdateWrapper(unittest.TestCase):
585
586    def check_wrapper(self, wrapper, wrapped,
587                      assigned=functools.WRAPPER_ASSIGNMENTS,
588                      updated=functools.WRAPPER_UPDATES):
589        # Check attributes were assigned
590        for name in assigned:
591            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
592        # Check attributes were updated
593        for name in updated:
594            wrapper_attr = getattr(wrapper, name)
595            wrapped_attr = getattr(wrapped, name)
596            for key in wrapped_attr:
597                if name == "__dict__" and key == "__wrapped__":
598                    # __wrapped__ is overwritten by the update code
599                    continue
600                self.assertIs(wrapped_attr[key], wrapper_attr[key])
601        # Check __wrapped__
602        self.assertIs(wrapper.__wrapped__, wrapped)
603
604
605    def _default_update(self):
606        def f(a:'This is a new annotation'):
607            """This is a test"""
608            pass
609        f.attr = 'This is also a test'
610        f.__wrapped__ = "This is a bald faced lie"
611        def wrapper(b:'This is the prior annotation'):
612            pass
613        functools.update_wrapper(wrapper, f)
614        return wrapper, f
615
616    def test_default_update(self):
617        wrapper, f = self._default_update()
618        self.check_wrapper(wrapper, f)
619        self.assertIs(wrapper.__wrapped__, f)
620        self.assertEqual(wrapper.__name__, 'f')
621        self.assertEqual(wrapper.__qualname__, f.__qualname__)
622        self.assertEqual(wrapper.attr, 'This is also a test')
623        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
624        self.assertNotIn('b', wrapper.__annotations__)
625
626    @unittest.skipIf(sys.flags.optimize >= 2,
627                     "Docstrings are omitted with -O2 and above")
628    def test_default_update_doc(self):
629        wrapper, f = self._default_update()
630        self.assertEqual(wrapper.__doc__, 'This is a test')
631
632    def test_no_update(self):
633        def f():
634            """This is a test"""
635            pass
636        f.attr = 'This is also a test'
637        def wrapper():
638            pass
639        functools.update_wrapper(wrapper, f, (), ())
640        self.check_wrapper(wrapper, f, (), ())
641        self.assertEqual(wrapper.__name__, 'wrapper')
642        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
643        self.assertEqual(wrapper.__doc__, None)
644        self.assertEqual(wrapper.__annotations__, {})
645        self.assertFalse(hasattr(wrapper, 'attr'))
646
647    def test_selective_update(self):
648        def f():
649            pass
650        f.attr = 'This is a different test'
651        f.dict_attr = dict(a=1, b=2, c=3)
652        def wrapper():
653            pass
654        wrapper.dict_attr = {}
655        assign = ('attr',)
656        update = ('dict_attr',)
657        functools.update_wrapper(wrapper, f, assign, update)
658        self.check_wrapper(wrapper, f, assign, update)
659        self.assertEqual(wrapper.__name__, 'wrapper')
660        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
661        self.assertEqual(wrapper.__doc__, None)
662        self.assertEqual(wrapper.attr, 'This is a different test')
663        self.assertEqual(wrapper.dict_attr, f.dict_attr)
664
665    def test_missing_attributes(self):
666        def f():
667            pass
668        def wrapper():
669            pass
670        wrapper.dict_attr = {}
671        assign = ('attr',)
672        update = ('dict_attr',)
673        # Missing attributes on wrapped object are ignored
674        functools.update_wrapper(wrapper, f, assign, update)
675        self.assertNotIn('attr', wrapper.__dict__)
676        self.assertEqual(wrapper.dict_attr, {})
677        # Wrapper must have expected attributes for updating
678        del wrapper.dict_attr
679        with self.assertRaises(AttributeError):
680            functools.update_wrapper(wrapper, f, assign, update)
681        wrapper.dict_attr = 1
682        with self.assertRaises(AttributeError):
683            functools.update_wrapper(wrapper, f, assign, update)
684
685    @support.requires_docstrings
686    @unittest.skipIf(sys.flags.optimize >= 2,
687                     "Docstrings are omitted with -O2 and above")
688    def test_builtin_update(self):
689        # Test for bug #1576241
690        def wrapper():
691            pass
692        functools.update_wrapper(wrapper, max)
693        self.assertEqual(wrapper.__name__, 'max')
694        self.assertTrue(wrapper.__doc__.startswith('max('))
695        self.assertEqual(wrapper.__annotations__, {})
696
697
698class TestWraps(TestUpdateWrapper):
699
700    def _default_update(self):
701        def f():
702            """This is a test"""
703            pass
704        f.attr = 'This is also a test'
705        f.__wrapped__ = "This is still a bald faced lie"
706        @functools.wraps(f)
707        def wrapper():
708            pass
709        return wrapper, f
710
711    def test_default_update(self):
712        wrapper, f = self._default_update()
713        self.check_wrapper(wrapper, f)
714        self.assertEqual(wrapper.__name__, 'f')
715        self.assertEqual(wrapper.__qualname__, f.__qualname__)
716        self.assertEqual(wrapper.attr, 'This is also a test')
717
718    @unittest.skipIf(sys.flags.optimize >= 2,
719                     "Docstrings are omitted with -O2 and above")
720    def test_default_update_doc(self):
721        wrapper, _ = self._default_update()
722        self.assertEqual(wrapper.__doc__, 'This is a test')
723
724    def test_no_update(self):
725        def f():
726            """This is a test"""
727            pass
728        f.attr = 'This is also a test'
729        @functools.wraps(f, (), ())
730        def wrapper():
731            pass
732        self.check_wrapper(wrapper, f, (), ())
733        self.assertEqual(wrapper.__name__, 'wrapper')
734        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
735        self.assertEqual(wrapper.__doc__, None)
736        self.assertFalse(hasattr(wrapper, 'attr'))
737
738    def test_selective_update(self):
739        def f():
740            pass
741        f.attr = 'This is a different test'
742        f.dict_attr = dict(a=1, b=2, c=3)
743        def add_dict_attr(f):
744            f.dict_attr = {}
745            return f
746        assign = ('attr',)
747        update = ('dict_attr',)
748        @functools.wraps(f, assign, update)
749        @add_dict_attr
750        def wrapper():
751            pass
752        self.check_wrapper(wrapper, f, assign, update)
753        self.assertEqual(wrapper.__name__, 'wrapper')
754        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
755        self.assertEqual(wrapper.__doc__, None)
756        self.assertEqual(wrapper.attr, 'This is a different test')
757        self.assertEqual(wrapper.dict_attr, f.dict_attr)
758
759@unittest.skipUnless(c_functools, 'requires the C _functools module')
760class TestReduce(unittest.TestCase):
761    if c_functools:
762        func = c_functools.reduce
763
764    def test_reduce(self):
765        class Squares:
766            def __init__(self, max):
767                self.max = max
768                self.sofar = []
769
770            def __len__(self):
771                return len(self.sofar)
772
773            def __getitem__(self, i):
774                if not 0 <= i < self.max: raise IndexError
775                n = len(self.sofar)
776                while n <= i:
777                    self.sofar.append(n*n)
778                    n += 1
779                return self.sofar[i]
780        def add(x, y):
781            return x + y
782        self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
783        self.assertEqual(
784            self.func(add, [['a', 'c'], [], ['d', 'w']], []),
785            ['a','c','d','w']
786        )
787        self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
788        self.assertEqual(
789            self.func(lambda x, y: x*y, range(2,21), 1),
790            2432902008176640000
791        )
792        self.assertEqual(self.func(add, Squares(10)), 285)
793        self.assertEqual(self.func(add, Squares(10), 0), 285)
794        self.assertEqual(self.func(add, Squares(0), 0), 0)
795        self.assertRaises(TypeError, self.func)
796        self.assertRaises(TypeError, self.func, 42, 42)
797        self.assertRaises(TypeError, self.func, 42, 42, 42)
798        self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
799        self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
800        self.assertRaises(TypeError, self.func, 42, (42, 42))
801        self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
802        self.assertRaises(TypeError, self.func, add, "")
803        self.assertRaises(TypeError, self.func, add, ())
804        self.assertRaises(TypeError, self.func, add, object())
805
806        class TestFailingIter:
807            def __iter__(self):
808                raise RuntimeError
809        self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
810
811        self.assertEqual(self.func(add, [], None), None)
812        self.assertEqual(self.func(add, [], 42), 42)
813
814        class BadSeq:
815            def __getitem__(self, index):
816                raise ValueError
817        self.assertRaises(ValueError, self.func, 42, BadSeq())
818
819    # Test reduce()'s use of iterators.
820    def test_iterator_usage(self):
821        class SequenceClass:
822            def __init__(self, n):
823                self.n = n
824            def __getitem__(self, i):
825                if 0 <= i < self.n:
826                    return i
827                else:
828                    raise IndexError
829
830        from operator import add
831        self.assertEqual(self.func(add, SequenceClass(5)), 10)
832        self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
833        self.assertRaises(TypeError, self.func, add, SequenceClass(0))
834        self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
835        self.assertEqual(self.func(add, SequenceClass(1)), 0)
836        self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
837
838        d = {"one": 1, "two": 2, "three": 3}
839        self.assertEqual(self.func(add, d), "".join(d.keys()))
840
841
842class TestCmpToKey:
843
844    def test_cmp_to_key(self):
845        def cmp1(x, y):
846            return (x > y) - (x < y)
847        key = self.cmp_to_key(cmp1)
848        self.assertEqual(key(3), key(3))
849        self.assertGreater(key(3), key(1))
850        self.assertGreaterEqual(key(3), key(3))
851
852        def cmp2(x, y):
853            return int(x) - int(y)
854        key = self.cmp_to_key(cmp2)
855        self.assertEqual(key(4.0), key('4'))
856        self.assertLess(key(2), key('35'))
857        self.assertLessEqual(key(2), key('35'))
858        self.assertNotEqual(key(2), key('35'))
859
860    def test_cmp_to_key_arguments(self):
861        def cmp1(x, y):
862            return (x > y) - (x < y)
863        key = self.cmp_to_key(mycmp=cmp1)
864        self.assertEqual(key(obj=3), key(obj=3))
865        self.assertGreater(key(obj=3), key(obj=1))
866        with self.assertRaises((TypeError, AttributeError)):
867            key(3) > 1    # rhs is not a K object
868        with self.assertRaises((TypeError, AttributeError)):
869            1 < key(3)    # lhs is not a K object
870        with self.assertRaises(TypeError):
871            key = self.cmp_to_key()             # too few args
872        with self.assertRaises(TypeError):
873            key = self.cmp_to_key(cmp1, None)   # too many args
874        key = self.cmp_to_key(cmp1)
875        with self.assertRaises(TypeError):
876            key()                                    # too few args
877        with self.assertRaises(TypeError):
878            key(None, None)                          # too many args
879
880    def test_bad_cmp(self):
881        def cmp1(x, y):
882            raise ZeroDivisionError
883        key = self.cmp_to_key(cmp1)
884        with self.assertRaises(ZeroDivisionError):
885            key(3) > key(1)
886
887        class BadCmp:
888            def __lt__(self, other):
889                raise ZeroDivisionError
890        def cmp1(x, y):
891            return BadCmp()
892        with self.assertRaises(ZeroDivisionError):
893            key(3) > key(1)
894
895    def test_obj_field(self):
896        def cmp1(x, y):
897            return (x > y) - (x < y)
898        key = self.cmp_to_key(mycmp=cmp1)
899        self.assertEqual(key(50).obj, 50)
900
901    def test_sort_int(self):
902        def mycmp(x, y):
903            return y - x
904        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
905                         [4, 3, 2, 1, 0])
906
907    def test_sort_int_str(self):
908        def mycmp(x, y):
909            x, y = int(x), int(y)
910            return (x > y) - (x < y)
911        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
912        values = sorted(values, key=self.cmp_to_key(mycmp))
913        self.assertEqual([int(value) for value in values],
914                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
915
916    def test_hash(self):
917        def mycmp(x, y):
918            return y - x
919        key = self.cmp_to_key(mycmp)
920        k = key(10)
921        self.assertRaises(TypeError, hash, k)
922        self.assertNotIsInstance(k, collections.abc.Hashable)
923
924
925@unittest.skipUnless(c_functools, 'requires the C _functools module')
926class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
927    if c_functools:
928        cmp_to_key = c_functools.cmp_to_key
929
930
931class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
932    cmp_to_key = staticmethod(py_functools.cmp_to_key)
933
934
935class TestTotalOrdering(unittest.TestCase):
936
937    def test_total_ordering_lt(self):
938        @functools.total_ordering
939        class A:
940            def __init__(self, value):
941                self.value = value
942            def __lt__(self, other):
943                return self.value < other.value
944            def __eq__(self, other):
945                return self.value == other.value
946        self.assertTrue(A(1) < A(2))
947        self.assertTrue(A(2) > A(1))
948        self.assertTrue(A(1) <= A(2))
949        self.assertTrue(A(2) >= A(1))
950        self.assertTrue(A(2) <= A(2))
951        self.assertTrue(A(2) >= A(2))
952        self.assertFalse(A(1) > A(2))
953
954    def test_total_ordering_le(self):
955        @functools.total_ordering
956        class A:
957            def __init__(self, value):
958                self.value = value
959            def __le__(self, other):
960                return self.value <= other.value
961            def __eq__(self, other):
962                return self.value == other.value
963        self.assertTrue(A(1) < A(2))
964        self.assertTrue(A(2) > A(1))
965        self.assertTrue(A(1) <= A(2))
966        self.assertTrue(A(2) >= A(1))
967        self.assertTrue(A(2) <= A(2))
968        self.assertTrue(A(2) >= A(2))
969        self.assertFalse(A(1) >= A(2))
970
971    def test_total_ordering_gt(self):
972        @functools.total_ordering
973        class A:
974            def __init__(self, value):
975                self.value = value
976            def __gt__(self, other):
977                return self.value > other.value
978            def __eq__(self, other):
979                return self.value == other.value
980        self.assertTrue(A(1) < A(2))
981        self.assertTrue(A(2) > A(1))
982        self.assertTrue(A(1) <= A(2))
983        self.assertTrue(A(2) >= A(1))
984        self.assertTrue(A(2) <= A(2))
985        self.assertTrue(A(2) >= A(2))
986        self.assertFalse(A(2) < A(1))
987
988    def test_total_ordering_ge(self):
989        @functools.total_ordering
990        class A:
991            def __init__(self, value):
992                self.value = value
993            def __ge__(self, other):
994                return self.value >= other.value
995            def __eq__(self, other):
996                return self.value == other.value
997        self.assertTrue(A(1) < A(2))
998        self.assertTrue(A(2) > A(1))
999        self.assertTrue(A(1) <= A(2))
1000        self.assertTrue(A(2) >= A(1))
1001        self.assertTrue(A(2) <= A(2))
1002        self.assertTrue(A(2) >= A(2))
1003        self.assertFalse(A(2) <= A(1))
1004
1005    def test_total_ordering_no_overwrite(self):
1006        # new methods should not overwrite existing
1007        @functools.total_ordering
1008        class A(int):
1009            pass
1010        self.assertTrue(A(1) < A(2))
1011        self.assertTrue(A(2) > A(1))
1012        self.assertTrue(A(1) <= A(2))
1013        self.assertTrue(A(2) >= A(1))
1014        self.assertTrue(A(2) <= A(2))
1015        self.assertTrue(A(2) >= A(2))
1016
1017    def test_no_operations_defined(self):
1018        with self.assertRaises(ValueError):
1019            @functools.total_ordering
1020            class A:
1021                pass
1022
1023    def test_type_error_when_not_implemented(self):
1024        # bug 10042; ensure stack overflow does not occur
1025        # when decorated types return NotImplemented
1026        @functools.total_ordering
1027        class ImplementsLessThan:
1028            def __init__(self, value):
1029                self.value = value
1030            def __eq__(self, other):
1031                if isinstance(other, ImplementsLessThan):
1032                    return self.value == other.value
1033                return False
1034            def __lt__(self, other):
1035                if isinstance(other, ImplementsLessThan):
1036                    return self.value < other.value
1037                return NotImplemented
1038
1039        @functools.total_ordering
1040        class ImplementsGreaterThan:
1041            def __init__(self, value):
1042                self.value = value
1043            def __eq__(self, other):
1044                if isinstance(other, ImplementsGreaterThan):
1045                    return self.value == other.value
1046                return False
1047            def __gt__(self, other):
1048                if isinstance(other, ImplementsGreaterThan):
1049                    return self.value > other.value
1050                return NotImplemented
1051
1052        @functools.total_ordering
1053        class ImplementsLessThanEqualTo:
1054            def __init__(self, value):
1055                self.value = value
1056            def __eq__(self, other):
1057                if isinstance(other, ImplementsLessThanEqualTo):
1058                    return self.value == other.value
1059                return False
1060            def __le__(self, other):
1061                if isinstance(other, ImplementsLessThanEqualTo):
1062                    return self.value <= other.value
1063                return NotImplemented
1064
1065        @functools.total_ordering
1066        class ImplementsGreaterThanEqualTo:
1067            def __init__(self, value):
1068                self.value = value
1069            def __eq__(self, other):
1070                if isinstance(other, ImplementsGreaterThanEqualTo):
1071                    return self.value == other.value
1072                return False
1073            def __ge__(self, other):
1074                if isinstance(other, ImplementsGreaterThanEqualTo):
1075                    return self.value >= other.value
1076                return NotImplemented
1077
1078        @functools.total_ordering
1079        class ComparatorNotImplemented:
1080            def __init__(self, value):
1081                self.value = value
1082            def __eq__(self, other):
1083                if isinstance(other, ComparatorNotImplemented):
1084                    return self.value == other.value
1085                return False
1086            def __lt__(self, other):
1087                return NotImplemented
1088
1089        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1090            ImplementsLessThan(-1) < 1
1091
1092        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1093            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1094
1095        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1096            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1097
1098        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1099            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1100
1101        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1102            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1103
1104        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1105            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1106
1107        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1108            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1109
1110        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1111            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1112
1113        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1114            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1115
1116        with self.subTest("GE when equal"):
1117            a = ComparatorNotImplemented(8)
1118            b = ComparatorNotImplemented(8)
1119            self.assertEqual(a, b)
1120            with self.assertRaises(TypeError):
1121                a >= b
1122
1123        with self.subTest("LE when equal"):
1124            a = ComparatorNotImplemented(9)
1125            b = ComparatorNotImplemented(9)
1126            self.assertEqual(a, b)
1127            with self.assertRaises(TypeError):
1128                a <= b
1129
1130    def test_pickle(self):
1131        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1132            for name in '__lt__', '__gt__', '__le__', '__ge__':
1133                with self.subTest(method=name, proto=proto):
1134                    method = getattr(Orderable_LT, name)
1135                    method_copy = pickle.loads(pickle.dumps(method, proto))
1136                    self.assertIs(method_copy, method)
1137
1138@functools.total_ordering
1139class Orderable_LT:
1140    def __init__(self, value):
1141        self.value = value
1142    def __lt__(self, other):
1143        return self.value < other.value
1144    def __eq__(self, other):
1145        return self.value == other.value
1146
1147
1148class TestLRU:
1149
1150    def test_lru(self):
1151        def orig(x, y):
1152            return 3 * x + y
1153        f = self.module.lru_cache(maxsize=20)(orig)
1154        hits, misses, maxsize, currsize = f.cache_info()
1155        self.assertEqual(maxsize, 20)
1156        self.assertEqual(currsize, 0)
1157        self.assertEqual(hits, 0)
1158        self.assertEqual(misses, 0)
1159
1160        domain = range(5)
1161        for i in range(1000):
1162            x, y = choice(domain), choice(domain)
1163            actual = f(x, y)
1164            expected = orig(x, y)
1165            self.assertEqual(actual, expected)
1166        hits, misses, maxsize, currsize = f.cache_info()
1167        self.assertTrue(hits > misses)
1168        self.assertEqual(hits + misses, 1000)
1169        self.assertEqual(currsize, 20)
1170
1171        f.cache_clear()   # test clearing
1172        hits, misses, maxsize, currsize = f.cache_info()
1173        self.assertEqual(hits, 0)
1174        self.assertEqual(misses, 0)
1175        self.assertEqual(currsize, 0)
1176        f(x, y)
1177        hits, misses, maxsize, currsize = f.cache_info()
1178        self.assertEqual(hits, 0)
1179        self.assertEqual(misses, 1)
1180        self.assertEqual(currsize, 1)
1181
1182        # Test bypassing the cache
1183        self.assertIs(f.__wrapped__, orig)
1184        f.__wrapped__(x, y)
1185        hits, misses, maxsize, currsize = f.cache_info()
1186        self.assertEqual(hits, 0)
1187        self.assertEqual(misses, 1)
1188        self.assertEqual(currsize, 1)
1189
1190        # test size zero (which means "never-cache")
1191        @self.module.lru_cache(0)
1192        def f():
1193            nonlocal f_cnt
1194            f_cnt += 1
1195            return 20
1196        self.assertEqual(f.cache_info().maxsize, 0)
1197        f_cnt = 0
1198        for i in range(5):
1199            self.assertEqual(f(), 20)
1200        self.assertEqual(f_cnt, 5)
1201        hits, misses, maxsize, currsize = f.cache_info()
1202        self.assertEqual(hits, 0)
1203        self.assertEqual(misses, 5)
1204        self.assertEqual(currsize, 0)
1205
1206        # test size one
1207        @self.module.lru_cache(1)
1208        def f():
1209            nonlocal f_cnt
1210            f_cnt += 1
1211            return 20
1212        self.assertEqual(f.cache_info().maxsize, 1)
1213        f_cnt = 0
1214        for i in range(5):
1215            self.assertEqual(f(), 20)
1216        self.assertEqual(f_cnt, 1)
1217        hits, misses, maxsize, currsize = f.cache_info()
1218        self.assertEqual(hits, 4)
1219        self.assertEqual(misses, 1)
1220        self.assertEqual(currsize, 1)
1221
1222        # test size two
1223        @self.module.lru_cache(2)
1224        def f(x):
1225            nonlocal f_cnt
1226            f_cnt += 1
1227            return x*10
1228        self.assertEqual(f.cache_info().maxsize, 2)
1229        f_cnt = 0
1230        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1231            #    *  *              *                          *
1232            self.assertEqual(f(x), x*10)
1233        self.assertEqual(f_cnt, 4)
1234        hits, misses, maxsize, currsize = f.cache_info()
1235        self.assertEqual(hits, 12)
1236        self.assertEqual(misses, 4)
1237        self.assertEqual(currsize, 2)
1238
1239    def test_lru_bug_35780(self):
1240        # C version of the lru_cache was not checking to see if
1241        # the user function call has already modified the cache
1242        # (this arises in recursive calls and in multi-threading).
1243        # This cause the cache to have orphan links not referenced
1244        # by the cache dictionary.
1245
1246        once = True                 # Modified by f(x) below
1247
1248        @self.module.lru_cache(maxsize=10)
1249        def f(x):
1250            nonlocal once
1251            rv = f'.{x}.'
1252            if x == 20 and once:
1253                once = False
1254                rv = f(x)
1255            return rv
1256
1257        # Fill the cache
1258        for x in range(15):
1259            self.assertEqual(f(x), f'.{x}.')
1260        self.assertEqual(f.cache_info().currsize, 10)
1261
1262        # Make a recursive call and make sure the cache remains full
1263        self.assertEqual(f(20), '.20.')
1264        self.assertEqual(f.cache_info().currsize, 10)
1265
1266    def test_lru_bug_36650(self):
1267        # C version of lru_cache was treating a call with an empty **kwargs
1268        # dictionary as being distinct from a call with no keywords at all.
1269        # This did not result in an incorrect answer, but it did trigger
1270        # an unexpected cache miss.
1271
1272        @self.module.lru_cache()
1273        def f(x):
1274            pass
1275
1276        f(0)
1277        f(0, **{})
1278        self.assertEqual(f.cache_info().hits, 1)
1279
1280    def test_lru_hash_only_once(self):
1281        # To protect against weird reentrancy bugs and to improve
1282        # efficiency when faced with slow __hash__ methods, the
1283        # LRU cache guarantees that it will only call __hash__
1284        # only once per use as an argument to the cached function.
1285
1286        @self.module.lru_cache(maxsize=1)
1287        def f(x, y):
1288            return x * 3 + y
1289
1290        # Simulate the integer 5
1291        mock_int = unittest.mock.Mock()
1292        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1293        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1294
1295        # Add to cache:  One use as an argument gives one call
1296        self.assertEqual(f(mock_int, 1), 16)
1297        self.assertEqual(mock_int.__hash__.call_count, 1)
1298        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1299
1300        # Cache hit: One use as an argument gives one additional call
1301        self.assertEqual(f(mock_int, 1), 16)
1302        self.assertEqual(mock_int.__hash__.call_count, 2)
1303        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1304
1305        # Cache eviction: No use as an argument gives no additional call
1306        self.assertEqual(f(6, 2), 20)
1307        self.assertEqual(mock_int.__hash__.call_count, 2)
1308        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1309
1310        # Cache miss: One use as an argument gives one additional call
1311        self.assertEqual(f(mock_int, 1), 16)
1312        self.assertEqual(mock_int.__hash__.call_count, 3)
1313        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1314
1315    def test_lru_reentrancy_with_len(self):
1316        # Test to make sure the LRU cache code isn't thrown-off by
1317        # caching the built-in len() function.  Since len() can be
1318        # cached, we shouldn't use it inside the lru code itself.
1319        old_len = builtins.len
1320        try:
1321            builtins.len = self.module.lru_cache(4)(len)
1322            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1323                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1324        finally:
1325            builtins.len = old_len
1326
1327    def test_lru_star_arg_handling(self):
1328        # Test regression that arose in ea064ff3c10f
1329        @functools.lru_cache()
1330        def f(*args):
1331            return args
1332
1333        self.assertEqual(f(1, 2), (1, 2))
1334        self.assertEqual(f((1, 2)), ((1, 2),))
1335
1336    def test_lru_type_error(self):
1337        # Regression test for issue #28653.
1338        # lru_cache was leaking when one of the arguments
1339        # wasn't cacheable.
1340
1341        @functools.lru_cache(maxsize=None)
1342        def infinite_cache(o):
1343            pass
1344
1345        @functools.lru_cache(maxsize=10)
1346        def limited_cache(o):
1347            pass
1348
1349        with self.assertRaises(TypeError):
1350            infinite_cache([])
1351
1352        with self.assertRaises(TypeError):
1353            limited_cache([])
1354
1355    def test_lru_with_maxsize_none(self):
1356        @self.module.lru_cache(maxsize=None)
1357        def fib(n):
1358            if n < 2:
1359                return n
1360            return fib(n-1) + fib(n-2)
1361        self.assertEqual([fib(n) for n in range(16)],
1362            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1363        self.assertEqual(fib.cache_info(),
1364            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1365        fib.cache_clear()
1366        self.assertEqual(fib.cache_info(),
1367            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1368
1369    def test_lru_with_maxsize_negative(self):
1370        @self.module.lru_cache(maxsize=-10)
1371        def eq(n):
1372            return n
1373        for i in (0, 1):
1374            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1375        self.assertEqual(eq.cache_info(),
1376            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1377
1378    def test_lru_with_exceptions(self):
1379        # Verify that user_function exceptions get passed through without
1380        # creating a hard-to-read chained exception.
1381        # http://bugs.python.org/issue13177
1382        for maxsize in (None, 128):
1383            @self.module.lru_cache(maxsize)
1384            def func(i):
1385                return 'abc'[i]
1386            self.assertEqual(func(0), 'a')
1387            with self.assertRaises(IndexError) as cm:
1388                func(15)
1389            self.assertIsNone(cm.exception.__context__)
1390            # Verify that the previous exception did not result in a cached entry
1391            with self.assertRaises(IndexError):
1392                func(15)
1393
1394    def test_lru_with_types(self):
1395        for maxsize in (None, 128):
1396            @self.module.lru_cache(maxsize=maxsize, typed=True)
1397            def square(x):
1398                return x * x
1399            self.assertEqual(square(3), 9)
1400            self.assertEqual(type(square(3)), type(9))
1401            self.assertEqual(square(3.0), 9.0)
1402            self.assertEqual(type(square(3.0)), type(9.0))
1403            self.assertEqual(square(x=3), 9)
1404            self.assertEqual(type(square(x=3)), type(9))
1405            self.assertEqual(square(x=3.0), 9.0)
1406            self.assertEqual(type(square(x=3.0)), type(9.0))
1407            self.assertEqual(square.cache_info().hits, 4)
1408            self.assertEqual(square.cache_info().misses, 4)
1409
1410    def test_lru_with_keyword_args(self):
1411        @self.module.lru_cache()
1412        def fib(n):
1413            if n < 2:
1414                return n
1415            return fib(n=n-1) + fib(n=n-2)
1416        self.assertEqual(
1417            [fib(n=number) for number in range(16)],
1418            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1419        )
1420        self.assertEqual(fib.cache_info(),
1421            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1422        fib.cache_clear()
1423        self.assertEqual(fib.cache_info(),
1424            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1425
1426    def test_lru_with_keyword_args_maxsize_none(self):
1427        @self.module.lru_cache(maxsize=None)
1428        def fib(n):
1429            if n < 2:
1430                return n
1431            return fib(n=n-1) + fib(n=n-2)
1432        self.assertEqual([fib(n=number) for number in range(16)],
1433            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1434        self.assertEqual(fib.cache_info(),
1435            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1436        fib.cache_clear()
1437        self.assertEqual(fib.cache_info(),
1438            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1439
1440    def test_kwargs_order(self):
1441        # PEP 468: Preserving Keyword Argument Order
1442        @self.module.lru_cache(maxsize=10)
1443        def f(**kwargs):
1444            return list(kwargs.items())
1445        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1446        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1447        self.assertEqual(f.cache_info(),
1448            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1449
1450    def test_lru_cache_decoration(self):
1451        def f(zomg: 'zomg_annotation'):
1452            """f doc string"""
1453            return 42
1454        g = self.module.lru_cache()(f)
1455        for attr in self.module.WRAPPER_ASSIGNMENTS:
1456            self.assertEqual(getattr(g, attr), getattr(f, attr))
1457
1458    def test_lru_cache_threaded(self):
1459        n, m = 5, 11
1460        def orig(x, y):
1461            return 3 * x + y
1462        f = self.module.lru_cache(maxsize=n*m)(orig)
1463        hits, misses, maxsize, currsize = f.cache_info()
1464        self.assertEqual(currsize, 0)
1465
1466        start = threading.Event()
1467        def full(k):
1468            start.wait(10)
1469            for _ in range(m):
1470                self.assertEqual(f(k, 0), orig(k, 0))
1471
1472        def clear():
1473            start.wait(10)
1474            for _ in range(2*m):
1475                f.cache_clear()
1476
1477        orig_si = sys.getswitchinterval()
1478        support.setswitchinterval(1e-6)
1479        try:
1480            # create n threads in order to fill cache
1481            threads = [threading.Thread(target=full, args=[k])
1482                       for k in range(n)]
1483            with support.start_threads(threads):
1484                start.set()
1485
1486            hits, misses, maxsize, currsize = f.cache_info()
1487            if self.module is py_functools:
1488                # XXX: Why can be not equal?
1489                self.assertLessEqual(misses, n)
1490                self.assertLessEqual(hits, m*n - misses)
1491            else:
1492                self.assertEqual(misses, n)
1493                self.assertEqual(hits, m*n - misses)
1494            self.assertEqual(currsize, n)
1495
1496            # create n threads in order to fill cache and 1 to clear it
1497            threads = [threading.Thread(target=clear)]
1498            threads += [threading.Thread(target=full, args=[k])
1499                        for k in range(n)]
1500            start.clear()
1501            with support.start_threads(threads):
1502                start.set()
1503        finally:
1504            sys.setswitchinterval(orig_si)
1505
1506    def test_lru_cache_threaded2(self):
1507        # Simultaneous call with the same arguments
1508        n, m = 5, 7
1509        start = threading.Barrier(n+1)
1510        pause = threading.Barrier(n+1)
1511        stop = threading.Barrier(n+1)
1512        @self.module.lru_cache(maxsize=m*n)
1513        def f(x):
1514            pause.wait(10)
1515            return 3 * x
1516        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1517        def test():
1518            for i in range(m):
1519                start.wait(10)
1520                self.assertEqual(f(i), 3 * i)
1521                stop.wait(10)
1522        threads = [threading.Thread(target=test) for k in range(n)]
1523        with support.start_threads(threads):
1524            for i in range(m):
1525                start.wait(10)
1526                stop.reset()
1527                pause.wait(10)
1528                start.reset()
1529                stop.wait(10)
1530                pause.reset()
1531                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1532
1533    def test_lru_cache_threaded3(self):
1534        @self.module.lru_cache(maxsize=2)
1535        def f(x):
1536            time.sleep(.01)
1537            return 3 * x
1538        def test(i, x):
1539            with self.subTest(thread=i):
1540                self.assertEqual(f(x), 3 * x, i)
1541        threads = [threading.Thread(target=test, args=(i, v))
1542                   for i, v in enumerate([1, 2, 2, 3, 2])]
1543        with support.start_threads(threads):
1544            pass
1545
1546    def test_need_for_rlock(self):
1547        # This will deadlock on an LRU cache that uses a regular lock
1548
1549        @self.module.lru_cache(maxsize=10)
1550        def test_func(x):
1551            'Used to demonstrate a reentrant lru_cache call within a single thread'
1552            return x
1553
1554        class DoubleEq:
1555            'Demonstrate a reentrant lru_cache call within a single thread'
1556            def __init__(self, x):
1557                self.x = x
1558            def __hash__(self):
1559                return self.x
1560            def __eq__(self, other):
1561                if self.x == 2:
1562                    test_func(DoubleEq(1))
1563                return self.x == other.x
1564
1565        test_func(DoubleEq(1))                      # Load the cache
1566        test_func(DoubleEq(2))                      # Load the cache
1567        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1568                         DoubleEq(2))               # Verify the correct return value
1569
1570    def test_early_detection_of_bad_call(self):
1571        # Issue #22184
1572        with self.assertRaises(TypeError):
1573            @functools.lru_cache
1574            def f():
1575                pass
1576
1577    def test_lru_method(self):
1578        class X(int):
1579            f_cnt = 0
1580            @self.module.lru_cache(2)
1581            def f(self, x):
1582                self.f_cnt += 1
1583                return x*10+self
1584        a = X(5)
1585        b = X(5)
1586        c = X(7)
1587        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1588
1589        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1590            self.assertEqual(a.f(x), x*10 + 5)
1591        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1592        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1593
1594        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1595            self.assertEqual(b.f(x), x*10 + 5)
1596        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1597        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1598
1599        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1600            self.assertEqual(c.f(x), x*10 + 7)
1601        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1602        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1603
1604        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1605        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1606        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1607
1608    def test_pickle(self):
1609        cls = self.__class__
1610        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1611            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1612                with self.subTest(proto=proto, func=f):
1613                    f_copy = pickle.loads(pickle.dumps(f, proto))
1614                    self.assertIs(f_copy, f)
1615
1616    def test_copy(self):
1617        cls = self.__class__
1618        def orig(x, y):
1619            return 3 * x + y
1620        part = self.module.partial(orig, 2)
1621        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1622                 self.module.lru_cache(2)(part))
1623        for f in funcs:
1624            with self.subTest(func=f):
1625                f_copy = copy.copy(f)
1626                self.assertIs(f_copy, f)
1627
1628    def test_deepcopy(self):
1629        cls = self.__class__
1630        def orig(x, y):
1631            return 3 * x + y
1632        part = self.module.partial(orig, 2)
1633        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1634                 self.module.lru_cache(2)(part))
1635        for f in funcs:
1636            with self.subTest(func=f):
1637                f_copy = copy.deepcopy(f)
1638                self.assertIs(f_copy, f)
1639
1640
1641@py_functools.lru_cache()
1642def py_cached_func(x, y):
1643    return 3 * x + y
1644
1645@c_functools.lru_cache()
1646def c_cached_func(x, y):
1647    return 3 * x + y
1648
1649
1650class TestLRUPy(TestLRU, unittest.TestCase):
1651    module = py_functools
1652    cached_func = py_cached_func,
1653
1654    @module.lru_cache()
1655    def cached_meth(self, x, y):
1656        return 3 * x + y
1657
1658    @staticmethod
1659    @module.lru_cache()
1660    def cached_staticmeth(x, y):
1661        return 3 * x + y
1662
1663
1664class TestLRUC(TestLRU, unittest.TestCase):
1665    module = c_functools
1666    cached_func = c_cached_func,
1667
1668    @module.lru_cache()
1669    def cached_meth(self, x, y):
1670        return 3 * x + y
1671
1672    @staticmethod
1673    @module.lru_cache()
1674    def cached_staticmeth(x, y):
1675        return 3 * x + y
1676
1677
1678class TestSingleDispatch(unittest.TestCase):
1679    def test_simple_overloads(self):
1680        @functools.singledispatch
1681        def g(obj):
1682            return "base"
1683        def g_int(i):
1684            return "integer"
1685        g.register(int, g_int)
1686        self.assertEqual(g("str"), "base")
1687        self.assertEqual(g(1), "integer")
1688        self.assertEqual(g([1,2,3]), "base")
1689
1690    def test_mro(self):
1691        @functools.singledispatch
1692        def g(obj):
1693            return "base"
1694        class A:
1695            pass
1696        class C(A):
1697            pass
1698        class B(A):
1699            pass
1700        class D(C, B):
1701            pass
1702        def g_A(a):
1703            return "A"
1704        def g_B(b):
1705            return "B"
1706        g.register(A, g_A)
1707        g.register(B, g_B)
1708        self.assertEqual(g(A()), "A")
1709        self.assertEqual(g(B()), "B")
1710        self.assertEqual(g(C()), "A")
1711        self.assertEqual(g(D()), "B")
1712
1713    def test_register_decorator(self):
1714        @functools.singledispatch
1715        def g(obj):
1716            return "base"
1717        @g.register(int)
1718        def g_int(i):
1719            return "int %s" % (i,)
1720        self.assertEqual(g(""), "base")
1721        self.assertEqual(g(12), "int 12")
1722        self.assertIs(g.dispatch(int), g_int)
1723        self.assertIs(g.dispatch(object), g.dispatch(str))
1724        # Note: in the assert above this is not g.
1725        # @singledispatch returns the wrapper.
1726
1727    def test_wrapping_attributes(self):
1728        @functools.singledispatch
1729        def g(obj):
1730            "Simple test"
1731            return "Test"
1732        self.assertEqual(g.__name__, "g")
1733        if sys.flags.optimize < 2:
1734            self.assertEqual(g.__doc__, "Simple test")
1735
1736    @unittest.skipUnless(decimal, 'requires _decimal')
1737    @support.cpython_only
1738    def test_c_classes(self):
1739        @functools.singledispatch
1740        def g(obj):
1741            return "base"
1742        @g.register(decimal.DecimalException)
1743        def _(obj):
1744            return obj.args
1745        subn = decimal.Subnormal("Exponent < Emin")
1746        rnd = decimal.Rounded("Number got rounded")
1747        self.assertEqual(g(subn), ("Exponent < Emin",))
1748        self.assertEqual(g(rnd), ("Number got rounded",))
1749        @g.register(decimal.Subnormal)
1750        def _(obj):
1751            return "Too small to care."
1752        self.assertEqual(g(subn), "Too small to care.")
1753        self.assertEqual(g(rnd), ("Number got rounded",))
1754
1755    def test_compose_mro(self):
1756        # None of the examples in this test depend on haystack ordering.
1757        c = collections.abc
1758        mro = functools._compose_mro
1759        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1760        for haystack in permutations(bases):
1761            m = mro(dict, haystack)
1762            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1763                                 c.Collection, c.Sized, c.Iterable,
1764                                 c.Container, object])
1765        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1766        for haystack in permutations(bases):
1767            m = mro(collections.ChainMap, haystack)
1768            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1769                                 c.Collection, c.Sized, c.Iterable,
1770                                 c.Container, object])
1771
1772        # If there's a generic function with implementations registered for
1773        # both Sized and Container, passing a defaultdict to it results in an
1774        # ambiguous dispatch which will cause a RuntimeError (see
1775        # test_mro_conflicts).
1776        bases = [c.Container, c.Sized, str]
1777        for haystack in permutations(bases):
1778            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1779            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1780                                 c.Container, object])
1781
1782        # MutableSequence below is registered directly on D. In other words, it
1783        # precedes MutableMapping which means single dispatch will always
1784        # choose MutableSequence here.
1785        class D(collections.defaultdict):
1786            pass
1787        c.MutableSequence.register(D)
1788        bases = [c.MutableSequence, c.MutableMapping]
1789        for haystack in permutations(bases):
1790            m = mro(D, bases)
1791            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1792                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1793                                 c.Collection, c.Sized, c.Iterable, c.Container,
1794                                 object])
1795
1796        # Container and Callable are registered on different base classes and
1797        # a generic function supporting both should always pick the Callable
1798        # implementation if a C instance is passed.
1799        class C(collections.defaultdict):
1800            def __call__(self):
1801                pass
1802        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1803        for haystack in permutations(bases):
1804            m = mro(C, haystack)
1805            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1806                                 c.Collection, c.Sized, c.Iterable,
1807                                 c.Container, object])
1808
1809    def test_register_abc(self):
1810        c = collections.abc
1811        d = {"a": "b"}
1812        l = [1, 2, 3]
1813        s = {object(), None}
1814        f = frozenset(s)
1815        t = (1, 2, 3)
1816        @functools.singledispatch
1817        def g(obj):
1818            return "base"
1819        self.assertEqual(g(d), "base")
1820        self.assertEqual(g(l), "base")
1821        self.assertEqual(g(s), "base")
1822        self.assertEqual(g(f), "base")
1823        self.assertEqual(g(t), "base")
1824        g.register(c.Sized, lambda obj: "sized")
1825        self.assertEqual(g(d), "sized")
1826        self.assertEqual(g(l), "sized")
1827        self.assertEqual(g(s), "sized")
1828        self.assertEqual(g(f), "sized")
1829        self.assertEqual(g(t), "sized")
1830        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1831        self.assertEqual(g(d), "mutablemapping")
1832        self.assertEqual(g(l), "sized")
1833        self.assertEqual(g(s), "sized")
1834        self.assertEqual(g(f), "sized")
1835        self.assertEqual(g(t), "sized")
1836        g.register(collections.ChainMap, lambda obj: "chainmap")
1837        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1838        self.assertEqual(g(l), "sized")
1839        self.assertEqual(g(s), "sized")
1840        self.assertEqual(g(f), "sized")
1841        self.assertEqual(g(t), "sized")
1842        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1843        self.assertEqual(g(d), "mutablemapping")
1844        self.assertEqual(g(l), "mutablesequence")
1845        self.assertEqual(g(s), "sized")
1846        self.assertEqual(g(f), "sized")
1847        self.assertEqual(g(t), "sized")
1848        g.register(c.MutableSet, lambda obj: "mutableset")
1849        self.assertEqual(g(d), "mutablemapping")
1850        self.assertEqual(g(l), "mutablesequence")
1851        self.assertEqual(g(s), "mutableset")
1852        self.assertEqual(g(f), "sized")
1853        self.assertEqual(g(t), "sized")
1854        g.register(c.Mapping, lambda obj: "mapping")
1855        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1856        self.assertEqual(g(l), "mutablesequence")
1857        self.assertEqual(g(s), "mutableset")
1858        self.assertEqual(g(f), "sized")
1859        self.assertEqual(g(t), "sized")
1860        g.register(c.Sequence, lambda obj: "sequence")
1861        self.assertEqual(g(d), "mutablemapping")
1862        self.assertEqual(g(l), "mutablesequence")
1863        self.assertEqual(g(s), "mutableset")
1864        self.assertEqual(g(f), "sized")
1865        self.assertEqual(g(t), "sequence")
1866        g.register(c.Set, lambda obj: "set")
1867        self.assertEqual(g(d), "mutablemapping")
1868        self.assertEqual(g(l), "mutablesequence")
1869        self.assertEqual(g(s), "mutableset")
1870        self.assertEqual(g(f), "set")
1871        self.assertEqual(g(t), "sequence")
1872        g.register(dict, lambda obj: "dict")
1873        self.assertEqual(g(d), "dict")
1874        self.assertEqual(g(l), "mutablesequence")
1875        self.assertEqual(g(s), "mutableset")
1876        self.assertEqual(g(f), "set")
1877        self.assertEqual(g(t), "sequence")
1878        g.register(list, lambda obj: "list")
1879        self.assertEqual(g(d), "dict")
1880        self.assertEqual(g(l), "list")
1881        self.assertEqual(g(s), "mutableset")
1882        self.assertEqual(g(f), "set")
1883        self.assertEqual(g(t), "sequence")
1884        g.register(set, lambda obj: "concrete-set")
1885        self.assertEqual(g(d), "dict")
1886        self.assertEqual(g(l), "list")
1887        self.assertEqual(g(s), "concrete-set")
1888        self.assertEqual(g(f), "set")
1889        self.assertEqual(g(t), "sequence")
1890        g.register(frozenset, lambda obj: "frozen-set")
1891        self.assertEqual(g(d), "dict")
1892        self.assertEqual(g(l), "list")
1893        self.assertEqual(g(s), "concrete-set")
1894        self.assertEqual(g(f), "frozen-set")
1895        self.assertEqual(g(t), "sequence")
1896        g.register(tuple, lambda obj: "tuple")
1897        self.assertEqual(g(d), "dict")
1898        self.assertEqual(g(l), "list")
1899        self.assertEqual(g(s), "concrete-set")
1900        self.assertEqual(g(f), "frozen-set")
1901        self.assertEqual(g(t), "tuple")
1902
1903    def test_c3_abc(self):
1904        c = collections.abc
1905        mro = functools._c3_mro
1906        class A(object):
1907            pass
1908        class B(A):
1909            def __len__(self):
1910                return 0   # implies Sized
1911        @c.Container.register
1912        class C(object):
1913            pass
1914        class D(object):
1915            pass   # unrelated
1916        class X(D, C, B):
1917            def __call__(self):
1918                pass   # implies Callable
1919        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1920        for abcs in permutations([c.Sized, c.Callable, c.Container]):
1921            self.assertEqual(mro(X, abcs=abcs), expected)
1922        # unrelated ABCs don't appear in the resulting MRO
1923        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1924        self.assertEqual(mro(X, abcs=many_abcs), expected)
1925
1926    def test_false_meta(self):
1927        # see issue23572
1928        class MetaA(type):
1929            def __len__(self):
1930                return 0
1931        class A(metaclass=MetaA):
1932            pass
1933        class AA(A):
1934            pass
1935        @functools.singledispatch
1936        def fun(a):
1937            return 'base A'
1938        @fun.register(A)
1939        def _(a):
1940            return 'fun A'
1941        aa = AA()
1942        self.assertEqual(fun(aa), 'fun A')
1943
1944    def test_mro_conflicts(self):
1945        c = collections.abc
1946        @functools.singledispatch
1947        def g(arg):
1948            return "base"
1949        class O(c.Sized):
1950            def __len__(self):
1951                return 0
1952        o = O()
1953        self.assertEqual(g(o), "base")
1954        g.register(c.Iterable, lambda arg: "iterable")
1955        g.register(c.Container, lambda arg: "container")
1956        g.register(c.Sized, lambda arg: "sized")
1957        g.register(c.Set, lambda arg: "set")
1958        self.assertEqual(g(o), "sized")
1959        c.Iterable.register(O)
1960        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
1961        c.Container.register(O)
1962        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
1963        c.Set.register(O)
1964        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
1965                                          # c.Sized and c.Container
1966        class P:
1967            pass
1968        p = P()
1969        self.assertEqual(g(p), "base")
1970        c.Iterable.register(P)
1971        self.assertEqual(g(p), "iterable")
1972        c.Container.register(P)
1973        with self.assertRaises(RuntimeError) as re_one:
1974            g(p)
1975        self.assertIn(
1976            str(re_one.exception),
1977            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1978              "or <class 'collections.abc.Iterable'>"),
1979             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1980              "or <class 'collections.abc.Container'>")),
1981        )
1982        class Q(c.Sized):
1983            def __len__(self):
1984                return 0
1985        q = Q()
1986        self.assertEqual(g(q), "sized")
1987        c.Iterable.register(Q)
1988        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
1989        c.Set.register(Q)
1990        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
1991                                          # c.Sized and c.Iterable
1992        @functools.singledispatch
1993        def h(arg):
1994            return "base"
1995        @h.register(c.Sized)
1996        def _(arg):
1997            return "sized"
1998        @h.register(c.Container)
1999        def _(arg):
2000            return "container"
2001        # Even though Sized and Container are explicit bases of MutableMapping,
2002        # this ABC is implicitly registered on defaultdict which makes all of
2003        # MutableMapping's bases implicit as well from defaultdict's
2004        # perspective.
2005        with self.assertRaises(RuntimeError) as re_two:
2006            h(collections.defaultdict(lambda: 0))
2007        self.assertIn(
2008            str(re_two.exception),
2009            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2010              "or <class 'collections.abc.Sized'>"),
2011             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2012              "or <class 'collections.abc.Container'>")),
2013        )
2014        class R(collections.defaultdict):
2015            pass
2016        c.MutableSequence.register(R)
2017        @functools.singledispatch
2018        def i(arg):
2019            return "base"
2020        @i.register(c.MutableMapping)
2021        def _(arg):
2022            return "mapping"
2023        @i.register(c.MutableSequence)
2024        def _(arg):
2025            return "sequence"
2026        r = R()
2027        self.assertEqual(i(r), "sequence")
2028        class S:
2029            pass
2030        class T(S, c.Sized):
2031            def __len__(self):
2032                return 0
2033        t = T()
2034        self.assertEqual(h(t), "sized")
2035        c.Container.register(T)
2036        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2037        class U:
2038            def __len__(self):
2039                return 0
2040        u = U()
2041        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2042                                          # from the existence of __len__()
2043        c.Container.register(U)
2044        # There is no preference for registered versus inferred ABCs.
2045        with self.assertRaises(RuntimeError) as re_three:
2046            h(u)
2047        self.assertIn(
2048            str(re_three.exception),
2049            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2050              "or <class 'collections.abc.Sized'>"),
2051             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2052              "or <class 'collections.abc.Container'>")),
2053        )
2054        class V(c.Sized, S):
2055            def __len__(self):
2056                return 0
2057        @functools.singledispatch
2058        def j(arg):
2059            return "base"
2060        @j.register(S)
2061        def _(arg):
2062            return "s"
2063        @j.register(c.Container)
2064        def _(arg):
2065            return "container"
2066        v = V()
2067        self.assertEqual(j(v), "s")
2068        c.Container.register(V)
2069        self.assertEqual(j(v), "container")   # because it ends up right after
2070                                              # Sized in the MRO
2071
2072    def test_cache_invalidation(self):
2073        from collections import UserDict
2074        import weakref
2075
2076        class TracingDict(UserDict):
2077            def __init__(self, *args, **kwargs):
2078                super(TracingDict, self).__init__(*args, **kwargs)
2079                self.set_ops = []
2080                self.get_ops = []
2081            def __getitem__(self, key):
2082                result = self.data[key]
2083                self.get_ops.append(key)
2084                return result
2085            def __setitem__(self, key, value):
2086                self.set_ops.append(key)
2087                self.data[key] = value
2088            def clear(self):
2089                self.data.clear()
2090
2091        td = TracingDict()
2092        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2093            c = collections.abc
2094            @functools.singledispatch
2095            def g(arg):
2096                return "base"
2097            d = {}
2098            l = []
2099            self.assertEqual(len(td), 0)
2100            self.assertEqual(g(d), "base")
2101            self.assertEqual(len(td), 1)
2102            self.assertEqual(td.get_ops, [])
2103            self.assertEqual(td.set_ops, [dict])
2104            self.assertEqual(td.data[dict], g.registry[object])
2105            self.assertEqual(g(l), "base")
2106            self.assertEqual(len(td), 2)
2107            self.assertEqual(td.get_ops, [])
2108            self.assertEqual(td.set_ops, [dict, list])
2109            self.assertEqual(td.data[dict], g.registry[object])
2110            self.assertEqual(td.data[list], g.registry[object])
2111            self.assertEqual(td.data[dict], td.data[list])
2112            self.assertEqual(g(l), "base")
2113            self.assertEqual(g(d), "base")
2114            self.assertEqual(td.get_ops, [list, dict])
2115            self.assertEqual(td.set_ops, [dict, list])
2116            g.register(list, lambda arg: "list")
2117            self.assertEqual(td.get_ops, [list, dict])
2118            self.assertEqual(len(td), 0)
2119            self.assertEqual(g(d), "base")
2120            self.assertEqual(len(td), 1)
2121            self.assertEqual(td.get_ops, [list, dict])
2122            self.assertEqual(td.set_ops, [dict, list, dict])
2123            self.assertEqual(td.data[dict],
2124                             functools._find_impl(dict, g.registry))
2125            self.assertEqual(g(l), "list")
2126            self.assertEqual(len(td), 2)
2127            self.assertEqual(td.get_ops, [list, dict])
2128            self.assertEqual(td.set_ops, [dict, list, dict, list])
2129            self.assertEqual(td.data[list],
2130                             functools._find_impl(list, g.registry))
2131            class X:
2132                pass
2133            c.MutableMapping.register(X)   # Will not invalidate the cache,
2134                                           # not using ABCs yet.
2135            self.assertEqual(g(d), "base")
2136            self.assertEqual(g(l), "list")
2137            self.assertEqual(td.get_ops, [list, dict, dict, list])
2138            self.assertEqual(td.set_ops, [dict, list, dict, list])
2139            g.register(c.Sized, lambda arg: "sized")
2140            self.assertEqual(len(td), 0)
2141            self.assertEqual(g(d), "sized")
2142            self.assertEqual(len(td), 1)
2143            self.assertEqual(td.get_ops, [list, dict, dict, list])
2144            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2145            self.assertEqual(g(l), "list")
2146            self.assertEqual(len(td), 2)
2147            self.assertEqual(td.get_ops, [list, dict, dict, list])
2148            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2149            self.assertEqual(g(l), "list")
2150            self.assertEqual(g(d), "sized")
2151            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2152            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2153            g.dispatch(list)
2154            g.dispatch(dict)
2155            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2156                                          list, dict])
2157            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2158            c.MutableSet.register(X)       # Will invalidate the cache.
2159            self.assertEqual(len(td), 2)   # Stale cache.
2160            self.assertEqual(g(l), "list")
2161            self.assertEqual(len(td), 1)
2162            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2163            self.assertEqual(len(td), 0)
2164            self.assertEqual(g(d), "mutablemapping")
2165            self.assertEqual(len(td), 1)
2166            self.assertEqual(g(l), "list")
2167            self.assertEqual(len(td), 2)
2168            g.register(dict, lambda arg: "dict")
2169            self.assertEqual(g(d), "dict")
2170            self.assertEqual(g(l), "list")
2171            g._clear_cache()
2172            self.assertEqual(len(td), 0)
2173
2174    def test_annotations(self):
2175        @functools.singledispatch
2176        def i(arg):
2177            return "base"
2178        @i.register
2179        def _(arg: collections.abc.Mapping):
2180            return "mapping"
2181        @i.register
2182        def _(arg: "collections.abc.Sequence"):
2183            return "sequence"
2184        self.assertEqual(i(None), "base")
2185        self.assertEqual(i({"a": 1}), "mapping")
2186        self.assertEqual(i([1, 2, 3]), "sequence")
2187        self.assertEqual(i((1, 2, 3)), "sequence")
2188        self.assertEqual(i("str"), "sequence")
2189
2190        # Registering classes as callables doesn't work with annotations,
2191        # you need to pass the type explicitly.
2192        @i.register(str)
2193        class _:
2194            def __init__(self, arg):
2195                self.arg = arg
2196
2197            def __eq__(self, other):
2198                return self.arg == other
2199        self.assertEqual(i("str"), "str")
2200
2201    def test_invalid_registrations(self):
2202        msg_prefix = "Invalid first argument to `register()`: "
2203        msg_suffix = (
2204            ". Use either `@register(some_class)` or plain `@register` on an "
2205            "annotated function."
2206        )
2207        @functools.singledispatch
2208        def i(arg):
2209            return "base"
2210        with self.assertRaises(TypeError) as exc:
2211            @i.register(42)
2212            def _(arg):
2213                return "I annotated with a non-type"
2214        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2215        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2216        with self.assertRaises(TypeError) as exc:
2217            @i.register
2218            def _(arg):
2219                return "I forgot to annotate"
2220        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2221            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2222        ))
2223        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2224
2225        # FIXME: The following will only work after PEP 560 is implemented.
2226        return
2227
2228        with self.assertRaises(TypeError) as exc:
2229            @i.register
2230            def _(arg: typing.Iterable[str]):
2231                # At runtime, dispatching on generics is impossible.
2232                # When registering implementations with singledispatch, avoid
2233                # types from `typing`. Instead, annotate with regular types
2234                # or ABCs.
2235                return "I annotated with a generic collection"
2236        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2237            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2238        ))
2239        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2240
2241    def test_invalid_positional_argument(self):
2242        @functools.singledispatch
2243        def f(*args):
2244            pass
2245        msg = 'f requires at least 1 positional argument'
2246        with self.assertRaisesRegex(TypeError, msg):
2247            f()
2248
2249if __name__ == '__main__':
2250    unittest.main()
2251