1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import ctypes 19import traceback 20from cpython cimport Py_INCREF, Py_DECREF 21from numbers import Number, Integral 22from ..base import string_types, py2cerror 23from ..runtime_ctypes import DataType, TVMContext, TVMByteArray, ObjectRValueRef 24 25 26cdef void tvm_callback_finalize(void* fhandle) with gil: 27 local_pyfunc = <object>(fhandle) 28 Py_DECREF(local_pyfunc) 29 30cdef int tvm_callback(TVMValue* args, 31 int* type_codes, 32 int num_args, 33 TVMRetValueHandle ret, 34 void* fhandle) with gil: 35 cdef list pyargs 36 cdef TVMValue value 37 cdef int tcode 38 local_pyfunc = <object>(fhandle) 39 pyargs = [] 40 for i in range(num_args): 41 value = args[i] 42 tcode = type_codes[i] 43 if (tcode == kTVMObjectHandle or 44 tcode == kTVMPackedFuncHandle or 45 tcode == kTVMModuleHandle or 46 tcode == kTVMNDArrayHandle or 47 tcode == kTVMObjectRefArg or 48 tcode > kTVMExtBegin): 49 CALL(TVMCbArgToReturn(&value, &tcode)) 50 51 if tcode != kTVMDLTensorHandle: 52 pyargs.append(make_ret(value, tcode)) 53 else: 54 pyargs.append(c_make_array(value.v_handle, True, False)) 55 try: 56 rv = local_pyfunc(*pyargs) 57 except Exception: 58 msg = traceback.format_exc() 59 msg = py2cerror(msg) 60 TVMAPISetLastError(c_str(msg)) 61 return -1 62 if rv is not None: 63 if isinstance(rv, tuple): 64 raise ValueError("PackedFunction can only support one return value") 65 temp_args = [] 66 make_arg(rv, &value, &tcode, temp_args) 67 CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1)) 68 return 0 69 70 71cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global): 72 obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) 73 (<PackedFuncBase>obj).chandle = chandle 74 (<PackedFuncBase>obj).is_global = is_global 75 return obj 76 77 78def convert_to_tvm_func(object pyfunc): 79 """Convert a python function to TVM function 80 81 Parameters 82 ---------- 83 pyfunc : python function 84 The python function to be converted. 85 86 Returns 87 ------- 88 tvmfunc: tvm.Function 89 The converted tvm function. 90 """ 91 cdef TVMPackedFuncHandle chandle 92 Py_INCREF(pyfunc) 93 CALL(TVMFuncCreateFromCFunc(tvm_callback, 94 <void*>(pyfunc), 95 tvm_callback_finalize, 96 &chandle)) 97 return make_packed_func(chandle, False) 98 99 100cdef inline int make_arg(object arg, 101 TVMValue* value, 102 int* tcode, 103 list temp_args) except -1: 104 """Pack arguments into c args tvm call accept""" 105 cdef unsigned long long ptr 106 if isinstance(arg, ObjectBase): 107 value[0].v_handle = (<ObjectBase>arg).chandle 108 tcode[0] = kTVMObjectHandle 109 elif isinstance(arg, NDArrayBase): 110 value[0].v_handle = (<NDArrayBase>arg).chandle 111 tcode[0] = (kTVMNDArrayHandle if 112 not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle) 113 elif isinstance(arg, PyNativeObject): 114 value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle 115 tcode[0] = kTVMObjectHandle 116 elif isinstance(arg, _TVM_COMPATS): 117 ptr = arg._tvm_handle 118 value[0].v_handle = (<void*>ptr) 119 tcode[0] = arg.__class__._tvm_tcode 120 elif isinstance(arg, Integral): 121 value[0].v_int64 = arg 122 tcode[0] = kInt 123 elif isinstance(arg, float): 124 value[0].v_float64 = arg 125 tcode[0] = kFloat 126 elif isinstance(arg, str): 127 tstr = c_str(arg) 128 value[0].v_str = tstr 129 tcode[0] = kTVMStr 130 temp_args.append(tstr) 131 elif arg is None: 132 value[0].v_handle = NULL 133 tcode[0] = kTVMNullptr 134 elif isinstance(arg, Number): 135 value[0].v_float64 = arg 136 tcode[0] = kFloat 137 elif isinstance(arg, DataType): 138 tstr = c_str(str(arg)) 139 value[0].v_str = tstr 140 tcode[0] = kTVMStr 141 temp_args.append(tstr) 142 elif isinstance(arg, TVMContext): 143 value[0].v_ctx = (<DLContext*>( 144 <unsigned long long>ctypes.addressof(arg)))[0] 145 tcode[0] = kTVMContext 146 elif isinstance(arg, (bytes, bytearray)): 147 # from_buffer only taeks in bytearray. 148 if isinstance(arg, bytes): 149 byte_arr = bytearray(arg) 150 temp_args.append(byte_arr) 151 arg = byte_arr 152 153 arr = TVMByteArray() 154 arr.data = ctypes.cast( 155 (ctypes.c_byte * len(arg)).from_buffer(arg), 156 ctypes.POINTER(ctypes.c_byte)) 157 arr.size = len(arg) 158 value[0].v_handle = <void*>( 159 <unsigned long long>ctypes.addressof(arr)) 160 tcode[0] = kTVMBytes 161 temp_args.append(arr) 162 elif isinstance(arg, string_types): 163 tstr = c_str(arg) 164 value[0].v_str = tstr 165 tcode[0] = kTVMStr 166 temp_args.append(tstr) 167 elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): 168 arg = _FUNC_CONVERT_TO_OBJECT(arg) 169 value[0].v_handle = (<ObjectBase>arg).chandle 170 tcode[0] = kTVMObjectHandle 171 temp_args.append(arg) 172 elif isinstance(arg, _CLASS_MODULE): 173 value[0].v_handle = c_handle(arg.handle) 174 tcode[0] = kTVMModuleHandle 175 elif isinstance(arg, PackedFuncBase): 176 value[0].v_handle = (<PackedFuncBase>arg).chandle 177 tcode[0] = kTVMPackedFuncHandle 178 elif isinstance(arg, ctypes.c_void_p): 179 value[0].v_handle = c_handle(arg) 180 tcode[0] = kTVMOpaqueHandle 181 elif isinstance(arg, ObjectRValueRef): 182 value[0].v_handle = &((<ObjectBase>(arg.obj)).chandle) 183 tcode[0] = kTVMObjectRefArg 184 elif callable(arg): 185 arg = convert_to_tvm_func(arg) 186 value[0].v_handle = (<PackedFuncBase>arg).chandle 187 tcode[0] = kTVMPackedFuncHandle 188 temp_args.append(arg) 189 else: 190 raise TypeError("Don't know how to handle type %s" % type(arg)) 191 return 0 192 193 194cdef inline bytearray make_ret_bytes(void* chandle): 195 handle = ctypes_handle(chandle) 196 arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] 197 size = arr.size 198 res = bytearray(size) 199 rptr = (ctypes.c_byte * size).from_buffer(res) 200 if not ctypes.memmove(rptr, arr.data, size): 201 raise RuntimeError('memmove failed') 202 return res 203 204 205cdef inline object make_ret(TVMValue value, int tcode): 206 """convert result to return value.""" 207 if tcode == kTVMObjectHandle: 208 return make_ret_object(value.v_handle) 209 elif tcode == kTVMNullptr: 210 return None 211 elif tcode == kInt: 212 return value.v_int64 213 elif tcode == kFloat: 214 return value.v_float64 215 elif tcode == kTVMNDArrayHandle: 216 return c_make_array(value.v_handle, False, True) 217 elif tcode == kTVMStr: 218 return py_str(value.v_str) 219 elif tcode == kTVMBytes: 220 return make_ret_bytes(value.v_handle) 221 elif tcode == kTVMOpaqueHandle: 222 return ctypes_handle(value.v_handle) 223 elif tcode == kTVMContext: 224 return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id) 225 elif tcode == kTVMModuleHandle: 226 return _CLASS_MODULE(ctypes_handle(value.v_handle)) 227 elif tcode == kTVMPackedFuncHandle: 228 return make_packed_func(value.v_handle, False) 229 elif tcode in _TVM_EXT_RET: 230 return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) 231 232 raise ValueError("Unhandled type code %d" % tcode) 233 234 235cdef inline int FuncCall3(void* chandle, 236 tuple args, 237 int nargs, 238 TVMValue* ret_val, 239 int* ret_tcode) except -1: 240 cdef TVMValue[3] values 241 cdef int[3] tcodes 242 nargs = len(args) 243 temp_args = [] 244 for i in range(nargs): 245 make_arg(args[i], &values[i], &tcodes[i], temp_args) 246 CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], 247 nargs, ret_val, ret_tcode)) 248 return 0 249 250cdef inline int FuncCall(void* chandle, 251 tuple args, 252 TVMValue* ret_val, 253 int* ret_tcode) except -1: 254 cdef int nargs 255 nargs = len(args) 256 if nargs <= 3: 257 FuncCall3(chandle, args, nargs, ret_val, ret_tcode) 258 return 0 259 260 cdef vector[TVMValue] values 261 cdef vector[int] tcodes 262 values.resize(max(nargs, 1)) 263 tcodes.resize(max(nargs, 1)) 264 temp_args = [] 265 for i in range(nargs): 266 make_arg(args[i], &values[i], &tcodes[i], temp_args) 267 CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], 268 nargs, ret_val, ret_tcode)) 269 return 0 270 271 272cdef inline int ConstructorCall(void* constructor_handle, 273 int type_code, 274 tuple args, 275 void** handle) except -1: 276 """Call contructor of a handle function""" 277 cdef TVMValue ret_val 278 cdef int ret_tcode 279 FuncCall(constructor_handle, args, &ret_val, &ret_tcode) 280 assert ret_tcode == type_code 281 handle[0] = ret_val.v_handle 282 return 0 283 284 285cdef class PackedFuncBase: 286 cdef TVMPackedFuncHandle chandle 287 cdef int is_global 288 289 cdef inline _set_handle(self, handle): 290 if handle is None: 291 self.chandle = NULL 292 else: 293 self.chandle = c_handle(handle) 294 295 property is_global: 296 def __get__(self): 297 return self.c_is_global != 0 298 299 def __set__(self, value): 300 self.c_is_global = value 301 302 property handle: 303 def __get__(self): 304 if self.chandle == NULL: 305 return None 306 else: 307 return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p) 308 def __set__(self, value): 309 self._set_handle(value) 310 311 def __init__(self, handle, is_global): 312 self._set_handle(handle) 313 self.c_is_global = is_global 314 315 def __dealloc__(self): 316 if self.is_global == 0: 317 CALL(TVMFuncFree(self.chandle)) 318 319 def __call__(self, *args): 320 cdef TVMValue ret_val 321 cdef int ret_tcode 322 FuncCall(self.chandle, args, &ret_val, &ret_tcode) 323 return make_ret(ret_val, ret_tcode) 324 325 326def _get_global_func(name, allow_missing): 327 cdef TVMPackedFuncHandle chandle 328 CALL(TVMFuncGetGlobal(c_str(name), &chandle)) 329 if chandle != NULL: 330 return make_packed_func(chandle, True) 331 332 if allow_missing: 333 return None 334 335 raise ValueError("Cannot find global function %s" % name) 336 337 338_CLASS_PACKED_FUNC = None 339_CLASS_MODULE = None 340_CLASS_OBJECT = None 341_CLASS_OBJECT_GENERIC = None 342_FUNC_CONVERT_TO_OBJECT = None 343 344def _set_class_module(module_class): 345 """Initialize the module.""" 346 global _CLASS_MODULE 347 _CLASS_MODULE = module_class 348 349def _set_class_packed_func(func_class): 350 global _CLASS_PACKED_FUNC 351 _CLASS_PACKED_FUNC = func_class 352 353def _set_class_object(obj_class): 354 global _CLASS_OBJECT 355 _CLASS_OBJECT = obj_class 356 357def _set_class_object_generic(object_generic_class, func_convert_to_object): 358 global _CLASS_OBJECT_GENERIC 359 global _FUNC_CONVERT_TO_OBJECT 360 _CLASS_OBJECT_GENERIC = object_generic_class 361 _FUNC_CONVERT_TO_OBJECT = func_convert_to_object 362