1from collections import OrderedDict 2from collections.abc import Sequence 3import types as pytypes 4import inspect 5import operator 6 7from llvmlite import ir as llvmir 8 9from numba.core import types, utils, errors, cgutils, imputils 10from numba.core.registry import cpu_target 11from numba import njit 12from numba.core.typing import templates 13from numba.core.datamodel import default_manager, models 14from numba.experimental.jitclass import _box 15 16 17############################################################################## 18# Data model 19 20 21class InstanceModel(models.StructModel): 22 def __init__(self, dmm, fe_typ): 23 cls_data_ty = types.ClassDataType(fe_typ) 24 # MemInfoPointer uses the `dtype` attribute to traverse for nested 25 # NRT MemInfo. Since we handle nested NRT MemInfo ourselves, 26 # we will replace provide MemInfoPointer with an opaque type 27 # so that it does not raise exception for nested meminfo. 28 dtype = types.Opaque('Opaque.' + str(cls_data_ty)) 29 members = [ 30 ('meminfo', types.MemInfoPointer(dtype)), 31 ('data', types.CPointer(cls_data_ty)), 32 ] 33 super(InstanceModel, self).__init__(dmm, fe_typ, members) 34 35 36class InstanceDataModel(models.StructModel): 37 def __init__(self, dmm, fe_typ): 38 clsty = fe_typ.class_type 39 members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()] 40 super(InstanceDataModel, self).__init__(dmm, fe_typ, members) 41 42 43default_manager.register(types.ClassInstanceType, InstanceModel) 44default_manager.register(types.ClassDataType, InstanceDataModel) 45default_manager.register(types.ClassType, models.OpaqueModel) 46 47 48def _mangle_attr(name): 49 """ 50 Mangle attributes. 51 The resulting name does not startswith an underscore '_'. 52 """ 53 return 'm_' + name 54 55 56############################################################################## 57# Class object 58 59_ctor_template = """ 60def ctor({args}): 61 return __numba_cls_({args}) 62""" 63 64 65def _getargs(fn_sig): 66 """ 67 Returns list of positional and keyword argument names in order. 68 """ 69 params = fn_sig.parameters 70 args = [] 71 for k, v in params.items(): 72 if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD: 73 args.append(k) 74 else: 75 msg = "%s argument type unsupported in jitclass" % v.kind 76 raise errors.UnsupportedError(msg) 77 return args 78 79 80class JitClassType(type): 81 """ 82 The type of any jitclass. 83 """ 84 def __new__(cls, name, bases, dct): 85 if len(bases) != 1: 86 raise TypeError("must have exactly one base class") 87 [base] = bases 88 if isinstance(base, JitClassType): 89 raise TypeError("cannot subclass from a jitclass") 90 assert 'class_type' in dct, 'missing "class_type" attr' 91 outcls = type.__new__(cls, name, bases, dct) 92 outcls._set_init() 93 return outcls 94 95 def _set_init(cls): 96 """ 97 Generate a wrapper for calling the constructor from pure Python. 98 Note the wrapper will only accept positional arguments. 99 """ 100 init = cls.class_type.instance_type.methods['__init__'] 101 init_sig = utils.pysignature(init) 102 # get postitional and keyword arguments 103 # offset by one to exclude the `self` arg 104 args = _getargs(init_sig)[1:] 105 cls._ctor_sig = init_sig 106 ctor_source = _ctor_template.format(args=', '.join(args)) 107 glbls = {"__numba_cls_": cls} 108 exec(ctor_source, glbls) 109 ctor = glbls['ctor'] 110 cls._ctor = njit(ctor) 111 112 def __instancecheck__(cls, instance): 113 if isinstance(instance, _box.Box): 114 return instance._numba_type_.class_type is cls.class_type 115 return False 116 117 def __call__(cls, *args, **kwargs): 118 # The first argument of _ctor_sig is `cls`, which here 119 # is bound to None and then skipped when invoking the constructor. 120 bind = cls._ctor_sig.bind(None, *args, **kwargs) 121 bind.apply_defaults() 122 return cls._ctor(*bind.args[1:], **bind.kwargs) 123 124 125############################################################################## 126# Registration utils 127 128def _validate_spec(spec): 129 for k, v in spec.items(): 130 if not isinstance(k, str): 131 raise TypeError("spec keys should be strings, got %r" % (k,)) 132 if not isinstance(v, types.Type): 133 raise TypeError("spec values should be Numba type instances, got %r" 134 % (v,)) 135 136 137def _fix_up_private_attr(clsname, spec): 138 """ 139 Apply the same changes to dunder names as CPython would. 140 """ 141 out = OrderedDict() 142 for k, v in spec.items(): 143 if k.startswith('__') and not k.endswith('__'): 144 k = '_' + clsname + k 145 out[k] = v 146 return out 147 148 149def _add_linking_libs(context, call): 150 """ 151 Add the required libs for the callable to allow inlining. 152 """ 153 libs = getattr(call, "libs", ()) 154 if libs: 155 context.add_linking_libs(libs) 156 157 158def register_class_type(cls, spec, class_ctor, builder): 159 """ 160 Internal function to create a jitclass. 161 162 Args 163 ---- 164 cls: the original class object (used as the prototype) 165 spec: the structural specification contains the field types. 166 class_ctor: the numba type to represent the jitclass 167 builder: the internal jitclass builder 168 """ 169 # Normalize spec 170 if isinstance(spec, Sequence): 171 spec = OrderedDict(spec) 172 _validate_spec(spec) 173 174 # Fix up private attribute names 175 spec = _fix_up_private_attr(cls.__name__, spec) 176 177 # Copy methods from base classes 178 clsdct = {} 179 for basecls in reversed(inspect.getmro(cls)): 180 clsdct.update(basecls.__dict__) 181 182 methods, props, static_methods, others = {}, {}, {}, {} 183 for k, v in clsdct.items(): 184 if isinstance(v, pytypes.FunctionType): 185 methods[k] = v 186 elif isinstance(v, property): 187 props[k] = v 188 elif isinstance(v, staticmethod): 189 static_methods[k] = v 190 else: 191 others[k] = v 192 193 # Check for name shadowing 194 shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec) 195 if shadowed: 196 raise NameError("name shadowing: {0}".format(', '.join(shadowed))) 197 198 docstring = others.pop('__doc__', "") 199 _drop_ignored_attrs(others) 200 if others: 201 msg = "class members are not yet supported: {0}" 202 members = ', '.join(others.keys()) 203 raise TypeError(msg.format(members)) 204 205 for k, v in props.items(): 206 if v.fdel is not None: 207 raise TypeError("deleter is not supported: {0}".format(k)) 208 209 jit_methods = {k: njit(v) for k, v in methods.items()} 210 211 jit_props = {} 212 for k, v in props.items(): 213 dct = {} 214 if v.fget: 215 dct['get'] = njit(v.fget) 216 if v.fset: 217 dct['set'] = njit(v.fset) 218 jit_props[k] = dct 219 220 jit_static_methods = { 221 k: njit(v.__func__) for k, v in static_methods.items()} 222 223 # Instantiate class type 224 class_type = class_ctor( 225 cls, 226 ConstructorTemplate, 227 spec, 228 jit_methods, 229 jit_props, 230 jit_static_methods) 231 232 jit_class_dct = dict(class_type=class_type, __doc__=docstring) 233 jit_class_dct.update(jit_static_methods) 234 cls = JitClassType(cls.__name__, (cls,), jit_class_dct) 235 236 # Register resolution of the class object 237 typingctx = cpu_target.typing_context 238 typingctx.insert_global(cls, class_type) 239 240 # Register class 241 targetctx = cpu_target.target_context 242 builder(class_type, typingctx, targetctx).register() 243 244 return cls 245 246 247class ConstructorTemplate(templates.AbstractTemplate): 248 """ 249 Base class for jitclass constructor templates. 250 """ 251 252 def generic(self, args, kws): 253 # Redirect resolution to __init__ 254 instance_type = self.key.instance_type 255 ctor = instance_type.jit_methods['__init__'] 256 boundargs = (instance_type.get_reference_type(),) + args 257 disp_type = types.Dispatcher(ctor) 258 sig = disp_type.get_call_type(self.context, boundargs, kws) 259 260 if not isinstance(sig.return_type, types.NoneType): 261 raise TypeError( 262 f"__init__() should return None, not '{sig.return_type}'") 263 264 # Actual constructor returns an instance value (not None) 265 out = templates.signature(instance_type, *sig.args[1:]) 266 return out 267 268 269def _drop_ignored_attrs(dct): 270 # ignore anything defined by object 271 drop = set(['__weakref__', 272 '__module__', 273 '__dict__']) 274 275 if '__annotations__' in dct: 276 drop.add('__annotations__') 277 278 for k, v in dct.items(): 279 if isinstance(v, (pytypes.BuiltinFunctionType, 280 pytypes.BuiltinMethodType)): 281 drop.add(k) 282 elif getattr(v, '__objclass__', None) is object: 283 drop.add(k) 284 285 for k in drop: 286 del dct[k] 287 288 289class ClassBuilder(object): 290 """ 291 A jitclass builder for a mutable jitclass. This will register 292 typing and implementation hooks to the given typing and target contexts. 293 """ 294 class_impl_registry = imputils.Registry() 295 implemented_methods = set() 296 297 def __init__(self, class_type, typingctx, targetctx): 298 self.class_type = class_type 299 self.typingctx = typingctx 300 self.targetctx = targetctx 301 302 def register(self): 303 """ 304 Register to the frontend and backend. 305 """ 306 # Register generic implementations for all jitclasses 307 self._register_methods(self.class_impl_registry, 308 self.class_type.instance_type) 309 # NOTE other registrations are done at the top-level 310 # (see ctor_impl and attr_impl below) 311 self.targetctx.install_registry(self.class_impl_registry) 312 313 def _register_methods(self, registry, instance_type): 314 """ 315 Register method implementations. 316 This simply registers that the method names are valid methods. Inside 317 of imp() below we retrieve the actual method to run from the type of 318 the reciever argument (i.e. self). 319 """ 320 to_register = list(instance_type.jit_methods) + \ 321 list(instance_type.jit_static_methods) 322 for meth in to_register: 323 324 # There's no way to retrieve the particular method name 325 # inside the implementation function, so we have to register a 326 # specific closure for each different name 327 if meth not in self.implemented_methods: 328 self._implement_method(registry, meth) 329 self.implemented_methods.add(meth) 330 331 def _implement_method(self, registry, attr): 332 # create a separate instance of imp method to avoid closure clashing 333 def get_imp(): 334 def imp(context, builder, sig, args): 335 instance_type = sig.args[0] 336 337 if attr in instance_type.jit_methods: 338 method = instance_type.jit_methods[attr] 339 elif attr in instance_type.jit_static_methods: 340 method = instance_type.jit_static_methods[attr] 341 # imp gets called as a method, where the first argument is 342 # self. We drop this for a static method. 343 sig = sig.replace(args=sig.args[1:]) 344 args = args[1:] 345 346 disp_type = types.Dispatcher(method) 347 call = context.get_function(disp_type, sig) 348 out = call(builder, args) 349 _add_linking_libs(context, call) 350 return imputils.impl_ret_new_ref(context, builder, 351 sig.return_type, out) 352 return imp 353 354 def _getsetitem_gen(getset): 355 _dunder_meth = "__%s__" % getset 356 op = getattr(operator, getset) 357 358 @templates.infer_global(op) 359 class GetSetItem(templates.AbstractTemplate): 360 def generic(self, args, kws): 361 instance = args[0] 362 if isinstance(instance, types.ClassInstanceType) and \ 363 _dunder_meth in instance.jit_methods: 364 meth = instance.jit_methods[_dunder_meth] 365 disp_type = types.Dispatcher(meth) 366 sig = disp_type.get_call_type(self.context, args, kws) 367 return sig 368 369 # lower both {g,s}etitem and __{g,s}etitem__ to catch the calls 370 # from python and numba 371 imputils.lower_builtin((types.ClassInstanceType, _dunder_meth), 372 types.ClassInstanceType, 373 types.VarArg(types.Any))(get_imp()) 374 imputils.lower_builtin(op, 375 types.ClassInstanceType, 376 types.VarArg(types.Any))(get_imp()) 377 378 dunder_stripped = attr.strip('_') 379 if dunder_stripped in ("getitem", "setitem"): 380 _getsetitem_gen(dunder_stripped) 381 else: 382 registry.lower((types.ClassInstanceType, attr), 383 types.ClassInstanceType, 384 types.VarArg(types.Any))(get_imp()) 385 386 387@templates.infer_getattr 388class ClassAttribute(templates.AttributeTemplate): 389 key = types.ClassInstanceType 390 391 def generic_resolve(self, instance, attr): 392 if attr in instance.struct: 393 # It's a struct field => the type is well-known 394 return instance.struct[attr] 395 396 elif attr in instance.jit_methods: 397 # It's a jitted method => typeinfer it 398 meth = instance.jit_methods[attr] 399 disp_type = types.Dispatcher(meth) 400 401 class MethodTemplate(templates.AbstractTemplate): 402 key = (self.key, attr) 403 404 def generic(self, args, kws): 405 args = (instance,) + tuple(args) 406 sig = disp_type.get_call_type(self.context, args, kws) 407 return sig.as_method() 408 409 return types.BoundFunction(MethodTemplate, instance) 410 411 elif attr in instance.jit_static_methods: 412 # It's a jitted method => typeinfer it 413 meth = instance.jit_static_methods[attr] 414 disp_type = types.Dispatcher(meth) 415 416 class StaticMethodTemplate(templates.AbstractTemplate): 417 key = (self.key, attr) 418 419 def generic(self, args, kws): 420 # Don't add instance as the first argument for a static 421 # method. 422 sig = disp_type.get_call_type(self.context, args, kws) 423 # sig itself does not include ClassInstanceType as it's 424 # first argument, so instead of calling sig.as_method() 425 # we insert the recvr. This is equivalent to 426 # sig.replace(args=(instance,) + sig.args).as_method(). 427 return sig.replace(recvr=instance) 428 429 return types.BoundFunction(StaticMethodTemplate, instance) 430 431 elif attr in instance.jit_props: 432 # It's a jitted property => typeinfer its getter 433 impdct = instance.jit_props[attr] 434 getter = impdct['get'] 435 disp_type = types.Dispatcher(getter) 436 sig = disp_type.get_call_type(self.context, (instance,), {}) 437 return sig.return_type 438 439 440@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType) 441def get_attr_impl(context, builder, typ, value, attr): 442 """ 443 Generic getattr() for @jitclass instances. 444 """ 445 if attr in typ.struct: 446 # It's a struct field 447 inst = context.make_helper(builder, typ, value=value) 448 data_pointer = inst.data 449 data = context.make_data_helper(builder, typ.get_data_type(), 450 ref=data_pointer) 451 return imputils.impl_ret_borrowed(context, builder, 452 typ.struct[attr], 453 getattr(data, _mangle_attr(attr))) 454 elif attr in typ.jit_props: 455 # It's a jitted property 456 getter = typ.jit_props[attr]['get'] 457 sig = templates.signature(None, typ) 458 dispatcher = types.Dispatcher(getter) 459 sig = dispatcher.get_call_type(context.typing_context, [typ], {}) 460 call = context.get_function(dispatcher, sig) 461 out = call(builder, [value]) 462 _add_linking_libs(context, call) 463 return imputils.impl_ret_new_ref(context, builder, sig.return_type, out) 464 465 raise NotImplementedError('attribute {0!r} not implemented'.format(attr)) 466 467 468@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType) 469def set_attr_impl(context, builder, sig, args, attr): 470 """ 471 Generic setattr() for @jitclass instances. 472 """ 473 typ, valty = sig.args 474 target, val = args 475 476 if attr in typ.struct: 477 # It's a struct member 478 inst = context.make_helper(builder, typ, value=target) 479 data_ptr = inst.data 480 data = context.make_data_helper(builder, typ.get_data_type(), 481 ref=data_ptr) 482 483 # Get old value 484 attr_type = typ.struct[attr] 485 oldvalue = getattr(data, _mangle_attr(attr)) 486 487 # Store n 488 setattr(data, _mangle_attr(attr), val) 489 context.nrt.incref(builder, attr_type, val) 490 491 # Delete old value 492 context.nrt.decref(builder, attr_type, oldvalue) 493 494 elif attr in typ.jit_props: 495 # It's a jitted property 496 setter = typ.jit_props[attr]['set'] 497 disp_type = types.Dispatcher(setter) 498 sig = disp_type.get_call_type(context.typing_context, 499 (typ, valty), {}) 500 call = context.get_function(disp_type, sig) 501 call(builder, (target, val)) 502 _add_linking_libs(context, call) 503 else: 504 raise NotImplementedError( 505 'attribute {0!r} not implemented'.format(attr)) 506 507 508def imp_dtor(context, module, instance_type): 509 llvoidptr = context.get_value_type(types.voidptr) 510 llsize = context.get_value_type(types.uintp) 511 dtor_ftype = llvmir.FunctionType(llvmir.VoidType(), 512 [llvoidptr, llsize, llvoidptr]) 513 514 fname = "_Dtor.{0}".format(instance_type.name) 515 dtor_fn = module.get_or_insert_function(dtor_ftype, 516 name=fname) 517 if dtor_fn.is_declaration: 518 # Define 519 builder = llvmir.IRBuilder(dtor_fn.append_basic_block()) 520 521 alloc_fe_type = instance_type.get_data_type() 522 alloc_type = context.get_value_type(alloc_fe_type) 523 524 ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer()) 525 data = context.make_helper(builder, alloc_fe_type, ref=ptr) 526 527 context.nrt.decref(builder, alloc_fe_type, data._getvalue()) 528 529 builder.ret_void() 530 531 return dtor_fn 532 533 534@ClassBuilder.class_impl_registry.lower(types.ClassType, 535 types.VarArg(types.Any)) 536def ctor_impl(context, builder, sig, args): 537 """ 538 Generic constructor (__new__) for jitclasses. 539 """ 540 # Allocate the instance 541 inst_typ = sig.return_type 542 alloc_type = context.get_data_type(inst_typ.get_data_type()) 543 alloc_size = context.get_abi_sizeof(alloc_type) 544 545 meminfo = context.nrt.meminfo_alloc_dtor( 546 builder, 547 context.get_constant(types.uintp, alloc_size), 548 imp_dtor(context, builder.module, inst_typ), 549 ) 550 data_pointer = context.nrt.meminfo_data(builder, meminfo) 551 data_pointer = builder.bitcast(data_pointer, 552 alloc_type.as_pointer()) 553 554 # Nullify all data 555 builder.store(cgutils.get_null_value(alloc_type), 556 data_pointer) 557 558 inst_struct = context.make_helper(builder, inst_typ) 559 inst_struct.meminfo = meminfo 560 inst_struct.data = data_pointer 561 562 # Call the jitted __init__ 563 # TODO: extract the following into a common util 564 init_sig = (sig.return_type,) + sig.args 565 566 init = inst_typ.jit_methods['__init__'] 567 disp_type = types.Dispatcher(init) 568 call = context.get_function(disp_type, types.void(*init_sig)) 569 _add_linking_libs(context, call) 570 realargs = [inst_struct._getvalue()] + list(args) 571 call(builder, realargs) 572 573 # Prepare return value 574 ret = inst_struct._getvalue() 575 576 return imputils.impl_ret_new_ref(context, builder, inst_typ, ret) 577