1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16import os 17import weakref 18import gc 19from weakref import proxy 20import contextlib 21 22from test.support import import_helper 23from test.support import threading_helper 24from test.support.script_helper import assert_python_ok 25 26import functools 27 28py_functools = import_helper.import_fresh_module('functools', 29 blocked=['_functools']) 30c_functools = import_helper.import_fresh_module('functools') 31 32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) 33 34@contextlib.contextmanager 35def replaced_module(name, replacement): 36 original_module = sys.modules[name] 37 sys.modules[name] = replacement 38 try: 39 yield 40 finally: 41 sys.modules[name] = original_module 42 43def capture(*args, **kw): 44 """capture all positional and keyword arguments""" 45 return args, kw 46 47 48def signature(part): 49 """ return the signature of a partial object """ 50 return (part.func, part.args, part.keywords, part.__dict__) 51 52class MyTuple(tuple): 53 pass 54 55class BadTuple(tuple): 56 def __add__(self, other): 57 return list(self) + list(other) 58 59class MyDict(dict): 60 pass 61 62 63class TestPartial: 64 65 def test_basic_examples(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 self.assertTrue(callable(p)) 68 self.assertEqual(p(3, 4, b=30, c=40), 69 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 70 p = self.partial(map, lambda x: x*10) 71 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 72 73 def test_attributes(self): 74 p = self.partial(capture, 1, 2, a=10, b=20) 75 # attributes should be readable 76 self.assertEqual(p.func, capture) 77 self.assertEqual(p.args, (1, 2)) 78 self.assertEqual(p.keywords, dict(a=10, b=20)) 79 80 def test_argument_checking(self): 81 self.assertRaises(TypeError, self.partial) # need at least a func arg 82 try: 83 self.partial(2)() 84 except TypeError: 85 pass 86 else: 87 self.fail('First arg not checked for callability') 88 89 def test_protection_of_callers_dict_argument(self): 90 # a caller's dictionary should not be altered by partial 91 def func(a=10, b=20): 92 return a 93 d = {'a':3} 94 p = self.partial(func, a=5) 95 self.assertEqual(p(**d), 3) 96 self.assertEqual(d, {'a':3}) 97 p(b=7) 98 self.assertEqual(d, {'a':3}) 99 100 def test_kwargs_copy(self): 101 # Issue #29532: Altering a kwarg dictionary passed to a constructor 102 # should not affect a partial object after creation 103 d = {'a': 3} 104 p = self.partial(capture, **d) 105 self.assertEqual(p(), ((), {'a': 3})) 106 d['a'] = 5 107 self.assertEqual(p(), ((), {'a': 3})) 108 109 def test_arg_combinations(self): 110 # exercise special code paths for zero args in either partial 111 # object or the caller 112 p = self.partial(capture) 113 self.assertEqual(p(), ((), {})) 114 self.assertEqual(p(1,2), ((1,2), {})) 115 p = self.partial(capture, 1, 2) 116 self.assertEqual(p(), ((1,2), {})) 117 self.assertEqual(p(3,4), ((1,2,3,4), {})) 118 119 def test_kw_combinations(self): 120 # exercise special code paths for no keyword args in 121 # either the partial object or the caller 122 p = self.partial(capture) 123 self.assertEqual(p.keywords, {}) 124 self.assertEqual(p(), ((), {})) 125 self.assertEqual(p(a=1), ((), {'a':1})) 126 p = self.partial(capture, a=1) 127 self.assertEqual(p.keywords, {'a':1}) 128 self.assertEqual(p(), ((), {'a':1})) 129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 130 # keyword args in the call override those in the partial object 131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 132 133 def test_positional(self): 134 # make sure positional arguments are captured correctly 135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 136 p = self.partial(capture, *args) 137 expected = args + ('x',) 138 got, empty = p('x') 139 self.assertTrue(expected == got and empty == {}) 140 141 def test_keyword(self): 142 # make sure keyword arguments are captured correctly 143 for a in ['a', 0, None, 3.5]: 144 p = self.partial(capture, a=a) 145 expected = {'a':a,'x':None} 146 empty, got = p(x=None) 147 self.assertTrue(expected == got and empty == ()) 148 149 def test_no_side_effects(self): 150 # make sure there are no side effects that affect subsequent calls 151 p = self.partial(capture, 0, a=1) 152 args1, kw1 = p(1, b=2) 153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 154 args2, kw2 = p() 155 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 156 157 def test_error_propagation(self): 158 def f(x, y): 159 x / y 160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 164 165 def test_weakref(self): 166 f = self.partial(int, base=16) 167 p = proxy(f) 168 self.assertEqual(f.func, p.func) 169 f = None 170 support.gc_collect() # For PyPy or other GCs. 171 self.assertRaises(ReferenceError, getattr, p, 'func') 172 173 def test_with_bound_and_unbound_methods(self): 174 data = list(map(str, range(10))) 175 join = self.partial(str.join, '') 176 self.assertEqual(join(data), '0123456789') 177 join = self.partial(''.join) 178 self.assertEqual(join(data), '0123456789') 179 180 def test_nested_optimization(self): 181 partial = self.partial 182 inner = partial(signature, 'asdf') 183 nested = partial(inner, bar=True) 184 flat = partial(signature, 'asdf', bar=True) 185 self.assertEqual(signature(nested), signature(flat)) 186 187 def test_nested_partial_with_attribute(self): 188 # see issue 25137 189 partial = self.partial 190 191 def foo(bar): 192 return bar 193 194 p = partial(foo, 'first') 195 p2 = partial(p, 'second') 196 p2.new_attr = 'spam' 197 self.assertEqual(p2.new_attr, 'spam') 198 199 def test_repr(self): 200 args = (object(), object()) 201 args_repr = ', '.join(repr(a) for a in args) 202 kwargs = {'a': object(), 'b': object()} 203 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 204 'b={b!r}, a={a!r}'.format_map(kwargs)] 205 if self.partial in (c_functools.partial, py_functools.partial): 206 name = 'functools.partial' 207 else: 208 name = self.partial.__name__ 209 210 f = self.partial(capture) 211 self.assertEqual(f'{name}({capture!r})', repr(f)) 212 213 f = self.partial(capture, *args) 214 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 215 216 f = self.partial(capture, **kwargs) 217 self.assertIn(repr(f), 218 [f'{name}({capture!r}, {kwargs_repr})' 219 for kwargs_repr in kwargs_reprs]) 220 221 f = self.partial(capture, *args, **kwargs) 222 self.assertIn(repr(f), 223 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 224 for kwargs_repr in kwargs_reprs]) 225 226 def test_recursive_repr(self): 227 if self.partial in (c_functools.partial, py_functools.partial): 228 name = 'functools.partial' 229 else: 230 name = self.partial.__name__ 231 232 f = self.partial(capture) 233 f.__setstate__((f, (), {}, {})) 234 try: 235 self.assertEqual(repr(f), '%s(...)' % (name,)) 236 finally: 237 f.__setstate__((capture, (), {}, {})) 238 239 f = self.partial(capture) 240 f.__setstate__((capture, (f,), {}, {})) 241 try: 242 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 243 finally: 244 f.__setstate__((capture, (), {}, {})) 245 246 f = self.partial(capture) 247 f.__setstate__((capture, (), {'a': f}, {})) 248 try: 249 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 250 finally: 251 f.__setstate__((capture, (), {}, {})) 252 253 def test_pickle(self): 254 with self.AllowPickle(): 255 f = self.partial(signature, ['asdf'], bar=[True]) 256 f.attr = [] 257 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 258 f_copy = pickle.loads(pickle.dumps(f, proto)) 259 self.assertEqual(signature(f_copy), signature(f)) 260 261 def test_copy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.copy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIs(f_copy.attr, f.attr) 267 self.assertIs(f_copy.args, f.args) 268 self.assertIs(f_copy.keywords, f.keywords) 269 270 def test_deepcopy(self): 271 f = self.partial(signature, ['asdf'], bar=[True]) 272 f.attr = [] 273 f_copy = copy.deepcopy(f) 274 self.assertEqual(signature(f_copy), signature(f)) 275 self.assertIsNot(f_copy.attr, f.attr) 276 self.assertIsNot(f_copy.args, f.args) 277 self.assertIsNot(f_copy.args[0], f.args[0]) 278 self.assertIsNot(f_copy.keywords, f.keywords) 279 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 280 281 def test_setstate(self): 282 f = self.partial(signature) 283 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 284 285 self.assertEqual(signature(f), 286 (capture, (1,), dict(a=10), dict(attr=[]))) 287 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 288 289 f.__setstate__((capture, (1,), dict(a=10), None)) 290 291 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 292 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 293 294 f.__setstate__((capture, (1,), None, None)) 295 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 296 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 297 self.assertEqual(f(2), ((1, 2), {})) 298 self.assertEqual(f(), ((1,), {})) 299 300 f.__setstate__((capture, (), {}, None)) 301 self.assertEqual(signature(f), (capture, (), {}, {})) 302 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 303 self.assertEqual(f(2), ((2,), {})) 304 self.assertEqual(f(), ((), {})) 305 306 def test_setstate_errors(self): 307 f = self.partial(signature) 308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 309 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 310 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 311 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 312 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 313 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 314 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 315 316 def test_setstate_subclasses(self): 317 f = self.partial(signature) 318 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 319 s = signature(f) 320 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 321 self.assertIs(type(s[1]), tuple) 322 self.assertIs(type(s[2]), dict) 323 r = f() 324 self.assertEqual(r, ((1,), {'a': 10})) 325 self.assertIs(type(r[0]), tuple) 326 self.assertIs(type(r[1]), dict) 327 328 f.__setstate__((capture, BadTuple((1,)), {}, None)) 329 s = signature(f) 330 self.assertEqual(s, (capture, (1,), {}, {})) 331 self.assertIs(type(s[1]), tuple) 332 r = f(2) 333 self.assertEqual(r, ((1, 2), {})) 334 self.assertIs(type(r[0]), tuple) 335 336 def test_recursive_pickle(self): 337 with self.AllowPickle(): 338 f = self.partial(capture) 339 f.__setstate__((f, (), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 with self.assertRaises(RecursionError): 343 pickle.dumps(f, proto) 344 finally: 345 f.__setstate__((capture, (), {}, {})) 346 347 f = self.partial(capture) 348 f.__setstate__((capture, (f,), {}, {})) 349 try: 350 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 351 f_copy = pickle.loads(pickle.dumps(f, proto)) 352 try: 353 self.assertIs(f_copy.args[0], f_copy) 354 finally: 355 f_copy.__setstate__((capture, (), {}, {})) 356 finally: 357 f.__setstate__((capture, (), {}, {})) 358 359 f = self.partial(capture) 360 f.__setstate__((capture, (), {'a': f}, {})) 361 try: 362 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 363 f_copy = pickle.loads(pickle.dumps(f, proto)) 364 try: 365 self.assertIs(f_copy.keywords['a'], f_copy) 366 finally: 367 f_copy.__setstate__((capture, (), {}, {})) 368 finally: 369 f.__setstate__((capture, (), {}, {})) 370 371 # Issue 6083: Reference counting bug 372 def test_setstate_refcount(self): 373 class BadSequence: 374 def __len__(self): 375 return 4 376 def __getitem__(self, key): 377 if key == 0: 378 return max 379 elif key == 1: 380 return tuple(range(1000000)) 381 elif key in (2, 3): 382 return {} 383 raise IndexError 384 385 f = self.partial(object) 386 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 387 388@unittest.skipUnless(c_functools, 'requires the C _functools module') 389class TestPartialC(TestPartial, unittest.TestCase): 390 if c_functools: 391 partial = c_functools.partial 392 393 class AllowPickle: 394 def __enter__(self): 395 return self 396 def __exit__(self, type, value, tb): 397 return False 398 399 def test_attributes_unwritable(self): 400 # attributes should not be writable 401 p = self.partial(capture, 1, 2, a=10, b=20) 402 self.assertRaises(AttributeError, setattr, p, 'func', map) 403 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 404 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 405 406 p = self.partial(hex) 407 try: 408 del p.__dict__ 409 except TypeError: 410 pass 411 else: 412 self.fail('partial object allowed __dict__ to be deleted') 413 414 def test_manually_adding_non_string_keyword(self): 415 p = self.partial(capture) 416 # Adding a non-string/unicode keyword to partial kwargs 417 p.keywords[1234] = 'value' 418 r = repr(p) 419 self.assertIn('1234', r) 420 self.assertIn("'value'", r) 421 with self.assertRaises(TypeError): 422 p() 423 424 def test_keystr_replaces_value(self): 425 p = self.partial(capture) 426 427 class MutatesYourDict(object): 428 def __str__(self): 429 p.keywords[self] = ['sth2'] 430 return 'astr' 431 432 # Replacing the value during key formatting should keep the original 433 # value alive (at least long enough). 434 p.keywords[MutatesYourDict()] = ['sth'] 435 r = repr(p) 436 self.assertIn('astr', r) 437 self.assertIn("['sth']", r) 438 439 440class TestPartialPy(TestPartial, unittest.TestCase): 441 partial = py_functools.partial 442 443 class AllowPickle: 444 def __init__(self): 445 self._cm = replaced_module("functools", py_functools) 446 def __enter__(self): 447 return self._cm.__enter__() 448 def __exit__(self, type, value, tb): 449 return self._cm.__exit__(type, value, tb) 450 451if c_functools: 452 class CPartialSubclass(c_functools.partial): 453 pass 454 455class PyPartialSubclass(py_functools.partial): 456 pass 457 458@unittest.skipUnless(c_functools, 'requires the C _functools module') 459class TestPartialCSubclass(TestPartialC): 460 if c_functools: 461 partial = CPartialSubclass 462 463 # partial subclasses are not optimized for nested calls 464 test_nested_optimization = None 465 466class TestPartialPySubclass(TestPartialPy): 467 partial = PyPartialSubclass 468 469class TestPartialMethod(unittest.TestCase): 470 471 class A(object): 472 nothing = functools.partialmethod(capture) 473 positional = functools.partialmethod(capture, 1) 474 keywords = functools.partialmethod(capture, a=2) 475 both = functools.partialmethod(capture, 3, b=4) 476 spec_keywords = functools.partialmethod(capture, self=1, func=2) 477 478 nested = functools.partialmethod(positional, 5) 479 480 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 481 482 static = functools.partialmethod(staticmethod(capture), 8) 483 cls = functools.partialmethod(classmethod(capture), d=9) 484 485 a = A() 486 487 def test_arg_combinations(self): 488 self.assertEqual(self.a.nothing(), ((self.a,), {})) 489 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 490 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 491 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 492 493 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 494 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 495 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 496 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 497 498 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 499 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 500 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 501 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 502 503 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 504 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 505 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 506 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 507 508 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 509 510 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 511 512 def test_nested(self): 513 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 514 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 515 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 516 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 517 518 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 519 520 def test_over_partial(self): 521 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 522 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 523 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 524 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 525 526 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 527 528 def test_bound_method_introspection(self): 529 obj = self.a 530 self.assertIs(obj.both.__self__, obj) 531 self.assertIs(obj.nested.__self__, obj) 532 self.assertIs(obj.over_partial.__self__, obj) 533 self.assertIs(obj.cls.__self__, self.A) 534 self.assertIs(self.A.cls.__self__, self.A) 535 536 def test_unbound_method_retrieval(self): 537 obj = self.A 538 self.assertFalse(hasattr(obj.both, "__self__")) 539 self.assertFalse(hasattr(obj.nested, "__self__")) 540 self.assertFalse(hasattr(obj.over_partial, "__self__")) 541 self.assertFalse(hasattr(obj.static, "__self__")) 542 self.assertFalse(hasattr(self.a.static, "__self__")) 543 544 def test_descriptors(self): 545 for obj in [self.A, self.a]: 546 with self.subTest(obj=obj): 547 self.assertEqual(obj.static(), ((8,), {})) 548 self.assertEqual(obj.static(5), ((8, 5), {})) 549 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 550 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 551 552 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 553 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 554 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 555 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 556 557 def test_overriding_keywords(self): 558 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 559 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 560 561 def test_invalid_args(self): 562 with self.assertRaises(TypeError): 563 class B(object): 564 method = functools.partialmethod(None, 1) 565 with self.assertRaises(TypeError): 566 class B: 567 method = functools.partialmethod() 568 with self.assertRaises(TypeError): 569 class B: 570 method = functools.partialmethod(func=capture, a=1) 571 572 def test_repr(self): 573 self.assertEqual(repr(vars(self.A)['both']), 574 'functools.partialmethod({}, 3, b=4)'.format(capture)) 575 576 def test_abstract(self): 577 class Abstract(abc.ABCMeta): 578 579 @abc.abstractmethod 580 def add(self, x, y): 581 pass 582 583 add5 = functools.partialmethod(add, 5) 584 585 self.assertTrue(Abstract.add.__isabstractmethod__) 586 self.assertTrue(Abstract.add5.__isabstractmethod__) 587 588 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 589 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 590 591 def test_positional_only(self): 592 def f(a, b, /): 593 return a + b 594 595 p = functools.partial(f, 1) 596 self.assertEqual(p(2), f(1, 2)) 597 598 599class TestUpdateWrapper(unittest.TestCase): 600 601 def check_wrapper(self, wrapper, wrapped, 602 assigned=functools.WRAPPER_ASSIGNMENTS, 603 updated=functools.WRAPPER_UPDATES): 604 # Check attributes were assigned 605 for name in assigned: 606 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 607 # Check attributes were updated 608 for name in updated: 609 wrapper_attr = getattr(wrapper, name) 610 wrapped_attr = getattr(wrapped, name) 611 for key in wrapped_attr: 612 if name == "__dict__" and key == "__wrapped__": 613 # __wrapped__ is overwritten by the update code 614 continue 615 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 616 # Check __wrapped__ 617 self.assertIs(wrapper.__wrapped__, wrapped) 618 619 620 def _default_update(self): 621 def f(a:'This is a new annotation'): 622 """This is a test""" 623 pass 624 f.attr = 'This is also a test' 625 f.__wrapped__ = "This is a bald faced lie" 626 def wrapper(b:'This is the prior annotation'): 627 pass 628 functools.update_wrapper(wrapper, f) 629 return wrapper, f 630 631 def test_default_update(self): 632 wrapper, f = self._default_update() 633 self.check_wrapper(wrapper, f) 634 self.assertIs(wrapper.__wrapped__, f) 635 self.assertEqual(wrapper.__name__, 'f') 636 self.assertEqual(wrapper.__qualname__, f.__qualname__) 637 self.assertEqual(wrapper.attr, 'This is also a test') 638 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 639 self.assertNotIn('b', wrapper.__annotations__) 640 641 @unittest.skipIf(sys.flags.optimize >= 2, 642 "Docstrings are omitted with -O2 and above") 643 def test_default_update_doc(self): 644 wrapper, f = self._default_update() 645 self.assertEqual(wrapper.__doc__, 'This is a test') 646 647 def test_no_update(self): 648 def f(): 649 """This is a test""" 650 pass 651 f.attr = 'This is also a test' 652 def wrapper(): 653 pass 654 functools.update_wrapper(wrapper, f, (), ()) 655 self.check_wrapper(wrapper, f, (), ()) 656 self.assertEqual(wrapper.__name__, 'wrapper') 657 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 658 self.assertEqual(wrapper.__doc__, None) 659 self.assertEqual(wrapper.__annotations__, {}) 660 self.assertFalse(hasattr(wrapper, 'attr')) 661 662 def test_selective_update(self): 663 def f(): 664 pass 665 f.attr = 'This is a different test' 666 f.dict_attr = dict(a=1, b=2, c=3) 667 def wrapper(): 668 pass 669 wrapper.dict_attr = {} 670 assign = ('attr',) 671 update = ('dict_attr',) 672 functools.update_wrapper(wrapper, f, assign, update) 673 self.check_wrapper(wrapper, f, assign, update) 674 self.assertEqual(wrapper.__name__, 'wrapper') 675 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 676 self.assertEqual(wrapper.__doc__, None) 677 self.assertEqual(wrapper.attr, 'This is a different test') 678 self.assertEqual(wrapper.dict_attr, f.dict_attr) 679 680 def test_missing_attributes(self): 681 def f(): 682 pass 683 def wrapper(): 684 pass 685 wrapper.dict_attr = {} 686 assign = ('attr',) 687 update = ('dict_attr',) 688 # Missing attributes on wrapped object are ignored 689 functools.update_wrapper(wrapper, f, assign, update) 690 self.assertNotIn('attr', wrapper.__dict__) 691 self.assertEqual(wrapper.dict_attr, {}) 692 # Wrapper must have expected attributes for updating 693 del wrapper.dict_attr 694 with self.assertRaises(AttributeError): 695 functools.update_wrapper(wrapper, f, assign, update) 696 wrapper.dict_attr = 1 697 with self.assertRaises(AttributeError): 698 functools.update_wrapper(wrapper, f, assign, update) 699 700 @support.requires_docstrings 701 @unittest.skipIf(sys.flags.optimize >= 2, 702 "Docstrings are omitted with -O2 and above") 703 def test_builtin_update(self): 704 # Test for bug #1576241 705 def wrapper(): 706 pass 707 functools.update_wrapper(wrapper, max) 708 self.assertEqual(wrapper.__name__, 'max') 709 self.assertTrue(wrapper.__doc__.startswith('max(')) 710 self.assertEqual(wrapper.__annotations__, {}) 711 712 713class TestWraps(TestUpdateWrapper): 714 715 def _default_update(self): 716 def f(): 717 """This is a test""" 718 pass 719 f.attr = 'This is also a test' 720 f.__wrapped__ = "This is still a bald faced lie" 721 @functools.wraps(f) 722 def wrapper(): 723 pass 724 return wrapper, f 725 726 def test_default_update(self): 727 wrapper, f = self._default_update() 728 self.check_wrapper(wrapper, f) 729 self.assertEqual(wrapper.__name__, 'f') 730 self.assertEqual(wrapper.__qualname__, f.__qualname__) 731 self.assertEqual(wrapper.attr, 'This is also a test') 732 733 @unittest.skipIf(sys.flags.optimize >= 2, 734 "Docstrings are omitted with -O2 and above") 735 def test_default_update_doc(self): 736 wrapper, _ = self._default_update() 737 self.assertEqual(wrapper.__doc__, 'This is a test') 738 739 def test_no_update(self): 740 def f(): 741 """This is a test""" 742 pass 743 f.attr = 'This is also a test' 744 @functools.wraps(f, (), ()) 745 def wrapper(): 746 pass 747 self.check_wrapper(wrapper, f, (), ()) 748 self.assertEqual(wrapper.__name__, 'wrapper') 749 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 750 self.assertEqual(wrapper.__doc__, None) 751 self.assertFalse(hasattr(wrapper, 'attr')) 752 753 def test_selective_update(self): 754 def f(): 755 pass 756 f.attr = 'This is a different test' 757 f.dict_attr = dict(a=1, b=2, c=3) 758 def add_dict_attr(f): 759 f.dict_attr = {} 760 return f 761 assign = ('attr',) 762 update = ('dict_attr',) 763 @functools.wraps(f, assign, update) 764 @add_dict_attr 765 def wrapper(): 766 pass 767 self.check_wrapper(wrapper, f, assign, update) 768 self.assertEqual(wrapper.__name__, 'wrapper') 769 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 770 self.assertEqual(wrapper.__doc__, None) 771 self.assertEqual(wrapper.attr, 'This is a different test') 772 self.assertEqual(wrapper.dict_attr, f.dict_attr) 773 774 775class TestReduce: 776 def test_reduce(self): 777 class Squares: 778 def __init__(self, max): 779 self.max = max 780 self.sofar = [] 781 782 def __len__(self): 783 return len(self.sofar) 784 785 def __getitem__(self, i): 786 if not 0 <= i < self.max: raise IndexError 787 n = len(self.sofar) 788 while n <= i: 789 self.sofar.append(n*n) 790 n += 1 791 return self.sofar[i] 792 def add(x, y): 793 return x + y 794 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 795 self.assertEqual( 796 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 797 ['a','c','d','w'] 798 ) 799 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 800 self.assertEqual( 801 self.reduce(lambda x, y: x*y, range(2,21), 1), 802 2432902008176640000 803 ) 804 self.assertEqual(self.reduce(add, Squares(10)), 285) 805 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 806 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 807 self.assertRaises(TypeError, self.reduce) 808 self.assertRaises(TypeError, self.reduce, 42, 42) 809 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 810 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 811 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 812 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 813 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 814 self.assertRaises(TypeError, self.reduce, add, "") 815 self.assertRaises(TypeError, self.reduce, add, ()) 816 self.assertRaises(TypeError, self.reduce, add, object()) 817 818 class TestFailingIter: 819 def __iter__(self): 820 raise RuntimeError 821 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 822 823 self.assertEqual(self.reduce(add, [], None), None) 824 self.assertEqual(self.reduce(add, [], 42), 42) 825 826 class BadSeq: 827 def __getitem__(self, index): 828 raise ValueError 829 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 830 831 # Test reduce()'s use of iterators. 832 def test_iterator_usage(self): 833 class SequenceClass: 834 def __init__(self, n): 835 self.n = n 836 def __getitem__(self, i): 837 if 0 <= i < self.n: 838 return i 839 else: 840 raise IndexError 841 842 from operator import add 843 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 844 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 845 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 846 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 847 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 848 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 849 850 d = {"one": 1, "two": 2, "three": 3} 851 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 852 853 854@unittest.skipUnless(c_functools, 'requires the C _functools module') 855class TestReduceC(TestReduce, unittest.TestCase): 856 if c_functools: 857 reduce = c_functools.reduce 858 859 860class TestReducePy(TestReduce, unittest.TestCase): 861 reduce = staticmethod(py_functools.reduce) 862 863 864class TestCmpToKey: 865 866 def test_cmp_to_key(self): 867 def cmp1(x, y): 868 return (x > y) - (x < y) 869 key = self.cmp_to_key(cmp1) 870 self.assertEqual(key(3), key(3)) 871 self.assertGreater(key(3), key(1)) 872 self.assertGreaterEqual(key(3), key(3)) 873 874 def cmp2(x, y): 875 return int(x) - int(y) 876 key = self.cmp_to_key(cmp2) 877 self.assertEqual(key(4.0), key('4')) 878 self.assertLess(key(2), key('35')) 879 self.assertLessEqual(key(2), key('35')) 880 self.assertNotEqual(key(2), key('35')) 881 882 def test_cmp_to_key_arguments(self): 883 def cmp1(x, y): 884 return (x > y) - (x < y) 885 key = self.cmp_to_key(mycmp=cmp1) 886 self.assertEqual(key(obj=3), key(obj=3)) 887 self.assertGreater(key(obj=3), key(obj=1)) 888 with self.assertRaises((TypeError, AttributeError)): 889 key(3) > 1 # rhs is not a K object 890 with self.assertRaises((TypeError, AttributeError)): 891 1 < key(3) # lhs is not a K object 892 with self.assertRaises(TypeError): 893 key = self.cmp_to_key() # too few args 894 with self.assertRaises(TypeError): 895 key = self.cmp_to_key(cmp1, None) # too many args 896 key = self.cmp_to_key(cmp1) 897 with self.assertRaises(TypeError): 898 key() # too few args 899 with self.assertRaises(TypeError): 900 key(None, None) # too many args 901 902 def test_bad_cmp(self): 903 def cmp1(x, y): 904 raise ZeroDivisionError 905 key = self.cmp_to_key(cmp1) 906 with self.assertRaises(ZeroDivisionError): 907 key(3) > key(1) 908 909 class BadCmp: 910 def __lt__(self, other): 911 raise ZeroDivisionError 912 def cmp1(x, y): 913 return BadCmp() 914 with self.assertRaises(ZeroDivisionError): 915 key(3) > key(1) 916 917 def test_obj_field(self): 918 def cmp1(x, y): 919 return (x > y) - (x < y) 920 key = self.cmp_to_key(mycmp=cmp1) 921 self.assertEqual(key(50).obj, 50) 922 923 def test_sort_int(self): 924 def mycmp(x, y): 925 return y - x 926 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 927 [4, 3, 2, 1, 0]) 928 929 def test_sort_int_str(self): 930 def mycmp(x, y): 931 x, y = int(x), int(y) 932 return (x > y) - (x < y) 933 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 934 values = sorted(values, key=self.cmp_to_key(mycmp)) 935 self.assertEqual([int(value) for value in values], 936 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 937 938 def test_hash(self): 939 def mycmp(x, y): 940 return y - x 941 key = self.cmp_to_key(mycmp) 942 k = key(10) 943 self.assertRaises(TypeError, hash, k) 944 self.assertNotIsInstance(k, collections.abc.Hashable) 945 946 947@unittest.skipUnless(c_functools, 'requires the C _functools module') 948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 949 if c_functools: 950 cmp_to_key = c_functools.cmp_to_key 951 952 @support.cpython_only 953 def test_disallow_instantiation(self): 954 # Ensure that the type disallows instantiation (bpo-43916) 955 support.check_disallow_instantiation( 956 self, type(c_functools.cmp_to_key(None)) 957 ) 958 959 960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 961 cmp_to_key = staticmethod(py_functools.cmp_to_key) 962 963 964class TestTotalOrdering(unittest.TestCase): 965 966 def test_total_ordering_lt(self): 967 @functools.total_ordering 968 class A: 969 def __init__(self, value): 970 self.value = value 971 def __lt__(self, other): 972 return self.value < other.value 973 def __eq__(self, other): 974 return self.value == other.value 975 self.assertTrue(A(1) < A(2)) 976 self.assertTrue(A(2) > A(1)) 977 self.assertTrue(A(1) <= A(2)) 978 self.assertTrue(A(2) >= A(1)) 979 self.assertTrue(A(2) <= A(2)) 980 self.assertTrue(A(2) >= A(2)) 981 self.assertFalse(A(1) > A(2)) 982 983 def test_total_ordering_le(self): 984 @functools.total_ordering 985 class A: 986 def __init__(self, value): 987 self.value = value 988 def __le__(self, other): 989 return self.value <= other.value 990 def __eq__(self, other): 991 return self.value == other.value 992 self.assertTrue(A(1) < A(2)) 993 self.assertTrue(A(2) > A(1)) 994 self.assertTrue(A(1) <= A(2)) 995 self.assertTrue(A(2) >= A(1)) 996 self.assertTrue(A(2) <= A(2)) 997 self.assertTrue(A(2) >= A(2)) 998 self.assertFalse(A(1) >= A(2)) 999 1000 def test_total_ordering_gt(self): 1001 @functools.total_ordering 1002 class A: 1003 def __init__(self, value): 1004 self.value = value 1005 def __gt__(self, other): 1006 return self.value > other.value 1007 def __eq__(self, other): 1008 return self.value == other.value 1009 self.assertTrue(A(1) < A(2)) 1010 self.assertTrue(A(2) > A(1)) 1011 self.assertTrue(A(1) <= A(2)) 1012 self.assertTrue(A(2) >= A(1)) 1013 self.assertTrue(A(2) <= A(2)) 1014 self.assertTrue(A(2) >= A(2)) 1015 self.assertFalse(A(2) < A(1)) 1016 1017 def test_total_ordering_ge(self): 1018 @functools.total_ordering 1019 class A: 1020 def __init__(self, value): 1021 self.value = value 1022 def __ge__(self, other): 1023 return self.value >= other.value 1024 def __eq__(self, other): 1025 return self.value == other.value 1026 self.assertTrue(A(1) < A(2)) 1027 self.assertTrue(A(2) > A(1)) 1028 self.assertTrue(A(1) <= A(2)) 1029 self.assertTrue(A(2) >= A(1)) 1030 self.assertTrue(A(2) <= A(2)) 1031 self.assertTrue(A(2) >= A(2)) 1032 self.assertFalse(A(2) <= A(1)) 1033 1034 def test_total_ordering_no_overwrite(self): 1035 # new methods should not overwrite existing 1036 @functools.total_ordering 1037 class A(int): 1038 pass 1039 self.assertTrue(A(1) < A(2)) 1040 self.assertTrue(A(2) > A(1)) 1041 self.assertTrue(A(1) <= A(2)) 1042 self.assertTrue(A(2) >= A(1)) 1043 self.assertTrue(A(2) <= A(2)) 1044 self.assertTrue(A(2) >= A(2)) 1045 1046 def test_no_operations_defined(self): 1047 with self.assertRaises(ValueError): 1048 @functools.total_ordering 1049 class A: 1050 pass 1051 1052 def test_type_error_when_not_implemented(self): 1053 # bug 10042; ensure stack overflow does not occur 1054 # when decorated types return NotImplemented 1055 @functools.total_ordering 1056 class ImplementsLessThan: 1057 def __init__(self, value): 1058 self.value = value 1059 def __eq__(self, other): 1060 if isinstance(other, ImplementsLessThan): 1061 return self.value == other.value 1062 return False 1063 def __lt__(self, other): 1064 if isinstance(other, ImplementsLessThan): 1065 return self.value < other.value 1066 return NotImplemented 1067 1068 @functools.total_ordering 1069 class ImplementsGreaterThan: 1070 def __init__(self, value): 1071 self.value = value 1072 def __eq__(self, other): 1073 if isinstance(other, ImplementsGreaterThan): 1074 return self.value == other.value 1075 return False 1076 def __gt__(self, other): 1077 if isinstance(other, ImplementsGreaterThan): 1078 return self.value > other.value 1079 return NotImplemented 1080 1081 @functools.total_ordering 1082 class ImplementsLessThanEqualTo: 1083 def __init__(self, value): 1084 self.value = value 1085 def __eq__(self, other): 1086 if isinstance(other, ImplementsLessThanEqualTo): 1087 return self.value == other.value 1088 return False 1089 def __le__(self, other): 1090 if isinstance(other, ImplementsLessThanEqualTo): 1091 return self.value <= other.value 1092 return NotImplemented 1093 1094 @functools.total_ordering 1095 class ImplementsGreaterThanEqualTo: 1096 def __init__(self, value): 1097 self.value = value 1098 def __eq__(self, other): 1099 if isinstance(other, ImplementsGreaterThanEqualTo): 1100 return self.value == other.value 1101 return False 1102 def __ge__(self, other): 1103 if isinstance(other, ImplementsGreaterThanEqualTo): 1104 return self.value >= other.value 1105 return NotImplemented 1106 1107 @functools.total_ordering 1108 class ComparatorNotImplemented: 1109 def __init__(self, value): 1110 self.value = value 1111 def __eq__(self, other): 1112 if isinstance(other, ComparatorNotImplemented): 1113 return self.value == other.value 1114 return False 1115 def __lt__(self, other): 1116 return NotImplemented 1117 1118 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1119 ImplementsLessThan(-1) < 1 1120 1121 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1122 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1123 1124 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1125 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1126 1127 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1128 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1129 1130 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1131 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1132 1133 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1134 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1135 1136 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1137 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1138 1139 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1140 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1141 1142 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1143 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1144 1145 with self.subTest("GE when equal"): 1146 a = ComparatorNotImplemented(8) 1147 b = ComparatorNotImplemented(8) 1148 self.assertEqual(a, b) 1149 with self.assertRaises(TypeError): 1150 a >= b 1151 1152 with self.subTest("LE when equal"): 1153 a = ComparatorNotImplemented(9) 1154 b = ComparatorNotImplemented(9) 1155 self.assertEqual(a, b) 1156 with self.assertRaises(TypeError): 1157 a <= b 1158 1159 def test_pickle(self): 1160 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1161 for name in '__lt__', '__gt__', '__le__', '__ge__': 1162 with self.subTest(method=name, proto=proto): 1163 method = getattr(Orderable_LT, name) 1164 method_copy = pickle.loads(pickle.dumps(method, proto)) 1165 self.assertIs(method_copy, method) 1166 1167 1168 def test_total_ordering_for_metaclasses_issue_44605(self): 1169 1170 @functools.total_ordering 1171 class SortableMeta(type): 1172 def __new__(cls, name, bases, ns): 1173 return super().__new__(cls, name, bases, ns) 1174 1175 def __lt__(self, other): 1176 if not isinstance(other, SortableMeta): 1177 pass 1178 return self.__name__ < other.__name__ 1179 1180 def __eq__(self, other): 1181 if not isinstance(other, SortableMeta): 1182 pass 1183 return self.__name__ == other.__name__ 1184 1185 class B(metaclass=SortableMeta): 1186 pass 1187 1188 class A(metaclass=SortableMeta): 1189 pass 1190 1191 self.assertTrue(A < B) 1192 self.assertFalse(A > B) 1193 1194 1195@functools.total_ordering 1196class Orderable_LT: 1197 def __init__(self, value): 1198 self.value = value 1199 def __lt__(self, other): 1200 return self.value < other.value 1201 def __eq__(self, other): 1202 return self.value == other.value 1203 1204 1205class TestCache: 1206 # This tests that the pass-through is working as designed. 1207 # The underlying functionality is tested in TestLRU. 1208 1209 def test_cache(self): 1210 @self.module.cache 1211 def fib(n): 1212 if n < 2: 1213 return n 1214 return fib(n-1) + fib(n-2) 1215 self.assertEqual([fib(n) for n in range(16)], 1216 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1217 self.assertEqual(fib.cache_info(), 1218 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1219 fib.cache_clear() 1220 self.assertEqual(fib.cache_info(), 1221 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1222 1223 1224class TestLRU: 1225 1226 def test_lru(self): 1227 def orig(x, y): 1228 return 3 * x + y 1229 f = self.module.lru_cache(maxsize=20)(orig) 1230 hits, misses, maxsize, currsize = f.cache_info() 1231 self.assertEqual(maxsize, 20) 1232 self.assertEqual(currsize, 0) 1233 self.assertEqual(hits, 0) 1234 self.assertEqual(misses, 0) 1235 1236 domain = range(5) 1237 for i in range(1000): 1238 x, y = choice(domain), choice(domain) 1239 actual = f(x, y) 1240 expected = orig(x, y) 1241 self.assertEqual(actual, expected) 1242 hits, misses, maxsize, currsize = f.cache_info() 1243 self.assertTrue(hits > misses) 1244 self.assertEqual(hits + misses, 1000) 1245 self.assertEqual(currsize, 20) 1246 1247 f.cache_clear() # test clearing 1248 hits, misses, maxsize, currsize = f.cache_info() 1249 self.assertEqual(hits, 0) 1250 self.assertEqual(misses, 0) 1251 self.assertEqual(currsize, 0) 1252 f(x, y) 1253 hits, misses, maxsize, currsize = f.cache_info() 1254 self.assertEqual(hits, 0) 1255 self.assertEqual(misses, 1) 1256 self.assertEqual(currsize, 1) 1257 1258 # Test bypassing the cache 1259 self.assertIs(f.__wrapped__, orig) 1260 f.__wrapped__(x, y) 1261 hits, misses, maxsize, currsize = f.cache_info() 1262 self.assertEqual(hits, 0) 1263 self.assertEqual(misses, 1) 1264 self.assertEqual(currsize, 1) 1265 1266 # test size zero (which means "never-cache") 1267 @self.module.lru_cache(0) 1268 def f(): 1269 nonlocal f_cnt 1270 f_cnt += 1 1271 return 20 1272 self.assertEqual(f.cache_info().maxsize, 0) 1273 f_cnt = 0 1274 for i in range(5): 1275 self.assertEqual(f(), 20) 1276 self.assertEqual(f_cnt, 5) 1277 hits, misses, maxsize, currsize = f.cache_info() 1278 self.assertEqual(hits, 0) 1279 self.assertEqual(misses, 5) 1280 self.assertEqual(currsize, 0) 1281 1282 # test size one 1283 @self.module.lru_cache(1) 1284 def f(): 1285 nonlocal f_cnt 1286 f_cnt += 1 1287 return 20 1288 self.assertEqual(f.cache_info().maxsize, 1) 1289 f_cnt = 0 1290 for i in range(5): 1291 self.assertEqual(f(), 20) 1292 self.assertEqual(f_cnt, 1) 1293 hits, misses, maxsize, currsize = f.cache_info() 1294 self.assertEqual(hits, 4) 1295 self.assertEqual(misses, 1) 1296 self.assertEqual(currsize, 1) 1297 1298 # test size two 1299 @self.module.lru_cache(2) 1300 def f(x): 1301 nonlocal f_cnt 1302 f_cnt += 1 1303 return x*10 1304 self.assertEqual(f.cache_info().maxsize, 2) 1305 f_cnt = 0 1306 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1307 # * * * * 1308 self.assertEqual(f(x), x*10) 1309 self.assertEqual(f_cnt, 4) 1310 hits, misses, maxsize, currsize = f.cache_info() 1311 self.assertEqual(hits, 12) 1312 self.assertEqual(misses, 4) 1313 self.assertEqual(currsize, 2) 1314 1315 def test_lru_no_args(self): 1316 @self.module.lru_cache 1317 def square(x): 1318 return x ** 2 1319 1320 self.assertEqual(list(map(square, [10, 20, 10])), 1321 [100, 400, 100]) 1322 self.assertEqual(square.cache_info().hits, 1) 1323 self.assertEqual(square.cache_info().misses, 2) 1324 self.assertEqual(square.cache_info().maxsize, 128) 1325 self.assertEqual(square.cache_info().currsize, 2) 1326 1327 def test_lru_bug_35780(self): 1328 # C version of the lru_cache was not checking to see if 1329 # the user function call has already modified the cache 1330 # (this arises in recursive calls and in multi-threading). 1331 # This cause the cache to have orphan links not referenced 1332 # by the cache dictionary. 1333 1334 once = True # Modified by f(x) below 1335 1336 @self.module.lru_cache(maxsize=10) 1337 def f(x): 1338 nonlocal once 1339 rv = f'.{x}.' 1340 if x == 20 and once: 1341 once = False 1342 rv = f(x) 1343 return rv 1344 1345 # Fill the cache 1346 for x in range(15): 1347 self.assertEqual(f(x), f'.{x}.') 1348 self.assertEqual(f.cache_info().currsize, 10) 1349 1350 # Make a recursive call and make sure the cache remains full 1351 self.assertEqual(f(20), '.20.') 1352 self.assertEqual(f.cache_info().currsize, 10) 1353 1354 def test_lru_bug_36650(self): 1355 # C version of lru_cache was treating a call with an empty **kwargs 1356 # dictionary as being distinct from a call with no keywords at all. 1357 # This did not result in an incorrect answer, but it did trigger 1358 # an unexpected cache miss. 1359 1360 @self.module.lru_cache() 1361 def f(x): 1362 pass 1363 1364 f(0) 1365 f(0, **{}) 1366 self.assertEqual(f.cache_info().hits, 1) 1367 1368 def test_lru_hash_only_once(self): 1369 # To protect against weird reentrancy bugs and to improve 1370 # efficiency when faced with slow __hash__ methods, the 1371 # LRU cache guarantees that it will only call __hash__ 1372 # only once per use as an argument to the cached function. 1373 1374 @self.module.lru_cache(maxsize=1) 1375 def f(x, y): 1376 return x * 3 + y 1377 1378 # Simulate the integer 5 1379 mock_int = unittest.mock.Mock() 1380 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1381 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1382 1383 # Add to cache: One use as an argument gives one call 1384 self.assertEqual(f(mock_int, 1), 16) 1385 self.assertEqual(mock_int.__hash__.call_count, 1) 1386 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1387 1388 # Cache hit: One use as an argument gives one additional call 1389 self.assertEqual(f(mock_int, 1), 16) 1390 self.assertEqual(mock_int.__hash__.call_count, 2) 1391 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1392 1393 # Cache eviction: No use as an argument gives no additional call 1394 self.assertEqual(f(6, 2), 20) 1395 self.assertEqual(mock_int.__hash__.call_count, 2) 1396 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1397 1398 # Cache miss: One use as an argument gives one additional call 1399 self.assertEqual(f(mock_int, 1), 16) 1400 self.assertEqual(mock_int.__hash__.call_count, 3) 1401 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1402 1403 def test_lru_reentrancy_with_len(self): 1404 # Test to make sure the LRU cache code isn't thrown-off by 1405 # caching the built-in len() function. Since len() can be 1406 # cached, we shouldn't use it inside the lru code itself. 1407 old_len = builtins.len 1408 try: 1409 builtins.len = self.module.lru_cache(4)(len) 1410 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1411 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1412 finally: 1413 builtins.len = old_len 1414 1415 def test_lru_star_arg_handling(self): 1416 # Test regression that arose in ea064ff3c10f 1417 @functools.lru_cache() 1418 def f(*args): 1419 return args 1420 1421 self.assertEqual(f(1, 2), (1, 2)) 1422 self.assertEqual(f((1, 2)), ((1, 2),)) 1423 1424 def test_lru_type_error(self): 1425 # Regression test for issue #28653. 1426 # lru_cache was leaking when one of the arguments 1427 # wasn't cacheable. 1428 1429 @functools.lru_cache(maxsize=None) 1430 def infinite_cache(o): 1431 pass 1432 1433 @functools.lru_cache(maxsize=10) 1434 def limited_cache(o): 1435 pass 1436 1437 with self.assertRaises(TypeError): 1438 infinite_cache([]) 1439 1440 with self.assertRaises(TypeError): 1441 limited_cache([]) 1442 1443 def test_lru_with_maxsize_none(self): 1444 @self.module.lru_cache(maxsize=None) 1445 def fib(n): 1446 if n < 2: 1447 return n 1448 return fib(n-1) + fib(n-2) 1449 self.assertEqual([fib(n) for n in range(16)], 1450 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1451 self.assertEqual(fib.cache_info(), 1452 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1453 fib.cache_clear() 1454 self.assertEqual(fib.cache_info(), 1455 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1456 1457 def test_lru_with_maxsize_negative(self): 1458 @self.module.lru_cache(maxsize=-10) 1459 def eq(n): 1460 return n 1461 for i in (0, 1): 1462 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1463 self.assertEqual(eq.cache_info(), 1464 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1465 1466 def test_lru_with_exceptions(self): 1467 # Verify that user_function exceptions get passed through without 1468 # creating a hard-to-read chained exception. 1469 # http://bugs.python.org/issue13177 1470 for maxsize in (None, 128): 1471 @self.module.lru_cache(maxsize) 1472 def func(i): 1473 return 'abc'[i] 1474 self.assertEqual(func(0), 'a') 1475 with self.assertRaises(IndexError) as cm: 1476 func(15) 1477 self.assertIsNone(cm.exception.__context__) 1478 # Verify that the previous exception did not result in a cached entry 1479 with self.assertRaises(IndexError): 1480 func(15) 1481 1482 def test_lru_with_types(self): 1483 for maxsize in (None, 128): 1484 @self.module.lru_cache(maxsize=maxsize, typed=True) 1485 def square(x): 1486 return x * x 1487 self.assertEqual(square(3), 9) 1488 self.assertEqual(type(square(3)), type(9)) 1489 self.assertEqual(square(3.0), 9.0) 1490 self.assertEqual(type(square(3.0)), type(9.0)) 1491 self.assertEqual(square(x=3), 9) 1492 self.assertEqual(type(square(x=3)), type(9)) 1493 self.assertEqual(square(x=3.0), 9.0) 1494 self.assertEqual(type(square(x=3.0)), type(9.0)) 1495 self.assertEqual(square.cache_info().hits, 4) 1496 self.assertEqual(square.cache_info().misses, 4) 1497 1498 def test_lru_cache_typed_is_not_recursive(self): 1499 cached = self.module.lru_cache(typed=True)(repr) 1500 1501 self.assertEqual(cached(1), '1') 1502 self.assertEqual(cached(True), 'True') 1503 self.assertEqual(cached(1.0), '1.0') 1504 self.assertEqual(cached(0), '0') 1505 self.assertEqual(cached(False), 'False') 1506 self.assertEqual(cached(0.0), '0.0') 1507 1508 self.assertEqual(cached((1,)), '(1,)') 1509 self.assertEqual(cached((True,)), '(1,)') 1510 self.assertEqual(cached((1.0,)), '(1,)') 1511 self.assertEqual(cached((0,)), '(0,)') 1512 self.assertEqual(cached((False,)), '(0,)') 1513 self.assertEqual(cached((0.0,)), '(0,)') 1514 1515 class T(tuple): 1516 pass 1517 1518 self.assertEqual(cached(T((1,))), '(1,)') 1519 self.assertEqual(cached(T((True,))), '(1,)') 1520 self.assertEqual(cached(T((1.0,))), '(1,)') 1521 self.assertEqual(cached(T((0,))), '(0,)') 1522 self.assertEqual(cached(T((False,))), '(0,)') 1523 self.assertEqual(cached(T((0.0,))), '(0,)') 1524 1525 def test_lru_with_keyword_args(self): 1526 @self.module.lru_cache() 1527 def fib(n): 1528 if n < 2: 1529 return n 1530 return fib(n=n-1) + fib(n=n-2) 1531 self.assertEqual( 1532 [fib(n=number) for number in range(16)], 1533 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1534 ) 1535 self.assertEqual(fib.cache_info(), 1536 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1537 fib.cache_clear() 1538 self.assertEqual(fib.cache_info(), 1539 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1540 1541 def test_lru_with_keyword_args_maxsize_none(self): 1542 @self.module.lru_cache(maxsize=None) 1543 def fib(n): 1544 if n < 2: 1545 return n 1546 return fib(n=n-1) + fib(n=n-2) 1547 self.assertEqual([fib(n=number) for number in range(16)], 1548 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1549 self.assertEqual(fib.cache_info(), 1550 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1551 fib.cache_clear() 1552 self.assertEqual(fib.cache_info(), 1553 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1554 1555 def test_kwargs_order(self): 1556 # PEP 468: Preserving Keyword Argument Order 1557 @self.module.lru_cache(maxsize=10) 1558 def f(**kwargs): 1559 return list(kwargs.items()) 1560 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1561 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1562 self.assertEqual(f.cache_info(), 1563 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1564 1565 def test_lru_cache_decoration(self): 1566 def f(zomg: 'zomg_annotation'): 1567 """f doc string""" 1568 return 42 1569 g = self.module.lru_cache()(f) 1570 for attr in self.module.WRAPPER_ASSIGNMENTS: 1571 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1572 1573 def test_lru_cache_threaded(self): 1574 n, m = 5, 11 1575 def orig(x, y): 1576 return 3 * x + y 1577 f = self.module.lru_cache(maxsize=n*m)(orig) 1578 hits, misses, maxsize, currsize = f.cache_info() 1579 self.assertEqual(currsize, 0) 1580 1581 start = threading.Event() 1582 def full(k): 1583 start.wait(10) 1584 for _ in range(m): 1585 self.assertEqual(f(k, 0), orig(k, 0)) 1586 1587 def clear(): 1588 start.wait(10) 1589 for _ in range(2*m): 1590 f.cache_clear() 1591 1592 orig_si = sys.getswitchinterval() 1593 support.setswitchinterval(1e-6) 1594 try: 1595 # create n threads in order to fill cache 1596 threads = [threading.Thread(target=full, args=[k]) 1597 for k in range(n)] 1598 with threading_helper.start_threads(threads): 1599 start.set() 1600 1601 hits, misses, maxsize, currsize = f.cache_info() 1602 if self.module is py_functools: 1603 # XXX: Why can be not equal? 1604 self.assertLessEqual(misses, n) 1605 self.assertLessEqual(hits, m*n - misses) 1606 else: 1607 self.assertEqual(misses, n) 1608 self.assertEqual(hits, m*n - misses) 1609 self.assertEqual(currsize, n) 1610 1611 # create n threads in order to fill cache and 1 to clear it 1612 threads = [threading.Thread(target=clear)] 1613 threads += [threading.Thread(target=full, args=[k]) 1614 for k in range(n)] 1615 start.clear() 1616 with threading_helper.start_threads(threads): 1617 start.set() 1618 finally: 1619 sys.setswitchinterval(orig_si) 1620 1621 def test_lru_cache_threaded2(self): 1622 # Simultaneous call with the same arguments 1623 n, m = 5, 7 1624 start = threading.Barrier(n+1) 1625 pause = threading.Barrier(n+1) 1626 stop = threading.Barrier(n+1) 1627 @self.module.lru_cache(maxsize=m*n) 1628 def f(x): 1629 pause.wait(10) 1630 return 3 * x 1631 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1632 def test(): 1633 for i in range(m): 1634 start.wait(10) 1635 self.assertEqual(f(i), 3 * i) 1636 stop.wait(10) 1637 threads = [threading.Thread(target=test) for k in range(n)] 1638 with threading_helper.start_threads(threads): 1639 for i in range(m): 1640 start.wait(10) 1641 stop.reset() 1642 pause.wait(10) 1643 start.reset() 1644 stop.wait(10) 1645 pause.reset() 1646 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1647 1648 def test_lru_cache_threaded3(self): 1649 @self.module.lru_cache(maxsize=2) 1650 def f(x): 1651 time.sleep(.01) 1652 return 3 * x 1653 def test(i, x): 1654 with self.subTest(thread=i): 1655 self.assertEqual(f(x), 3 * x, i) 1656 threads = [threading.Thread(target=test, args=(i, v)) 1657 for i, v in enumerate([1, 2, 2, 3, 2])] 1658 with threading_helper.start_threads(threads): 1659 pass 1660 1661 def test_need_for_rlock(self): 1662 # This will deadlock on an LRU cache that uses a regular lock 1663 1664 @self.module.lru_cache(maxsize=10) 1665 def test_func(x): 1666 'Used to demonstrate a reentrant lru_cache call within a single thread' 1667 return x 1668 1669 class DoubleEq: 1670 'Demonstrate a reentrant lru_cache call within a single thread' 1671 def __init__(self, x): 1672 self.x = x 1673 def __hash__(self): 1674 return self.x 1675 def __eq__(self, other): 1676 if self.x == 2: 1677 test_func(DoubleEq(1)) 1678 return self.x == other.x 1679 1680 test_func(DoubleEq(1)) # Load the cache 1681 test_func(DoubleEq(2)) # Load the cache 1682 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1683 DoubleEq(2)) # Verify the correct return value 1684 1685 def test_lru_method(self): 1686 class X(int): 1687 f_cnt = 0 1688 @self.module.lru_cache(2) 1689 def f(self, x): 1690 self.f_cnt += 1 1691 return x*10+self 1692 a = X(5) 1693 b = X(5) 1694 c = X(7) 1695 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1696 1697 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1698 self.assertEqual(a.f(x), x*10 + 5) 1699 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1700 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1701 1702 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1703 self.assertEqual(b.f(x), x*10 + 5) 1704 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1705 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1706 1707 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1708 self.assertEqual(c.f(x), x*10 + 7) 1709 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1710 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1711 1712 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1713 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1714 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1715 1716 def test_pickle(self): 1717 cls = self.__class__ 1718 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1719 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1720 with self.subTest(proto=proto, func=f): 1721 f_copy = pickle.loads(pickle.dumps(f, proto)) 1722 self.assertIs(f_copy, f) 1723 1724 def test_copy(self): 1725 cls = self.__class__ 1726 def orig(x, y): 1727 return 3 * x + y 1728 part = self.module.partial(orig, 2) 1729 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1730 self.module.lru_cache(2)(part)) 1731 for f in funcs: 1732 with self.subTest(func=f): 1733 f_copy = copy.copy(f) 1734 self.assertIs(f_copy, f) 1735 1736 def test_deepcopy(self): 1737 cls = self.__class__ 1738 def orig(x, y): 1739 return 3 * x + y 1740 part = self.module.partial(orig, 2) 1741 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1742 self.module.lru_cache(2)(part)) 1743 for f in funcs: 1744 with self.subTest(func=f): 1745 f_copy = copy.deepcopy(f) 1746 self.assertIs(f_copy, f) 1747 1748 def test_lru_cache_parameters(self): 1749 @self.module.lru_cache(maxsize=2) 1750 def f(): 1751 return 1 1752 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) 1753 1754 @self.module.lru_cache(maxsize=1000, typed=True) 1755 def f(): 1756 return 1 1757 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) 1758 1759 def test_lru_cache_weakrefable(self): 1760 @self.module.lru_cache 1761 def test_function(x): 1762 return x 1763 1764 class A: 1765 @self.module.lru_cache 1766 def test_method(self, x): 1767 return (self, x) 1768 1769 @staticmethod 1770 @self.module.lru_cache 1771 def test_staticmethod(x): 1772 return (self, x) 1773 1774 refs = [weakref.ref(test_function), 1775 weakref.ref(A.test_method), 1776 weakref.ref(A.test_staticmethod)] 1777 1778 for ref in refs: 1779 self.assertIsNotNone(ref()) 1780 1781 del A 1782 del test_function 1783 gc.collect() 1784 1785 for ref in refs: 1786 self.assertIsNone(ref()) 1787 1788 1789@py_functools.lru_cache() 1790def py_cached_func(x, y): 1791 return 3 * x + y 1792 1793@c_functools.lru_cache() 1794def c_cached_func(x, y): 1795 return 3 * x + y 1796 1797 1798class TestLRUPy(TestLRU, unittest.TestCase): 1799 module = py_functools 1800 cached_func = py_cached_func, 1801 1802 @module.lru_cache() 1803 def cached_meth(self, x, y): 1804 return 3 * x + y 1805 1806 @staticmethod 1807 @module.lru_cache() 1808 def cached_staticmeth(x, y): 1809 return 3 * x + y 1810 1811 1812class TestLRUC(TestLRU, unittest.TestCase): 1813 module = c_functools 1814 cached_func = c_cached_func, 1815 1816 @module.lru_cache() 1817 def cached_meth(self, x, y): 1818 return 3 * x + y 1819 1820 @staticmethod 1821 @module.lru_cache() 1822 def cached_staticmeth(x, y): 1823 return 3 * x + y 1824 1825 1826class TestSingleDispatch(unittest.TestCase): 1827 def test_simple_overloads(self): 1828 @functools.singledispatch 1829 def g(obj): 1830 return "base" 1831 def g_int(i): 1832 return "integer" 1833 g.register(int, g_int) 1834 self.assertEqual(g("str"), "base") 1835 self.assertEqual(g(1), "integer") 1836 self.assertEqual(g([1,2,3]), "base") 1837 1838 def test_mro(self): 1839 @functools.singledispatch 1840 def g(obj): 1841 return "base" 1842 class A: 1843 pass 1844 class C(A): 1845 pass 1846 class B(A): 1847 pass 1848 class D(C, B): 1849 pass 1850 def g_A(a): 1851 return "A" 1852 def g_B(b): 1853 return "B" 1854 g.register(A, g_A) 1855 g.register(B, g_B) 1856 self.assertEqual(g(A()), "A") 1857 self.assertEqual(g(B()), "B") 1858 self.assertEqual(g(C()), "A") 1859 self.assertEqual(g(D()), "B") 1860 1861 def test_register_decorator(self): 1862 @functools.singledispatch 1863 def g(obj): 1864 return "base" 1865 @g.register(int) 1866 def g_int(i): 1867 return "int %s" % (i,) 1868 self.assertEqual(g(""), "base") 1869 self.assertEqual(g(12), "int 12") 1870 self.assertIs(g.dispatch(int), g_int) 1871 self.assertIs(g.dispatch(object), g.dispatch(str)) 1872 # Note: in the assert above this is not g. 1873 # @singledispatch returns the wrapper. 1874 1875 def test_wrapping_attributes(self): 1876 @functools.singledispatch 1877 def g(obj): 1878 "Simple test" 1879 return "Test" 1880 self.assertEqual(g.__name__, "g") 1881 if sys.flags.optimize < 2: 1882 self.assertEqual(g.__doc__, "Simple test") 1883 1884 @unittest.skipUnless(decimal, 'requires _decimal') 1885 @support.cpython_only 1886 def test_c_classes(self): 1887 @functools.singledispatch 1888 def g(obj): 1889 return "base" 1890 @g.register(decimal.DecimalException) 1891 def _(obj): 1892 return obj.args 1893 subn = decimal.Subnormal("Exponent < Emin") 1894 rnd = decimal.Rounded("Number got rounded") 1895 self.assertEqual(g(subn), ("Exponent < Emin",)) 1896 self.assertEqual(g(rnd), ("Number got rounded",)) 1897 @g.register(decimal.Subnormal) 1898 def _(obj): 1899 return "Too small to care." 1900 self.assertEqual(g(subn), "Too small to care.") 1901 self.assertEqual(g(rnd), ("Number got rounded",)) 1902 1903 def test_compose_mro(self): 1904 # None of the examples in this test depend on haystack ordering. 1905 c = collections.abc 1906 mro = functools._compose_mro 1907 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1908 for haystack in permutations(bases): 1909 m = mro(dict, haystack) 1910 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1911 c.Collection, c.Sized, c.Iterable, 1912 c.Container, object]) 1913 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1914 for haystack in permutations(bases): 1915 m = mro(collections.ChainMap, haystack) 1916 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1917 c.Collection, c.Sized, c.Iterable, 1918 c.Container, object]) 1919 1920 # If there's a generic function with implementations registered for 1921 # both Sized and Container, passing a defaultdict to it results in an 1922 # ambiguous dispatch which will cause a RuntimeError (see 1923 # test_mro_conflicts). 1924 bases = [c.Container, c.Sized, str] 1925 for haystack in permutations(bases): 1926 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1927 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1928 c.Container, object]) 1929 1930 # MutableSequence below is registered directly on D. In other words, it 1931 # precedes MutableMapping which means single dispatch will always 1932 # choose MutableSequence here. 1933 class D(collections.defaultdict): 1934 pass 1935 c.MutableSequence.register(D) 1936 bases = [c.MutableSequence, c.MutableMapping] 1937 for haystack in permutations(bases): 1938 m = mro(D, bases) 1939 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1940 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1941 c.Collection, c.Sized, c.Iterable, c.Container, 1942 object]) 1943 1944 # Container and Callable are registered on different base classes and 1945 # a generic function supporting both should always pick the Callable 1946 # implementation if a C instance is passed. 1947 class C(collections.defaultdict): 1948 def __call__(self): 1949 pass 1950 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1951 for haystack in permutations(bases): 1952 m = mro(C, haystack) 1953 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1954 c.Collection, c.Sized, c.Iterable, 1955 c.Container, object]) 1956 1957 def test_register_abc(self): 1958 c = collections.abc 1959 d = {"a": "b"} 1960 l = [1, 2, 3] 1961 s = {object(), None} 1962 f = frozenset(s) 1963 t = (1, 2, 3) 1964 @functools.singledispatch 1965 def g(obj): 1966 return "base" 1967 self.assertEqual(g(d), "base") 1968 self.assertEqual(g(l), "base") 1969 self.assertEqual(g(s), "base") 1970 self.assertEqual(g(f), "base") 1971 self.assertEqual(g(t), "base") 1972 g.register(c.Sized, lambda obj: "sized") 1973 self.assertEqual(g(d), "sized") 1974 self.assertEqual(g(l), "sized") 1975 self.assertEqual(g(s), "sized") 1976 self.assertEqual(g(f), "sized") 1977 self.assertEqual(g(t), "sized") 1978 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1979 self.assertEqual(g(d), "mutablemapping") 1980 self.assertEqual(g(l), "sized") 1981 self.assertEqual(g(s), "sized") 1982 self.assertEqual(g(f), "sized") 1983 self.assertEqual(g(t), "sized") 1984 g.register(collections.ChainMap, lambda obj: "chainmap") 1985 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1986 self.assertEqual(g(l), "sized") 1987 self.assertEqual(g(s), "sized") 1988 self.assertEqual(g(f), "sized") 1989 self.assertEqual(g(t), "sized") 1990 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1991 self.assertEqual(g(d), "mutablemapping") 1992 self.assertEqual(g(l), "mutablesequence") 1993 self.assertEqual(g(s), "sized") 1994 self.assertEqual(g(f), "sized") 1995 self.assertEqual(g(t), "sized") 1996 g.register(c.MutableSet, lambda obj: "mutableset") 1997 self.assertEqual(g(d), "mutablemapping") 1998 self.assertEqual(g(l), "mutablesequence") 1999 self.assertEqual(g(s), "mutableset") 2000 self.assertEqual(g(f), "sized") 2001 self.assertEqual(g(t), "sized") 2002 g.register(c.Mapping, lambda obj: "mapping") 2003 self.assertEqual(g(d), "mutablemapping") # not specific enough 2004 self.assertEqual(g(l), "mutablesequence") 2005 self.assertEqual(g(s), "mutableset") 2006 self.assertEqual(g(f), "sized") 2007 self.assertEqual(g(t), "sized") 2008 g.register(c.Sequence, lambda obj: "sequence") 2009 self.assertEqual(g(d), "mutablemapping") 2010 self.assertEqual(g(l), "mutablesequence") 2011 self.assertEqual(g(s), "mutableset") 2012 self.assertEqual(g(f), "sized") 2013 self.assertEqual(g(t), "sequence") 2014 g.register(c.Set, lambda obj: "set") 2015 self.assertEqual(g(d), "mutablemapping") 2016 self.assertEqual(g(l), "mutablesequence") 2017 self.assertEqual(g(s), "mutableset") 2018 self.assertEqual(g(f), "set") 2019 self.assertEqual(g(t), "sequence") 2020 g.register(dict, lambda obj: "dict") 2021 self.assertEqual(g(d), "dict") 2022 self.assertEqual(g(l), "mutablesequence") 2023 self.assertEqual(g(s), "mutableset") 2024 self.assertEqual(g(f), "set") 2025 self.assertEqual(g(t), "sequence") 2026 g.register(list, lambda obj: "list") 2027 self.assertEqual(g(d), "dict") 2028 self.assertEqual(g(l), "list") 2029 self.assertEqual(g(s), "mutableset") 2030 self.assertEqual(g(f), "set") 2031 self.assertEqual(g(t), "sequence") 2032 g.register(set, lambda obj: "concrete-set") 2033 self.assertEqual(g(d), "dict") 2034 self.assertEqual(g(l), "list") 2035 self.assertEqual(g(s), "concrete-set") 2036 self.assertEqual(g(f), "set") 2037 self.assertEqual(g(t), "sequence") 2038 g.register(frozenset, lambda obj: "frozen-set") 2039 self.assertEqual(g(d), "dict") 2040 self.assertEqual(g(l), "list") 2041 self.assertEqual(g(s), "concrete-set") 2042 self.assertEqual(g(f), "frozen-set") 2043 self.assertEqual(g(t), "sequence") 2044 g.register(tuple, lambda obj: "tuple") 2045 self.assertEqual(g(d), "dict") 2046 self.assertEqual(g(l), "list") 2047 self.assertEqual(g(s), "concrete-set") 2048 self.assertEqual(g(f), "frozen-set") 2049 self.assertEqual(g(t), "tuple") 2050 2051 def test_c3_abc(self): 2052 c = collections.abc 2053 mro = functools._c3_mro 2054 class A(object): 2055 pass 2056 class B(A): 2057 def __len__(self): 2058 return 0 # implies Sized 2059 @c.Container.register 2060 class C(object): 2061 pass 2062 class D(object): 2063 pass # unrelated 2064 class X(D, C, B): 2065 def __call__(self): 2066 pass # implies Callable 2067 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 2068 for abcs in permutations([c.Sized, c.Callable, c.Container]): 2069 self.assertEqual(mro(X, abcs=abcs), expected) 2070 # unrelated ABCs don't appear in the resulting MRO 2071 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 2072 self.assertEqual(mro(X, abcs=many_abcs), expected) 2073 2074 def test_false_meta(self): 2075 # see issue23572 2076 class MetaA(type): 2077 def __len__(self): 2078 return 0 2079 class A(metaclass=MetaA): 2080 pass 2081 class AA(A): 2082 pass 2083 @functools.singledispatch 2084 def fun(a): 2085 return 'base A' 2086 @fun.register(A) 2087 def _(a): 2088 return 'fun A' 2089 aa = AA() 2090 self.assertEqual(fun(aa), 'fun A') 2091 2092 def test_mro_conflicts(self): 2093 c = collections.abc 2094 @functools.singledispatch 2095 def g(arg): 2096 return "base" 2097 class O(c.Sized): 2098 def __len__(self): 2099 return 0 2100 o = O() 2101 self.assertEqual(g(o), "base") 2102 g.register(c.Iterable, lambda arg: "iterable") 2103 g.register(c.Container, lambda arg: "container") 2104 g.register(c.Sized, lambda arg: "sized") 2105 g.register(c.Set, lambda arg: "set") 2106 self.assertEqual(g(o), "sized") 2107 c.Iterable.register(O) 2108 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 2109 c.Container.register(O) 2110 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 2111 c.Set.register(O) 2112 self.assertEqual(g(o), "set") # because c.Set is a subclass of 2113 # c.Sized and c.Container 2114 class P: 2115 pass 2116 p = P() 2117 self.assertEqual(g(p), "base") 2118 c.Iterable.register(P) 2119 self.assertEqual(g(p), "iterable") 2120 c.Container.register(P) 2121 with self.assertRaises(RuntimeError) as re_one: 2122 g(p) 2123 self.assertIn( 2124 str(re_one.exception), 2125 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2126 "or <class 'collections.abc.Iterable'>"), 2127 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2128 "or <class 'collections.abc.Container'>")), 2129 ) 2130 class Q(c.Sized): 2131 def __len__(self): 2132 return 0 2133 q = Q() 2134 self.assertEqual(g(q), "sized") 2135 c.Iterable.register(Q) 2136 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2137 c.Set.register(Q) 2138 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2139 # c.Sized and c.Iterable 2140 @functools.singledispatch 2141 def h(arg): 2142 return "base" 2143 @h.register(c.Sized) 2144 def _(arg): 2145 return "sized" 2146 @h.register(c.Container) 2147 def _(arg): 2148 return "container" 2149 # Even though Sized and Container are explicit bases of MutableMapping, 2150 # this ABC is implicitly registered on defaultdict which makes all of 2151 # MutableMapping's bases implicit as well from defaultdict's 2152 # perspective. 2153 with self.assertRaises(RuntimeError) as re_two: 2154 h(collections.defaultdict(lambda: 0)) 2155 self.assertIn( 2156 str(re_two.exception), 2157 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2158 "or <class 'collections.abc.Sized'>"), 2159 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2160 "or <class 'collections.abc.Container'>")), 2161 ) 2162 class R(collections.defaultdict): 2163 pass 2164 c.MutableSequence.register(R) 2165 @functools.singledispatch 2166 def i(arg): 2167 return "base" 2168 @i.register(c.MutableMapping) 2169 def _(arg): 2170 return "mapping" 2171 @i.register(c.MutableSequence) 2172 def _(arg): 2173 return "sequence" 2174 r = R() 2175 self.assertEqual(i(r), "sequence") 2176 class S: 2177 pass 2178 class T(S, c.Sized): 2179 def __len__(self): 2180 return 0 2181 t = T() 2182 self.assertEqual(h(t), "sized") 2183 c.Container.register(T) 2184 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2185 class U: 2186 def __len__(self): 2187 return 0 2188 u = U() 2189 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2190 # from the existence of __len__() 2191 c.Container.register(U) 2192 # There is no preference for registered versus inferred ABCs. 2193 with self.assertRaises(RuntimeError) as re_three: 2194 h(u) 2195 self.assertIn( 2196 str(re_three.exception), 2197 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2198 "or <class 'collections.abc.Sized'>"), 2199 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2200 "or <class 'collections.abc.Container'>")), 2201 ) 2202 class V(c.Sized, S): 2203 def __len__(self): 2204 return 0 2205 @functools.singledispatch 2206 def j(arg): 2207 return "base" 2208 @j.register(S) 2209 def _(arg): 2210 return "s" 2211 @j.register(c.Container) 2212 def _(arg): 2213 return "container" 2214 v = V() 2215 self.assertEqual(j(v), "s") 2216 c.Container.register(V) 2217 self.assertEqual(j(v), "container") # because it ends up right after 2218 # Sized in the MRO 2219 2220 def test_cache_invalidation(self): 2221 from collections import UserDict 2222 import weakref 2223 2224 class TracingDict(UserDict): 2225 def __init__(self, *args, **kwargs): 2226 super(TracingDict, self).__init__(*args, **kwargs) 2227 self.set_ops = [] 2228 self.get_ops = [] 2229 def __getitem__(self, key): 2230 result = self.data[key] 2231 self.get_ops.append(key) 2232 return result 2233 def __setitem__(self, key, value): 2234 self.set_ops.append(key) 2235 self.data[key] = value 2236 def clear(self): 2237 self.data.clear() 2238 2239 td = TracingDict() 2240 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2241 c = collections.abc 2242 @functools.singledispatch 2243 def g(arg): 2244 return "base" 2245 d = {} 2246 l = [] 2247 self.assertEqual(len(td), 0) 2248 self.assertEqual(g(d), "base") 2249 self.assertEqual(len(td), 1) 2250 self.assertEqual(td.get_ops, []) 2251 self.assertEqual(td.set_ops, [dict]) 2252 self.assertEqual(td.data[dict], g.registry[object]) 2253 self.assertEqual(g(l), "base") 2254 self.assertEqual(len(td), 2) 2255 self.assertEqual(td.get_ops, []) 2256 self.assertEqual(td.set_ops, [dict, list]) 2257 self.assertEqual(td.data[dict], g.registry[object]) 2258 self.assertEqual(td.data[list], g.registry[object]) 2259 self.assertEqual(td.data[dict], td.data[list]) 2260 self.assertEqual(g(l), "base") 2261 self.assertEqual(g(d), "base") 2262 self.assertEqual(td.get_ops, [list, dict]) 2263 self.assertEqual(td.set_ops, [dict, list]) 2264 g.register(list, lambda arg: "list") 2265 self.assertEqual(td.get_ops, [list, dict]) 2266 self.assertEqual(len(td), 0) 2267 self.assertEqual(g(d), "base") 2268 self.assertEqual(len(td), 1) 2269 self.assertEqual(td.get_ops, [list, dict]) 2270 self.assertEqual(td.set_ops, [dict, list, dict]) 2271 self.assertEqual(td.data[dict], 2272 functools._find_impl(dict, g.registry)) 2273 self.assertEqual(g(l), "list") 2274 self.assertEqual(len(td), 2) 2275 self.assertEqual(td.get_ops, [list, dict]) 2276 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2277 self.assertEqual(td.data[list], 2278 functools._find_impl(list, g.registry)) 2279 class X: 2280 pass 2281 c.MutableMapping.register(X) # Will not invalidate the cache, 2282 # not using ABCs yet. 2283 self.assertEqual(g(d), "base") 2284 self.assertEqual(g(l), "list") 2285 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2286 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2287 g.register(c.Sized, lambda arg: "sized") 2288 self.assertEqual(len(td), 0) 2289 self.assertEqual(g(d), "sized") 2290 self.assertEqual(len(td), 1) 2291 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2292 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2293 self.assertEqual(g(l), "list") 2294 self.assertEqual(len(td), 2) 2295 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2296 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2297 self.assertEqual(g(l), "list") 2298 self.assertEqual(g(d), "sized") 2299 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2300 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2301 g.dispatch(list) 2302 g.dispatch(dict) 2303 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2304 list, dict]) 2305 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2306 c.MutableSet.register(X) # Will invalidate the cache. 2307 self.assertEqual(len(td), 2) # Stale cache. 2308 self.assertEqual(g(l), "list") 2309 self.assertEqual(len(td), 1) 2310 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2311 self.assertEqual(len(td), 0) 2312 self.assertEqual(g(d), "mutablemapping") 2313 self.assertEqual(len(td), 1) 2314 self.assertEqual(g(l), "list") 2315 self.assertEqual(len(td), 2) 2316 g.register(dict, lambda arg: "dict") 2317 self.assertEqual(g(d), "dict") 2318 self.assertEqual(g(l), "list") 2319 g._clear_cache() 2320 self.assertEqual(len(td), 0) 2321 2322 def test_annotations(self): 2323 @functools.singledispatch 2324 def i(arg): 2325 return "base" 2326 @i.register 2327 def _(arg: collections.abc.Mapping): 2328 return "mapping" 2329 @i.register 2330 def _(arg: "collections.abc.Sequence"): 2331 return "sequence" 2332 self.assertEqual(i(None), "base") 2333 self.assertEqual(i({"a": 1}), "mapping") 2334 self.assertEqual(i([1, 2, 3]), "sequence") 2335 self.assertEqual(i((1, 2, 3)), "sequence") 2336 self.assertEqual(i("str"), "sequence") 2337 2338 # Registering classes as callables doesn't work with annotations, 2339 # you need to pass the type explicitly. 2340 @i.register(str) 2341 class _: 2342 def __init__(self, arg): 2343 self.arg = arg 2344 2345 def __eq__(self, other): 2346 return self.arg == other 2347 self.assertEqual(i("str"), "str") 2348 2349 def test_method_register(self): 2350 class A: 2351 @functools.singledispatchmethod 2352 def t(self, arg): 2353 self.arg = "base" 2354 @t.register(int) 2355 def _(self, arg): 2356 self.arg = "int" 2357 @t.register(str) 2358 def _(self, arg): 2359 self.arg = "str" 2360 a = A() 2361 2362 a.t(0) 2363 self.assertEqual(a.arg, "int") 2364 aa = A() 2365 self.assertFalse(hasattr(aa, 'arg')) 2366 a.t('') 2367 self.assertEqual(a.arg, "str") 2368 aa = A() 2369 self.assertFalse(hasattr(aa, 'arg')) 2370 a.t(0.0) 2371 self.assertEqual(a.arg, "base") 2372 aa = A() 2373 self.assertFalse(hasattr(aa, 'arg')) 2374 2375 def test_staticmethod_register(self): 2376 class A: 2377 @functools.singledispatchmethod 2378 @staticmethod 2379 def t(arg): 2380 return arg 2381 @t.register(int) 2382 @staticmethod 2383 def _(arg): 2384 return isinstance(arg, int) 2385 @t.register(str) 2386 @staticmethod 2387 def _(arg): 2388 return isinstance(arg, str) 2389 a = A() 2390 2391 self.assertTrue(A.t(0)) 2392 self.assertTrue(A.t('')) 2393 self.assertEqual(A.t(0.0), 0.0) 2394 2395 def test_classmethod_register(self): 2396 class A: 2397 def __init__(self, arg): 2398 self.arg = arg 2399 2400 @functools.singledispatchmethod 2401 @classmethod 2402 def t(cls, arg): 2403 return cls("base") 2404 @t.register(int) 2405 @classmethod 2406 def _(cls, arg): 2407 return cls("int") 2408 @t.register(str) 2409 @classmethod 2410 def _(cls, arg): 2411 return cls("str") 2412 2413 self.assertEqual(A.t(0).arg, "int") 2414 self.assertEqual(A.t('').arg, "str") 2415 self.assertEqual(A.t(0.0).arg, "base") 2416 2417 def test_callable_register(self): 2418 class A: 2419 def __init__(self, arg): 2420 self.arg = arg 2421 2422 @functools.singledispatchmethod 2423 @classmethod 2424 def t(cls, arg): 2425 return cls("base") 2426 2427 @A.t.register(int) 2428 @classmethod 2429 def _(cls, arg): 2430 return cls("int") 2431 @A.t.register(str) 2432 @classmethod 2433 def _(cls, arg): 2434 return cls("str") 2435 2436 self.assertEqual(A.t(0).arg, "int") 2437 self.assertEqual(A.t('').arg, "str") 2438 self.assertEqual(A.t(0.0).arg, "base") 2439 2440 def test_abstractmethod_register(self): 2441 class Abstract(metaclass=abc.ABCMeta): 2442 2443 @functools.singledispatchmethod 2444 @abc.abstractmethod 2445 def add(self, x, y): 2446 pass 2447 2448 self.assertTrue(Abstract.add.__isabstractmethod__) 2449 self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) 2450 2451 with self.assertRaises(TypeError): 2452 Abstract() 2453 2454 def test_type_ann_register(self): 2455 class A: 2456 @functools.singledispatchmethod 2457 def t(self, arg): 2458 return "base" 2459 @t.register 2460 def _(self, arg: int): 2461 return "int" 2462 @t.register 2463 def _(self, arg: str): 2464 return "str" 2465 a = A() 2466 2467 self.assertEqual(a.t(0), "int") 2468 self.assertEqual(a.t(''), "str") 2469 self.assertEqual(a.t(0.0), "base") 2470 2471 def test_staticmethod_type_ann_register(self): 2472 class A: 2473 @functools.singledispatchmethod 2474 @staticmethod 2475 def t(arg): 2476 return arg 2477 @t.register 2478 @staticmethod 2479 def _(arg: int): 2480 return isinstance(arg, int) 2481 @t.register 2482 @staticmethod 2483 def _(arg: str): 2484 return isinstance(arg, str) 2485 a = A() 2486 2487 self.assertTrue(A.t(0)) 2488 self.assertTrue(A.t('')) 2489 self.assertEqual(A.t(0.0), 0.0) 2490 2491 def test_classmethod_type_ann_register(self): 2492 class A: 2493 def __init__(self, arg): 2494 self.arg = arg 2495 2496 @functools.singledispatchmethod 2497 @classmethod 2498 def t(cls, arg): 2499 return cls("base") 2500 @t.register 2501 @classmethod 2502 def _(cls, arg: int): 2503 return cls("int") 2504 @t.register 2505 @classmethod 2506 def _(cls, arg: str): 2507 return cls("str") 2508 2509 self.assertEqual(A.t(0).arg, "int") 2510 self.assertEqual(A.t('').arg, "str") 2511 self.assertEqual(A.t(0.0).arg, "base") 2512 2513 def test_method_wrapping_attributes(self): 2514 class A: 2515 @functools.singledispatchmethod 2516 def func(self, arg: int) -> str: 2517 """My function docstring""" 2518 return str(arg) 2519 @functools.singledispatchmethod 2520 @classmethod 2521 def cls_func(cls, arg: int) -> str: 2522 """My function docstring""" 2523 return str(arg) 2524 @functools.singledispatchmethod 2525 @staticmethod 2526 def static_func(arg: int) -> str: 2527 """My function docstring""" 2528 return str(arg) 2529 2530 for meth in ( 2531 A.func, 2532 A().func, 2533 A.cls_func, 2534 A().cls_func, 2535 A.static_func, 2536 A().static_func 2537 ): 2538 with self.subTest(meth=meth): 2539 self.assertEqual(meth.__doc__, 'My function docstring') 2540 self.assertEqual(meth.__annotations__['arg'], int) 2541 2542 self.assertEqual(A.func.__name__, 'func') 2543 self.assertEqual(A().func.__name__, 'func') 2544 self.assertEqual(A.cls_func.__name__, 'cls_func') 2545 self.assertEqual(A().cls_func.__name__, 'cls_func') 2546 self.assertEqual(A.static_func.__name__, 'static_func') 2547 self.assertEqual(A().static_func.__name__, 'static_func') 2548 2549 def test_double_wrapped_methods(self): 2550 def classmethod_friendly_decorator(func): 2551 wrapped = func.__func__ 2552 @classmethod 2553 @functools.wraps(wrapped) 2554 def wrapper(*args, **kwargs): 2555 return wrapped(*args, **kwargs) 2556 return wrapper 2557 2558 class WithoutSingleDispatch: 2559 @classmethod 2560 @contextlib.contextmanager 2561 def cls_context_manager(cls, arg: int) -> str: 2562 try: 2563 yield str(arg) 2564 finally: 2565 return 'Done' 2566 2567 @classmethod_friendly_decorator 2568 @classmethod 2569 def decorated_classmethod(cls, arg: int) -> str: 2570 return str(arg) 2571 2572 class WithSingleDispatch: 2573 @functools.singledispatchmethod 2574 @classmethod 2575 @contextlib.contextmanager 2576 def cls_context_manager(cls, arg: int) -> str: 2577 """My function docstring""" 2578 try: 2579 yield str(arg) 2580 finally: 2581 return 'Done' 2582 2583 @functools.singledispatchmethod 2584 @classmethod_friendly_decorator 2585 @classmethod 2586 def decorated_classmethod(cls, arg: int) -> str: 2587 """My function docstring""" 2588 return str(arg) 2589 2590 # These are sanity checks 2591 # to test the test itself is working as expected 2592 with WithoutSingleDispatch.cls_context_manager(5) as foo: 2593 without_single_dispatch_foo = foo 2594 2595 with WithSingleDispatch.cls_context_manager(5) as foo: 2596 single_dispatch_foo = foo 2597 2598 self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) 2599 self.assertEqual(single_dispatch_foo, '5') 2600 2601 self.assertEqual( 2602 WithoutSingleDispatch.decorated_classmethod(5), 2603 WithSingleDispatch.decorated_classmethod(5) 2604 ) 2605 2606 self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') 2607 2608 # Behavioural checks now follow 2609 for method_name in ('cls_context_manager', 'decorated_classmethod'): 2610 with self.subTest(method=method_name): 2611 self.assertEqual( 2612 getattr(WithSingleDispatch, method_name).__name__, 2613 getattr(WithoutSingleDispatch, method_name).__name__ 2614 ) 2615 2616 self.assertEqual( 2617 getattr(WithSingleDispatch(), method_name).__name__, 2618 getattr(WithoutSingleDispatch(), method_name).__name__ 2619 ) 2620 2621 for meth in ( 2622 WithSingleDispatch.cls_context_manager, 2623 WithSingleDispatch().cls_context_manager, 2624 WithSingleDispatch.decorated_classmethod, 2625 WithSingleDispatch().decorated_classmethod 2626 ): 2627 with self.subTest(meth=meth): 2628 self.assertEqual(meth.__doc__, 'My function docstring') 2629 self.assertEqual(meth.__annotations__['arg'], int) 2630 2631 self.assertEqual( 2632 WithSingleDispatch.cls_context_manager.__name__, 2633 'cls_context_manager' 2634 ) 2635 self.assertEqual( 2636 WithSingleDispatch().cls_context_manager.__name__, 2637 'cls_context_manager' 2638 ) 2639 self.assertEqual( 2640 WithSingleDispatch.decorated_classmethod.__name__, 2641 'decorated_classmethod' 2642 ) 2643 self.assertEqual( 2644 WithSingleDispatch().decorated_classmethod.__name__, 2645 'decorated_classmethod' 2646 ) 2647 2648 def test_invalid_registrations(self): 2649 msg_prefix = "Invalid first argument to `register()`: " 2650 msg_suffix = ( 2651 ". Use either `@register(some_class)` or plain `@register` on an " 2652 "annotated function." 2653 ) 2654 @functools.singledispatch 2655 def i(arg): 2656 return "base" 2657 with self.assertRaises(TypeError) as exc: 2658 @i.register(42) 2659 def _(arg): 2660 return "I annotated with a non-type" 2661 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2662 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2663 with self.assertRaises(TypeError) as exc: 2664 @i.register 2665 def _(arg): 2666 return "I forgot to annotate" 2667 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2668 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2669 )) 2670 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2671 2672 with self.assertRaises(TypeError) as exc: 2673 @i.register 2674 def _(arg: typing.Iterable[str]): 2675 # At runtime, dispatching on generics is impossible. 2676 # When registering implementations with singledispatch, avoid 2677 # types from `typing`. Instead, annotate with regular types 2678 # or ABCs. 2679 return "I annotated with a generic collection" 2680 self.assertTrue(str(exc.exception).startswith( 2681 "Invalid annotation for 'arg'." 2682 )) 2683 self.assertTrue(str(exc.exception).endswith( 2684 'typing.Iterable[str] is not a class.' 2685 )) 2686 2687 def test_invalid_positional_argument(self): 2688 @functools.singledispatch 2689 def f(*args): 2690 pass 2691 msg = 'f requires at least 1 positional argument' 2692 with self.assertRaisesRegex(TypeError, msg): 2693 f() 2694 2695 2696class CachedCostItem: 2697 _cost = 1 2698 2699 def __init__(self): 2700 self.lock = py_functools.RLock() 2701 2702 @py_functools.cached_property 2703 def cost(self): 2704 """The cost of the item.""" 2705 with self.lock: 2706 self._cost += 1 2707 return self._cost 2708 2709 2710class OptionallyCachedCostItem: 2711 _cost = 1 2712 2713 def get_cost(self): 2714 """The cost of the item.""" 2715 self._cost += 1 2716 return self._cost 2717 2718 cached_cost = py_functools.cached_property(get_cost) 2719 2720 2721class CachedCostItemWait: 2722 2723 def __init__(self, event): 2724 self._cost = 1 2725 self.lock = py_functools.RLock() 2726 self.event = event 2727 2728 @py_functools.cached_property 2729 def cost(self): 2730 self.event.wait(1) 2731 with self.lock: 2732 self._cost += 1 2733 return self._cost 2734 2735 2736class CachedCostItemWithSlots: 2737 __slots__ = ('_cost') 2738 2739 def __init__(self): 2740 self._cost = 1 2741 2742 @py_functools.cached_property 2743 def cost(self): 2744 raise RuntimeError('never called, slots not supported') 2745 2746 2747class TestCachedProperty(unittest.TestCase): 2748 def test_cached(self): 2749 item = CachedCostItem() 2750 self.assertEqual(item.cost, 2) 2751 self.assertEqual(item.cost, 2) # not 3 2752 2753 def test_cached_attribute_name_differs_from_func_name(self): 2754 item = OptionallyCachedCostItem() 2755 self.assertEqual(item.get_cost(), 2) 2756 self.assertEqual(item.cached_cost, 3) 2757 self.assertEqual(item.get_cost(), 4) 2758 self.assertEqual(item.cached_cost, 3) 2759 2760 def test_threaded(self): 2761 go = threading.Event() 2762 item = CachedCostItemWait(go) 2763 2764 num_threads = 3 2765 2766 orig_si = sys.getswitchinterval() 2767 sys.setswitchinterval(1e-6) 2768 try: 2769 threads = [ 2770 threading.Thread(target=lambda: item.cost) 2771 for k in range(num_threads) 2772 ] 2773 with threading_helper.start_threads(threads): 2774 go.set() 2775 finally: 2776 sys.setswitchinterval(orig_si) 2777 2778 self.assertEqual(item.cost, 2) 2779 2780 def test_object_with_slots(self): 2781 item = CachedCostItemWithSlots() 2782 with self.assertRaisesRegex( 2783 TypeError, 2784 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2785 ): 2786 item.cost 2787 2788 def test_immutable_dict(self): 2789 class MyMeta(type): 2790 @py_functools.cached_property 2791 def prop(self): 2792 return True 2793 2794 class MyClass(metaclass=MyMeta): 2795 pass 2796 2797 with self.assertRaisesRegex( 2798 TypeError, 2799 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 2800 ): 2801 MyClass.prop 2802 2803 def test_reuse_different_names(self): 2804 """Disallow this case because decorated function a would not be cached.""" 2805 with self.assertRaises(RuntimeError) as ctx: 2806 class ReusedCachedProperty: 2807 @py_functools.cached_property 2808 def a(self): 2809 pass 2810 2811 b = a 2812 2813 self.assertEqual( 2814 str(ctx.exception.__context__), 2815 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 2816 ) 2817 2818 def test_reuse_same_name(self): 2819 """Reusing a cached_property on different classes under the same name is OK.""" 2820 counter = 0 2821 2822 @py_functools.cached_property 2823 def _cp(_self): 2824 nonlocal counter 2825 counter += 1 2826 return counter 2827 2828 class A: 2829 cp = _cp 2830 2831 class B: 2832 cp = _cp 2833 2834 a = A() 2835 b = B() 2836 2837 self.assertEqual(a.cp, 1) 2838 self.assertEqual(b.cp, 2) 2839 self.assertEqual(a.cp, 1) 2840 2841 def test_set_name_not_called(self): 2842 cp = py_functools.cached_property(lambda s: None) 2843 class Foo: 2844 pass 2845 2846 Foo.cp = cp 2847 2848 with self.assertRaisesRegex( 2849 TypeError, 2850 "Cannot use cached_property instance without calling __set_name__ on it.", 2851 ): 2852 Foo().cp 2853 2854 def test_access_from_class(self): 2855 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 2856 2857 def test_doc(self): 2858 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 2859 2860 2861if __name__ == '__main__': 2862 unittest.main() 2863