1"""functools.py - Tools for working with functions and callable objects
2"""
3# Python module wrapper for _functools C module
4# to allow utilities written in Python to be added
5# to the functools module.
6# Written by Nick Coghlan <ncoghlan at gmail.com>,
7# Raymond Hettinger <python at rcn.com>,
8# and Łukasz Langa <lukasz at langa.pl>.
9#   Copyright (C) 2006-2013 Python Software Foundation.
10# See C source code for _functools credits/copyright
11
12__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
13           'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial',
14           'partialmethod', 'singledispatch']
15
16try:
17    from _functools import reduce
18except ImportError:
19    pass
20from abc import get_cache_token
21from collections import namedtuple
22# import types, weakref  # Deferred to single_dispatch()
23from reprlib import recursive_repr
24from _thread import RLock
25
26
27################################################################################
28### update_wrapper() and wraps() decorator
29################################################################################
30
31# update_wrapper() and wraps() are tools to help write
32# wrapper functions that can handle naive introspection
33
34WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
35                       '__annotations__')
36WRAPPER_UPDATES = ('__dict__',)
37def update_wrapper(wrapper,
38                   wrapped,
39                   assigned = WRAPPER_ASSIGNMENTS,
40                   updated = WRAPPER_UPDATES):
41    """Update a wrapper function to look like the wrapped function
42
43       wrapper is the function to be updated
44       wrapped is the original function
45       assigned is a tuple naming the attributes assigned directly
46       from the wrapped function to the wrapper function (defaults to
47       functools.WRAPPER_ASSIGNMENTS)
48       updated is a tuple naming the attributes of the wrapper that
49       are updated with the corresponding attribute from the wrapped
50       function (defaults to functools.WRAPPER_UPDATES)
51    """
52    for attr in assigned:
53        try:
54            value = getattr(wrapped, attr)
55        except AttributeError:
56            pass
57        else:
58            setattr(wrapper, attr, value)
59    for attr in updated:
60        getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
61    # Issue #17482: set __wrapped__ last so we don't inadvertently copy it
62    # from the wrapped function when updating __dict__
63    wrapper.__wrapped__ = wrapped
64    # Return the wrapper so this can be used as a decorator via partial()
65    return wrapper
66
67def wraps(wrapped,
68          assigned = WRAPPER_ASSIGNMENTS,
69          updated = WRAPPER_UPDATES):
70    """Decorator factory to apply update_wrapper() to a wrapper function
71
72       Returns a decorator that invokes update_wrapper() with the decorated
73       function as the wrapper argument and the arguments to wraps() as the
74       remaining arguments. Default arguments are as for update_wrapper().
75       This is a convenience function to simplify applying partial() to
76       update_wrapper().
77    """
78    return partial(update_wrapper, wrapped=wrapped,
79                   assigned=assigned, updated=updated)
80
81
82################################################################################
83### total_ordering class decorator
84################################################################################
85
86# The total ordering functions all invoke the root magic method directly
87# rather than using the corresponding operator.  This avoids possible
88# infinite recursion that could occur when the operator dispatch logic
89# detects a NotImplemented result and then calls a reflected method.
90
91def _gt_from_lt(self, other, NotImplemented=NotImplemented):
92    'Return a > b.  Computed by @total_ordering from (not a < b) and (a != b).'
93    op_result = self.__lt__(other)
94    if op_result is NotImplemented:
95        return op_result
96    return not op_result and self != other
97
98def _le_from_lt(self, other, NotImplemented=NotImplemented):
99    'Return a <= b.  Computed by @total_ordering from (a < b) or (a == b).'
100    op_result = self.__lt__(other)
101    return op_result or self == other
102
103def _ge_from_lt(self, other, NotImplemented=NotImplemented):
104    'Return a >= b.  Computed by @total_ordering from (not a < b).'
105    op_result = self.__lt__(other)
106    if op_result is NotImplemented:
107        return op_result
108    return not op_result
109
110def _ge_from_le(self, other, NotImplemented=NotImplemented):
111    'Return a >= b.  Computed by @total_ordering from (not a <= b) or (a == b).'
112    op_result = self.__le__(other)
113    if op_result is NotImplemented:
114        return op_result
115    return not op_result or self == other
116
117def _lt_from_le(self, other, NotImplemented=NotImplemented):
118    'Return a < b.  Computed by @total_ordering from (a <= b) and (a != b).'
119    op_result = self.__le__(other)
120    if op_result is NotImplemented:
121        return op_result
122    return op_result and self != other
123
124def _gt_from_le(self, other, NotImplemented=NotImplemented):
125    'Return a > b.  Computed by @total_ordering from (not a <= b).'
126    op_result = self.__le__(other)
127    if op_result is NotImplemented:
128        return op_result
129    return not op_result
130
131def _lt_from_gt(self, other, NotImplemented=NotImplemented):
132    'Return a < b.  Computed by @total_ordering from (not a > b) and (a != b).'
133    op_result = self.__gt__(other)
134    if op_result is NotImplemented:
135        return op_result
136    return not op_result and self != other
137
138def _ge_from_gt(self, other, NotImplemented=NotImplemented):
139    'Return a >= b.  Computed by @total_ordering from (a > b) or (a == b).'
140    op_result = self.__gt__(other)
141    return op_result or self == other
142
143def _le_from_gt(self, other, NotImplemented=NotImplemented):
144    'Return a <= b.  Computed by @total_ordering from (not a > b).'
145    op_result = self.__gt__(other)
146    if op_result is NotImplemented:
147        return op_result
148    return not op_result
149
150def _le_from_ge(self, other, NotImplemented=NotImplemented):
151    'Return a <= b.  Computed by @total_ordering from (not a >= b) or (a == b).'
152    op_result = self.__ge__(other)
153    if op_result is NotImplemented:
154        return op_result
155    return not op_result or self == other
156
157def _gt_from_ge(self, other, NotImplemented=NotImplemented):
158    'Return a > b.  Computed by @total_ordering from (a >= b) and (a != b).'
159    op_result = self.__ge__(other)
160    if op_result is NotImplemented:
161        return op_result
162    return op_result and self != other
163
164def _lt_from_ge(self, other, NotImplemented=NotImplemented):
165    'Return a < b.  Computed by @total_ordering from (not a >= b).'
166    op_result = self.__ge__(other)
167    if op_result is NotImplemented:
168        return op_result
169    return not op_result
170
171_convert = {
172    '__lt__': [('__gt__', _gt_from_lt),
173               ('__le__', _le_from_lt),
174               ('__ge__', _ge_from_lt)],
175    '__le__': [('__ge__', _ge_from_le),
176               ('__lt__', _lt_from_le),
177               ('__gt__', _gt_from_le)],
178    '__gt__': [('__lt__', _lt_from_gt),
179               ('__ge__', _ge_from_gt),
180               ('__le__', _le_from_gt)],
181    '__ge__': [('__le__', _le_from_ge),
182               ('__gt__', _gt_from_ge),
183               ('__lt__', _lt_from_ge)]
184}
185
186def total_ordering(cls):
187    """Class decorator that fills in missing ordering methods"""
188    # Find user-defined comparisons (not those inherited from object).
189    roots = {op for op in _convert if getattr(cls, op, None) is not getattr(object, op, None)}
190    if not roots:
191        raise ValueError('must define at least one ordering operation: < > <= >=')
192    root = max(roots)       # prefer __lt__ to __le__ to __gt__ to __ge__
193    for opname, opfunc in _convert[root]:
194        if opname not in roots:
195            opfunc.__name__ = opname
196            setattr(cls, opname, opfunc)
197    return cls
198
199
200################################################################################
201### cmp_to_key() function converter
202################################################################################
203
204def cmp_to_key(mycmp):
205    """Convert a cmp= function into a key= function"""
206    class K(object):
207        __slots__ = ['obj']
208        def __init__(self, obj):
209            self.obj = obj
210        def __lt__(self, other):
211            return mycmp(self.obj, other.obj) < 0
212        def __gt__(self, other):
213            return mycmp(self.obj, other.obj) > 0
214        def __eq__(self, other):
215            return mycmp(self.obj, other.obj) == 0
216        def __le__(self, other):
217            return mycmp(self.obj, other.obj) <= 0
218        def __ge__(self, other):
219            return mycmp(self.obj, other.obj) >= 0
220        __hash__ = None
221    return K
222
223try:
224    from _functools import cmp_to_key
225except ImportError:
226    pass
227
228
229################################################################################
230### partial() argument application
231################################################################################
232
233# Purely functional, no descriptor behaviour
234class partial:
235    """New function with partial application of the given arguments
236    and keywords.
237    """
238
239    __slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
240
241    def __new__(*args, **keywords):
242        if not args:
243            raise TypeError("descriptor '__new__' of partial needs an argument")
244        if len(args) < 2:
245            raise TypeError("type 'partial' takes at least one argument")
246        cls, func, *args = args
247        if not callable(func):
248            raise TypeError("the first argument must be callable")
249        args = tuple(args)
250
251        if hasattr(func, "func"):
252            args = func.args + args
253            tmpkw = func.keywords.copy()
254            tmpkw.update(keywords)
255            keywords = tmpkw
256            del tmpkw
257            func = func.func
258
259        self = super(partial, cls).__new__(cls)
260
261        self.func = func
262        self.args = args
263        self.keywords = keywords
264        return self
265
266    def __call__(*args, **keywords):
267        if not args:
268            raise TypeError("descriptor '__call__' of partial needs an argument")
269        self, *args = args
270        newkeywords = self.keywords.copy()
271        newkeywords.update(keywords)
272        return self.func(*self.args, *args, **newkeywords)
273
274    @recursive_repr()
275    def __repr__(self):
276        qualname = type(self).__qualname__
277        args = [repr(self.func)]
278        args.extend(repr(x) for x in self.args)
279        args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
280        if type(self).__module__ == "functools":
281            return f"functools.{qualname}({', '.join(args)})"
282        return f"{qualname}({', '.join(args)})"
283
284    def __reduce__(self):
285        return type(self), (self.func,), (self.func, self.args,
286               self.keywords or None, self.__dict__ or None)
287
288    def __setstate__(self, state):
289        if not isinstance(state, tuple):
290            raise TypeError("argument to __setstate__ must be a tuple")
291        if len(state) != 4:
292            raise TypeError(f"expected 4 items in state, got {len(state)}")
293        func, args, kwds, namespace = state
294        if (not callable(func) or not isinstance(args, tuple) or
295           (kwds is not None and not isinstance(kwds, dict)) or
296           (namespace is not None and not isinstance(namespace, dict))):
297            raise TypeError("invalid partial state")
298
299        args = tuple(args) # just in case it's a subclass
300        if kwds is None:
301            kwds = {}
302        elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
303            kwds = dict(kwds)
304        if namespace is None:
305            namespace = {}
306
307        self.__dict__ = namespace
308        self.func = func
309        self.args = args
310        self.keywords = kwds
311
312try:
313    from _functools import partial
314except ImportError:
315    pass
316
317# Descriptor version
318class partialmethod(object):
319    """Method descriptor with partial application of the given arguments
320    and keywords.
321
322    Supports wrapping existing descriptors and handles non-descriptor
323    callables as instance methods.
324    """
325
326    def __init__(*args, **keywords):
327        if len(args) >= 2:
328            self, func, *args = args
329        elif not args:
330            raise TypeError("descriptor '__init__' of partialmethod "
331                            "needs an argument")
332        elif 'func' in keywords:
333            func = keywords.pop('func')
334            self, *args = args
335        else:
336            raise TypeError("type 'partialmethod' takes at least one argument, "
337                            "got %d" % (len(args)-1))
338        args = tuple(args)
339
340        if not callable(func) and not hasattr(func, "__get__"):
341            raise TypeError("{!r} is not callable or a descriptor"
342                                 .format(func))
343
344        # func could be a descriptor like classmethod which isn't callable,
345        # so we can't inherit from partial (it verifies func is callable)
346        if isinstance(func, partialmethod):
347            # flattening is mandatory in order to place cls/self before all
348            # other arguments
349            # it's also more efficient since only one function will be called
350            self.func = func.func
351            self.args = func.args + args
352            self.keywords = func.keywords.copy()
353            self.keywords.update(keywords)
354        else:
355            self.func = func
356            self.args = args
357            self.keywords = keywords
358
359    def __repr__(self):
360        args = ", ".join(map(repr, self.args))
361        keywords = ", ".join("{}={!r}".format(k, v)
362                                 for k, v in self.keywords.items())
363        format_string = "{module}.{cls}({func}, {args}, {keywords})"
364        return format_string.format(module=self.__class__.__module__,
365                                    cls=self.__class__.__qualname__,
366                                    func=self.func,
367                                    args=args,
368                                    keywords=keywords)
369
370    def _make_unbound_method(self):
371        def _method(*args, **keywords):
372            call_keywords = self.keywords.copy()
373            call_keywords.update(keywords)
374            cls_or_self, *rest = args
375            call_args = (cls_or_self,) + self.args + tuple(rest)
376            return self.func(*call_args, **call_keywords)
377        _method.__isabstractmethod__ = self.__isabstractmethod__
378        _method._partialmethod = self
379        return _method
380
381    def __get__(self, obj, cls):
382        get = getattr(self.func, "__get__", None)
383        result = None
384        if get is not None:
385            new_func = get(obj, cls)
386            if new_func is not self.func:
387                # Assume __get__ returning something new indicates the
388                # creation of an appropriate callable
389                result = partial(new_func, *self.args, **self.keywords)
390                try:
391                    result.__self__ = new_func.__self__
392                except AttributeError:
393                    pass
394        if result is None:
395            # If the underlying descriptor didn't do anything, treat this
396            # like an instance method
397            result = self._make_unbound_method().__get__(obj, cls)
398        return result
399
400    @property
401    def __isabstractmethod__(self):
402        return getattr(self.func, "__isabstractmethod__", False)
403
404
405################################################################################
406### LRU Cache function decorator
407################################################################################
408
409_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
410
411class _HashedSeq(list):
412    """ This class guarantees that hash() will be called no more than once
413        per element.  This is important because the lru_cache() will hash
414        the key multiple times on a cache miss.
415
416    """
417
418    __slots__ = 'hashvalue'
419
420    def __init__(self, tup, hash=hash):
421        self[:] = tup
422        self.hashvalue = hash(tup)
423
424    def __hash__(self):
425        return self.hashvalue
426
427def _make_key(args, kwds, typed,
428             kwd_mark = (object(),),
429             fasttypes = {int, str},
430             tuple=tuple, type=type, len=len):
431    """Make a cache key from optionally typed positional and keyword arguments
432
433    The key is constructed in a way that is flat as possible rather than
434    as a nested structure that would take more memory.
435
436    If there is only a single argument and its data type is known to cache
437    its hash value, then that argument is returned without a wrapper.  This
438    saves space and improves lookup speed.
439
440    """
441    # All of code below relies on kwds preserving the order input by the user.
442    # Formerly, we sorted() the kwds before looping.  The new way is *much*
443    # faster; however, it means that f(x=1, y=2) will now be treated as a
444    # distinct call from f(y=2, x=1) which will be cached separately.
445    key = args
446    if kwds:
447        key += kwd_mark
448        for item in kwds.items():
449            key += item
450    if typed:
451        key += tuple(type(v) for v in args)
452        if kwds:
453            key += tuple(type(v) for v in kwds.values())
454    elif len(key) == 1 and type(key[0]) in fasttypes:
455        return key[0]
456    return _HashedSeq(key)
457
458def lru_cache(maxsize=128, typed=False):
459    """Least-recently-used cache decorator.
460
461    If *maxsize* is set to None, the LRU features are disabled and the cache
462    can grow without bound.
463
464    If *typed* is True, arguments of different types will be cached separately.
465    For example, f(3.0) and f(3) will be treated as distinct calls with
466    distinct results.
467
468    Arguments to the cached function must be hashable.
469
470    View the cache statistics named tuple (hits, misses, maxsize, currsize)
471    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
472    Access the underlying function with f.__wrapped__.
473
474    See:  http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
475
476    """
477
478    # Users should only access the lru_cache through its public API:
479    #       cache_info, cache_clear, and f.__wrapped__
480    # The internals of the lru_cache are encapsulated for thread safety and
481    # to allow the implementation to change (including a possible C version).
482
483    # Early detection of an erroneous call to @lru_cache without any arguments
484    # resulting in the inner function being passed to maxsize instead of an
485    # integer or None.  Negative maxsize is treated as 0.
486    if isinstance(maxsize, int):
487        if maxsize < 0:
488            maxsize = 0
489    elif maxsize is not None:
490        raise TypeError('Expected maxsize to be an integer or None')
491
492    def decorating_function(user_function):
493        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
494        return update_wrapper(wrapper, user_function)
495
496    return decorating_function
497
498def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
499    # Constants shared by all lru cache instances:
500    sentinel = object()          # unique object used to signal cache misses
501    make_key = _make_key         # build a key from the function arguments
502    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields
503
504    cache = {}
505    hits = misses = 0
506    full = False
507    cache_get = cache.get    # bound method to lookup a key or return None
508    cache_len = cache.__len__  # get cache size without calling len()
509    lock = RLock()           # because linkedlist updates aren't threadsafe
510    root = []                # root of the circular doubly linked list
511    root[:] = [root, root, None, None]     # initialize by pointing to self
512
513    if maxsize == 0:
514
515        def wrapper(*args, **kwds):
516            # No caching -- just a statistics update
517            nonlocal misses
518            misses += 1
519            result = user_function(*args, **kwds)
520            return result
521
522    elif maxsize is None:
523
524        def wrapper(*args, **kwds):
525            # Simple caching without ordering or size limit
526            nonlocal hits, misses
527            key = make_key(args, kwds, typed)
528            result = cache_get(key, sentinel)
529            if result is not sentinel:
530                hits += 1
531                return result
532            misses += 1
533            result = user_function(*args, **kwds)
534            cache[key] = result
535            return result
536
537    else:
538
539        def wrapper(*args, **kwds):
540            # Size limited caching that tracks accesses by recency
541            nonlocal root, hits, misses, full
542            key = make_key(args, kwds, typed)
543            with lock:
544                link = cache_get(key)
545                if link is not None:
546                    # Move the link to the front of the circular queue
547                    link_prev, link_next, _key, result = link
548                    link_prev[NEXT] = link_next
549                    link_next[PREV] = link_prev
550                    last = root[PREV]
551                    last[NEXT] = root[PREV] = link
552                    link[PREV] = last
553                    link[NEXT] = root
554                    hits += 1
555                    return result
556                misses += 1
557            result = user_function(*args, **kwds)
558            with lock:
559                if key in cache:
560                    # Getting here means that this same key was added to the
561                    # cache while the lock was released.  Since the link
562                    # update is already done, we need only return the
563                    # computed result and update the count of misses.
564                    pass
565                elif full:
566                    # Use the old root to store the new key and result.
567                    oldroot = root
568                    oldroot[KEY] = key
569                    oldroot[RESULT] = result
570                    # Empty the oldest link and make it the new root.
571                    # Keep a reference to the old key and old result to
572                    # prevent their ref counts from going to zero during the
573                    # update. That will prevent potentially arbitrary object
574                    # clean-up code (i.e. __del__) from running while we're
575                    # still adjusting the links.
576                    root = oldroot[NEXT]
577                    oldkey = root[KEY]
578                    oldresult = root[RESULT]
579                    root[KEY] = root[RESULT] = None
580                    # Now update the cache dictionary.
581                    del cache[oldkey]
582                    # Save the potentially reentrant cache[key] assignment
583                    # for last, after the root and links have been put in
584                    # a consistent state.
585                    cache[key] = oldroot
586                else:
587                    # Put result in a new link at the front of the queue.
588                    last = root[PREV]
589                    link = [last, root, key, result]
590                    last[NEXT] = root[PREV] = cache[key] = link
591                    # Use the cache_len bound method instead of the len() function
592                    # which could potentially be wrapped in an lru_cache itself.
593                    full = (cache_len() >= maxsize)
594            return result
595
596    def cache_info():
597        """Report cache statistics"""
598        with lock:
599            return _CacheInfo(hits, misses, maxsize, cache_len())
600
601    def cache_clear():
602        """Clear the cache and cache statistics"""
603        nonlocal hits, misses, full
604        with lock:
605            cache.clear()
606            root[:] = [root, root, None, None]
607            hits = misses = 0
608            full = False
609
610    wrapper.cache_info = cache_info
611    wrapper.cache_clear = cache_clear
612    return wrapper
613
614try:
615    from _functools import _lru_cache_wrapper
616except ImportError:
617    pass
618
619
620################################################################################
621### singledispatch() - single-dispatch generic function decorator
622################################################################################
623
624def _c3_merge(sequences):
625    """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
626
627    Adapted from http://www.python.org/download/releases/2.3/mro/.
628
629    """
630    result = []
631    while True:
632        sequences = [s for s in sequences if s]   # purge empty sequences
633        if not sequences:
634            return result
635        for s1 in sequences:   # find merge candidates among seq heads
636            candidate = s1[0]
637            for s2 in sequences:
638                if candidate in s2[1:]:
639                    candidate = None
640                    break      # reject the current head, it appears later
641            else:
642                break
643        if candidate is None:
644            raise RuntimeError("Inconsistent hierarchy")
645        result.append(candidate)
646        # remove the chosen candidate
647        for seq in sequences:
648            if seq[0] == candidate:
649                del seq[0]
650
651def _c3_mro(cls, abcs=None):
652    """Computes the method resolution order using extended C3 linearization.
653
654    If no *abcs* are given, the algorithm works exactly like the built-in C3
655    linearization used for method resolution.
656
657    If given, *abcs* is a list of abstract base classes that should be inserted
658    into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
659    result. The algorithm inserts ABCs where their functionality is introduced,
660    i.e. issubclass(cls, abc) returns True for the class itself but returns
661    False for all its direct base classes. Implicit ABCs for a given class
662    (either registered or inferred from the presence of a special method like
663    __len__) are inserted directly after the last ABC explicitly listed in the
664    MRO of said class. If two implicit ABCs end up next to each other in the
665    resulting MRO, their ordering depends on the order of types in *abcs*.
666
667    """
668    for i, base in enumerate(reversed(cls.__bases__)):
669        if hasattr(base, '__abstractmethods__'):
670            boundary = len(cls.__bases__) - i
671            break   # Bases up to the last explicit ABC are considered first.
672    else:
673        boundary = 0
674    abcs = list(abcs) if abcs else []
675    explicit_bases = list(cls.__bases__[:boundary])
676    abstract_bases = []
677    other_bases = list(cls.__bases__[boundary:])
678    for base in abcs:
679        if issubclass(cls, base) and not any(
680                issubclass(b, base) for b in cls.__bases__
681            ):
682            # If *cls* is the class that introduces behaviour described by
683            # an ABC *base*, insert said ABC to its MRO.
684            abstract_bases.append(base)
685    for base in abstract_bases:
686        abcs.remove(base)
687    explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
688    abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
689    other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
690    return _c3_merge(
691        [[cls]] +
692        explicit_c3_mros + abstract_c3_mros + other_c3_mros +
693        [explicit_bases] + [abstract_bases] + [other_bases]
694    )
695
696def _compose_mro(cls, types):
697    """Calculates the method resolution order for a given class *cls*.
698
699    Includes relevant abstract base classes (with their respective bases) from
700    the *types* iterable. Uses a modified C3 linearization algorithm.
701
702    """
703    bases = set(cls.__mro__)
704    # Remove entries which are already present in the __mro__ or unrelated.
705    def is_related(typ):
706        return (typ not in bases and hasattr(typ, '__mro__')
707                                 and issubclass(cls, typ))
708    types = [n for n in types if is_related(n)]
709    # Remove entries which are strict bases of other entries (they will end up
710    # in the MRO anyway.
711    def is_strict_base(typ):
712        for other in types:
713            if typ != other and typ in other.__mro__:
714                return True
715        return False
716    types = [n for n in types if not is_strict_base(n)]
717    # Subclasses of the ABCs in *types* which are also implemented by
718    # *cls* can be used to stabilize ABC ordering.
719    type_set = set(types)
720    mro = []
721    for typ in types:
722        found = []
723        for sub in typ.__subclasses__():
724            if sub not in bases and issubclass(cls, sub):
725                found.append([s for s in sub.__mro__ if s in type_set])
726        if not found:
727            mro.append(typ)
728            continue
729        # Favor subclasses with the biggest number of useful bases
730        found.sort(key=len, reverse=True)
731        for sub in found:
732            for subcls in sub:
733                if subcls not in mro:
734                    mro.append(subcls)
735    return _c3_mro(cls, abcs=mro)
736
737def _find_impl(cls, registry):
738    """Returns the best matching implementation from *registry* for type *cls*.
739
740    Where there is no registered implementation for a specific type, its method
741    resolution order is used to find a more generic implementation.
742
743    Note: if *registry* does not contain an implementation for the base
744    *object* type, this function may return None.
745
746    """
747    mro = _compose_mro(cls, registry.keys())
748    match = None
749    for t in mro:
750        if match is not None:
751            # If *match* is an implicit ABC but there is another unrelated,
752            # equally matching implicit ABC, refuse the temptation to guess.
753            if (t in registry and t not in cls.__mro__
754                              and match not in cls.__mro__
755                              and not issubclass(match, t)):
756                raise RuntimeError("Ambiguous dispatch: {} or {}".format(
757                    match, t))
758            break
759        if t in registry:
760            match = t
761    return registry.get(match)
762
763def singledispatch(func):
764    """Single-dispatch generic function decorator.
765
766    Transforms a function into a generic function, which can have different
767    behaviours depending upon the type of its first argument. The decorated
768    function acts as the default implementation, and additional
769    implementations can be registered using the register() attribute of the
770    generic function.
771    """
772    # There are many programs that use functools without singledispatch, so we
773    # trade-off making singledispatch marginally slower for the benefit of
774    # making start-up of such applications slightly faster.
775    import types, weakref
776
777    registry = {}
778    dispatch_cache = weakref.WeakKeyDictionary()
779    cache_token = None
780
781    def dispatch(cls):
782        """generic_func.dispatch(cls) -> <function implementation>
783
784        Runs the dispatch algorithm to return the best available implementation
785        for the given *cls* registered on *generic_func*.
786
787        """
788        nonlocal cache_token
789        if cache_token is not None:
790            current_token = get_cache_token()
791            if cache_token != current_token:
792                dispatch_cache.clear()
793                cache_token = current_token
794        try:
795            impl = dispatch_cache[cls]
796        except KeyError:
797            try:
798                impl = registry[cls]
799            except KeyError:
800                impl = _find_impl(cls, registry)
801            dispatch_cache[cls] = impl
802        return impl
803
804    def register(cls, func=None):
805        """generic_func.register(cls, func) -> func
806
807        Registers a new implementation for the given *cls* on a *generic_func*.
808
809        """
810        nonlocal cache_token
811        if func is None:
812            if isinstance(cls, type):
813                return lambda f: register(cls, f)
814            ann = getattr(cls, '__annotations__', {})
815            if not ann:
816                raise TypeError(
817                    f"Invalid first argument to `register()`: {cls!r}. "
818                    f"Use either `@register(some_class)` or plain `@register` "
819                    f"on an annotated function."
820                )
821            func = cls
822
823            # only import typing if annotation parsing is necessary
824            from typing import get_type_hints
825            argname, cls = next(iter(get_type_hints(func).items()))
826            assert isinstance(cls, type), (
827                f"Invalid annotation for {argname!r}. {cls!r} is not a class."
828            )
829        registry[cls] = func
830        if cache_token is None and hasattr(cls, '__abstractmethods__'):
831            cache_token = get_cache_token()
832        dispatch_cache.clear()
833        return func
834
835    def wrapper(*args, **kw):
836        if not args:
837            raise TypeError(f'{funcname} requires at least '
838                            '1 positional argument')
839
840        return dispatch(args[0].__class__)(*args, **kw)
841
842    funcname = getattr(func, '__name__', 'singledispatch function')
843    registry[object] = func
844    wrapper.register = register
845    wrapper.dispatch = dispatch
846    wrapper.registry = types.MappingProxyType(registry)
847    wrapper._clear_cache = dispatch_cache.clear
848    update_wrapper(wrapper, func)
849    return wrapper
850