1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16from weakref import proxy 17import contextlib 18 19import functools 20 21py_functools = support.import_fresh_module('functools', blocked=['_functools']) 22c_functools = support.import_fresh_module('functools', fresh=['_functools']) 23 24decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 25 26@contextlib.contextmanager 27def replaced_module(name, replacement): 28 original_module = sys.modules[name] 29 sys.modules[name] = replacement 30 try: 31 yield 32 finally: 33 sys.modules[name] = original_module 34 35def capture(*args, **kw): 36 """capture all positional and keyword arguments""" 37 return args, kw 38 39 40def signature(part): 41 """ return the signature of a partial object """ 42 return (part.func, part.args, part.keywords, part.__dict__) 43 44class MyTuple(tuple): 45 pass 46 47class BadTuple(tuple): 48 def __add__(self, other): 49 return list(self) + list(other) 50 51class MyDict(dict): 52 pass 53 54 55class TestPartial: 56 57 def test_basic_examples(self): 58 p = self.partial(capture, 1, 2, a=10, b=20) 59 self.assertTrue(callable(p)) 60 self.assertEqual(p(3, 4, b=30, c=40), 61 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 62 p = self.partial(map, lambda x: x*10) 63 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 64 65 def test_attributes(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 # attributes should be readable 68 self.assertEqual(p.func, capture) 69 self.assertEqual(p.args, (1, 2)) 70 self.assertEqual(p.keywords, dict(a=10, b=20)) 71 72 def test_argument_checking(self): 73 self.assertRaises(TypeError, self.partial) # need at least a func arg 74 try: 75 self.partial(2)() 76 except TypeError: 77 pass 78 else: 79 self.fail('First arg not checked for callability') 80 81 def test_protection_of_callers_dict_argument(self): 82 # a caller's dictionary should not be altered by partial 83 def func(a=10, b=20): 84 return a 85 d = {'a':3} 86 p = self.partial(func, a=5) 87 self.assertEqual(p(**d), 3) 88 self.assertEqual(d, {'a':3}) 89 p(b=7) 90 self.assertEqual(d, {'a':3}) 91 92 def test_kwargs_copy(self): 93 # Issue #29532: Altering a kwarg dictionary passed to a constructor 94 # should not affect a partial object after creation 95 d = {'a': 3} 96 p = self.partial(capture, **d) 97 self.assertEqual(p(), ((), {'a': 3})) 98 d['a'] = 5 99 self.assertEqual(p(), ((), {'a': 3})) 100 101 def test_arg_combinations(self): 102 # exercise special code paths for zero args in either partial 103 # object or the caller 104 p = self.partial(capture) 105 self.assertEqual(p(), ((), {})) 106 self.assertEqual(p(1,2), ((1,2), {})) 107 p = self.partial(capture, 1, 2) 108 self.assertEqual(p(), ((1,2), {})) 109 self.assertEqual(p(3,4), ((1,2,3,4), {})) 110 111 def test_kw_combinations(self): 112 # exercise special code paths for no keyword args in 113 # either the partial object or the caller 114 p = self.partial(capture) 115 self.assertEqual(p.keywords, {}) 116 self.assertEqual(p(), ((), {})) 117 self.assertEqual(p(a=1), ((), {'a':1})) 118 p = self.partial(capture, a=1) 119 self.assertEqual(p.keywords, {'a':1}) 120 self.assertEqual(p(), ((), {'a':1})) 121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 122 # keyword args in the call override those in the partial object 123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 124 125 def test_positional(self): 126 # make sure positional arguments are captured correctly 127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 128 p = self.partial(capture, *args) 129 expected = args + ('x',) 130 got, empty = p('x') 131 self.assertTrue(expected == got and empty == {}) 132 133 def test_keyword(self): 134 # make sure keyword arguments are captured correctly 135 for a in ['a', 0, None, 3.5]: 136 p = self.partial(capture, a=a) 137 expected = {'a':a,'x':None} 138 empty, got = p(x=None) 139 self.assertTrue(expected == got and empty == ()) 140 141 def test_no_side_effects(self): 142 # make sure there are no side effects that affect subsequent calls 143 p = self.partial(capture, 0, a=1) 144 args1, kw1 = p(1, b=2) 145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 146 args2, kw2 = p() 147 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 148 149 def test_error_propagation(self): 150 def f(x, y): 151 x / y 152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 156 157 def test_weakref(self): 158 f = self.partial(int, base=16) 159 p = proxy(f) 160 self.assertEqual(f.func, p.func) 161 f = None 162 self.assertRaises(ReferenceError, getattr, p, 'func') 163 164 def test_with_bound_and_unbound_methods(self): 165 data = list(map(str, range(10))) 166 join = self.partial(str.join, '') 167 self.assertEqual(join(data), '0123456789') 168 join = self.partial(''.join) 169 self.assertEqual(join(data), '0123456789') 170 171 def test_nested_optimization(self): 172 partial = self.partial 173 inner = partial(signature, 'asdf') 174 nested = partial(inner, bar=True) 175 flat = partial(signature, 'asdf', bar=True) 176 self.assertEqual(signature(nested), signature(flat)) 177 178 def test_nested_partial_with_attribute(self): 179 # see issue 25137 180 partial = self.partial 181 182 def foo(bar): 183 return bar 184 185 p = partial(foo, 'first') 186 p2 = partial(p, 'second') 187 p2.new_attr = 'spam' 188 self.assertEqual(p2.new_attr, 'spam') 189 190 def test_repr(self): 191 args = (object(), object()) 192 args_repr = ', '.join(repr(a) for a in args) 193 kwargs = {'a': object(), 'b': object()} 194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 195 'b={b!r}, a={a!r}'.format_map(kwargs)] 196 if self.partial in (c_functools.partial, py_functools.partial): 197 name = 'functools.partial' 198 else: 199 name = self.partial.__name__ 200 201 f = self.partial(capture) 202 self.assertEqual(f'{name}({capture!r})', repr(f)) 203 204 f = self.partial(capture, *args) 205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 206 207 f = self.partial(capture, **kwargs) 208 self.assertIn(repr(f), 209 [f'{name}({capture!r}, {kwargs_repr})' 210 for kwargs_repr in kwargs_reprs]) 211 212 f = self.partial(capture, *args, **kwargs) 213 self.assertIn(repr(f), 214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 215 for kwargs_repr in kwargs_reprs]) 216 217 def test_recursive_repr(self): 218 if self.partial in (c_functools.partial, py_functools.partial): 219 name = 'functools.partial' 220 else: 221 name = self.partial.__name__ 222 223 f = self.partial(capture) 224 f.__setstate__((f, (), {}, {})) 225 try: 226 self.assertEqual(repr(f), '%s(...)' % (name,)) 227 finally: 228 f.__setstate__((capture, (), {}, {})) 229 230 f = self.partial(capture) 231 f.__setstate__((capture, (f,), {}, {})) 232 try: 233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 234 finally: 235 f.__setstate__((capture, (), {}, {})) 236 237 f = self.partial(capture) 238 f.__setstate__((capture, (), {'a': f}, {})) 239 try: 240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 241 finally: 242 f.__setstate__((capture, (), {}, {})) 243 244 def test_pickle(self): 245 with self.AllowPickle(): 246 f = self.partial(signature, ['asdf'], bar=[True]) 247 f.attr = [] 248 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 249 f_copy = pickle.loads(pickle.dumps(f, proto)) 250 self.assertEqual(signature(f_copy), signature(f)) 251 252 def test_copy(self): 253 f = self.partial(signature, ['asdf'], bar=[True]) 254 f.attr = [] 255 f_copy = copy.copy(f) 256 self.assertEqual(signature(f_copy), signature(f)) 257 self.assertIs(f_copy.attr, f.attr) 258 self.assertIs(f_copy.args, f.args) 259 self.assertIs(f_copy.keywords, f.keywords) 260 261 def test_deepcopy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.deepcopy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIsNot(f_copy.attr, f.attr) 267 self.assertIsNot(f_copy.args, f.args) 268 self.assertIsNot(f_copy.args[0], f.args[0]) 269 self.assertIsNot(f_copy.keywords, f.keywords) 270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 271 272 def test_setstate(self): 273 f = self.partial(signature) 274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 275 276 self.assertEqual(signature(f), 277 (capture, (1,), dict(a=10), dict(attr=[]))) 278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 279 280 f.__setstate__((capture, (1,), dict(a=10), None)) 281 282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 284 285 f.__setstate__((capture, (1,), None, None)) 286 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 288 self.assertEqual(f(2), ((1, 2), {})) 289 self.assertEqual(f(), ((1,), {})) 290 291 f.__setstate__((capture, (), {}, None)) 292 self.assertEqual(signature(f), (capture, (), {}, {})) 293 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 294 self.assertEqual(f(2), ((2,), {})) 295 self.assertEqual(f(), ((), {})) 296 297 def test_setstate_errors(self): 298 f = self.partial(signature) 299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 306 307 def test_setstate_subclasses(self): 308 f = self.partial(signature) 309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 310 s = signature(f) 311 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 312 self.assertIs(type(s[1]), tuple) 313 self.assertIs(type(s[2]), dict) 314 r = f() 315 self.assertEqual(r, ((1,), {'a': 10})) 316 self.assertIs(type(r[0]), tuple) 317 self.assertIs(type(r[1]), dict) 318 319 f.__setstate__((capture, BadTuple((1,)), {}, None)) 320 s = signature(f) 321 self.assertEqual(s, (capture, (1,), {}, {})) 322 self.assertIs(type(s[1]), tuple) 323 r = f(2) 324 self.assertEqual(r, ((1, 2), {})) 325 self.assertIs(type(r[0]), tuple) 326 327 def test_recursive_pickle(self): 328 with self.AllowPickle(): 329 f = self.partial(capture) 330 f.__setstate__((f, (), {}, {})) 331 try: 332 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 333 with self.assertRaises(RecursionError): 334 pickle.dumps(f, proto) 335 finally: 336 f.__setstate__((capture, (), {}, {})) 337 338 f = self.partial(capture) 339 f.__setstate__((capture, (f,), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 f_copy = pickle.loads(pickle.dumps(f, proto)) 343 try: 344 self.assertIs(f_copy.args[0], f_copy) 345 finally: 346 f_copy.__setstate__((capture, (), {}, {})) 347 finally: 348 f.__setstate__((capture, (), {}, {})) 349 350 f = self.partial(capture) 351 f.__setstate__((capture, (), {'a': f}, {})) 352 try: 353 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 354 f_copy = pickle.loads(pickle.dumps(f, proto)) 355 try: 356 self.assertIs(f_copy.keywords['a'], f_copy) 357 finally: 358 f_copy.__setstate__((capture, (), {}, {})) 359 finally: 360 f.__setstate__((capture, (), {}, {})) 361 362 # Issue 6083: Reference counting bug 363 def test_setstate_refcount(self): 364 class BadSequence: 365 def __len__(self): 366 return 4 367 def __getitem__(self, key): 368 if key == 0: 369 return max 370 elif key == 1: 371 return tuple(range(1000000)) 372 elif key in (2, 3): 373 return {} 374 raise IndexError 375 376 f = self.partial(object) 377 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 378 379@unittest.skipUnless(c_functools, 'requires the C _functools module') 380class TestPartialC(TestPartial, unittest.TestCase): 381 if c_functools: 382 partial = c_functools.partial 383 384 class AllowPickle: 385 def __enter__(self): 386 return self 387 def __exit__(self, type, value, tb): 388 return False 389 390 def test_attributes_unwritable(self): 391 # attributes should not be writable 392 p = self.partial(capture, 1, 2, a=10, b=20) 393 self.assertRaises(AttributeError, setattr, p, 'func', map) 394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 396 397 p = self.partial(hex) 398 try: 399 del p.__dict__ 400 except TypeError: 401 pass 402 else: 403 self.fail('partial object allowed __dict__ to be deleted') 404 405 def test_manually_adding_non_string_keyword(self): 406 p = self.partial(capture) 407 # Adding a non-string/unicode keyword to partial kwargs 408 p.keywords[1234] = 'value' 409 r = repr(p) 410 self.assertIn('1234', r) 411 self.assertIn("'value'", r) 412 with self.assertRaises(TypeError): 413 p() 414 415 def test_keystr_replaces_value(self): 416 p = self.partial(capture) 417 418 class MutatesYourDict(object): 419 def __str__(self): 420 p.keywords[self] = ['sth2'] 421 return 'astr' 422 423 # Replacing the value during key formatting should keep the original 424 # value alive (at least long enough). 425 p.keywords[MutatesYourDict()] = ['sth'] 426 r = repr(p) 427 self.assertIn('astr', r) 428 self.assertIn("['sth']", r) 429 430 431class TestPartialPy(TestPartial, unittest.TestCase): 432 partial = py_functools.partial 433 434 class AllowPickle: 435 def __init__(self): 436 self._cm = replaced_module("functools", py_functools) 437 def __enter__(self): 438 return self._cm.__enter__() 439 def __exit__(self, type, value, tb): 440 return self._cm.__exit__(type, value, tb) 441 442if c_functools: 443 class CPartialSubclass(c_functools.partial): 444 pass 445 446class PyPartialSubclass(py_functools.partial): 447 pass 448 449@unittest.skipUnless(c_functools, 'requires the C _functools module') 450class TestPartialCSubclass(TestPartialC): 451 if c_functools: 452 partial = CPartialSubclass 453 454 # partial subclasses are not optimized for nested calls 455 test_nested_optimization = None 456 457class TestPartialPySubclass(TestPartialPy): 458 partial = PyPartialSubclass 459 460class TestPartialMethod(unittest.TestCase): 461 462 class A(object): 463 nothing = functools.partialmethod(capture) 464 positional = functools.partialmethod(capture, 1) 465 keywords = functools.partialmethod(capture, a=2) 466 both = functools.partialmethod(capture, 3, b=4) 467 spec_keywords = functools.partialmethod(capture, self=1, func=2) 468 469 nested = functools.partialmethod(positional, 5) 470 471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 472 473 static = functools.partialmethod(staticmethod(capture), 8) 474 cls = functools.partialmethod(classmethod(capture), d=9) 475 476 a = A() 477 478 def test_arg_combinations(self): 479 self.assertEqual(self.a.nothing(), ((self.a,), {})) 480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 483 484 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 488 489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 493 494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 498 499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 500 501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 502 503 def test_nested(self): 504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 508 509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 510 511 def test_over_partial(self): 512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 516 517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 518 519 def test_bound_method_introspection(self): 520 obj = self.a 521 self.assertIs(obj.both.__self__, obj) 522 self.assertIs(obj.nested.__self__, obj) 523 self.assertIs(obj.over_partial.__self__, obj) 524 self.assertIs(obj.cls.__self__, self.A) 525 self.assertIs(self.A.cls.__self__, self.A) 526 527 def test_unbound_method_retrieval(self): 528 obj = self.A 529 self.assertFalse(hasattr(obj.both, "__self__")) 530 self.assertFalse(hasattr(obj.nested, "__self__")) 531 self.assertFalse(hasattr(obj.over_partial, "__self__")) 532 self.assertFalse(hasattr(obj.static, "__self__")) 533 self.assertFalse(hasattr(self.a.static, "__self__")) 534 535 def test_descriptors(self): 536 for obj in [self.A, self.a]: 537 with self.subTest(obj=obj): 538 self.assertEqual(obj.static(), ((8,), {})) 539 self.assertEqual(obj.static(5), ((8, 5), {})) 540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 542 543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 547 548 def test_overriding_keywords(self): 549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 551 552 def test_invalid_args(self): 553 with self.assertRaises(TypeError): 554 class B(object): 555 method = functools.partialmethod(None, 1) 556 with self.assertRaises(TypeError): 557 class B: 558 method = functools.partialmethod() 559 with self.assertWarns(DeprecationWarning): 560 class B: 561 method = functools.partialmethod(func=capture, a=1) 562 b = B() 563 self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3})) 564 565 def test_repr(self): 566 self.assertEqual(repr(vars(self.A)['both']), 567 'functools.partialmethod({}, 3, b=4)'.format(capture)) 568 569 def test_abstract(self): 570 class Abstract(abc.ABCMeta): 571 572 @abc.abstractmethod 573 def add(self, x, y): 574 pass 575 576 add5 = functools.partialmethod(add, 5) 577 578 self.assertTrue(Abstract.add.__isabstractmethod__) 579 self.assertTrue(Abstract.add5.__isabstractmethod__) 580 581 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 582 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 583 584 def test_positional_only(self): 585 def f(a, b, /): 586 return a + b 587 588 p = functools.partial(f, 1) 589 self.assertEqual(p(2), f(1, 2)) 590 591 592class TestUpdateWrapper(unittest.TestCase): 593 594 def check_wrapper(self, wrapper, wrapped, 595 assigned=functools.WRAPPER_ASSIGNMENTS, 596 updated=functools.WRAPPER_UPDATES): 597 # Check attributes were assigned 598 for name in assigned: 599 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 600 # Check attributes were updated 601 for name in updated: 602 wrapper_attr = getattr(wrapper, name) 603 wrapped_attr = getattr(wrapped, name) 604 for key in wrapped_attr: 605 if name == "__dict__" and key == "__wrapped__": 606 # __wrapped__ is overwritten by the update code 607 continue 608 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 609 # Check __wrapped__ 610 self.assertIs(wrapper.__wrapped__, wrapped) 611 612 613 def _default_update(self): 614 def f(a:'This is a new annotation'): 615 """This is a test""" 616 pass 617 f.attr = 'This is also a test' 618 f.__wrapped__ = "This is a bald faced lie" 619 def wrapper(b:'This is the prior annotation'): 620 pass 621 functools.update_wrapper(wrapper, f) 622 return wrapper, f 623 624 def test_default_update(self): 625 wrapper, f = self._default_update() 626 self.check_wrapper(wrapper, f) 627 self.assertIs(wrapper.__wrapped__, f) 628 self.assertEqual(wrapper.__name__, 'f') 629 self.assertEqual(wrapper.__qualname__, f.__qualname__) 630 self.assertEqual(wrapper.attr, 'This is also a test') 631 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 632 self.assertNotIn('b', wrapper.__annotations__) 633 634 @unittest.skipIf(sys.flags.optimize >= 2, 635 "Docstrings are omitted with -O2 and above") 636 def test_default_update_doc(self): 637 wrapper, f = self._default_update() 638 self.assertEqual(wrapper.__doc__, 'This is a test') 639 640 def test_no_update(self): 641 def f(): 642 """This is a test""" 643 pass 644 f.attr = 'This is also a test' 645 def wrapper(): 646 pass 647 functools.update_wrapper(wrapper, f, (), ()) 648 self.check_wrapper(wrapper, f, (), ()) 649 self.assertEqual(wrapper.__name__, 'wrapper') 650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 651 self.assertEqual(wrapper.__doc__, None) 652 self.assertEqual(wrapper.__annotations__, {}) 653 self.assertFalse(hasattr(wrapper, 'attr')) 654 655 def test_selective_update(self): 656 def f(): 657 pass 658 f.attr = 'This is a different test' 659 f.dict_attr = dict(a=1, b=2, c=3) 660 def wrapper(): 661 pass 662 wrapper.dict_attr = {} 663 assign = ('attr',) 664 update = ('dict_attr',) 665 functools.update_wrapper(wrapper, f, assign, update) 666 self.check_wrapper(wrapper, f, assign, update) 667 self.assertEqual(wrapper.__name__, 'wrapper') 668 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 669 self.assertEqual(wrapper.__doc__, None) 670 self.assertEqual(wrapper.attr, 'This is a different test') 671 self.assertEqual(wrapper.dict_attr, f.dict_attr) 672 673 def test_missing_attributes(self): 674 def f(): 675 pass 676 def wrapper(): 677 pass 678 wrapper.dict_attr = {} 679 assign = ('attr',) 680 update = ('dict_attr',) 681 # Missing attributes on wrapped object are ignored 682 functools.update_wrapper(wrapper, f, assign, update) 683 self.assertNotIn('attr', wrapper.__dict__) 684 self.assertEqual(wrapper.dict_attr, {}) 685 # Wrapper must have expected attributes for updating 686 del wrapper.dict_attr 687 with self.assertRaises(AttributeError): 688 functools.update_wrapper(wrapper, f, assign, update) 689 wrapper.dict_attr = 1 690 with self.assertRaises(AttributeError): 691 functools.update_wrapper(wrapper, f, assign, update) 692 693 @support.requires_docstrings 694 @unittest.skipIf(sys.flags.optimize >= 2, 695 "Docstrings are omitted with -O2 and above") 696 def test_builtin_update(self): 697 # Test for bug #1576241 698 def wrapper(): 699 pass 700 functools.update_wrapper(wrapper, max) 701 self.assertEqual(wrapper.__name__, 'max') 702 self.assertTrue(wrapper.__doc__.startswith('max(')) 703 self.assertEqual(wrapper.__annotations__, {}) 704 705 706class TestWraps(TestUpdateWrapper): 707 708 def _default_update(self): 709 def f(): 710 """This is a test""" 711 pass 712 f.attr = 'This is also a test' 713 f.__wrapped__ = "This is still a bald faced lie" 714 @functools.wraps(f) 715 def wrapper(): 716 pass 717 return wrapper, f 718 719 def test_default_update(self): 720 wrapper, f = self._default_update() 721 self.check_wrapper(wrapper, f) 722 self.assertEqual(wrapper.__name__, 'f') 723 self.assertEqual(wrapper.__qualname__, f.__qualname__) 724 self.assertEqual(wrapper.attr, 'This is also a test') 725 726 @unittest.skipIf(sys.flags.optimize >= 2, 727 "Docstrings are omitted with -O2 and above") 728 def test_default_update_doc(self): 729 wrapper, _ = self._default_update() 730 self.assertEqual(wrapper.__doc__, 'This is a test') 731 732 def test_no_update(self): 733 def f(): 734 """This is a test""" 735 pass 736 f.attr = 'This is also a test' 737 @functools.wraps(f, (), ()) 738 def wrapper(): 739 pass 740 self.check_wrapper(wrapper, f, (), ()) 741 self.assertEqual(wrapper.__name__, 'wrapper') 742 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 743 self.assertEqual(wrapper.__doc__, None) 744 self.assertFalse(hasattr(wrapper, 'attr')) 745 746 def test_selective_update(self): 747 def f(): 748 pass 749 f.attr = 'This is a different test' 750 f.dict_attr = dict(a=1, b=2, c=3) 751 def add_dict_attr(f): 752 f.dict_attr = {} 753 return f 754 assign = ('attr',) 755 update = ('dict_attr',) 756 @functools.wraps(f, assign, update) 757 @add_dict_attr 758 def wrapper(): 759 pass 760 self.check_wrapper(wrapper, f, assign, update) 761 self.assertEqual(wrapper.__name__, 'wrapper') 762 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 763 self.assertEqual(wrapper.__doc__, None) 764 self.assertEqual(wrapper.attr, 'This is a different test') 765 self.assertEqual(wrapper.dict_attr, f.dict_attr) 766 767 768class TestReduce: 769 def test_reduce(self): 770 class Squares: 771 def __init__(self, max): 772 self.max = max 773 self.sofar = [] 774 775 def __len__(self): 776 return len(self.sofar) 777 778 def __getitem__(self, i): 779 if not 0 <= i < self.max: raise IndexError 780 n = len(self.sofar) 781 while n <= i: 782 self.sofar.append(n*n) 783 n += 1 784 return self.sofar[i] 785 def add(x, y): 786 return x + y 787 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 788 self.assertEqual( 789 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 790 ['a','c','d','w'] 791 ) 792 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 793 self.assertEqual( 794 self.reduce(lambda x, y: x*y, range(2,21), 1), 795 2432902008176640000 796 ) 797 self.assertEqual(self.reduce(add, Squares(10)), 285) 798 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 799 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 800 self.assertRaises(TypeError, self.reduce) 801 self.assertRaises(TypeError, self.reduce, 42, 42) 802 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 803 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 804 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 805 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 806 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 807 self.assertRaises(TypeError, self.reduce, add, "") 808 self.assertRaises(TypeError, self.reduce, add, ()) 809 self.assertRaises(TypeError, self.reduce, add, object()) 810 811 class TestFailingIter: 812 def __iter__(self): 813 raise RuntimeError 814 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 815 816 self.assertEqual(self.reduce(add, [], None), None) 817 self.assertEqual(self.reduce(add, [], 42), 42) 818 819 class BadSeq: 820 def __getitem__(self, index): 821 raise ValueError 822 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 823 824 # Test reduce()'s use of iterators. 825 def test_iterator_usage(self): 826 class SequenceClass: 827 def __init__(self, n): 828 self.n = n 829 def __getitem__(self, i): 830 if 0 <= i < self.n: 831 return i 832 else: 833 raise IndexError 834 835 from operator import add 836 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 837 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 838 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 839 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 840 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 841 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 842 843 d = {"one": 1, "two": 2, "three": 3} 844 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 845 846 847@unittest.skipUnless(c_functools, 'requires the C _functools module') 848class TestReduceC(TestReduce, unittest.TestCase): 849 if c_functools: 850 reduce = c_functools.reduce 851 852 853class TestReducePy(TestReduce, unittest.TestCase): 854 reduce = staticmethod(py_functools.reduce) 855 856 857class TestCmpToKey: 858 859 def test_cmp_to_key(self): 860 def cmp1(x, y): 861 return (x > y) - (x < y) 862 key = self.cmp_to_key(cmp1) 863 self.assertEqual(key(3), key(3)) 864 self.assertGreater(key(3), key(1)) 865 self.assertGreaterEqual(key(3), key(3)) 866 867 def cmp2(x, y): 868 return int(x) - int(y) 869 key = self.cmp_to_key(cmp2) 870 self.assertEqual(key(4.0), key('4')) 871 self.assertLess(key(2), key('35')) 872 self.assertLessEqual(key(2), key('35')) 873 self.assertNotEqual(key(2), key('35')) 874 875 def test_cmp_to_key_arguments(self): 876 def cmp1(x, y): 877 return (x > y) - (x < y) 878 key = self.cmp_to_key(mycmp=cmp1) 879 self.assertEqual(key(obj=3), key(obj=3)) 880 self.assertGreater(key(obj=3), key(obj=1)) 881 with self.assertRaises((TypeError, AttributeError)): 882 key(3) > 1 # rhs is not a K object 883 with self.assertRaises((TypeError, AttributeError)): 884 1 < key(3) # lhs is not a K object 885 with self.assertRaises(TypeError): 886 key = self.cmp_to_key() # too few args 887 with self.assertRaises(TypeError): 888 key = self.cmp_to_key(cmp1, None) # too many args 889 key = self.cmp_to_key(cmp1) 890 with self.assertRaises(TypeError): 891 key() # too few args 892 with self.assertRaises(TypeError): 893 key(None, None) # too many args 894 895 def test_bad_cmp(self): 896 def cmp1(x, y): 897 raise ZeroDivisionError 898 key = self.cmp_to_key(cmp1) 899 with self.assertRaises(ZeroDivisionError): 900 key(3) > key(1) 901 902 class BadCmp: 903 def __lt__(self, other): 904 raise ZeroDivisionError 905 def cmp1(x, y): 906 return BadCmp() 907 with self.assertRaises(ZeroDivisionError): 908 key(3) > key(1) 909 910 def test_obj_field(self): 911 def cmp1(x, y): 912 return (x > y) - (x < y) 913 key = self.cmp_to_key(mycmp=cmp1) 914 self.assertEqual(key(50).obj, 50) 915 916 def test_sort_int(self): 917 def mycmp(x, y): 918 return y - x 919 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 920 [4, 3, 2, 1, 0]) 921 922 def test_sort_int_str(self): 923 def mycmp(x, y): 924 x, y = int(x), int(y) 925 return (x > y) - (x < y) 926 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 927 values = sorted(values, key=self.cmp_to_key(mycmp)) 928 self.assertEqual([int(value) for value in values], 929 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 930 931 def test_hash(self): 932 def mycmp(x, y): 933 return y - x 934 key = self.cmp_to_key(mycmp) 935 k = key(10) 936 self.assertRaises(TypeError, hash, k) 937 self.assertNotIsInstance(k, collections.abc.Hashable) 938 939 940@unittest.skipUnless(c_functools, 'requires the C _functools module') 941class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 942 if c_functools: 943 cmp_to_key = c_functools.cmp_to_key 944 945 946class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 947 cmp_to_key = staticmethod(py_functools.cmp_to_key) 948 949 950class TestTotalOrdering(unittest.TestCase): 951 952 def test_total_ordering_lt(self): 953 @functools.total_ordering 954 class A: 955 def __init__(self, value): 956 self.value = value 957 def __lt__(self, other): 958 return self.value < other.value 959 def __eq__(self, other): 960 return self.value == other.value 961 self.assertTrue(A(1) < A(2)) 962 self.assertTrue(A(2) > A(1)) 963 self.assertTrue(A(1) <= A(2)) 964 self.assertTrue(A(2) >= A(1)) 965 self.assertTrue(A(2) <= A(2)) 966 self.assertTrue(A(2) >= A(2)) 967 self.assertFalse(A(1) > A(2)) 968 969 def test_total_ordering_le(self): 970 @functools.total_ordering 971 class A: 972 def __init__(self, value): 973 self.value = value 974 def __le__(self, other): 975 return self.value <= other.value 976 def __eq__(self, other): 977 return self.value == other.value 978 self.assertTrue(A(1) < A(2)) 979 self.assertTrue(A(2) > A(1)) 980 self.assertTrue(A(1) <= A(2)) 981 self.assertTrue(A(2) >= A(1)) 982 self.assertTrue(A(2) <= A(2)) 983 self.assertTrue(A(2) >= A(2)) 984 self.assertFalse(A(1) >= A(2)) 985 986 def test_total_ordering_gt(self): 987 @functools.total_ordering 988 class A: 989 def __init__(self, value): 990 self.value = value 991 def __gt__(self, other): 992 return self.value > other.value 993 def __eq__(self, other): 994 return self.value == other.value 995 self.assertTrue(A(1) < A(2)) 996 self.assertTrue(A(2) > A(1)) 997 self.assertTrue(A(1) <= A(2)) 998 self.assertTrue(A(2) >= A(1)) 999 self.assertTrue(A(2) <= A(2)) 1000 self.assertTrue(A(2) >= A(2)) 1001 self.assertFalse(A(2) < A(1)) 1002 1003 def test_total_ordering_ge(self): 1004 @functools.total_ordering 1005 class A: 1006 def __init__(self, value): 1007 self.value = value 1008 def __ge__(self, other): 1009 return self.value >= other.value 1010 def __eq__(self, other): 1011 return self.value == other.value 1012 self.assertTrue(A(1) < A(2)) 1013 self.assertTrue(A(2) > A(1)) 1014 self.assertTrue(A(1) <= A(2)) 1015 self.assertTrue(A(2) >= A(1)) 1016 self.assertTrue(A(2) <= A(2)) 1017 self.assertTrue(A(2) >= A(2)) 1018 self.assertFalse(A(2) <= A(1)) 1019 1020 def test_total_ordering_no_overwrite(self): 1021 # new methods should not overwrite existing 1022 @functools.total_ordering 1023 class A(int): 1024 pass 1025 self.assertTrue(A(1) < A(2)) 1026 self.assertTrue(A(2) > A(1)) 1027 self.assertTrue(A(1) <= A(2)) 1028 self.assertTrue(A(2) >= A(1)) 1029 self.assertTrue(A(2) <= A(2)) 1030 self.assertTrue(A(2) >= A(2)) 1031 1032 def test_no_operations_defined(self): 1033 with self.assertRaises(ValueError): 1034 @functools.total_ordering 1035 class A: 1036 pass 1037 1038 def test_type_error_when_not_implemented(self): 1039 # bug 10042; ensure stack overflow does not occur 1040 # when decorated types return NotImplemented 1041 @functools.total_ordering 1042 class ImplementsLessThan: 1043 def __init__(self, value): 1044 self.value = value 1045 def __eq__(self, other): 1046 if isinstance(other, ImplementsLessThan): 1047 return self.value == other.value 1048 return False 1049 def __lt__(self, other): 1050 if isinstance(other, ImplementsLessThan): 1051 return self.value < other.value 1052 return NotImplemented 1053 1054 @functools.total_ordering 1055 class ImplementsGreaterThan: 1056 def __init__(self, value): 1057 self.value = value 1058 def __eq__(self, other): 1059 if isinstance(other, ImplementsGreaterThan): 1060 return self.value == other.value 1061 return False 1062 def __gt__(self, other): 1063 if isinstance(other, ImplementsGreaterThan): 1064 return self.value > other.value 1065 return NotImplemented 1066 1067 @functools.total_ordering 1068 class ImplementsLessThanEqualTo: 1069 def __init__(self, value): 1070 self.value = value 1071 def __eq__(self, other): 1072 if isinstance(other, ImplementsLessThanEqualTo): 1073 return self.value == other.value 1074 return False 1075 def __le__(self, other): 1076 if isinstance(other, ImplementsLessThanEqualTo): 1077 return self.value <= other.value 1078 return NotImplemented 1079 1080 @functools.total_ordering 1081 class ImplementsGreaterThanEqualTo: 1082 def __init__(self, value): 1083 self.value = value 1084 def __eq__(self, other): 1085 if isinstance(other, ImplementsGreaterThanEqualTo): 1086 return self.value == other.value 1087 return False 1088 def __ge__(self, other): 1089 if isinstance(other, ImplementsGreaterThanEqualTo): 1090 return self.value >= other.value 1091 return NotImplemented 1092 1093 @functools.total_ordering 1094 class ComparatorNotImplemented: 1095 def __init__(self, value): 1096 self.value = value 1097 def __eq__(self, other): 1098 if isinstance(other, ComparatorNotImplemented): 1099 return self.value == other.value 1100 return False 1101 def __lt__(self, other): 1102 return NotImplemented 1103 1104 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1105 ImplementsLessThan(-1) < 1 1106 1107 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1108 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1109 1110 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1111 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1112 1113 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1114 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1115 1116 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1117 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1118 1119 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1120 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1121 1122 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1123 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1124 1125 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1126 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1127 1128 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1129 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1130 1131 with self.subTest("GE when equal"): 1132 a = ComparatorNotImplemented(8) 1133 b = ComparatorNotImplemented(8) 1134 self.assertEqual(a, b) 1135 with self.assertRaises(TypeError): 1136 a >= b 1137 1138 with self.subTest("LE when equal"): 1139 a = ComparatorNotImplemented(9) 1140 b = ComparatorNotImplemented(9) 1141 self.assertEqual(a, b) 1142 with self.assertRaises(TypeError): 1143 a <= b 1144 1145 def test_pickle(self): 1146 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1147 for name in '__lt__', '__gt__', '__le__', '__ge__': 1148 with self.subTest(method=name, proto=proto): 1149 method = getattr(Orderable_LT, name) 1150 method_copy = pickle.loads(pickle.dumps(method, proto)) 1151 self.assertIs(method_copy, method) 1152 1153@functools.total_ordering 1154class Orderable_LT: 1155 def __init__(self, value): 1156 self.value = value 1157 def __lt__(self, other): 1158 return self.value < other.value 1159 def __eq__(self, other): 1160 return self.value == other.value 1161 1162 1163class TestLRU: 1164 1165 def test_lru(self): 1166 def orig(x, y): 1167 return 3 * x + y 1168 f = self.module.lru_cache(maxsize=20)(orig) 1169 hits, misses, maxsize, currsize = f.cache_info() 1170 self.assertEqual(maxsize, 20) 1171 self.assertEqual(currsize, 0) 1172 self.assertEqual(hits, 0) 1173 self.assertEqual(misses, 0) 1174 1175 domain = range(5) 1176 for i in range(1000): 1177 x, y = choice(domain), choice(domain) 1178 actual = f(x, y) 1179 expected = orig(x, y) 1180 self.assertEqual(actual, expected) 1181 hits, misses, maxsize, currsize = f.cache_info() 1182 self.assertTrue(hits > misses) 1183 self.assertEqual(hits + misses, 1000) 1184 self.assertEqual(currsize, 20) 1185 1186 f.cache_clear() # test clearing 1187 hits, misses, maxsize, currsize = f.cache_info() 1188 self.assertEqual(hits, 0) 1189 self.assertEqual(misses, 0) 1190 self.assertEqual(currsize, 0) 1191 f(x, y) 1192 hits, misses, maxsize, currsize = f.cache_info() 1193 self.assertEqual(hits, 0) 1194 self.assertEqual(misses, 1) 1195 self.assertEqual(currsize, 1) 1196 1197 # Test bypassing the cache 1198 self.assertIs(f.__wrapped__, orig) 1199 f.__wrapped__(x, y) 1200 hits, misses, maxsize, currsize = f.cache_info() 1201 self.assertEqual(hits, 0) 1202 self.assertEqual(misses, 1) 1203 self.assertEqual(currsize, 1) 1204 1205 # test size zero (which means "never-cache") 1206 @self.module.lru_cache(0) 1207 def f(): 1208 nonlocal f_cnt 1209 f_cnt += 1 1210 return 20 1211 self.assertEqual(f.cache_info().maxsize, 0) 1212 f_cnt = 0 1213 for i in range(5): 1214 self.assertEqual(f(), 20) 1215 self.assertEqual(f_cnt, 5) 1216 hits, misses, maxsize, currsize = f.cache_info() 1217 self.assertEqual(hits, 0) 1218 self.assertEqual(misses, 5) 1219 self.assertEqual(currsize, 0) 1220 1221 # test size one 1222 @self.module.lru_cache(1) 1223 def f(): 1224 nonlocal f_cnt 1225 f_cnt += 1 1226 return 20 1227 self.assertEqual(f.cache_info().maxsize, 1) 1228 f_cnt = 0 1229 for i in range(5): 1230 self.assertEqual(f(), 20) 1231 self.assertEqual(f_cnt, 1) 1232 hits, misses, maxsize, currsize = f.cache_info() 1233 self.assertEqual(hits, 4) 1234 self.assertEqual(misses, 1) 1235 self.assertEqual(currsize, 1) 1236 1237 # test size two 1238 @self.module.lru_cache(2) 1239 def f(x): 1240 nonlocal f_cnt 1241 f_cnt += 1 1242 return x*10 1243 self.assertEqual(f.cache_info().maxsize, 2) 1244 f_cnt = 0 1245 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1246 # * * * * 1247 self.assertEqual(f(x), x*10) 1248 self.assertEqual(f_cnt, 4) 1249 hits, misses, maxsize, currsize = f.cache_info() 1250 self.assertEqual(hits, 12) 1251 self.assertEqual(misses, 4) 1252 self.assertEqual(currsize, 2) 1253 1254 def test_lru_no_args(self): 1255 @self.module.lru_cache 1256 def square(x): 1257 return x ** 2 1258 1259 self.assertEqual(list(map(square, [10, 20, 10])), 1260 [100, 400, 100]) 1261 self.assertEqual(square.cache_info().hits, 1) 1262 self.assertEqual(square.cache_info().misses, 2) 1263 self.assertEqual(square.cache_info().maxsize, 128) 1264 self.assertEqual(square.cache_info().currsize, 2) 1265 1266 def test_lru_bug_35780(self): 1267 # C version of the lru_cache was not checking to see if 1268 # the user function call has already modified the cache 1269 # (this arises in recursive calls and in multi-threading). 1270 # This cause the cache to have orphan links not referenced 1271 # by the cache dictionary. 1272 1273 once = True # Modified by f(x) below 1274 1275 @self.module.lru_cache(maxsize=10) 1276 def f(x): 1277 nonlocal once 1278 rv = f'.{x}.' 1279 if x == 20 and once: 1280 once = False 1281 rv = f(x) 1282 return rv 1283 1284 # Fill the cache 1285 for x in range(15): 1286 self.assertEqual(f(x), f'.{x}.') 1287 self.assertEqual(f.cache_info().currsize, 10) 1288 1289 # Make a recursive call and make sure the cache remains full 1290 self.assertEqual(f(20), '.20.') 1291 self.assertEqual(f.cache_info().currsize, 10) 1292 1293 def test_lru_bug_36650(self): 1294 # C version of lru_cache was treating a call with an empty **kwargs 1295 # dictionary as being distinct from a call with no keywords at all. 1296 # This did not result in an incorrect answer, but it did trigger 1297 # an unexpected cache miss. 1298 1299 @self.module.lru_cache() 1300 def f(x): 1301 pass 1302 1303 f(0) 1304 f(0, **{}) 1305 self.assertEqual(f.cache_info().hits, 1) 1306 1307 def test_lru_hash_only_once(self): 1308 # To protect against weird reentrancy bugs and to improve 1309 # efficiency when faced with slow __hash__ methods, the 1310 # LRU cache guarantees that it will only call __hash__ 1311 # only once per use as an argument to the cached function. 1312 1313 @self.module.lru_cache(maxsize=1) 1314 def f(x, y): 1315 return x * 3 + y 1316 1317 # Simulate the integer 5 1318 mock_int = unittest.mock.Mock() 1319 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1320 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1321 1322 # Add to cache: One use as an argument gives one call 1323 self.assertEqual(f(mock_int, 1), 16) 1324 self.assertEqual(mock_int.__hash__.call_count, 1) 1325 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1326 1327 # Cache hit: One use as an argument gives one additional call 1328 self.assertEqual(f(mock_int, 1), 16) 1329 self.assertEqual(mock_int.__hash__.call_count, 2) 1330 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1331 1332 # Cache eviction: No use as an argument gives no additional call 1333 self.assertEqual(f(6, 2), 20) 1334 self.assertEqual(mock_int.__hash__.call_count, 2) 1335 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1336 1337 # Cache miss: One use as an argument gives one additional call 1338 self.assertEqual(f(mock_int, 1), 16) 1339 self.assertEqual(mock_int.__hash__.call_count, 3) 1340 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1341 1342 def test_lru_reentrancy_with_len(self): 1343 # Test to make sure the LRU cache code isn't thrown-off by 1344 # caching the built-in len() function. Since len() can be 1345 # cached, we shouldn't use it inside the lru code itself. 1346 old_len = builtins.len 1347 try: 1348 builtins.len = self.module.lru_cache(4)(len) 1349 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1350 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1351 finally: 1352 builtins.len = old_len 1353 1354 def test_lru_star_arg_handling(self): 1355 # Test regression that arose in ea064ff3c10f 1356 @functools.lru_cache() 1357 def f(*args): 1358 return args 1359 1360 self.assertEqual(f(1, 2), (1, 2)) 1361 self.assertEqual(f((1, 2)), ((1, 2),)) 1362 1363 def test_lru_type_error(self): 1364 # Regression test for issue #28653. 1365 # lru_cache was leaking when one of the arguments 1366 # wasn't cacheable. 1367 1368 @functools.lru_cache(maxsize=None) 1369 def infinite_cache(o): 1370 pass 1371 1372 @functools.lru_cache(maxsize=10) 1373 def limited_cache(o): 1374 pass 1375 1376 with self.assertRaises(TypeError): 1377 infinite_cache([]) 1378 1379 with self.assertRaises(TypeError): 1380 limited_cache([]) 1381 1382 def test_lru_with_maxsize_none(self): 1383 @self.module.lru_cache(maxsize=None) 1384 def fib(n): 1385 if n < 2: 1386 return n 1387 return fib(n-1) + fib(n-2) 1388 self.assertEqual([fib(n) for n in range(16)], 1389 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1390 self.assertEqual(fib.cache_info(), 1391 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1392 fib.cache_clear() 1393 self.assertEqual(fib.cache_info(), 1394 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1395 1396 def test_lru_with_maxsize_negative(self): 1397 @self.module.lru_cache(maxsize=-10) 1398 def eq(n): 1399 return n 1400 for i in (0, 1): 1401 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1402 self.assertEqual(eq.cache_info(), 1403 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1404 1405 def test_lru_with_exceptions(self): 1406 # Verify that user_function exceptions get passed through without 1407 # creating a hard-to-read chained exception. 1408 # http://bugs.python.org/issue13177 1409 for maxsize in (None, 128): 1410 @self.module.lru_cache(maxsize) 1411 def func(i): 1412 return 'abc'[i] 1413 self.assertEqual(func(0), 'a') 1414 with self.assertRaises(IndexError) as cm: 1415 func(15) 1416 self.assertIsNone(cm.exception.__context__) 1417 # Verify that the previous exception did not result in a cached entry 1418 with self.assertRaises(IndexError): 1419 func(15) 1420 1421 def test_lru_with_types(self): 1422 for maxsize in (None, 128): 1423 @self.module.lru_cache(maxsize=maxsize, typed=True) 1424 def square(x): 1425 return x * x 1426 self.assertEqual(square(3), 9) 1427 self.assertEqual(type(square(3)), type(9)) 1428 self.assertEqual(square(3.0), 9.0) 1429 self.assertEqual(type(square(3.0)), type(9.0)) 1430 self.assertEqual(square(x=3), 9) 1431 self.assertEqual(type(square(x=3)), type(9)) 1432 self.assertEqual(square(x=3.0), 9.0) 1433 self.assertEqual(type(square(x=3.0)), type(9.0)) 1434 self.assertEqual(square.cache_info().hits, 4) 1435 self.assertEqual(square.cache_info().misses, 4) 1436 1437 def test_lru_with_keyword_args(self): 1438 @self.module.lru_cache() 1439 def fib(n): 1440 if n < 2: 1441 return n 1442 return fib(n=n-1) + fib(n=n-2) 1443 self.assertEqual( 1444 [fib(n=number) for number in range(16)], 1445 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1446 ) 1447 self.assertEqual(fib.cache_info(), 1448 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1449 fib.cache_clear() 1450 self.assertEqual(fib.cache_info(), 1451 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1452 1453 def test_lru_with_keyword_args_maxsize_none(self): 1454 @self.module.lru_cache(maxsize=None) 1455 def fib(n): 1456 if n < 2: 1457 return n 1458 return fib(n=n-1) + fib(n=n-2) 1459 self.assertEqual([fib(n=number) for number in range(16)], 1460 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1461 self.assertEqual(fib.cache_info(), 1462 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1463 fib.cache_clear() 1464 self.assertEqual(fib.cache_info(), 1465 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1466 1467 def test_kwargs_order(self): 1468 # PEP 468: Preserving Keyword Argument Order 1469 @self.module.lru_cache(maxsize=10) 1470 def f(**kwargs): 1471 return list(kwargs.items()) 1472 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1473 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1474 self.assertEqual(f.cache_info(), 1475 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1476 1477 def test_lru_cache_decoration(self): 1478 def f(zomg: 'zomg_annotation'): 1479 """f doc string""" 1480 return 42 1481 g = self.module.lru_cache()(f) 1482 for attr in self.module.WRAPPER_ASSIGNMENTS: 1483 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1484 1485 def test_lru_cache_threaded(self): 1486 n, m = 5, 11 1487 def orig(x, y): 1488 return 3 * x + y 1489 f = self.module.lru_cache(maxsize=n*m)(orig) 1490 hits, misses, maxsize, currsize = f.cache_info() 1491 self.assertEqual(currsize, 0) 1492 1493 start = threading.Event() 1494 def full(k): 1495 start.wait(10) 1496 for _ in range(m): 1497 self.assertEqual(f(k, 0), orig(k, 0)) 1498 1499 def clear(): 1500 start.wait(10) 1501 for _ in range(2*m): 1502 f.cache_clear() 1503 1504 orig_si = sys.getswitchinterval() 1505 support.setswitchinterval(1e-6) 1506 try: 1507 # create n threads in order to fill cache 1508 threads = [threading.Thread(target=full, args=[k]) 1509 for k in range(n)] 1510 with support.start_threads(threads): 1511 start.set() 1512 1513 hits, misses, maxsize, currsize = f.cache_info() 1514 if self.module is py_functools: 1515 # XXX: Why can be not equal? 1516 self.assertLessEqual(misses, n) 1517 self.assertLessEqual(hits, m*n - misses) 1518 else: 1519 self.assertEqual(misses, n) 1520 self.assertEqual(hits, m*n - misses) 1521 self.assertEqual(currsize, n) 1522 1523 # create n threads in order to fill cache and 1 to clear it 1524 threads = [threading.Thread(target=clear)] 1525 threads += [threading.Thread(target=full, args=[k]) 1526 for k in range(n)] 1527 start.clear() 1528 with support.start_threads(threads): 1529 start.set() 1530 finally: 1531 sys.setswitchinterval(orig_si) 1532 1533 def test_lru_cache_threaded2(self): 1534 # Simultaneous call with the same arguments 1535 n, m = 5, 7 1536 start = threading.Barrier(n+1) 1537 pause = threading.Barrier(n+1) 1538 stop = threading.Barrier(n+1) 1539 @self.module.lru_cache(maxsize=m*n) 1540 def f(x): 1541 pause.wait(10) 1542 return 3 * x 1543 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1544 def test(): 1545 for i in range(m): 1546 start.wait(10) 1547 self.assertEqual(f(i), 3 * i) 1548 stop.wait(10) 1549 threads = [threading.Thread(target=test) for k in range(n)] 1550 with support.start_threads(threads): 1551 for i in range(m): 1552 start.wait(10) 1553 stop.reset() 1554 pause.wait(10) 1555 start.reset() 1556 stop.wait(10) 1557 pause.reset() 1558 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1559 1560 def test_lru_cache_threaded3(self): 1561 @self.module.lru_cache(maxsize=2) 1562 def f(x): 1563 time.sleep(.01) 1564 return 3 * x 1565 def test(i, x): 1566 with self.subTest(thread=i): 1567 self.assertEqual(f(x), 3 * x, i) 1568 threads = [threading.Thread(target=test, args=(i, v)) 1569 for i, v in enumerate([1, 2, 2, 3, 2])] 1570 with support.start_threads(threads): 1571 pass 1572 1573 def test_need_for_rlock(self): 1574 # This will deadlock on an LRU cache that uses a regular lock 1575 1576 @self.module.lru_cache(maxsize=10) 1577 def test_func(x): 1578 'Used to demonstrate a reentrant lru_cache call within a single thread' 1579 return x 1580 1581 class DoubleEq: 1582 'Demonstrate a reentrant lru_cache call within a single thread' 1583 def __init__(self, x): 1584 self.x = x 1585 def __hash__(self): 1586 return self.x 1587 def __eq__(self, other): 1588 if self.x == 2: 1589 test_func(DoubleEq(1)) 1590 return self.x == other.x 1591 1592 test_func(DoubleEq(1)) # Load the cache 1593 test_func(DoubleEq(2)) # Load the cache 1594 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1595 DoubleEq(2)) # Verify the correct return value 1596 1597 def test_lru_method(self): 1598 class X(int): 1599 f_cnt = 0 1600 @self.module.lru_cache(2) 1601 def f(self, x): 1602 self.f_cnt += 1 1603 return x*10+self 1604 a = X(5) 1605 b = X(5) 1606 c = X(7) 1607 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1608 1609 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1610 self.assertEqual(a.f(x), x*10 + 5) 1611 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1612 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1613 1614 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1615 self.assertEqual(b.f(x), x*10 + 5) 1616 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1617 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1618 1619 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1620 self.assertEqual(c.f(x), x*10 + 7) 1621 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1622 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1623 1624 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1625 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1626 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1627 1628 def test_pickle(self): 1629 cls = self.__class__ 1630 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1631 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1632 with self.subTest(proto=proto, func=f): 1633 f_copy = pickle.loads(pickle.dumps(f, proto)) 1634 self.assertIs(f_copy, f) 1635 1636 def test_copy(self): 1637 cls = self.__class__ 1638 def orig(x, y): 1639 return 3 * x + y 1640 part = self.module.partial(orig, 2) 1641 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1642 self.module.lru_cache(2)(part)) 1643 for f in funcs: 1644 with self.subTest(func=f): 1645 f_copy = copy.copy(f) 1646 self.assertIs(f_copy, f) 1647 1648 def test_deepcopy(self): 1649 cls = self.__class__ 1650 def orig(x, y): 1651 return 3 * x + y 1652 part = self.module.partial(orig, 2) 1653 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1654 self.module.lru_cache(2)(part)) 1655 for f in funcs: 1656 with self.subTest(func=f): 1657 f_copy = copy.deepcopy(f) 1658 self.assertIs(f_copy, f) 1659 1660 1661@py_functools.lru_cache() 1662def py_cached_func(x, y): 1663 return 3 * x + y 1664 1665@c_functools.lru_cache() 1666def c_cached_func(x, y): 1667 return 3 * x + y 1668 1669 1670class TestLRUPy(TestLRU, unittest.TestCase): 1671 module = py_functools 1672 cached_func = py_cached_func, 1673 1674 @module.lru_cache() 1675 def cached_meth(self, x, y): 1676 return 3 * x + y 1677 1678 @staticmethod 1679 @module.lru_cache() 1680 def cached_staticmeth(x, y): 1681 return 3 * x + y 1682 1683 1684class TestLRUC(TestLRU, unittest.TestCase): 1685 module = c_functools 1686 cached_func = c_cached_func, 1687 1688 @module.lru_cache() 1689 def cached_meth(self, x, y): 1690 return 3 * x + y 1691 1692 @staticmethod 1693 @module.lru_cache() 1694 def cached_staticmeth(x, y): 1695 return 3 * x + y 1696 1697 1698class TestSingleDispatch(unittest.TestCase): 1699 def test_simple_overloads(self): 1700 @functools.singledispatch 1701 def g(obj): 1702 return "base" 1703 def g_int(i): 1704 return "integer" 1705 g.register(int, g_int) 1706 self.assertEqual(g("str"), "base") 1707 self.assertEqual(g(1), "integer") 1708 self.assertEqual(g([1,2,3]), "base") 1709 1710 def test_mro(self): 1711 @functools.singledispatch 1712 def g(obj): 1713 return "base" 1714 class A: 1715 pass 1716 class C(A): 1717 pass 1718 class B(A): 1719 pass 1720 class D(C, B): 1721 pass 1722 def g_A(a): 1723 return "A" 1724 def g_B(b): 1725 return "B" 1726 g.register(A, g_A) 1727 g.register(B, g_B) 1728 self.assertEqual(g(A()), "A") 1729 self.assertEqual(g(B()), "B") 1730 self.assertEqual(g(C()), "A") 1731 self.assertEqual(g(D()), "B") 1732 1733 def test_register_decorator(self): 1734 @functools.singledispatch 1735 def g(obj): 1736 return "base" 1737 @g.register(int) 1738 def g_int(i): 1739 return "int %s" % (i,) 1740 self.assertEqual(g(""), "base") 1741 self.assertEqual(g(12), "int 12") 1742 self.assertIs(g.dispatch(int), g_int) 1743 self.assertIs(g.dispatch(object), g.dispatch(str)) 1744 # Note: in the assert above this is not g. 1745 # @singledispatch returns the wrapper. 1746 1747 def test_wrapping_attributes(self): 1748 @functools.singledispatch 1749 def g(obj): 1750 "Simple test" 1751 return "Test" 1752 self.assertEqual(g.__name__, "g") 1753 if sys.flags.optimize < 2: 1754 self.assertEqual(g.__doc__, "Simple test") 1755 1756 @unittest.skipUnless(decimal, 'requires _decimal') 1757 @support.cpython_only 1758 def test_c_classes(self): 1759 @functools.singledispatch 1760 def g(obj): 1761 return "base" 1762 @g.register(decimal.DecimalException) 1763 def _(obj): 1764 return obj.args 1765 subn = decimal.Subnormal("Exponent < Emin") 1766 rnd = decimal.Rounded("Number got rounded") 1767 self.assertEqual(g(subn), ("Exponent < Emin",)) 1768 self.assertEqual(g(rnd), ("Number got rounded",)) 1769 @g.register(decimal.Subnormal) 1770 def _(obj): 1771 return "Too small to care." 1772 self.assertEqual(g(subn), "Too small to care.") 1773 self.assertEqual(g(rnd), ("Number got rounded",)) 1774 1775 def test_compose_mro(self): 1776 # None of the examples in this test depend on haystack ordering. 1777 c = collections.abc 1778 mro = functools._compose_mro 1779 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1780 for haystack in permutations(bases): 1781 m = mro(dict, haystack) 1782 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1783 c.Collection, c.Sized, c.Iterable, 1784 c.Container, object]) 1785 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1786 for haystack in permutations(bases): 1787 m = mro(collections.ChainMap, haystack) 1788 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1789 c.Collection, c.Sized, c.Iterable, 1790 c.Container, object]) 1791 1792 # If there's a generic function with implementations registered for 1793 # both Sized and Container, passing a defaultdict to it results in an 1794 # ambiguous dispatch which will cause a RuntimeError (see 1795 # test_mro_conflicts). 1796 bases = [c.Container, c.Sized, str] 1797 for haystack in permutations(bases): 1798 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1799 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1800 c.Container, object]) 1801 1802 # MutableSequence below is registered directly on D. In other words, it 1803 # precedes MutableMapping which means single dispatch will always 1804 # choose MutableSequence here. 1805 class D(collections.defaultdict): 1806 pass 1807 c.MutableSequence.register(D) 1808 bases = [c.MutableSequence, c.MutableMapping] 1809 for haystack in permutations(bases): 1810 m = mro(D, bases) 1811 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1812 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1813 c.Collection, c.Sized, c.Iterable, c.Container, 1814 object]) 1815 1816 # Container and Callable are registered on different base classes and 1817 # a generic function supporting both should always pick the Callable 1818 # implementation if a C instance is passed. 1819 class C(collections.defaultdict): 1820 def __call__(self): 1821 pass 1822 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1823 for haystack in permutations(bases): 1824 m = mro(C, haystack) 1825 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1826 c.Collection, c.Sized, c.Iterable, 1827 c.Container, object]) 1828 1829 def test_register_abc(self): 1830 c = collections.abc 1831 d = {"a": "b"} 1832 l = [1, 2, 3] 1833 s = {object(), None} 1834 f = frozenset(s) 1835 t = (1, 2, 3) 1836 @functools.singledispatch 1837 def g(obj): 1838 return "base" 1839 self.assertEqual(g(d), "base") 1840 self.assertEqual(g(l), "base") 1841 self.assertEqual(g(s), "base") 1842 self.assertEqual(g(f), "base") 1843 self.assertEqual(g(t), "base") 1844 g.register(c.Sized, lambda obj: "sized") 1845 self.assertEqual(g(d), "sized") 1846 self.assertEqual(g(l), "sized") 1847 self.assertEqual(g(s), "sized") 1848 self.assertEqual(g(f), "sized") 1849 self.assertEqual(g(t), "sized") 1850 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1851 self.assertEqual(g(d), "mutablemapping") 1852 self.assertEqual(g(l), "sized") 1853 self.assertEqual(g(s), "sized") 1854 self.assertEqual(g(f), "sized") 1855 self.assertEqual(g(t), "sized") 1856 g.register(collections.ChainMap, lambda obj: "chainmap") 1857 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1858 self.assertEqual(g(l), "sized") 1859 self.assertEqual(g(s), "sized") 1860 self.assertEqual(g(f), "sized") 1861 self.assertEqual(g(t), "sized") 1862 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1863 self.assertEqual(g(d), "mutablemapping") 1864 self.assertEqual(g(l), "mutablesequence") 1865 self.assertEqual(g(s), "sized") 1866 self.assertEqual(g(f), "sized") 1867 self.assertEqual(g(t), "sized") 1868 g.register(c.MutableSet, lambda obj: "mutableset") 1869 self.assertEqual(g(d), "mutablemapping") 1870 self.assertEqual(g(l), "mutablesequence") 1871 self.assertEqual(g(s), "mutableset") 1872 self.assertEqual(g(f), "sized") 1873 self.assertEqual(g(t), "sized") 1874 g.register(c.Mapping, lambda obj: "mapping") 1875 self.assertEqual(g(d), "mutablemapping") # not specific enough 1876 self.assertEqual(g(l), "mutablesequence") 1877 self.assertEqual(g(s), "mutableset") 1878 self.assertEqual(g(f), "sized") 1879 self.assertEqual(g(t), "sized") 1880 g.register(c.Sequence, lambda obj: "sequence") 1881 self.assertEqual(g(d), "mutablemapping") 1882 self.assertEqual(g(l), "mutablesequence") 1883 self.assertEqual(g(s), "mutableset") 1884 self.assertEqual(g(f), "sized") 1885 self.assertEqual(g(t), "sequence") 1886 g.register(c.Set, lambda obj: "set") 1887 self.assertEqual(g(d), "mutablemapping") 1888 self.assertEqual(g(l), "mutablesequence") 1889 self.assertEqual(g(s), "mutableset") 1890 self.assertEqual(g(f), "set") 1891 self.assertEqual(g(t), "sequence") 1892 g.register(dict, lambda obj: "dict") 1893 self.assertEqual(g(d), "dict") 1894 self.assertEqual(g(l), "mutablesequence") 1895 self.assertEqual(g(s), "mutableset") 1896 self.assertEqual(g(f), "set") 1897 self.assertEqual(g(t), "sequence") 1898 g.register(list, lambda obj: "list") 1899 self.assertEqual(g(d), "dict") 1900 self.assertEqual(g(l), "list") 1901 self.assertEqual(g(s), "mutableset") 1902 self.assertEqual(g(f), "set") 1903 self.assertEqual(g(t), "sequence") 1904 g.register(set, lambda obj: "concrete-set") 1905 self.assertEqual(g(d), "dict") 1906 self.assertEqual(g(l), "list") 1907 self.assertEqual(g(s), "concrete-set") 1908 self.assertEqual(g(f), "set") 1909 self.assertEqual(g(t), "sequence") 1910 g.register(frozenset, lambda obj: "frozen-set") 1911 self.assertEqual(g(d), "dict") 1912 self.assertEqual(g(l), "list") 1913 self.assertEqual(g(s), "concrete-set") 1914 self.assertEqual(g(f), "frozen-set") 1915 self.assertEqual(g(t), "sequence") 1916 g.register(tuple, lambda obj: "tuple") 1917 self.assertEqual(g(d), "dict") 1918 self.assertEqual(g(l), "list") 1919 self.assertEqual(g(s), "concrete-set") 1920 self.assertEqual(g(f), "frozen-set") 1921 self.assertEqual(g(t), "tuple") 1922 1923 def test_c3_abc(self): 1924 c = collections.abc 1925 mro = functools._c3_mro 1926 class A(object): 1927 pass 1928 class B(A): 1929 def __len__(self): 1930 return 0 # implies Sized 1931 @c.Container.register 1932 class C(object): 1933 pass 1934 class D(object): 1935 pass # unrelated 1936 class X(D, C, B): 1937 def __call__(self): 1938 pass # implies Callable 1939 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 1940 for abcs in permutations([c.Sized, c.Callable, c.Container]): 1941 self.assertEqual(mro(X, abcs=abcs), expected) 1942 # unrelated ABCs don't appear in the resulting MRO 1943 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 1944 self.assertEqual(mro(X, abcs=many_abcs), expected) 1945 1946 def test_false_meta(self): 1947 # see issue23572 1948 class MetaA(type): 1949 def __len__(self): 1950 return 0 1951 class A(metaclass=MetaA): 1952 pass 1953 class AA(A): 1954 pass 1955 @functools.singledispatch 1956 def fun(a): 1957 return 'base A' 1958 @fun.register(A) 1959 def _(a): 1960 return 'fun A' 1961 aa = AA() 1962 self.assertEqual(fun(aa), 'fun A') 1963 1964 def test_mro_conflicts(self): 1965 c = collections.abc 1966 @functools.singledispatch 1967 def g(arg): 1968 return "base" 1969 class O(c.Sized): 1970 def __len__(self): 1971 return 0 1972 o = O() 1973 self.assertEqual(g(o), "base") 1974 g.register(c.Iterable, lambda arg: "iterable") 1975 g.register(c.Container, lambda arg: "container") 1976 g.register(c.Sized, lambda arg: "sized") 1977 g.register(c.Set, lambda arg: "set") 1978 self.assertEqual(g(o), "sized") 1979 c.Iterable.register(O) 1980 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 1981 c.Container.register(O) 1982 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 1983 c.Set.register(O) 1984 self.assertEqual(g(o), "set") # because c.Set is a subclass of 1985 # c.Sized and c.Container 1986 class P: 1987 pass 1988 p = P() 1989 self.assertEqual(g(p), "base") 1990 c.Iterable.register(P) 1991 self.assertEqual(g(p), "iterable") 1992 c.Container.register(P) 1993 with self.assertRaises(RuntimeError) as re_one: 1994 g(p) 1995 self.assertIn( 1996 str(re_one.exception), 1997 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1998 "or <class 'collections.abc.Iterable'>"), 1999 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2000 "or <class 'collections.abc.Container'>")), 2001 ) 2002 class Q(c.Sized): 2003 def __len__(self): 2004 return 0 2005 q = Q() 2006 self.assertEqual(g(q), "sized") 2007 c.Iterable.register(Q) 2008 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2009 c.Set.register(Q) 2010 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2011 # c.Sized and c.Iterable 2012 @functools.singledispatch 2013 def h(arg): 2014 return "base" 2015 @h.register(c.Sized) 2016 def _(arg): 2017 return "sized" 2018 @h.register(c.Container) 2019 def _(arg): 2020 return "container" 2021 # Even though Sized and Container are explicit bases of MutableMapping, 2022 # this ABC is implicitly registered on defaultdict which makes all of 2023 # MutableMapping's bases implicit as well from defaultdict's 2024 # perspective. 2025 with self.assertRaises(RuntimeError) as re_two: 2026 h(collections.defaultdict(lambda: 0)) 2027 self.assertIn( 2028 str(re_two.exception), 2029 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2030 "or <class 'collections.abc.Sized'>"), 2031 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2032 "or <class 'collections.abc.Container'>")), 2033 ) 2034 class R(collections.defaultdict): 2035 pass 2036 c.MutableSequence.register(R) 2037 @functools.singledispatch 2038 def i(arg): 2039 return "base" 2040 @i.register(c.MutableMapping) 2041 def _(arg): 2042 return "mapping" 2043 @i.register(c.MutableSequence) 2044 def _(arg): 2045 return "sequence" 2046 r = R() 2047 self.assertEqual(i(r), "sequence") 2048 class S: 2049 pass 2050 class T(S, c.Sized): 2051 def __len__(self): 2052 return 0 2053 t = T() 2054 self.assertEqual(h(t), "sized") 2055 c.Container.register(T) 2056 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2057 class U: 2058 def __len__(self): 2059 return 0 2060 u = U() 2061 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2062 # from the existence of __len__() 2063 c.Container.register(U) 2064 # There is no preference for registered versus inferred ABCs. 2065 with self.assertRaises(RuntimeError) as re_three: 2066 h(u) 2067 self.assertIn( 2068 str(re_three.exception), 2069 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2070 "or <class 'collections.abc.Sized'>"), 2071 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2072 "or <class 'collections.abc.Container'>")), 2073 ) 2074 class V(c.Sized, S): 2075 def __len__(self): 2076 return 0 2077 @functools.singledispatch 2078 def j(arg): 2079 return "base" 2080 @j.register(S) 2081 def _(arg): 2082 return "s" 2083 @j.register(c.Container) 2084 def _(arg): 2085 return "container" 2086 v = V() 2087 self.assertEqual(j(v), "s") 2088 c.Container.register(V) 2089 self.assertEqual(j(v), "container") # because it ends up right after 2090 # Sized in the MRO 2091 2092 def test_cache_invalidation(self): 2093 from collections import UserDict 2094 import weakref 2095 2096 class TracingDict(UserDict): 2097 def __init__(self, *args, **kwargs): 2098 super(TracingDict, self).__init__(*args, **kwargs) 2099 self.set_ops = [] 2100 self.get_ops = [] 2101 def __getitem__(self, key): 2102 result = self.data[key] 2103 self.get_ops.append(key) 2104 return result 2105 def __setitem__(self, key, value): 2106 self.set_ops.append(key) 2107 self.data[key] = value 2108 def clear(self): 2109 self.data.clear() 2110 2111 td = TracingDict() 2112 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2113 c = collections.abc 2114 @functools.singledispatch 2115 def g(arg): 2116 return "base" 2117 d = {} 2118 l = [] 2119 self.assertEqual(len(td), 0) 2120 self.assertEqual(g(d), "base") 2121 self.assertEqual(len(td), 1) 2122 self.assertEqual(td.get_ops, []) 2123 self.assertEqual(td.set_ops, [dict]) 2124 self.assertEqual(td.data[dict], g.registry[object]) 2125 self.assertEqual(g(l), "base") 2126 self.assertEqual(len(td), 2) 2127 self.assertEqual(td.get_ops, []) 2128 self.assertEqual(td.set_ops, [dict, list]) 2129 self.assertEqual(td.data[dict], g.registry[object]) 2130 self.assertEqual(td.data[list], g.registry[object]) 2131 self.assertEqual(td.data[dict], td.data[list]) 2132 self.assertEqual(g(l), "base") 2133 self.assertEqual(g(d), "base") 2134 self.assertEqual(td.get_ops, [list, dict]) 2135 self.assertEqual(td.set_ops, [dict, list]) 2136 g.register(list, lambda arg: "list") 2137 self.assertEqual(td.get_ops, [list, dict]) 2138 self.assertEqual(len(td), 0) 2139 self.assertEqual(g(d), "base") 2140 self.assertEqual(len(td), 1) 2141 self.assertEqual(td.get_ops, [list, dict]) 2142 self.assertEqual(td.set_ops, [dict, list, dict]) 2143 self.assertEqual(td.data[dict], 2144 functools._find_impl(dict, g.registry)) 2145 self.assertEqual(g(l), "list") 2146 self.assertEqual(len(td), 2) 2147 self.assertEqual(td.get_ops, [list, dict]) 2148 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2149 self.assertEqual(td.data[list], 2150 functools._find_impl(list, g.registry)) 2151 class X: 2152 pass 2153 c.MutableMapping.register(X) # Will not invalidate the cache, 2154 # not using ABCs yet. 2155 self.assertEqual(g(d), "base") 2156 self.assertEqual(g(l), "list") 2157 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2158 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2159 g.register(c.Sized, lambda arg: "sized") 2160 self.assertEqual(len(td), 0) 2161 self.assertEqual(g(d), "sized") 2162 self.assertEqual(len(td), 1) 2163 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2164 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2165 self.assertEqual(g(l), "list") 2166 self.assertEqual(len(td), 2) 2167 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2168 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2169 self.assertEqual(g(l), "list") 2170 self.assertEqual(g(d), "sized") 2171 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2172 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2173 g.dispatch(list) 2174 g.dispatch(dict) 2175 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2176 list, dict]) 2177 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2178 c.MutableSet.register(X) # Will invalidate the cache. 2179 self.assertEqual(len(td), 2) # Stale cache. 2180 self.assertEqual(g(l), "list") 2181 self.assertEqual(len(td), 1) 2182 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2183 self.assertEqual(len(td), 0) 2184 self.assertEqual(g(d), "mutablemapping") 2185 self.assertEqual(len(td), 1) 2186 self.assertEqual(g(l), "list") 2187 self.assertEqual(len(td), 2) 2188 g.register(dict, lambda arg: "dict") 2189 self.assertEqual(g(d), "dict") 2190 self.assertEqual(g(l), "list") 2191 g._clear_cache() 2192 self.assertEqual(len(td), 0) 2193 2194 def test_annotations(self): 2195 @functools.singledispatch 2196 def i(arg): 2197 return "base" 2198 @i.register 2199 def _(arg: collections.abc.Mapping): 2200 return "mapping" 2201 @i.register 2202 def _(arg: "collections.abc.Sequence"): 2203 return "sequence" 2204 self.assertEqual(i(None), "base") 2205 self.assertEqual(i({"a": 1}), "mapping") 2206 self.assertEqual(i([1, 2, 3]), "sequence") 2207 self.assertEqual(i((1, 2, 3)), "sequence") 2208 self.assertEqual(i("str"), "sequence") 2209 2210 # Registering classes as callables doesn't work with annotations, 2211 # you need to pass the type explicitly. 2212 @i.register(str) 2213 class _: 2214 def __init__(self, arg): 2215 self.arg = arg 2216 2217 def __eq__(self, other): 2218 return self.arg == other 2219 self.assertEqual(i("str"), "str") 2220 2221 def test_method_register(self): 2222 class A: 2223 @functools.singledispatchmethod 2224 def t(self, arg): 2225 self.arg = "base" 2226 @t.register(int) 2227 def _(self, arg): 2228 self.arg = "int" 2229 @t.register(str) 2230 def _(self, arg): 2231 self.arg = "str" 2232 a = A() 2233 2234 a.t(0) 2235 self.assertEqual(a.arg, "int") 2236 aa = A() 2237 self.assertFalse(hasattr(aa, 'arg')) 2238 a.t('') 2239 self.assertEqual(a.arg, "str") 2240 aa = A() 2241 self.assertFalse(hasattr(aa, 'arg')) 2242 a.t(0.0) 2243 self.assertEqual(a.arg, "base") 2244 aa = A() 2245 self.assertFalse(hasattr(aa, 'arg')) 2246 2247 def test_staticmethod_register(self): 2248 class A: 2249 @functools.singledispatchmethod 2250 @staticmethod 2251 def t(arg): 2252 return arg 2253 @t.register(int) 2254 @staticmethod 2255 def _(arg): 2256 return isinstance(arg, int) 2257 @t.register(str) 2258 @staticmethod 2259 def _(arg): 2260 return isinstance(arg, str) 2261 a = A() 2262 2263 self.assertTrue(A.t(0)) 2264 self.assertTrue(A.t('')) 2265 self.assertEqual(A.t(0.0), 0.0) 2266 2267 def test_classmethod_register(self): 2268 class A: 2269 def __init__(self, arg): 2270 self.arg = arg 2271 2272 @functools.singledispatchmethod 2273 @classmethod 2274 def t(cls, arg): 2275 return cls("base") 2276 @t.register(int) 2277 @classmethod 2278 def _(cls, arg): 2279 return cls("int") 2280 @t.register(str) 2281 @classmethod 2282 def _(cls, arg): 2283 return cls("str") 2284 2285 self.assertEqual(A.t(0).arg, "int") 2286 self.assertEqual(A.t('').arg, "str") 2287 self.assertEqual(A.t(0.0).arg, "base") 2288 2289 def test_callable_register(self): 2290 class A: 2291 def __init__(self, arg): 2292 self.arg = arg 2293 2294 @functools.singledispatchmethod 2295 @classmethod 2296 def t(cls, arg): 2297 return cls("base") 2298 2299 @A.t.register(int) 2300 @classmethod 2301 def _(cls, arg): 2302 return cls("int") 2303 @A.t.register(str) 2304 @classmethod 2305 def _(cls, arg): 2306 return cls("str") 2307 2308 self.assertEqual(A.t(0).arg, "int") 2309 self.assertEqual(A.t('').arg, "str") 2310 self.assertEqual(A.t(0.0).arg, "base") 2311 2312 def test_abstractmethod_register(self): 2313 class Abstract(abc.ABCMeta): 2314 2315 @functools.singledispatchmethod 2316 @abc.abstractmethod 2317 def add(self, x, y): 2318 pass 2319 2320 self.assertTrue(Abstract.add.__isabstractmethod__) 2321 2322 def test_type_ann_register(self): 2323 class A: 2324 @functools.singledispatchmethod 2325 def t(self, arg): 2326 return "base" 2327 @t.register 2328 def _(self, arg: int): 2329 return "int" 2330 @t.register 2331 def _(self, arg: str): 2332 return "str" 2333 a = A() 2334 2335 self.assertEqual(a.t(0), "int") 2336 self.assertEqual(a.t(''), "str") 2337 self.assertEqual(a.t(0.0), "base") 2338 2339 def test_invalid_registrations(self): 2340 msg_prefix = "Invalid first argument to `register()`: " 2341 msg_suffix = ( 2342 ". Use either `@register(some_class)` or plain `@register` on an " 2343 "annotated function." 2344 ) 2345 @functools.singledispatch 2346 def i(arg): 2347 return "base" 2348 with self.assertRaises(TypeError) as exc: 2349 @i.register(42) 2350 def _(arg): 2351 return "I annotated with a non-type" 2352 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2353 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2354 with self.assertRaises(TypeError) as exc: 2355 @i.register 2356 def _(arg): 2357 return "I forgot to annotate" 2358 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2359 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2360 )) 2361 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2362 2363 with self.assertRaises(TypeError) as exc: 2364 @i.register 2365 def _(arg: typing.Iterable[str]): 2366 # At runtime, dispatching on generics is impossible. 2367 # When registering implementations with singledispatch, avoid 2368 # types from `typing`. Instead, annotate with regular types 2369 # or ABCs. 2370 return "I annotated with a generic collection" 2371 self.assertTrue(str(exc.exception).startswith( 2372 "Invalid annotation for 'arg'." 2373 )) 2374 self.assertTrue(str(exc.exception).endswith( 2375 'typing.Iterable[str] is not a class.' 2376 )) 2377 2378 def test_invalid_positional_argument(self): 2379 @functools.singledispatch 2380 def f(*args): 2381 pass 2382 msg = 'f requires at least 1 positional argument' 2383 with self.assertRaisesRegex(TypeError, msg): 2384 f() 2385 2386 2387class CachedCostItem: 2388 _cost = 1 2389 2390 def __init__(self): 2391 self.lock = py_functools.RLock() 2392 2393 @py_functools.cached_property 2394 def cost(self): 2395 """The cost of the item.""" 2396 with self.lock: 2397 self._cost += 1 2398 return self._cost 2399 2400 2401class OptionallyCachedCostItem: 2402 _cost = 1 2403 2404 def get_cost(self): 2405 """The cost of the item.""" 2406 self._cost += 1 2407 return self._cost 2408 2409 cached_cost = py_functools.cached_property(get_cost) 2410 2411 2412class CachedCostItemWait: 2413 2414 def __init__(self, event): 2415 self._cost = 1 2416 self.lock = py_functools.RLock() 2417 self.event = event 2418 2419 @py_functools.cached_property 2420 def cost(self): 2421 self.event.wait(1) 2422 with self.lock: 2423 self._cost += 1 2424 return self._cost 2425 2426 2427class CachedCostItemWithSlots: 2428 __slots__ = ('_cost') 2429 2430 def __init__(self): 2431 self._cost = 1 2432 2433 @py_functools.cached_property 2434 def cost(self): 2435 raise RuntimeError('never called, slots not supported') 2436 2437 2438class TestCachedProperty(unittest.TestCase): 2439 def test_cached(self): 2440 item = CachedCostItem() 2441 self.assertEqual(item.cost, 2) 2442 self.assertEqual(item.cost, 2) # not 3 2443 2444 def test_cached_attribute_name_differs_from_func_name(self): 2445 item = OptionallyCachedCostItem() 2446 self.assertEqual(item.get_cost(), 2) 2447 self.assertEqual(item.cached_cost, 3) 2448 self.assertEqual(item.get_cost(), 4) 2449 self.assertEqual(item.cached_cost, 3) 2450 2451 def test_threaded(self): 2452 go = threading.Event() 2453 item = CachedCostItemWait(go) 2454 2455 num_threads = 3 2456 2457 orig_si = sys.getswitchinterval() 2458 sys.setswitchinterval(1e-6) 2459 try: 2460 threads = [ 2461 threading.Thread(target=lambda: item.cost) 2462 for k in range(num_threads) 2463 ] 2464 with support.start_threads(threads): 2465 go.set() 2466 finally: 2467 sys.setswitchinterval(orig_si) 2468 2469 self.assertEqual(item.cost, 2) 2470 2471 def test_object_with_slots(self): 2472 item = CachedCostItemWithSlots() 2473 with self.assertRaisesRegex( 2474 TypeError, 2475 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2476 ): 2477 item.cost 2478 2479 def test_immutable_dict(self): 2480 class MyMeta(type): 2481 @py_functools.cached_property 2482 def prop(self): 2483 return True 2484 2485 class MyClass(metaclass=MyMeta): 2486 pass 2487 2488 with self.assertRaisesRegex( 2489 TypeError, 2490 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 2491 ): 2492 MyClass.prop 2493 2494 def test_reuse_different_names(self): 2495 """Disallow this case because decorated function a would not be cached.""" 2496 with self.assertRaises(RuntimeError) as ctx: 2497 class ReusedCachedProperty: 2498 @py_functools.cached_property 2499 def a(self): 2500 pass 2501 2502 b = a 2503 2504 self.assertEqual( 2505 str(ctx.exception.__context__), 2506 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 2507 ) 2508 2509 def test_reuse_same_name(self): 2510 """Reusing a cached_property on different classes under the same name is OK.""" 2511 counter = 0 2512 2513 @py_functools.cached_property 2514 def _cp(_self): 2515 nonlocal counter 2516 counter += 1 2517 return counter 2518 2519 class A: 2520 cp = _cp 2521 2522 class B: 2523 cp = _cp 2524 2525 a = A() 2526 b = B() 2527 2528 self.assertEqual(a.cp, 1) 2529 self.assertEqual(b.cp, 2) 2530 self.assertEqual(a.cp, 1) 2531 2532 def test_set_name_not_called(self): 2533 cp = py_functools.cached_property(lambda s: None) 2534 class Foo: 2535 pass 2536 2537 Foo.cp = cp 2538 2539 with self.assertRaisesRegex( 2540 TypeError, 2541 "Cannot use cached_property instance without calling __set_name__ on it.", 2542 ): 2543 Foo().cp 2544 2545 def test_access_from_class(self): 2546 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 2547 2548 def test_doc(self): 2549 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 2550 2551 2552if __name__ == '__main__': 2553 unittest.main() 2554