1import os
2import uuid
3import weakref
4import collections
5
6import numba
7from numba.core import types, errors, utils, config
8
9# Exported symbols
10from numba.core.typing.typeof import typeof_impl  # noqa: F401
11from numba.core.typing.templates import infer, infer_getattr  # noqa: F401
12from numba.core.imputils import (  # noqa: F401
13    lower_builtin, lower_getattr, lower_getattr_generic,  # noqa: F401
14    lower_setattr, lower_setattr_generic, lower_cast)  # noqa: F401
15from numba.core.datamodel import models   # noqa: F401
16from numba.core.datamodel import register_default as register_model  # noqa: F401, E501
17from numba.core.pythonapi import box, unbox, reflect, NativeValue  # noqa: F401
18from numba._helperlib import _import_cython_function  # noqa: F401
19from numba.core.serialize import ReduceMixin
20
21
22def type_callable(func):
23    """
24    Decorate a function as implementing typing for the callable *func*.
25    *func* can be a callable object (probably a global) or a string
26    denoting a built-in operation (such 'getitem' or '__array_wrap__')
27    """
28    from numba.core.typing.templates import (CallableTemplate, infer,
29                                             infer_global)
30    if not callable(func) and not isinstance(func, str):
31        raise TypeError("`func` should be a function or string")
32    try:
33        func_name = func.__name__
34    except AttributeError:
35        func_name = str(func)
36
37    def decorate(typing_func):
38        def generic(self):
39            return typing_func(self.context)
40
41        name = "%s_CallableTemplate" % (func_name,)
42        bases = (CallableTemplate,)
43        class_dict = dict(key=func, generic=generic)
44        template = type(name, bases, class_dict)
45        infer(template)
46        if callable(func):
47            infer_global(func, types.Function(template))
48        return typing_func
49
50    return decorate
51
52
53# By default, an *overload* does not have a cpython wrapper because it is not
54# callable from python.
55_overload_default_jit_options = {'no_cpython_wrapper': True}
56
57
58def overload(func, jit_options={}, strict=True, inline='never',
59             prefer_literal=False):
60    """
61    A decorator marking the decorated function as typing and implementing
62    *func* in nopython mode.
63
64    The decorated function will have the same formal parameters as *func*
65    and be passed the Numba types of those parameters.  It should return
66    a function implementing *func* for the given types.
67
68    Here is an example implementing len() for tuple types::
69
70        @overload(len)
71        def tuple_len(seq):
72            if isinstance(seq, types.BaseTuple):
73                n = len(seq)
74                def len_impl(seq):
75                    return n
76                return len_impl
77
78    Compiler options can be passed as an dictionary using the **jit_options**
79    argument.
80
81    Overloading strictness (that the typing and implementing signatures match)
82    is enforced by the **strict** keyword argument, it is recommended that this
83    is set to True (default).
84
85    To handle a function that accepts imprecise types, an overload
86    definition can return 2-tuple of ``(signature, impl_function)``, where
87    the ``signature`` is a ``typing.Signature`` specifying the precise
88    signature to be used; and ``impl_function`` is the same implementation
89    function as in the simple case.
90
91    If the kwarg inline determines whether the overload is inlined in the
92    calling function and can be one of three values:
93    * 'never' (default) - the overload is never inlined.
94    * 'always' - the overload is always inlined.
95    * a function that takes two arguments, both of which are instances of a
96      namedtuple with fields:
97        * func_ir
98        * typemap
99        * calltypes
100        * signature
101      The first argument holds the information from the caller, the second
102      holds the information from the callee. The function should return Truthy
103      to determine whether to inline, this essentially permitting custom
104      inlining rules (typical use might be cost models).
105
106    The *prefer_literal* option allows users to control if literal types should
107    be tried first or last. The default (`False`) is to use non-literal types.
108    Implementations that can specialize based on literal values should set the
109    option to `True`. Note, this option maybe expanded in the near future to
110    allow for more control (e.g. disabling non-literal types).
111    """
112    from numba.core.typing.templates import make_overload_template, infer_global
113
114    # set default options
115    opts = _overload_default_jit_options.copy()
116    opts.update(jit_options)  # let user options override
117
118    def decorate(overload_func):
119        template = make_overload_template(func, overload_func, opts, strict,
120                                          inline, prefer_literal)
121        infer(template)
122        if callable(func):
123            infer_global(func, types.Function(template))
124        return overload_func
125
126    return decorate
127
128
129def register_jitable(*args, **kwargs):
130    """
131    Register a regular python function that can be executed by the python
132    interpreter and can be compiled into a nopython function when referenced
133    by other jit'ed functions.  Can be used as::
134
135        @register_jitable
136        def foo(x, y):
137            return x + y
138
139    Or, with compiler options::
140
141        @register_jitable(_nrt=False) # disable runtime allocation
142        def foo(x, y):
143            return x + y
144
145    """
146    def wrap(fn):
147        # It is just a wrapper for @overload
148        inline = kwargs.pop('inline', 'never')
149
150        @overload(fn, jit_options=kwargs, inline=inline, strict=False)
151        def ov_wrap(*args, **kwargs):
152            return fn
153        return fn
154
155    if kwargs:
156        return wrap
157    else:
158        return wrap(*args)
159
160
161def overload_attribute(typ, attr, **kwargs):
162    """
163    A decorator marking the decorated function as typing and implementing
164    attribute *attr* for the given Numba type in nopython mode.
165
166    *kwargs* are passed to the underlying `@overload` call.
167
168    Here is an example implementing .nbytes for array types::
169
170        @overload_attribute(types.Array, 'nbytes')
171        def array_nbytes(arr):
172            def get(arr):
173                return arr.size * arr.itemsize
174            return get
175    """
176    # TODO implement setters
177    from numba.core.typing.templates import make_overload_attribute_template
178
179    def decorate(overload_func):
180        template = make_overload_attribute_template(
181            typ, attr, overload_func,
182            inline=kwargs.get('inline', 'never'),
183        )
184        infer_getattr(template)
185        overload(overload_func, **kwargs)(overload_func)
186        return overload_func
187
188    return decorate
189
190
191def overload_method(typ, attr, **kwargs):
192    """
193    A decorator marking the decorated function as typing and implementing
194    attribute *attr* for the given Numba type in nopython mode.
195
196    *kwargs* are passed to the underlying `@overload` call.
197
198    Here is an example implementing .take() for array types::
199
200        @overload_method(types.Array, 'take')
201        def array_take(arr, indices):
202            if isinstance(indices, types.Array):
203                def take_impl(arr, indices):
204                    n = indices.shape[0]
205                    res = np.empty(n, arr.dtype)
206                    for i in range(n):
207                        res[i] = arr[indices[i]]
208                    return res
209                return take_impl
210    """
211    from numba.core.typing.templates import make_overload_method_template
212
213    def decorate(overload_func):
214        template = make_overload_method_template(
215            typ, attr, overload_func,
216            inline=kwargs.get('inline', 'never'),
217            prefer_literal=kwargs.get('prefer_literal', False)
218        )
219        infer_getattr(template)
220        overload(overload_func, **kwargs)(overload_func)
221        return overload_func
222
223    return decorate
224
225
226def make_attribute_wrapper(typeclass, struct_attr, python_attr):
227    """
228    Make an automatic attribute wrapper exposing member named *struct_attr*
229    as a read-only attribute named *python_attr*.
230    The given *typeclass*'s model must be a StructModel subclass.
231    """
232    from numba.core.typing.templates import AttributeTemplate
233    from numba.core.datamodel import default_manager
234    from numba.core.datamodel.models import StructModel
235    from numba.core.imputils import impl_ret_borrowed
236    from numba.core import cgutils
237
238    if not isinstance(typeclass, type) or not issubclass(typeclass, types.Type):
239        raise TypeError("typeclass should be a Type subclass, got %s"
240                        % (typeclass,))
241
242    def get_attr_fe_type(typ):
243        """
244        Get the Numba type of member *struct_attr* in *typ*.
245        """
246        model = default_manager.lookup(typ)
247        if not isinstance(model, StructModel):
248            raise TypeError("make_struct_attribute_wrapper() needs a type "
249                            "with a StructModel, but got %s" % (model,))
250        return model.get_member_fe_type(struct_attr)
251
252    @infer_getattr
253    class StructAttribute(AttributeTemplate):
254        key = typeclass
255
256        def generic_resolve(self, typ, attr):
257            if attr == python_attr:
258                return get_attr_fe_type(typ)
259
260    @lower_getattr(typeclass, python_attr)
261    def struct_getattr_impl(context, builder, typ, val):
262        val = cgutils.create_struct_proxy(typ)(context, builder, value=val)
263        attrty = get_attr_fe_type(typ)
264        attrval = getattr(val, struct_attr)
265        return impl_ret_borrowed(context, builder, attrty, attrval)
266
267
268class _Intrinsic(ReduceMixin):
269    """
270    Dummy callable for intrinsic
271    """
272    _memo = weakref.WeakValueDictionary()
273    # hold refs to last N functions deserialized, retaining them in _memo
274    # regardless of whether there is another reference
275    _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)
276
277    __uuid = None
278
279    def __init__(self, name, defn):
280        self._name = name
281        self._defn = defn
282
283    @property
284    def _uuid(self):
285        """
286        An instance-specific UUID, to avoid multiple deserializations of
287        a given instance.
288
289        Note this is lazily-generated, for performance reasons.
290        """
291        u = self.__uuid
292        if u is None:
293            u = str(uuid.uuid1())
294            self._set_uuid(u)
295        return u
296
297    def _set_uuid(self, u):
298        assert self.__uuid is None
299        self.__uuid = u
300        self._memo[u] = self
301        self._recent.append(self)
302
303    def _register(self):
304        from numba.core.typing.templates import (make_intrinsic_template,
305                                                 infer_global)
306
307        template = make_intrinsic_template(self, self._defn, self._name)
308        infer(template)
309        infer_global(self, types.Function(template))
310
311    def __call__(self, *args, **kwargs):
312        """
313        This is only defined to pretend to be a callable from CPython.
314        """
315        msg = '{0} is not usable in pure-python'.format(self)
316        raise NotImplementedError(msg)
317
318    def __repr__(self):
319        return "<intrinsic {0}>".format(self._name)
320
321    def __deepcopy__(self, memo):
322        # NOTE: Intrinsic are immutable and we don't need to copy.
323        #       This is triggered from deepcopy of statements.
324        return self
325
326    def _reduce_states(self):
327        """
328        NOTE: part of ReduceMixin protocol
329        """
330        return dict(uuid=self._uuid, name=self._name, defn=self._defn)
331
332    @classmethod
333    def _rebuild(cls, uuid, name, defn):
334        """
335        NOTE: part of ReduceMixin protocol
336        """
337        try:
338            return cls._memo[uuid]
339        except KeyError:
340            llc = cls(name=name, defn=defn)
341            llc._register()
342            llc._set_uuid(uuid)
343            return llc
344
345
346def intrinsic(*args, **kwargs):
347    """
348    A decorator marking the decorated function as typing and implementing
349    *func* in nopython mode using the llvmlite IRBuilder API.  This is an escape
350    hatch for expert users to build custom LLVM IR that will be inlined to
351    the caller.
352
353    The first argument to *func* is the typing context.  The rest of the
354    arguments corresponds to the type of arguments of the decorated function.
355    These arguments are also used as the formal argument of the decorated
356    function.  If *func* has the signature ``foo(typing_context, arg0, arg1)``,
357    the decorated function will have the signature ``foo(arg0, arg1)``.
358
359    The return values of *func* should be a 2-tuple of expected type signature,
360    and a code-generation function that will passed to ``lower_builtin``.
361    For unsupported operation, return None.
362
363    Here is an example implementing a ``cast_int_to_byte_ptr`` that cast
364    any integer to a byte pointer::
365
366        @intrinsic
367        def cast_int_to_byte_ptr(typingctx, src):
368            # check for accepted types
369            if isinstance(src, types.Integer):
370                # create the expected type signature
371                result_type = types.CPointer(types.uint8)
372                sig = result_type(types.uintp)
373                # defines the custom code generation
374                def codegen(context, builder, signature, args):
375                    # llvm IRBuilder code here
376                    [src] = args
377                    rtype = signature.return_type
378                    llrtype = context.get_value_type(rtype)
379                    return builder.inttoptr(src, llrtype)
380                return sig, codegen
381    """
382    # Make inner function for the actual work
383    def _intrinsic(func):
384        name = getattr(func, '__name__', str(func))
385        llc = _Intrinsic(name, func, **kwargs)
386        llc._register()
387        return llc
388
389    if not kwargs:
390        # No option is given
391        return _intrinsic(*args)
392    else:
393        # options are given, create a new callable to recv the
394        # definition function
395        def wrapper(func):
396            return _intrinsic(func)
397        return wrapper
398
399
400def get_cython_function_address(module_name, function_name):
401    """
402    Get the address of a Cython function.
403
404    Args
405    ----
406    module_name:
407        Name of the Cython module
408    function_name:
409        Name of the Cython function
410
411    Returns
412    -------
413    A Python int containing the address of the function
414
415    """
416    return _import_cython_function(module_name, function_name)
417
418
419def include_path():
420    """Returns the C include directory path.
421    """
422    include_dir = os.path.dirname(os.path.dirname(numba.__file__))
423    path = os.path.abspath(include_dir)
424    return path
425
426
427def sentry_literal_args(pysig, literal_args, args, kwargs):
428    """Ensures that the given argument types (in *args* and *kwargs*) are
429    literally typed for a function with the python signature *pysig* and the
430    list of literal argument names in *literal_args*.
431
432    Alternatively, this is the same as::
433
434        SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)
435    """
436    boundargs = pysig.bind(*args, **kwargs)
437
438    # Find literal argument positions and whether it is satisfied.
439    request_pos = set()
440    missing = False
441    for i, (k, v) in enumerate(boundargs.arguments.items()):
442        if k in literal_args:
443            request_pos.add(i)
444            if not isinstance(v, types.Literal):
445                missing = True
446    if missing:
447        # Yes, there are missing required literal arguments
448        e = errors.ForceLiteralArg(request_pos)
449
450        # A helper function to fold arguments
451        def folded(args, kwargs):
452            out = pysig.bind(*args, **kwargs).arguments.values()
453            return tuple(out)
454
455        raise e.bind_fold_arguments(folded)
456
457
458class SentryLiteralArgs(collections.namedtuple(
459        '_SentryLiteralArgs', ['literal_args'])):
460    """
461    Parameters
462    ----------
463    literal_args : Sequence[str]
464        A sequence of names for literal arguments
465
466    Examples
467    --------
468
469    The following line:
470
471    >>> SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)
472
473    is equivalent to:
474
475    >>> sentry_literal_args(pysig, literal_args, args, kwargs)
476    """
477    def for_function(self, func):
478        """Bind the sentry to the signature of *func*.
479
480        Parameters
481        ----------
482        func : Function
483            A python function.
484
485        Returns
486        -------
487        obj : BoundLiteralArgs
488        """
489        return self.for_pysig(utils.pysignature(func))
490
491    def for_pysig(self, pysig):
492        """Bind the sentry to the given signature *pysig*.
493
494        Parameters
495        ----------
496        pysig : inspect.Signature
497
498
499        Returns
500        -------
501        obj : BoundLiteralArgs
502        """
503        return BoundLiteralArgs(
504            pysig=pysig,
505            literal_args=self.literal_args,
506        )
507
508
509class BoundLiteralArgs(collections.namedtuple(
510        'BoundLiteralArgs', ['pysig', 'literal_args'])):
511    """
512    This class is usually created by SentryLiteralArgs.
513    """
514    def bind(self, *args, **kwargs):
515        """Bind to argument types.
516        """
517        return sentry_literal_args(
518            self.pysig,
519            self.literal_args,
520            args,
521            kwargs,
522        )
523
524
525def is_jitted(function):
526    """Returns True if a function is wrapped by one of the Numba @jit
527    decorators, for example: numba.jit, numba.njit
528
529    The purpose of this function is to provide a means to check if a function is
530    already JIT decorated.
531    """
532
533    # don't want to export this so import locally
534    from numba.core.dispatcher import Dispatcher
535    return isinstance(function, Dispatcher)
536