1import typing
2import inspect
3import functools
4from . import _uarray  # type: ignore
5import copyreg  # type: ignore
6import atexit
7import pickle
8
9ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]]
10ArgumentReplacerType = typing.Callable[
11    [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict]
12]
13
14from ._uarray import (  # type: ignore
15    BackendNotImplementedError,
16    _Function,
17    _SkipBackendContext,
18    _SetBackendContext,
19)
20
21__all__ = [
22    "set_backend",
23    "set_global_backend",
24    "skip_backend",
25    "register_backend",
26    "clear_backends",
27    "create_multimethod",
28    "generate_multimethod",
29    "_Function",
30    "BackendNotImplementedError",
31    "Dispatchable",
32    "wrap_single_convertor",
33    "all_of_type",
34    "mark_as",
35]
36
37
38def unpickle_function(mod_name, qname):
39    import importlib
40
41    try:
42        module = importlib.import_module(mod_name)
43        func = getattr(module, qname)
44        return func
45    except (ImportError, AttributeError) as e:
46        from pickle import UnpicklingError
47
48        raise UnpicklingError from e
49
50
51def pickle_function(func):
52    mod_name = getattr(func, "__module__", None)
53    qname = getattr(func, "__qualname__", None)
54
55    try:
56        test = unpickle_function(mod_name, qname)
57    except pickle.UnpicklingError:
58        test = None
59
60    if test is not func:
61        raise pickle.PicklingError(
62            "Can't pickle {}: it's not the same object as {}".format(func, test)
63        )
64
65    return unpickle_function, (mod_name, qname)
66
67
68copyreg.pickle(_Function, pickle_function)
69atexit.register(_uarray.clear_all_globals)
70
71
72def create_multimethod(*args, **kwargs):
73    """
74    Creates a decorator for generating multimethods.
75
76    This function creates a decorator that can be used with an argument
77    extractor in order to generate a multimethod. Other than for the
78    argument extractor, all arguments are passed on to
79    :obj:`generate_multimethod`.
80
81    See Also
82    --------
83    generate_multimethod : Generates a multimethod.
84    """
85
86    def wrapper(a):
87        return generate_multimethod(a, *args, **kwargs)
88
89    return wrapper
90
91
92def generate_multimethod(
93    argument_extractor: ArgumentExtractorType,
94    argument_replacer: ArgumentReplacerType,
95    domain: str,
96    default: typing.Optional[typing.Callable] = None
97):
98    """
99    Generates a multimethod.
100
101    Parameters
102    ----------
103    argument_extractor : ArgumentExtractorType
104        A callable which extracts the dispatchable arguments. Extracted arguments
105        should be marked by the :obj:`Dispatchable` class. It has the same signature
106        as the desired multimethod.
107    argument_replacer : ArgumentReplacerType
108        A callable with the signature (args, kwargs, dispatchables), which should also
109        return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
110    domain : str
111        A string value indicating the domain of this multimethod.
112    default : Optional[Callable], optional
113        The default implementation of this multimethod, where ``None`` (the default) specifies
114        there is no default implementation.
115
116    Examples
117    --------
118    In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``.
119    The trailing comma is needed because the args have to be returned as an iterable.
120
121    >>> def override_me(a, b):
122    ...   return Dispatchable(a, int),
123
124    Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the
125    supplied ones.
126
127    >>> def override_replacer(args, kwargs, dispatchables):
128    ...     return (dispatchables[0], args[1]), {}
129
130    Next, we define the multimethod.
131
132    >>> overridden_me = generate_multimethod(
133    ...     override_me, override_replacer, "ua_examples"
134    ... )
135
136    Notice that there's no default implementation, unless you supply one.
137
138    >>> overridden_me(1, "a")
139    Traceback (most recent call last):
140        ...
141    uarray.backend.BackendNotImplementedError: ...
142    >>> overridden_me2 = generate_multimethod(
143    ...     override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
144    ... )
145    >>> overridden_me2(1, "a")
146    (1, 'a')
147
148    See Also
149    --------
150    uarray :
151        See the module documentation for how to override the method by creating backends.
152    """
153    kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
154    ua_func = _Function(
155        argument_extractor,
156        argument_replacer,
157        domain,
158        arg_defaults,
159        kw_defaults,
160        default,
161    )
162
163    return functools.update_wrapper(ua_func, argument_extractor)
164
165
166def set_backend(backend, coerce=False, only=False):
167    """
168    A context manager that sets the preferred backend.
169
170    Parameters
171    ----------
172    backend
173        The backend to set.
174    coerce
175        Whether or not to coerce to a specific backend's types. Implies ``only``.
176    only
177        Whether or not this should be the last backend to try.
178
179    See Also
180    --------
181    skip_backend : A context manager that allows skipping of backends.
182    set_global_backend : Set a single, global backend for a domain.
183    """
184    try:
185        return backend.__ua_cache__["set", coerce, only]
186    except AttributeError:
187        backend.__ua_cache__ = {}
188    except KeyError:
189        pass
190
191    ctx = _SetBackendContext(backend, coerce, only)
192    backend.__ua_cache__["set", coerce, only] = ctx
193    return ctx
194
195
196def skip_backend(backend):
197    """
198    A context manager that allows one to skip a given backend from processing
199    entirely. This allows one to use another backend's code in a library that
200    is also a consumer of the same backend.
201
202    Parameters
203    ----------
204    backend
205        The backend to skip.
206
207    See Also
208    --------
209    set_backend : A context manager that allows setting of backends.
210    set_global_backend : Set a single, global backend for a domain.
211    """
212    try:
213        return backend.__ua_cache__["skip"]
214    except AttributeError:
215        backend.__ua_cache__ = {}
216    except KeyError:
217        pass
218
219    ctx = _SkipBackendContext(backend)
220    backend.__ua_cache__["skip"] = ctx
221    return ctx
222
223
224def get_defaults(f):
225    sig = inspect.signature(f)
226    kw_defaults = {}
227    arg_defaults = []
228    opts = set()
229    for k, v in sig.parameters.items():
230        if v.default is not inspect.Parameter.empty:
231            kw_defaults[k] = v.default
232        if v.kind in (
233            inspect.Parameter.POSITIONAL_ONLY,
234            inspect.Parameter.POSITIONAL_OR_KEYWORD,
235        ):
236            arg_defaults.append(v.default)
237        opts.add(k)
238
239    return kw_defaults, tuple(arg_defaults), opts
240
241
242def set_global_backend(backend, coerce=False, only=False):
243    """
244    This utility method replaces the default backend for permanent use. It
245    will be tried in the list of backends automatically, unless the
246    ``only`` flag is set on a backend. This will be the first tried
247    backend outside the :obj:`set_backend` context manager.
248
249    Note that this method is not thread-safe.
250
251    .. warning::
252        We caution library authors against using this function in
253        their code. We do *not* support this use-case. This function
254        is meant to be used only by users themselves, or by a reference
255        implementation, if one exists.
256
257    Parameters
258    ----------
259    backend
260        The backend to register.
261
262    See Also
263    --------
264    set_backend : A context manager that allows setting of backends.
265    skip_backend : A context manager that allows skipping of backends.
266    """
267    _uarray.set_global_backend(backend, coerce, only)
268
269
270def register_backend(backend):
271    """
272    This utility method sets registers backend for permanent use. It
273    will be tried in the list of backends automatically, unless the
274    ``only`` flag is set on a backend.
275
276    Note that this method is not thread-safe.
277
278    Parameters
279    ----------
280    backend
281        The backend to register.
282    """
283    _uarray.register_backend(backend)
284
285
286def clear_backends(domain, registered=True, globals=False):
287    """
288    This utility method clears registered backends.
289
290    .. warning::
291        We caution library authors against using this function in
292        their code. We do *not* support this use-case. This function
293        is meant to be used only by the users themselves.
294
295    .. warning::
296        Do NOT use this method inside a multimethod call, or the
297        program is likely to crash.
298
299    Parameters
300    ----------
301    domain : Optional[str]
302        The domain for which to de-register backends. ``None`` means
303        de-register for all domains.
304    registered : bool
305        Whether or not to clear registered backends. See :obj:`register_backend`.
306    globals : bool
307        Whether or not to clear global backends. See :obj:`set_global_backend`.
308
309    See Also
310    --------
311    register_backend : Register a backend globally.
312    set_global_backend : Set a global backend.
313    """
314    _uarray.clear_backends(domain, registered, globals)
315
316
317class Dispatchable:
318    """
319    A utility class which marks an argument with a specific dispatch type.
320
321
322    Attributes
323    ----------
324    value
325        The value of the Dispatchable.
326
327    type
328        The type of the Dispatchable.
329
330    Examples
331    --------
332    >>> x = Dispatchable(1, str)
333    >>> x
334    <Dispatchable: type=<class 'str'>, value=1>
335
336    See Also
337    --------
338    all_of_type
339        Marks all unmarked parameters of a function.
340
341    mark_as
342        Allows one to create a utility function to mark as a given type.
343    """
344
345    def __init__(self, value, dispatch_type, coercible=True):
346        self.value = value
347        self.type = dispatch_type
348        self.coercible = coercible
349
350    def __getitem__(self, index):
351        return (self.type, self.value)[index]
352
353    def __str__(self):
354        return "<{0}: type={1!r}, value={2!r}>".format(
355            type(self).__name__, self.type, self.value
356        )
357
358    __repr__ = __str__
359
360
361def mark_as(dispatch_type):
362    """
363    Creates a utility function to mark something as a specific type.
364
365    Examples
366    --------
367    >>> mark_int = mark_as(int)
368    >>> mark_int(1)
369    <Dispatchable: type=<class 'int'>, value=1>
370    """
371    return functools.partial(Dispatchable, dispatch_type=dispatch_type)
372
373
374def all_of_type(arg_type):
375    """
376    Marks all unmarked arguments as a given type.
377
378    Examples
379    --------
380    >>> @all_of_type(str)
381    ... def f(a, b):
382    ...     return a, Dispatchable(b, int)
383    >>> f('a', 1)
384    (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>)
385    """
386
387    def outer(func):
388        @functools.wraps(func)
389        def inner(*args, **kwargs):
390            extracted_args = func(*args, **kwargs)
391            return tuple(
392                Dispatchable(arg, arg_type)
393                if not isinstance(arg, Dispatchable)
394                else arg
395                for arg in extracted_args
396            )
397
398        return inner
399
400    return outer
401
402
403def wrap_single_convertor(convert_single):
404    """
405    Wraps a ``__ua_convert__`` defined for a single element to all elements.
406    If any of them return ``NotImplemented``, the operation is assumed to be
407    undefined.
408
409    Accepts a signature of (value, type, coerce).
410    """
411
412    @functools.wraps(convert_single)
413    def __ua_convert__(dispatchables, coerce):
414        converted = []
415        for d in dispatchables:
416            c = convert_single(d.value, d.type, coerce and d.coercible)
417
418            if c is NotImplemented:
419                return NotImplemented
420
421            converted.append(c)
422
423        return converted
424
425    return __ua_convert__
426