1# -*- coding: utf-8 -*-
2# Copyright (c) 2018 the Pockets team, see AUTHORS.
3# Licensed under the BSD License, see LICENSE for details.
4
5"""A pocket full of useful reflection functions!"""
6
7from __future__ import absolute_import, print_function
8
9import inspect
10import functools
11from os.path import basename
12from pkgutil import iter_modules
13
14import six
15from six import string_types
16
17from pockets.collections import listify
18from pockets.string import splitify
19
20
21__all__ = [
22    "collect_subclasses",
23    "collect_superclasses",
24    "collect_superclass_attr_names",
25    "hoist_submodules",
26    "import_star",
27    "import_submodules",
28    "is_data",
29    "resolve",
30    "unwrap",
31]
32
33
34def collect_subclasses(cls):
35    """
36    Recursively collects all descendant subclasses that inherit from the
37    given class, not including the class itself.
38
39    Note:
40        Does not include `cls` itself.
41
42    Args:
43        cls (class): The class object from which the collection should begin.
44
45    Returns:
46        list: A list of `class` objects that inherit from `cls`. This list
47            will not include `cls` itself.
48    """
49    subclasses = set()
50    for subclass in cls.__subclasses__():
51        subclasses.add(subclass)
52        subclasses.update(collect_subclasses(subclass))
53    return list(subclasses)
54
55
56def collect_superclasses(cls, terminal_class=None, modules=None):
57    """
58    Recursively collects all ancestor superclasses in the inheritance
59    hierarchy of the given class, including the class itself.
60
61    Note:
62        Inlcudes `cls` itself. Will not include `terminal_class`.
63
64    Args:
65        cls (class): The class object from which the collection should begin.
66        terminal_class (class or list): If `terminal_class` is encountered in
67            the hierarchy, we stop ascending the tree. `terminal_class` will
68            not be included in the returned list.
69        modules (string, module, or list): If `modules` is passed, we only
70            return classes that are in the given module/modules. This can be
71            used to exclude base classes that come from external libraries.
72
73    Returns:
74        list: A list of `class` objects from which `cls` inherits. This list
75            will include `cls` itself.
76    """
77    terminal_class = listify(terminal_class)
78    if modules is not None:
79        modules = listify(modules)
80        module_strings = []
81        for m in modules:
82            if isinstance(m, six.string_types):
83                module_strings.append(m)
84            else:
85                module_strings.append(m.__name__)
86        modules = module_strings
87
88    superclasses = set()
89    is_in_module = modules is None or cls.__module__ in modules
90    if is_in_module and cls not in terminal_class:
91        superclasses.add(cls)
92        for base in cls.__bases__:
93            superclasses.update(
94                collect_superclasses(base, terminal_class, modules)
95            )
96
97    return list(superclasses)
98
99
100def collect_superclass_attr_names(cls, terminal_class=None, modules=None):
101    """
102    Recursively collects all attribute names of ancestor superclasses in the
103    inheritance hierarchy of the given class, including the class itself.
104
105    Note:
106        Inlcudes `cls` itself. Will not include `terminal_class`.
107
108    Args:
109        cls (class): The class object from which the collection should begin.
110        terminal_class (class or list): If `terminal_class` is encountered in
111            the hierarchy, we stop ascending the tree. Attributes from
112            `terminal_class` will not be included in the returned list.
113        modules (string, module, or list): If `modules` is passed, we only
114            return classes that are in the given module/modules. This can be
115            used to exclude base classes that come from external libraries.
116
117    Returns:
118        list: A list of `str` attribute names for every `class` in the
119            inheritance hierarchy.
120    """
121    superclasses = collect_superclasses(cls, terminal_class, modules)
122    attr_names = set()
123    for superclass in superclasses:
124        attr_names.update(superclass.__dict__.keys())
125    return list(attr_names)
126
127
128def hoist_submodules(package, extend_all=True):
129    """
130    Sets `__all__` attrs from submodules of `package` as attrs on `package`.
131
132    Note:
133        This only considers attributes exported by `__all__`. If a submodule
134        does not define `__all__`, then it is ignored.
135
136    Effectively does::
137
138        from package.* import *
139
140    Args:
141        package (str or module): The parent package into which submodule
142            exports should be hoisted.
143        extend_all (bool): If True, `package.__all__` will be extended
144            to include the hoisted attributes. Defaults to True.
145
146    Returns:
147        list: List of all hoisted attribute names.
148
149    """
150    module = resolve(package)
151    hoisted_attrs = []
152    for submodule in import_submodules(module):
153        for attr_name, attr in import_star(submodule).items():
154            hoisted_attrs.append(attr_name)
155            setattr(module, attr_name, attr)
156
157    if extend_all:
158        if getattr(module, "__all__", None) is None:
159            module.__all__ = list(hoisted_attrs)
160        else:
161            module.__all__.extend(hoisted_attrs)
162
163    return hoisted_attrs
164
165
166def import_star(module):
167    """
168    Imports all exported attributes of `module` and returns them in a `dict`.
169
170    Note:
171        This only considers attributes exported by `__all__`. If `module`
172        does not define `__all__`, then nothing is imported.
173
174    Effectively does::
175
176        from module import *
177
178    Args:
179        module (str or module): The module from which a wildcard import
180            should be done.
181
182    Returns:
183        dict: Map of all imported attributes.
184
185    """
186    module = resolve(module)
187    attrs = getattr(module, "__all__", [])
188    return dict([(attr, getattr(module, attr)) for attr in attrs])
189
190
191def import_submodules(package):
192    """
193    Imports all submodules of `package`.
194
195    Effectively does::
196
197        __import__(package.*)
198
199    Args:
200        package (str or module): The parent package from which submodules
201            should be imported.
202
203    Yields:
204        module: The next submodule of `package`.
205
206    """
207    module = resolve(package)
208    if basename(module.__file__).startswith("__init__.py"):
209        for _, submodule_name, _ in iter_modules(module.__path__):
210            yield resolve(submodule_name, module)
211
212
213def is_data(obj):
214    """
215    Returns True if `obj` is a "data like" object.
216
217    Strongly inspired by `inspect.classify_class_attrs`. This function is
218    useful when trying to determine if an attribute has a meaningful docstring
219    or not. In general, a routine can have meaningful docstrings, whereas
220    non-routines cannot.
221
222    See Also:
223        * `inspect.classify_class_attrs`
224        * `inspect.isroutine`
225
226    Args:
227        obj (object): The object in question.
228
229    Returns:
230        bool: True if `obj` is "data like", False otherwise.
231    """
232    if isinstance(
233        obj, (staticmethod, classmethod, property)
234    ) or inspect.isroutine(obj):
235        return False
236    else:
237        return True
238
239
240def resolve(name, modules=None):
241    """
242    Resolve a dotted name to an object (usually class, module, or function).
243
244    If `name` is a string, attempt to resolve it according to Python
245    dot notation, e.g. "path.to.MyClass". If `name` is anything other than a
246    string, return it immediately:
247
248    >>> resolve("calendar.TextCalendar")
249    <class 'calendar.TextCalendar'>
250    >>> resolve(object())
251    <object object at 0x...>
252
253    If `modules` is specified, then resolution of `name` is restricted
254    to the given modules. Leading dots are allowed in `name`, but they are
255    ignored. Resolution **will not** traverse up the module path if `modules`
256    is specified.
257
258    If `modules` is not specified and `name` has leading dots, then resolution
259    is first attempted relative to the calling function's module, and then
260    absolutely. Resolution **will** traverse up the module path. If `name` has
261    no leading dots, resolution is first attempted absolutely and then
262    relative to the calling module.
263
264    Pass an empty string for `modules` to only use absolute resolution.
265
266    Warning:
267        Do not resolve strings supplied by an end user without specifying
268        `modules`. Instantiating an arbitrary object specified by an end user
269        can introduce a potential security risk.
270
271        To avoid this, restrict the search path by explicitly specifying
272        `modules`.
273
274    Restricting `name` resolution to a set of `modules`:
275
276    >>> resolve("pockets.camel")
277    <function camel at 0x...>
278    >>> resolve("pockets.camel", modules=["re", "six"])
279    Traceback (most recent call last):
280    ValueError: Unable to resolve 'pockets.camel' in modules: ['re', 'six']
281      ...
282
283    Args:
284        name (str or object): A dotted name.
285
286        modules (str, module, or list, optional): A module or list of modules,
287            under which to search for `name`.
288
289    Returns:
290        object: The object specified by `name`.
291
292    Raises:
293        ValueError: If `name` can't be resolved.
294
295    """
296    if not isinstance(name, string_types):
297        return name
298
299    obj_path = splitify(name, ".", include_empty=True)
300    search_paths = []
301    if modules is not None:
302        while not obj_path[0].strip():
303            obj_path.pop(0)
304        for module_path in listify(modules):
305            search_paths.append(splitify(module_path, ".") + obj_path)
306    else:
307        caller = inspect.getouterframes(inspect.currentframe())[1][0].f_globals
308        module_path = caller["__name__"].split(".")
309        if not obj_path[0]:
310            obj_path.pop(0)
311            while not obj_path[0]:
312                obj_path.pop(0)
313                if module_path:
314                    module_path.pop()
315
316            search_paths.append(module_path + obj_path)
317            search_paths.append(obj_path)
318        else:
319            search_paths.append(obj_path)
320            search_paths.append(module_path + obj_path)
321
322    exceptions = []
323    for path in search_paths:
324        # Import the most deeply nested module available
325        module = None
326        module_path = []
327        obj_path = list(path)
328        while obj_path:
329            module_name = obj_path.pop(0)
330            while not module_name:
331                module_name = obj_path.pop(0)
332            if isinstance(module_name, string_types):
333                package = ".".join(module_path + [module_name])
334                try:
335                    module = __import__(package, fromlist=module_name)
336                except ImportError as ex:
337                    exceptions.append(ex)
338                    obj_path = [module_name] + obj_path
339                    break
340                else:
341                    module_path.append(module_name)
342            else:
343                module = module_name
344                module_path.append(module.__name__)
345
346        if module:
347            if obj_path:
348                try:
349                    return functools.reduce(getattr, obj_path, module)
350                except AttributeError as ex:
351                    exceptions.append(ex)
352            else:
353                return module
354
355    if modules:
356        msg = "Unable to resolve '{0}' in modules: {1}".format(name, modules)
357    else:
358        msg = "Unable to resolve '{0}'".format(name)
359
360    if exceptions:
361        msgs = ["{0}: {1}".format(type(e).__name__, e) for e in exceptions]
362        raise ValueError("\n    ".join([msg] + msgs))
363    else:
364        raise ValueError(msg)
365
366
367def unwrap(func):
368    """
369    Finds the innermost function that has been wrapped using `functools.wrap`.
370
371    Note:
372        This function relies on the existence of the `__wrapped__` attribute,
373        which was not automatically added until Python 3.2. If you are using
374        an older version of Python, you'll have to manually add the
375        `__wrapped__` attribute in order to use `unwrap`::
376
377            def my_decorator(func):
378                @wraps(func)
379                def with_my_decorator(*args, **kwargs):
380                    return func(*args, **kwargs)
381
382                if not hasattr(with_my_decorator, '__wrapped__'):
383                    with_my_decorator.__wrapped__ = func
384
385                return with_my_decorator
386
387    Args:
388        func (function): A function that may or may not have been wrapped
389            using `functools.wrap`.
390
391    Returns:
392        function: The original function before it was wrapped using
393            `functools.wrap`. `func` is returned directly, if it was never
394            wrapped using `functools.wrap`.
395    """
396    return unwrap(func.__wrapped__) if hasattr(func, "__wrapped__") else func
397