1"""functools.py - Tools for working with functions and callable objects 2""" 3# Python module wrapper for _functools C module 4# to allow utilities written in Python to be added 5# to the functools module. 6# Written by Nick Coghlan <ncoghlan at gmail.com>, 7# Raymond Hettinger <python at rcn.com>, 8# and Łukasz Langa <lukasz at langa.pl>. 9# Copyright (C) 2006-2013 Python Software Foundation. 10# See C source code for _functools credits/copyright 11 12__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', 13 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', 14 'partialmethod', 'singledispatch'] 15 16try: 17 from _functools import reduce 18except ImportError: 19 pass 20from abc import get_cache_token 21from collections import namedtuple 22# import types, weakref # Deferred to single_dispatch() 23from reprlib import recursive_repr 24from _thread import RLock 25 26 27################################################################################ 28### update_wrapper() and wraps() decorator 29################################################################################ 30 31# update_wrapper() and wraps() are tools to help write 32# wrapper functions that can handle naive introspection 33 34WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', 35 '__annotations__') 36WRAPPER_UPDATES = ('__dict__',) 37def update_wrapper(wrapper, 38 wrapped, 39 assigned = WRAPPER_ASSIGNMENTS, 40 updated = WRAPPER_UPDATES): 41 """Update a wrapper function to look like the wrapped function 42 43 wrapper is the function to be updated 44 wrapped is the original function 45 assigned is a tuple naming the attributes assigned directly 46 from the wrapped function to the wrapper function (defaults to 47 functools.WRAPPER_ASSIGNMENTS) 48 updated is a tuple naming the attributes of the wrapper that 49 are updated with the corresponding attribute from the wrapped 50 function (defaults to functools.WRAPPER_UPDATES) 51 """ 52 for attr in assigned: 53 try: 54 value = getattr(wrapped, attr) 55 except AttributeError: 56 pass 57 else: 58 setattr(wrapper, attr, value) 59 for attr in updated: 60 getattr(wrapper, attr).update(getattr(wrapped, attr, {})) 61 # Issue #17482: set __wrapped__ last so we don't inadvertently copy it 62 # from the wrapped function when updating __dict__ 63 wrapper.__wrapped__ = wrapped 64 # Return the wrapper so this can be used as a decorator via partial() 65 return wrapper 66 67def wraps(wrapped, 68 assigned = WRAPPER_ASSIGNMENTS, 69 updated = WRAPPER_UPDATES): 70 """Decorator factory to apply update_wrapper() to a wrapper function 71 72 Returns a decorator that invokes update_wrapper() with the decorated 73 function as the wrapper argument and the arguments to wraps() as the 74 remaining arguments. Default arguments are as for update_wrapper(). 75 This is a convenience function to simplify applying partial() to 76 update_wrapper(). 77 """ 78 return partial(update_wrapper, wrapped=wrapped, 79 assigned=assigned, updated=updated) 80 81 82################################################################################ 83### total_ordering class decorator 84################################################################################ 85 86# The total ordering functions all invoke the root magic method directly 87# rather than using the corresponding operator. This avoids possible 88# infinite recursion that could occur when the operator dispatch logic 89# detects a NotImplemented result and then calls a reflected method. 90 91def _gt_from_lt(self, other, NotImplemented=NotImplemented): 92 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' 93 op_result = self.__lt__(other) 94 if op_result is NotImplemented: 95 return op_result 96 return not op_result and self != other 97 98def _le_from_lt(self, other, NotImplemented=NotImplemented): 99 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' 100 op_result = self.__lt__(other) 101 return op_result or self == other 102 103def _ge_from_lt(self, other, NotImplemented=NotImplemented): 104 'Return a >= b. Computed by @total_ordering from (not a < b).' 105 op_result = self.__lt__(other) 106 if op_result is NotImplemented: 107 return op_result 108 return not op_result 109 110def _ge_from_le(self, other, NotImplemented=NotImplemented): 111 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' 112 op_result = self.__le__(other) 113 if op_result is NotImplemented: 114 return op_result 115 return not op_result or self == other 116 117def _lt_from_le(self, other, NotImplemented=NotImplemented): 118 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' 119 op_result = self.__le__(other) 120 if op_result is NotImplemented: 121 return op_result 122 return op_result and self != other 123 124def _gt_from_le(self, other, NotImplemented=NotImplemented): 125 'Return a > b. Computed by @total_ordering from (not a <= b).' 126 op_result = self.__le__(other) 127 if op_result is NotImplemented: 128 return op_result 129 return not op_result 130 131def _lt_from_gt(self, other, NotImplemented=NotImplemented): 132 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' 133 op_result = self.__gt__(other) 134 if op_result is NotImplemented: 135 return op_result 136 return not op_result and self != other 137 138def _ge_from_gt(self, other, NotImplemented=NotImplemented): 139 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' 140 op_result = self.__gt__(other) 141 return op_result or self == other 142 143def _le_from_gt(self, other, NotImplemented=NotImplemented): 144 'Return a <= b. Computed by @total_ordering from (not a > b).' 145 op_result = self.__gt__(other) 146 if op_result is NotImplemented: 147 return op_result 148 return not op_result 149 150def _le_from_ge(self, other, NotImplemented=NotImplemented): 151 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' 152 op_result = self.__ge__(other) 153 if op_result is NotImplemented: 154 return op_result 155 return not op_result or self == other 156 157def _gt_from_ge(self, other, NotImplemented=NotImplemented): 158 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' 159 op_result = self.__ge__(other) 160 if op_result is NotImplemented: 161 return op_result 162 return op_result and self != other 163 164def _lt_from_ge(self, other, NotImplemented=NotImplemented): 165 'Return a < b. Computed by @total_ordering from (not a >= b).' 166 op_result = self.__ge__(other) 167 if op_result is NotImplemented: 168 return op_result 169 return not op_result 170 171_convert = { 172 '__lt__': [('__gt__', _gt_from_lt), 173 ('__le__', _le_from_lt), 174 ('__ge__', _ge_from_lt)], 175 '__le__': [('__ge__', _ge_from_le), 176 ('__lt__', _lt_from_le), 177 ('__gt__', _gt_from_le)], 178 '__gt__': [('__lt__', _lt_from_gt), 179 ('__ge__', _ge_from_gt), 180 ('__le__', _le_from_gt)], 181 '__ge__': [('__le__', _le_from_ge), 182 ('__gt__', _gt_from_ge), 183 ('__lt__', _lt_from_ge)] 184} 185 186def total_ordering(cls): 187 """Class decorator that fills in missing ordering methods""" 188 # Find user-defined comparisons (not those inherited from object). 189 roots = {op for op in _convert if getattr(cls, op, None) is not getattr(object, op, None)} 190 if not roots: 191 raise ValueError('must define at least one ordering operation: < > <= >=') 192 root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ 193 for opname, opfunc in _convert[root]: 194 if opname not in roots: 195 opfunc.__name__ = opname 196 setattr(cls, opname, opfunc) 197 return cls 198 199 200################################################################################ 201### cmp_to_key() function converter 202################################################################################ 203 204def cmp_to_key(mycmp): 205 """Convert a cmp= function into a key= function""" 206 class K(object): 207 __slots__ = ['obj'] 208 def __init__(self, obj): 209 self.obj = obj 210 def __lt__(self, other): 211 return mycmp(self.obj, other.obj) < 0 212 def __gt__(self, other): 213 return mycmp(self.obj, other.obj) > 0 214 def __eq__(self, other): 215 return mycmp(self.obj, other.obj) == 0 216 def __le__(self, other): 217 return mycmp(self.obj, other.obj) <= 0 218 def __ge__(self, other): 219 return mycmp(self.obj, other.obj) >= 0 220 __hash__ = None 221 return K 222 223try: 224 from _functools import cmp_to_key 225except ImportError: 226 pass 227 228 229################################################################################ 230### partial() argument application 231################################################################################ 232 233# Purely functional, no descriptor behaviour 234class partial: 235 """New function with partial application of the given arguments 236 and keywords. 237 """ 238 239 __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" 240 241 def __new__(*args, **keywords): 242 if not args: 243 raise TypeError("descriptor '__new__' of partial needs an argument") 244 if len(args) < 2: 245 raise TypeError("type 'partial' takes at least one argument") 246 cls, func, *args = args 247 if not callable(func): 248 raise TypeError("the first argument must be callable") 249 args = tuple(args) 250 251 if hasattr(func, "func"): 252 args = func.args + args 253 tmpkw = func.keywords.copy() 254 tmpkw.update(keywords) 255 keywords = tmpkw 256 del tmpkw 257 func = func.func 258 259 self = super(partial, cls).__new__(cls) 260 261 self.func = func 262 self.args = args 263 self.keywords = keywords 264 return self 265 266 def __call__(*args, **keywords): 267 if not args: 268 raise TypeError("descriptor '__call__' of partial needs an argument") 269 self, *args = args 270 newkeywords = self.keywords.copy() 271 newkeywords.update(keywords) 272 return self.func(*self.args, *args, **newkeywords) 273 274 @recursive_repr() 275 def __repr__(self): 276 qualname = type(self).__qualname__ 277 args = [repr(self.func)] 278 args.extend(repr(x) for x in self.args) 279 args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items()) 280 if type(self).__module__ == "functools": 281 return f"functools.{qualname}({', '.join(args)})" 282 return f"{qualname}({', '.join(args)})" 283 284 def __reduce__(self): 285 return type(self), (self.func,), (self.func, self.args, 286 self.keywords or None, self.__dict__ or None) 287 288 def __setstate__(self, state): 289 if not isinstance(state, tuple): 290 raise TypeError("argument to __setstate__ must be a tuple") 291 if len(state) != 4: 292 raise TypeError(f"expected 4 items in state, got {len(state)}") 293 func, args, kwds, namespace = state 294 if (not callable(func) or not isinstance(args, tuple) or 295 (kwds is not None and not isinstance(kwds, dict)) or 296 (namespace is not None and not isinstance(namespace, dict))): 297 raise TypeError("invalid partial state") 298 299 args = tuple(args) # just in case it's a subclass 300 if kwds is None: 301 kwds = {} 302 elif type(kwds) is not dict: # XXX does it need to be *exactly* dict? 303 kwds = dict(kwds) 304 if namespace is None: 305 namespace = {} 306 307 self.__dict__ = namespace 308 self.func = func 309 self.args = args 310 self.keywords = kwds 311 312try: 313 from _functools import partial 314except ImportError: 315 pass 316 317# Descriptor version 318class partialmethod(object): 319 """Method descriptor with partial application of the given arguments 320 and keywords. 321 322 Supports wrapping existing descriptors and handles non-descriptor 323 callables as instance methods. 324 """ 325 326 def __init__(*args, **keywords): 327 if len(args) >= 2: 328 self, func, *args = args 329 elif not args: 330 raise TypeError("descriptor '__init__' of partialmethod " 331 "needs an argument") 332 elif 'func' in keywords: 333 func = keywords.pop('func') 334 self, *args = args 335 else: 336 raise TypeError("type 'partialmethod' takes at least one argument, " 337 "got %d" % (len(args)-1)) 338 args = tuple(args) 339 340 if not callable(func) and not hasattr(func, "__get__"): 341 raise TypeError("{!r} is not callable or a descriptor" 342 .format(func)) 343 344 # func could be a descriptor like classmethod which isn't callable, 345 # so we can't inherit from partial (it verifies func is callable) 346 if isinstance(func, partialmethod): 347 # flattening is mandatory in order to place cls/self before all 348 # other arguments 349 # it's also more efficient since only one function will be called 350 self.func = func.func 351 self.args = func.args + args 352 self.keywords = func.keywords.copy() 353 self.keywords.update(keywords) 354 else: 355 self.func = func 356 self.args = args 357 self.keywords = keywords 358 359 def __repr__(self): 360 args = ", ".join(map(repr, self.args)) 361 keywords = ", ".join("{}={!r}".format(k, v) 362 for k, v in self.keywords.items()) 363 format_string = "{module}.{cls}({func}, {args}, {keywords})" 364 return format_string.format(module=self.__class__.__module__, 365 cls=self.__class__.__qualname__, 366 func=self.func, 367 args=args, 368 keywords=keywords) 369 370 def _make_unbound_method(self): 371 def _method(*args, **keywords): 372 call_keywords = self.keywords.copy() 373 call_keywords.update(keywords) 374 cls_or_self, *rest = args 375 call_args = (cls_or_self,) + self.args + tuple(rest) 376 return self.func(*call_args, **call_keywords) 377 _method.__isabstractmethod__ = self.__isabstractmethod__ 378 _method._partialmethod = self 379 return _method 380 381 def __get__(self, obj, cls): 382 get = getattr(self.func, "__get__", None) 383 result = None 384 if get is not None: 385 new_func = get(obj, cls) 386 if new_func is not self.func: 387 # Assume __get__ returning something new indicates the 388 # creation of an appropriate callable 389 result = partial(new_func, *self.args, **self.keywords) 390 try: 391 result.__self__ = new_func.__self__ 392 except AttributeError: 393 pass 394 if result is None: 395 # If the underlying descriptor didn't do anything, treat this 396 # like an instance method 397 result = self._make_unbound_method().__get__(obj, cls) 398 return result 399 400 @property 401 def __isabstractmethod__(self): 402 return getattr(self.func, "__isabstractmethod__", False) 403 404 405################################################################################ 406### LRU Cache function decorator 407################################################################################ 408 409_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) 410 411class _HashedSeq(list): 412 """ This class guarantees that hash() will be called no more than once 413 per element. This is important because the lru_cache() will hash 414 the key multiple times on a cache miss. 415 416 """ 417 418 __slots__ = 'hashvalue' 419 420 def __init__(self, tup, hash=hash): 421 self[:] = tup 422 self.hashvalue = hash(tup) 423 424 def __hash__(self): 425 return self.hashvalue 426 427def _make_key(args, kwds, typed, 428 kwd_mark = (object(),), 429 fasttypes = {int, str}, 430 tuple=tuple, type=type, len=len): 431 """Make a cache key from optionally typed positional and keyword arguments 432 433 The key is constructed in a way that is flat as possible rather than 434 as a nested structure that would take more memory. 435 436 If there is only a single argument and its data type is known to cache 437 its hash value, then that argument is returned without a wrapper. This 438 saves space and improves lookup speed. 439 440 """ 441 # All of code below relies on kwds preserving the order input by the user. 442 # Formerly, we sorted() the kwds before looping. The new way is *much* 443 # faster; however, it means that f(x=1, y=2) will now be treated as a 444 # distinct call from f(y=2, x=1) which will be cached separately. 445 key = args 446 if kwds: 447 key += kwd_mark 448 for item in kwds.items(): 449 key += item 450 if typed: 451 key += tuple(type(v) for v in args) 452 if kwds: 453 key += tuple(type(v) for v in kwds.values()) 454 elif len(key) == 1 and type(key[0]) in fasttypes: 455 return key[0] 456 return _HashedSeq(key) 457 458def lru_cache(maxsize=128, typed=False): 459 """Least-recently-used cache decorator. 460 461 If *maxsize* is set to None, the LRU features are disabled and the cache 462 can grow without bound. 463 464 If *typed* is True, arguments of different types will be cached separately. 465 For example, f(3.0) and f(3) will be treated as distinct calls with 466 distinct results. 467 468 Arguments to the cached function must be hashable. 469 470 View the cache statistics named tuple (hits, misses, maxsize, currsize) 471 with f.cache_info(). Clear the cache and statistics with f.cache_clear(). 472 Access the underlying function with f.__wrapped__. 473 474 See: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) 475 476 """ 477 478 # Users should only access the lru_cache through its public API: 479 # cache_info, cache_clear, and f.__wrapped__ 480 # The internals of the lru_cache are encapsulated for thread safety and 481 # to allow the implementation to change (including a possible C version). 482 483 # Early detection of an erroneous call to @lru_cache without any arguments 484 # resulting in the inner function being passed to maxsize instead of an 485 # integer or None. Negative maxsize is treated as 0. 486 if isinstance(maxsize, int): 487 if maxsize < 0: 488 maxsize = 0 489 elif maxsize is not None: 490 raise TypeError('Expected maxsize to be an integer or None') 491 492 def decorating_function(user_function): 493 wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) 494 return update_wrapper(wrapper, user_function) 495 496 return decorating_function 497 498def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): 499 # Constants shared by all lru cache instances: 500 sentinel = object() # unique object used to signal cache misses 501 make_key = _make_key # build a key from the function arguments 502 PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields 503 504 cache = {} 505 hits = misses = 0 506 full = False 507 cache_get = cache.get # bound method to lookup a key or return None 508 cache_len = cache.__len__ # get cache size without calling len() 509 lock = RLock() # because linkedlist updates aren't threadsafe 510 root = [] # root of the circular doubly linked list 511 root[:] = [root, root, None, None] # initialize by pointing to self 512 513 if maxsize == 0: 514 515 def wrapper(*args, **kwds): 516 # No caching -- just a statistics update 517 nonlocal misses 518 misses += 1 519 result = user_function(*args, **kwds) 520 return result 521 522 elif maxsize is None: 523 524 def wrapper(*args, **kwds): 525 # Simple caching without ordering or size limit 526 nonlocal hits, misses 527 key = make_key(args, kwds, typed) 528 result = cache_get(key, sentinel) 529 if result is not sentinel: 530 hits += 1 531 return result 532 misses += 1 533 result = user_function(*args, **kwds) 534 cache[key] = result 535 return result 536 537 else: 538 539 def wrapper(*args, **kwds): 540 # Size limited caching that tracks accesses by recency 541 nonlocal root, hits, misses, full 542 key = make_key(args, kwds, typed) 543 with lock: 544 link = cache_get(key) 545 if link is not None: 546 # Move the link to the front of the circular queue 547 link_prev, link_next, _key, result = link 548 link_prev[NEXT] = link_next 549 link_next[PREV] = link_prev 550 last = root[PREV] 551 last[NEXT] = root[PREV] = link 552 link[PREV] = last 553 link[NEXT] = root 554 hits += 1 555 return result 556 misses += 1 557 result = user_function(*args, **kwds) 558 with lock: 559 if key in cache: 560 # Getting here means that this same key was added to the 561 # cache while the lock was released. Since the link 562 # update is already done, we need only return the 563 # computed result and update the count of misses. 564 pass 565 elif full: 566 # Use the old root to store the new key and result. 567 oldroot = root 568 oldroot[KEY] = key 569 oldroot[RESULT] = result 570 # Empty the oldest link and make it the new root. 571 # Keep a reference to the old key and old result to 572 # prevent their ref counts from going to zero during the 573 # update. That will prevent potentially arbitrary object 574 # clean-up code (i.e. __del__) from running while we're 575 # still adjusting the links. 576 root = oldroot[NEXT] 577 oldkey = root[KEY] 578 oldresult = root[RESULT] 579 root[KEY] = root[RESULT] = None 580 # Now update the cache dictionary. 581 del cache[oldkey] 582 # Save the potentially reentrant cache[key] assignment 583 # for last, after the root and links have been put in 584 # a consistent state. 585 cache[key] = oldroot 586 else: 587 # Put result in a new link at the front of the queue. 588 last = root[PREV] 589 link = [last, root, key, result] 590 last[NEXT] = root[PREV] = cache[key] = link 591 # Use the cache_len bound method instead of the len() function 592 # which could potentially be wrapped in an lru_cache itself. 593 full = (cache_len() >= maxsize) 594 return result 595 596 def cache_info(): 597 """Report cache statistics""" 598 with lock: 599 return _CacheInfo(hits, misses, maxsize, cache_len()) 600 601 def cache_clear(): 602 """Clear the cache and cache statistics""" 603 nonlocal hits, misses, full 604 with lock: 605 cache.clear() 606 root[:] = [root, root, None, None] 607 hits = misses = 0 608 full = False 609 610 wrapper.cache_info = cache_info 611 wrapper.cache_clear = cache_clear 612 return wrapper 613 614try: 615 from _functools import _lru_cache_wrapper 616except ImportError: 617 pass 618 619 620################################################################################ 621### singledispatch() - single-dispatch generic function decorator 622################################################################################ 623 624def _c3_merge(sequences): 625 """Merges MROs in *sequences* to a single MRO using the C3 algorithm. 626 627 Adapted from http://www.python.org/download/releases/2.3/mro/. 628 629 """ 630 result = [] 631 while True: 632 sequences = [s for s in sequences if s] # purge empty sequences 633 if not sequences: 634 return result 635 for s1 in sequences: # find merge candidates among seq heads 636 candidate = s1[0] 637 for s2 in sequences: 638 if candidate in s2[1:]: 639 candidate = None 640 break # reject the current head, it appears later 641 else: 642 break 643 if candidate is None: 644 raise RuntimeError("Inconsistent hierarchy") 645 result.append(candidate) 646 # remove the chosen candidate 647 for seq in sequences: 648 if seq[0] == candidate: 649 del seq[0] 650 651def _c3_mro(cls, abcs=None): 652 """Computes the method resolution order using extended C3 linearization. 653 654 If no *abcs* are given, the algorithm works exactly like the built-in C3 655 linearization used for method resolution. 656 657 If given, *abcs* is a list of abstract base classes that should be inserted 658 into the resulting MRO. Unrelated ABCs are ignored and don't end up in the 659 result. The algorithm inserts ABCs where their functionality is introduced, 660 i.e. issubclass(cls, abc) returns True for the class itself but returns 661 False for all its direct base classes. Implicit ABCs for a given class 662 (either registered or inferred from the presence of a special method like 663 __len__) are inserted directly after the last ABC explicitly listed in the 664 MRO of said class. If two implicit ABCs end up next to each other in the 665 resulting MRO, their ordering depends on the order of types in *abcs*. 666 667 """ 668 for i, base in enumerate(reversed(cls.__bases__)): 669 if hasattr(base, '__abstractmethods__'): 670 boundary = len(cls.__bases__) - i 671 break # Bases up to the last explicit ABC are considered first. 672 else: 673 boundary = 0 674 abcs = list(abcs) if abcs else [] 675 explicit_bases = list(cls.__bases__[:boundary]) 676 abstract_bases = [] 677 other_bases = list(cls.__bases__[boundary:]) 678 for base in abcs: 679 if issubclass(cls, base) and not any( 680 issubclass(b, base) for b in cls.__bases__ 681 ): 682 # If *cls* is the class that introduces behaviour described by 683 # an ABC *base*, insert said ABC to its MRO. 684 abstract_bases.append(base) 685 for base in abstract_bases: 686 abcs.remove(base) 687 explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] 688 abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] 689 other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] 690 return _c3_merge( 691 [[cls]] + 692 explicit_c3_mros + abstract_c3_mros + other_c3_mros + 693 [explicit_bases] + [abstract_bases] + [other_bases] 694 ) 695 696def _compose_mro(cls, types): 697 """Calculates the method resolution order for a given class *cls*. 698 699 Includes relevant abstract base classes (with their respective bases) from 700 the *types* iterable. Uses a modified C3 linearization algorithm. 701 702 """ 703 bases = set(cls.__mro__) 704 # Remove entries which are already present in the __mro__ or unrelated. 705 def is_related(typ): 706 return (typ not in bases and hasattr(typ, '__mro__') 707 and issubclass(cls, typ)) 708 types = [n for n in types if is_related(n)] 709 # Remove entries which are strict bases of other entries (they will end up 710 # in the MRO anyway. 711 def is_strict_base(typ): 712 for other in types: 713 if typ != other and typ in other.__mro__: 714 return True 715 return False 716 types = [n for n in types if not is_strict_base(n)] 717 # Subclasses of the ABCs in *types* which are also implemented by 718 # *cls* can be used to stabilize ABC ordering. 719 type_set = set(types) 720 mro = [] 721 for typ in types: 722 found = [] 723 for sub in typ.__subclasses__(): 724 if sub not in bases and issubclass(cls, sub): 725 found.append([s for s in sub.__mro__ if s in type_set]) 726 if not found: 727 mro.append(typ) 728 continue 729 # Favor subclasses with the biggest number of useful bases 730 found.sort(key=len, reverse=True) 731 for sub in found: 732 for subcls in sub: 733 if subcls not in mro: 734 mro.append(subcls) 735 return _c3_mro(cls, abcs=mro) 736 737def _find_impl(cls, registry): 738 """Returns the best matching implementation from *registry* for type *cls*. 739 740 Where there is no registered implementation for a specific type, its method 741 resolution order is used to find a more generic implementation. 742 743 Note: if *registry* does not contain an implementation for the base 744 *object* type, this function may return None. 745 746 """ 747 mro = _compose_mro(cls, registry.keys()) 748 match = None 749 for t in mro: 750 if match is not None: 751 # If *match* is an implicit ABC but there is another unrelated, 752 # equally matching implicit ABC, refuse the temptation to guess. 753 if (t in registry and t not in cls.__mro__ 754 and match not in cls.__mro__ 755 and not issubclass(match, t)): 756 raise RuntimeError("Ambiguous dispatch: {} or {}".format( 757 match, t)) 758 break 759 if t in registry: 760 match = t 761 return registry.get(match) 762 763def singledispatch(func): 764 """Single-dispatch generic function decorator. 765 766 Transforms a function into a generic function, which can have different 767 behaviours depending upon the type of its first argument. The decorated 768 function acts as the default implementation, and additional 769 implementations can be registered using the register() attribute of the 770 generic function. 771 """ 772 # There are many programs that use functools without singledispatch, so we 773 # trade-off making singledispatch marginally slower for the benefit of 774 # making start-up of such applications slightly faster. 775 import types, weakref 776 777 registry = {} 778 dispatch_cache = weakref.WeakKeyDictionary() 779 cache_token = None 780 781 def dispatch(cls): 782 """generic_func.dispatch(cls) -> <function implementation> 783 784 Runs the dispatch algorithm to return the best available implementation 785 for the given *cls* registered on *generic_func*. 786 787 """ 788 nonlocal cache_token 789 if cache_token is not None: 790 current_token = get_cache_token() 791 if cache_token != current_token: 792 dispatch_cache.clear() 793 cache_token = current_token 794 try: 795 impl = dispatch_cache[cls] 796 except KeyError: 797 try: 798 impl = registry[cls] 799 except KeyError: 800 impl = _find_impl(cls, registry) 801 dispatch_cache[cls] = impl 802 return impl 803 804 def register(cls, func=None): 805 """generic_func.register(cls, func) -> func 806 807 Registers a new implementation for the given *cls* on a *generic_func*. 808 809 """ 810 nonlocal cache_token 811 if func is None: 812 if isinstance(cls, type): 813 return lambda f: register(cls, f) 814 ann = getattr(cls, '__annotations__', {}) 815 if not ann: 816 raise TypeError( 817 f"Invalid first argument to `register()`: {cls!r}. " 818 f"Use either `@register(some_class)` or plain `@register` " 819 f"on an annotated function." 820 ) 821 func = cls 822 823 # only import typing if annotation parsing is necessary 824 from typing import get_type_hints 825 argname, cls = next(iter(get_type_hints(func).items())) 826 assert isinstance(cls, type), ( 827 f"Invalid annotation for {argname!r}. {cls!r} is not a class." 828 ) 829 registry[cls] = func 830 if cache_token is None and hasattr(cls, '__abstractmethods__'): 831 cache_token = get_cache_token() 832 dispatch_cache.clear() 833 return func 834 835 def wrapper(*args, **kw): 836 if not args: 837 raise TypeError(f'{funcname} requires at least ' 838 '1 positional argument') 839 840 return dispatch(args[0].__class__)(*args, **kw) 841 842 funcname = getattr(func, '__name__', 'singledispatch function') 843 registry[object] = func 844 wrapper.register = register 845 wrapper.dispatch = dispatch 846 wrapper.registry = types.MappingProxyType(registry) 847 wrapper._clear_cache = dispatch_cache.clear 848 update_wrapper(wrapper, func) 849 return wrapper 850