1# util/langhelpers.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Routines to help with the creation, loading and introspection of
9modules, classes, hierarchies, attributes, functions, and methods.
10
11"""
12import itertools
13import inspect
14import operator
15import re
16import sys
17import types
18import warnings
19from functools import update_wrapper
20from .. import exc
21import hashlib
22from . import compat
23from . import _collections
24
25
26def md5_hex(x):
27    if compat.py3k:
28        x = x.encode('utf-8')
29    m = hashlib.md5()
30    m.update(x)
31    return m.hexdigest()
32
33
34class safe_reraise(object):
35    """Reraise an exception after invoking some
36    handler code.
37
38    Stores the existing exception info before
39    invoking so that it is maintained across a potential
40    coroutine context switch.
41
42    e.g.::
43
44        try:
45            sess.commit()
46        except:
47            with safe_reraise():
48                sess.rollback()
49
50    """
51
52    def __enter__(self):
53        self._exc_info = sys.exc_info()
54
55    def __exit__(self, type_, value, traceback):
56        # see #2703 for notes
57        if type_ is None:
58            exc_type, exc_value, exc_tb = self._exc_info
59            self._exc_info = None   # remove potential circular references
60            compat.reraise(exc_type, exc_value, exc_tb)
61        else:
62            if not compat.py3k and self._exc_info and self._exc_info[1]:
63                # emulate Py3K's behavior of telling us when an exception
64                # occurs in an exception handler.
65                warn(
66                    "An exception has occurred during handling of a "
67                    "previous exception.  The previous exception "
68                    "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]))
69            self._exc_info = None   # remove potential circular references
70            compat.reraise(type_, value, traceback)
71
72
73def decode_slice(slc):
74    """decode a slice object as sent to __getitem__.
75
76    takes into account the 2.5 __index__() method, basically.
77
78    """
79    ret = []
80    for x in slc.start, slc.stop, slc.step:
81        if hasattr(x, '__index__'):
82            x = x.__index__()
83        ret.append(x)
84    return tuple(ret)
85
86
87def _unique_symbols(used, *bases):
88    used = set(used)
89    for base in bases:
90        pool = itertools.chain((base,),
91                               compat.itertools_imap(lambda i: base + str(i),
92                                                     range(1000)))
93        for sym in pool:
94            if sym not in used:
95                used.add(sym)
96                yield sym
97                break
98        else:
99            raise NameError("exhausted namespace for symbol base %s" % base)
100
101
102def map_bits(fn, n):
103    """Call the given function given each nonzero bit from n."""
104
105    while n:
106        b = n & (~n + 1)
107        yield fn(b)
108        n ^= b
109
110
111def decorator(target):
112    """A signature-matching decorator factory."""
113
114    def decorate(fn):
115        if not inspect.isfunction(fn):
116            raise Exception("not a decoratable function")
117        spec = compat.inspect_getfullargspec(fn)
118        names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
119        targ_name, fn_name = _unique_symbols(names, 'target', 'fn')
120
121        metadata = dict(target=targ_name, fn=fn_name)
122        metadata.update(format_argspec_plus(spec, grouped=False))
123        metadata['name'] = fn.__name__
124        code = """\
125def %(name)s(%(args)s):
126    return %(target)s(%(fn)s, %(apply_kw)s)
127""" % metadata
128        decorated = _exec_code_in_env(code,
129                                      {targ_name: target, fn_name: fn},
130                                      fn.__name__)
131        decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__
132        decorated.__wrapped__ = fn
133        return update_wrapper(decorated, fn)
134    return update_wrapper(decorate, target)
135
136
137def _exec_code_in_env(code, env, fn_name):
138    exec(code, env)
139    return env[fn_name]
140
141
142def public_factory(target, location):
143    """Produce a wrapping function for the given cls or classmethod.
144
145    Rationale here is so that the __init__ method of the
146    class can serve as documentation for the function.
147
148    """
149    if isinstance(target, type):
150        fn = target.__init__
151        callable_ = target
152        doc = "Construct a new :class:`.%s` object. \n\n"\
153            "This constructor is mirrored as a public API function; "\
154            "see :func:`~%s` "\
155            "for a full usage and argument description." % (
156                target.__name__, location, )
157    else:
158        fn = callable_ = target
159        doc = "This function is mirrored; see :func:`~%s` "\
160            "for a description of arguments." % location
161
162    location_name = location.split(".")[-1]
163    spec = compat.inspect_getfullargspec(fn)
164    del spec[0][0]
165    metadata = format_argspec_plus(spec, grouped=False)
166    metadata['name'] = location_name
167    code = """\
168def %(name)s(%(args)s):
169    return cls(%(apply_kw)s)
170""" % metadata
171    env = {'cls': callable_, 'symbol': symbol}
172    exec(code, env)
173    decorated = env[location_name]
174    decorated.__doc__ = fn.__doc__
175    decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
176    if compat.py2k or hasattr(fn, '__func__'):
177        fn.__func__.__doc__ = doc
178    else:
179        fn.__doc__ = doc
180    return decorated
181
182
183class PluginLoader(object):
184
185    def __init__(self, group, auto_fn=None):
186        self.group = group
187        self.impls = {}
188        self.auto_fn = auto_fn
189
190    def load(self, name):
191        if name in self.impls:
192            return self.impls[name]()
193
194        if self.auto_fn:
195            loader = self.auto_fn(name)
196            if loader:
197                self.impls[name] = loader
198                return loader()
199
200        try:
201            import pkg_resources
202        except ImportError:
203            pass
204        else:
205            for impl in pkg_resources.iter_entry_points(
206                    self.group, name):
207                self.impls[name] = impl.load
208                return impl.load()
209
210        raise exc.NoSuchModuleError(
211            "Can't load plugin: %s:%s" %
212            (self.group, name))
213
214    def register(self, name, modulepath, objname):
215        def load():
216            mod = compat.import_(modulepath)
217            for token in modulepath.split(".")[1:]:
218                mod = getattr(mod, token)
219            return getattr(mod, objname)
220        self.impls[name] = load
221
222
223def get_cls_kwargs(cls, _set=None):
224    """Return the full set of inherited kwargs for the given `cls`.
225
226    Probes a class's __init__ method, collecting all named arguments.  If the
227    __init__ defines a \**kwargs catch-all, then the constructor is presumed
228    to pass along unrecognized keywords to its base classes, and the
229    collection process is repeated recursively on each of the bases.
230
231    Uses a subset of inspect.getargspec() to cut down on method overhead.
232    No anonymous tuple arguments please !
233
234    """
235    toplevel = _set is None
236    if toplevel:
237        _set = set()
238
239    ctr = cls.__dict__.get('__init__', False)
240
241    has_init = ctr and isinstance(ctr, types.FunctionType) and \
242        isinstance(ctr.__code__, types.CodeType)
243
244    if has_init:
245        names, has_kw = inspect_func_args(ctr)
246        _set.update(names)
247
248        if not has_kw and not toplevel:
249            return None
250
251    if not has_init or has_kw:
252        for c in cls.__bases__:
253            if get_cls_kwargs(c, _set) is None:
254                break
255
256    _set.discard('self')
257    return _set
258
259
260try:
261    # TODO: who doesn't have this constant?
262    from inspect import CO_VARKEYWORDS
263
264    def inspect_func_args(fn):
265        co = fn.__code__
266        nargs = co.co_argcount
267        names = co.co_varnames
268        args = list(names[:nargs])
269        has_kw = bool(co.co_flags & CO_VARKEYWORDS)
270        return args, has_kw
271
272except ImportError:
273    def inspect_func_args(fn):
274        names, _, has_kw, _ = inspect.getargspec(fn)
275        return names, bool(has_kw)
276
277
278def get_func_kwargs(func):
279    """Return the set of legal kwargs for the given `func`.
280
281    Uses getargspec so is safe to call for methods, functions,
282    etc.
283
284    """
285
286    return compat.inspect_getargspec(func)[0]
287
288
289def get_callable_argspec(fn, no_self=False, _is_init=False):
290    """Return the argument signature for any callable.
291
292    All pure-Python callables are accepted, including
293    functions, methods, classes, objects with __call__;
294    builtins and other edge cases like functools.partial() objects
295    raise a TypeError.
296
297    """
298    if inspect.isbuiltin(fn):
299        raise TypeError("Can't inspect builtin: %s" % fn)
300    elif inspect.isfunction(fn):
301        if _is_init and no_self:
302            spec = compat.inspect_getargspec(fn)
303            return compat.ArgSpec(spec.args[1:], spec.varargs,
304                                  spec.keywords, spec.defaults)
305        else:
306            return compat.inspect_getargspec(fn)
307    elif inspect.ismethod(fn):
308        if no_self and (_is_init or fn.__self__):
309            spec = compat.inspect_getargspec(fn.__func__)
310            return compat.ArgSpec(spec.args[1:], spec.varargs,
311                                  spec.keywords, spec.defaults)
312        else:
313            return compat.inspect_getargspec(fn.__func__)
314    elif inspect.isclass(fn):
315        return get_callable_argspec(
316            fn.__init__, no_self=no_self, _is_init=True)
317    elif hasattr(fn, '__func__'):
318        return compat.inspect_getargspec(fn.__func__)
319    elif hasattr(fn, '__call__'):
320        if inspect.ismethod(fn.__call__):
321            return get_callable_argspec(fn.__call__, no_self=no_self)
322        else:
323            raise TypeError("Can't inspect callable: %s" % fn)
324    else:
325        raise TypeError("Can't inspect callable: %s" % fn)
326
327
328def format_argspec_plus(fn, grouped=True):
329    """Returns a dictionary of formatted, introspected function arguments.
330
331    A enhanced variant of inspect.formatargspec to support code generation.
332
333    fn
334       An inspectable callable or tuple of inspect getargspec() results.
335    grouped
336      Defaults to True; include (parens, around, argument) lists
337
338    Returns:
339
340    args
341      Full inspect.formatargspec for fn
342    self_arg
343      The name of the first positional argument, varargs[0], or None
344      if the function defines no positional arguments.
345    apply_pos
346      args, re-written in calling rather than receiving syntax.  Arguments are
347      passed positionally.
348    apply_kw
349      Like apply_pos, except keyword-ish args are passed as keywords.
350
351    Example::
352
353      >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
354      {'args': '(self, a, b, c=3, **d)',
355       'self_arg': 'self',
356       'apply_kw': '(self, a, b, c=c, **d)',
357       'apply_pos': '(self, a, b, c, **d)'}
358
359    """
360    if compat.callable(fn):
361        spec = compat.inspect_getfullargspec(fn)
362    else:
363        # we accept an existing argspec...
364        spec = fn
365    args = inspect.formatargspec(*spec)
366    if spec[0]:
367        self_arg = spec[0][0]
368    elif spec[1]:
369        self_arg = '%s[0]' % spec[1]
370    else:
371        self_arg = None
372
373    if compat.py3k:
374        apply_pos = inspect.formatargspec(spec[0], spec[1],
375                                          spec[2], None, spec[4])
376        num_defaults = 0
377        if spec[3]:
378            num_defaults += len(spec[3])
379        if spec[4]:
380            num_defaults += len(spec[4])
381        name_args = spec[0] + spec[4]
382    else:
383        apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2])
384        num_defaults = 0
385        if spec[3]:
386            num_defaults += len(spec[3])
387        name_args = spec[0]
388
389    if num_defaults:
390        defaulted_vals = name_args[0 - num_defaults:]
391    else:
392        defaulted_vals = ()
393
394    apply_kw = inspect.formatargspec(name_args, spec[1], spec[2],
395                                     defaulted_vals,
396                                     formatvalue=lambda x: '=' + x)
397    if grouped:
398        return dict(args=args, self_arg=self_arg,
399                    apply_pos=apply_pos, apply_kw=apply_kw)
400    else:
401        return dict(args=args[1:-1], self_arg=self_arg,
402                    apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1])
403
404
405def format_argspec_init(method, grouped=True):
406    """format_argspec_plus with considerations for typical __init__ methods
407
408    Wraps format_argspec_plus with error handling strategies for typical
409    __init__ cases::
410
411      object.__init__ -> (self)
412      other unreflectable (usually C) -> (self, *args, **kwargs)
413
414    """
415    if method is object.__init__:
416        args = grouped and '(self)' or 'self'
417    else:
418        try:
419            return format_argspec_plus(method, grouped=grouped)
420        except TypeError:
421            args = (grouped and '(self, *args, **kwargs)'
422                    or 'self, *args, **kwargs')
423    return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args)
424
425
426def getargspec_init(method):
427    """inspect.getargspec with considerations for typical __init__ methods
428
429    Wraps inspect.getargspec with error handling for typical __init__ cases::
430
431      object.__init__ -> (self)
432      other unreflectable (usually C) -> (self, *args, **kwargs)
433
434    """
435    try:
436        return compat.inspect_getargspec(method)
437    except TypeError:
438        if method is object.__init__:
439            return (['self'], None, None, None)
440        else:
441            return (['self'], 'args', 'kwargs', None)
442
443
444def unbound_method_to_callable(func_or_cls):
445    """Adjust the incoming callable such that a 'self' argument is not
446    required.
447
448    """
449
450    if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
451        return func_or_cls.__func__
452    else:
453        return func_or_cls
454
455
456def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
457    """Produce a __repr__() based on direct association of the __init__()
458    specification vs. same-named attributes present.
459
460    """
461    if to_inspect is None:
462        to_inspect = [obj]
463    else:
464        to_inspect = _collections.to_list(to_inspect)
465
466    missing = object()
467
468    pos_args = []
469    kw_args = _collections.OrderedDict()
470    vargs = None
471    for i, insp in enumerate(to_inspect):
472        try:
473            (_args, _vargs, vkw, defaults) = \
474                compat.inspect_getargspec(insp.__init__)
475        except TypeError:
476            continue
477        else:
478            default_len = defaults and len(defaults) or 0
479            if i == 0:
480                if _vargs:
481                    vargs = _vargs
482                if default_len:
483                    pos_args.extend(_args[1:-default_len])
484                else:
485                    pos_args.extend(_args[1:])
486            else:
487                kw_args.update([
488                    (arg, missing) for arg in _args[1:-default_len]
489                ])
490
491            if default_len:
492                kw_args.update([
493                    (arg, default)
494                    for arg, default
495                    in zip(_args[-default_len:], defaults)
496                ])
497    output = []
498
499    output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
500
501    if vargs is not None and hasattr(obj, vargs):
502        output.extend([repr(val) for val in getattr(obj, vargs)])
503
504    for arg, defval in kw_args.items():
505        if arg in omit_kwarg:
506            continue
507        try:
508            val = getattr(obj, arg, missing)
509            if val is not missing and val != defval:
510                output.append('%s=%r' % (arg, val))
511        except Exception:
512            pass
513
514    if additional_kw:
515        for arg, defval in additional_kw:
516            try:
517                val = getattr(obj, arg, missing)
518                if val is not missing and val != defval:
519                    output.append('%s=%r' % (arg, val))
520            except Exception:
521                pass
522
523    return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
524
525
526class portable_instancemethod(object):
527    """Turn an instancemethod into a (parent, name) pair
528    to produce a serializable callable.
529
530    """
531
532    __slots__ = 'target', 'name', '__weakref__'
533
534    def __getstate__(self):
535        return {'target': self.target, 'name': self.name}
536
537    def __setstate__(self, state):
538        self.target = state['target']
539        self.name = state['name']
540
541    def __init__(self, meth):
542        self.target = meth.__self__
543        self.name = meth.__name__
544
545    def __call__(self, *arg, **kw):
546        return getattr(self.target, self.name)(*arg, **kw)
547
548
549def class_hierarchy(cls):
550    """Return an unordered sequence of all classes related to cls.
551
552    Traverses diamond hierarchies.
553
554    Fibs slightly: subclasses of builtin types are not returned.  Thus
555    class_hierarchy(class A(object)) returns (A, object), not A plus every
556    class systemwide that derives from object.
557
558    Old-style classes are discarded and hierarchies rooted on them
559    will not be descended.
560
561    """
562    if compat.py2k:
563        if isinstance(cls, types.ClassType):
564            return list()
565
566    hier = set([cls])
567    process = list(cls.__mro__)
568    while process:
569        c = process.pop()
570        if compat.py2k:
571            if isinstance(c, types.ClassType):
572                continue
573            bases = (_ for _ in c.__bases__
574                     if _ not in hier and not isinstance(_, types.ClassType))
575        else:
576            bases = (_ for _ in c.__bases__ if _ not in hier)
577
578        for b in bases:
579            process.append(b)
580            hier.add(b)
581
582        if compat.py3k:
583            if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
584                continue
585        else:
586            if c.__module__ == '__builtin__' or not hasattr(
587                    c, '__subclasses__'):
588                continue
589
590        for s in [_ for _ in c.__subclasses__() if _ not in hier]:
591            process.append(s)
592            hier.add(s)
593    return list(hier)
594
595
596def iterate_attributes(cls):
597    """iterate all the keys and attributes associated
598       with a class, without using getattr().
599
600       Does not use getattr() so that class-sensitive
601       descriptors (i.e. property.__get__()) are not called.
602
603    """
604    keys = dir(cls)
605    for key in keys:
606        for c in cls.__mro__:
607            if key in c.__dict__:
608                yield (key, c.__dict__[key])
609                break
610
611
612def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
613                                 name='self.proxy', from_instance=None):
614    """Automates delegation of __specials__ for a proxying type."""
615
616    if only:
617        dunders = only
618    else:
619        if skip is None:
620            skip = ('__slots__', '__del__', '__getattribute__',
621                    '__metaclass__', '__getstate__', '__setstate__')
622        dunders = [m for m in dir(from_cls)
623                   if (m.startswith('__') and m.endswith('__') and
624                       not hasattr(into_cls, m) and m not in skip)]
625
626    for method in dunders:
627        try:
628            fn = getattr(from_cls, method)
629            if not hasattr(fn, '__call__'):
630                continue
631            fn = getattr(fn, 'im_func', fn)
632        except AttributeError:
633            continue
634        try:
635            spec = compat.inspect_getargspec(fn)
636            fn_args = inspect.formatargspec(spec[0])
637            d_args = inspect.formatargspec(spec[0][1:])
638        except TypeError:
639            fn_args = '(self, *args, **kw)'
640            d_args = '(*args, **kw)'
641
642        py = ("def %(method)s%(fn_args)s: "
643              "return %(name)s.%(method)s%(d_args)s" % locals())
644
645        env = from_instance is not None and {name: from_instance} or {}
646        compat.exec_(py, env)
647        try:
648            env[method].__defaults__ = fn.__defaults__
649        except AttributeError:
650            pass
651        setattr(into_cls, method, env[method])
652
653
654def methods_equivalent(meth1, meth2):
655    """Return True if the two methods are the same implementation."""
656
657    return getattr(meth1, '__func__', meth1) is getattr(
658        meth2, '__func__', meth2)
659
660
661def as_interface(obj, cls=None, methods=None, required=None):
662    """Ensure basic interface compliance for an instance or dict of callables.
663
664    Checks that ``obj`` implements public methods of ``cls`` or has members
665    listed in ``methods``. If ``required`` is not supplied, implementing at
666    least one interface method is sufficient. Methods present on ``obj`` that
667    are not in the interface are ignored.
668
669    If ``obj`` is a dict and ``dict`` does not meet the interface
670    requirements, the keys of the dictionary are inspected. Keys present in
671    ``obj`` that are not in the interface will raise TypeErrors.
672
673    Raises TypeError if ``obj`` does not meet the interface criteria.
674
675    In all passing cases, an object with callable members is returned.  In the
676    simple case, ``obj`` is returned as-is; if dict processing kicks in then
677    an anonymous class is returned.
678
679    obj
680      A type, instance, or dictionary of callables.
681    cls
682      Optional, a type.  All public methods of cls are considered the
683      interface.  An ``obj`` instance of cls will always pass, ignoring
684      ``required``..
685    methods
686      Optional, a sequence of method names to consider as the interface.
687    required
688      Optional, a sequence of mandatory implementations. If omitted, an
689      ``obj`` that provides at least one interface method is considered
690      sufficient.  As a convenience, required may be a type, in which case
691      all public methods of the type are required.
692
693    """
694    if not cls and not methods:
695        raise TypeError('a class or collection of method names are required')
696
697    if isinstance(cls, type) and isinstance(obj, cls):
698        return obj
699
700    interface = set(methods or [m for m in dir(cls) if not m.startswith('_')])
701    implemented = set(dir(obj))
702
703    complies = operator.ge
704    if isinstance(required, type):
705        required = interface
706    elif not required:
707        required = set()
708        complies = operator.gt
709    else:
710        required = set(required)
711
712    if complies(implemented.intersection(interface), required):
713        return obj
714
715    # No dict duck typing here.
716    if not isinstance(obj, dict):
717        qualifier = complies is operator.gt and 'any of' or 'all of'
718        raise TypeError("%r does not implement %s: %s" % (
719            obj, qualifier, ', '.join(interface)))
720
721    class AnonymousInterface(object):
722        """A callable-holding shell."""
723
724    if cls:
725        AnonymousInterface.__name__ = 'Anonymous' + cls.__name__
726    found = set()
727
728    for method, impl in dictlike_iteritems(obj):
729        if method not in interface:
730            raise TypeError("%r: unknown in this interface" % method)
731        if not compat.callable(impl):
732            raise TypeError("%r=%r is not callable" % (method, impl))
733        setattr(AnonymousInterface, method, staticmethod(impl))
734        found.add(method)
735
736    if complies(found, required):
737        return AnonymousInterface
738
739    raise TypeError("dictionary does not contain required keys %s" %
740                    ', '.join(required - found))
741
742
743class memoized_property(object):
744    """A read-only @property that is only evaluated once."""
745
746    def __init__(self, fget, doc=None):
747        self.fget = fget
748        self.__doc__ = doc or fget.__doc__
749        self.__name__ = fget.__name__
750
751    def __get__(self, obj, cls):
752        if obj is None:
753            return self
754        obj.__dict__[self.__name__] = result = self.fget(obj)
755        return result
756
757    def _reset(self, obj):
758        memoized_property.reset(obj, self.__name__)
759
760    @classmethod
761    def reset(cls, obj, name):
762        obj.__dict__.pop(name, None)
763
764
765def memoized_instancemethod(fn):
766    """Decorate a method memoize its return value.
767
768    Best applied to no-arg methods: memoization is not sensitive to
769    argument values, and will always return the same value even when
770    called with different arguments.
771
772    """
773
774    def oneshot(self, *args, **kw):
775        result = fn(self, *args, **kw)
776        memo = lambda *a, **kw: result
777        memo.__name__ = fn.__name__
778        memo.__doc__ = fn.__doc__
779        self.__dict__[fn.__name__] = memo
780        return result
781    return update_wrapper(oneshot, fn)
782
783
784class group_expirable_memoized_property(object):
785    """A family of @memoized_properties that can be expired in tandem."""
786
787    def __init__(self, attributes=()):
788        self.attributes = []
789        if attributes:
790            self.attributes.extend(attributes)
791
792    def expire_instance(self, instance):
793        """Expire all memoized properties for *instance*."""
794        stash = instance.__dict__
795        for attribute in self.attributes:
796            stash.pop(attribute, None)
797
798    def __call__(self, fn):
799        self.attributes.append(fn.__name__)
800        return memoized_property(fn)
801
802    def method(self, fn):
803        self.attributes.append(fn.__name__)
804        return memoized_instancemethod(fn)
805
806
807class MemoizedSlots(object):
808    """Apply memoized items to an object using a __getattr__ scheme.
809
810    This allows the functionality of memoized_property and
811    memoized_instancemethod to be available to a class using __slots__.
812
813    """
814
815    __slots__ = ()
816
817    def _fallback_getattr(self, key):
818        raise AttributeError(key)
819
820    def __getattr__(self, key):
821        if key.startswith('_memoized'):
822            raise AttributeError(key)
823        elif hasattr(self, '_memoized_attr_%s' % key):
824            value = getattr(self, '_memoized_attr_%s' % key)()
825            setattr(self, key, value)
826            return value
827        elif hasattr(self, '_memoized_method_%s' % key):
828            fn = getattr(self, '_memoized_method_%s' % key)
829
830            def oneshot(*args, **kw):
831                result = fn(*args, **kw)
832                memo = lambda *a, **kw: result
833                memo.__name__ = fn.__name__
834                memo.__doc__ = fn.__doc__
835                setattr(self, key, memo)
836                return result
837            oneshot.__doc__ = fn.__doc__
838            return oneshot
839        else:
840            return self._fallback_getattr(key)
841
842
843def dependency_for(modulename):
844    def decorate(obj):
845        # TODO: would be nice to improve on this import silliness,
846        # unfortunately importlib doesn't work that great either
847        tokens = modulename.split(".")
848        mod = compat.import_(
849            ".".join(tokens[0:-1]), globals(), locals(), tokens[-1])
850        mod = getattr(mod, tokens[-1])
851        setattr(mod, obj.__name__, obj)
852        return obj
853    return decorate
854
855
856class dependencies(object):
857    """Apply imported dependencies as arguments to a function.
858
859    E.g.::
860
861        @util.dependencies(
862            "sqlalchemy.sql.widget",
863            "sqlalchemy.engine.default"
864        );
865        def some_func(self, widget, default, arg1, arg2, **kw):
866            # ...
867
868    Rationale is so that the impact of a dependency cycle can be
869    associated directly with the few functions that cause the cycle,
870    and not pollute the module-level namespace.
871
872    """
873
874    def __init__(self, *deps):
875        self.import_deps = []
876        for dep in deps:
877            tokens = dep.split(".")
878            self.import_deps.append(
879                dependencies._importlater(
880                    ".".join(tokens[0:-1]),
881                    tokens[-1]
882                )
883            )
884
885    def __call__(self, fn):
886        import_deps = self.import_deps
887        spec = compat.inspect_getfullargspec(fn)
888
889        spec_zero = list(spec[0])
890        hasself = spec_zero[0] in ('self', 'cls')
891
892        for i in range(len(import_deps)):
893            spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i
894
895        inner_spec = format_argspec_plus(spec, grouped=False)
896
897        for impname in import_deps:
898            del spec_zero[1 if hasself else 0]
899        spec[0][:] = spec_zero
900
901        outer_spec = format_argspec_plus(spec, grouped=False)
902
903        code = 'lambda %(args)s: fn(%(apply_kw)s)' % {
904            "args": outer_spec['args'],
905            "apply_kw": inner_spec['apply_kw']
906        }
907
908        decorated = eval(code, locals())
909        decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__
910        return update_wrapper(decorated, fn)
911
912    @classmethod
913    def resolve_all(cls, path):
914        for m in list(dependencies._unresolved):
915            if m._full_path.startswith(path):
916                m._resolve()
917
918    _unresolved = set()
919    _by_key = {}
920
921    class _importlater(object):
922        _unresolved = set()
923
924        _by_key = {}
925
926        def __new__(cls, path, addtl):
927            key = path + "." + addtl
928            if key in dependencies._by_key:
929                return dependencies._by_key[key]
930            else:
931                dependencies._by_key[key] = imp = object.__new__(cls)
932                return imp
933
934        def __init__(self, path, addtl):
935            self._il_path = path
936            self._il_addtl = addtl
937            dependencies._unresolved.add(self)
938
939        @property
940        def _full_path(self):
941            return self._il_path + "." + self._il_addtl
942
943        @memoized_property
944        def module(self):
945            if self in dependencies._unresolved:
946                raise ImportError(
947                    "importlater.resolve_all() hasn't "
948                    "been called (this is %s %s)"
949                    % (self._il_path, self._il_addtl))
950
951            return getattr(self._initial_import, self._il_addtl)
952
953        def _resolve(self):
954            dependencies._unresolved.discard(self)
955            self._initial_import = compat.import_(
956                self._il_path, globals(), locals(),
957                [self._il_addtl])
958
959        def __getattr__(self, key):
960            if key == 'module':
961                raise ImportError("Could not resolve module %s"
962                                  % self._full_path)
963            try:
964                attr = getattr(self.module, key)
965            except AttributeError:
966                raise AttributeError(
967                    "Module %s has no attribute '%s'" %
968                    (self._full_path, key)
969                )
970            self.__dict__[key] = attr
971            return attr
972
973
974# from paste.deploy.converters
975def asbool(obj):
976    if isinstance(obj, compat.string_types):
977        obj = obj.strip().lower()
978        if obj in ['true', 'yes', 'on', 'y', 't', '1']:
979            return True
980        elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
981            return False
982        else:
983            raise ValueError("String is not true/false: %r" % obj)
984    return bool(obj)
985
986
987def bool_or_str(*text):
988    """Return a callable that will evaluate a string as
989    boolean, or one of a set of "alternate" string values.
990
991    """
992    def bool_or_value(obj):
993        if obj in text:
994            return obj
995        else:
996            return asbool(obj)
997    return bool_or_value
998
999
1000def asint(value):
1001    """Coerce to integer."""
1002
1003    if value is None:
1004        return value
1005    return int(value)
1006
1007
1008def coerce_kw_type(kw, key, type_, flexi_bool=True):
1009    """If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1010    necessary.  If 'flexi_bool' is True, the string '0' is considered false
1011    when coercing to boolean.
1012    """
1013
1014    if key in kw and not isinstance(kw[key], type_) and kw[key] is not None:
1015        if type_ is bool and flexi_bool:
1016            kw[key] = asbool(kw[key])
1017        else:
1018            kw[key] = type_(kw[key])
1019
1020
1021def constructor_copy(obj, cls, *args, **kw):
1022    """Instantiate cls using the __dict__ of obj as constructor arguments.
1023
1024    Uses inspect to match the named arguments of ``cls``.
1025
1026    """
1027
1028    names = get_cls_kwargs(cls)
1029    kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__)
1030    return cls(*args, **kw)
1031
1032
1033def counter():
1034    """Return a threadsafe counter function."""
1035
1036    lock = compat.threading.Lock()
1037    counter = itertools.count(1)
1038
1039    # avoid the 2to3 "next" transformation...
1040    def _next():
1041        lock.acquire()
1042        try:
1043            return next(counter)
1044        finally:
1045            lock.release()
1046
1047    return _next
1048
1049
1050def duck_type_collection(specimen, default=None):
1051    """Given an instance or class, guess if it is or is acting as one of
1052    the basic collection types: list, set and dict.  If the __emulates__
1053    property is present, return that preferentially.
1054    """
1055
1056    if hasattr(specimen, '__emulates__'):
1057        # canonicalize set vs sets.Set to a standard: the builtin set
1058        if (specimen.__emulates__ is not None and
1059                issubclass(specimen.__emulates__, set)):
1060            return set
1061        else:
1062            return specimen.__emulates__
1063
1064    isa = isinstance(specimen, type) and issubclass or isinstance
1065    if isa(specimen, list):
1066        return list
1067    elif isa(specimen, set):
1068        return set
1069    elif isa(specimen, dict):
1070        return dict
1071
1072    if hasattr(specimen, 'append'):
1073        return list
1074    elif hasattr(specimen, 'add'):
1075        return set
1076    elif hasattr(specimen, 'set'):
1077        return dict
1078    else:
1079        return default
1080
1081
1082def assert_arg_type(arg, argtype, name):
1083    if isinstance(arg, argtype):
1084        return arg
1085    else:
1086        if isinstance(argtype, tuple):
1087            raise exc.ArgumentError(
1088                "Argument '%s' is expected to be one of type %s, got '%s'" %
1089                (name, ' or '.join("'%s'" % a for a in argtype), type(arg)))
1090        else:
1091            raise exc.ArgumentError(
1092                "Argument '%s' is expected to be of type '%s', got '%s'" %
1093                (name, argtype, type(arg)))
1094
1095
1096def dictlike_iteritems(dictlike):
1097    """Return a (key, value) iterator for almost any dict-like object."""
1098
1099    if compat.py3k:
1100        if hasattr(dictlike, 'items'):
1101            return list(dictlike.items())
1102    else:
1103        if hasattr(dictlike, 'iteritems'):
1104            return dictlike.iteritems()
1105        elif hasattr(dictlike, 'items'):
1106            return iter(dictlike.items())
1107
1108    getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None))
1109    if getter is None:
1110        raise TypeError(
1111            "Object '%r' is not dict-like" % dictlike)
1112
1113    if hasattr(dictlike, 'iterkeys'):
1114        def iterator():
1115            for key in dictlike.iterkeys():
1116                yield key, getter(key)
1117        return iterator()
1118    elif hasattr(dictlike, 'keys'):
1119        return iter((key, getter(key)) for key in dictlike.keys())
1120    else:
1121        raise TypeError(
1122            "Object '%r' is not dict-like" % dictlike)
1123
1124
1125class classproperty(property):
1126    """A decorator that behaves like @property except that operates
1127    on classes rather than instances.
1128
1129    The decorator is currently special when using the declarative
1130    module, but note that the
1131    :class:`~.sqlalchemy.ext.declarative.declared_attr`
1132    decorator should be used for this purpose with declarative.
1133
1134    """
1135
1136    def __init__(self, fget, *arg, **kw):
1137        super(classproperty, self).__init__(fget, *arg, **kw)
1138        self.__doc__ = fget.__doc__
1139
1140    def __get__(desc, self, cls):
1141        return desc.fget(cls)
1142
1143
1144class hybridproperty(object):
1145    def __init__(self, func):
1146        self.func = func
1147
1148    def __get__(self, instance, owner):
1149        if instance is None:
1150            clsval = self.func(owner)
1151            clsval.__doc__ = self.func.__doc__
1152            return clsval
1153        else:
1154            return self.func(instance)
1155
1156
1157class hybridmethod(object):
1158    """Decorate a function as cls- or instance- level."""
1159
1160    def __init__(self, func):
1161        self.func = func
1162
1163    def __get__(self, instance, owner):
1164        if instance is None:
1165            return self.func.__get__(owner, owner.__class__)
1166        else:
1167            return self.func.__get__(instance, owner)
1168
1169
1170class _symbol(int):
1171    def __new__(self, name, doc=None, canonical=None):
1172        """Construct a new named symbol."""
1173        assert isinstance(name, compat.string_types)
1174        if canonical is None:
1175            canonical = hash(name)
1176        v = int.__new__(_symbol, canonical)
1177        v.name = name
1178        if doc:
1179            v.__doc__ = doc
1180        return v
1181
1182    def __reduce__(self):
1183        return symbol, (self.name, "x", int(self))
1184
1185    def __str__(self):
1186        return repr(self)
1187
1188    def __repr__(self):
1189        return "symbol(%r)" % self.name
1190
1191_symbol.__name__ = 'symbol'
1192
1193
1194class symbol(object):
1195    """A constant symbol.
1196
1197    >>> symbol('foo') is symbol('foo')
1198    True
1199    >>> symbol('foo')
1200    <symbol 'foo>
1201
1202    A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
1203    advantage of symbol() is its repr().  They are also singletons.
1204
1205    Repeated calls of symbol('name') will all return the same instance.
1206
1207    The optional ``doc`` argument assigns to ``__doc__``.  This
1208    is strictly so that Sphinx autoattr picks up the docstring we want
1209    (it doesn't appear to pick up the in-module docstring if the datamember
1210    is in a different module - autoattribute also blows up completely).
1211    If Sphinx fixes/improves this then we would no longer need
1212    ``doc`` here.
1213
1214    """
1215    symbols = {}
1216    _lock = compat.threading.Lock()
1217
1218    def __new__(cls, name, doc=None, canonical=None):
1219        cls._lock.acquire()
1220        try:
1221            sym = cls.symbols.get(name)
1222            if sym is None:
1223                cls.symbols[name] = sym = _symbol(name, doc, canonical)
1224            return sym
1225        finally:
1226            symbol._lock.release()
1227
1228
1229_creation_order = 1
1230
1231
1232def set_creation_order(instance):
1233    """Assign a '_creation_order' sequence to the given instance.
1234
1235    This allows multiple instances to be sorted in order of creation
1236    (typically within a single thread; the counter is not particularly
1237    threadsafe).
1238
1239    """
1240    global _creation_order
1241    instance._creation_order = _creation_order
1242    _creation_order += 1
1243
1244
1245def warn_exception(func, *args, **kwargs):
1246    """executes the given function, catches all exceptions and converts to
1247    a warning.
1248
1249    """
1250    try:
1251        return func(*args, **kwargs)
1252    except Exception:
1253        warn("%s('%s') ignored" % sys.exc_info()[0:2])
1254
1255
1256def ellipses_string(value, len_=25):
1257    try:
1258        if len(value) > len_:
1259            return "%s..." % value[0:len_]
1260        else:
1261            return value
1262    except TypeError:
1263        return value
1264
1265
1266class _hash_limit_string(compat.text_type):
1267    """A string subclass that can only be hashed on a maximum amount
1268    of unique values.
1269
1270    This is used for warnings so that we can send out parameterized warnings
1271    without the __warningregistry__ of the module,  or the non-overridable
1272    "once" registry within warnings.py, overloading memory,
1273
1274
1275    """
1276    def __new__(cls, value, num, args):
1277        interpolated = (value % args) + \
1278            (" (this warning may be suppressed after %d occurrences)" % num)
1279        self = super(_hash_limit_string, cls).__new__(cls, interpolated)
1280        self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1281        return self
1282
1283    def __hash__(self):
1284        return self._hash
1285
1286    def __eq__(self, other):
1287        return hash(self) == hash(other)
1288
1289
1290def warn(msg):
1291    """Issue a warning.
1292
1293    If msg is a string, :class:`.exc.SAWarning` is used as
1294    the category.
1295
1296    """
1297    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1298
1299
1300def warn_limited(msg, args):
1301    """Issue a warning with a paramterized string, limiting the number
1302    of registrations.
1303
1304    """
1305    if args:
1306        msg = _hash_limit_string(msg, 10, args)
1307    warnings.warn(msg, exc.SAWarning, stacklevel=2)
1308
1309
1310def only_once(fn):
1311    """Decorate the given function to be a no-op after it is called exactly
1312    once."""
1313
1314    once = [fn]
1315
1316    def go(*arg, **kw):
1317        if once:
1318            once_fn = once.pop()
1319            return once_fn(*arg, **kw)
1320
1321    return go
1322
1323
1324_SQLA_RE = re.compile(r'sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py')
1325_UNITTEST_RE = re.compile(r'unit(?:2|test2?/)')
1326
1327
1328def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
1329    """Chop extraneous lines off beginning and end of a traceback.
1330
1331    :param tb:
1332      a list of traceback lines as returned by ``traceback.format_stack()``
1333
1334    :param exclude_prefix:
1335      a regular expression object matching lines to skip at beginning of
1336      ``tb``
1337
1338    :param exclude_suffix:
1339      a regular expression object matching lines to skip at end of ``tb``
1340    """
1341    start = 0
1342    end = len(tb) - 1
1343    while start <= end and exclude_prefix.search(tb[start]):
1344        start += 1
1345    while start <= end and exclude_suffix.search(tb[end]):
1346        end -= 1
1347    return tb[start:end + 1]
1348
1349NoneType = type(None)
1350
1351
1352def attrsetter(attrname):
1353    code = \
1354        "def set(obj, value):"\
1355        "    obj.%s = value" % attrname
1356    env = locals().copy()
1357    exec(code, env)
1358    return env['set']
1359
1360
1361class EnsureKWArgType(type):
1362    """Apply translation of functions to accept **kw arguments if they
1363    don't already.
1364
1365    """
1366    def __init__(cls, clsname, bases, clsdict):
1367        fn_reg = cls.ensure_kwarg
1368        if fn_reg:
1369            for key in clsdict:
1370                m = re.match(fn_reg, key)
1371                if m:
1372                    fn = clsdict[key]
1373                    spec = compat.inspect_getargspec(fn)
1374                    if not spec.keywords:
1375                        clsdict[key] = wrapped = cls._wrap_w_kw(fn)
1376                        setattr(cls, key, wrapped)
1377        super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
1378
1379    def _wrap_w_kw(self, fn):
1380
1381        def wrap(*arg, **kw):
1382            return fn(*arg)
1383        return update_wrapper(wrap, fn)
1384
1385