1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""Runtime Object api""" 19import ctypes 20from ..base import _LIB, check_call 21from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func 22from .ndarray import _register_ndarray, NDArrayBase 23 24 25ObjectHandle = ctypes.c_void_p 26__init_by_constructor__ = None 27 28"""Maps object type to its constructor""" 29OBJECT_TYPE = {} 30 31_CLASS_OBJECT = None 32 33 34def _set_class_object(object_class): 35 global _CLASS_OBJECT 36 _CLASS_OBJECT = object_class 37 38 39def _register_object(index, cls): 40 """register object class""" 41 if issubclass(cls, NDArrayBase): 42 _register_ndarray(index, cls) 43 return 44 OBJECT_TYPE[index] = cls 45 46 47def _return_object(x): 48 handle = x.v_handle 49 if not isinstance(handle, ObjectHandle): 50 handle = ObjectHandle(handle) 51 tindex = ctypes.c_uint() 52 check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) 53 cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) 54 if issubclass(cls, PyNativeObject): 55 obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) 56 obj.handle = handle 57 return cls.__from_tvm_object__(cls, obj) 58 # Avoid calling __init__ of cls, instead directly call __new__ 59 # This allows child class to implement their own __init__ 60 obj = cls.__new__(cls) 61 obj.handle = handle 62 return obj 63 64 65RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object 66C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func( 67 _return_object, ArgTypeCode.OBJECT_HANDLE 68) 69 70C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( 71 _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG 72) 73 74 75class PyNativeObject: 76 """Base class of all TVM objects that also subclass python's builtin types.""" 77 78 __slots__ = [] 79 80 def __init_tvm_object_by_constructor__(self, fconstructor, *args): 81 """Initialize the internal tvm_object by calling constructor function. 82 83 Parameters 84 ---------- 85 fconstructor : Function 86 Constructor function. 87 88 args: list of objects 89 The arguments to the constructor 90 91 Note 92 ---- 93 We have a special calling convention to call constructor functions. 94 So the return object is directly set into the object 95 """ 96 # pylint: disable=assigning-non-slot 97 obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) 98 obj.__init_handle_by_constructor__(fconstructor, *args) 99 self.__tvm_object__ = obj 100 101 102class ObjectBase(object): 103 """Base object for all object types""" 104 105 __slots__ = ["handle"] 106 107 def __del__(self): 108 if _LIB is not None: 109 check_call(_LIB.TVMObjectFree(self.handle)) 110 111 def __init_handle_by_constructor__(self, fconstructor, *args): 112 """Initialize the handle by calling constructor function. 113 114 Parameters 115 ---------- 116 fconstructor : Function 117 Constructor function. 118 119 args: list of objects 120 The arguments to the constructor 121 122 Note 123 ---- 124 We have a special calling convention to call constructor functions. 125 So the return handle is directly set into the Node object 126 instead of creating a new Node. 127 """ 128 # assign handle first to avoid error raising 129 # pylint: disable=not-callable 130 self.handle = None 131 handle = __init_by_constructor__(fconstructor, args) 132 if not isinstance(handle, ObjectHandle): 133 handle = ObjectHandle(handle) 134 self.handle = handle 135 136 def same_as(self, other): 137 """Check object identity. 138 139 Parameters 140 ---------- 141 other : object 142 The other object to compare against. 143 144 Returns 145 ------- 146 result : bool 147 The comparison result. 148 """ 149 if not isinstance(other, ObjectBase): 150 return False 151 if self.handle is None: 152 return other.handle is None 153 return self.handle.value == other.handle.value 154