1# util/langhelpers.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Routines to help with the creation, loading and introspection of
9modules, classes, hierarchies, attributes, functions, and methods.
10
11"""
12from functools import update_wrapper
13import hashlib
14import inspect
15import itertools
16import operator
17import re
18import sys
19import types
20import warnings
21
22from . import _collections
23from . import compat
24from .. import exc
25
26
27def md5_hex(x):
28    if compat.py3k:
29        x = x.encode("utf-8")
30    m = hashlib.md5()
31    m.update(x)
32    return m.hexdigest()
33
34
35class safe_reraise(object):
36    """Reraise an exception after invoking some
37    handler code.
38
39    Stores the existing exception info before
40    invoking so that it is maintained across a potential
41    coroutine context switch.
42
43    e.g.::
44
45        try:
46            sess.commit()
47        except:
48            with safe_reraise():
49                sess.rollback()
50
51    """
52
53    __slots__ = ("warn_only", "_exc_info")
54
55    def __init__(self, warn_only=False):
56        self.warn_only = warn_only
57
58    def __enter__(self):
59        self._exc_info = sys.exc_info()
60
61    def __exit__(self, type_, value, traceback):
62        # see #2703 for notes
63        if type_ is None:
64            exc_type, exc_value, exc_tb = self._exc_info
65            self._exc_info = None  # remove potential circular references
66            if not self.warn_only:
67                compat.reraise(exc_type, exc_value, exc_tb)
68        else:
69            if not compat.py3k and self._exc_info and self._exc_info[1]:
70                # emulate Py3K's behavior of telling us when an exception
71                # occurs in an exception handler.
72                warn(
73                    "An exception has occurred during handling of a "
74                    "previous exception.  The previous exception "
75                    "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
76                )
77            self._exc_info = None  # remove potential circular references
78            compat.reraise(type_, value, traceback)
79
80
81def decode_slice(slc):
82    """decode a slice object as sent to __getitem__.
83
84    takes into account the 2.5 __index__() method, basically.
85
86    """
87    ret = []
88    for x in slc.start, slc.stop, slc.step:
89        if hasattr(x, "__index__"):
90            x = x.__index__()
91        ret.append(x)
92    return tuple(ret)
93
94
95def _unique_symbols(used, *bases):
96    used = set(used)
97    for base in bases:
98        pool = itertools.chain(
99            (base,),
100            compat.itertools_imap(lambda i: base + str(i), range(1000)),
101        )
102        for sym in pool:
103            if sym not in used:
104                used.add(sym)
105                yield sym
106                break
107        else:
108            raise NameError("exhausted namespace for symbol base %s" % base)
109
110
111def map_bits(fn, n):
112    """Call the given function given each nonzero bit from n."""
113
114    while n:
115        b = n & (~n + 1)
116        yield fn(b)
117        n ^= b
118
119
120def decorator(target):
121    """A signature-matching decorator factory."""
122
123    def decorate(fn):
124        if not inspect.isfunction(fn):
125            raise Exception("not a decoratable function")
126        spec = compat.inspect_getfullargspec(fn)
127        names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
128        targ_name, fn_name = _unique_symbols(names, "target", "fn")
129
130        metadata = dict(target=targ_name, fn=fn_name)
131        metadata.update(format_argspec_plus(spec, grouped=False))
132        metadata["name"] = fn.__name__
133        code = (
134            """\
135def %(name)s(%(args)s):
136    return %(target)s(%(fn)s, %(apply_kw)s)
137"""
138            % metadata
139        )
140        decorated = _exec_code_in_env(
141            code, {targ_name: target, fn_name: fn}, fn.__name__
142        )
143        decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
144        decorated.__wrapped__ = fn
145        return update_wrapper(decorated, fn)
146
147    return update_wrapper(decorate, target)
148
149
150def _exec_code_in_env(code, env, fn_name):
151    exec(code, env)
152    return env[fn_name]
153
154
155def public_factory(target, location):
156    """Produce a wrapping function for the given cls or classmethod.
157
158    Rationale here is so that the __init__ method of the
159    class can serve as documentation for the function.
160
161    """
162    if isinstance(target, type):
163        fn = target.__init__
164        callable_ = target
165        doc = (
166            "Construct a new :class:`.%s` object. \n\n"
167            "This constructor is mirrored as a public API function; "
168            "see :func:`~%s` "
169            "for a full usage and argument description."
170            % (target.__name__, location)
171        )
172    else:
173        fn = callable_ = target
174        doc = (
175            "This function is mirrored; see :func:`~%s` "
176            "for a description of arguments." % location
177        )
178
179    location_name = location.split(".")[-1]
180    spec = compat.inspect_getfullargspec(fn)
181    del spec[0][0]
182    metadata = format_argspec_plus(spec, grouped=False)
183    metadata["name"] = location_name
184    code = (
185        """\
186def %(name)s(%(args)s):
187    return cls(%(apply_kw)s)
188"""
189        % metadata
190    )
191    env = {"cls": callable_, "symbol": symbol}
192    exec(code, env)
193    decorated = env[location_name]
194    decorated.__doc__ = fn.__doc__
195    decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
196    if compat.py2k or hasattr(fn, "__func__"):
197        fn.__func__.__doc__ = doc
198    else:
199        fn.__doc__ = doc
200    return decorated
201
202
203class PluginLoader(object):
204    def __init__(self, group, auto_fn=None):
205        self.group = group
206        self.impls = {}
207        self.auto_fn = auto_fn
208
209    def clear(self):
210        self.impls.clear()
211
212    def load(self, name):
213        if name in self.impls:
214            return self.impls[name]()
215
216        if self.auto_fn:
217            loader = self.auto_fn(name)
218            if loader:
219                self.impls[name] = loader
220                return loader()
221
222        try:
223            import pkg_resources
224        except ImportError:
225            pass
226        else:
227            for impl in pkg_resources.iter_entry_points(self.group, name):
228                self.impls[name] = impl.load
229                return impl.load()
230
231        raise exc.NoSuchModuleError(
232            "Can't load plugin: %s:%s" % (self.group, name)
233        )
234
235    def register(self, name, modulepath, objname):
236        def load():
237            mod = compat.import_(modulepath)
238            for token in modulepath.split(".")[1:]:
239                mod = getattr(mod, token)
240            return getattr(mod, objname)
241
242        self.impls[name] = load
243
244
245def get_cls_kwargs(cls, _set=None):
246    r"""Return the full set of inherited kwargs for the given `cls`.
247
248    Probes a class's __init__ method, collecting all named arguments.  If the
249    __init__ defines a \**kwargs catch-all, then the constructor is presumed
250    to pass along unrecognized keywords to its base classes, and the
251    collection process is repeated recursively on each of the bases.
252
253    Uses a subset of inspect.getargspec() to cut down on method overhead.
254    No anonymous tuple arguments please !
255
256    """
257    toplevel = _set is None
258    if toplevel:
259        _set = set()
260
261    ctr = cls.__dict__.get("__init__", False)
262
263    has_init = (
264        ctr
265        and isinstance(ctr, types.FunctionType)
266        and isinstance(ctr.__code__, types.CodeType)
267    )
268
269    if has_init:
270        names, has_kw = inspect_func_args(ctr)
271        _set.update(names)
272
273        if not has_kw and not toplevel:
274            return None
275
276    if not has_init or has_kw:
277        for c in cls.__bases__:
278            if get_cls_kwargs(c, _set) is None:
279                break
280
281    _set.discard("self")
282    return _set
283
284
285try:
286    # TODO: who doesn't have this constant?
287    from inspect import CO_VARKEYWORDS
288
289    def inspect_func_args(fn):
290        co = fn.__code__
291        nargs = co.co_argcount
292        names = co.co_varnames
293        args = list(names[:nargs])
294        has_kw = bool(co.co_flags & CO_VARKEYWORDS)
295        return args, has_kw
296
297
298except ImportError:
299
300    def inspect_func_args(fn):
301        names, _, has_kw, _ = compat.inspect_getargspec(fn)
302        return names, bool(has_kw)
303
304
305def get_func_kwargs(func):
306    """Return the set of legal kwargs for the given `func`.
307
308    Uses getargspec so is safe to call for methods, functions,
309    etc.
310
311    """
312
313    return compat.inspect_getargspec(func)[0]
314
315
316def get_callable_argspec(fn, no_self=False, _is_init=False):
317    """Return the argument signature for any callable.
318
319    All pure-Python callables are accepted, including
320    functions, methods, classes, objects with __call__;
321    builtins and other edge cases like functools.partial() objects
322    raise a TypeError.
323
324    """
325    if inspect.isbuiltin(fn):
326        raise TypeError("Can't inspect builtin: %s" % fn)
327    elif inspect.isfunction(fn):
328        if _is_init and no_self:
329            spec = compat.inspect_getargspec(fn)
330            return compat.ArgSpec(
331                spec.args[1:], spec.varargs, spec.keywords, spec.defaults
332            )
333        else:
334            return compat.inspect_getargspec(fn)
335    elif inspect.ismethod(fn):
336        if no_self and (_is_init or fn.__self__):
337            spec = compat.inspect_getargspec(fn.__func__)
338            return compat.ArgSpec(
339                spec.args[1:], spec.varargs, spec.keywords, spec.defaults
340            )
341        else:
342            return compat.inspect_getargspec(fn.__func__)
343    elif inspect.isclass(fn):
344        return get_callable_argspec(
345            fn.__init__, no_self=no_self, _is_init=True
346        )
347    elif hasattr(fn, "__func__"):
348        return compat.inspect_getargspec(fn.__func__)
349    elif hasattr(fn, "__call__"):
350        if inspect.ismethod(fn.__call__):
351            return get_callable_argspec(fn.__call__, no_self=no_self)
352        else:
353            raise TypeError("Can't inspect callable: %s" % fn)
354    else:
355        raise TypeError("Can't inspect callable: %s" % fn)
356
357
358def format_argspec_plus(fn, grouped=True):
359    """Returns a dictionary of formatted, introspected function arguments.
360
361    A enhanced variant of inspect.formatargspec to support code generation.
362
363    fn
364       An inspectable callable or tuple of inspect getargspec() results.
365    grouped
366      Defaults to True; include (parens, around, argument) lists
367
368    Returns:
369
370    args
371      Full inspect.formatargspec for fn
372    self_arg
373      The name of the first positional argument, varargs[0], or None
374      if the function defines no positional arguments.
375    apply_pos
376      args, re-written in calling rather than receiving syntax.  Arguments are
377      passed positionally.
378    apply_kw
379      Like apply_pos, except keyword-ish args are passed as keywords.
380
381    Example::
382
383      >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
384      {'args': '(self, a, b, c=3, **d)',
385       'self_arg': 'self',
386       'apply_kw': '(self, a, b, c=c, **d)',
387       'apply_pos': '(self, a, b, c, **d)'}
388
389    """
390    if compat.callable(fn):
391        spec = compat.inspect_getfullargspec(fn)
392    else:
393        # we accept an existing argspec...
394        spec = fn
395    args = compat.inspect_formatargspec(*spec)
396    if spec[0]:
397        self_arg = spec[0][0]
398    elif spec[1]:
399        self_arg = "%s[0]" % spec[1]
400    else:
401        self_arg = None
402
403    if compat.py3k:
404        apply_pos = compat.inspect_formatargspec(
405            spec[0], spec[1], spec[2], None, spec[4]
406        )
407        num_defaults = 0
408        if spec[3]:
409            num_defaults += len(spec[3])
410        if spec[4]:
411            num_defaults += len(spec[4])
412        name_args = spec[0] + spec[4]
413    else:
414        apply_pos = compat.inspect_formatargspec(spec[0], spec[1], spec[2])
415        num_defaults = 0
416        if spec[3]:
417            num_defaults += len(spec[3])
418        name_args = spec[0]
419
420    if num_defaults:
421        defaulted_vals = name_args[0 - num_defaults :]
422    else:
423        defaulted_vals = ()
424
425    apply_kw = compat.inspect_formatargspec(
426        name_args,
427        spec[1],
428        spec[2],
429        defaulted_vals,
430        formatvalue=lambda x: "=" + x,
431    )
432    if grouped:
433        return dict(
434            args=args,
435            self_arg=self_arg,
436            apply_pos=apply_pos,
437            apply_kw=apply_kw,
438        )
439    else:
440        return dict(
441            args=args[1:-1],
442            self_arg=self_arg,
443            apply_pos=apply_pos[1:-1],
444            apply_kw=apply_kw[1:-1],
445        )
446
447
448def format_argspec_init(method, grouped=True):
449    """format_argspec_plus with considerations for typical __init__ methods
450
451    Wraps format_argspec_plus with error handling strategies for typical
452    __init__ cases::
453
454      object.__init__ -> (self)
455      other unreflectable (usually C) -> (self, *args, **kwargs)
456
457    """
458    if method is object.__init__:
459        args = grouped and "(self)" or "self"
460    else:
461        try:
462            return format_argspec_plus(method, grouped=grouped)
463        except TypeError:
464            args = (
465                grouped
466                and "(self, *args, **kwargs)"
467                or "self, *args, **kwargs"
468            )
469    return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
470
471
472def getargspec_init(method):
473    """inspect.getargspec with considerations for typical __init__ methods
474
475    Wraps inspect.getargspec with error handling for typical __init__ cases::
476
477      object.__init__ -> (self)
478      other unreflectable (usually C) -> (self, *args, **kwargs)
479
480    """
481    try:
482        return compat.inspect_getargspec(method)
483    except TypeError:
484        if method is object.__init__:
485            return (["self"], None, None, None)
486        else:
487            return (["self"], "args", "kwargs", None)
488
489
490def unbound_method_to_callable(func_or_cls):
491    """Adjust the incoming callable such that a 'self' argument is not
492    required.
493
494    """
495
496    if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
497        return func_or_cls.__func__
498    else:
499        return func_or_cls
500
501
502def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
503    """Produce a __repr__() based on direct association of the __init__()
504    specification vs. same-named attributes present.
505
506    """
507    if to_inspect is None:
508        to_inspect = [obj]
509    else:
510        to_inspect = _collections.to_list(to_inspect)
511
512    missing = object()
513
514    pos_args = []
515    kw_args = _collections.OrderedDict()
516    vargs = None
517    for i, insp in enumerate(to_inspect):
518        try:
519            (_args, _vargs, vkw, defaults) = compat.inspect_getargspec(
520                insp.__init__
521            )
522        except TypeError:
523            continue
524        else:
525            default_len = defaults and len(defaults) or 0
526            if i == 0:
527                if _vargs:
528                    vargs = _vargs
529                if default_len:
530                    pos_args.extend(_args[1:-default_len])
531                else:
532                    pos_args.extend(_args[1:])
533            else:
534                kw_args.update(
535                    [(arg, missing) for arg in _args[1:-default_len]]
536                )
537
538            if default_len:
539                kw_args.update(
540                    [
541                        (arg, default)
542                        for arg, default in zip(_args[-default_len:], defaults)
543                    ]
544                )
545    output = []
546
547    output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
548
549    if vargs is not None and hasattr(obj, vargs):
550        output.extend([repr(val) for val in getattr(obj, vargs)])
551
552    for arg, defval in kw_args.items():
553        if arg in omit_kwarg:
554            continue
555        try:
556            val = getattr(obj, arg, missing)
557            if val is not missing and val != defval:
558                output.append("%s=%r" % (arg, val))
559        except Exception:
560            pass
561
562    if additional_kw:
563        for arg, defval in additional_kw:
564            try:
565                val = getattr(obj, arg, missing)
566                if val is not missing and val != defval:
567                    output.append("%s=%r" % (arg, val))
568            except Exception:
569                pass
570
571    return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
572
573
574class portable_instancemethod(object):
575    """Turn an instancemethod into a (parent, name) pair
576    to produce a serializable callable.
577
578    """
579
580    __slots__ = "target", "name", "kwargs", "__weakref__"
581
582    def __getstate__(self):
583        return {
584            "target": self.target,
585            "name": self.name,
586            "kwargs": self.kwargs,
587        }
588
589    def __setstate__(self, state):
590        self.target = state["target"]
591        self.name = state["name"]
592        self.kwargs = state.get("kwargs", ())
593
594    def __init__(self, meth, kwargs=()):
595        self.target = meth.__self__
596        self.name = meth.__name__
597        self.kwargs = kwargs
598
599    def __call__(self, *arg, **kw):
600        kw.update(self.kwargs)
601        return getattr(self.target, self.name)(*arg, **kw)
602
603
604def class_hierarchy(cls):
605    """Return an unordered sequence of all classes related to cls.
606
607    Traverses diamond hierarchies.
608
609    Fibs slightly: subclasses of builtin types are not returned.  Thus
610    class_hierarchy(class A(object)) returns (A, object), not A plus every
611    class systemwide that derives from object.
612
613    Old-style classes are discarded and hierarchies rooted on them
614    will not be descended.
615
616    """
617    if compat.py2k:
618        if isinstance(cls, types.ClassType):
619            return list()
620
621    hier = {cls}
622    process = list(cls.__mro__)
623    while process:
624        c = process.pop()
625        if compat.py2k:
626            if isinstance(c, types.ClassType):
627                continue
628            bases = (
629                _
630                for _ in c.__bases__
631                if _ not in hier and not isinstance(_, types.ClassType)
632            )
633        else:
634            bases = (_ for _ in c.__bases__ if _ not in hier)
635
636        for b in bases:
637            process.append(b)
638            hier.add(b)
639
640        if compat.py3k:
641            if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
642                continue
643        else:
644            if c.__module__ == "__builtin__" or not hasattr(
645                c, "__subclasses__"
646            ):
647                continue
648
649        for s in [_ for _ in c.__subclasses__() if _ not in hier]:
650            process.append(s)
651            hier.add(s)
652    return list(hier)
653
654
655def iterate_attributes(cls):
656    """iterate all the keys and attributes associated
657       with a class, without using getattr().
658
659       Does not use getattr() so that class-sensitive
660       descriptors (i.e. property.__get__()) are not called.
661
662    """
663    keys = dir(cls)
664    for key in keys:
665        for c in cls.__mro__:
666            if key in c.__dict__:
667                yield (key, c.__dict__[key])
668                break
669
670
671def monkeypatch_proxied_specials(
672    into_cls,
673    from_cls,
674    skip=None,
675    only=None,
676    name="self.proxy",
677    from_instance=None,
678):
679    """Automates delegation of __specials__ for a proxying type."""
680
681    if only:
682        dunders = only
683    else:
684        if skip is None:
685            skip = (
686                "__slots__",
687                "__del__",
688                "__getattribute__",
689                "__metaclass__",
690                "__getstate__",
691                "__setstate__",
692            )
693        dunders = [
694            m
695            for m in dir(from_cls)
696            if (
697                m.startswith("__")
698                and m.endswith("__")
699                and not hasattr(into_cls, m)
700                and m not in skip
701            )
702        ]
703
704    for method in dunders:
705        try:
706            fn = getattr(from_cls, method)
707            if not hasattr(fn, "__call__"):
708                continue
709            fn = getattr(fn, "im_func", fn)
710        except AttributeError:
711            continue
712        try:
713            spec = compat.inspect_getargspec(fn)
714            fn_args = compat.inspect_formatargspec(spec[0])
715            d_args = compat.inspect_formatargspec(spec[0][1:])
716        except TypeError:
717            fn_args = "(self, *args, **kw)"
718            d_args = "(*args, **kw)"
719
720        py = (
721            "def %(method)s%(fn_args)s: "
722            "return %(name)s.%(method)s%(d_args)s" % locals()
723        )
724
725        env = from_instance is not None and {name: from_instance} or {}
726        compat.exec_(py, env)
727        try:
728            env[method].__defaults__ = fn.__defaults__
729        except AttributeError:
730            pass
731        setattr(into_cls, method, env[method])
732
733
734def methods_equivalent(meth1, meth2):
735    """Return True if the two methods are the same implementation."""
736
737    return getattr(meth1, "__func__", meth1) is getattr(
738        meth2, "__func__", meth2
739    )
740
741
742def as_interface(obj, cls=None, methods=None, required=None):
743    """Ensure basic interface compliance for an instance or dict of callables.
744
745    Checks that ``obj`` implements public methods of ``cls`` or has members
746    listed in ``methods``. If ``required`` is not supplied, implementing at
747    least one interface method is sufficient. Methods present on ``obj`` that
748    are not in the interface are ignored.
749
750    If ``obj`` is a dict and ``dict`` does not meet the interface
751    requirements, the keys of the dictionary are inspected. Keys present in
752    ``obj`` that are not in the interface will raise TypeErrors.
753
754    Raises TypeError if ``obj`` does not meet the interface criteria.
755
756    In all passing cases, an object with callable members is returned.  In the
757    simple case, ``obj`` is returned as-is; if dict processing kicks in then
758    an anonymous class is returned.
759
760    obj
761      A type, instance, or dictionary of callables.
762    cls
763      Optional, a type.  All public methods of cls are considered the
764      interface.  An ``obj`` instance of cls will always pass, ignoring
765      ``required``..
766    methods
767      Optional, a sequence of method names to consider as the interface.
768    required
769      Optional, a sequence of mandatory implementations. If omitted, an
770      ``obj`` that provides at least one interface method is considered
771      sufficient.  As a convenience, required may be a type, in which case
772      all public methods of the type are required.
773
774    """
775    if not cls and not methods:
776        raise TypeError("a class or collection of method names are required")
777
778    if isinstance(cls, type) and isinstance(obj, cls):
779        return obj
780
781    interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
782    implemented = set(dir(obj))
783
784    complies = operator.ge
785    if isinstance(required, type):
786        required = interface
787    elif not required:
788        required = set()
789        complies = operator.gt
790    else:
791        required = set(required)
792
793    if complies(implemented.intersection(interface), required):
794        return obj
795
796    # No dict duck typing here.
797    if not isinstance(obj, dict):
798        qualifier = complies is operator.gt and "any of" or "all of"
799        raise TypeError(
800            "%r does not implement %s: %s"
801            % (obj, qualifier, ", ".join(interface))
802        )
803
804    class AnonymousInterface(object):
805        """A callable-holding shell."""
806
807    if cls:
808        AnonymousInterface.__name__ = "Anonymous" + cls.__name__
809    found = set()
810
811    for method, impl in dictlike_iteritems(obj):
812        if method not in interface:
813            raise TypeError("%r: unknown in this interface" % method)
814        if not compat.callable(impl):
815            raise TypeError("%r=%r is not callable" % (method, impl))
816        setattr(AnonymousInterface, method, staticmethod(impl))
817        found.add(method)
818
819    if complies(found, required):
820        return AnonymousInterface
821
822    raise TypeError(
823        "dictionary does not contain required keys %s"
824        % ", ".join(required - found)
825    )
826
827
828class memoized_property(object):
829    """A read-only @property that is only evaluated once."""
830
831    def __init__(self, fget, doc=None):
832        self.fget = fget
833        self.__doc__ = doc or fget.__doc__
834        self.__name__ = fget.__name__
835
836    def __get__(self, obj, cls):
837        if obj is None:
838            return self
839        obj.__dict__[self.__name__] = result = self.fget(obj)
840        return result
841
842    def _reset(self, obj):
843        memoized_property.reset(obj, self.__name__)
844
845    @classmethod
846    def reset(cls, obj, name):
847        obj.__dict__.pop(name, None)
848
849
850def memoized_instancemethod(fn):
851    """Decorate a method memoize its return value.
852
853    Best applied to no-arg methods: memoization is not sensitive to
854    argument values, and will always return the same value even when
855    called with different arguments.
856
857    """
858
859    def oneshot(self, *args, **kw):
860        result = fn(self, *args, **kw)
861
862        def memo(*a, **kw):
863            return result
864
865        memo.__name__ = fn.__name__
866        memo.__doc__ = fn.__doc__
867        self.__dict__[fn.__name__] = memo
868        return result
869
870    return update_wrapper(oneshot, fn)
871
872
873class group_expirable_memoized_property(object):
874    """A family of @memoized_properties that can be expired in tandem."""
875
876    def __init__(self, attributes=()):
877        self.attributes = []
878        if attributes:
879            self.attributes.extend(attributes)
880
881    def expire_instance(self, instance):
882        """Expire all memoized properties for *instance*."""
883        stash = instance.__dict__
884        for attribute in self.attributes:
885            stash.pop(attribute, None)
886
887    def __call__(self, fn):
888        self.attributes.append(fn.__name__)
889        return memoized_property(fn)
890
891    def method(self, fn):
892        self.attributes.append(fn.__name__)
893        return memoized_instancemethod(fn)
894
895
896class MemoizedSlots(object):
897    """Apply memoized items to an object using a __getattr__ scheme.
898
899    This allows the functionality of memoized_property and
900    memoized_instancemethod to be available to a class using __slots__.
901
902    """
903
904    __slots__ = ()
905
906    def _fallback_getattr(self, key):
907        raise AttributeError(key)
908
909    def __getattr__(self, key):
910        if key.startswith("_memoized"):
911            raise AttributeError(key)
912        elif hasattr(self, "_memoized_attr_%s" % key):
913            value = getattr(self, "_memoized_attr_%s" % key)()
914            setattr(self, key, value)
915            return value
916        elif hasattr(self, "_memoized_method_%s" % key):
917            fn = getattr(self, "_memoized_method_%s" % key)
918
919            def oneshot(*args, **kw):
920                result = fn(*args, **kw)
921
922                def memo(*a, **kw):
923                    return result
924
925                memo.__name__ = fn.__name__
926                memo.__doc__ = fn.__doc__
927                setattr(self, key, memo)
928                return result
929
930            oneshot.__doc__ = fn.__doc__
931            return oneshot
932        else:
933            return self._fallback_getattr(key)
934
935
936def dependency_for(modulename, add_to_all=False):
937    def decorate(obj):
938        tokens = modulename.split(".")
939        mod = compat.import_(
940            ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]]
941        )
942        mod = getattr(mod, tokens[-1])
943        setattr(mod, obj.__name__, obj)
944        if add_to_all and hasattr(mod, "__all__"):
945            mod.__all__.append(obj.__name__)
946        return obj
947
948    return decorate
949
950
951class dependencies(object):
952    """Apply imported dependencies as arguments to a function.
953
954    E.g.::
955
956        @util.dependencies(
957            "sqlalchemy.sql.widget",
958            "sqlalchemy.engine.default"
959        );
960        def some_func(self, widget, default, arg1, arg2, **kw):
961            # ...
962
963    Rationale is so that the impact of a dependency cycle can be
964    associated directly with the few functions that cause the cycle,
965    and not pollute the module-level namespace.
966
967    """
968
969    def __init__(self, *deps):
970        self.import_deps = []
971        for dep in deps:
972            tokens = dep.split(".")
973            self.import_deps.append(
974                dependencies._importlater(".".join(tokens[0:-1]), tokens[-1])
975            )
976
977    def __call__(self, fn):
978        import_deps = self.import_deps
979        spec = compat.inspect_getfullargspec(fn)
980
981        spec_zero = list(spec[0])
982        hasself = spec_zero[0] in ("self", "cls")
983
984        for i in range(len(import_deps)):
985            spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i
986
987        inner_spec = format_argspec_plus(spec, grouped=False)
988
989        for impname in import_deps:
990            del spec_zero[1 if hasself else 0]
991        spec[0][:] = spec_zero
992
993        outer_spec = format_argspec_plus(spec, grouped=False)
994
995        code = "lambda %(args)s: fn(%(apply_kw)s)" % {
996            "args": outer_spec["args"],
997            "apply_kw": inner_spec["apply_kw"],
998        }
999
1000        decorated = eval(code, locals())
1001        decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
1002        return update_wrapper(decorated, fn)
1003
1004    @classmethod
1005    def resolve_all(cls, path):
1006        for m in list(dependencies._unresolved):
1007            if m._full_path.startswith(path):
1008                m._resolve()
1009
1010    _unresolved = set()
1011    _by_key = {}
1012
1013    class _importlater(object):
1014        _unresolved = set()
1015
1016        _by_key = {}
1017
1018        def __new__(cls, path, addtl):
1019            key = path + "." + addtl
1020            if key in dependencies._by_key:
1021                return dependencies._by_key[key]
1022            else:
1023                dependencies._by_key[key] = imp = object.__new__(cls)
1024                return imp
1025
1026        def __init__(self, path, addtl):
1027            self._il_path = path
1028            self._il_addtl = addtl
1029            dependencies._unresolved.add(self)
1030
1031        @property
1032        def _full_path(self):
1033            return self._il_path + "." + self._il_addtl
1034
1035        @memoized_property
1036        def module(self):
1037            if self in dependencies._unresolved:
1038                raise ImportError(
1039                    "importlater.resolve_all() hasn't "
1040                    "been called (this is %s %s)"
1041                    % (self._il_path, self._il_addtl)
1042                )
1043
1044            return getattr(self._initial_import, self._il_addtl)
1045
1046        def _resolve(self):
1047            dependencies._unresolved.discard(self)
1048            self._initial_import = compat.import_(
1049                self._il_path, globals(), locals(), [self._il_addtl]
1050            )
1051
1052        def __getattr__(self, key):
1053            if key == "module":
1054                raise ImportError(
1055                    "Could not resolve module %s" % self._full_path
1056                )
1057            try:
1058                attr = getattr(self.module, key)
1059            except AttributeError:
1060                raise AttributeError(
1061                    "Module %s has no attribute '%s'" % (self._full_path, key)
1062                )
1063            self.__dict__[key] = attr
1064            return attr
1065
1066
1067# from paste.deploy.converters
1068def asbool(obj):
1069    if isinstance(obj, compat.string_types):
1070        obj = obj.strip().lower()
1071        if obj in ["true", "yes", "on", "y", "t", "1"]:
1072            return True
1073        elif obj in ["false", "no", "off", "n", "f", "0"]:
1074            return False
1075        else:
1076            raise ValueError("String is not true/false: %r" % obj)
1077    return bool(obj)
1078
1079
1080def bool_or_str(*text):
1081    """Return a callable that will evaluate a string as
1082    boolean, or one of a set of "alternate" string values.
1083
1084    """
1085
1086    def bool_or_value(obj):
1087        if obj in text:
1088            return obj
1089        else:
1090            return asbool(obj)
1091
1092    return bool_or_value
1093
1094
1095def asint(value):
1096    """Coerce to integer."""
1097
1098    if value is None:
1099        return value
1100    return int(value)
1101
1102
1103def coerce_kw_type(kw, key, type_, flexi_bool=True):
1104    r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1105    necessary.  If 'flexi_bool' is True, the string '0' is considered false
1106    when coercing to boolean.
1107    """
1108
1109    if key in kw and not isinstance(kw[key], type_) and kw[key] is not None:
1110        if type_ is bool and flexi_bool:
1111            kw[key] = asbool(kw[key])
1112        else:
1113            kw[key] = type_(kw[key])
1114
1115
1116def constructor_copy(obj, cls, *args, **kw):
1117    """Instantiate cls using the __dict__ of obj as constructor arguments.
1118
1119    Uses inspect to match the named arguments of ``cls``.
1120
1121    """
1122
1123    names = get_cls_kwargs(cls)
1124    kw.update(
1125        (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1126    )
1127    return cls(*args, **kw)
1128
1129
1130def counter():
1131    """Return a threadsafe counter function."""
1132
1133    lock = compat.threading.Lock()
1134    counter = itertools.count(1)
1135
1136    # avoid the 2to3 "next" transformation...
1137    def _next():
1138        lock.acquire()
1139        try:
1140            return next(counter)
1141        finally:
1142            lock.release()
1143
1144    return _next
1145
1146
1147def duck_type_collection(specimen, default=None):
1148    """Given an instance or class, guess if it is or is acting as one of
1149    the basic collection types: list, set and dict.  If the __emulates__
1150    property is present, return that preferentially.
1151    """
1152
1153    if hasattr(specimen, "__emulates__"):
1154        # canonicalize set vs sets.Set to a standard: the builtin set
1155        if specimen.__emulates__ is not None and issubclass(
1156            specimen.__emulates__, set
1157        ):
1158            return set
1159        else:
1160            return specimen.__emulates__
1161
1162    isa = isinstance(specimen, type) and issubclass or isinstance
1163    if isa(specimen, list):
1164        return list
1165    elif isa(specimen, set):
1166        return set
1167    elif isa(specimen, dict):
1168        return dict
1169
1170    if hasattr(specimen, "append"):
1171        return list
1172    elif hasattr(specimen, "add"):
1173        return set
1174    elif hasattr(specimen, "set"):
1175        return dict
1176    else:
1177        return default
1178
1179
1180def assert_arg_type(arg, argtype, name):
1181    if isinstance(arg, argtype):
1182        return arg
1183    else:
1184        if isinstance(argtype, tuple):
1185            raise exc.ArgumentError(
1186                "Argument '%s' is expected to be one of type %s, got '%s'"
1187                % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1188            )
1189        else:
1190            raise exc.ArgumentError(
1191                "Argument '%s' is expected to be of type '%s', got '%s'"
1192                % (name, argtype, type(arg))
1193            )
1194
1195
1196def dictlike_iteritems(dictlike):
1197    """Return a (key, value) iterator for almost any dict-like object."""
1198
1199    if compat.py3k:
1200        if hasattr(dictlike, "items"):
1201            return list(dictlike.items())
1202    else:
1203        if hasattr(dictlike, "iteritems"):
1204            return dictlike.iteritems()
1205        elif hasattr(dictlike, "items"):
1206            return iter(dictlike.items())
1207
1208    getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1209    if getter is None:
1210        raise TypeError("Object '%r' is not dict-like" % dictlike)
1211
1212    if hasattr(dictlike, "iterkeys"):
1213
1214        def iterator():
1215            for key in dictlike.iterkeys():
1216                yield key, getter(key)
1217
1218        return iterator()
1219    elif hasattr(dictlike, "keys"):
1220        return iter((key, getter(key)) for key in dictlike.keys())
1221    else:
1222        raise TypeError("Object '%r' is not dict-like" % dictlike)
1223
1224
1225class classproperty(property):
1226    """A decorator that behaves like @property except that operates
1227    on classes rather than instances.
1228
1229    The decorator is currently special when using the declarative
1230    module, but note that the
1231    :class:`~.sqlalchemy.ext.declarative.declared_attr`
1232    decorator should be used for this purpose with declarative.
1233
1234    """
1235
1236    def __init__(self, fget, *arg, **kw):
1237        super(classproperty, self).__init__(fget, *arg, **kw)
1238        self.__doc__ = fget.__doc__
1239
1240    def __get__(desc, self, cls):
1241        return desc.fget(cls)
1242
1243
1244class hybridproperty(object):
1245    def __init__(self, func):
1246        self.func = func
1247
1248    def __get__(self, instance, owner):
1249        if instance is None:
1250            clsval = self.func(owner)
1251            clsval.__doc__ = self.func.__doc__
1252            return clsval
1253        else:
1254            return self.func(instance)
1255
1256
1257class hybridmethod(object):
1258    """Decorate a function as cls- or instance- level."""
1259
1260    def __init__(self, func):
1261        self.func = func
1262
1263    def __get__(self, instance, owner):
1264        if instance is None:
1265            return self.func.__get__(owner, owner.__class__)
1266        else:
1267            return self.func.__get__(instance, owner)
1268
1269
1270class _symbol(int):
1271    def __new__(self, name, doc=None, canonical=None):
1272        """Construct a new named symbol."""
1273        assert isinstance(name, compat.string_types)
1274        if canonical is None:
1275            canonical = hash(name)
1276        v = int.__new__(_symbol, canonical)
1277        v.name = name
1278        if doc:
1279            v.__doc__ = doc
1280        return v
1281
1282    def __reduce__(self):
1283        return symbol, (self.name, "x", int(self))
1284
1285    def __str__(self):
1286        return repr(self)
1287
1288    def __repr__(self):
1289        return "symbol(%r)" % self.name
1290
1291
1292_symbol.__name__ = "symbol"
1293
1294
1295class symbol(object):
1296    """A constant symbol.
1297
1298    >>> symbol('foo') is symbol('foo')
1299    True
1300    >>> symbol('foo')
1301    <symbol 'foo>
1302
1303    A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
1304    advantage of symbol() is its repr().  They are also singletons.
1305
1306    Repeated calls of symbol('name') will all return the same instance.
1307
1308    The optional ``doc`` argument assigns to ``__doc__``.  This
1309    is strictly so that Sphinx autoattr picks up the docstring we want
1310    (it doesn't appear to pick up the in-module docstring if the datamember
1311    is in a different module - autoattribute also blows up completely).
1312    If Sphinx fixes/improves this then we would no longer need
1313    ``doc`` here.
1314
1315    """
1316
1317    symbols = {}
1318    _lock = compat.threading.Lock()
1319
1320    def __new__(cls, name, doc=None, canonical=None):
1321        cls._lock.acquire()
1322        try:
1323            sym = cls.symbols.get(name)
1324            if sym is None:
1325                cls.symbols[name] = sym = _symbol(name, doc, canonical)
1326            return sym
1327        finally:
1328            symbol._lock.release()
1329
1330
1331_creation_order = 1
1332
1333
1334def set_creation_order(instance):
1335    """Assign a '_creation_order' sequence to the given instance.
1336
1337    This allows multiple instances to be sorted in order of creation
1338    (typically within a single thread; the counter is not particularly
1339    threadsafe).
1340
1341    """
1342    global _creation_order
1343    instance._creation_order = _creation_order
1344    _creation_order += 1
1345
1346
1347def warn_exception(func, *args, **kwargs):
1348    """executes the given function, catches all exceptions and converts to
1349    a warning.
1350
1351    """
1352    try:
1353        return func(*args, **kwargs)
1354    except Exception:
1355        warn("%s('%s') ignored" % sys.exc_info()[0:2])
1356
1357
1358def ellipses_string(value, len_=25):
1359    try:
1360        if len(value) > len_:
1361            return "%s..." % value[0:len_]
1362        else:
1363            return value
1364    except TypeError:
1365        return value
1366
1367
1368class _hash_limit_string(compat.text_type):
1369    """A string subclass that can only be hashed on a maximum amount
1370    of unique values.
1371
1372    This is used for warnings so that we can send out parameterized warnings
1373    without the __warningregistry__ of the module,  or the non-overridable
1374    "once" registry within warnings.py, overloading memory,
1375
1376
1377    """
1378
1379    def __new__(cls, value, num, args):
1380        interpolated = (value % args) + (
1381            " (this warning may be suppressed after %d occurrences)" % num
1382        )
1383        self = super(_hash_limit_string, cls).__new__(cls, interpolated)
1384        self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1385        return self
1386
1387    def __hash__(self):
1388        return self._hash
1389
1390    def __eq__(self, other):
1391        return hash(self) == hash(other)
1392
1393
1394def warn(msg):
1395    """Issue a warning.
1396
1397    If msg is a string, :class:`.exc.SAWarning` is used as
1398    the category.
1399
1400    """
1401    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1402
1403
1404def warn_limited(msg, args):
1405    """Issue a warning with a parameterized string, limiting the number
1406    of registrations.
1407
1408    """
1409    if args:
1410        msg = _hash_limit_string(msg, 10, args)
1411    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1412
1413
1414def only_once(fn):
1415    """Decorate the given function to be a no-op after it is called exactly
1416    once."""
1417
1418    once = [fn]
1419
1420    def go(*arg, **kw):
1421        if once:
1422            once_fn = once.pop()
1423            return once_fn(*arg, **kw)
1424
1425    return go
1426
1427
1428_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1429_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1430
1431
1432def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
1433    """Chop extraneous lines off beginning and end of a traceback.
1434
1435    :param tb:
1436      a list of traceback lines as returned by ``traceback.format_stack()``
1437
1438    :param exclude_prefix:
1439      a regular expression object matching lines to skip at beginning of
1440      ``tb``
1441
1442    :param exclude_suffix:
1443      a regular expression object matching lines to skip at end of ``tb``
1444    """
1445    start = 0
1446    end = len(tb) - 1
1447    while start <= end and exclude_prefix.search(tb[start]):
1448        start += 1
1449    while start <= end and exclude_suffix.search(tb[end]):
1450        end -= 1
1451    return tb[start : end + 1]
1452
1453
1454NoneType = type(None)
1455
1456
1457def attrsetter(attrname):
1458    code = "def set(obj, value):" "    obj.%s = value" % attrname
1459    env = locals().copy()
1460    exec(code, env)
1461    return env["set"]
1462
1463
1464class EnsureKWArgType(type):
1465    r"""Apply translation of functions to accept \**kw arguments if they
1466    don't already.
1467
1468    """
1469
1470    def __init__(cls, clsname, bases, clsdict):
1471        fn_reg = cls.ensure_kwarg
1472        if fn_reg:
1473            for key in clsdict:
1474                m = re.match(fn_reg, key)
1475                if m:
1476                    fn = clsdict[key]
1477                    spec = compat.inspect_getargspec(fn)
1478                    if not spec.keywords:
1479                        clsdict[key] = wrapped = cls._wrap_w_kw(fn)
1480                        setattr(cls, key, wrapped)
1481        super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
1482
1483    def _wrap_w_kw(self, fn):
1484        def wrap(*arg, **kw):
1485            return fn(*arg)
1486
1487        return update_wrapper(wrap, fn)
1488
1489
1490def wrap_callable(wrapper, fn):
1491    """Augment functools.update_wrapper() to work with objects with
1492    a ``__call__()`` method.
1493
1494    :param fn:
1495      object with __call__ method
1496
1497    """
1498    if hasattr(fn, "__name__"):
1499        return update_wrapper(wrapper, fn)
1500    else:
1501        _f = wrapper
1502        _f.__name__ = fn.__class__.__name__
1503        if hasattr(fn, "__module__"):
1504            _f.__module__ = fn.__module__
1505
1506        if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
1507            _f.__doc__ = fn.__call__.__doc__
1508        elif fn.__doc__:
1509            _f.__doc__ = fn.__doc__
1510
1511        return _f
1512
1513
1514def quoted_token_parser(value):
1515    """Parse a dotted identifier with accommodation for quoted names.
1516
1517    Includes support for SQL-style double quotes as a literal character.
1518
1519    E.g.::
1520
1521        >>> quoted_token_parser("name")
1522        ["name"]
1523        >>> quoted_token_parser("schema.name")
1524        ["schema", "name"]
1525        >>> quoted_token_parser('"Schema"."Name"')
1526        ['Schema', 'Name']
1527        >>> quoted_token_parser('"Schema"."Name""Foo"')
1528        ['Schema', 'Name""Foo']
1529
1530    """
1531
1532    if '"' not in value:
1533        return value.split(".")
1534
1535    # 0 = outside of quotes
1536    # 1 = inside of quotes
1537    state = 0
1538    result = [[]]
1539    idx = 0
1540    lv = len(value)
1541    while idx < lv:
1542        char = value[idx]
1543        if char == '"':
1544            if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
1545                result[-1].append('"')
1546                idx += 1
1547            else:
1548                state ^= 1
1549        elif char == "." and state == 0:
1550            result.append([])
1551        else:
1552            result[-1].append(char)
1553        idx += 1
1554
1555    return ["".join(token) for token in result]
1556