1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import ctypes
19import traceback
20from cpython cimport Py_INCREF, Py_DECREF
21from numbers import Number, Integral
22from ..base import string_types, py2cerror
23from ..runtime_ctypes import DataType, TVMContext, TVMByteArray, ObjectRValueRef
24
25
26cdef void tvm_callback_finalize(void* fhandle) with gil:
27    local_pyfunc = <object>(fhandle)
28    Py_DECREF(local_pyfunc)
29
30cdef int tvm_callback(TVMValue* args,
31                      int* type_codes,
32                      int num_args,
33                      TVMRetValueHandle ret,
34                      void* fhandle) with gil:
35    cdef list pyargs
36    cdef TVMValue value
37    cdef int tcode
38    local_pyfunc = <object>(fhandle)
39    pyargs = []
40    for i in range(num_args):
41        value = args[i]
42        tcode = type_codes[i]
43        if (tcode == kTVMObjectHandle or
44            tcode == kTVMPackedFuncHandle or
45            tcode == kTVMModuleHandle or
46            tcode == kTVMNDArrayHandle or
47            tcode == kTVMObjectRefArg or
48            tcode > kTVMExtBegin):
49            CALL(TVMCbArgToReturn(&value, &tcode))
50
51        if tcode != kTVMDLTensorHandle:
52            pyargs.append(make_ret(value, tcode))
53        else:
54            pyargs.append(c_make_array(value.v_handle, True, False))
55    try:
56        rv = local_pyfunc(*pyargs)
57    except Exception:
58        msg = traceback.format_exc()
59        msg = py2cerror(msg)
60        TVMAPISetLastError(c_str(msg))
61        return -1
62    if rv is not None:
63        if isinstance(rv, tuple):
64            raise ValueError("PackedFunction can only support one return value")
65        temp_args = []
66        make_arg(rv, &value, &tcode, temp_args)
67        CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
68    return 0
69
70
71cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
72    obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
73    (<PackedFuncBase>obj).chandle = chandle
74    (<PackedFuncBase>obj).is_global = is_global
75    return obj
76
77
78def convert_to_tvm_func(object pyfunc):
79    """Convert a python function to TVM function
80
81    Parameters
82    ----------
83    pyfunc : python function
84        The python function to be converted.
85
86    Returns
87    -------
88    tvmfunc: tvm.Function
89        The converted tvm function.
90    """
91    cdef TVMPackedFuncHandle chandle
92    Py_INCREF(pyfunc)
93    CALL(TVMFuncCreateFromCFunc(tvm_callback,
94                                <void*>(pyfunc),
95                                tvm_callback_finalize,
96                                &chandle))
97    return make_packed_func(chandle, False)
98
99
100cdef inline int make_arg(object arg,
101                         TVMValue* value,
102                         int* tcode,
103                         list temp_args) except -1:
104    """Pack arguments into c args tvm call accept"""
105    cdef unsigned long long ptr
106    if isinstance(arg, ObjectBase):
107        value[0].v_handle = (<ObjectBase>arg).chandle
108        tcode[0] = kTVMObjectHandle
109    elif isinstance(arg, NDArrayBase):
110        value[0].v_handle = (<NDArrayBase>arg).chandle
111        tcode[0] = (kTVMNDArrayHandle if
112                    not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
113    elif isinstance(arg, PyNativeObject):
114        value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
115        tcode[0] = kTVMObjectHandle
116    elif isinstance(arg, _TVM_COMPATS):
117        ptr = arg._tvm_handle
118        value[0].v_handle = (<void*>ptr)
119        tcode[0] = arg.__class__._tvm_tcode
120    elif isinstance(arg, Integral):
121        value[0].v_int64 = arg
122        tcode[0] = kInt
123    elif isinstance(arg, float):
124        value[0].v_float64 = arg
125        tcode[0] = kFloat
126    elif isinstance(arg, str):
127        tstr = c_str(arg)
128        value[0].v_str = tstr
129        tcode[0] = kTVMStr
130        temp_args.append(tstr)
131    elif arg is None:
132        value[0].v_handle = NULL
133        tcode[0] = kTVMNullptr
134    elif isinstance(arg, Number):
135        value[0].v_float64 = arg
136        tcode[0] = kFloat
137    elif isinstance(arg, DataType):
138        tstr = c_str(str(arg))
139        value[0].v_str = tstr
140        tcode[0] = kTVMStr
141        temp_args.append(tstr)
142    elif isinstance(arg, TVMContext):
143        value[0].v_ctx = (<DLContext*>(
144            <unsigned long long>ctypes.addressof(arg)))[0]
145        tcode[0] = kTVMContext
146    elif isinstance(arg, (bytes, bytearray)):
147        # from_buffer only taeks in bytearray.
148        if isinstance(arg, bytes):
149            byte_arr = bytearray(arg)
150            temp_args.append(byte_arr)
151            arg = byte_arr
152
153        arr = TVMByteArray()
154        arr.data = ctypes.cast(
155            (ctypes.c_byte * len(arg)).from_buffer(arg),
156            ctypes.POINTER(ctypes.c_byte))
157        arr.size = len(arg)
158        value[0].v_handle = <void*>(
159            <unsigned long long>ctypes.addressof(arr))
160        tcode[0] = kTVMBytes
161        temp_args.append(arr)
162    elif isinstance(arg, string_types):
163        tstr = c_str(arg)
164        value[0].v_str = tstr
165        tcode[0] = kTVMStr
166        temp_args.append(tstr)
167    elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
168        arg = _FUNC_CONVERT_TO_OBJECT(arg)
169        value[0].v_handle = (<ObjectBase>arg).chandle
170        tcode[0] = kTVMObjectHandle
171        temp_args.append(arg)
172    elif isinstance(arg, _CLASS_MODULE):
173        value[0].v_handle = c_handle(arg.handle)
174        tcode[0] = kTVMModuleHandle
175    elif isinstance(arg, PackedFuncBase):
176        value[0].v_handle = (<PackedFuncBase>arg).chandle
177        tcode[0] = kTVMPackedFuncHandle
178    elif isinstance(arg, ctypes.c_void_p):
179        value[0].v_handle = c_handle(arg)
180        tcode[0] = kTVMOpaqueHandle
181    elif isinstance(arg, ObjectRValueRef):
182        value[0].v_handle = &((<ObjectBase>(arg.obj)).chandle)
183        tcode[0] = kTVMObjectRefArg
184    elif callable(arg):
185        arg = convert_to_tvm_func(arg)
186        value[0].v_handle = (<PackedFuncBase>arg).chandle
187        tcode[0] = kTVMPackedFuncHandle
188        temp_args.append(arg)
189    else:
190        raise TypeError("Don't know how to handle type %s" % type(arg))
191    return 0
192
193
194cdef inline bytearray make_ret_bytes(void* chandle):
195    handle = ctypes_handle(chandle)
196    arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
197    size = arr.size
198    res = bytearray(size)
199    rptr = (ctypes.c_byte * size).from_buffer(res)
200    if not ctypes.memmove(rptr, arr.data, size):
201        raise RuntimeError('memmove failed')
202    return res
203
204
205cdef inline object make_ret(TVMValue value, int tcode):
206    """convert result to return value."""
207    if tcode == kTVMObjectHandle:
208        return make_ret_object(value.v_handle)
209    elif tcode == kTVMNullptr:
210        return None
211    elif tcode == kInt:
212        return value.v_int64
213    elif tcode == kFloat:
214        return value.v_float64
215    elif tcode == kTVMNDArrayHandle:
216        return c_make_array(value.v_handle, False, True)
217    elif tcode == kTVMStr:
218        return py_str(value.v_str)
219    elif tcode == kTVMBytes:
220        return make_ret_bytes(value.v_handle)
221    elif tcode == kTVMOpaqueHandle:
222        return ctypes_handle(value.v_handle)
223    elif tcode == kTVMContext:
224        return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
225    elif tcode == kTVMModuleHandle:
226        return _CLASS_MODULE(ctypes_handle(value.v_handle))
227    elif tcode == kTVMPackedFuncHandle:
228        return make_packed_func(value.v_handle, False)
229    elif tcode in _TVM_EXT_RET:
230        return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
231
232    raise ValueError("Unhandled type code %d" % tcode)
233
234
235cdef inline int FuncCall3(void* chandle,
236                          tuple args,
237                          int nargs,
238                          TVMValue* ret_val,
239                          int* ret_tcode) except -1:
240    cdef TVMValue[3] values
241    cdef int[3] tcodes
242    nargs = len(args)
243    temp_args = []
244    for i in range(nargs):
245        make_arg(args[i], &values[i], &tcodes[i], temp_args)
246    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
247                     nargs, ret_val, ret_tcode))
248    return 0
249
250cdef inline int FuncCall(void* chandle,
251                         tuple args,
252                         TVMValue* ret_val,
253                         int* ret_tcode) except -1:
254    cdef int nargs
255    nargs = len(args)
256    if nargs <= 3:
257        FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
258        return 0
259
260    cdef vector[TVMValue] values
261    cdef vector[int] tcodes
262    values.resize(max(nargs, 1))
263    tcodes.resize(max(nargs, 1))
264    temp_args = []
265    for i in range(nargs):
266        make_arg(args[i], &values[i], &tcodes[i], temp_args)
267    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
268                     nargs, ret_val, ret_tcode))
269    return 0
270
271
272cdef inline int ConstructorCall(void* constructor_handle,
273                                int type_code,
274                                tuple args,
275                                void** handle) except -1:
276    """Call contructor of a handle function"""
277    cdef TVMValue ret_val
278    cdef int ret_tcode
279    FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
280    assert ret_tcode == type_code
281    handle[0] = ret_val.v_handle
282    return 0
283
284
285cdef class PackedFuncBase:
286    cdef TVMPackedFuncHandle chandle
287    cdef int is_global
288
289    cdef inline _set_handle(self, handle):
290        if handle is None:
291            self.chandle = NULL
292        else:
293            self.chandle = c_handle(handle)
294
295    property is_global:
296        def __get__(self):
297            return self.c_is_global != 0
298
299        def __set__(self, value):
300            self.c_is_global = value
301
302    property handle:
303        def __get__(self):
304            if self.chandle == NULL:
305                return None
306            else:
307                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
308        def __set__(self, value):
309            self._set_handle(value)
310
311    def __init__(self, handle, is_global):
312        self._set_handle(handle)
313        self.c_is_global = is_global
314
315    def __dealloc__(self):
316        if self.is_global == 0:
317            CALL(TVMFuncFree(self.chandle))
318
319    def __call__(self, *args):
320        cdef TVMValue ret_val
321        cdef int ret_tcode
322        FuncCall(self.chandle, args, &ret_val, &ret_tcode)
323        return make_ret(ret_val, ret_tcode)
324
325
326def _get_global_func(name, allow_missing):
327    cdef TVMPackedFuncHandle chandle
328    CALL(TVMFuncGetGlobal(c_str(name), &chandle))
329    if chandle != NULL:
330        return make_packed_func(chandle, True)
331
332    if allow_missing:
333       return None
334
335    raise ValueError("Cannot find global function %s" % name)
336
337
338_CLASS_PACKED_FUNC = None
339_CLASS_MODULE = None
340_CLASS_OBJECT = None
341_CLASS_OBJECT_GENERIC = None
342_FUNC_CONVERT_TO_OBJECT = None
343
344def _set_class_module(module_class):
345    """Initialize the module."""
346    global _CLASS_MODULE
347    _CLASS_MODULE = module_class
348
349def _set_class_packed_func(func_class):
350    global _CLASS_PACKED_FUNC
351    _CLASS_PACKED_FUNC = func_class
352
353def _set_class_object(obj_class):
354    global _CLASS_OBJECT
355    _CLASS_OBJECT = obj_class
356
357def _set_class_object_generic(object_generic_class, func_convert_to_object):
358    global _CLASS_OBJECT_GENERIC
359    global _FUNC_CONVERT_TO_OBJECT
360    _CLASS_OBJECT_GENERIC = object_generic_class
361    _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
362