1from collections import OrderedDict
2from collections.abc import Sequence
3import types as pytypes
4import inspect
5import operator
6
7from llvmlite import ir as llvmir
8
9from numba.core import types, utils, errors, cgutils, imputils
10from numba.core.registry import cpu_target
11from numba import njit
12from numba.core.typing import templates
13from numba.core.datamodel import default_manager, models
14from numba.experimental.jitclass import _box
15
16
17##############################################################################
18# Data model
19
20
21class InstanceModel(models.StructModel):
22    def __init__(self, dmm, fe_typ):
23        cls_data_ty = types.ClassDataType(fe_typ)
24        # MemInfoPointer uses the `dtype` attribute to traverse for nested
25        # NRT MemInfo.  Since we handle nested NRT MemInfo ourselves,
26        # we will replace provide MemInfoPointer with an opaque type
27        # so that it does not raise exception for nested meminfo.
28        dtype = types.Opaque('Opaque.' + str(cls_data_ty))
29        members = [
30            ('meminfo', types.MemInfoPointer(dtype)),
31            ('data', types.CPointer(cls_data_ty)),
32        ]
33        super(InstanceModel, self).__init__(dmm, fe_typ, members)
34
35
36class InstanceDataModel(models.StructModel):
37    def __init__(self, dmm, fe_typ):
38        clsty = fe_typ.class_type
39        members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
40        super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
41
42
43default_manager.register(types.ClassInstanceType, InstanceModel)
44default_manager.register(types.ClassDataType, InstanceDataModel)
45default_manager.register(types.ClassType, models.OpaqueModel)
46
47
48def _mangle_attr(name):
49    """
50    Mangle attributes.
51    The resulting name does not startswith an underscore '_'.
52    """
53    return 'm_' + name
54
55
56##############################################################################
57# Class object
58
59_ctor_template = """
60def ctor({args}):
61    return __numba_cls_({args})
62"""
63
64
65def _getargs(fn_sig):
66    """
67    Returns list of positional and keyword argument names in order.
68    """
69    params = fn_sig.parameters
70    args = []
71    for k, v in params.items():
72        if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
73            args.append(k)
74        else:
75            msg = "%s argument type unsupported in jitclass" % v.kind
76            raise errors.UnsupportedError(msg)
77    return args
78
79
80class JitClassType(type):
81    """
82    The type of any jitclass.
83    """
84    def __new__(cls, name, bases, dct):
85        if len(bases) != 1:
86            raise TypeError("must have exactly one base class")
87        [base] = bases
88        if isinstance(base, JitClassType):
89            raise TypeError("cannot subclass from a jitclass")
90        assert 'class_type' in dct, 'missing "class_type" attr'
91        outcls = type.__new__(cls, name, bases, dct)
92        outcls._set_init()
93        return outcls
94
95    def _set_init(cls):
96        """
97        Generate a wrapper for calling the constructor from pure Python.
98        Note the wrapper will only accept positional arguments.
99        """
100        init = cls.class_type.instance_type.methods['__init__']
101        init_sig = utils.pysignature(init)
102        # get postitional and keyword arguments
103        # offset by one to exclude the `self` arg
104        args = _getargs(init_sig)[1:]
105        cls._ctor_sig = init_sig
106        ctor_source = _ctor_template.format(args=', '.join(args))
107        glbls = {"__numba_cls_": cls}
108        exec(ctor_source, glbls)
109        ctor = glbls['ctor']
110        cls._ctor = njit(ctor)
111
112    def __instancecheck__(cls, instance):
113        if isinstance(instance, _box.Box):
114            return instance._numba_type_.class_type is cls.class_type
115        return False
116
117    def __call__(cls, *args, **kwargs):
118        # The first argument of _ctor_sig is `cls`, which here
119        # is bound to None and then skipped when invoking the constructor.
120        bind = cls._ctor_sig.bind(None, *args, **kwargs)
121        bind.apply_defaults()
122        return cls._ctor(*bind.args[1:], **bind.kwargs)
123
124
125##############################################################################
126# Registration utils
127
128def _validate_spec(spec):
129    for k, v in spec.items():
130        if not isinstance(k, str):
131            raise TypeError("spec keys should be strings, got %r" % (k,))
132        if not isinstance(v, types.Type):
133            raise TypeError("spec values should be Numba type instances, got %r"
134                            % (v,))
135
136
137def _fix_up_private_attr(clsname, spec):
138    """
139    Apply the same changes to dunder names as CPython would.
140    """
141    out = OrderedDict()
142    for k, v in spec.items():
143        if k.startswith('__') and not k.endswith('__'):
144            k = '_' + clsname + k
145        out[k] = v
146    return out
147
148
149def _add_linking_libs(context, call):
150    """
151    Add the required libs for the callable to allow inlining.
152    """
153    libs = getattr(call, "libs", ())
154    if libs:
155        context.add_linking_libs(libs)
156
157
158def register_class_type(cls, spec, class_ctor, builder):
159    """
160    Internal function to create a jitclass.
161
162    Args
163    ----
164    cls: the original class object (used as the prototype)
165    spec: the structural specification contains the field types.
166    class_ctor: the numba type to represent the jitclass
167    builder: the internal jitclass builder
168    """
169    # Normalize spec
170    if isinstance(spec, Sequence):
171        spec = OrderedDict(spec)
172    _validate_spec(spec)
173
174    # Fix up private attribute names
175    spec = _fix_up_private_attr(cls.__name__, spec)
176
177    # Copy methods from base classes
178    clsdct = {}
179    for basecls in reversed(inspect.getmro(cls)):
180        clsdct.update(basecls.__dict__)
181
182    methods, props, static_methods, others = {}, {}, {}, {}
183    for k, v in clsdct.items():
184        if isinstance(v, pytypes.FunctionType):
185            methods[k] = v
186        elif isinstance(v, property):
187            props[k] = v
188        elif isinstance(v, staticmethod):
189            static_methods[k] = v
190        else:
191            others[k] = v
192
193    # Check for name shadowing
194    shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
195    if shadowed:
196        raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
197
198    docstring = others.pop('__doc__', "")
199    _drop_ignored_attrs(others)
200    if others:
201        msg = "class members are not yet supported: {0}"
202        members = ', '.join(others.keys())
203        raise TypeError(msg.format(members))
204
205    for k, v in props.items():
206        if v.fdel is not None:
207            raise TypeError("deleter is not supported: {0}".format(k))
208
209    jit_methods = {k: njit(v) for k, v in methods.items()}
210
211    jit_props = {}
212    for k, v in props.items():
213        dct = {}
214        if v.fget:
215            dct['get'] = njit(v.fget)
216        if v.fset:
217            dct['set'] = njit(v.fset)
218        jit_props[k] = dct
219
220    jit_static_methods = {
221        k: njit(v.__func__) for k, v in static_methods.items()}
222
223    # Instantiate class type
224    class_type = class_ctor(
225        cls,
226        ConstructorTemplate,
227        spec,
228        jit_methods,
229        jit_props,
230        jit_static_methods)
231
232    jit_class_dct = dict(class_type=class_type, __doc__=docstring)
233    jit_class_dct.update(jit_static_methods)
234    cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
235
236    # Register resolution of the class object
237    typingctx = cpu_target.typing_context
238    typingctx.insert_global(cls, class_type)
239
240    # Register class
241    targetctx = cpu_target.target_context
242    builder(class_type, typingctx, targetctx).register()
243
244    return cls
245
246
247class ConstructorTemplate(templates.AbstractTemplate):
248    """
249    Base class for jitclass constructor templates.
250    """
251
252    def generic(self, args, kws):
253        # Redirect resolution to __init__
254        instance_type = self.key.instance_type
255        ctor = instance_type.jit_methods['__init__']
256        boundargs = (instance_type.get_reference_type(),) + args
257        disp_type = types.Dispatcher(ctor)
258        sig = disp_type.get_call_type(self.context, boundargs, kws)
259
260        if not isinstance(sig.return_type, types.NoneType):
261            raise TypeError(
262                f"__init__() should return None, not '{sig.return_type}'")
263
264        # Actual constructor returns an instance value (not None)
265        out = templates.signature(instance_type, *sig.args[1:])
266        return out
267
268
269def _drop_ignored_attrs(dct):
270    # ignore anything defined by object
271    drop = set(['__weakref__',
272                '__module__',
273                '__dict__'])
274
275    if '__annotations__' in dct:
276        drop.add('__annotations__')
277
278    for k, v in dct.items():
279        if isinstance(v, (pytypes.BuiltinFunctionType,
280                          pytypes.BuiltinMethodType)):
281            drop.add(k)
282        elif getattr(v, '__objclass__', None) is object:
283            drop.add(k)
284
285    for k in drop:
286        del dct[k]
287
288
289class ClassBuilder(object):
290    """
291    A jitclass builder for a mutable jitclass.  This will register
292    typing and implementation hooks to the given typing and target contexts.
293    """
294    class_impl_registry = imputils.Registry()
295    implemented_methods = set()
296
297    def __init__(self, class_type, typingctx, targetctx):
298        self.class_type = class_type
299        self.typingctx = typingctx
300        self.targetctx = targetctx
301
302    def register(self):
303        """
304        Register to the frontend and backend.
305        """
306        # Register generic implementations for all jitclasses
307        self._register_methods(self.class_impl_registry,
308                               self.class_type.instance_type)
309        # NOTE other registrations are done at the top-level
310        # (see ctor_impl and attr_impl below)
311        self.targetctx.install_registry(self.class_impl_registry)
312
313    def _register_methods(self, registry, instance_type):
314        """
315        Register method implementations.
316        This simply registers that the method names are valid methods.  Inside
317        of imp() below we retrieve the actual method to run from the type of
318        the reciever argument (i.e. self).
319        """
320        to_register = list(instance_type.jit_methods) + \
321            list(instance_type.jit_static_methods)
322        for meth in to_register:
323
324            # There's no way to retrieve the particular method name
325            # inside the implementation function, so we have to register a
326            # specific closure for each different name
327            if meth not in self.implemented_methods:
328                self._implement_method(registry, meth)
329                self.implemented_methods.add(meth)
330
331    def _implement_method(self, registry, attr):
332        # create a separate instance of imp method to avoid closure clashing
333        def get_imp():
334            def imp(context, builder, sig, args):
335                instance_type = sig.args[0]
336
337                if attr in instance_type.jit_methods:
338                    method = instance_type.jit_methods[attr]
339                elif attr in instance_type.jit_static_methods:
340                    method = instance_type.jit_static_methods[attr]
341                    # imp gets called as a method, where the first argument is
342                    # self.  We drop this for a static method.
343                    sig = sig.replace(args=sig.args[1:])
344                    args = args[1:]
345
346                disp_type = types.Dispatcher(method)
347                call = context.get_function(disp_type, sig)
348                out = call(builder, args)
349                _add_linking_libs(context, call)
350                return imputils.impl_ret_new_ref(context, builder,
351                                                 sig.return_type, out)
352            return imp
353
354        def _getsetitem_gen(getset):
355            _dunder_meth = "__%s__" % getset
356            op = getattr(operator, getset)
357
358            @templates.infer_global(op)
359            class GetSetItem(templates.AbstractTemplate):
360                def generic(self, args, kws):
361                    instance = args[0]
362                    if isinstance(instance, types.ClassInstanceType) and \
363                            _dunder_meth in instance.jit_methods:
364                        meth = instance.jit_methods[_dunder_meth]
365                        disp_type = types.Dispatcher(meth)
366                        sig = disp_type.get_call_type(self.context, args, kws)
367                        return sig
368
369            # lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
370            # from python and numba
371            imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
372                                   types.ClassInstanceType,
373                                   types.VarArg(types.Any))(get_imp())
374            imputils.lower_builtin(op,
375                                   types.ClassInstanceType,
376                                   types.VarArg(types.Any))(get_imp())
377
378        dunder_stripped = attr.strip('_')
379        if dunder_stripped in ("getitem", "setitem"):
380            _getsetitem_gen(dunder_stripped)
381        else:
382            registry.lower((types.ClassInstanceType, attr),
383                           types.ClassInstanceType,
384                           types.VarArg(types.Any))(get_imp())
385
386
387@templates.infer_getattr
388class ClassAttribute(templates.AttributeTemplate):
389    key = types.ClassInstanceType
390
391    def generic_resolve(self, instance, attr):
392        if attr in instance.struct:
393            # It's a struct field => the type is well-known
394            return instance.struct[attr]
395
396        elif attr in instance.jit_methods:
397            # It's a jitted method => typeinfer it
398            meth = instance.jit_methods[attr]
399            disp_type = types.Dispatcher(meth)
400
401            class MethodTemplate(templates.AbstractTemplate):
402                key = (self.key, attr)
403
404                def generic(self, args, kws):
405                    args = (instance,) + tuple(args)
406                    sig = disp_type.get_call_type(self.context, args, kws)
407                    return sig.as_method()
408
409            return types.BoundFunction(MethodTemplate, instance)
410
411        elif attr in instance.jit_static_methods:
412            # It's a jitted method => typeinfer it
413            meth = instance.jit_static_methods[attr]
414            disp_type = types.Dispatcher(meth)
415
416            class StaticMethodTemplate(templates.AbstractTemplate):
417                key = (self.key, attr)
418
419                def generic(self, args, kws):
420                    # Don't add instance as the first argument for a static
421                    # method.
422                    sig = disp_type.get_call_type(self.context, args, kws)
423                    # sig itself does not include ClassInstanceType as it's
424                    # first argument, so instead of calling sig.as_method()
425                    # we insert the recvr. This is equivalent to
426                    # sig.replace(args=(instance,) + sig.args).as_method().
427                    return sig.replace(recvr=instance)
428
429            return types.BoundFunction(StaticMethodTemplate, instance)
430
431        elif attr in instance.jit_props:
432            # It's a jitted property => typeinfer its getter
433            impdct = instance.jit_props[attr]
434            getter = impdct['get']
435            disp_type = types.Dispatcher(getter)
436            sig = disp_type.get_call_type(self.context, (instance,), {})
437            return sig.return_type
438
439
440@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
441def get_attr_impl(context, builder, typ, value, attr):
442    """
443    Generic getattr() for @jitclass instances.
444    """
445    if attr in typ.struct:
446        # It's a struct field
447        inst = context.make_helper(builder, typ, value=value)
448        data_pointer = inst.data
449        data = context.make_data_helper(builder, typ.get_data_type(),
450                                        ref=data_pointer)
451        return imputils.impl_ret_borrowed(context, builder,
452                                          typ.struct[attr],
453                                          getattr(data, _mangle_attr(attr)))
454    elif attr in typ.jit_props:
455        # It's a jitted property
456        getter = typ.jit_props[attr]['get']
457        sig = templates.signature(None, typ)
458        dispatcher = types.Dispatcher(getter)
459        sig = dispatcher.get_call_type(context.typing_context, [typ], {})
460        call = context.get_function(dispatcher, sig)
461        out = call(builder, [value])
462        _add_linking_libs(context, call)
463        return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
464
465    raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
466
467
468@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
469def set_attr_impl(context, builder, sig, args, attr):
470    """
471    Generic setattr() for @jitclass instances.
472    """
473    typ, valty = sig.args
474    target, val = args
475
476    if attr in typ.struct:
477        # It's a struct member
478        inst = context.make_helper(builder, typ, value=target)
479        data_ptr = inst.data
480        data = context.make_data_helper(builder, typ.get_data_type(),
481                                        ref=data_ptr)
482
483        # Get old value
484        attr_type = typ.struct[attr]
485        oldvalue = getattr(data, _mangle_attr(attr))
486
487        # Store n
488        setattr(data, _mangle_attr(attr), val)
489        context.nrt.incref(builder, attr_type, val)
490
491        # Delete old value
492        context.nrt.decref(builder, attr_type, oldvalue)
493
494    elif attr in typ.jit_props:
495        # It's a jitted property
496        setter = typ.jit_props[attr]['set']
497        disp_type = types.Dispatcher(setter)
498        sig = disp_type.get_call_type(context.typing_context,
499                                      (typ, valty), {})
500        call = context.get_function(disp_type, sig)
501        call(builder, (target, val))
502        _add_linking_libs(context, call)
503    else:
504        raise NotImplementedError(
505            'attribute {0!r} not implemented'.format(attr))
506
507
508def imp_dtor(context, module, instance_type):
509    llvoidptr = context.get_value_type(types.voidptr)
510    llsize = context.get_value_type(types.uintp)
511    dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
512                                     [llvoidptr, llsize, llvoidptr])
513
514    fname = "_Dtor.{0}".format(instance_type.name)
515    dtor_fn = module.get_or_insert_function(dtor_ftype,
516                                            name=fname)
517    if dtor_fn.is_declaration:
518        # Define
519        builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
520
521        alloc_fe_type = instance_type.get_data_type()
522        alloc_type = context.get_value_type(alloc_fe_type)
523
524        ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
525        data = context.make_helper(builder, alloc_fe_type, ref=ptr)
526
527        context.nrt.decref(builder, alloc_fe_type, data._getvalue())
528
529        builder.ret_void()
530
531    return dtor_fn
532
533
534@ClassBuilder.class_impl_registry.lower(types.ClassType,
535                                        types.VarArg(types.Any))
536def ctor_impl(context, builder, sig, args):
537    """
538    Generic constructor (__new__) for jitclasses.
539    """
540    # Allocate the instance
541    inst_typ = sig.return_type
542    alloc_type = context.get_data_type(inst_typ.get_data_type())
543    alloc_size = context.get_abi_sizeof(alloc_type)
544
545    meminfo = context.nrt.meminfo_alloc_dtor(
546        builder,
547        context.get_constant(types.uintp, alloc_size),
548        imp_dtor(context, builder.module, inst_typ),
549    )
550    data_pointer = context.nrt.meminfo_data(builder, meminfo)
551    data_pointer = builder.bitcast(data_pointer,
552                                   alloc_type.as_pointer())
553
554    # Nullify all data
555    builder.store(cgutils.get_null_value(alloc_type),
556                  data_pointer)
557
558    inst_struct = context.make_helper(builder, inst_typ)
559    inst_struct.meminfo = meminfo
560    inst_struct.data = data_pointer
561
562    # Call the jitted __init__
563    # TODO: extract the following into a common util
564    init_sig = (sig.return_type,) + sig.args
565
566    init = inst_typ.jit_methods['__init__']
567    disp_type = types.Dispatcher(init)
568    call = context.get_function(disp_type, types.void(*init_sig))
569    _add_linking_libs(context, call)
570    realargs = [inst_struct._getvalue()] + list(args)
571    call(builder, realargs)
572
573    # Prepare return value
574    ret = inst_struct._getvalue()
575
576    return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)
577