1"""Implementation of __array_function__ overrides from NEP-18."""
2import collections
3import functools
4import os
5import textwrap
6
7from numpy.core._multiarray_umath import (
8    add_docstring, implement_array_function, _get_implementing_args)
9from numpy.compat._inspect import getargspec
10
11
12ARRAY_FUNCTION_ENABLED = bool(
13    int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
14
15array_function_like_doc = (
16    """like : array_like
17        Reference object to allow the creation of arrays which are not
18        NumPy arrays. If an array-like passed in as ``like`` supports
19        the ``__array_function__`` protocol, the result will be defined
20        by it. In this case, it ensures the creation of an array object
21        compatible with that passed in via this argument.
22
23        .. note::
24            The ``like`` keyword is an experimental feature pending on
25            acceptance of :ref:`NEP 35 <NEP35>`."""
26)
27
28def set_array_function_like_doc(public_api):
29    if public_api.__doc__ is not None:
30        public_api.__doc__ = public_api.__doc__.replace(
31            "${ARRAY_FUNCTION_LIKE}",
32            array_function_like_doc,
33        )
34    return public_api
35
36
37add_docstring(
38    implement_array_function,
39    """
40    Implement a function with checks for __array_function__ overrides.
41
42    All arguments are required, and can only be passed by position.
43
44    Parameters
45    ----------
46    implementation : function
47        Function that implements the operation on NumPy array without
48        overrides when called like ``implementation(*args, **kwargs)``.
49    public_api : function
50        Function exposed by NumPy's public API originally called like
51        ``public_api(*args, **kwargs)`` on which arguments are now being
52        checked.
53    relevant_args : iterable
54        Iterable of arguments to check for __array_function__ methods.
55    args : tuple
56        Arbitrary positional arguments originally passed into ``public_api``.
57    kwargs : dict
58        Arbitrary keyword arguments originally passed into ``public_api``.
59
60    Returns
61    -------
62    Result from calling ``implementation()`` or an ``__array_function__``
63    method, as appropriate.
64
65    Raises
66    ------
67    TypeError : if no implementation is found.
68    """)
69
70
71# exposed for testing purposes; used internally by implement_array_function
72add_docstring(
73    _get_implementing_args,
74    """
75    Collect arguments on which to call __array_function__.
76
77    Parameters
78    ----------
79    relevant_args : iterable of array-like
80        Iterable of possibly array-like arguments to check for
81        __array_function__ methods.
82
83    Returns
84    -------
85    Sequence of arguments with __array_function__ methods, in the order in
86    which they should be called.
87    """)
88
89
90ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
91
92
93def verify_matching_signatures(implementation, dispatcher):
94    """Verify that a dispatcher function has the right signature."""
95    implementation_spec = ArgSpec(*getargspec(implementation))
96    dispatcher_spec = ArgSpec(*getargspec(dispatcher))
97
98    if (implementation_spec.args != dispatcher_spec.args or
99            implementation_spec.varargs != dispatcher_spec.varargs or
100            implementation_spec.keywords != dispatcher_spec.keywords or
101            (bool(implementation_spec.defaults) !=
102             bool(dispatcher_spec.defaults)) or
103            (implementation_spec.defaults is not None and
104             len(implementation_spec.defaults) !=
105             len(dispatcher_spec.defaults))):
106        raise RuntimeError('implementation and dispatcher for %s have '
107                           'different function signatures' % implementation)
108
109    if implementation_spec.defaults is not None:
110        if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
111            raise RuntimeError('dispatcher functions can only use None for '
112                               'default argument values')
113
114
115def set_module(module):
116    """Decorator for overriding __module__ on a function or class.
117
118    Example usage::
119
120        @set_module('numpy')
121        def example():
122            pass
123
124        assert example.__module__ == 'numpy'
125    """
126    def decorator(func):
127        if module is not None:
128            func.__module__ = module
129        return func
130    return decorator
131
132
133
134# Call textwrap.dedent here instead of in the function so as to avoid
135# calling dedent multiple times on the same text
136_wrapped_func_source = textwrap.dedent("""
137    @functools.wraps(implementation)
138    def {name}(*args, **kwargs):
139        relevant_args = dispatcher(*args, **kwargs)
140        return implement_array_function(
141            implementation, {name}, relevant_args, args, kwargs)
142    """)
143
144
145def array_function_dispatch(dispatcher, module=None, verify=True,
146                            docs_from_dispatcher=False):
147    """Decorator for adding dispatch with the __array_function__ protocol.
148
149    See NEP-18 for example usage.
150
151    Parameters
152    ----------
153    dispatcher : callable
154        Function that when called like ``dispatcher(*args, **kwargs)`` with
155        arguments from the NumPy function call returns an iterable of
156        array-like arguments to check for ``__array_function__``.
157    module : str, optional
158        __module__ attribute to set on new function, e.g., ``module='numpy'``.
159        By default, module is copied from the decorated function.
160    verify : bool, optional
161        If True, verify the that the signature of the dispatcher and decorated
162        function signatures match exactly: all required and optional arguments
163        should appear in order with the same names, but the default values for
164        all optional arguments should be ``None``. Only disable verification
165        if the dispatcher's signature needs to deviate for some particular
166        reason, e.g., because the function has a signature like
167        ``func(*args, **kwargs)``.
168    docs_from_dispatcher : bool, optional
169        If True, copy docs from the dispatcher function onto the dispatched
170        function, rather than from the implementation. This is useful for
171        functions defined in C, which otherwise don't have docstrings.
172
173    Returns
174    -------
175    Function suitable for decorating the implementation of a NumPy function.
176    """
177
178    if not ARRAY_FUNCTION_ENABLED:
179        def decorator(implementation):
180            if docs_from_dispatcher:
181                add_docstring(implementation, dispatcher.__doc__)
182            if module is not None:
183                implementation.__module__ = module
184            return implementation
185        return decorator
186
187    def decorator(implementation):
188        if verify:
189            verify_matching_signatures(implementation, dispatcher)
190
191        if docs_from_dispatcher:
192            add_docstring(implementation, dispatcher.__doc__)
193
194        # Equivalently, we could define this function directly instead of using
195        # exec. This version has the advantage of giving the helper function a
196        # more interpettable name. Otherwise, the original function does not
197        # show up at all in many cases, e.g., if it's written in C or if the
198        # dispatcher gets an invalid keyword argument.
199        source = _wrapped_func_source.format(name=implementation.__name__)
200
201        source_object = compile(
202            source, filename='<__array_function__ internals>', mode='exec')
203        scope = {
204            'implementation': implementation,
205            'dispatcher': dispatcher,
206            'functools': functools,
207            'implement_array_function': implement_array_function,
208        }
209        exec(source_object, scope)
210
211        public_api = scope[implementation.__name__]
212
213        if module is not None:
214            public_api.__module__ = module
215
216        public_api._implementation = implementation
217
218        return public_api
219
220    return decorator
221
222
223def array_function_from_dispatcher(
224        implementation, module=None, verify=True, docs_from_dispatcher=True):
225    """Like array_function_dispatcher, but with function arguments flipped."""
226
227    def decorator(dispatcher):
228        return array_function_dispatch(
229            dispatcher, module, verify=verify,
230            docs_from_dispatcher=docs_from_dispatcher)(implementation)
231    return decorator
232