1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16from weakref import proxy 17import contextlib 18 19import functools 20 21py_functools = support.import_fresh_module('functools', blocked=['_functools']) 22c_functools = support.import_fresh_module('functools', fresh=['_functools']) 23 24decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 25 26@contextlib.contextmanager 27def replaced_module(name, replacement): 28 original_module = sys.modules[name] 29 sys.modules[name] = replacement 30 try: 31 yield 32 finally: 33 sys.modules[name] = original_module 34 35def capture(*args, **kw): 36 """capture all positional and keyword arguments""" 37 return args, kw 38 39 40def signature(part): 41 """ return the signature of a partial object """ 42 return (part.func, part.args, part.keywords, part.__dict__) 43 44class MyTuple(tuple): 45 pass 46 47class BadTuple(tuple): 48 def __add__(self, other): 49 return list(self) + list(other) 50 51class MyDict(dict): 52 pass 53 54 55class TestPartial: 56 57 def test_basic_examples(self): 58 p = self.partial(capture, 1, 2, a=10, b=20) 59 self.assertTrue(callable(p)) 60 self.assertEqual(p(3, 4, b=30, c=40), 61 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 62 p = self.partial(map, lambda x: x*10) 63 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 64 65 def test_attributes(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 # attributes should be readable 68 self.assertEqual(p.func, capture) 69 self.assertEqual(p.args, (1, 2)) 70 self.assertEqual(p.keywords, dict(a=10, b=20)) 71 72 def test_argument_checking(self): 73 self.assertRaises(TypeError, self.partial) # need at least a func arg 74 try: 75 self.partial(2)() 76 except TypeError: 77 pass 78 else: 79 self.fail('First arg not checked for callability') 80 81 def test_protection_of_callers_dict_argument(self): 82 # a caller's dictionary should not be altered by partial 83 def func(a=10, b=20): 84 return a 85 d = {'a':3} 86 p = self.partial(func, a=5) 87 self.assertEqual(p(**d), 3) 88 self.assertEqual(d, {'a':3}) 89 p(b=7) 90 self.assertEqual(d, {'a':3}) 91 92 def test_kwargs_copy(self): 93 # Issue #29532: Altering a kwarg dictionary passed to a constructor 94 # should not affect a partial object after creation 95 d = {'a': 3} 96 p = self.partial(capture, **d) 97 self.assertEqual(p(), ((), {'a': 3})) 98 d['a'] = 5 99 self.assertEqual(p(), ((), {'a': 3})) 100 101 def test_arg_combinations(self): 102 # exercise special code paths for zero args in either partial 103 # object or the caller 104 p = self.partial(capture) 105 self.assertEqual(p(), ((), {})) 106 self.assertEqual(p(1,2), ((1,2), {})) 107 p = self.partial(capture, 1, 2) 108 self.assertEqual(p(), ((1,2), {})) 109 self.assertEqual(p(3,4), ((1,2,3,4), {})) 110 111 def test_kw_combinations(self): 112 # exercise special code paths for no keyword args in 113 # either the partial object or the caller 114 p = self.partial(capture) 115 self.assertEqual(p.keywords, {}) 116 self.assertEqual(p(), ((), {})) 117 self.assertEqual(p(a=1), ((), {'a':1})) 118 p = self.partial(capture, a=1) 119 self.assertEqual(p.keywords, {'a':1}) 120 self.assertEqual(p(), ((), {'a':1})) 121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 122 # keyword args in the call override those in the partial object 123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 124 125 def test_positional(self): 126 # make sure positional arguments are captured correctly 127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 128 p = self.partial(capture, *args) 129 expected = args + ('x',) 130 got, empty = p('x') 131 self.assertTrue(expected == got and empty == {}) 132 133 def test_keyword(self): 134 # make sure keyword arguments are captured correctly 135 for a in ['a', 0, None, 3.5]: 136 p = self.partial(capture, a=a) 137 expected = {'a':a,'x':None} 138 empty, got = p(x=None) 139 self.assertTrue(expected == got and empty == ()) 140 141 def test_no_side_effects(self): 142 # make sure there are no side effects that affect subsequent calls 143 p = self.partial(capture, 0, a=1) 144 args1, kw1 = p(1, b=2) 145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 146 args2, kw2 = p() 147 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 148 149 def test_error_propagation(self): 150 def f(x, y): 151 x / y 152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 156 157 def test_weakref(self): 158 f = self.partial(int, base=16) 159 p = proxy(f) 160 self.assertEqual(f.func, p.func) 161 f = None 162 self.assertRaises(ReferenceError, getattr, p, 'func') 163 164 def test_with_bound_and_unbound_methods(self): 165 data = list(map(str, range(10))) 166 join = self.partial(str.join, '') 167 self.assertEqual(join(data), '0123456789') 168 join = self.partial(''.join) 169 self.assertEqual(join(data), '0123456789') 170 171 def test_nested_optimization(self): 172 partial = self.partial 173 inner = partial(signature, 'asdf') 174 nested = partial(inner, bar=True) 175 flat = partial(signature, 'asdf', bar=True) 176 self.assertEqual(signature(nested), signature(flat)) 177 178 def test_nested_partial_with_attribute(self): 179 # see issue 25137 180 partial = self.partial 181 182 def foo(bar): 183 return bar 184 185 p = partial(foo, 'first') 186 p2 = partial(p, 'second') 187 p2.new_attr = 'spam' 188 self.assertEqual(p2.new_attr, 'spam') 189 190 def test_repr(self): 191 args = (object(), object()) 192 args_repr = ', '.join(repr(a) for a in args) 193 kwargs = {'a': object(), 'b': object()} 194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 195 'b={b!r}, a={a!r}'.format_map(kwargs)] 196 if self.partial in (c_functools.partial, py_functools.partial): 197 name = 'functools.partial' 198 else: 199 name = self.partial.__name__ 200 201 f = self.partial(capture) 202 self.assertEqual(f'{name}({capture!r})', repr(f)) 203 204 f = self.partial(capture, *args) 205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 206 207 f = self.partial(capture, **kwargs) 208 self.assertIn(repr(f), 209 [f'{name}({capture!r}, {kwargs_repr})' 210 for kwargs_repr in kwargs_reprs]) 211 212 f = self.partial(capture, *args, **kwargs) 213 self.assertIn(repr(f), 214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 215 for kwargs_repr in kwargs_reprs]) 216 217 def test_recursive_repr(self): 218 if self.partial in (c_functools.partial, py_functools.partial): 219 name = 'functools.partial' 220 else: 221 name = self.partial.__name__ 222 223 f = self.partial(capture) 224 f.__setstate__((f, (), {}, {})) 225 try: 226 self.assertEqual(repr(f), '%s(...)' % (name,)) 227 finally: 228 f.__setstate__((capture, (), {}, {})) 229 230 f = self.partial(capture) 231 f.__setstate__((capture, (f,), {}, {})) 232 try: 233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 234 finally: 235 f.__setstate__((capture, (), {}, {})) 236 237 f = self.partial(capture) 238 f.__setstate__((capture, (), {'a': f}, {})) 239 try: 240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 241 finally: 242 f.__setstate__((capture, (), {}, {})) 243 244 def test_pickle(self): 245 with self.AllowPickle(): 246 f = self.partial(signature, ['asdf'], bar=[True]) 247 f.attr = [] 248 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 249 f_copy = pickle.loads(pickle.dumps(f, proto)) 250 self.assertEqual(signature(f_copy), signature(f)) 251 252 def test_copy(self): 253 f = self.partial(signature, ['asdf'], bar=[True]) 254 f.attr = [] 255 f_copy = copy.copy(f) 256 self.assertEqual(signature(f_copy), signature(f)) 257 self.assertIs(f_copy.attr, f.attr) 258 self.assertIs(f_copy.args, f.args) 259 self.assertIs(f_copy.keywords, f.keywords) 260 261 def test_deepcopy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.deepcopy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIsNot(f_copy.attr, f.attr) 267 self.assertIsNot(f_copy.args, f.args) 268 self.assertIsNot(f_copy.args[0], f.args[0]) 269 self.assertIsNot(f_copy.keywords, f.keywords) 270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 271 272 def test_setstate(self): 273 f = self.partial(signature) 274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 275 276 self.assertEqual(signature(f), 277 (capture, (1,), dict(a=10), dict(attr=[]))) 278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 279 280 f.__setstate__((capture, (1,), dict(a=10), None)) 281 282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 284 285 f.__setstate__((capture, (1,), None, None)) 286 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 288 self.assertEqual(f(2), ((1, 2), {})) 289 self.assertEqual(f(), ((1,), {})) 290 291 f.__setstate__((capture, (), {}, None)) 292 self.assertEqual(signature(f), (capture, (), {}, {})) 293 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 294 self.assertEqual(f(2), ((2,), {})) 295 self.assertEqual(f(), ((), {})) 296 297 def test_setstate_errors(self): 298 f = self.partial(signature) 299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 306 307 def test_setstate_subclasses(self): 308 f = self.partial(signature) 309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 310 s = signature(f) 311 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 312 self.assertIs(type(s[1]), tuple) 313 self.assertIs(type(s[2]), dict) 314 r = f() 315 self.assertEqual(r, ((1,), {'a': 10})) 316 self.assertIs(type(r[0]), tuple) 317 self.assertIs(type(r[1]), dict) 318 319 f.__setstate__((capture, BadTuple((1,)), {}, None)) 320 s = signature(f) 321 self.assertEqual(s, (capture, (1,), {}, {})) 322 self.assertIs(type(s[1]), tuple) 323 r = f(2) 324 self.assertEqual(r, ((1, 2), {})) 325 self.assertIs(type(r[0]), tuple) 326 327 def test_recursive_pickle(self): 328 with self.AllowPickle(): 329 f = self.partial(capture) 330 f.__setstate__((f, (), {}, {})) 331 try: 332 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 333 with self.assertRaises(RecursionError): 334 pickle.dumps(f, proto) 335 finally: 336 f.__setstate__((capture, (), {}, {})) 337 338 f = self.partial(capture) 339 f.__setstate__((capture, (f,), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 f_copy = pickle.loads(pickle.dumps(f, proto)) 343 try: 344 self.assertIs(f_copy.args[0], f_copy) 345 finally: 346 f_copy.__setstate__((capture, (), {}, {})) 347 finally: 348 f.__setstate__((capture, (), {}, {})) 349 350 f = self.partial(capture) 351 f.__setstate__((capture, (), {'a': f}, {})) 352 try: 353 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 354 f_copy = pickle.loads(pickle.dumps(f, proto)) 355 try: 356 self.assertIs(f_copy.keywords['a'], f_copy) 357 finally: 358 f_copy.__setstate__((capture, (), {}, {})) 359 finally: 360 f.__setstate__((capture, (), {}, {})) 361 362 # Issue 6083: Reference counting bug 363 def test_setstate_refcount(self): 364 class BadSequence: 365 def __len__(self): 366 return 4 367 def __getitem__(self, key): 368 if key == 0: 369 return max 370 elif key == 1: 371 return tuple(range(1000000)) 372 elif key in (2, 3): 373 return {} 374 raise IndexError 375 376 f = self.partial(object) 377 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 378 379@unittest.skipUnless(c_functools, 'requires the C _functools module') 380class TestPartialC(TestPartial, unittest.TestCase): 381 if c_functools: 382 partial = c_functools.partial 383 384 class AllowPickle: 385 def __enter__(self): 386 return self 387 def __exit__(self, type, value, tb): 388 return False 389 390 def test_attributes_unwritable(self): 391 # attributes should not be writable 392 p = self.partial(capture, 1, 2, a=10, b=20) 393 self.assertRaises(AttributeError, setattr, p, 'func', map) 394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 396 397 p = self.partial(hex) 398 try: 399 del p.__dict__ 400 except TypeError: 401 pass 402 else: 403 self.fail('partial object allowed __dict__ to be deleted') 404 405 def test_manually_adding_non_string_keyword(self): 406 p = self.partial(capture) 407 # Adding a non-string/unicode keyword to partial kwargs 408 p.keywords[1234] = 'value' 409 r = repr(p) 410 self.assertIn('1234', r) 411 self.assertIn("'value'", r) 412 with self.assertRaises(TypeError): 413 p() 414 415 def test_keystr_replaces_value(self): 416 p = self.partial(capture) 417 418 class MutatesYourDict(object): 419 def __str__(self): 420 p.keywords[self] = ['sth2'] 421 return 'astr' 422 423 # Replacing the value during key formatting should keep the original 424 # value alive (at least long enough). 425 p.keywords[MutatesYourDict()] = ['sth'] 426 r = repr(p) 427 self.assertIn('astr', r) 428 self.assertIn("['sth']", r) 429 430 431class TestPartialPy(TestPartial, unittest.TestCase): 432 partial = py_functools.partial 433 434 class AllowPickle: 435 def __init__(self): 436 self._cm = replaced_module("functools", py_functools) 437 def __enter__(self): 438 return self._cm.__enter__() 439 def __exit__(self, type, value, tb): 440 return self._cm.__exit__(type, value, tb) 441 442if c_functools: 443 class CPartialSubclass(c_functools.partial): 444 pass 445 446class PyPartialSubclass(py_functools.partial): 447 pass 448 449@unittest.skipUnless(c_functools, 'requires the C _functools module') 450class TestPartialCSubclass(TestPartialC): 451 if c_functools: 452 partial = CPartialSubclass 453 454 # partial subclasses are not optimized for nested calls 455 test_nested_optimization = None 456 457class TestPartialPySubclass(TestPartialPy): 458 partial = PyPartialSubclass 459 460class TestPartialMethod(unittest.TestCase): 461 462 class A(object): 463 nothing = functools.partialmethod(capture) 464 positional = functools.partialmethod(capture, 1) 465 keywords = functools.partialmethod(capture, a=2) 466 both = functools.partialmethod(capture, 3, b=4) 467 spec_keywords = functools.partialmethod(capture, self=1, func=2) 468 469 nested = functools.partialmethod(positional, 5) 470 471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 472 473 static = functools.partialmethod(staticmethod(capture), 8) 474 cls = functools.partialmethod(classmethod(capture), d=9) 475 476 a = A() 477 478 def test_arg_combinations(self): 479 self.assertEqual(self.a.nothing(), ((self.a,), {})) 480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 483 484 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 488 489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 493 494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 498 499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 500 501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 502 503 def test_nested(self): 504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 508 509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 510 511 def test_over_partial(self): 512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 516 517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 518 519 def test_bound_method_introspection(self): 520 obj = self.a 521 self.assertIs(obj.both.__self__, obj) 522 self.assertIs(obj.nested.__self__, obj) 523 self.assertIs(obj.over_partial.__self__, obj) 524 self.assertIs(obj.cls.__self__, self.A) 525 self.assertIs(self.A.cls.__self__, self.A) 526 527 def test_unbound_method_retrieval(self): 528 obj = self.A 529 self.assertFalse(hasattr(obj.both, "__self__")) 530 self.assertFalse(hasattr(obj.nested, "__self__")) 531 self.assertFalse(hasattr(obj.over_partial, "__self__")) 532 self.assertFalse(hasattr(obj.static, "__self__")) 533 self.assertFalse(hasattr(self.a.static, "__self__")) 534 535 def test_descriptors(self): 536 for obj in [self.A, self.a]: 537 with self.subTest(obj=obj): 538 self.assertEqual(obj.static(), ((8,), {})) 539 self.assertEqual(obj.static(5), ((8, 5), {})) 540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 542 543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 547 548 def test_overriding_keywords(self): 549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 551 552 def test_invalid_args(self): 553 with self.assertRaises(TypeError): 554 class B(object): 555 method = functools.partialmethod(None, 1) 556 with self.assertRaises(TypeError): 557 class B: 558 method = functools.partialmethod() 559 class B: 560 method = functools.partialmethod(func=capture, a=1) 561 b = B() 562 self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3})) 563 564 def test_repr(self): 565 self.assertEqual(repr(vars(self.A)['both']), 566 'functools.partialmethod({}, 3, b=4)'.format(capture)) 567 568 def test_abstract(self): 569 class Abstract(abc.ABCMeta): 570 571 @abc.abstractmethod 572 def add(self, x, y): 573 pass 574 575 add5 = functools.partialmethod(add, 5) 576 577 self.assertTrue(Abstract.add.__isabstractmethod__) 578 self.assertTrue(Abstract.add5.__isabstractmethod__) 579 580 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 581 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 582 583 584class TestUpdateWrapper(unittest.TestCase): 585 586 def check_wrapper(self, wrapper, wrapped, 587 assigned=functools.WRAPPER_ASSIGNMENTS, 588 updated=functools.WRAPPER_UPDATES): 589 # Check attributes were assigned 590 for name in assigned: 591 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 592 # Check attributes were updated 593 for name in updated: 594 wrapper_attr = getattr(wrapper, name) 595 wrapped_attr = getattr(wrapped, name) 596 for key in wrapped_attr: 597 if name == "__dict__" and key == "__wrapped__": 598 # __wrapped__ is overwritten by the update code 599 continue 600 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 601 # Check __wrapped__ 602 self.assertIs(wrapper.__wrapped__, wrapped) 603 604 605 def _default_update(self): 606 def f(a:'This is a new annotation'): 607 """This is a test""" 608 pass 609 f.attr = 'This is also a test' 610 f.__wrapped__ = "This is a bald faced lie" 611 def wrapper(b:'This is the prior annotation'): 612 pass 613 functools.update_wrapper(wrapper, f) 614 return wrapper, f 615 616 def test_default_update(self): 617 wrapper, f = self._default_update() 618 self.check_wrapper(wrapper, f) 619 self.assertIs(wrapper.__wrapped__, f) 620 self.assertEqual(wrapper.__name__, 'f') 621 self.assertEqual(wrapper.__qualname__, f.__qualname__) 622 self.assertEqual(wrapper.attr, 'This is also a test') 623 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 624 self.assertNotIn('b', wrapper.__annotations__) 625 626 @unittest.skipIf(sys.flags.optimize >= 2, 627 "Docstrings are omitted with -O2 and above") 628 def test_default_update_doc(self): 629 wrapper, f = self._default_update() 630 self.assertEqual(wrapper.__doc__, 'This is a test') 631 632 def test_no_update(self): 633 def f(): 634 """This is a test""" 635 pass 636 f.attr = 'This is also a test' 637 def wrapper(): 638 pass 639 functools.update_wrapper(wrapper, f, (), ()) 640 self.check_wrapper(wrapper, f, (), ()) 641 self.assertEqual(wrapper.__name__, 'wrapper') 642 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 643 self.assertEqual(wrapper.__doc__, None) 644 self.assertEqual(wrapper.__annotations__, {}) 645 self.assertFalse(hasattr(wrapper, 'attr')) 646 647 def test_selective_update(self): 648 def f(): 649 pass 650 f.attr = 'This is a different test' 651 f.dict_attr = dict(a=1, b=2, c=3) 652 def wrapper(): 653 pass 654 wrapper.dict_attr = {} 655 assign = ('attr',) 656 update = ('dict_attr',) 657 functools.update_wrapper(wrapper, f, assign, update) 658 self.check_wrapper(wrapper, f, assign, update) 659 self.assertEqual(wrapper.__name__, 'wrapper') 660 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 661 self.assertEqual(wrapper.__doc__, None) 662 self.assertEqual(wrapper.attr, 'This is a different test') 663 self.assertEqual(wrapper.dict_attr, f.dict_attr) 664 665 def test_missing_attributes(self): 666 def f(): 667 pass 668 def wrapper(): 669 pass 670 wrapper.dict_attr = {} 671 assign = ('attr',) 672 update = ('dict_attr',) 673 # Missing attributes on wrapped object are ignored 674 functools.update_wrapper(wrapper, f, assign, update) 675 self.assertNotIn('attr', wrapper.__dict__) 676 self.assertEqual(wrapper.dict_attr, {}) 677 # Wrapper must have expected attributes for updating 678 del wrapper.dict_attr 679 with self.assertRaises(AttributeError): 680 functools.update_wrapper(wrapper, f, assign, update) 681 wrapper.dict_attr = 1 682 with self.assertRaises(AttributeError): 683 functools.update_wrapper(wrapper, f, assign, update) 684 685 @support.requires_docstrings 686 @unittest.skipIf(sys.flags.optimize >= 2, 687 "Docstrings are omitted with -O2 and above") 688 def test_builtin_update(self): 689 # Test for bug #1576241 690 def wrapper(): 691 pass 692 functools.update_wrapper(wrapper, max) 693 self.assertEqual(wrapper.__name__, 'max') 694 self.assertTrue(wrapper.__doc__.startswith('max(')) 695 self.assertEqual(wrapper.__annotations__, {}) 696 697 698class TestWraps(TestUpdateWrapper): 699 700 def _default_update(self): 701 def f(): 702 """This is a test""" 703 pass 704 f.attr = 'This is also a test' 705 f.__wrapped__ = "This is still a bald faced lie" 706 @functools.wraps(f) 707 def wrapper(): 708 pass 709 return wrapper, f 710 711 def test_default_update(self): 712 wrapper, f = self._default_update() 713 self.check_wrapper(wrapper, f) 714 self.assertEqual(wrapper.__name__, 'f') 715 self.assertEqual(wrapper.__qualname__, f.__qualname__) 716 self.assertEqual(wrapper.attr, 'This is also a test') 717 718 @unittest.skipIf(sys.flags.optimize >= 2, 719 "Docstrings are omitted with -O2 and above") 720 def test_default_update_doc(self): 721 wrapper, _ = self._default_update() 722 self.assertEqual(wrapper.__doc__, 'This is a test') 723 724 def test_no_update(self): 725 def f(): 726 """This is a test""" 727 pass 728 f.attr = 'This is also a test' 729 @functools.wraps(f, (), ()) 730 def wrapper(): 731 pass 732 self.check_wrapper(wrapper, f, (), ()) 733 self.assertEqual(wrapper.__name__, 'wrapper') 734 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 735 self.assertEqual(wrapper.__doc__, None) 736 self.assertFalse(hasattr(wrapper, 'attr')) 737 738 def test_selective_update(self): 739 def f(): 740 pass 741 f.attr = 'This is a different test' 742 f.dict_attr = dict(a=1, b=2, c=3) 743 def add_dict_attr(f): 744 f.dict_attr = {} 745 return f 746 assign = ('attr',) 747 update = ('dict_attr',) 748 @functools.wraps(f, assign, update) 749 @add_dict_attr 750 def wrapper(): 751 pass 752 self.check_wrapper(wrapper, f, assign, update) 753 self.assertEqual(wrapper.__name__, 'wrapper') 754 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 755 self.assertEqual(wrapper.__doc__, None) 756 self.assertEqual(wrapper.attr, 'This is a different test') 757 self.assertEqual(wrapper.dict_attr, f.dict_attr) 758 759@unittest.skipUnless(c_functools, 'requires the C _functools module') 760class TestReduce(unittest.TestCase): 761 if c_functools: 762 func = c_functools.reduce 763 764 def test_reduce(self): 765 class Squares: 766 def __init__(self, max): 767 self.max = max 768 self.sofar = [] 769 770 def __len__(self): 771 return len(self.sofar) 772 773 def __getitem__(self, i): 774 if not 0 <= i < self.max: raise IndexError 775 n = len(self.sofar) 776 while n <= i: 777 self.sofar.append(n*n) 778 n += 1 779 return self.sofar[i] 780 def add(x, y): 781 return x + y 782 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc') 783 self.assertEqual( 784 self.func(add, [['a', 'c'], [], ['d', 'w']], []), 785 ['a','c','d','w'] 786 ) 787 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) 788 self.assertEqual( 789 self.func(lambda x, y: x*y, range(2,21), 1), 790 2432902008176640000 791 ) 792 self.assertEqual(self.func(add, Squares(10)), 285) 793 self.assertEqual(self.func(add, Squares(10), 0), 285) 794 self.assertEqual(self.func(add, Squares(0), 0), 0) 795 self.assertRaises(TypeError, self.func) 796 self.assertRaises(TypeError, self.func, 42, 42) 797 self.assertRaises(TypeError, self.func, 42, 42, 42) 798 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item 799 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item 800 self.assertRaises(TypeError, self.func, 42, (42, 42)) 801 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value 802 self.assertRaises(TypeError, self.func, add, "") 803 self.assertRaises(TypeError, self.func, add, ()) 804 self.assertRaises(TypeError, self.func, add, object()) 805 806 class TestFailingIter: 807 def __iter__(self): 808 raise RuntimeError 809 self.assertRaises(RuntimeError, self.func, add, TestFailingIter()) 810 811 self.assertEqual(self.func(add, [], None), None) 812 self.assertEqual(self.func(add, [], 42), 42) 813 814 class BadSeq: 815 def __getitem__(self, index): 816 raise ValueError 817 self.assertRaises(ValueError, self.func, 42, BadSeq()) 818 819 # Test reduce()'s use of iterators. 820 def test_iterator_usage(self): 821 class SequenceClass: 822 def __init__(self, n): 823 self.n = n 824 def __getitem__(self, i): 825 if 0 <= i < self.n: 826 return i 827 else: 828 raise IndexError 829 830 from operator import add 831 self.assertEqual(self.func(add, SequenceClass(5)), 10) 832 self.assertEqual(self.func(add, SequenceClass(5), 42), 52) 833 self.assertRaises(TypeError, self.func, add, SequenceClass(0)) 834 self.assertEqual(self.func(add, SequenceClass(0), 42), 42) 835 self.assertEqual(self.func(add, SequenceClass(1)), 0) 836 self.assertEqual(self.func(add, SequenceClass(1), 42), 42) 837 838 d = {"one": 1, "two": 2, "three": 3} 839 self.assertEqual(self.func(add, d), "".join(d.keys())) 840 841 842class TestCmpToKey: 843 844 def test_cmp_to_key(self): 845 def cmp1(x, y): 846 return (x > y) - (x < y) 847 key = self.cmp_to_key(cmp1) 848 self.assertEqual(key(3), key(3)) 849 self.assertGreater(key(3), key(1)) 850 self.assertGreaterEqual(key(3), key(3)) 851 852 def cmp2(x, y): 853 return int(x) - int(y) 854 key = self.cmp_to_key(cmp2) 855 self.assertEqual(key(4.0), key('4')) 856 self.assertLess(key(2), key('35')) 857 self.assertLessEqual(key(2), key('35')) 858 self.assertNotEqual(key(2), key('35')) 859 860 def test_cmp_to_key_arguments(self): 861 def cmp1(x, y): 862 return (x > y) - (x < y) 863 key = self.cmp_to_key(mycmp=cmp1) 864 self.assertEqual(key(obj=3), key(obj=3)) 865 self.assertGreater(key(obj=3), key(obj=1)) 866 with self.assertRaises((TypeError, AttributeError)): 867 key(3) > 1 # rhs is not a K object 868 with self.assertRaises((TypeError, AttributeError)): 869 1 < key(3) # lhs is not a K object 870 with self.assertRaises(TypeError): 871 key = self.cmp_to_key() # too few args 872 with self.assertRaises(TypeError): 873 key = self.cmp_to_key(cmp1, None) # too many args 874 key = self.cmp_to_key(cmp1) 875 with self.assertRaises(TypeError): 876 key() # too few args 877 with self.assertRaises(TypeError): 878 key(None, None) # too many args 879 880 def test_bad_cmp(self): 881 def cmp1(x, y): 882 raise ZeroDivisionError 883 key = self.cmp_to_key(cmp1) 884 with self.assertRaises(ZeroDivisionError): 885 key(3) > key(1) 886 887 class BadCmp: 888 def __lt__(self, other): 889 raise ZeroDivisionError 890 def cmp1(x, y): 891 return BadCmp() 892 with self.assertRaises(ZeroDivisionError): 893 key(3) > key(1) 894 895 def test_obj_field(self): 896 def cmp1(x, y): 897 return (x > y) - (x < y) 898 key = self.cmp_to_key(mycmp=cmp1) 899 self.assertEqual(key(50).obj, 50) 900 901 def test_sort_int(self): 902 def mycmp(x, y): 903 return y - x 904 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 905 [4, 3, 2, 1, 0]) 906 907 def test_sort_int_str(self): 908 def mycmp(x, y): 909 x, y = int(x), int(y) 910 return (x > y) - (x < y) 911 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 912 values = sorted(values, key=self.cmp_to_key(mycmp)) 913 self.assertEqual([int(value) for value in values], 914 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 915 916 def test_hash(self): 917 def mycmp(x, y): 918 return y - x 919 key = self.cmp_to_key(mycmp) 920 k = key(10) 921 self.assertRaises(TypeError, hash, k) 922 self.assertNotIsInstance(k, collections.abc.Hashable) 923 924 925@unittest.skipUnless(c_functools, 'requires the C _functools module') 926class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 927 if c_functools: 928 cmp_to_key = c_functools.cmp_to_key 929 930 931class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 932 cmp_to_key = staticmethod(py_functools.cmp_to_key) 933 934 935class TestTotalOrdering(unittest.TestCase): 936 937 def test_total_ordering_lt(self): 938 @functools.total_ordering 939 class A: 940 def __init__(self, value): 941 self.value = value 942 def __lt__(self, other): 943 return self.value < other.value 944 def __eq__(self, other): 945 return self.value == other.value 946 self.assertTrue(A(1) < A(2)) 947 self.assertTrue(A(2) > A(1)) 948 self.assertTrue(A(1) <= A(2)) 949 self.assertTrue(A(2) >= A(1)) 950 self.assertTrue(A(2) <= A(2)) 951 self.assertTrue(A(2) >= A(2)) 952 self.assertFalse(A(1) > A(2)) 953 954 def test_total_ordering_le(self): 955 @functools.total_ordering 956 class A: 957 def __init__(self, value): 958 self.value = value 959 def __le__(self, other): 960 return self.value <= other.value 961 def __eq__(self, other): 962 return self.value == other.value 963 self.assertTrue(A(1) < A(2)) 964 self.assertTrue(A(2) > A(1)) 965 self.assertTrue(A(1) <= A(2)) 966 self.assertTrue(A(2) >= A(1)) 967 self.assertTrue(A(2) <= A(2)) 968 self.assertTrue(A(2) >= A(2)) 969 self.assertFalse(A(1) >= A(2)) 970 971 def test_total_ordering_gt(self): 972 @functools.total_ordering 973 class A: 974 def __init__(self, value): 975 self.value = value 976 def __gt__(self, other): 977 return self.value > other.value 978 def __eq__(self, other): 979 return self.value == other.value 980 self.assertTrue(A(1) < A(2)) 981 self.assertTrue(A(2) > A(1)) 982 self.assertTrue(A(1) <= A(2)) 983 self.assertTrue(A(2) >= A(1)) 984 self.assertTrue(A(2) <= A(2)) 985 self.assertTrue(A(2) >= A(2)) 986 self.assertFalse(A(2) < A(1)) 987 988 def test_total_ordering_ge(self): 989 @functools.total_ordering 990 class A: 991 def __init__(self, value): 992 self.value = value 993 def __ge__(self, other): 994 return self.value >= other.value 995 def __eq__(self, other): 996 return self.value == other.value 997 self.assertTrue(A(1) < A(2)) 998 self.assertTrue(A(2) > A(1)) 999 self.assertTrue(A(1) <= A(2)) 1000 self.assertTrue(A(2) >= A(1)) 1001 self.assertTrue(A(2) <= A(2)) 1002 self.assertTrue(A(2) >= A(2)) 1003 self.assertFalse(A(2) <= A(1)) 1004 1005 def test_total_ordering_no_overwrite(self): 1006 # new methods should not overwrite existing 1007 @functools.total_ordering 1008 class A(int): 1009 pass 1010 self.assertTrue(A(1) < A(2)) 1011 self.assertTrue(A(2) > A(1)) 1012 self.assertTrue(A(1) <= A(2)) 1013 self.assertTrue(A(2) >= A(1)) 1014 self.assertTrue(A(2) <= A(2)) 1015 self.assertTrue(A(2) >= A(2)) 1016 1017 def test_no_operations_defined(self): 1018 with self.assertRaises(ValueError): 1019 @functools.total_ordering 1020 class A: 1021 pass 1022 1023 def test_type_error_when_not_implemented(self): 1024 # bug 10042; ensure stack overflow does not occur 1025 # when decorated types return NotImplemented 1026 @functools.total_ordering 1027 class ImplementsLessThan: 1028 def __init__(self, value): 1029 self.value = value 1030 def __eq__(self, other): 1031 if isinstance(other, ImplementsLessThan): 1032 return self.value == other.value 1033 return False 1034 def __lt__(self, other): 1035 if isinstance(other, ImplementsLessThan): 1036 return self.value < other.value 1037 return NotImplemented 1038 1039 @functools.total_ordering 1040 class ImplementsGreaterThan: 1041 def __init__(self, value): 1042 self.value = value 1043 def __eq__(self, other): 1044 if isinstance(other, ImplementsGreaterThan): 1045 return self.value == other.value 1046 return False 1047 def __gt__(self, other): 1048 if isinstance(other, ImplementsGreaterThan): 1049 return self.value > other.value 1050 return NotImplemented 1051 1052 @functools.total_ordering 1053 class ImplementsLessThanEqualTo: 1054 def __init__(self, value): 1055 self.value = value 1056 def __eq__(self, other): 1057 if isinstance(other, ImplementsLessThanEqualTo): 1058 return self.value == other.value 1059 return False 1060 def __le__(self, other): 1061 if isinstance(other, ImplementsLessThanEqualTo): 1062 return self.value <= other.value 1063 return NotImplemented 1064 1065 @functools.total_ordering 1066 class ImplementsGreaterThanEqualTo: 1067 def __init__(self, value): 1068 self.value = value 1069 def __eq__(self, other): 1070 if isinstance(other, ImplementsGreaterThanEqualTo): 1071 return self.value == other.value 1072 return False 1073 def __ge__(self, other): 1074 if isinstance(other, ImplementsGreaterThanEqualTo): 1075 return self.value >= other.value 1076 return NotImplemented 1077 1078 @functools.total_ordering 1079 class ComparatorNotImplemented: 1080 def __init__(self, value): 1081 self.value = value 1082 def __eq__(self, other): 1083 if isinstance(other, ComparatorNotImplemented): 1084 return self.value == other.value 1085 return False 1086 def __lt__(self, other): 1087 return NotImplemented 1088 1089 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1090 ImplementsLessThan(-1) < 1 1091 1092 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1093 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1094 1095 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1096 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1097 1098 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1099 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1100 1101 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1102 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1103 1104 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1105 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1106 1107 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1108 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1109 1110 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1111 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1112 1113 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1114 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1115 1116 with self.subTest("GE when equal"): 1117 a = ComparatorNotImplemented(8) 1118 b = ComparatorNotImplemented(8) 1119 self.assertEqual(a, b) 1120 with self.assertRaises(TypeError): 1121 a >= b 1122 1123 with self.subTest("LE when equal"): 1124 a = ComparatorNotImplemented(9) 1125 b = ComparatorNotImplemented(9) 1126 self.assertEqual(a, b) 1127 with self.assertRaises(TypeError): 1128 a <= b 1129 1130 def test_pickle(self): 1131 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1132 for name in '__lt__', '__gt__', '__le__', '__ge__': 1133 with self.subTest(method=name, proto=proto): 1134 method = getattr(Orderable_LT, name) 1135 method_copy = pickle.loads(pickle.dumps(method, proto)) 1136 self.assertIs(method_copy, method) 1137 1138@functools.total_ordering 1139class Orderable_LT: 1140 def __init__(self, value): 1141 self.value = value 1142 def __lt__(self, other): 1143 return self.value < other.value 1144 def __eq__(self, other): 1145 return self.value == other.value 1146 1147 1148class TestLRU: 1149 1150 def test_lru(self): 1151 def orig(x, y): 1152 return 3 * x + y 1153 f = self.module.lru_cache(maxsize=20)(orig) 1154 hits, misses, maxsize, currsize = f.cache_info() 1155 self.assertEqual(maxsize, 20) 1156 self.assertEqual(currsize, 0) 1157 self.assertEqual(hits, 0) 1158 self.assertEqual(misses, 0) 1159 1160 domain = range(5) 1161 for i in range(1000): 1162 x, y = choice(domain), choice(domain) 1163 actual = f(x, y) 1164 expected = orig(x, y) 1165 self.assertEqual(actual, expected) 1166 hits, misses, maxsize, currsize = f.cache_info() 1167 self.assertTrue(hits > misses) 1168 self.assertEqual(hits + misses, 1000) 1169 self.assertEqual(currsize, 20) 1170 1171 f.cache_clear() # test clearing 1172 hits, misses, maxsize, currsize = f.cache_info() 1173 self.assertEqual(hits, 0) 1174 self.assertEqual(misses, 0) 1175 self.assertEqual(currsize, 0) 1176 f(x, y) 1177 hits, misses, maxsize, currsize = f.cache_info() 1178 self.assertEqual(hits, 0) 1179 self.assertEqual(misses, 1) 1180 self.assertEqual(currsize, 1) 1181 1182 # Test bypassing the cache 1183 self.assertIs(f.__wrapped__, orig) 1184 f.__wrapped__(x, y) 1185 hits, misses, maxsize, currsize = f.cache_info() 1186 self.assertEqual(hits, 0) 1187 self.assertEqual(misses, 1) 1188 self.assertEqual(currsize, 1) 1189 1190 # test size zero (which means "never-cache") 1191 @self.module.lru_cache(0) 1192 def f(): 1193 nonlocal f_cnt 1194 f_cnt += 1 1195 return 20 1196 self.assertEqual(f.cache_info().maxsize, 0) 1197 f_cnt = 0 1198 for i in range(5): 1199 self.assertEqual(f(), 20) 1200 self.assertEqual(f_cnt, 5) 1201 hits, misses, maxsize, currsize = f.cache_info() 1202 self.assertEqual(hits, 0) 1203 self.assertEqual(misses, 5) 1204 self.assertEqual(currsize, 0) 1205 1206 # test size one 1207 @self.module.lru_cache(1) 1208 def f(): 1209 nonlocal f_cnt 1210 f_cnt += 1 1211 return 20 1212 self.assertEqual(f.cache_info().maxsize, 1) 1213 f_cnt = 0 1214 for i in range(5): 1215 self.assertEqual(f(), 20) 1216 self.assertEqual(f_cnt, 1) 1217 hits, misses, maxsize, currsize = f.cache_info() 1218 self.assertEqual(hits, 4) 1219 self.assertEqual(misses, 1) 1220 self.assertEqual(currsize, 1) 1221 1222 # test size two 1223 @self.module.lru_cache(2) 1224 def f(x): 1225 nonlocal f_cnt 1226 f_cnt += 1 1227 return x*10 1228 self.assertEqual(f.cache_info().maxsize, 2) 1229 f_cnt = 0 1230 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1231 # * * * * 1232 self.assertEqual(f(x), x*10) 1233 self.assertEqual(f_cnt, 4) 1234 hits, misses, maxsize, currsize = f.cache_info() 1235 self.assertEqual(hits, 12) 1236 self.assertEqual(misses, 4) 1237 self.assertEqual(currsize, 2) 1238 1239 def test_lru_bug_35780(self): 1240 # C version of the lru_cache was not checking to see if 1241 # the user function call has already modified the cache 1242 # (this arises in recursive calls and in multi-threading). 1243 # This cause the cache to have orphan links not referenced 1244 # by the cache dictionary. 1245 1246 once = True # Modified by f(x) below 1247 1248 @self.module.lru_cache(maxsize=10) 1249 def f(x): 1250 nonlocal once 1251 rv = f'.{x}.' 1252 if x == 20 and once: 1253 once = False 1254 rv = f(x) 1255 return rv 1256 1257 # Fill the cache 1258 for x in range(15): 1259 self.assertEqual(f(x), f'.{x}.') 1260 self.assertEqual(f.cache_info().currsize, 10) 1261 1262 # Make a recursive call and make sure the cache remains full 1263 self.assertEqual(f(20), '.20.') 1264 self.assertEqual(f.cache_info().currsize, 10) 1265 1266 def test_lru_bug_36650(self): 1267 # C version of lru_cache was treating a call with an empty **kwargs 1268 # dictionary as being distinct from a call with no keywords at all. 1269 # This did not result in an incorrect answer, but it did trigger 1270 # an unexpected cache miss. 1271 1272 @self.module.lru_cache() 1273 def f(x): 1274 pass 1275 1276 f(0) 1277 f(0, **{}) 1278 self.assertEqual(f.cache_info().hits, 1) 1279 1280 def test_lru_hash_only_once(self): 1281 # To protect against weird reentrancy bugs and to improve 1282 # efficiency when faced with slow __hash__ methods, the 1283 # LRU cache guarantees that it will only call __hash__ 1284 # only once per use as an argument to the cached function. 1285 1286 @self.module.lru_cache(maxsize=1) 1287 def f(x, y): 1288 return x * 3 + y 1289 1290 # Simulate the integer 5 1291 mock_int = unittest.mock.Mock() 1292 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1293 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1294 1295 # Add to cache: One use as an argument gives one call 1296 self.assertEqual(f(mock_int, 1), 16) 1297 self.assertEqual(mock_int.__hash__.call_count, 1) 1298 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1299 1300 # Cache hit: One use as an argument gives one additional call 1301 self.assertEqual(f(mock_int, 1), 16) 1302 self.assertEqual(mock_int.__hash__.call_count, 2) 1303 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1304 1305 # Cache eviction: No use as an argument gives no additional call 1306 self.assertEqual(f(6, 2), 20) 1307 self.assertEqual(mock_int.__hash__.call_count, 2) 1308 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1309 1310 # Cache miss: One use as an argument gives one additional call 1311 self.assertEqual(f(mock_int, 1), 16) 1312 self.assertEqual(mock_int.__hash__.call_count, 3) 1313 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1314 1315 def test_lru_reentrancy_with_len(self): 1316 # Test to make sure the LRU cache code isn't thrown-off by 1317 # caching the built-in len() function. Since len() can be 1318 # cached, we shouldn't use it inside the lru code itself. 1319 old_len = builtins.len 1320 try: 1321 builtins.len = self.module.lru_cache(4)(len) 1322 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1323 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1324 finally: 1325 builtins.len = old_len 1326 1327 def test_lru_star_arg_handling(self): 1328 # Test regression that arose in ea064ff3c10f 1329 @functools.lru_cache() 1330 def f(*args): 1331 return args 1332 1333 self.assertEqual(f(1, 2), (1, 2)) 1334 self.assertEqual(f((1, 2)), ((1, 2),)) 1335 1336 def test_lru_type_error(self): 1337 # Regression test for issue #28653. 1338 # lru_cache was leaking when one of the arguments 1339 # wasn't cacheable. 1340 1341 @functools.lru_cache(maxsize=None) 1342 def infinite_cache(o): 1343 pass 1344 1345 @functools.lru_cache(maxsize=10) 1346 def limited_cache(o): 1347 pass 1348 1349 with self.assertRaises(TypeError): 1350 infinite_cache([]) 1351 1352 with self.assertRaises(TypeError): 1353 limited_cache([]) 1354 1355 def test_lru_with_maxsize_none(self): 1356 @self.module.lru_cache(maxsize=None) 1357 def fib(n): 1358 if n < 2: 1359 return n 1360 return fib(n-1) + fib(n-2) 1361 self.assertEqual([fib(n) for n in range(16)], 1362 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1363 self.assertEqual(fib.cache_info(), 1364 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1365 fib.cache_clear() 1366 self.assertEqual(fib.cache_info(), 1367 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1368 1369 def test_lru_with_maxsize_negative(self): 1370 @self.module.lru_cache(maxsize=-10) 1371 def eq(n): 1372 return n 1373 for i in (0, 1): 1374 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1375 self.assertEqual(eq.cache_info(), 1376 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1377 1378 def test_lru_with_exceptions(self): 1379 # Verify that user_function exceptions get passed through without 1380 # creating a hard-to-read chained exception. 1381 # http://bugs.python.org/issue13177 1382 for maxsize in (None, 128): 1383 @self.module.lru_cache(maxsize) 1384 def func(i): 1385 return 'abc'[i] 1386 self.assertEqual(func(0), 'a') 1387 with self.assertRaises(IndexError) as cm: 1388 func(15) 1389 self.assertIsNone(cm.exception.__context__) 1390 # Verify that the previous exception did not result in a cached entry 1391 with self.assertRaises(IndexError): 1392 func(15) 1393 1394 def test_lru_with_types(self): 1395 for maxsize in (None, 128): 1396 @self.module.lru_cache(maxsize=maxsize, typed=True) 1397 def square(x): 1398 return x * x 1399 self.assertEqual(square(3), 9) 1400 self.assertEqual(type(square(3)), type(9)) 1401 self.assertEqual(square(3.0), 9.0) 1402 self.assertEqual(type(square(3.0)), type(9.0)) 1403 self.assertEqual(square(x=3), 9) 1404 self.assertEqual(type(square(x=3)), type(9)) 1405 self.assertEqual(square(x=3.0), 9.0) 1406 self.assertEqual(type(square(x=3.0)), type(9.0)) 1407 self.assertEqual(square.cache_info().hits, 4) 1408 self.assertEqual(square.cache_info().misses, 4) 1409 1410 def test_lru_with_keyword_args(self): 1411 @self.module.lru_cache() 1412 def fib(n): 1413 if n < 2: 1414 return n 1415 return fib(n=n-1) + fib(n=n-2) 1416 self.assertEqual( 1417 [fib(n=number) for number in range(16)], 1418 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1419 ) 1420 self.assertEqual(fib.cache_info(), 1421 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1422 fib.cache_clear() 1423 self.assertEqual(fib.cache_info(), 1424 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1425 1426 def test_lru_with_keyword_args_maxsize_none(self): 1427 @self.module.lru_cache(maxsize=None) 1428 def fib(n): 1429 if n < 2: 1430 return n 1431 return fib(n=n-1) + fib(n=n-2) 1432 self.assertEqual([fib(n=number) for number in range(16)], 1433 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1434 self.assertEqual(fib.cache_info(), 1435 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1436 fib.cache_clear() 1437 self.assertEqual(fib.cache_info(), 1438 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1439 1440 def test_kwargs_order(self): 1441 # PEP 468: Preserving Keyword Argument Order 1442 @self.module.lru_cache(maxsize=10) 1443 def f(**kwargs): 1444 return list(kwargs.items()) 1445 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1446 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1447 self.assertEqual(f.cache_info(), 1448 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1449 1450 def test_lru_cache_decoration(self): 1451 def f(zomg: 'zomg_annotation'): 1452 """f doc string""" 1453 return 42 1454 g = self.module.lru_cache()(f) 1455 for attr in self.module.WRAPPER_ASSIGNMENTS: 1456 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1457 1458 def test_lru_cache_threaded(self): 1459 n, m = 5, 11 1460 def orig(x, y): 1461 return 3 * x + y 1462 f = self.module.lru_cache(maxsize=n*m)(orig) 1463 hits, misses, maxsize, currsize = f.cache_info() 1464 self.assertEqual(currsize, 0) 1465 1466 start = threading.Event() 1467 def full(k): 1468 start.wait(10) 1469 for _ in range(m): 1470 self.assertEqual(f(k, 0), orig(k, 0)) 1471 1472 def clear(): 1473 start.wait(10) 1474 for _ in range(2*m): 1475 f.cache_clear() 1476 1477 orig_si = sys.getswitchinterval() 1478 support.setswitchinterval(1e-6) 1479 try: 1480 # create n threads in order to fill cache 1481 threads = [threading.Thread(target=full, args=[k]) 1482 for k in range(n)] 1483 with support.start_threads(threads): 1484 start.set() 1485 1486 hits, misses, maxsize, currsize = f.cache_info() 1487 if self.module is py_functools: 1488 # XXX: Why can be not equal? 1489 self.assertLessEqual(misses, n) 1490 self.assertLessEqual(hits, m*n - misses) 1491 else: 1492 self.assertEqual(misses, n) 1493 self.assertEqual(hits, m*n - misses) 1494 self.assertEqual(currsize, n) 1495 1496 # create n threads in order to fill cache and 1 to clear it 1497 threads = [threading.Thread(target=clear)] 1498 threads += [threading.Thread(target=full, args=[k]) 1499 for k in range(n)] 1500 start.clear() 1501 with support.start_threads(threads): 1502 start.set() 1503 finally: 1504 sys.setswitchinterval(orig_si) 1505 1506 def test_lru_cache_threaded2(self): 1507 # Simultaneous call with the same arguments 1508 n, m = 5, 7 1509 start = threading.Barrier(n+1) 1510 pause = threading.Barrier(n+1) 1511 stop = threading.Barrier(n+1) 1512 @self.module.lru_cache(maxsize=m*n) 1513 def f(x): 1514 pause.wait(10) 1515 return 3 * x 1516 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1517 def test(): 1518 for i in range(m): 1519 start.wait(10) 1520 self.assertEqual(f(i), 3 * i) 1521 stop.wait(10) 1522 threads = [threading.Thread(target=test) for k in range(n)] 1523 with support.start_threads(threads): 1524 for i in range(m): 1525 start.wait(10) 1526 stop.reset() 1527 pause.wait(10) 1528 start.reset() 1529 stop.wait(10) 1530 pause.reset() 1531 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1532 1533 def test_lru_cache_threaded3(self): 1534 @self.module.lru_cache(maxsize=2) 1535 def f(x): 1536 time.sleep(.01) 1537 return 3 * x 1538 def test(i, x): 1539 with self.subTest(thread=i): 1540 self.assertEqual(f(x), 3 * x, i) 1541 threads = [threading.Thread(target=test, args=(i, v)) 1542 for i, v in enumerate([1, 2, 2, 3, 2])] 1543 with support.start_threads(threads): 1544 pass 1545 1546 def test_need_for_rlock(self): 1547 # This will deadlock on an LRU cache that uses a regular lock 1548 1549 @self.module.lru_cache(maxsize=10) 1550 def test_func(x): 1551 'Used to demonstrate a reentrant lru_cache call within a single thread' 1552 return x 1553 1554 class DoubleEq: 1555 'Demonstrate a reentrant lru_cache call within a single thread' 1556 def __init__(self, x): 1557 self.x = x 1558 def __hash__(self): 1559 return self.x 1560 def __eq__(self, other): 1561 if self.x == 2: 1562 test_func(DoubleEq(1)) 1563 return self.x == other.x 1564 1565 test_func(DoubleEq(1)) # Load the cache 1566 test_func(DoubleEq(2)) # Load the cache 1567 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1568 DoubleEq(2)) # Verify the correct return value 1569 1570 def test_early_detection_of_bad_call(self): 1571 # Issue #22184 1572 with self.assertRaises(TypeError): 1573 @functools.lru_cache 1574 def f(): 1575 pass 1576 1577 def test_lru_method(self): 1578 class X(int): 1579 f_cnt = 0 1580 @self.module.lru_cache(2) 1581 def f(self, x): 1582 self.f_cnt += 1 1583 return x*10+self 1584 a = X(5) 1585 b = X(5) 1586 c = X(7) 1587 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1588 1589 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1590 self.assertEqual(a.f(x), x*10 + 5) 1591 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1592 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1593 1594 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1595 self.assertEqual(b.f(x), x*10 + 5) 1596 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1597 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1598 1599 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1600 self.assertEqual(c.f(x), x*10 + 7) 1601 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1602 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1603 1604 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1605 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1606 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1607 1608 def test_pickle(self): 1609 cls = self.__class__ 1610 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1611 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1612 with self.subTest(proto=proto, func=f): 1613 f_copy = pickle.loads(pickle.dumps(f, proto)) 1614 self.assertIs(f_copy, f) 1615 1616 def test_copy(self): 1617 cls = self.__class__ 1618 def orig(x, y): 1619 return 3 * x + y 1620 part = self.module.partial(orig, 2) 1621 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1622 self.module.lru_cache(2)(part)) 1623 for f in funcs: 1624 with self.subTest(func=f): 1625 f_copy = copy.copy(f) 1626 self.assertIs(f_copy, f) 1627 1628 def test_deepcopy(self): 1629 cls = self.__class__ 1630 def orig(x, y): 1631 return 3 * x + y 1632 part = self.module.partial(orig, 2) 1633 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1634 self.module.lru_cache(2)(part)) 1635 for f in funcs: 1636 with self.subTest(func=f): 1637 f_copy = copy.deepcopy(f) 1638 self.assertIs(f_copy, f) 1639 1640 1641@py_functools.lru_cache() 1642def py_cached_func(x, y): 1643 return 3 * x + y 1644 1645@c_functools.lru_cache() 1646def c_cached_func(x, y): 1647 return 3 * x + y 1648 1649 1650class TestLRUPy(TestLRU, unittest.TestCase): 1651 module = py_functools 1652 cached_func = py_cached_func, 1653 1654 @module.lru_cache() 1655 def cached_meth(self, x, y): 1656 return 3 * x + y 1657 1658 @staticmethod 1659 @module.lru_cache() 1660 def cached_staticmeth(x, y): 1661 return 3 * x + y 1662 1663 1664class TestLRUC(TestLRU, unittest.TestCase): 1665 module = c_functools 1666 cached_func = c_cached_func, 1667 1668 @module.lru_cache() 1669 def cached_meth(self, x, y): 1670 return 3 * x + y 1671 1672 @staticmethod 1673 @module.lru_cache() 1674 def cached_staticmeth(x, y): 1675 return 3 * x + y 1676 1677 1678class TestSingleDispatch(unittest.TestCase): 1679 def test_simple_overloads(self): 1680 @functools.singledispatch 1681 def g(obj): 1682 return "base" 1683 def g_int(i): 1684 return "integer" 1685 g.register(int, g_int) 1686 self.assertEqual(g("str"), "base") 1687 self.assertEqual(g(1), "integer") 1688 self.assertEqual(g([1,2,3]), "base") 1689 1690 def test_mro(self): 1691 @functools.singledispatch 1692 def g(obj): 1693 return "base" 1694 class A: 1695 pass 1696 class C(A): 1697 pass 1698 class B(A): 1699 pass 1700 class D(C, B): 1701 pass 1702 def g_A(a): 1703 return "A" 1704 def g_B(b): 1705 return "B" 1706 g.register(A, g_A) 1707 g.register(B, g_B) 1708 self.assertEqual(g(A()), "A") 1709 self.assertEqual(g(B()), "B") 1710 self.assertEqual(g(C()), "A") 1711 self.assertEqual(g(D()), "B") 1712 1713 def test_register_decorator(self): 1714 @functools.singledispatch 1715 def g(obj): 1716 return "base" 1717 @g.register(int) 1718 def g_int(i): 1719 return "int %s" % (i,) 1720 self.assertEqual(g(""), "base") 1721 self.assertEqual(g(12), "int 12") 1722 self.assertIs(g.dispatch(int), g_int) 1723 self.assertIs(g.dispatch(object), g.dispatch(str)) 1724 # Note: in the assert above this is not g. 1725 # @singledispatch returns the wrapper. 1726 1727 def test_wrapping_attributes(self): 1728 @functools.singledispatch 1729 def g(obj): 1730 "Simple test" 1731 return "Test" 1732 self.assertEqual(g.__name__, "g") 1733 if sys.flags.optimize < 2: 1734 self.assertEqual(g.__doc__, "Simple test") 1735 1736 @unittest.skipUnless(decimal, 'requires _decimal') 1737 @support.cpython_only 1738 def test_c_classes(self): 1739 @functools.singledispatch 1740 def g(obj): 1741 return "base" 1742 @g.register(decimal.DecimalException) 1743 def _(obj): 1744 return obj.args 1745 subn = decimal.Subnormal("Exponent < Emin") 1746 rnd = decimal.Rounded("Number got rounded") 1747 self.assertEqual(g(subn), ("Exponent < Emin",)) 1748 self.assertEqual(g(rnd), ("Number got rounded",)) 1749 @g.register(decimal.Subnormal) 1750 def _(obj): 1751 return "Too small to care." 1752 self.assertEqual(g(subn), "Too small to care.") 1753 self.assertEqual(g(rnd), ("Number got rounded",)) 1754 1755 def test_compose_mro(self): 1756 # None of the examples in this test depend on haystack ordering. 1757 c = collections.abc 1758 mro = functools._compose_mro 1759 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1760 for haystack in permutations(bases): 1761 m = mro(dict, haystack) 1762 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1763 c.Collection, c.Sized, c.Iterable, 1764 c.Container, object]) 1765 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1766 for haystack in permutations(bases): 1767 m = mro(collections.ChainMap, haystack) 1768 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1769 c.Collection, c.Sized, c.Iterable, 1770 c.Container, object]) 1771 1772 # If there's a generic function with implementations registered for 1773 # both Sized and Container, passing a defaultdict to it results in an 1774 # ambiguous dispatch which will cause a RuntimeError (see 1775 # test_mro_conflicts). 1776 bases = [c.Container, c.Sized, str] 1777 for haystack in permutations(bases): 1778 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1779 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1780 c.Container, object]) 1781 1782 # MutableSequence below is registered directly on D. In other words, it 1783 # precedes MutableMapping which means single dispatch will always 1784 # choose MutableSequence here. 1785 class D(collections.defaultdict): 1786 pass 1787 c.MutableSequence.register(D) 1788 bases = [c.MutableSequence, c.MutableMapping] 1789 for haystack in permutations(bases): 1790 m = mro(D, bases) 1791 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1792 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1793 c.Collection, c.Sized, c.Iterable, c.Container, 1794 object]) 1795 1796 # Container and Callable are registered on different base classes and 1797 # a generic function supporting both should always pick the Callable 1798 # implementation if a C instance is passed. 1799 class C(collections.defaultdict): 1800 def __call__(self): 1801 pass 1802 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1803 for haystack in permutations(bases): 1804 m = mro(C, haystack) 1805 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1806 c.Collection, c.Sized, c.Iterable, 1807 c.Container, object]) 1808 1809 def test_register_abc(self): 1810 c = collections.abc 1811 d = {"a": "b"} 1812 l = [1, 2, 3] 1813 s = {object(), None} 1814 f = frozenset(s) 1815 t = (1, 2, 3) 1816 @functools.singledispatch 1817 def g(obj): 1818 return "base" 1819 self.assertEqual(g(d), "base") 1820 self.assertEqual(g(l), "base") 1821 self.assertEqual(g(s), "base") 1822 self.assertEqual(g(f), "base") 1823 self.assertEqual(g(t), "base") 1824 g.register(c.Sized, lambda obj: "sized") 1825 self.assertEqual(g(d), "sized") 1826 self.assertEqual(g(l), "sized") 1827 self.assertEqual(g(s), "sized") 1828 self.assertEqual(g(f), "sized") 1829 self.assertEqual(g(t), "sized") 1830 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1831 self.assertEqual(g(d), "mutablemapping") 1832 self.assertEqual(g(l), "sized") 1833 self.assertEqual(g(s), "sized") 1834 self.assertEqual(g(f), "sized") 1835 self.assertEqual(g(t), "sized") 1836 g.register(collections.ChainMap, lambda obj: "chainmap") 1837 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1838 self.assertEqual(g(l), "sized") 1839 self.assertEqual(g(s), "sized") 1840 self.assertEqual(g(f), "sized") 1841 self.assertEqual(g(t), "sized") 1842 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1843 self.assertEqual(g(d), "mutablemapping") 1844 self.assertEqual(g(l), "mutablesequence") 1845 self.assertEqual(g(s), "sized") 1846 self.assertEqual(g(f), "sized") 1847 self.assertEqual(g(t), "sized") 1848 g.register(c.MutableSet, lambda obj: "mutableset") 1849 self.assertEqual(g(d), "mutablemapping") 1850 self.assertEqual(g(l), "mutablesequence") 1851 self.assertEqual(g(s), "mutableset") 1852 self.assertEqual(g(f), "sized") 1853 self.assertEqual(g(t), "sized") 1854 g.register(c.Mapping, lambda obj: "mapping") 1855 self.assertEqual(g(d), "mutablemapping") # not specific enough 1856 self.assertEqual(g(l), "mutablesequence") 1857 self.assertEqual(g(s), "mutableset") 1858 self.assertEqual(g(f), "sized") 1859 self.assertEqual(g(t), "sized") 1860 g.register(c.Sequence, lambda obj: "sequence") 1861 self.assertEqual(g(d), "mutablemapping") 1862 self.assertEqual(g(l), "mutablesequence") 1863 self.assertEqual(g(s), "mutableset") 1864 self.assertEqual(g(f), "sized") 1865 self.assertEqual(g(t), "sequence") 1866 g.register(c.Set, lambda obj: "set") 1867 self.assertEqual(g(d), "mutablemapping") 1868 self.assertEqual(g(l), "mutablesequence") 1869 self.assertEqual(g(s), "mutableset") 1870 self.assertEqual(g(f), "set") 1871 self.assertEqual(g(t), "sequence") 1872 g.register(dict, lambda obj: "dict") 1873 self.assertEqual(g(d), "dict") 1874 self.assertEqual(g(l), "mutablesequence") 1875 self.assertEqual(g(s), "mutableset") 1876 self.assertEqual(g(f), "set") 1877 self.assertEqual(g(t), "sequence") 1878 g.register(list, lambda obj: "list") 1879 self.assertEqual(g(d), "dict") 1880 self.assertEqual(g(l), "list") 1881 self.assertEqual(g(s), "mutableset") 1882 self.assertEqual(g(f), "set") 1883 self.assertEqual(g(t), "sequence") 1884 g.register(set, lambda obj: "concrete-set") 1885 self.assertEqual(g(d), "dict") 1886 self.assertEqual(g(l), "list") 1887 self.assertEqual(g(s), "concrete-set") 1888 self.assertEqual(g(f), "set") 1889 self.assertEqual(g(t), "sequence") 1890 g.register(frozenset, lambda obj: "frozen-set") 1891 self.assertEqual(g(d), "dict") 1892 self.assertEqual(g(l), "list") 1893 self.assertEqual(g(s), "concrete-set") 1894 self.assertEqual(g(f), "frozen-set") 1895 self.assertEqual(g(t), "sequence") 1896 g.register(tuple, lambda obj: "tuple") 1897 self.assertEqual(g(d), "dict") 1898 self.assertEqual(g(l), "list") 1899 self.assertEqual(g(s), "concrete-set") 1900 self.assertEqual(g(f), "frozen-set") 1901 self.assertEqual(g(t), "tuple") 1902 1903 def test_c3_abc(self): 1904 c = collections.abc 1905 mro = functools._c3_mro 1906 class A(object): 1907 pass 1908 class B(A): 1909 def __len__(self): 1910 return 0 # implies Sized 1911 @c.Container.register 1912 class C(object): 1913 pass 1914 class D(object): 1915 pass # unrelated 1916 class X(D, C, B): 1917 def __call__(self): 1918 pass # implies Callable 1919 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 1920 for abcs in permutations([c.Sized, c.Callable, c.Container]): 1921 self.assertEqual(mro(X, abcs=abcs), expected) 1922 # unrelated ABCs don't appear in the resulting MRO 1923 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 1924 self.assertEqual(mro(X, abcs=many_abcs), expected) 1925 1926 def test_false_meta(self): 1927 # see issue23572 1928 class MetaA(type): 1929 def __len__(self): 1930 return 0 1931 class A(metaclass=MetaA): 1932 pass 1933 class AA(A): 1934 pass 1935 @functools.singledispatch 1936 def fun(a): 1937 return 'base A' 1938 @fun.register(A) 1939 def _(a): 1940 return 'fun A' 1941 aa = AA() 1942 self.assertEqual(fun(aa), 'fun A') 1943 1944 def test_mro_conflicts(self): 1945 c = collections.abc 1946 @functools.singledispatch 1947 def g(arg): 1948 return "base" 1949 class O(c.Sized): 1950 def __len__(self): 1951 return 0 1952 o = O() 1953 self.assertEqual(g(o), "base") 1954 g.register(c.Iterable, lambda arg: "iterable") 1955 g.register(c.Container, lambda arg: "container") 1956 g.register(c.Sized, lambda arg: "sized") 1957 g.register(c.Set, lambda arg: "set") 1958 self.assertEqual(g(o), "sized") 1959 c.Iterable.register(O) 1960 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 1961 c.Container.register(O) 1962 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 1963 c.Set.register(O) 1964 self.assertEqual(g(o), "set") # because c.Set is a subclass of 1965 # c.Sized and c.Container 1966 class P: 1967 pass 1968 p = P() 1969 self.assertEqual(g(p), "base") 1970 c.Iterable.register(P) 1971 self.assertEqual(g(p), "iterable") 1972 c.Container.register(P) 1973 with self.assertRaises(RuntimeError) as re_one: 1974 g(p) 1975 self.assertIn( 1976 str(re_one.exception), 1977 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1978 "or <class 'collections.abc.Iterable'>"), 1979 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 1980 "or <class 'collections.abc.Container'>")), 1981 ) 1982 class Q(c.Sized): 1983 def __len__(self): 1984 return 0 1985 q = Q() 1986 self.assertEqual(g(q), "sized") 1987 c.Iterable.register(Q) 1988 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 1989 c.Set.register(Q) 1990 self.assertEqual(g(q), "set") # because c.Set is a subclass of 1991 # c.Sized and c.Iterable 1992 @functools.singledispatch 1993 def h(arg): 1994 return "base" 1995 @h.register(c.Sized) 1996 def _(arg): 1997 return "sized" 1998 @h.register(c.Container) 1999 def _(arg): 2000 return "container" 2001 # Even though Sized and Container are explicit bases of MutableMapping, 2002 # this ABC is implicitly registered on defaultdict which makes all of 2003 # MutableMapping's bases implicit as well from defaultdict's 2004 # perspective. 2005 with self.assertRaises(RuntimeError) as re_two: 2006 h(collections.defaultdict(lambda: 0)) 2007 self.assertIn( 2008 str(re_two.exception), 2009 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2010 "or <class 'collections.abc.Sized'>"), 2011 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2012 "or <class 'collections.abc.Container'>")), 2013 ) 2014 class R(collections.defaultdict): 2015 pass 2016 c.MutableSequence.register(R) 2017 @functools.singledispatch 2018 def i(arg): 2019 return "base" 2020 @i.register(c.MutableMapping) 2021 def _(arg): 2022 return "mapping" 2023 @i.register(c.MutableSequence) 2024 def _(arg): 2025 return "sequence" 2026 r = R() 2027 self.assertEqual(i(r), "sequence") 2028 class S: 2029 pass 2030 class T(S, c.Sized): 2031 def __len__(self): 2032 return 0 2033 t = T() 2034 self.assertEqual(h(t), "sized") 2035 c.Container.register(T) 2036 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2037 class U: 2038 def __len__(self): 2039 return 0 2040 u = U() 2041 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2042 # from the existence of __len__() 2043 c.Container.register(U) 2044 # There is no preference for registered versus inferred ABCs. 2045 with self.assertRaises(RuntimeError) as re_three: 2046 h(u) 2047 self.assertIn( 2048 str(re_three.exception), 2049 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2050 "or <class 'collections.abc.Sized'>"), 2051 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2052 "or <class 'collections.abc.Container'>")), 2053 ) 2054 class V(c.Sized, S): 2055 def __len__(self): 2056 return 0 2057 @functools.singledispatch 2058 def j(arg): 2059 return "base" 2060 @j.register(S) 2061 def _(arg): 2062 return "s" 2063 @j.register(c.Container) 2064 def _(arg): 2065 return "container" 2066 v = V() 2067 self.assertEqual(j(v), "s") 2068 c.Container.register(V) 2069 self.assertEqual(j(v), "container") # because it ends up right after 2070 # Sized in the MRO 2071 2072 def test_cache_invalidation(self): 2073 from collections import UserDict 2074 import weakref 2075 2076 class TracingDict(UserDict): 2077 def __init__(self, *args, **kwargs): 2078 super(TracingDict, self).__init__(*args, **kwargs) 2079 self.set_ops = [] 2080 self.get_ops = [] 2081 def __getitem__(self, key): 2082 result = self.data[key] 2083 self.get_ops.append(key) 2084 return result 2085 def __setitem__(self, key, value): 2086 self.set_ops.append(key) 2087 self.data[key] = value 2088 def clear(self): 2089 self.data.clear() 2090 2091 td = TracingDict() 2092 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2093 c = collections.abc 2094 @functools.singledispatch 2095 def g(arg): 2096 return "base" 2097 d = {} 2098 l = [] 2099 self.assertEqual(len(td), 0) 2100 self.assertEqual(g(d), "base") 2101 self.assertEqual(len(td), 1) 2102 self.assertEqual(td.get_ops, []) 2103 self.assertEqual(td.set_ops, [dict]) 2104 self.assertEqual(td.data[dict], g.registry[object]) 2105 self.assertEqual(g(l), "base") 2106 self.assertEqual(len(td), 2) 2107 self.assertEqual(td.get_ops, []) 2108 self.assertEqual(td.set_ops, [dict, list]) 2109 self.assertEqual(td.data[dict], g.registry[object]) 2110 self.assertEqual(td.data[list], g.registry[object]) 2111 self.assertEqual(td.data[dict], td.data[list]) 2112 self.assertEqual(g(l), "base") 2113 self.assertEqual(g(d), "base") 2114 self.assertEqual(td.get_ops, [list, dict]) 2115 self.assertEqual(td.set_ops, [dict, list]) 2116 g.register(list, lambda arg: "list") 2117 self.assertEqual(td.get_ops, [list, dict]) 2118 self.assertEqual(len(td), 0) 2119 self.assertEqual(g(d), "base") 2120 self.assertEqual(len(td), 1) 2121 self.assertEqual(td.get_ops, [list, dict]) 2122 self.assertEqual(td.set_ops, [dict, list, dict]) 2123 self.assertEqual(td.data[dict], 2124 functools._find_impl(dict, g.registry)) 2125 self.assertEqual(g(l), "list") 2126 self.assertEqual(len(td), 2) 2127 self.assertEqual(td.get_ops, [list, dict]) 2128 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2129 self.assertEqual(td.data[list], 2130 functools._find_impl(list, g.registry)) 2131 class X: 2132 pass 2133 c.MutableMapping.register(X) # Will not invalidate the cache, 2134 # not using ABCs yet. 2135 self.assertEqual(g(d), "base") 2136 self.assertEqual(g(l), "list") 2137 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2138 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2139 g.register(c.Sized, lambda arg: "sized") 2140 self.assertEqual(len(td), 0) 2141 self.assertEqual(g(d), "sized") 2142 self.assertEqual(len(td), 1) 2143 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2144 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2145 self.assertEqual(g(l), "list") 2146 self.assertEqual(len(td), 2) 2147 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2148 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2149 self.assertEqual(g(l), "list") 2150 self.assertEqual(g(d), "sized") 2151 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2152 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2153 g.dispatch(list) 2154 g.dispatch(dict) 2155 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2156 list, dict]) 2157 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2158 c.MutableSet.register(X) # Will invalidate the cache. 2159 self.assertEqual(len(td), 2) # Stale cache. 2160 self.assertEqual(g(l), "list") 2161 self.assertEqual(len(td), 1) 2162 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2163 self.assertEqual(len(td), 0) 2164 self.assertEqual(g(d), "mutablemapping") 2165 self.assertEqual(len(td), 1) 2166 self.assertEqual(g(l), "list") 2167 self.assertEqual(len(td), 2) 2168 g.register(dict, lambda arg: "dict") 2169 self.assertEqual(g(d), "dict") 2170 self.assertEqual(g(l), "list") 2171 g._clear_cache() 2172 self.assertEqual(len(td), 0) 2173 2174 def test_annotations(self): 2175 @functools.singledispatch 2176 def i(arg): 2177 return "base" 2178 @i.register 2179 def _(arg: collections.abc.Mapping): 2180 return "mapping" 2181 @i.register 2182 def _(arg: "collections.abc.Sequence"): 2183 return "sequence" 2184 self.assertEqual(i(None), "base") 2185 self.assertEqual(i({"a": 1}), "mapping") 2186 self.assertEqual(i([1, 2, 3]), "sequence") 2187 self.assertEqual(i((1, 2, 3)), "sequence") 2188 self.assertEqual(i("str"), "sequence") 2189 2190 # Registering classes as callables doesn't work with annotations, 2191 # you need to pass the type explicitly. 2192 @i.register(str) 2193 class _: 2194 def __init__(self, arg): 2195 self.arg = arg 2196 2197 def __eq__(self, other): 2198 return self.arg == other 2199 self.assertEqual(i("str"), "str") 2200 2201 def test_invalid_registrations(self): 2202 msg_prefix = "Invalid first argument to `register()`: " 2203 msg_suffix = ( 2204 ". Use either `@register(some_class)` or plain `@register` on an " 2205 "annotated function." 2206 ) 2207 @functools.singledispatch 2208 def i(arg): 2209 return "base" 2210 with self.assertRaises(TypeError) as exc: 2211 @i.register(42) 2212 def _(arg): 2213 return "I annotated with a non-type" 2214 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2215 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2216 with self.assertRaises(TypeError) as exc: 2217 @i.register 2218 def _(arg): 2219 return "I forgot to annotate" 2220 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2221 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2222 )) 2223 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2224 2225 # FIXME: The following will only work after PEP 560 is implemented. 2226 return 2227 2228 with self.assertRaises(TypeError) as exc: 2229 @i.register 2230 def _(arg: typing.Iterable[str]): 2231 # At runtime, dispatching on generics is impossible. 2232 # When registering implementations with singledispatch, avoid 2233 # types from `typing`. Instead, annotate with regular types 2234 # or ABCs. 2235 return "I annotated with a generic collection" 2236 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2237 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2238 )) 2239 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2240 2241 def test_invalid_positional_argument(self): 2242 @functools.singledispatch 2243 def f(*args): 2244 pass 2245 msg = 'f requires at least 1 positional argument' 2246 with self.assertRaisesRegex(TypeError, msg): 2247 f() 2248 2249if __name__ == '__main__': 2250 unittest.main() 2251