1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16import os 17import weakref 18import gc 19from weakref import proxy 20import contextlib 21 22from test.support.script_helper import assert_python_ok 23 24import functools 25 26py_functools = support.import_fresh_module('functools', blocked=['_functools']) 27c_functools = support.import_fresh_module('functools', fresh=['_functools']) 28 29decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 30 31@contextlib.contextmanager 32def replaced_module(name, replacement): 33 original_module = sys.modules[name] 34 sys.modules[name] = replacement 35 try: 36 yield 37 finally: 38 sys.modules[name] = original_module 39 40def capture(*args, **kw): 41 """capture all positional and keyword arguments""" 42 return args, kw 43 44 45def signature(part): 46 """ return the signature of a partial object """ 47 return (part.func, part.args, part.keywords, part.__dict__) 48 49class MyTuple(tuple): 50 pass 51 52class BadTuple(tuple): 53 def __add__(self, other): 54 return list(self) + list(other) 55 56class MyDict(dict): 57 pass 58 59 60class TestPartial: 61 62 def test_basic_examples(self): 63 p = self.partial(capture, 1, 2, a=10, b=20) 64 self.assertTrue(callable(p)) 65 self.assertEqual(p(3, 4, b=30, c=40), 66 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 67 p = self.partial(map, lambda x: x*10) 68 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 69 70 def test_attributes(self): 71 p = self.partial(capture, 1, 2, a=10, b=20) 72 # attributes should be readable 73 self.assertEqual(p.func, capture) 74 self.assertEqual(p.args, (1, 2)) 75 self.assertEqual(p.keywords, dict(a=10, b=20)) 76 77 def test_argument_checking(self): 78 self.assertRaises(TypeError, self.partial) # need at least a func arg 79 try: 80 self.partial(2)() 81 except TypeError: 82 pass 83 else: 84 self.fail('First arg not checked for callability') 85 86 def test_protection_of_callers_dict_argument(self): 87 # a caller's dictionary should not be altered by partial 88 def func(a=10, b=20): 89 return a 90 d = {'a':3} 91 p = self.partial(func, a=5) 92 self.assertEqual(p(**d), 3) 93 self.assertEqual(d, {'a':3}) 94 p(b=7) 95 self.assertEqual(d, {'a':3}) 96 97 def test_kwargs_copy(self): 98 # Issue #29532: Altering a kwarg dictionary passed to a constructor 99 # should not affect a partial object after creation 100 d = {'a': 3} 101 p = self.partial(capture, **d) 102 self.assertEqual(p(), ((), {'a': 3})) 103 d['a'] = 5 104 self.assertEqual(p(), ((), {'a': 3})) 105 106 def test_arg_combinations(self): 107 # exercise special code paths for zero args in either partial 108 # object or the caller 109 p = self.partial(capture) 110 self.assertEqual(p(), ((), {})) 111 self.assertEqual(p(1,2), ((1,2), {})) 112 p = self.partial(capture, 1, 2) 113 self.assertEqual(p(), ((1,2), {})) 114 self.assertEqual(p(3,4), ((1,2,3,4), {})) 115 116 def test_kw_combinations(self): 117 # exercise special code paths for no keyword args in 118 # either the partial object or the caller 119 p = self.partial(capture) 120 self.assertEqual(p.keywords, {}) 121 self.assertEqual(p(), ((), {})) 122 self.assertEqual(p(a=1), ((), {'a':1})) 123 p = self.partial(capture, a=1) 124 self.assertEqual(p.keywords, {'a':1}) 125 self.assertEqual(p(), ((), {'a':1})) 126 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 127 # keyword args in the call override those in the partial object 128 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 129 130 def test_positional(self): 131 # make sure positional arguments are captured correctly 132 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 133 p = self.partial(capture, *args) 134 expected = args + ('x',) 135 got, empty = p('x') 136 self.assertTrue(expected == got and empty == {}) 137 138 def test_keyword(self): 139 # make sure keyword arguments are captured correctly 140 for a in ['a', 0, None, 3.5]: 141 p = self.partial(capture, a=a) 142 expected = {'a':a,'x':None} 143 empty, got = p(x=None) 144 self.assertTrue(expected == got and empty == ()) 145 146 def test_no_side_effects(self): 147 # make sure there are no side effects that affect subsequent calls 148 p = self.partial(capture, 0, a=1) 149 args1, kw1 = p(1, b=2) 150 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 151 args2, kw2 = p() 152 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 153 154 def test_error_propagation(self): 155 def f(x, y): 156 x / y 157 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 158 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 159 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 160 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 161 162 def test_weakref(self): 163 f = self.partial(int, base=16) 164 p = proxy(f) 165 self.assertEqual(f.func, p.func) 166 f = None 167 support.gc_collect() # For PyPy or other GCs. 168 self.assertRaises(ReferenceError, getattr, p, 'func') 169 170 def test_with_bound_and_unbound_methods(self): 171 data = list(map(str, range(10))) 172 join = self.partial(str.join, '') 173 self.assertEqual(join(data), '0123456789') 174 join = self.partial(''.join) 175 self.assertEqual(join(data), '0123456789') 176 177 def test_nested_optimization(self): 178 partial = self.partial 179 inner = partial(signature, 'asdf') 180 nested = partial(inner, bar=True) 181 flat = partial(signature, 'asdf', bar=True) 182 self.assertEqual(signature(nested), signature(flat)) 183 184 def test_nested_partial_with_attribute(self): 185 # see issue 25137 186 partial = self.partial 187 188 def foo(bar): 189 return bar 190 191 p = partial(foo, 'first') 192 p2 = partial(p, 'second') 193 p2.new_attr = 'spam' 194 self.assertEqual(p2.new_attr, 'spam') 195 196 def test_repr(self): 197 args = (object(), object()) 198 args_repr = ', '.join(repr(a) for a in args) 199 kwargs = {'a': object(), 'b': object()} 200 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 201 'b={b!r}, a={a!r}'.format_map(kwargs)] 202 if self.partial in (c_functools.partial, py_functools.partial): 203 name = 'functools.partial' 204 else: 205 name = self.partial.__name__ 206 207 f = self.partial(capture) 208 self.assertEqual(f'{name}({capture!r})', repr(f)) 209 210 f = self.partial(capture, *args) 211 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 212 213 f = self.partial(capture, **kwargs) 214 self.assertIn(repr(f), 215 [f'{name}({capture!r}, {kwargs_repr})' 216 for kwargs_repr in kwargs_reprs]) 217 218 f = self.partial(capture, *args, **kwargs) 219 self.assertIn(repr(f), 220 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 221 for kwargs_repr in kwargs_reprs]) 222 223 def test_recursive_repr(self): 224 if self.partial in (c_functools.partial, py_functools.partial): 225 name = 'functools.partial' 226 else: 227 name = self.partial.__name__ 228 229 f = self.partial(capture) 230 f.__setstate__((f, (), {}, {})) 231 try: 232 self.assertEqual(repr(f), '%s(...)' % (name,)) 233 finally: 234 f.__setstate__((capture, (), {}, {})) 235 236 f = self.partial(capture) 237 f.__setstate__((capture, (f,), {}, {})) 238 try: 239 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 240 finally: 241 f.__setstate__((capture, (), {}, {})) 242 243 f = self.partial(capture) 244 f.__setstate__((capture, (), {'a': f}, {})) 245 try: 246 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 247 finally: 248 f.__setstate__((capture, (), {}, {})) 249 250 def test_pickle(self): 251 with self.AllowPickle(): 252 f = self.partial(signature, ['asdf'], bar=[True]) 253 f.attr = [] 254 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 255 f_copy = pickle.loads(pickle.dumps(f, proto)) 256 self.assertEqual(signature(f_copy), signature(f)) 257 258 def test_copy(self): 259 f = self.partial(signature, ['asdf'], bar=[True]) 260 f.attr = [] 261 f_copy = copy.copy(f) 262 self.assertEqual(signature(f_copy), signature(f)) 263 self.assertIs(f_copy.attr, f.attr) 264 self.assertIs(f_copy.args, f.args) 265 self.assertIs(f_copy.keywords, f.keywords) 266 267 def test_deepcopy(self): 268 f = self.partial(signature, ['asdf'], bar=[True]) 269 f.attr = [] 270 f_copy = copy.deepcopy(f) 271 self.assertEqual(signature(f_copy), signature(f)) 272 self.assertIsNot(f_copy.attr, f.attr) 273 self.assertIsNot(f_copy.args, f.args) 274 self.assertIsNot(f_copy.args[0], f.args[0]) 275 self.assertIsNot(f_copy.keywords, f.keywords) 276 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 277 278 def test_setstate(self): 279 f = self.partial(signature) 280 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 281 282 self.assertEqual(signature(f), 283 (capture, (1,), dict(a=10), dict(attr=[]))) 284 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 285 286 f.__setstate__((capture, (1,), dict(a=10), None)) 287 288 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 289 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 290 291 f.__setstate__((capture, (1,), None, None)) 292 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 293 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 294 self.assertEqual(f(2), ((1, 2), {})) 295 self.assertEqual(f(), ((1,), {})) 296 297 f.__setstate__((capture, (), {}, None)) 298 self.assertEqual(signature(f), (capture, (), {}, {})) 299 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 300 self.assertEqual(f(2), ((2,), {})) 301 self.assertEqual(f(), ((), {})) 302 303 def test_setstate_errors(self): 304 f = self.partial(signature) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 306 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 307 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 308 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 309 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 310 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 311 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 312 313 def test_setstate_subclasses(self): 314 f = self.partial(signature) 315 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 316 s = signature(f) 317 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 318 self.assertIs(type(s[1]), tuple) 319 self.assertIs(type(s[2]), dict) 320 r = f() 321 self.assertEqual(r, ((1,), {'a': 10})) 322 self.assertIs(type(r[0]), tuple) 323 self.assertIs(type(r[1]), dict) 324 325 f.__setstate__((capture, BadTuple((1,)), {}, None)) 326 s = signature(f) 327 self.assertEqual(s, (capture, (1,), {}, {})) 328 self.assertIs(type(s[1]), tuple) 329 r = f(2) 330 self.assertEqual(r, ((1, 2), {})) 331 self.assertIs(type(r[0]), tuple) 332 333 def test_recursive_pickle(self): 334 with self.AllowPickle(): 335 f = self.partial(capture) 336 f.__setstate__((f, (), {}, {})) 337 try: 338 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 339 with self.assertRaises(RecursionError): 340 pickle.dumps(f, proto) 341 finally: 342 f.__setstate__((capture, (), {}, {})) 343 344 f = self.partial(capture) 345 f.__setstate__((capture, (f,), {}, {})) 346 try: 347 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 348 f_copy = pickle.loads(pickle.dumps(f, proto)) 349 try: 350 self.assertIs(f_copy.args[0], f_copy) 351 finally: 352 f_copy.__setstate__((capture, (), {}, {})) 353 finally: 354 f.__setstate__((capture, (), {}, {})) 355 356 f = self.partial(capture) 357 f.__setstate__((capture, (), {'a': f}, {})) 358 try: 359 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 360 f_copy = pickle.loads(pickle.dumps(f, proto)) 361 try: 362 self.assertIs(f_copy.keywords['a'], f_copy) 363 finally: 364 f_copy.__setstate__((capture, (), {}, {})) 365 finally: 366 f.__setstate__((capture, (), {}, {})) 367 368 # Issue 6083: Reference counting bug 369 def test_setstate_refcount(self): 370 class BadSequence: 371 def __len__(self): 372 return 4 373 def __getitem__(self, key): 374 if key == 0: 375 return max 376 elif key == 1: 377 return tuple(range(1000000)) 378 elif key in (2, 3): 379 return {} 380 raise IndexError 381 382 f = self.partial(object) 383 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 384 385@unittest.skipUnless(c_functools, 'requires the C _functools module') 386class TestPartialC(TestPartial, unittest.TestCase): 387 if c_functools: 388 partial = c_functools.partial 389 390 class AllowPickle: 391 def __enter__(self): 392 return self 393 def __exit__(self, type, value, tb): 394 return False 395 396 def test_attributes_unwritable(self): 397 # attributes should not be writable 398 p = self.partial(capture, 1, 2, a=10, b=20) 399 self.assertRaises(AttributeError, setattr, p, 'func', map) 400 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 401 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 402 403 p = self.partial(hex) 404 try: 405 del p.__dict__ 406 except TypeError: 407 pass 408 else: 409 self.fail('partial object allowed __dict__ to be deleted') 410 411 def test_manually_adding_non_string_keyword(self): 412 p = self.partial(capture) 413 # Adding a non-string/unicode keyword to partial kwargs 414 p.keywords[1234] = 'value' 415 r = repr(p) 416 self.assertIn('1234', r) 417 self.assertIn("'value'", r) 418 with self.assertRaises(TypeError): 419 p() 420 421 def test_keystr_replaces_value(self): 422 p = self.partial(capture) 423 424 class MutatesYourDict(object): 425 def __str__(self): 426 p.keywords[self] = ['sth2'] 427 return 'astr' 428 429 # Replacing the value during key formatting should keep the original 430 # value alive (at least long enough). 431 p.keywords[MutatesYourDict()] = ['sth'] 432 r = repr(p) 433 self.assertIn('astr', r) 434 self.assertIn("['sth']", r) 435 436 437class TestPartialPy(TestPartial, unittest.TestCase): 438 partial = py_functools.partial 439 440 class AllowPickle: 441 def __init__(self): 442 self._cm = replaced_module("functools", py_functools) 443 def __enter__(self): 444 return self._cm.__enter__() 445 def __exit__(self, type, value, tb): 446 return self._cm.__exit__(type, value, tb) 447 448if c_functools: 449 class CPartialSubclass(c_functools.partial): 450 pass 451 452class PyPartialSubclass(py_functools.partial): 453 pass 454 455@unittest.skipUnless(c_functools, 'requires the C _functools module') 456class TestPartialCSubclass(TestPartialC): 457 if c_functools: 458 partial = CPartialSubclass 459 460 # partial subclasses are not optimized for nested calls 461 test_nested_optimization = None 462 463class TestPartialPySubclass(TestPartialPy): 464 partial = PyPartialSubclass 465 466class TestPartialMethod(unittest.TestCase): 467 468 class A(object): 469 nothing = functools.partialmethod(capture) 470 positional = functools.partialmethod(capture, 1) 471 keywords = functools.partialmethod(capture, a=2) 472 both = functools.partialmethod(capture, 3, b=4) 473 spec_keywords = functools.partialmethod(capture, self=1, func=2) 474 475 nested = functools.partialmethod(positional, 5) 476 477 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 478 479 static = functools.partialmethod(staticmethod(capture), 8) 480 cls = functools.partialmethod(classmethod(capture), d=9) 481 482 a = A() 483 484 def test_arg_combinations(self): 485 self.assertEqual(self.a.nothing(), ((self.a,), {})) 486 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 487 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 488 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 489 490 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 491 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 492 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 493 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 494 495 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 496 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 497 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 498 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 499 500 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 501 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 502 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 503 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 504 505 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 506 507 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 508 509 def test_nested(self): 510 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 511 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 512 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 513 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 514 515 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 516 517 def test_over_partial(self): 518 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 519 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 520 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 521 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 522 523 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 524 525 def test_bound_method_introspection(self): 526 obj = self.a 527 self.assertIs(obj.both.__self__, obj) 528 self.assertIs(obj.nested.__self__, obj) 529 self.assertIs(obj.over_partial.__self__, obj) 530 self.assertIs(obj.cls.__self__, self.A) 531 self.assertIs(self.A.cls.__self__, self.A) 532 533 def test_unbound_method_retrieval(self): 534 obj = self.A 535 self.assertFalse(hasattr(obj.both, "__self__")) 536 self.assertFalse(hasattr(obj.nested, "__self__")) 537 self.assertFalse(hasattr(obj.over_partial, "__self__")) 538 self.assertFalse(hasattr(obj.static, "__self__")) 539 self.assertFalse(hasattr(self.a.static, "__self__")) 540 541 def test_descriptors(self): 542 for obj in [self.A, self.a]: 543 with self.subTest(obj=obj): 544 self.assertEqual(obj.static(), ((8,), {})) 545 self.assertEqual(obj.static(5), ((8, 5), {})) 546 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 547 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 548 549 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 550 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 551 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 552 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 553 554 def test_overriding_keywords(self): 555 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 556 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 557 558 def test_invalid_args(self): 559 with self.assertRaises(TypeError): 560 class B(object): 561 method = functools.partialmethod(None, 1) 562 with self.assertRaises(TypeError): 563 class B: 564 method = functools.partialmethod() 565 with self.assertRaises(TypeError): 566 class B: 567 method = functools.partialmethod(func=capture, a=1) 568 569 def test_repr(self): 570 self.assertEqual(repr(vars(self.A)['both']), 571 'functools.partialmethod({}, 3, b=4)'.format(capture)) 572 573 def test_abstract(self): 574 class Abstract(abc.ABCMeta): 575 576 @abc.abstractmethod 577 def add(self, x, y): 578 pass 579 580 add5 = functools.partialmethod(add, 5) 581 582 self.assertTrue(Abstract.add.__isabstractmethod__) 583 self.assertTrue(Abstract.add5.__isabstractmethod__) 584 585 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 586 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 587 588 def test_positional_only(self): 589 def f(a, b, /): 590 return a + b 591 592 p = functools.partial(f, 1) 593 self.assertEqual(p(2), f(1, 2)) 594 595 596class TestUpdateWrapper(unittest.TestCase): 597 598 def check_wrapper(self, wrapper, wrapped, 599 assigned=functools.WRAPPER_ASSIGNMENTS, 600 updated=functools.WRAPPER_UPDATES): 601 # Check attributes were assigned 602 for name in assigned: 603 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 604 # Check attributes were updated 605 for name in updated: 606 wrapper_attr = getattr(wrapper, name) 607 wrapped_attr = getattr(wrapped, name) 608 for key in wrapped_attr: 609 if name == "__dict__" and key == "__wrapped__": 610 # __wrapped__ is overwritten by the update code 611 continue 612 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 613 # Check __wrapped__ 614 self.assertIs(wrapper.__wrapped__, wrapped) 615 616 617 def _default_update(self): 618 def f(a:'This is a new annotation'): 619 """This is a test""" 620 pass 621 f.attr = 'This is also a test' 622 f.__wrapped__ = "This is a bald faced lie" 623 def wrapper(b:'This is the prior annotation'): 624 pass 625 functools.update_wrapper(wrapper, f) 626 return wrapper, f 627 628 def test_default_update(self): 629 wrapper, f = self._default_update() 630 self.check_wrapper(wrapper, f) 631 self.assertIs(wrapper.__wrapped__, f) 632 self.assertEqual(wrapper.__name__, 'f') 633 self.assertEqual(wrapper.__qualname__, f.__qualname__) 634 self.assertEqual(wrapper.attr, 'This is also a test') 635 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 636 self.assertNotIn('b', wrapper.__annotations__) 637 638 @unittest.skipIf(sys.flags.optimize >= 2, 639 "Docstrings are omitted with -O2 and above") 640 def test_default_update_doc(self): 641 wrapper, f = self._default_update() 642 self.assertEqual(wrapper.__doc__, 'This is a test') 643 644 def test_no_update(self): 645 def f(): 646 """This is a test""" 647 pass 648 f.attr = 'This is also a test' 649 def wrapper(): 650 pass 651 functools.update_wrapper(wrapper, f, (), ()) 652 self.check_wrapper(wrapper, f, (), ()) 653 self.assertEqual(wrapper.__name__, 'wrapper') 654 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 655 self.assertEqual(wrapper.__doc__, None) 656 self.assertEqual(wrapper.__annotations__, {}) 657 self.assertFalse(hasattr(wrapper, 'attr')) 658 659 def test_selective_update(self): 660 def f(): 661 pass 662 f.attr = 'This is a different test' 663 f.dict_attr = dict(a=1, b=2, c=3) 664 def wrapper(): 665 pass 666 wrapper.dict_attr = {} 667 assign = ('attr',) 668 update = ('dict_attr',) 669 functools.update_wrapper(wrapper, f, assign, update) 670 self.check_wrapper(wrapper, f, assign, update) 671 self.assertEqual(wrapper.__name__, 'wrapper') 672 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 673 self.assertEqual(wrapper.__doc__, None) 674 self.assertEqual(wrapper.attr, 'This is a different test') 675 self.assertEqual(wrapper.dict_attr, f.dict_attr) 676 677 def test_missing_attributes(self): 678 def f(): 679 pass 680 def wrapper(): 681 pass 682 wrapper.dict_attr = {} 683 assign = ('attr',) 684 update = ('dict_attr',) 685 # Missing attributes on wrapped object are ignored 686 functools.update_wrapper(wrapper, f, assign, update) 687 self.assertNotIn('attr', wrapper.__dict__) 688 self.assertEqual(wrapper.dict_attr, {}) 689 # Wrapper must have expected attributes for updating 690 del wrapper.dict_attr 691 with self.assertRaises(AttributeError): 692 functools.update_wrapper(wrapper, f, assign, update) 693 wrapper.dict_attr = 1 694 with self.assertRaises(AttributeError): 695 functools.update_wrapper(wrapper, f, assign, update) 696 697 @support.requires_docstrings 698 @unittest.skipIf(sys.flags.optimize >= 2, 699 "Docstrings are omitted with -O2 and above") 700 def test_builtin_update(self): 701 # Test for bug #1576241 702 def wrapper(): 703 pass 704 functools.update_wrapper(wrapper, max) 705 self.assertEqual(wrapper.__name__, 'max') 706 self.assertTrue(wrapper.__doc__.startswith('max(')) 707 self.assertEqual(wrapper.__annotations__, {}) 708 709 710class TestWraps(TestUpdateWrapper): 711 712 def _default_update(self): 713 def f(): 714 """This is a test""" 715 pass 716 f.attr = 'This is also a test' 717 f.__wrapped__ = "This is still a bald faced lie" 718 @functools.wraps(f) 719 def wrapper(): 720 pass 721 return wrapper, f 722 723 def test_default_update(self): 724 wrapper, f = self._default_update() 725 self.check_wrapper(wrapper, f) 726 self.assertEqual(wrapper.__name__, 'f') 727 self.assertEqual(wrapper.__qualname__, f.__qualname__) 728 self.assertEqual(wrapper.attr, 'This is also a test') 729 730 @unittest.skipIf(sys.flags.optimize >= 2, 731 "Docstrings are omitted with -O2 and above") 732 def test_default_update_doc(self): 733 wrapper, _ = self._default_update() 734 self.assertEqual(wrapper.__doc__, 'This is a test') 735 736 def test_no_update(self): 737 def f(): 738 """This is a test""" 739 pass 740 f.attr = 'This is also a test' 741 @functools.wraps(f, (), ()) 742 def wrapper(): 743 pass 744 self.check_wrapper(wrapper, f, (), ()) 745 self.assertEqual(wrapper.__name__, 'wrapper') 746 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 747 self.assertEqual(wrapper.__doc__, None) 748 self.assertFalse(hasattr(wrapper, 'attr')) 749 750 def test_selective_update(self): 751 def f(): 752 pass 753 f.attr = 'This is a different test' 754 f.dict_attr = dict(a=1, b=2, c=3) 755 def add_dict_attr(f): 756 f.dict_attr = {} 757 return f 758 assign = ('attr',) 759 update = ('dict_attr',) 760 @functools.wraps(f, assign, update) 761 @add_dict_attr 762 def wrapper(): 763 pass 764 self.check_wrapper(wrapper, f, assign, update) 765 self.assertEqual(wrapper.__name__, 'wrapper') 766 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 767 self.assertEqual(wrapper.__doc__, None) 768 self.assertEqual(wrapper.attr, 'This is a different test') 769 self.assertEqual(wrapper.dict_attr, f.dict_attr) 770 771 772class TestReduce: 773 def test_reduce(self): 774 class Squares: 775 def __init__(self, max): 776 self.max = max 777 self.sofar = [] 778 779 def __len__(self): 780 return len(self.sofar) 781 782 def __getitem__(self, i): 783 if not 0 <= i < self.max: raise IndexError 784 n = len(self.sofar) 785 while n <= i: 786 self.sofar.append(n*n) 787 n += 1 788 return self.sofar[i] 789 def add(x, y): 790 return x + y 791 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 792 self.assertEqual( 793 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 794 ['a','c','d','w'] 795 ) 796 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 797 self.assertEqual( 798 self.reduce(lambda x, y: x*y, range(2,21), 1), 799 2432902008176640000 800 ) 801 self.assertEqual(self.reduce(add, Squares(10)), 285) 802 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 803 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 804 self.assertRaises(TypeError, self.reduce) 805 self.assertRaises(TypeError, self.reduce, 42, 42) 806 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 807 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 808 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 809 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 810 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 811 self.assertRaises(TypeError, self.reduce, add, "") 812 self.assertRaises(TypeError, self.reduce, add, ()) 813 self.assertRaises(TypeError, self.reduce, add, object()) 814 815 class TestFailingIter: 816 def __iter__(self): 817 raise RuntimeError 818 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 819 820 self.assertEqual(self.reduce(add, [], None), None) 821 self.assertEqual(self.reduce(add, [], 42), 42) 822 823 class BadSeq: 824 def __getitem__(self, index): 825 raise ValueError 826 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 827 828 # Test reduce()'s use of iterators. 829 def test_iterator_usage(self): 830 class SequenceClass: 831 def __init__(self, n): 832 self.n = n 833 def __getitem__(self, i): 834 if 0 <= i < self.n: 835 return i 836 else: 837 raise IndexError 838 839 from operator import add 840 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 841 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 842 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 843 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 844 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 845 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 846 847 d = {"one": 1, "two": 2, "three": 3} 848 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 849 850 851@unittest.skipUnless(c_functools, 'requires the C _functools module') 852class TestReduceC(TestReduce, unittest.TestCase): 853 if c_functools: 854 reduce = c_functools.reduce 855 856 857class TestReducePy(TestReduce, unittest.TestCase): 858 reduce = staticmethod(py_functools.reduce) 859 860 861class TestCmpToKey: 862 863 def test_cmp_to_key(self): 864 def cmp1(x, y): 865 return (x > y) - (x < y) 866 key = self.cmp_to_key(cmp1) 867 self.assertEqual(key(3), key(3)) 868 self.assertGreater(key(3), key(1)) 869 self.assertGreaterEqual(key(3), key(3)) 870 871 def cmp2(x, y): 872 return int(x) - int(y) 873 key = self.cmp_to_key(cmp2) 874 self.assertEqual(key(4.0), key('4')) 875 self.assertLess(key(2), key('35')) 876 self.assertLessEqual(key(2), key('35')) 877 self.assertNotEqual(key(2), key('35')) 878 879 def test_cmp_to_key_arguments(self): 880 def cmp1(x, y): 881 return (x > y) - (x < y) 882 key = self.cmp_to_key(mycmp=cmp1) 883 self.assertEqual(key(obj=3), key(obj=3)) 884 self.assertGreater(key(obj=3), key(obj=1)) 885 with self.assertRaises((TypeError, AttributeError)): 886 key(3) > 1 # rhs is not a K object 887 with self.assertRaises((TypeError, AttributeError)): 888 1 < key(3) # lhs is not a K object 889 with self.assertRaises(TypeError): 890 key = self.cmp_to_key() # too few args 891 with self.assertRaises(TypeError): 892 key = self.cmp_to_key(cmp1, None) # too many args 893 key = self.cmp_to_key(cmp1) 894 with self.assertRaises(TypeError): 895 key() # too few args 896 with self.assertRaises(TypeError): 897 key(None, None) # too many args 898 899 def test_bad_cmp(self): 900 def cmp1(x, y): 901 raise ZeroDivisionError 902 key = self.cmp_to_key(cmp1) 903 with self.assertRaises(ZeroDivisionError): 904 key(3) > key(1) 905 906 class BadCmp: 907 def __lt__(self, other): 908 raise ZeroDivisionError 909 def cmp1(x, y): 910 return BadCmp() 911 with self.assertRaises(ZeroDivisionError): 912 key(3) > key(1) 913 914 def test_obj_field(self): 915 def cmp1(x, y): 916 return (x > y) - (x < y) 917 key = self.cmp_to_key(mycmp=cmp1) 918 self.assertEqual(key(50).obj, 50) 919 920 def test_sort_int(self): 921 def mycmp(x, y): 922 return y - x 923 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 924 [4, 3, 2, 1, 0]) 925 926 def test_sort_int_str(self): 927 def mycmp(x, y): 928 x, y = int(x), int(y) 929 return (x > y) - (x < y) 930 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 931 values = sorted(values, key=self.cmp_to_key(mycmp)) 932 self.assertEqual([int(value) for value in values], 933 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 934 935 def test_hash(self): 936 def mycmp(x, y): 937 return y - x 938 key = self.cmp_to_key(mycmp) 939 k = key(10) 940 self.assertRaises(TypeError, hash, k) 941 self.assertNotIsInstance(k, collections.abc.Hashable) 942 943 944@unittest.skipUnless(c_functools, 'requires the C _functools module') 945class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 946 if c_functools: 947 cmp_to_key = c_functools.cmp_to_key 948 949 950class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 951 cmp_to_key = staticmethod(py_functools.cmp_to_key) 952 953 954class TestTotalOrdering(unittest.TestCase): 955 956 def test_total_ordering_lt(self): 957 @functools.total_ordering 958 class A: 959 def __init__(self, value): 960 self.value = value 961 def __lt__(self, other): 962 return self.value < other.value 963 def __eq__(self, other): 964 return self.value == other.value 965 self.assertTrue(A(1) < A(2)) 966 self.assertTrue(A(2) > A(1)) 967 self.assertTrue(A(1) <= A(2)) 968 self.assertTrue(A(2) >= A(1)) 969 self.assertTrue(A(2) <= A(2)) 970 self.assertTrue(A(2) >= A(2)) 971 self.assertFalse(A(1) > A(2)) 972 973 def test_total_ordering_le(self): 974 @functools.total_ordering 975 class A: 976 def __init__(self, value): 977 self.value = value 978 def __le__(self, other): 979 return self.value <= other.value 980 def __eq__(self, other): 981 return self.value == other.value 982 self.assertTrue(A(1) < A(2)) 983 self.assertTrue(A(2) > A(1)) 984 self.assertTrue(A(1) <= A(2)) 985 self.assertTrue(A(2) >= A(1)) 986 self.assertTrue(A(2) <= A(2)) 987 self.assertTrue(A(2) >= A(2)) 988 self.assertFalse(A(1) >= A(2)) 989 990 def test_total_ordering_gt(self): 991 @functools.total_ordering 992 class A: 993 def __init__(self, value): 994 self.value = value 995 def __gt__(self, other): 996 return self.value > other.value 997 def __eq__(self, other): 998 return self.value == other.value 999 self.assertTrue(A(1) < A(2)) 1000 self.assertTrue(A(2) > A(1)) 1001 self.assertTrue(A(1) <= A(2)) 1002 self.assertTrue(A(2) >= A(1)) 1003 self.assertTrue(A(2) <= A(2)) 1004 self.assertTrue(A(2) >= A(2)) 1005 self.assertFalse(A(2) < A(1)) 1006 1007 def test_total_ordering_ge(self): 1008 @functools.total_ordering 1009 class A: 1010 def __init__(self, value): 1011 self.value = value 1012 def __ge__(self, other): 1013 return self.value >= other.value 1014 def __eq__(self, other): 1015 return self.value == other.value 1016 self.assertTrue(A(1) < A(2)) 1017 self.assertTrue(A(2) > A(1)) 1018 self.assertTrue(A(1) <= A(2)) 1019 self.assertTrue(A(2) >= A(1)) 1020 self.assertTrue(A(2) <= A(2)) 1021 self.assertTrue(A(2) >= A(2)) 1022 self.assertFalse(A(2) <= A(1)) 1023 1024 def test_total_ordering_no_overwrite(self): 1025 # new methods should not overwrite existing 1026 @functools.total_ordering 1027 class A(int): 1028 pass 1029 self.assertTrue(A(1) < A(2)) 1030 self.assertTrue(A(2) > A(1)) 1031 self.assertTrue(A(1) <= A(2)) 1032 self.assertTrue(A(2) >= A(1)) 1033 self.assertTrue(A(2) <= A(2)) 1034 self.assertTrue(A(2) >= A(2)) 1035 1036 def test_no_operations_defined(self): 1037 with self.assertRaises(ValueError): 1038 @functools.total_ordering 1039 class A: 1040 pass 1041 1042 def test_type_error_when_not_implemented(self): 1043 # bug 10042; ensure stack overflow does not occur 1044 # when decorated types return NotImplemented 1045 @functools.total_ordering 1046 class ImplementsLessThan: 1047 def __init__(self, value): 1048 self.value = value 1049 def __eq__(self, other): 1050 if isinstance(other, ImplementsLessThan): 1051 return self.value == other.value 1052 return False 1053 def __lt__(self, other): 1054 if isinstance(other, ImplementsLessThan): 1055 return self.value < other.value 1056 return NotImplemented 1057 1058 @functools.total_ordering 1059 class ImplementsGreaterThan: 1060 def __init__(self, value): 1061 self.value = value 1062 def __eq__(self, other): 1063 if isinstance(other, ImplementsGreaterThan): 1064 return self.value == other.value 1065 return False 1066 def __gt__(self, other): 1067 if isinstance(other, ImplementsGreaterThan): 1068 return self.value > other.value 1069 return NotImplemented 1070 1071 @functools.total_ordering 1072 class ImplementsLessThanEqualTo: 1073 def __init__(self, value): 1074 self.value = value 1075 def __eq__(self, other): 1076 if isinstance(other, ImplementsLessThanEqualTo): 1077 return self.value == other.value 1078 return False 1079 def __le__(self, other): 1080 if isinstance(other, ImplementsLessThanEqualTo): 1081 return self.value <= other.value 1082 return NotImplemented 1083 1084 @functools.total_ordering 1085 class ImplementsGreaterThanEqualTo: 1086 def __init__(self, value): 1087 self.value = value 1088 def __eq__(self, other): 1089 if isinstance(other, ImplementsGreaterThanEqualTo): 1090 return self.value == other.value 1091 return False 1092 def __ge__(self, other): 1093 if isinstance(other, ImplementsGreaterThanEqualTo): 1094 return self.value >= other.value 1095 return NotImplemented 1096 1097 @functools.total_ordering 1098 class ComparatorNotImplemented: 1099 def __init__(self, value): 1100 self.value = value 1101 def __eq__(self, other): 1102 if isinstance(other, ComparatorNotImplemented): 1103 return self.value == other.value 1104 return False 1105 def __lt__(self, other): 1106 return NotImplemented 1107 1108 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1109 ImplementsLessThan(-1) < 1 1110 1111 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1112 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1113 1114 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1115 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1116 1117 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1118 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1119 1120 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1121 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1122 1123 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1124 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1125 1126 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1127 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1128 1129 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1130 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1131 1132 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1133 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1134 1135 with self.subTest("GE when equal"): 1136 a = ComparatorNotImplemented(8) 1137 b = ComparatorNotImplemented(8) 1138 self.assertEqual(a, b) 1139 with self.assertRaises(TypeError): 1140 a >= b 1141 1142 with self.subTest("LE when equal"): 1143 a = ComparatorNotImplemented(9) 1144 b = ComparatorNotImplemented(9) 1145 self.assertEqual(a, b) 1146 with self.assertRaises(TypeError): 1147 a <= b 1148 1149 def test_pickle(self): 1150 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1151 for name in '__lt__', '__gt__', '__le__', '__ge__': 1152 with self.subTest(method=name, proto=proto): 1153 method = getattr(Orderable_LT, name) 1154 method_copy = pickle.loads(pickle.dumps(method, proto)) 1155 self.assertIs(method_copy, method) 1156 1157 1158 def test_total_ordering_for_metaclasses_issue_44605(self): 1159 1160 @functools.total_ordering 1161 class SortableMeta(type): 1162 def __new__(cls, name, bases, ns): 1163 return super().__new__(cls, name, bases, ns) 1164 1165 def __lt__(self, other): 1166 if not isinstance(other, SortableMeta): 1167 pass 1168 return self.__name__ < other.__name__ 1169 1170 def __eq__(self, other): 1171 if not isinstance(other, SortableMeta): 1172 pass 1173 return self.__name__ == other.__name__ 1174 1175 class B(metaclass=SortableMeta): 1176 pass 1177 1178 class A(metaclass=SortableMeta): 1179 pass 1180 1181 self.assertTrue(A < B) 1182 self.assertFalse(A > B) 1183 1184 1185@functools.total_ordering 1186class Orderable_LT: 1187 def __init__(self, value): 1188 self.value = value 1189 def __lt__(self, other): 1190 return self.value < other.value 1191 def __eq__(self, other): 1192 return self.value == other.value 1193 1194 1195class TestCache: 1196 # This tests that the pass-through is working as designed. 1197 # The underlying functionality is tested in TestLRU. 1198 1199 def test_cache(self): 1200 @self.module.cache 1201 def fib(n): 1202 if n < 2: 1203 return n 1204 return fib(n-1) + fib(n-2) 1205 self.assertEqual([fib(n) for n in range(16)], 1206 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1207 self.assertEqual(fib.cache_info(), 1208 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1209 fib.cache_clear() 1210 self.assertEqual(fib.cache_info(), 1211 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1212 1213 1214class TestLRU: 1215 1216 def test_lru(self): 1217 def orig(x, y): 1218 return 3 * x + y 1219 f = self.module.lru_cache(maxsize=20)(orig) 1220 hits, misses, maxsize, currsize = f.cache_info() 1221 self.assertEqual(maxsize, 20) 1222 self.assertEqual(currsize, 0) 1223 self.assertEqual(hits, 0) 1224 self.assertEqual(misses, 0) 1225 1226 domain = range(5) 1227 for i in range(1000): 1228 x, y = choice(domain), choice(domain) 1229 actual = f(x, y) 1230 expected = orig(x, y) 1231 self.assertEqual(actual, expected) 1232 hits, misses, maxsize, currsize = f.cache_info() 1233 self.assertTrue(hits > misses) 1234 self.assertEqual(hits + misses, 1000) 1235 self.assertEqual(currsize, 20) 1236 1237 f.cache_clear() # test clearing 1238 hits, misses, maxsize, currsize = f.cache_info() 1239 self.assertEqual(hits, 0) 1240 self.assertEqual(misses, 0) 1241 self.assertEqual(currsize, 0) 1242 f(x, y) 1243 hits, misses, maxsize, currsize = f.cache_info() 1244 self.assertEqual(hits, 0) 1245 self.assertEqual(misses, 1) 1246 self.assertEqual(currsize, 1) 1247 1248 # Test bypassing the cache 1249 self.assertIs(f.__wrapped__, orig) 1250 f.__wrapped__(x, y) 1251 hits, misses, maxsize, currsize = f.cache_info() 1252 self.assertEqual(hits, 0) 1253 self.assertEqual(misses, 1) 1254 self.assertEqual(currsize, 1) 1255 1256 # test size zero (which means "never-cache") 1257 @self.module.lru_cache(0) 1258 def f(): 1259 nonlocal f_cnt 1260 f_cnt += 1 1261 return 20 1262 self.assertEqual(f.cache_info().maxsize, 0) 1263 f_cnt = 0 1264 for i in range(5): 1265 self.assertEqual(f(), 20) 1266 self.assertEqual(f_cnt, 5) 1267 hits, misses, maxsize, currsize = f.cache_info() 1268 self.assertEqual(hits, 0) 1269 self.assertEqual(misses, 5) 1270 self.assertEqual(currsize, 0) 1271 1272 # test size one 1273 @self.module.lru_cache(1) 1274 def f(): 1275 nonlocal f_cnt 1276 f_cnt += 1 1277 return 20 1278 self.assertEqual(f.cache_info().maxsize, 1) 1279 f_cnt = 0 1280 for i in range(5): 1281 self.assertEqual(f(), 20) 1282 self.assertEqual(f_cnt, 1) 1283 hits, misses, maxsize, currsize = f.cache_info() 1284 self.assertEqual(hits, 4) 1285 self.assertEqual(misses, 1) 1286 self.assertEqual(currsize, 1) 1287 1288 # test size two 1289 @self.module.lru_cache(2) 1290 def f(x): 1291 nonlocal f_cnt 1292 f_cnt += 1 1293 return x*10 1294 self.assertEqual(f.cache_info().maxsize, 2) 1295 f_cnt = 0 1296 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1297 # * * * * 1298 self.assertEqual(f(x), x*10) 1299 self.assertEqual(f_cnt, 4) 1300 hits, misses, maxsize, currsize = f.cache_info() 1301 self.assertEqual(hits, 12) 1302 self.assertEqual(misses, 4) 1303 self.assertEqual(currsize, 2) 1304 1305 def test_lru_no_args(self): 1306 @self.module.lru_cache 1307 def square(x): 1308 return x ** 2 1309 1310 self.assertEqual(list(map(square, [10, 20, 10])), 1311 [100, 400, 100]) 1312 self.assertEqual(square.cache_info().hits, 1) 1313 self.assertEqual(square.cache_info().misses, 2) 1314 self.assertEqual(square.cache_info().maxsize, 128) 1315 self.assertEqual(square.cache_info().currsize, 2) 1316 1317 def test_lru_bug_35780(self): 1318 # C version of the lru_cache was not checking to see if 1319 # the user function call has already modified the cache 1320 # (this arises in recursive calls and in multi-threading). 1321 # This cause the cache to have orphan links not referenced 1322 # by the cache dictionary. 1323 1324 once = True # Modified by f(x) below 1325 1326 @self.module.lru_cache(maxsize=10) 1327 def f(x): 1328 nonlocal once 1329 rv = f'.{x}.' 1330 if x == 20 and once: 1331 once = False 1332 rv = f(x) 1333 return rv 1334 1335 # Fill the cache 1336 for x in range(15): 1337 self.assertEqual(f(x), f'.{x}.') 1338 self.assertEqual(f.cache_info().currsize, 10) 1339 1340 # Make a recursive call and make sure the cache remains full 1341 self.assertEqual(f(20), '.20.') 1342 self.assertEqual(f.cache_info().currsize, 10) 1343 1344 def test_lru_bug_36650(self): 1345 # C version of lru_cache was treating a call with an empty **kwargs 1346 # dictionary as being distinct from a call with no keywords at all. 1347 # This did not result in an incorrect answer, but it did trigger 1348 # an unexpected cache miss. 1349 1350 @self.module.lru_cache() 1351 def f(x): 1352 pass 1353 1354 f(0) 1355 f(0, **{}) 1356 self.assertEqual(f.cache_info().hits, 1) 1357 1358 def test_lru_hash_only_once(self): 1359 # To protect against weird reentrancy bugs and to improve 1360 # efficiency when faced with slow __hash__ methods, the 1361 # LRU cache guarantees that it will only call __hash__ 1362 # only once per use as an argument to the cached function. 1363 1364 @self.module.lru_cache(maxsize=1) 1365 def f(x, y): 1366 return x * 3 + y 1367 1368 # Simulate the integer 5 1369 mock_int = unittest.mock.Mock() 1370 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1371 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1372 1373 # Add to cache: One use as an argument gives one call 1374 self.assertEqual(f(mock_int, 1), 16) 1375 self.assertEqual(mock_int.__hash__.call_count, 1) 1376 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1377 1378 # Cache hit: One use as an argument gives one additional call 1379 self.assertEqual(f(mock_int, 1), 16) 1380 self.assertEqual(mock_int.__hash__.call_count, 2) 1381 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1382 1383 # Cache eviction: No use as an argument gives no additional call 1384 self.assertEqual(f(6, 2), 20) 1385 self.assertEqual(mock_int.__hash__.call_count, 2) 1386 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1387 1388 # Cache miss: One use as an argument gives one additional call 1389 self.assertEqual(f(mock_int, 1), 16) 1390 self.assertEqual(mock_int.__hash__.call_count, 3) 1391 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1392 1393 def test_lru_reentrancy_with_len(self): 1394 # Test to make sure the LRU cache code isn't thrown-off by 1395 # caching the built-in len() function. Since len() can be 1396 # cached, we shouldn't use it inside the lru code itself. 1397 old_len = builtins.len 1398 try: 1399 builtins.len = self.module.lru_cache(4)(len) 1400 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1401 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1402 finally: 1403 builtins.len = old_len 1404 1405 def test_lru_star_arg_handling(self): 1406 # Test regression that arose in ea064ff3c10f 1407 @functools.lru_cache() 1408 def f(*args): 1409 return args 1410 1411 self.assertEqual(f(1, 2), (1, 2)) 1412 self.assertEqual(f((1, 2)), ((1, 2),)) 1413 1414 def test_lru_type_error(self): 1415 # Regression test for issue #28653. 1416 # lru_cache was leaking when one of the arguments 1417 # wasn't cacheable. 1418 1419 @functools.lru_cache(maxsize=None) 1420 def infinite_cache(o): 1421 pass 1422 1423 @functools.lru_cache(maxsize=10) 1424 def limited_cache(o): 1425 pass 1426 1427 with self.assertRaises(TypeError): 1428 infinite_cache([]) 1429 1430 with self.assertRaises(TypeError): 1431 limited_cache([]) 1432 1433 def test_lru_with_maxsize_none(self): 1434 @self.module.lru_cache(maxsize=None) 1435 def fib(n): 1436 if n < 2: 1437 return n 1438 return fib(n-1) + fib(n-2) 1439 self.assertEqual([fib(n) for n in range(16)], 1440 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1441 self.assertEqual(fib.cache_info(), 1442 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1443 fib.cache_clear() 1444 self.assertEqual(fib.cache_info(), 1445 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1446 1447 def test_lru_with_maxsize_negative(self): 1448 @self.module.lru_cache(maxsize=-10) 1449 def eq(n): 1450 return n 1451 for i in (0, 1): 1452 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1453 self.assertEqual(eq.cache_info(), 1454 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1455 1456 def test_lru_with_exceptions(self): 1457 # Verify that user_function exceptions get passed through without 1458 # creating a hard-to-read chained exception. 1459 # http://bugs.python.org/issue13177 1460 for maxsize in (None, 128): 1461 @self.module.lru_cache(maxsize) 1462 def func(i): 1463 return 'abc'[i] 1464 self.assertEqual(func(0), 'a') 1465 with self.assertRaises(IndexError) as cm: 1466 func(15) 1467 self.assertIsNone(cm.exception.__context__) 1468 # Verify that the previous exception did not result in a cached entry 1469 with self.assertRaises(IndexError): 1470 func(15) 1471 1472 def test_lru_with_types(self): 1473 for maxsize in (None, 128): 1474 @self.module.lru_cache(maxsize=maxsize, typed=True) 1475 def square(x): 1476 return x * x 1477 self.assertEqual(square(3), 9) 1478 self.assertEqual(type(square(3)), type(9)) 1479 self.assertEqual(square(3.0), 9.0) 1480 self.assertEqual(type(square(3.0)), type(9.0)) 1481 self.assertEqual(square(x=3), 9) 1482 self.assertEqual(type(square(x=3)), type(9)) 1483 self.assertEqual(square(x=3.0), 9.0) 1484 self.assertEqual(type(square(x=3.0)), type(9.0)) 1485 self.assertEqual(square.cache_info().hits, 4) 1486 self.assertEqual(square.cache_info().misses, 4) 1487 1488 def test_lru_with_keyword_args(self): 1489 @self.module.lru_cache() 1490 def fib(n): 1491 if n < 2: 1492 return n 1493 return fib(n=n-1) + fib(n=n-2) 1494 self.assertEqual( 1495 [fib(n=number) for number in range(16)], 1496 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1497 ) 1498 self.assertEqual(fib.cache_info(), 1499 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1500 fib.cache_clear() 1501 self.assertEqual(fib.cache_info(), 1502 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1503 1504 def test_lru_with_keyword_args_maxsize_none(self): 1505 @self.module.lru_cache(maxsize=None) 1506 def fib(n): 1507 if n < 2: 1508 return n 1509 return fib(n=n-1) + fib(n=n-2) 1510 self.assertEqual([fib(n=number) for number in range(16)], 1511 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1512 self.assertEqual(fib.cache_info(), 1513 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1514 fib.cache_clear() 1515 self.assertEqual(fib.cache_info(), 1516 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1517 1518 def test_kwargs_order(self): 1519 # PEP 468: Preserving Keyword Argument Order 1520 @self.module.lru_cache(maxsize=10) 1521 def f(**kwargs): 1522 return list(kwargs.items()) 1523 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1524 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1525 self.assertEqual(f.cache_info(), 1526 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1527 1528 def test_lru_cache_decoration(self): 1529 def f(zomg: 'zomg_annotation'): 1530 """f doc string""" 1531 return 42 1532 g = self.module.lru_cache()(f) 1533 for attr in self.module.WRAPPER_ASSIGNMENTS: 1534 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1535 1536 def test_lru_cache_threaded(self): 1537 n, m = 5, 11 1538 def orig(x, y): 1539 return 3 * x + y 1540 f = self.module.lru_cache(maxsize=n*m)(orig) 1541 hits, misses, maxsize, currsize = f.cache_info() 1542 self.assertEqual(currsize, 0) 1543 1544 start = threading.Event() 1545 def full(k): 1546 start.wait(10) 1547 for _ in range(m): 1548 self.assertEqual(f(k, 0), orig(k, 0)) 1549 1550 def clear(): 1551 start.wait(10) 1552 for _ in range(2*m): 1553 f.cache_clear() 1554 1555 orig_si = sys.getswitchinterval() 1556 support.setswitchinterval(1e-6) 1557 try: 1558 # create n threads in order to fill cache 1559 threads = [threading.Thread(target=full, args=[k]) 1560 for k in range(n)] 1561 with support.start_threads(threads): 1562 start.set() 1563 1564 hits, misses, maxsize, currsize = f.cache_info() 1565 if self.module is py_functools: 1566 # XXX: Why can be not equal? 1567 self.assertLessEqual(misses, n) 1568 self.assertLessEqual(hits, m*n - misses) 1569 else: 1570 self.assertEqual(misses, n) 1571 self.assertEqual(hits, m*n - misses) 1572 self.assertEqual(currsize, n) 1573 1574 # create n threads in order to fill cache and 1 to clear it 1575 threads = [threading.Thread(target=clear)] 1576 threads += [threading.Thread(target=full, args=[k]) 1577 for k in range(n)] 1578 start.clear() 1579 with support.start_threads(threads): 1580 start.set() 1581 finally: 1582 sys.setswitchinterval(orig_si) 1583 1584 def test_lru_cache_threaded2(self): 1585 # Simultaneous call with the same arguments 1586 n, m = 5, 7 1587 start = threading.Barrier(n+1) 1588 pause = threading.Barrier(n+1) 1589 stop = threading.Barrier(n+1) 1590 @self.module.lru_cache(maxsize=m*n) 1591 def f(x): 1592 pause.wait(10) 1593 return 3 * x 1594 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1595 def test(): 1596 for i in range(m): 1597 start.wait(10) 1598 self.assertEqual(f(i), 3 * i) 1599 stop.wait(10) 1600 threads = [threading.Thread(target=test) for k in range(n)] 1601 with support.start_threads(threads): 1602 for i in range(m): 1603 start.wait(10) 1604 stop.reset() 1605 pause.wait(10) 1606 start.reset() 1607 stop.wait(10) 1608 pause.reset() 1609 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1610 1611 def test_lru_cache_threaded3(self): 1612 @self.module.lru_cache(maxsize=2) 1613 def f(x): 1614 time.sleep(.01) 1615 return 3 * x 1616 def test(i, x): 1617 with self.subTest(thread=i): 1618 self.assertEqual(f(x), 3 * x, i) 1619 threads = [threading.Thread(target=test, args=(i, v)) 1620 for i, v in enumerate([1, 2, 2, 3, 2])] 1621 with support.start_threads(threads): 1622 pass 1623 1624 def test_need_for_rlock(self): 1625 # This will deadlock on an LRU cache that uses a regular lock 1626 1627 @self.module.lru_cache(maxsize=10) 1628 def test_func(x): 1629 'Used to demonstrate a reentrant lru_cache call within a single thread' 1630 return x 1631 1632 class DoubleEq: 1633 'Demonstrate a reentrant lru_cache call within a single thread' 1634 def __init__(self, x): 1635 self.x = x 1636 def __hash__(self): 1637 return self.x 1638 def __eq__(self, other): 1639 if self.x == 2: 1640 test_func(DoubleEq(1)) 1641 return self.x == other.x 1642 1643 test_func(DoubleEq(1)) # Load the cache 1644 test_func(DoubleEq(2)) # Load the cache 1645 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1646 DoubleEq(2)) # Verify the correct return value 1647 1648 def test_lru_method(self): 1649 class X(int): 1650 f_cnt = 0 1651 @self.module.lru_cache(2) 1652 def f(self, x): 1653 self.f_cnt += 1 1654 return x*10+self 1655 a = X(5) 1656 b = X(5) 1657 c = X(7) 1658 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1659 1660 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1661 self.assertEqual(a.f(x), x*10 + 5) 1662 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1663 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1664 1665 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1666 self.assertEqual(b.f(x), x*10 + 5) 1667 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1668 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1669 1670 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1671 self.assertEqual(c.f(x), x*10 + 7) 1672 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1673 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1674 1675 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1676 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1677 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1678 1679 def test_pickle(self): 1680 cls = self.__class__ 1681 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1682 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1683 with self.subTest(proto=proto, func=f): 1684 f_copy = pickle.loads(pickle.dumps(f, proto)) 1685 self.assertIs(f_copy, f) 1686 1687 def test_copy(self): 1688 cls = self.__class__ 1689 def orig(x, y): 1690 return 3 * x + y 1691 part = self.module.partial(orig, 2) 1692 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1693 self.module.lru_cache(2)(part)) 1694 for f in funcs: 1695 with self.subTest(func=f): 1696 f_copy = copy.copy(f) 1697 self.assertIs(f_copy, f) 1698 1699 def test_deepcopy(self): 1700 cls = self.__class__ 1701 def orig(x, y): 1702 return 3 * x + y 1703 part = self.module.partial(orig, 2) 1704 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1705 self.module.lru_cache(2)(part)) 1706 for f in funcs: 1707 with self.subTest(func=f): 1708 f_copy = copy.deepcopy(f) 1709 self.assertIs(f_copy, f) 1710 1711 def test_lru_cache_parameters(self): 1712 @self.module.lru_cache(maxsize=2) 1713 def f(): 1714 return 1 1715 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) 1716 1717 @self.module.lru_cache(maxsize=1000, typed=True) 1718 def f(): 1719 return 1 1720 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) 1721 1722 def test_lru_cache_weakrefable(self): 1723 @self.module.lru_cache 1724 def test_function(x): 1725 return x 1726 1727 class A: 1728 @self.module.lru_cache 1729 def test_method(self, x): 1730 return (self, x) 1731 1732 @staticmethod 1733 @self.module.lru_cache 1734 def test_staticmethod(x): 1735 return (self, x) 1736 1737 refs = [weakref.ref(test_function), 1738 weakref.ref(A.test_method), 1739 weakref.ref(A.test_staticmethod)] 1740 1741 for ref in refs: 1742 self.assertIsNotNone(ref()) 1743 1744 del A 1745 del test_function 1746 gc.collect() 1747 1748 for ref in refs: 1749 self.assertIsNone(ref()) 1750 1751 1752@py_functools.lru_cache() 1753def py_cached_func(x, y): 1754 return 3 * x + y 1755 1756@c_functools.lru_cache() 1757def c_cached_func(x, y): 1758 return 3 * x + y 1759 1760 1761class TestLRUPy(TestLRU, unittest.TestCase): 1762 module = py_functools 1763 cached_func = py_cached_func, 1764 1765 @module.lru_cache() 1766 def cached_meth(self, x, y): 1767 return 3 * x + y 1768 1769 @staticmethod 1770 @module.lru_cache() 1771 def cached_staticmeth(x, y): 1772 return 3 * x + y 1773 1774 1775class TestLRUC(TestLRU, unittest.TestCase): 1776 module = c_functools 1777 cached_func = c_cached_func, 1778 1779 @module.lru_cache() 1780 def cached_meth(self, x, y): 1781 return 3 * x + y 1782 1783 @staticmethod 1784 @module.lru_cache() 1785 def cached_staticmeth(x, y): 1786 return 3 * x + y 1787 1788 1789class TestSingleDispatch(unittest.TestCase): 1790 def test_simple_overloads(self): 1791 @functools.singledispatch 1792 def g(obj): 1793 return "base" 1794 def g_int(i): 1795 return "integer" 1796 g.register(int, g_int) 1797 self.assertEqual(g("str"), "base") 1798 self.assertEqual(g(1), "integer") 1799 self.assertEqual(g([1,2,3]), "base") 1800 1801 def test_mro(self): 1802 @functools.singledispatch 1803 def g(obj): 1804 return "base" 1805 class A: 1806 pass 1807 class C(A): 1808 pass 1809 class B(A): 1810 pass 1811 class D(C, B): 1812 pass 1813 def g_A(a): 1814 return "A" 1815 def g_B(b): 1816 return "B" 1817 g.register(A, g_A) 1818 g.register(B, g_B) 1819 self.assertEqual(g(A()), "A") 1820 self.assertEqual(g(B()), "B") 1821 self.assertEqual(g(C()), "A") 1822 self.assertEqual(g(D()), "B") 1823 1824 def test_register_decorator(self): 1825 @functools.singledispatch 1826 def g(obj): 1827 return "base" 1828 @g.register(int) 1829 def g_int(i): 1830 return "int %s" % (i,) 1831 self.assertEqual(g(""), "base") 1832 self.assertEqual(g(12), "int 12") 1833 self.assertIs(g.dispatch(int), g_int) 1834 self.assertIs(g.dispatch(object), g.dispatch(str)) 1835 # Note: in the assert above this is not g. 1836 # @singledispatch returns the wrapper. 1837 1838 def test_wrapping_attributes(self): 1839 @functools.singledispatch 1840 def g(obj): 1841 "Simple test" 1842 return "Test" 1843 self.assertEqual(g.__name__, "g") 1844 if sys.flags.optimize < 2: 1845 self.assertEqual(g.__doc__, "Simple test") 1846 1847 @unittest.skipUnless(decimal, 'requires _decimal') 1848 @support.cpython_only 1849 def test_c_classes(self): 1850 @functools.singledispatch 1851 def g(obj): 1852 return "base" 1853 @g.register(decimal.DecimalException) 1854 def _(obj): 1855 return obj.args 1856 subn = decimal.Subnormal("Exponent < Emin") 1857 rnd = decimal.Rounded("Number got rounded") 1858 self.assertEqual(g(subn), ("Exponent < Emin",)) 1859 self.assertEqual(g(rnd), ("Number got rounded",)) 1860 @g.register(decimal.Subnormal) 1861 def _(obj): 1862 return "Too small to care." 1863 self.assertEqual(g(subn), "Too small to care.") 1864 self.assertEqual(g(rnd), ("Number got rounded",)) 1865 1866 def test_compose_mro(self): 1867 # None of the examples in this test depend on haystack ordering. 1868 c = collections.abc 1869 mro = functools._compose_mro 1870 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1871 for haystack in permutations(bases): 1872 m = mro(dict, haystack) 1873 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1874 c.Collection, c.Sized, c.Iterable, 1875 c.Container, object]) 1876 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1877 for haystack in permutations(bases): 1878 m = mro(collections.ChainMap, haystack) 1879 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1880 c.Collection, c.Sized, c.Iterable, 1881 c.Container, object]) 1882 1883 # If there's a generic function with implementations registered for 1884 # both Sized and Container, passing a defaultdict to it results in an 1885 # ambiguous dispatch which will cause a RuntimeError (see 1886 # test_mro_conflicts). 1887 bases = [c.Container, c.Sized, str] 1888 for haystack in permutations(bases): 1889 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1890 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1891 c.Container, object]) 1892 1893 # MutableSequence below is registered directly on D. In other words, it 1894 # precedes MutableMapping which means single dispatch will always 1895 # choose MutableSequence here. 1896 class D(collections.defaultdict): 1897 pass 1898 c.MutableSequence.register(D) 1899 bases = [c.MutableSequence, c.MutableMapping] 1900 for haystack in permutations(bases): 1901 m = mro(D, bases) 1902 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1903 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1904 c.Collection, c.Sized, c.Iterable, c.Container, 1905 object]) 1906 1907 # Container and Callable are registered on different base classes and 1908 # a generic function supporting both should always pick the Callable 1909 # implementation if a C instance is passed. 1910 class C(collections.defaultdict): 1911 def __call__(self): 1912 pass 1913 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1914 for haystack in permutations(bases): 1915 m = mro(C, haystack) 1916 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1917 c.Collection, c.Sized, c.Iterable, 1918 c.Container, object]) 1919 1920 def test_register_abc(self): 1921 c = collections.abc 1922 d = {"a": "b"} 1923 l = [1, 2, 3] 1924 s = {object(), None} 1925 f = frozenset(s) 1926 t = (1, 2, 3) 1927 @functools.singledispatch 1928 def g(obj): 1929 return "base" 1930 self.assertEqual(g(d), "base") 1931 self.assertEqual(g(l), "base") 1932 self.assertEqual(g(s), "base") 1933 self.assertEqual(g(f), "base") 1934 self.assertEqual(g(t), "base") 1935 g.register(c.Sized, lambda obj: "sized") 1936 self.assertEqual(g(d), "sized") 1937 self.assertEqual(g(l), "sized") 1938 self.assertEqual(g(s), "sized") 1939 self.assertEqual(g(f), "sized") 1940 self.assertEqual(g(t), "sized") 1941 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1942 self.assertEqual(g(d), "mutablemapping") 1943 self.assertEqual(g(l), "sized") 1944 self.assertEqual(g(s), "sized") 1945 self.assertEqual(g(f), "sized") 1946 self.assertEqual(g(t), "sized") 1947 g.register(collections.ChainMap, lambda obj: "chainmap") 1948 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1949 self.assertEqual(g(l), "sized") 1950 self.assertEqual(g(s), "sized") 1951 self.assertEqual(g(f), "sized") 1952 self.assertEqual(g(t), "sized") 1953 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1954 self.assertEqual(g(d), "mutablemapping") 1955 self.assertEqual(g(l), "mutablesequence") 1956 self.assertEqual(g(s), "sized") 1957 self.assertEqual(g(f), "sized") 1958 self.assertEqual(g(t), "sized") 1959 g.register(c.MutableSet, lambda obj: "mutableset") 1960 self.assertEqual(g(d), "mutablemapping") 1961 self.assertEqual(g(l), "mutablesequence") 1962 self.assertEqual(g(s), "mutableset") 1963 self.assertEqual(g(f), "sized") 1964 self.assertEqual(g(t), "sized") 1965 g.register(c.Mapping, lambda obj: "mapping") 1966 self.assertEqual(g(d), "mutablemapping") # not specific enough 1967 self.assertEqual(g(l), "mutablesequence") 1968 self.assertEqual(g(s), "mutableset") 1969 self.assertEqual(g(f), "sized") 1970 self.assertEqual(g(t), "sized") 1971 g.register(c.Sequence, lambda obj: "sequence") 1972 self.assertEqual(g(d), "mutablemapping") 1973 self.assertEqual(g(l), "mutablesequence") 1974 self.assertEqual(g(s), "mutableset") 1975 self.assertEqual(g(f), "sized") 1976 self.assertEqual(g(t), "sequence") 1977 g.register(c.Set, lambda obj: "set") 1978 self.assertEqual(g(d), "mutablemapping") 1979 self.assertEqual(g(l), "mutablesequence") 1980 self.assertEqual(g(s), "mutableset") 1981 self.assertEqual(g(f), "set") 1982 self.assertEqual(g(t), "sequence") 1983 g.register(dict, lambda obj: "dict") 1984 self.assertEqual(g(d), "dict") 1985 self.assertEqual(g(l), "mutablesequence") 1986 self.assertEqual(g(s), "mutableset") 1987 self.assertEqual(g(f), "set") 1988 self.assertEqual(g(t), "sequence") 1989 g.register(list, lambda obj: "list") 1990 self.assertEqual(g(d), "dict") 1991 self.assertEqual(g(l), "list") 1992 self.assertEqual(g(s), "mutableset") 1993 self.assertEqual(g(f), "set") 1994 self.assertEqual(g(t), "sequence") 1995 g.register(set, lambda obj: "concrete-set") 1996 self.assertEqual(g(d), "dict") 1997 self.assertEqual(g(l), "list") 1998 self.assertEqual(g(s), "concrete-set") 1999 self.assertEqual(g(f), "set") 2000 self.assertEqual(g(t), "sequence") 2001 g.register(frozenset, lambda obj: "frozen-set") 2002 self.assertEqual(g(d), "dict") 2003 self.assertEqual(g(l), "list") 2004 self.assertEqual(g(s), "concrete-set") 2005 self.assertEqual(g(f), "frozen-set") 2006 self.assertEqual(g(t), "sequence") 2007 g.register(tuple, lambda obj: "tuple") 2008 self.assertEqual(g(d), "dict") 2009 self.assertEqual(g(l), "list") 2010 self.assertEqual(g(s), "concrete-set") 2011 self.assertEqual(g(f), "frozen-set") 2012 self.assertEqual(g(t), "tuple") 2013 2014 def test_c3_abc(self): 2015 c = collections.abc 2016 mro = functools._c3_mro 2017 class A(object): 2018 pass 2019 class B(A): 2020 def __len__(self): 2021 return 0 # implies Sized 2022 @c.Container.register 2023 class C(object): 2024 pass 2025 class D(object): 2026 pass # unrelated 2027 class X(D, C, B): 2028 def __call__(self): 2029 pass # implies Callable 2030 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 2031 for abcs in permutations([c.Sized, c.Callable, c.Container]): 2032 self.assertEqual(mro(X, abcs=abcs), expected) 2033 # unrelated ABCs don't appear in the resulting MRO 2034 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 2035 self.assertEqual(mro(X, abcs=many_abcs), expected) 2036 2037 def test_false_meta(self): 2038 # see issue23572 2039 class MetaA(type): 2040 def __len__(self): 2041 return 0 2042 class A(metaclass=MetaA): 2043 pass 2044 class AA(A): 2045 pass 2046 @functools.singledispatch 2047 def fun(a): 2048 return 'base A' 2049 @fun.register(A) 2050 def _(a): 2051 return 'fun A' 2052 aa = AA() 2053 self.assertEqual(fun(aa), 'fun A') 2054 2055 def test_mro_conflicts(self): 2056 c = collections.abc 2057 @functools.singledispatch 2058 def g(arg): 2059 return "base" 2060 class O(c.Sized): 2061 def __len__(self): 2062 return 0 2063 o = O() 2064 self.assertEqual(g(o), "base") 2065 g.register(c.Iterable, lambda arg: "iterable") 2066 g.register(c.Container, lambda arg: "container") 2067 g.register(c.Sized, lambda arg: "sized") 2068 g.register(c.Set, lambda arg: "set") 2069 self.assertEqual(g(o), "sized") 2070 c.Iterable.register(O) 2071 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 2072 c.Container.register(O) 2073 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 2074 c.Set.register(O) 2075 self.assertEqual(g(o), "set") # because c.Set is a subclass of 2076 # c.Sized and c.Container 2077 class P: 2078 pass 2079 p = P() 2080 self.assertEqual(g(p), "base") 2081 c.Iterable.register(P) 2082 self.assertEqual(g(p), "iterable") 2083 c.Container.register(P) 2084 with self.assertRaises(RuntimeError) as re_one: 2085 g(p) 2086 self.assertIn( 2087 str(re_one.exception), 2088 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2089 "or <class 'collections.abc.Iterable'>"), 2090 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2091 "or <class 'collections.abc.Container'>")), 2092 ) 2093 class Q(c.Sized): 2094 def __len__(self): 2095 return 0 2096 q = Q() 2097 self.assertEqual(g(q), "sized") 2098 c.Iterable.register(Q) 2099 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2100 c.Set.register(Q) 2101 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2102 # c.Sized and c.Iterable 2103 @functools.singledispatch 2104 def h(arg): 2105 return "base" 2106 @h.register(c.Sized) 2107 def _(arg): 2108 return "sized" 2109 @h.register(c.Container) 2110 def _(arg): 2111 return "container" 2112 # Even though Sized and Container are explicit bases of MutableMapping, 2113 # this ABC is implicitly registered on defaultdict which makes all of 2114 # MutableMapping's bases implicit as well from defaultdict's 2115 # perspective. 2116 with self.assertRaises(RuntimeError) as re_two: 2117 h(collections.defaultdict(lambda: 0)) 2118 self.assertIn( 2119 str(re_two.exception), 2120 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2121 "or <class 'collections.abc.Sized'>"), 2122 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2123 "or <class 'collections.abc.Container'>")), 2124 ) 2125 class R(collections.defaultdict): 2126 pass 2127 c.MutableSequence.register(R) 2128 @functools.singledispatch 2129 def i(arg): 2130 return "base" 2131 @i.register(c.MutableMapping) 2132 def _(arg): 2133 return "mapping" 2134 @i.register(c.MutableSequence) 2135 def _(arg): 2136 return "sequence" 2137 r = R() 2138 self.assertEqual(i(r), "sequence") 2139 class S: 2140 pass 2141 class T(S, c.Sized): 2142 def __len__(self): 2143 return 0 2144 t = T() 2145 self.assertEqual(h(t), "sized") 2146 c.Container.register(T) 2147 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2148 class U: 2149 def __len__(self): 2150 return 0 2151 u = U() 2152 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2153 # from the existence of __len__() 2154 c.Container.register(U) 2155 # There is no preference for registered versus inferred ABCs. 2156 with self.assertRaises(RuntimeError) as re_three: 2157 h(u) 2158 self.assertIn( 2159 str(re_three.exception), 2160 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2161 "or <class 'collections.abc.Sized'>"), 2162 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2163 "or <class 'collections.abc.Container'>")), 2164 ) 2165 class V(c.Sized, S): 2166 def __len__(self): 2167 return 0 2168 @functools.singledispatch 2169 def j(arg): 2170 return "base" 2171 @j.register(S) 2172 def _(arg): 2173 return "s" 2174 @j.register(c.Container) 2175 def _(arg): 2176 return "container" 2177 v = V() 2178 self.assertEqual(j(v), "s") 2179 c.Container.register(V) 2180 self.assertEqual(j(v), "container") # because it ends up right after 2181 # Sized in the MRO 2182 2183 def test_cache_invalidation(self): 2184 from collections import UserDict 2185 import weakref 2186 2187 class TracingDict(UserDict): 2188 def __init__(self, *args, **kwargs): 2189 super(TracingDict, self).__init__(*args, **kwargs) 2190 self.set_ops = [] 2191 self.get_ops = [] 2192 def __getitem__(self, key): 2193 result = self.data[key] 2194 self.get_ops.append(key) 2195 return result 2196 def __setitem__(self, key, value): 2197 self.set_ops.append(key) 2198 self.data[key] = value 2199 def clear(self): 2200 self.data.clear() 2201 2202 td = TracingDict() 2203 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2204 c = collections.abc 2205 @functools.singledispatch 2206 def g(arg): 2207 return "base" 2208 d = {} 2209 l = [] 2210 self.assertEqual(len(td), 0) 2211 self.assertEqual(g(d), "base") 2212 self.assertEqual(len(td), 1) 2213 self.assertEqual(td.get_ops, []) 2214 self.assertEqual(td.set_ops, [dict]) 2215 self.assertEqual(td.data[dict], g.registry[object]) 2216 self.assertEqual(g(l), "base") 2217 self.assertEqual(len(td), 2) 2218 self.assertEqual(td.get_ops, []) 2219 self.assertEqual(td.set_ops, [dict, list]) 2220 self.assertEqual(td.data[dict], g.registry[object]) 2221 self.assertEqual(td.data[list], g.registry[object]) 2222 self.assertEqual(td.data[dict], td.data[list]) 2223 self.assertEqual(g(l), "base") 2224 self.assertEqual(g(d), "base") 2225 self.assertEqual(td.get_ops, [list, dict]) 2226 self.assertEqual(td.set_ops, [dict, list]) 2227 g.register(list, lambda arg: "list") 2228 self.assertEqual(td.get_ops, [list, dict]) 2229 self.assertEqual(len(td), 0) 2230 self.assertEqual(g(d), "base") 2231 self.assertEqual(len(td), 1) 2232 self.assertEqual(td.get_ops, [list, dict]) 2233 self.assertEqual(td.set_ops, [dict, list, dict]) 2234 self.assertEqual(td.data[dict], 2235 functools._find_impl(dict, g.registry)) 2236 self.assertEqual(g(l), "list") 2237 self.assertEqual(len(td), 2) 2238 self.assertEqual(td.get_ops, [list, dict]) 2239 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2240 self.assertEqual(td.data[list], 2241 functools._find_impl(list, g.registry)) 2242 class X: 2243 pass 2244 c.MutableMapping.register(X) # Will not invalidate the cache, 2245 # not using ABCs yet. 2246 self.assertEqual(g(d), "base") 2247 self.assertEqual(g(l), "list") 2248 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2249 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2250 g.register(c.Sized, lambda arg: "sized") 2251 self.assertEqual(len(td), 0) 2252 self.assertEqual(g(d), "sized") 2253 self.assertEqual(len(td), 1) 2254 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2255 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2256 self.assertEqual(g(l), "list") 2257 self.assertEqual(len(td), 2) 2258 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2259 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2260 self.assertEqual(g(l), "list") 2261 self.assertEqual(g(d), "sized") 2262 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2263 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2264 g.dispatch(list) 2265 g.dispatch(dict) 2266 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2267 list, dict]) 2268 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2269 c.MutableSet.register(X) # Will invalidate the cache. 2270 self.assertEqual(len(td), 2) # Stale cache. 2271 self.assertEqual(g(l), "list") 2272 self.assertEqual(len(td), 1) 2273 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2274 self.assertEqual(len(td), 0) 2275 self.assertEqual(g(d), "mutablemapping") 2276 self.assertEqual(len(td), 1) 2277 self.assertEqual(g(l), "list") 2278 self.assertEqual(len(td), 2) 2279 g.register(dict, lambda arg: "dict") 2280 self.assertEqual(g(d), "dict") 2281 self.assertEqual(g(l), "list") 2282 g._clear_cache() 2283 self.assertEqual(len(td), 0) 2284 2285 def test_annotations(self): 2286 @functools.singledispatch 2287 def i(arg): 2288 return "base" 2289 @i.register 2290 def _(arg: collections.abc.Mapping): 2291 return "mapping" 2292 @i.register 2293 def _(arg: "collections.abc.Sequence"): 2294 return "sequence" 2295 self.assertEqual(i(None), "base") 2296 self.assertEqual(i({"a": 1}), "mapping") 2297 self.assertEqual(i([1, 2, 3]), "sequence") 2298 self.assertEqual(i((1, 2, 3)), "sequence") 2299 self.assertEqual(i("str"), "sequence") 2300 2301 # Registering classes as callables doesn't work with annotations, 2302 # you need to pass the type explicitly. 2303 @i.register(str) 2304 class _: 2305 def __init__(self, arg): 2306 self.arg = arg 2307 2308 def __eq__(self, other): 2309 return self.arg == other 2310 self.assertEqual(i("str"), "str") 2311 2312 def test_method_register(self): 2313 class A: 2314 @functools.singledispatchmethod 2315 def t(self, arg): 2316 self.arg = "base" 2317 @t.register(int) 2318 def _(self, arg): 2319 self.arg = "int" 2320 @t.register(str) 2321 def _(self, arg): 2322 self.arg = "str" 2323 a = A() 2324 2325 a.t(0) 2326 self.assertEqual(a.arg, "int") 2327 aa = A() 2328 self.assertFalse(hasattr(aa, 'arg')) 2329 a.t('') 2330 self.assertEqual(a.arg, "str") 2331 aa = A() 2332 self.assertFalse(hasattr(aa, 'arg')) 2333 a.t(0.0) 2334 self.assertEqual(a.arg, "base") 2335 aa = A() 2336 self.assertFalse(hasattr(aa, 'arg')) 2337 2338 def test_staticmethod_register(self): 2339 class A: 2340 @functools.singledispatchmethod 2341 @staticmethod 2342 def t(arg): 2343 return arg 2344 @t.register(int) 2345 @staticmethod 2346 def _(arg): 2347 return isinstance(arg, int) 2348 @t.register(str) 2349 @staticmethod 2350 def _(arg): 2351 return isinstance(arg, str) 2352 a = A() 2353 2354 self.assertTrue(A.t(0)) 2355 self.assertTrue(A.t('')) 2356 self.assertEqual(A.t(0.0), 0.0) 2357 2358 def test_classmethod_register(self): 2359 class A: 2360 def __init__(self, arg): 2361 self.arg = arg 2362 2363 @functools.singledispatchmethod 2364 @classmethod 2365 def t(cls, arg): 2366 return cls("base") 2367 @t.register(int) 2368 @classmethod 2369 def _(cls, arg): 2370 return cls("int") 2371 @t.register(str) 2372 @classmethod 2373 def _(cls, arg): 2374 return cls("str") 2375 2376 self.assertEqual(A.t(0).arg, "int") 2377 self.assertEqual(A.t('').arg, "str") 2378 self.assertEqual(A.t(0.0).arg, "base") 2379 2380 def test_callable_register(self): 2381 class A: 2382 def __init__(self, arg): 2383 self.arg = arg 2384 2385 @functools.singledispatchmethod 2386 @classmethod 2387 def t(cls, arg): 2388 return cls("base") 2389 2390 @A.t.register(int) 2391 @classmethod 2392 def _(cls, arg): 2393 return cls("int") 2394 @A.t.register(str) 2395 @classmethod 2396 def _(cls, arg): 2397 return cls("str") 2398 2399 self.assertEqual(A.t(0).arg, "int") 2400 self.assertEqual(A.t('').arg, "str") 2401 self.assertEqual(A.t(0.0).arg, "base") 2402 2403 def test_abstractmethod_register(self): 2404 class Abstract(metaclass=abc.ABCMeta): 2405 2406 @functools.singledispatchmethod 2407 @abc.abstractmethod 2408 def add(self, x, y): 2409 pass 2410 2411 self.assertTrue(Abstract.add.__isabstractmethod__) 2412 self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) 2413 2414 with self.assertRaises(TypeError): 2415 Abstract() 2416 2417 def test_type_ann_register(self): 2418 class A: 2419 @functools.singledispatchmethod 2420 def t(self, arg): 2421 return "base" 2422 @t.register 2423 def _(self, arg: int): 2424 return "int" 2425 @t.register 2426 def _(self, arg: str): 2427 return "str" 2428 a = A() 2429 2430 self.assertEqual(a.t(0), "int") 2431 self.assertEqual(a.t(''), "str") 2432 self.assertEqual(a.t(0.0), "base") 2433 2434 def test_staticmethod_type_ann_register(self): 2435 class A: 2436 @functools.singledispatchmethod 2437 @staticmethod 2438 def t(arg): 2439 return arg 2440 @t.register 2441 @staticmethod 2442 def _(arg: int): 2443 return isinstance(arg, int) 2444 @t.register 2445 @staticmethod 2446 def _(arg: str): 2447 return isinstance(arg, str) 2448 a = A() 2449 2450 self.assertTrue(A.t(0)) 2451 self.assertTrue(A.t('')) 2452 self.assertEqual(A.t(0.0), 0.0) 2453 2454 def test_classmethod_type_ann_register(self): 2455 class A: 2456 def __init__(self, arg): 2457 self.arg = arg 2458 2459 @functools.singledispatchmethod 2460 @classmethod 2461 def t(cls, arg): 2462 return cls("base") 2463 @t.register 2464 @classmethod 2465 def _(cls, arg: int): 2466 return cls("int") 2467 @t.register 2468 @classmethod 2469 def _(cls, arg: str): 2470 return cls("str") 2471 2472 self.assertEqual(A.t(0).arg, "int") 2473 self.assertEqual(A.t('').arg, "str") 2474 self.assertEqual(A.t(0.0).arg, "base") 2475 2476 def test_method_wrapping_attributes(self): 2477 class A: 2478 @functools.singledispatchmethod 2479 def func(self, arg: int) -> str: 2480 """My function docstring""" 2481 return str(arg) 2482 @functools.singledispatchmethod 2483 @classmethod 2484 def cls_func(cls, arg: int) -> str: 2485 """My function docstring""" 2486 return str(arg) 2487 @functools.singledispatchmethod 2488 @staticmethod 2489 def static_func(arg: int) -> str: 2490 """My function docstring""" 2491 return str(arg) 2492 2493 for meth in ( 2494 A.func, 2495 A().func, 2496 A.cls_func, 2497 A().cls_func, 2498 A.static_func, 2499 A().static_func 2500 ): 2501 with self.subTest(meth=meth): 2502 self.assertEqual(meth.__doc__, 'My function docstring') 2503 self.assertEqual(meth.__annotations__['arg'], int) 2504 2505 self.assertEqual(A.func.__name__, 'func') 2506 self.assertEqual(A().func.__name__, 'func') 2507 self.assertEqual(A.cls_func.__name__, 'cls_func') 2508 self.assertEqual(A().cls_func.__name__, 'cls_func') 2509 self.assertEqual(A.static_func.__name__, 'static_func') 2510 self.assertEqual(A().static_func.__name__, 'static_func') 2511 2512 def test_double_wrapped_methods(self): 2513 def classmethod_friendly_decorator(func): 2514 wrapped = func.__func__ 2515 @classmethod 2516 @functools.wraps(wrapped) 2517 def wrapper(*args, **kwargs): 2518 return wrapped(*args, **kwargs) 2519 return wrapper 2520 2521 class WithoutSingleDispatch: 2522 @classmethod 2523 @contextlib.contextmanager 2524 def cls_context_manager(cls, arg: int) -> str: 2525 try: 2526 yield str(arg) 2527 finally: 2528 return 'Done' 2529 2530 @classmethod_friendly_decorator 2531 @classmethod 2532 def decorated_classmethod(cls, arg: int) -> str: 2533 return str(arg) 2534 2535 class WithSingleDispatch: 2536 @functools.singledispatchmethod 2537 @classmethod 2538 @contextlib.contextmanager 2539 def cls_context_manager(cls, arg: int) -> str: 2540 """My function docstring""" 2541 try: 2542 yield str(arg) 2543 finally: 2544 return 'Done' 2545 2546 @functools.singledispatchmethod 2547 @classmethod_friendly_decorator 2548 @classmethod 2549 def decorated_classmethod(cls, arg: int) -> str: 2550 """My function docstring""" 2551 return str(arg) 2552 2553 # These are sanity checks 2554 # to test the test itself is working as expected 2555 with WithoutSingleDispatch.cls_context_manager(5) as foo: 2556 without_single_dispatch_foo = foo 2557 2558 with WithSingleDispatch.cls_context_manager(5) as foo: 2559 single_dispatch_foo = foo 2560 2561 self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) 2562 self.assertEqual(single_dispatch_foo, '5') 2563 2564 self.assertEqual( 2565 WithoutSingleDispatch.decorated_classmethod(5), 2566 WithSingleDispatch.decorated_classmethod(5) 2567 ) 2568 2569 self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') 2570 2571 # Behavioural checks now follow 2572 for method_name in ('cls_context_manager', 'decorated_classmethod'): 2573 with self.subTest(method=method_name): 2574 self.assertEqual( 2575 getattr(WithSingleDispatch, method_name).__name__, 2576 getattr(WithoutSingleDispatch, method_name).__name__ 2577 ) 2578 2579 self.assertEqual( 2580 getattr(WithSingleDispatch(), method_name).__name__, 2581 getattr(WithoutSingleDispatch(), method_name).__name__ 2582 ) 2583 2584 for meth in ( 2585 WithSingleDispatch.cls_context_manager, 2586 WithSingleDispatch().cls_context_manager, 2587 WithSingleDispatch.decorated_classmethod, 2588 WithSingleDispatch().decorated_classmethod 2589 ): 2590 with self.subTest(meth=meth): 2591 self.assertEqual(meth.__doc__, 'My function docstring') 2592 self.assertEqual(meth.__annotations__['arg'], int) 2593 2594 self.assertEqual( 2595 WithSingleDispatch.cls_context_manager.__name__, 2596 'cls_context_manager' 2597 ) 2598 self.assertEqual( 2599 WithSingleDispatch().cls_context_manager.__name__, 2600 'cls_context_manager' 2601 ) 2602 self.assertEqual( 2603 WithSingleDispatch.decorated_classmethod.__name__, 2604 'decorated_classmethod' 2605 ) 2606 self.assertEqual( 2607 WithSingleDispatch().decorated_classmethod.__name__, 2608 'decorated_classmethod' 2609 ) 2610 2611 def test_invalid_registrations(self): 2612 msg_prefix = "Invalid first argument to `register()`: " 2613 msg_suffix = ( 2614 ". Use either `@register(some_class)` or plain `@register` on an " 2615 "annotated function." 2616 ) 2617 @functools.singledispatch 2618 def i(arg): 2619 return "base" 2620 with self.assertRaises(TypeError) as exc: 2621 @i.register(42) 2622 def _(arg): 2623 return "I annotated with a non-type" 2624 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2625 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2626 with self.assertRaises(TypeError) as exc: 2627 @i.register 2628 def _(arg): 2629 return "I forgot to annotate" 2630 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2631 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2632 )) 2633 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2634 2635 with self.assertRaises(TypeError) as exc: 2636 @i.register 2637 def _(arg: typing.Iterable[str]): 2638 # At runtime, dispatching on generics is impossible. 2639 # When registering implementations with singledispatch, avoid 2640 # types from `typing`. Instead, annotate with regular types 2641 # or ABCs. 2642 return "I annotated with a generic collection" 2643 self.assertTrue(str(exc.exception).startswith( 2644 "Invalid annotation for 'arg'." 2645 )) 2646 self.assertTrue(str(exc.exception).endswith( 2647 'typing.Iterable[str] is not a class.' 2648 )) 2649 2650 def test_invalid_positional_argument(self): 2651 @functools.singledispatch 2652 def f(*args): 2653 pass 2654 msg = 'f requires at least 1 positional argument' 2655 with self.assertRaisesRegex(TypeError, msg): 2656 f() 2657 2658 2659class CachedCostItem: 2660 _cost = 1 2661 2662 def __init__(self): 2663 self.lock = py_functools.RLock() 2664 2665 @py_functools.cached_property 2666 def cost(self): 2667 """The cost of the item.""" 2668 with self.lock: 2669 self._cost += 1 2670 return self._cost 2671 2672 2673class OptionallyCachedCostItem: 2674 _cost = 1 2675 2676 def get_cost(self): 2677 """The cost of the item.""" 2678 self._cost += 1 2679 return self._cost 2680 2681 cached_cost = py_functools.cached_property(get_cost) 2682 2683 2684class CachedCostItemWait: 2685 2686 def __init__(self, event): 2687 self._cost = 1 2688 self.lock = py_functools.RLock() 2689 self.event = event 2690 2691 @py_functools.cached_property 2692 def cost(self): 2693 self.event.wait(1) 2694 with self.lock: 2695 self._cost += 1 2696 return self._cost 2697 2698 2699class CachedCostItemWithSlots: 2700 __slots__ = ('_cost') 2701 2702 def __init__(self): 2703 self._cost = 1 2704 2705 @py_functools.cached_property 2706 def cost(self): 2707 raise RuntimeError('never called, slots not supported') 2708 2709 2710class TestCachedProperty(unittest.TestCase): 2711 def test_cached(self): 2712 item = CachedCostItem() 2713 self.assertEqual(item.cost, 2) 2714 self.assertEqual(item.cost, 2) # not 3 2715 2716 def test_cached_attribute_name_differs_from_func_name(self): 2717 item = OptionallyCachedCostItem() 2718 self.assertEqual(item.get_cost(), 2) 2719 self.assertEqual(item.cached_cost, 3) 2720 self.assertEqual(item.get_cost(), 4) 2721 self.assertEqual(item.cached_cost, 3) 2722 2723 def test_threaded(self): 2724 go = threading.Event() 2725 item = CachedCostItemWait(go) 2726 2727 num_threads = 3 2728 2729 orig_si = sys.getswitchinterval() 2730 sys.setswitchinterval(1e-6) 2731 try: 2732 threads = [ 2733 threading.Thread(target=lambda: item.cost) 2734 for k in range(num_threads) 2735 ] 2736 with support.start_threads(threads): 2737 go.set() 2738 finally: 2739 sys.setswitchinterval(orig_si) 2740 2741 self.assertEqual(item.cost, 2) 2742 2743 def test_object_with_slots(self): 2744 item = CachedCostItemWithSlots() 2745 with self.assertRaisesRegex( 2746 TypeError, 2747 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2748 ): 2749 item.cost 2750 2751 def test_immutable_dict(self): 2752 class MyMeta(type): 2753 @py_functools.cached_property 2754 def prop(self): 2755 return True 2756 2757 class MyClass(metaclass=MyMeta): 2758 pass 2759 2760 with self.assertRaisesRegex( 2761 TypeError, 2762 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 2763 ): 2764 MyClass.prop 2765 2766 def test_reuse_different_names(self): 2767 """Disallow this case because decorated function a would not be cached.""" 2768 with self.assertRaises(RuntimeError) as ctx: 2769 class ReusedCachedProperty: 2770 @py_functools.cached_property 2771 def a(self): 2772 pass 2773 2774 b = a 2775 2776 self.assertEqual( 2777 str(ctx.exception.__context__), 2778 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 2779 ) 2780 2781 def test_reuse_same_name(self): 2782 """Reusing a cached_property on different classes under the same name is OK.""" 2783 counter = 0 2784 2785 @py_functools.cached_property 2786 def _cp(_self): 2787 nonlocal counter 2788 counter += 1 2789 return counter 2790 2791 class A: 2792 cp = _cp 2793 2794 class B: 2795 cp = _cp 2796 2797 a = A() 2798 b = B() 2799 2800 self.assertEqual(a.cp, 1) 2801 self.assertEqual(b.cp, 2) 2802 self.assertEqual(a.cp, 1) 2803 2804 def test_set_name_not_called(self): 2805 cp = py_functools.cached_property(lambda s: None) 2806 class Foo: 2807 pass 2808 2809 Foo.cp = cp 2810 2811 with self.assertRaisesRegex( 2812 TypeError, 2813 "Cannot use cached_property instance without calling __set_name__ on it.", 2814 ): 2815 Foo().cp 2816 2817 def test_access_from_class(self): 2818 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 2819 2820 def test_doc(self): 2821 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 2822 2823 2824if __name__ == '__main__': 2825 unittest.main() 2826