1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import ctypes 19import traceback 20from cpython cimport Py_INCREF, Py_DECREF 21from numbers import Number, Integral 22from ..base import string_types, py2cerror 23from ..node_generic import convert_to_node, NodeGeneric 24from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray 25 26 27cdef void tvm_callback_finalize(void* fhandle): 28 local_pyfunc = <object>(fhandle) 29 Py_DECREF(local_pyfunc) 30 31cdef int tvm_callback(TVMValue* args, 32 int* type_codes, 33 int num_args, 34 TVMRetValueHandle ret, 35 void* fhandle) with gil: 36 cdef list pyargs 37 cdef TVMValue value 38 cdef int tcode 39 local_pyfunc = <object>(fhandle) 40 pyargs = [] 41 for i in range(num_args): 42 value = args[i] 43 tcode = type_codes[i] 44 if (tcode == kObjectHandle or 45 tcode == kFuncHandle or 46 tcode == kModuleHandle or 47 tcode > kExtBegin): 48 CALL(TVMCbArgToReturn(&value, tcode)) 49 50 if tcode != kArrayHandle: 51 pyargs.append(make_ret(value, tcode)) 52 else: 53 pyargs.append(c_make_array(value.v_handle, True, False)) 54 try: 55 rv = local_pyfunc(*pyargs) 56 except Exception: 57 msg = traceback.format_exc() 58 msg = py2cerror(msg) 59 TVMAPISetLastError(c_str(msg)) 60 return -1 61 if rv is not None: 62 if isinstance(rv, tuple): 63 raise ValueError("PackedFunction can only support one return value") 64 temp_args = [] 65 make_arg(rv, &value, &tcode, temp_args) 66 CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1)) 67 return 0 68 69 70def convert_to_tvm_func(object pyfunc): 71 """Convert a python function to TVM function 72 73 Parameters 74 ---------- 75 pyfunc : python function 76 The python function to be converted. 77 78 Returns 79 ------- 80 tvmfunc: tvm.Function 81 The converted tvm function. 82 """ 83 cdef TVMFunctionHandle chandle 84 Py_INCREF(pyfunc) 85 CALL(TVMFuncCreateFromCFunc(tvm_callback, 86 <void*>(pyfunc), 87 tvm_callback_finalize, 88 &chandle)) 89 ret = _CLASS_FUNCTION(None, False) 90 (<FunctionBase>ret).chandle = chandle 91 return ret 92 93 94cdef inline int make_arg(object arg, 95 TVMValue* value, 96 int* tcode, 97 list temp_args) except -1: 98 """Pack arguments into c args tvm call accept""" 99 cdef unsigned long long ptr 100 if isinstance(arg, ObjectBase): 101 value[0].v_handle = (<ObjectBase>arg).chandle 102 tcode[0] = kObjectHandle 103 elif isinstance(arg, NDArrayBase): 104 value[0].v_handle = (<NDArrayBase>arg).chandle 105 tcode[0] = (kNDArrayContainer if 106 not (<NDArrayBase>arg).c_is_view else kArrayHandle) 107 elif isinstance(arg, _TVM_COMPATS): 108 ptr = arg._tvm_handle 109 value[0].v_handle = (<void*>ptr) 110 tcode[0] = arg.__class__._tvm_tcode 111 elif isinstance(arg, (int, long)): 112 value[0].v_int64 = arg 113 tcode[0] = kInt 114 elif isinstance(arg, float): 115 value[0].v_float64 = arg 116 tcode[0] = kFloat 117 elif isinstance(arg, str): 118 tstr = c_str(arg) 119 value[0].v_str = tstr 120 tcode[0] = kStr 121 temp_args.append(tstr) 122 elif arg is None: 123 value[0].v_handle = NULL 124 tcode[0] = kNull 125 elif isinstance(arg, Number): 126 value[0].v_float64 = arg 127 tcode[0] = kFloat 128 elif isinstance(arg, TVMType): 129 tstr = c_str(str(arg)) 130 value[0].v_str = tstr 131 tcode[0] = kStr 132 temp_args.append(tstr) 133 elif isinstance(arg, TVMContext): 134 value[0].v_ctx = (<DLContext*>( 135 <unsigned long long>ctypes.addressof(arg)))[0] 136 tcode[0] = kTVMContext 137 elif isinstance(arg, bytearray): 138 arr = TVMByteArray() 139 arr.data = ctypes.cast( 140 (ctypes.c_byte * len(arg)).from_buffer(arg), 141 ctypes.POINTER(ctypes.c_byte)) 142 arr.size = len(arg) 143 value[0].v_handle = <void*>( 144 <unsigned long long>ctypes.addressof(arr)) 145 tcode[0] = kBytes 146 temp_args.append(arr) 147 elif isinstance(arg, string_types): 148 tstr = c_str(arg) 149 value[0].v_str = tstr 150 tcode[0] = kStr 151 temp_args.append(tstr) 152 elif isinstance(arg, (list, tuple, dict, NodeGeneric)): 153 arg = convert_to_node(arg) 154 value[0].v_handle = (<ObjectBase>arg).chandle 155 tcode[0] = kObjectHandle 156 temp_args.append(arg) 157 elif isinstance(arg, _CLASS_MODULE): 158 value[0].v_handle = c_handle(arg.handle) 159 tcode[0] = kModuleHandle 160 elif isinstance(arg, FunctionBase): 161 value[0].v_handle = (<FunctionBase>arg).chandle 162 tcode[0] = kFuncHandle 163 elif isinstance(arg, ctypes.c_void_p): 164 value[0].v_handle = c_handle(arg) 165 tcode[0] = kHandle 166 elif callable(arg): 167 arg = convert_to_tvm_func(arg) 168 value[0].v_handle = (<FunctionBase>arg).chandle 169 tcode[0] = kFuncHandle 170 temp_args.append(arg) 171 else: 172 raise TypeError("Don't know how to handle type %s" % type(arg)) 173 return 0 174 175cdef inline bytearray make_ret_bytes(void* chandle): 176 handle = ctypes_handle(chandle) 177 arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] 178 size = arr.size 179 res = bytearray(size) 180 rptr = (ctypes.c_byte * size).from_buffer(res) 181 if not ctypes.memmove(rptr, arr.data, size): 182 raise RuntimeError('memmove failed') 183 return res 184 185cdef inline object make_ret(TVMValue value, int tcode): 186 """convert result to return value.""" 187 if tcode == kObjectHandle: 188 return make_ret_object(value.v_handle) 189 elif tcode == kNull: 190 return None 191 elif tcode == kInt: 192 return value.v_int64 193 elif tcode == kFloat: 194 return value.v_float64 195 elif tcode == kNDArrayContainer: 196 return c_make_array(value.v_handle, False, True) 197 elif tcode == kStr: 198 return py_str(value.v_str) 199 elif tcode == kBytes: 200 return make_ret_bytes(value.v_handle) 201 elif tcode == kHandle: 202 return ctypes_handle(value.v_handle) 203 elif tcode == kTVMContext: 204 return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id) 205 elif tcode == kModuleHandle: 206 return _CLASS_MODULE(ctypes_handle(value.v_handle)) 207 elif tcode == kFuncHandle: 208 fobj = _CLASS_FUNCTION(None, False) 209 (<FunctionBase>fobj).chandle = value.v_handle 210 return fobj 211 elif tcode in _TVM_EXT_RET: 212 return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) 213 214 raise ValueError("Unhandled type code %d" % tcode) 215 216 217cdef inline int FuncCall3(void* chandle, 218 tuple args, 219 int nargs, 220 TVMValue* ret_val, 221 int* ret_tcode) except -1: 222 cdef TVMValue[3] values 223 cdef int[3] tcodes 224 nargs = len(args) 225 temp_args = [] 226 for i in range(nargs): 227 make_arg(args[i], &values[i], &tcodes[i], temp_args) 228 CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], 229 nargs, ret_val, ret_tcode)) 230 return 0 231 232cdef inline int FuncCall(void* chandle, 233 tuple args, 234 TVMValue* ret_val, 235 int* ret_tcode) except -1: 236 cdef int nargs 237 nargs = len(args) 238 if nargs <= 3: 239 FuncCall3(chandle, args, nargs, ret_val, ret_tcode) 240 return 0 241 242 cdef vector[TVMValue] values 243 cdef vector[int] tcodes 244 values.resize(max(nargs, 1)) 245 tcodes.resize(max(nargs, 1)) 246 temp_args = [] 247 for i in range(nargs): 248 make_arg(args[i], &values[i], &tcodes[i], temp_args) 249 CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], 250 nargs, ret_val, ret_tcode)) 251 return 0 252 253 254cdef inline int ConstructorCall(void* constructor_handle, 255 int type_code, 256 tuple args, 257 void** handle) except -1: 258 """Call contructor of a handle function""" 259 cdef TVMValue ret_val 260 cdef int ret_tcode 261 FuncCall(constructor_handle, args, &ret_val, &ret_tcode) 262 assert ret_tcode == type_code 263 handle[0] = ret_val.v_handle 264 return 0 265 266 267cdef class FunctionBase: 268 cdef TVMFunctionHandle chandle 269 cdef int is_global 270 271 cdef inline _set_handle(self, handle): 272 if handle is None: 273 self.chandle = NULL 274 else: 275 self.chandle = c_handle(handle) 276 277 property is_global: 278 def __get__(self): 279 return self.c_is_global != 0 280 281 def __set__(self, value): 282 self.c_is_global = value 283 284 property handle: 285 def __get__(self): 286 if self.chandle == NULL: 287 return None 288 else: 289 return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p) 290 def __set__(self, value): 291 self._set_handle(value) 292 293 def __init__(self, handle, is_global): 294 self._set_handle(handle) 295 self.c_is_global = is_global 296 297 def __dealloc__(self): 298 if self.is_global == 0: 299 CALL(TVMFuncFree(self.chandle)) 300 301 def __call__(self, *args): 302 cdef TVMValue ret_val 303 cdef int ret_tcode 304 FuncCall(self.chandle, args, &ret_val, &ret_tcode) 305 return make_ret(ret_val, ret_tcode) 306 307 308_CLASS_FUNCTION = None 309_CLASS_MODULE = None 310_CLASS_OBJECT = None 311_CLASS_NODE = None 312 313def _set_class_module(module_class): 314 """Initialize the module.""" 315 global _CLASS_MODULE 316 _CLASS_MODULE = module_class 317 318def _set_class_function(func_class): 319 global _CLASS_FUNCTION 320 _CLASS_FUNCTION = func_class 321 322def _set_class_object(obj_class): 323 global _CLASS_OBJECT 324 _CLASS_OBJECT = obj_class 325 326def _set_class_node(node_class): 327 global _CLASS_NODE 328 _CLASS_NODE = node_class 329