1# util/langhelpers.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Routines to help with the creation, loading and introspection of
9modules, classes, hierarchies, attributes, functions, and methods.
10
11"""
12from functools import update_wrapper
13import hashlib
14import inspect
15import itertools
16import operator
17import re
18import sys
19import textwrap
20import types
21import warnings
22
23from . import _collections
24from . import compat
25from .. import exc
26
27
28def md5_hex(x):
29    if compat.py3k:
30        x = x.encode("utf-8")
31    m = hashlib.md5()
32    m.update(x)
33    return m.hexdigest()
34
35
36class safe_reraise(object):
37    """Reraise an exception after invoking some
38    handler code.
39
40    Stores the existing exception info before
41    invoking so that it is maintained across a potential
42    coroutine context switch.
43
44    e.g.::
45
46        try:
47            sess.commit()
48        except:
49            with safe_reraise():
50                sess.rollback()
51
52    """
53
54    __slots__ = ("warn_only", "_exc_info")
55
56    def __init__(self, warn_only=False):
57        self.warn_only = warn_only
58
59    def __enter__(self):
60        self._exc_info = sys.exc_info()
61
62    def __exit__(self, type_, value, traceback):
63        # see #2703 for notes
64        if type_ is None:
65            exc_type, exc_value, exc_tb = self._exc_info
66            self._exc_info = None  # remove potential circular references
67            if not self.warn_only:
68                compat.raise_(
69                    exc_value,
70                    with_traceback=exc_tb,
71                )
72        else:
73            if not compat.py3k and self._exc_info and self._exc_info[1]:
74                # emulate Py3K's behavior of telling us when an exception
75                # occurs in an exception handler.
76                warn(
77                    "An exception has occurred during handling of a "
78                    "previous exception.  The previous exception "
79                    "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
80                )
81            self._exc_info = None  # remove potential circular references
82            compat.raise_(value, with_traceback=traceback)
83
84
85def clsname_as_plain_name(cls):
86    return " ".join(
87        n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
88    )
89
90
91def decode_slice(slc):
92    """decode a slice object as sent to __getitem__.
93
94    takes into account the 2.5 __index__() method, basically.
95
96    """
97    ret = []
98    for x in slc.start, slc.stop, slc.step:
99        if hasattr(x, "__index__"):
100            x = x.__index__()
101        ret.append(x)
102    return tuple(ret)
103
104
105def _unique_symbols(used, *bases):
106    used = set(used)
107    for base in bases:
108        pool = itertools.chain(
109            (base,),
110            compat.itertools_imap(lambda i: base + str(i), range(1000)),
111        )
112        for sym in pool:
113            if sym not in used:
114                used.add(sym)
115                yield sym
116                break
117        else:
118            raise NameError("exhausted namespace for symbol base %s" % base)
119
120
121def map_bits(fn, n):
122    """Call the given function given each nonzero bit from n."""
123
124    while n:
125        b = n & (~n + 1)
126        yield fn(b)
127        n ^= b
128
129
130def decorator(target):
131    """A signature-matching decorator factory."""
132
133    def decorate(fn):
134        if not inspect.isfunction(fn) and not inspect.ismethod(fn):
135            raise Exception("not a decoratable function")
136
137        spec = compat.inspect_getfullargspec(fn)
138        names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
139        targ_name, fn_name = _unique_symbols(names, "target", "fn")
140
141        metadata = dict(target=targ_name, fn=fn_name)
142        metadata.update(format_argspec_plus(spec, grouped=False))
143        metadata["name"] = fn.__name__
144        code = (
145            """\
146def %(name)s(%(args)s):
147    return %(target)s(%(fn)s, %(apply_kw)s)
148"""
149            % metadata
150        )
151        decorated = _exec_code_in_env(
152            code, {targ_name: target, fn_name: fn}, fn.__name__
153        )
154        decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
155        decorated.__wrapped__ = fn
156        return update_wrapper(decorated, fn)
157
158    return update_wrapper(decorate, target)
159
160
161def _exec_code_in_env(code, env, fn_name):
162    exec(code, env)
163    return env[fn_name]
164
165
166def public_factory(target, location, class_location=None):
167    """Produce a wrapping function for the given cls or classmethod.
168
169    Rationale here is so that the __init__ method of the
170    class can serve as documentation for the function.
171
172    """
173
174    if isinstance(target, type):
175        fn = target.__init__
176        callable_ = target
177        doc = (
178            "Construct a new :class:`%s` object. \n\n"
179            "This constructor is mirrored as a public API function; "
180            "see :func:`sqlalchemy%s` "
181            "for a full usage and argument description."
182            % (
183                class_location if class_location else ".%s" % target.__name__,
184                location,
185            )
186        )
187    else:
188        fn = callable_ = target
189        doc = (
190            "This function is mirrored; see :func:`sqlalchemy%s` "
191            "for a description of arguments." % location
192        )
193
194    location_name = location.split(".")[-1]
195    spec = compat.inspect_getfullargspec(fn)
196    del spec[0][0]
197    metadata = format_argspec_plus(spec, grouped=False)
198    metadata["name"] = location_name
199    code = (
200        """\
201def %(name)s(%(args)s):
202    return cls(%(apply_kw)s)
203"""
204        % metadata
205    )
206    env = {"cls": callable_, "symbol": symbol}
207    exec(code, env)
208    decorated = env[location_name]
209
210    if hasattr(fn, "_linked_to"):
211        linked_to, linked_to_location = fn._linked_to
212        linked_to_doc = linked_to.__doc__
213        if class_location is None:
214            class_location = "%s.%s" % (target.__module__, target.__name__)
215
216        linked_to_doc = inject_docstring_text(
217            linked_to_doc,
218            ".. container:: inherited_member\n\n    "
219            "This documentation is inherited from :func:`sqlalchemy%s`; "
220            "this constructor, :func:`sqlalchemy%s`,   "
221            "creates a :class:`sqlalchemy%s` object.  See that class for "
222            "additional details describing this subclass."
223            % (linked_to_location, location, class_location),
224            1,
225        )
226        decorated.__doc__ = linked_to_doc
227    else:
228        decorated.__doc__ = fn.__doc__
229
230    decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
231    if decorated.__module__ not in sys.modules:
232        raise ImportError(
233            "public_factory location %s is not in sys.modules"
234            % (decorated.__module__,)
235        )
236
237    if compat.py2k or hasattr(fn, "__func__"):
238        fn.__func__.__doc__ = doc
239        if not hasattr(fn.__func__, "_linked_to"):
240            fn.__func__._linked_to = (decorated, location)
241    else:
242        fn.__doc__ = doc
243        if not hasattr(fn, "_linked_to"):
244            fn._linked_to = (decorated, location)
245
246    return decorated
247
248
249class PluginLoader(object):
250    def __init__(self, group, auto_fn=None):
251        self.group = group
252        self.impls = {}
253        self.auto_fn = auto_fn
254
255    def clear(self):
256        self.impls.clear()
257
258    def load(self, name):
259        if name in self.impls:
260            return self.impls[name]()
261
262        if self.auto_fn:
263            loader = self.auto_fn(name)
264            if loader:
265                self.impls[name] = loader
266                return loader()
267
268        try:
269            import pkg_resources
270        except ImportError:
271            pass
272        else:
273            for impl in pkg_resources.iter_entry_points(self.group, name):
274                self.impls[name] = impl.load
275                return impl.load()
276
277        raise exc.NoSuchModuleError(
278            "Can't load plugin: %s:%s" % (self.group, name)
279        )
280
281    def register(self, name, modulepath, objname):
282        def load():
283            mod = compat.import_(modulepath)
284            for token in modulepath.split(".")[1:]:
285                mod = getattr(mod, token)
286            return getattr(mod, objname)
287
288        self.impls[name] = load
289
290
291def _inspect_func_args(fn):
292    try:
293        co_varkeywords = inspect.CO_VARKEYWORDS
294    except AttributeError:
295        # https://docs.python.org/3/library/inspect.html
296        # The flags are specific to CPython, and may not be defined in other
297        # Python implementations. Furthermore, the flags are an implementation
298        # detail, and can be removed or deprecated in future Python releases.
299        spec = compat.inspect_getfullargspec(fn)
300        return spec[0], bool(spec[2])
301    else:
302        # use fn.__code__ plus flags to reduce method call overhead
303        co = fn.__code__
304        nargs = co.co_argcount
305        return (
306            list(co.co_varnames[:nargs]),
307            bool(co.co_flags & co_varkeywords),
308        )
309
310
311def get_cls_kwargs(cls, _set=None):
312    r"""Return the full set of inherited kwargs for the given `cls`.
313
314    Probes a class's __init__ method, collecting all named arguments.  If the
315    __init__ defines a \**kwargs catch-all, then the constructor is presumed
316    to pass along unrecognized keywords to its base classes, and the
317    collection process is repeated recursively on each of the bases.
318
319    Uses a subset of inspect.getfullargspec() to cut down on method overhead,
320    as this is used within the Core typing system to create copies of type
321    objects which is a performance-sensitive operation.
322
323    No anonymous tuple arguments please !
324
325    """
326    toplevel = _set is None
327    if toplevel:
328        _set = set()
329
330    ctr = cls.__dict__.get("__init__", False)
331
332    has_init = (
333        ctr
334        and isinstance(ctr, types.FunctionType)
335        and isinstance(ctr.__code__, types.CodeType)
336    )
337
338    if has_init:
339        names, has_kw = _inspect_func_args(ctr)
340        _set.update(names)
341
342        if not has_kw and not toplevel:
343            return None
344
345    if not has_init or has_kw:
346        for c in cls.__bases__:
347            if get_cls_kwargs(c, _set) is None:
348                break
349
350    _set.discard("self")
351    return _set
352
353
354def get_func_kwargs(func):
355    """Return the set of legal kwargs for the given `func`.
356
357    Uses getargspec so is safe to call for methods, functions,
358    etc.
359
360    """
361
362    return compat.inspect_getfullargspec(func)[0]
363
364
365def get_callable_argspec(fn, no_self=False, _is_init=False):
366    """Return the argument signature for any callable.
367
368    All pure-Python callables are accepted, including
369    functions, methods, classes, objects with __call__;
370    builtins and other edge cases like functools.partial() objects
371    raise a TypeError.
372
373    """
374    if inspect.isbuiltin(fn):
375        raise TypeError("Can't inspect builtin: %s" % fn)
376    elif inspect.isfunction(fn):
377        if _is_init and no_self:
378            spec = compat.inspect_getfullargspec(fn)
379            return compat.FullArgSpec(
380                spec.args[1:],
381                spec.varargs,
382                spec.varkw,
383                spec.defaults,
384                spec.kwonlyargs,
385                spec.kwonlydefaults,
386                spec.annotations,
387            )
388        else:
389            return compat.inspect_getfullargspec(fn)
390    elif inspect.ismethod(fn):
391        if no_self and (_is_init or fn.__self__):
392            spec = compat.inspect_getfullargspec(fn.__func__)
393            return compat.FullArgSpec(
394                spec.args[1:],
395                spec.varargs,
396                spec.varkw,
397                spec.defaults,
398                spec.kwonlyargs,
399                spec.kwonlydefaults,
400                spec.annotations,
401            )
402        else:
403            return compat.inspect_getfullargspec(fn.__func__)
404    elif inspect.isclass(fn):
405        return get_callable_argspec(
406            fn.__init__, no_self=no_self, _is_init=True
407        )
408    elif hasattr(fn, "__func__"):
409        return compat.inspect_getfullargspec(fn.__func__)
410    elif hasattr(fn, "__call__"):
411        if inspect.ismethod(fn.__call__):
412            return get_callable_argspec(fn.__call__, no_self=no_self)
413        else:
414            raise TypeError("Can't inspect callable: %s" % fn)
415    else:
416        raise TypeError("Can't inspect callable: %s" % fn)
417
418
419def format_argspec_plus(fn, grouped=True):
420    """Returns a dictionary of formatted, introspected function arguments.
421
422    A enhanced variant of inspect.formatargspec to support code generation.
423
424    fn
425       An inspectable callable or tuple of inspect getargspec() results.
426    grouped
427      Defaults to True; include (parens, around, argument) lists
428
429    Returns:
430
431    args
432      Full inspect.formatargspec for fn
433    self_arg
434      The name of the first positional argument, varargs[0], or None
435      if the function defines no positional arguments.
436    apply_pos
437      args, re-written in calling rather than receiving syntax.  Arguments are
438      passed positionally.
439    apply_kw
440      Like apply_pos, except keyword-ish args are passed as keywords.
441
442    Example::
443
444      >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
445      {'args': '(self, a, b, c=3, **d)',
446       'self_arg': 'self',
447       'apply_kw': '(self, a, b, c=c, **d)',
448       'apply_pos': '(self, a, b, c, **d)'}
449
450    """
451    if compat.callable(fn):
452        spec = compat.inspect_getfullargspec(fn)
453    else:
454        spec = fn
455
456    args = compat.inspect_formatargspec(*spec)
457    if spec[0]:
458        self_arg = spec[0][0]
459    elif spec[1]:
460        self_arg = "%s[0]" % spec[1]
461    else:
462        self_arg = None
463
464    apply_pos = compat.inspect_formatargspec(
465        spec[0], spec[1], spec[2], None, spec[4]
466    )
467    num_defaults = 0
468    if spec[3]:
469        num_defaults += len(spec[3])
470    if spec[4]:
471        num_defaults += len(spec[4])
472    name_args = spec[0] + spec[4]
473
474    if num_defaults:
475        defaulted_vals = name_args[0 - num_defaults :]
476    else:
477        defaulted_vals = ()
478
479    apply_kw = compat.inspect_formatargspec(
480        name_args,
481        spec[1],
482        spec[2],
483        defaulted_vals,
484        formatvalue=lambda x: "=" + x,
485    )
486    if grouped:
487        return dict(
488            args=args,
489            self_arg=self_arg,
490            apply_pos=apply_pos,
491            apply_kw=apply_kw,
492        )
493    else:
494        return dict(
495            args=args[1:-1],
496            self_arg=self_arg,
497            apply_pos=apply_pos[1:-1],
498            apply_kw=apply_kw[1:-1],
499        )
500
501
502def format_argspec_init(method, grouped=True):
503    """format_argspec_plus with considerations for typical __init__ methods
504
505    Wraps format_argspec_plus with error handling strategies for typical
506    __init__ cases::
507
508      object.__init__ -> (self)
509      other unreflectable (usually C) -> (self, *args, **kwargs)
510
511    """
512    if method is object.__init__:
513        args = grouped and "(self)" or "self"
514    else:
515        try:
516            return format_argspec_plus(method, grouped=grouped)
517        except TypeError:
518            args = (
519                grouped
520                and "(self, *args, **kwargs)"
521                or "self, *args, **kwargs"
522            )
523    return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
524
525
526def getargspec_init(method):
527    """inspect.getargspec with considerations for typical __init__ methods
528
529    Wraps inspect.getargspec with error handling for typical __init__ cases::
530
531      object.__init__ -> (self)
532      other unreflectable (usually C) -> (self, *args, **kwargs)
533
534    """
535    try:
536        return compat.inspect_getfullargspec(method)
537    except TypeError:
538        if method is object.__init__:
539            return (["self"], None, None, None)
540        else:
541            return (["self"], "args", "kwargs", None)
542
543
544def unbound_method_to_callable(func_or_cls):
545    """Adjust the incoming callable such that a 'self' argument is not
546    required.
547
548    """
549
550    if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
551        return func_or_cls.__func__
552    else:
553        return func_or_cls
554
555
556def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
557    """Produce a __repr__() based on direct association of the __init__()
558    specification vs. same-named attributes present.
559
560    """
561    if to_inspect is None:
562        to_inspect = [obj]
563    else:
564        to_inspect = _collections.to_list(to_inspect)
565
566    missing = object()
567
568    pos_args = []
569    kw_args = _collections.OrderedDict()
570    vargs = None
571    for i, insp in enumerate(to_inspect):
572        try:
573            spec = compat.inspect_getfullargspec(insp.__init__)
574        except TypeError:
575            continue
576        else:
577            default_len = spec.defaults and len(spec.defaults) or 0
578            if i == 0:
579                if spec.varargs:
580                    vargs = spec.varargs
581                if default_len:
582                    pos_args.extend(spec.args[1:-default_len])
583                else:
584                    pos_args.extend(spec.args[1:])
585            else:
586                kw_args.update(
587                    [(arg, missing) for arg in spec.args[1:-default_len]]
588                )
589
590            if default_len:
591                kw_args.update(
592                    [
593                        (arg, default)
594                        for arg, default in zip(
595                            spec.args[-default_len:], spec.defaults
596                        )
597                    ]
598                )
599    output = []
600
601    output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
602
603    if vargs is not None and hasattr(obj, vargs):
604        output.extend([repr(val) for val in getattr(obj, vargs)])
605
606    for arg, defval in kw_args.items():
607        if arg in omit_kwarg:
608            continue
609        try:
610            val = getattr(obj, arg, missing)
611            if val is not missing and val != defval:
612                output.append("%s=%r" % (arg, val))
613        except Exception:
614            pass
615
616    if additional_kw:
617        for arg, defval in additional_kw:
618            try:
619                val = getattr(obj, arg, missing)
620                if val is not missing and val != defval:
621                    output.append("%s=%r" % (arg, val))
622            except Exception:
623                pass
624
625    return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
626
627
628class portable_instancemethod(object):
629    """Turn an instancemethod into a (parent, name) pair
630    to produce a serializable callable.
631
632    """
633
634    __slots__ = "target", "name", "kwargs", "__weakref__"
635
636    def __getstate__(self):
637        return {
638            "target": self.target,
639            "name": self.name,
640            "kwargs": self.kwargs,
641        }
642
643    def __setstate__(self, state):
644        self.target = state["target"]
645        self.name = state["name"]
646        self.kwargs = state.get("kwargs", ())
647
648    def __init__(self, meth, kwargs=()):
649        self.target = meth.__self__
650        self.name = meth.__name__
651        self.kwargs = kwargs
652
653    def __call__(self, *arg, **kw):
654        kw.update(self.kwargs)
655        return getattr(self.target, self.name)(*arg, **kw)
656
657
658def class_hierarchy(cls):
659    """Return an unordered sequence of all classes related to cls.
660
661    Traverses diamond hierarchies.
662
663    Fibs slightly: subclasses of builtin types are not returned.  Thus
664    class_hierarchy(class A(object)) returns (A, object), not A plus every
665    class systemwide that derives from object.
666
667    Old-style classes are discarded and hierarchies rooted on them
668    will not be descended.
669
670    """
671    if compat.py2k:
672        if isinstance(cls, types.ClassType):
673            return list()
674
675    hier = {cls}
676    process = list(cls.__mro__)
677    while process:
678        c = process.pop()
679        if compat.py2k:
680            if isinstance(c, types.ClassType):
681                continue
682            bases = (
683                _
684                for _ in c.__bases__
685                if _ not in hier and not isinstance(_, types.ClassType)
686            )
687        else:
688            bases = (_ for _ in c.__bases__ if _ not in hier)
689
690        for b in bases:
691            process.append(b)
692            hier.add(b)
693
694        if compat.py3k:
695            if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
696                continue
697        else:
698            if c.__module__ == "__builtin__" or not hasattr(
699                c, "__subclasses__"
700            ):
701                continue
702
703        for s in [_ for _ in c.__subclasses__() if _ not in hier]:
704            process.append(s)
705            hier.add(s)
706    return list(hier)
707
708
709def iterate_attributes(cls):
710    """iterate all the keys and attributes associated
711    with a class, without using getattr().
712
713    Does not use getattr() so that class-sensitive
714    descriptors (i.e. property.__get__()) are not called.
715
716    """
717    keys = dir(cls)
718    for key in keys:
719        for c in cls.__mro__:
720            if key in c.__dict__:
721                yield (key, c.__dict__[key])
722                break
723
724
725def monkeypatch_proxied_specials(
726    into_cls,
727    from_cls,
728    skip=None,
729    only=None,
730    name="self.proxy",
731    from_instance=None,
732):
733    """Automates delegation of __specials__ for a proxying type."""
734
735    if only:
736        dunders = only
737    else:
738        if skip is None:
739            skip = (
740                "__slots__",
741                "__del__",
742                "__getattribute__",
743                "__metaclass__",
744                "__getstate__",
745                "__setstate__",
746            )
747        dunders = [
748            m
749            for m in dir(from_cls)
750            if (
751                m.startswith("__")
752                and m.endswith("__")
753                and not hasattr(into_cls, m)
754                and m not in skip
755            )
756        ]
757
758    for method in dunders:
759        try:
760            fn = getattr(from_cls, method)
761            if not hasattr(fn, "__call__"):
762                continue
763            fn = getattr(fn, "im_func", fn)
764        except AttributeError:
765            continue
766        try:
767            spec = compat.inspect_getfullargspec(fn)
768            fn_args = compat.inspect_formatargspec(spec[0])
769            d_args = compat.inspect_formatargspec(spec[0][1:])
770        except TypeError:
771            fn_args = "(self, *args, **kw)"
772            d_args = "(*args, **kw)"
773
774        py = (
775            "def %(method)s%(fn_args)s: "
776            "return %(name)s.%(method)s%(d_args)s" % locals()
777        )
778
779        env = from_instance is not None and {name: from_instance} or {}
780        compat.exec_(py, env)
781        try:
782            env[method].__defaults__ = fn.__defaults__
783        except AttributeError:
784            pass
785        setattr(into_cls, method, env[method])
786
787
788def methods_equivalent(meth1, meth2):
789    """Return True if the two methods are the same implementation."""
790
791    return getattr(meth1, "__func__", meth1) is getattr(
792        meth2, "__func__", meth2
793    )
794
795
796def as_interface(obj, cls=None, methods=None, required=None):
797    """Ensure basic interface compliance for an instance or dict of callables.
798
799    Checks that ``obj`` implements public methods of ``cls`` or has members
800    listed in ``methods``. If ``required`` is not supplied, implementing at
801    least one interface method is sufficient. Methods present on ``obj`` that
802    are not in the interface are ignored.
803
804    If ``obj`` is a dict and ``dict`` does not meet the interface
805    requirements, the keys of the dictionary are inspected. Keys present in
806    ``obj`` that are not in the interface will raise TypeErrors.
807
808    Raises TypeError if ``obj`` does not meet the interface criteria.
809
810    In all passing cases, an object with callable members is returned.  In the
811    simple case, ``obj`` is returned as-is; if dict processing kicks in then
812    an anonymous class is returned.
813
814    obj
815      A type, instance, or dictionary of callables.
816    cls
817      Optional, a type.  All public methods of cls are considered the
818      interface.  An ``obj`` instance of cls will always pass, ignoring
819      ``required``..
820    methods
821      Optional, a sequence of method names to consider as the interface.
822    required
823      Optional, a sequence of mandatory implementations. If omitted, an
824      ``obj`` that provides at least one interface method is considered
825      sufficient.  As a convenience, required may be a type, in which case
826      all public methods of the type are required.
827
828    """
829    if not cls and not methods:
830        raise TypeError("a class or collection of method names are required")
831
832    if isinstance(cls, type) and isinstance(obj, cls):
833        return obj
834
835    interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
836    implemented = set(dir(obj))
837
838    complies = operator.ge
839    if isinstance(required, type):
840        required = interface
841    elif not required:
842        required = set()
843        complies = operator.gt
844    else:
845        required = set(required)
846
847    if complies(implemented.intersection(interface), required):
848        return obj
849
850    # No dict duck typing here.
851    if not isinstance(obj, dict):
852        qualifier = complies is operator.gt and "any of" or "all of"
853        raise TypeError(
854            "%r does not implement %s: %s"
855            % (obj, qualifier, ", ".join(interface))
856        )
857
858    class AnonymousInterface(object):
859        """A callable-holding shell."""
860
861    if cls:
862        AnonymousInterface.__name__ = "Anonymous" + cls.__name__
863    found = set()
864
865    for method, impl in dictlike_iteritems(obj):
866        if method not in interface:
867            raise TypeError("%r: unknown in this interface" % method)
868        if not compat.callable(impl):
869            raise TypeError("%r=%r is not callable" % (method, impl))
870        setattr(AnonymousInterface, method, staticmethod(impl))
871        found.add(method)
872
873    if complies(found, required):
874        return AnonymousInterface
875
876    raise TypeError(
877        "dictionary does not contain required keys %s"
878        % ", ".join(required - found)
879    )
880
881
882class memoized_property(object):
883    """A read-only @property that is only evaluated once."""
884
885    def __init__(self, fget, doc=None):
886        self.fget = fget
887        self.__doc__ = doc or fget.__doc__
888        self.__name__ = fget.__name__
889
890    def __get__(self, obj, cls):
891        if obj is None:
892            return self
893        obj.__dict__[self.__name__] = result = self.fget(obj)
894        return result
895
896    def _reset(self, obj):
897        memoized_property.reset(obj, self.__name__)
898
899    @classmethod
900    def reset(cls, obj, name):
901        obj.__dict__.pop(name, None)
902
903
904def memoized_instancemethod(fn):
905    """Decorate a method memoize its return value.
906
907    Best applied to no-arg methods: memoization is not sensitive to
908    argument values, and will always return the same value even when
909    called with different arguments.
910
911    """
912
913    def oneshot(self, *args, **kw):
914        result = fn(self, *args, **kw)
915
916        def memo(*a, **kw):
917            return result
918
919        memo.__name__ = fn.__name__
920        memo.__doc__ = fn.__doc__
921        self.__dict__[fn.__name__] = memo
922        return result
923
924    return update_wrapper(oneshot, fn)
925
926
927class group_expirable_memoized_property(object):
928    """A family of @memoized_properties that can be expired in tandem."""
929
930    def __init__(self, attributes=()):
931        self.attributes = []
932        if attributes:
933            self.attributes.extend(attributes)
934
935    def expire_instance(self, instance):
936        """Expire all memoized properties for *instance*."""
937        stash = instance.__dict__
938        for attribute in self.attributes:
939            stash.pop(attribute, None)
940
941    def __call__(self, fn):
942        self.attributes.append(fn.__name__)
943        return memoized_property(fn)
944
945    def method(self, fn):
946        self.attributes.append(fn.__name__)
947        return memoized_instancemethod(fn)
948
949
950class MemoizedSlots(object):
951    """Apply memoized items to an object using a __getattr__ scheme.
952
953    This allows the functionality of memoized_property and
954    memoized_instancemethod to be available to a class using __slots__.
955
956    """
957
958    __slots__ = ()
959
960    def _fallback_getattr(self, key):
961        raise AttributeError(key)
962
963    def __getattr__(self, key):
964        if key.startswith("_memoized"):
965            raise AttributeError(key)
966        elif hasattr(self, "_memoized_attr_%s" % key):
967            value = getattr(self, "_memoized_attr_%s" % key)()
968            setattr(self, key, value)
969            return value
970        elif hasattr(self, "_memoized_method_%s" % key):
971            fn = getattr(self, "_memoized_method_%s" % key)
972
973            def oneshot(*args, **kw):
974                result = fn(*args, **kw)
975
976                def memo(*a, **kw):
977                    return result
978
979                memo.__name__ = fn.__name__
980                memo.__doc__ = fn.__doc__
981                setattr(self, key, memo)
982                return result
983
984            oneshot.__doc__ = fn.__doc__
985            return oneshot
986        else:
987            return self._fallback_getattr(key)
988
989
990def dependency_for(modulename, add_to_all=False):
991    def decorate(obj):
992        tokens = modulename.split(".")
993        mod = compat.import_(
994            ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]]
995        )
996        mod = getattr(mod, tokens[-1])
997        setattr(mod, obj.__name__, obj)
998        if add_to_all and hasattr(mod, "__all__"):
999            mod.__all__.append(obj.__name__)
1000        return obj
1001
1002    return decorate
1003
1004
1005def asbool(obj):
1006    if isinstance(obj, compat.string_types):
1007        obj = obj.strip().lower()
1008        if obj in ["true", "yes", "on", "y", "t", "1"]:
1009            return True
1010        elif obj in ["false", "no", "off", "n", "f", "0"]:
1011            return False
1012        else:
1013            raise ValueError("String is not true/false: %r" % obj)
1014    return bool(obj)
1015
1016
1017def bool_or_str(*text):
1018    """Return a callable that will evaluate a string as
1019    boolean, or one of a set of "alternate" string values.
1020
1021    """
1022
1023    def bool_or_value(obj):
1024        if obj in text:
1025            return obj
1026        else:
1027            return asbool(obj)
1028
1029    return bool_or_value
1030
1031
1032def asint(value):
1033    """Coerce to integer."""
1034
1035    if value is None:
1036        return value
1037    return int(value)
1038
1039
1040def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None):
1041    r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1042    necessary.  If 'flexi_bool' is True, the string '0' is considered false
1043    when coercing to boolean.
1044    """
1045
1046    if dest is None:
1047        dest = kw
1048
1049    if (
1050        key in kw
1051        and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1052        and kw[key] is not None
1053    ):
1054        if type_ is bool and flexi_bool:
1055            dest[key] = asbool(kw[key])
1056        else:
1057            dest[key] = type_(kw[key])
1058
1059
1060def constructor_copy(obj, cls, *args, **kw):
1061    """Instantiate cls using the __dict__ of obj as constructor arguments.
1062
1063    Uses inspect to match the named arguments of ``cls``.
1064
1065    """
1066
1067    names = get_cls_kwargs(cls)
1068    kw.update(
1069        (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1070    )
1071    return cls(*args, **kw)
1072
1073
1074def counter():
1075    """Return a threadsafe counter function."""
1076
1077    lock = compat.threading.Lock()
1078    counter = itertools.count(1)
1079
1080    # avoid the 2to3 "next" transformation...
1081    def _next():
1082        lock.acquire()
1083        try:
1084            return next(counter)
1085        finally:
1086            lock.release()
1087
1088    return _next
1089
1090
1091def duck_type_collection(specimen, default=None):
1092    """Given an instance or class, guess if it is or is acting as one of
1093    the basic collection types: list, set and dict.  If the __emulates__
1094    property is present, return that preferentially.
1095    """
1096
1097    if hasattr(specimen, "__emulates__"):
1098        # canonicalize set vs sets.Set to a standard: the builtin set
1099        if specimen.__emulates__ is not None and issubclass(
1100            specimen.__emulates__, set
1101        ):
1102            return set
1103        else:
1104            return specimen.__emulates__
1105
1106    isa = isinstance(specimen, type) and issubclass or isinstance
1107    if isa(specimen, list):
1108        return list
1109    elif isa(specimen, set):
1110        return set
1111    elif isa(specimen, dict):
1112        return dict
1113
1114    if hasattr(specimen, "append"):
1115        return list
1116    elif hasattr(specimen, "add"):
1117        return set
1118    elif hasattr(specimen, "set"):
1119        return dict
1120    else:
1121        return default
1122
1123
1124def assert_arg_type(arg, argtype, name):
1125    if isinstance(arg, argtype):
1126        return arg
1127    else:
1128        if isinstance(argtype, tuple):
1129            raise exc.ArgumentError(
1130                "Argument '%s' is expected to be one of type %s, got '%s'"
1131                % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1132            )
1133        else:
1134            raise exc.ArgumentError(
1135                "Argument '%s' is expected to be of type '%s', got '%s'"
1136                % (name, argtype, type(arg))
1137            )
1138
1139
1140def dictlike_iteritems(dictlike):
1141    """Return a (key, value) iterator for almost any dict-like object."""
1142
1143    if compat.py3k:
1144        if hasattr(dictlike, "items"):
1145            return list(dictlike.items())
1146    else:
1147        if hasattr(dictlike, "iteritems"):
1148            return dictlike.iteritems()
1149        elif hasattr(dictlike, "items"):
1150            return iter(dictlike.items())
1151
1152    getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1153    if getter is None:
1154        raise TypeError("Object '%r' is not dict-like" % dictlike)
1155
1156    if hasattr(dictlike, "iterkeys"):
1157
1158        def iterator():
1159            for key in dictlike.iterkeys():
1160                yield key, getter(key)
1161
1162        return iterator()
1163    elif hasattr(dictlike, "keys"):
1164        return iter((key, getter(key)) for key in dictlike.keys())
1165    else:
1166        raise TypeError("Object '%r' is not dict-like" % dictlike)
1167
1168
1169class classproperty(property):
1170    """A decorator that behaves like @property except that operates
1171    on classes rather than instances.
1172
1173    The decorator is currently special when using the declarative
1174    module, but note that the
1175    :class:`~.sqlalchemy.ext.declarative.declared_attr`
1176    decorator should be used for this purpose with declarative.
1177
1178    """
1179
1180    def __init__(self, fget, *arg, **kw):
1181        super(classproperty, self).__init__(fget, *arg, **kw)
1182        self.__doc__ = fget.__doc__
1183
1184    def __get__(desc, self, cls):
1185        return desc.fget(cls)
1186
1187
1188class hybridproperty(object):
1189    def __init__(self, func):
1190        self.func = func
1191
1192    def __get__(self, instance, owner):
1193        if instance is None:
1194            clsval = self.func(owner)
1195            clsval.__doc__ = self.func.__doc__
1196            return clsval
1197        else:
1198            return self.func(instance)
1199
1200
1201class hybridmethod(object):
1202    """Decorate a function as cls- or instance- level."""
1203
1204    def __init__(self, func):
1205        self.func = func
1206
1207    def __get__(self, instance, owner):
1208        if instance is None:
1209            return self.func.__get__(owner, owner.__class__)
1210        else:
1211            return self.func.__get__(instance, owner)
1212
1213
1214class _symbol(int):
1215    def __new__(self, name, doc=None, canonical=None):
1216        """Construct a new named symbol."""
1217        assert isinstance(name, compat.string_types)
1218        if canonical is None:
1219            canonical = hash(name)
1220        v = int.__new__(_symbol, canonical)
1221        v.name = name
1222        if doc:
1223            v.__doc__ = doc
1224        return v
1225
1226    def __reduce__(self):
1227        return symbol, (self.name, "x", int(self))
1228
1229    def __str__(self):
1230        return repr(self)
1231
1232    def __repr__(self):
1233        return "symbol(%r)" % self.name
1234
1235
1236_symbol.__name__ = "symbol"
1237
1238
1239class symbol(object):
1240    """A constant symbol.
1241
1242    >>> symbol('foo') is symbol('foo')
1243    True
1244    >>> symbol('foo')
1245    <symbol 'foo>
1246
1247    A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
1248    advantage of symbol() is its repr().  They are also singletons.
1249
1250    Repeated calls of symbol('name') will all return the same instance.
1251
1252    The optional ``doc`` argument assigns to ``__doc__``.  This
1253    is strictly so that Sphinx autoattr picks up the docstring we want
1254    (it doesn't appear to pick up the in-module docstring if the datamember
1255    is in a different module - autoattribute also blows up completely).
1256    If Sphinx fixes/improves this then we would no longer need
1257    ``doc`` here.
1258
1259    """
1260
1261    symbols = {}
1262    _lock = compat.threading.Lock()
1263
1264    def __new__(cls, name, doc=None, canonical=None):
1265        cls._lock.acquire()
1266        try:
1267            sym = cls.symbols.get(name)
1268            if sym is None:
1269                cls.symbols[name] = sym = _symbol(name, doc, canonical)
1270            return sym
1271        finally:
1272            symbol._lock.release()
1273
1274    @classmethod
1275    def parse_user_argument(
1276        cls, arg, choices, name, resolve_symbol_names=False
1277    ):
1278        """Given a user parameter, parse the parameter into a chosen symbol.
1279
1280        The user argument can be a string name that matches the name of a
1281        symbol, or the symbol object itself, or any number of alternate choices
1282        such as True/False/ None etc.
1283
1284        :param arg: the user argument.
1285        :param choices: dictionary of symbol object to list of possible
1286         entries.
1287        :param name: name of the argument.   Used in an :class:`.ArgumentError`
1288         that is raised if the parameter doesn't match any available argument.
1289        :param resolve_symbol_names: include the name of each symbol as a valid
1290         entry.
1291
1292        """
1293        # note using hash lookup is tricky here because symbol's `__hash__`
1294        # is its int value which we don't want included in the lookup
1295        # explicitly, so we iterate and compare each.
1296        for sym, choice in choices.items():
1297            if arg is sym:
1298                return sym
1299            elif resolve_symbol_names and arg == sym.name:
1300                return sym
1301            elif arg in choice:
1302                return sym
1303
1304        if arg is None:
1305            return None
1306
1307        raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
1308
1309
1310_creation_order = 1
1311
1312
1313def set_creation_order(instance):
1314    """Assign a '_creation_order' sequence to the given instance.
1315
1316    This allows multiple instances to be sorted in order of creation
1317    (typically within a single thread; the counter is not particularly
1318    threadsafe).
1319
1320    """
1321    global _creation_order
1322    instance._creation_order = _creation_order
1323    _creation_order += 1
1324
1325
1326def warn_exception(func, *args, **kwargs):
1327    """executes the given function, catches all exceptions and converts to
1328    a warning.
1329
1330    """
1331    try:
1332        return func(*args, **kwargs)
1333    except Exception:
1334        warn("%s('%s') ignored" % sys.exc_info()[0:2])
1335
1336
1337def ellipses_string(value, len_=25):
1338    try:
1339        if len(value) > len_:
1340            return "%s..." % value[0:len_]
1341        else:
1342            return value
1343    except TypeError:
1344        return value
1345
1346
1347class _hash_limit_string(compat.text_type):
1348    """A string subclass that can only be hashed on a maximum amount
1349    of unique values.
1350
1351    This is used for warnings so that we can send out parameterized warnings
1352    without the __warningregistry__ of the module,  or the non-overridable
1353    "once" registry within warnings.py, overloading memory,
1354
1355
1356    """
1357
1358    def __new__(cls, value, num, args):
1359        interpolated = (value % args) + (
1360            " (this warning may be suppressed after %d occurrences)" % num
1361        )
1362        self = super(_hash_limit_string, cls).__new__(cls, interpolated)
1363        self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1364        return self
1365
1366    def __hash__(self):
1367        return self._hash
1368
1369    def __eq__(self, other):
1370        return hash(self) == hash(other)
1371
1372
1373def warn(msg):
1374    """Issue a warning.
1375
1376    If msg is a string, :class:`.exc.SAWarning` is used as
1377    the category.
1378
1379    """
1380    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1381
1382
1383def warn_limited(msg, args):
1384    """Issue a warning with a parameterized string, limiting the number
1385    of registrations.
1386
1387    """
1388    if args:
1389        msg = _hash_limit_string(msg, 10, args)
1390    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1391
1392
1393def only_once(fn, retry_on_exception):
1394    """Decorate the given function to be a no-op after it is called exactly
1395    once."""
1396
1397    once = [fn]
1398
1399    def go(*arg, **kw):
1400        # strong reference fn so that it isn't garbage collected,
1401        # which interferes with the event system's expectations
1402        strong_fn = fn  # noqa
1403        if once:
1404            once_fn = once.pop()
1405            try:
1406                return once_fn(*arg, **kw)
1407            except:
1408                if retry_on_exception:
1409                    once.insert(0, once_fn)
1410                raise
1411
1412    return go
1413
1414
1415_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1416_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1417
1418
1419def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
1420    """Chop extraneous lines off beginning and end of a traceback.
1421
1422    :param tb:
1423      a list of traceback lines as returned by ``traceback.format_stack()``
1424
1425    :param exclude_prefix:
1426      a regular expression object matching lines to skip at beginning of
1427      ``tb``
1428
1429    :param exclude_suffix:
1430      a regular expression object matching lines to skip at end of ``tb``
1431    """
1432    start = 0
1433    end = len(tb) - 1
1434    while start <= end and exclude_prefix.search(tb[start]):
1435        start += 1
1436    while start <= end and exclude_suffix.search(tb[end]):
1437        end -= 1
1438    return tb[start : end + 1]
1439
1440
1441NoneType = type(None)
1442
1443
1444def attrsetter(attrname):
1445    code = "def set(obj, value):" "    obj.%s = value" % attrname
1446    env = locals().copy()
1447    exec(code, env)
1448    return env["set"]
1449
1450
1451class EnsureKWArgType(type):
1452    r"""Apply translation of functions to accept \**kw arguments if they
1453    don't already.
1454
1455    """
1456
1457    def __init__(cls, clsname, bases, clsdict):
1458        fn_reg = cls.ensure_kwarg
1459        if fn_reg:
1460            for key in clsdict:
1461                m = re.match(fn_reg, key)
1462                if m:
1463                    fn = clsdict[key]
1464                    spec = compat.inspect_getfullargspec(fn)
1465                    if not spec.varkw:
1466                        clsdict[key] = wrapped = cls._wrap_w_kw(fn)
1467                        setattr(cls, key, wrapped)
1468        super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
1469
1470    def _wrap_w_kw(self, fn):
1471        def wrap(*arg, **kw):
1472            return fn(*arg)
1473
1474        return update_wrapper(wrap, fn)
1475
1476
1477def wrap_callable(wrapper, fn):
1478    """Augment functools.update_wrapper() to work with objects with
1479    a ``__call__()`` method.
1480
1481    :param fn:
1482      object with __call__ method
1483
1484    """
1485    if hasattr(fn, "__name__"):
1486        return update_wrapper(wrapper, fn)
1487    else:
1488        _f = wrapper
1489        _f.__name__ = fn.__class__.__name__
1490        if hasattr(fn, "__module__"):
1491            _f.__module__ = fn.__module__
1492
1493        if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
1494            _f.__doc__ = fn.__call__.__doc__
1495        elif fn.__doc__:
1496            _f.__doc__ = fn.__doc__
1497
1498        return _f
1499
1500
1501def quoted_token_parser(value):
1502    """Parse a dotted identifier with accommodation for quoted names.
1503
1504    Includes support for SQL-style double quotes as a literal character.
1505
1506    E.g.::
1507
1508        >>> quoted_token_parser("name")
1509        ["name"]
1510        >>> quoted_token_parser("schema.name")
1511        ["schema", "name"]
1512        >>> quoted_token_parser('"Schema"."Name"')
1513        ['Schema', 'Name']
1514        >>> quoted_token_parser('"Schema"."Name""Foo"')
1515        ['Schema', 'Name""Foo']
1516
1517    """
1518
1519    if '"' not in value:
1520        return value.split(".")
1521
1522    # 0 = outside of quotes
1523    # 1 = inside of quotes
1524    state = 0
1525    result = [[]]
1526    idx = 0
1527    lv = len(value)
1528    while idx < lv:
1529        char = value[idx]
1530        if char == '"':
1531            if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
1532                result[-1].append('"')
1533                idx += 1
1534            else:
1535                state ^= 1
1536        elif char == "." and state == 0:
1537            result.append([])
1538        else:
1539            result[-1].append(char)
1540        idx += 1
1541
1542    return ["".join(token) for token in result]
1543
1544
1545def add_parameter_text(params, text):
1546    params = _collections.to_list(params)
1547
1548    def decorate(fn):
1549        doc = fn.__doc__ is not None and fn.__doc__ or ""
1550        if doc:
1551            doc = inject_param_text(doc, {param: text for param in params})
1552        fn.__doc__ = doc
1553        return fn
1554
1555    return decorate
1556
1557
1558def _dedent_docstring(text):
1559    split_text = text.split("\n", 1)
1560    if len(split_text) == 1:
1561        return text
1562    else:
1563        firstline, remaining = split_text
1564    if not firstline.startswith(" "):
1565        return firstline + "\n" + textwrap.dedent(remaining)
1566    else:
1567        return textwrap.dedent(text)
1568
1569
1570def inject_docstring_text(doctext, injecttext, pos):
1571    doctext = _dedent_docstring(doctext or "")
1572    lines = doctext.split("\n")
1573    if len(lines) == 1:
1574        lines.append("")
1575    injectlines = textwrap.dedent(injecttext).split("\n")
1576    if injectlines[0]:
1577        injectlines.insert(0, "")
1578
1579    blanks = [num for num, line in enumerate(lines) if not line.strip()]
1580    blanks.insert(0, 0)
1581
1582    inject_pos = blanks[min(pos, len(blanks) - 1)]
1583
1584    lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
1585    return "\n".join(lines)
1586
1587
1588def inject_param_text(doctext, inject_params):
1589    doclines = doctext.splitlines()
1590    lines = []
1591
1592    to_inject = None
1593    while doclines:
1594        line = doclines.pop(0)
1595        if to_inject is None:
1596            m = re.match(r"(\s+):param (?:\\\*\*?)?(.+?):", line)
1597            if m:
1598                param = m.group(2)
1599                if param in inject_params:
1600                    # default indent to that of :param: plus one
1601                    indent = " " * len(m.group(1)) + " "
1602
1603                    # but if the next line has text, use that line's
1604                    # indentntation
1605                    if doclines:
1606                        m2 = re.match(r"(\s+)\S", doclines[0])
1607                        if m2:
1608                            indent = " " * len(m2.group(1))
1609
1610                    to_inject = indent + inject_params[param]
1611        elif line.lstrip().startswith(":param "):
1612            lines.append("\n")
1613            lines.append(to_inject)
1614            lines.append("\n")
1615            to_inject = None
1616        elif not line.rstrip():
1617            lines.append(line)
1618            lines.append(to_inject)
1619            lines.append("\n")
1620            to_inject = None
1621        elif line.endswith("::"):
1622            # TODO: this still wont cover if the code example itself has blank
1623            # lines in it, need to detect those via indentation.
1624            lines.append(line)
1625            lines.append(
1626                doclines.pop(0)
1627            )  # the blank line following a code example
1628            continue
1629        lines.append(line)
1630
1631    return "\n".join(lines)
1632
1633
1634def repr_tuple_names(names):
1635    """Trims a list of strings from the middle and return a string of up to
1636    four elements. Strings greater than 11 characters will be truncated"""
1637    if len(names) == 0:
1638        return None
1639    flag = len(names) <= 4
1640    names = names[0:4] if flag else names[0:3] + names[-1:]
1641    res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
1642    if flag:
1643        return ", ".join(res)
1644    else:
1645        return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
1646