1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import ctypes
19import traceback
20from cpython cimport Py_INCREF, Py_DECREF
21from numbers import Number, Integral
22from ..base import string_types, py2cerror
23from ..node_generic import convert_to_node, NodeGeneric
24from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
25
26
27cdef void tvm_callback_finalize(void* fhandle):
28    local_pyfunc = <object>(fhandle)
29    Py_DECREF(local_pyfunc)
30
31cdef int tvm_callback(TVMValue* args,
32                      int* type_codes,
33                      int num_args,
34                      TVMRetValueHandle ret,
35                      void* fhandle) with gil:
36    cdef list pyargs
37    cdef TVMValue value
38    cdef int tcode
39    local_pyfunc = <object>(fhandle)
40    pyargs = []
41    for i in range(num_args):
42        value = args[i]
43        tcode = type_codes[i]
44        if (tcode == kObjectHandle or
45            tcode == kFuncHandle or
46            tcode == kModuleHandle or
47            tcode > kExtBegin):
48            CALL(TVMCbArgToReturn(&value, tcode))
49
50        if tcode != kArrayHandle:
51            pyargs.append(make_ret(value, tcode))
52        else:
53            pyargs.append(c_make_array(value.v_handle, True, False))
54    try:
55        rv = local_pyfunc(*pyargs)
56    except Exception:
57        msg = traceback.format_exc()
58        msg = py2cerror(msg)
59        TVMAPISetLastError(c_str(msg))
60        return -1
61    if rv is not None:
62        if isinstance(rv, tuple):
63            raise ValueError("PackedFunction can only support one return value")
64        temp_args = []
65        make_arg(rv, &value, &tcode, temp_args)
66        CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
67    return 0
68
69
70def convert_to_tvm_func(object pyfunc):
71    """Convert a python function to TVM function
72
73    Parameters
74    ----------
75    pyfunc : python function
76        The python function to be converted.
77
78    Returns
79    -------
80    tvmfunc: tvm.Function
81        The converted tvm function.
82    """
83    cdef TVMFunctionHandle chandle
84    Py_INCREF(pyfunc)
85    CALL(TVMFuncCreateFromCFunc(tvm_callback,
86                                <void*>(pyfunc),
87                                tvm_callback_finalize,
88                                &chandle))
89    ret = _CLASS_FUNCTION(None, False)
90    (<FunctionBase>ret).chandle = chandle
91    return ret
92
93
94cdef inline int make_arg(object arg,
95                         TVMValue* value,
96                         int* tcode,
97                         list temp_args) except -1:
98    """Pack arguments into c args tvm call accept"""
99    cdef unsigned long long ptr
100    if isinstance(arg, ObjectBase):
101        value[0].v_handle = (<ObjectBase>arg).chandle
102        tcode[0] = kObjectHandle
103    elif isinstance(arg, NDArrayBase):
104        value[0].v_handle = (<NDArrayBase>arg).chandle
105        tcode[0] = (kNDArrayContainer if
106                    not (<NDArrayBase>arg).c_is_view else kArrayHandle)
107    elif isinstance(arg, _TVM_COMPATS):
108        ptr = arg._tvm_handle
109        value[0].v_handle = (<void*>ptr)
110        tcode[0] = arg.__class__._tvm_tcode
111    elif isinstance(arg, (int, long)):
112        value[0].v_int64 = arg
113        tcode[0] = kInt
114    elif isinstance(arg, float):
115        value[0].v_float64 = arg
116        tcode[0] = kFloat
117    elif isinstance(arg, str):
118        tstr = c_str(arg)
119        value[0].v_str = tstr
120        tcode[0] = kStr
121        temp_args.append(tstr)
122    elif arg is None:
123        value[0].v_handle = NULL
124        tcode[0] = kNull
125    elif isinstance(arg, Number):
126        value[0].v_float64 = arg
127        tcode[0] = kFloat
128    elif isinstance(arg, TVMType):
129        tstr = c_str(str(arg))
130        value[0].v_str = tstr
131        tcode[0] = kStr
132        temp_args.append(tstr)
133    elif isinstance(arg, TVMContext):
134        value[0].v_ctx = (<DLContext*>(
135            <unsigned long long>ctypes.addressof(arg)))[0]
136        tcode[0] = kTVMContext
137    elif isinstance(arg, bytearray):
138        arr = TVMByteArray()
139        arr.data = ctypes.cast(
140            (ctypes.c_byte * len(arg)).from_buffer(arg),
141            ctypes.POINTER(ctypes.c_byte))
142        arr.size = len(arg)
143        value[0].v_handle = <void*>(
144            <unsigned long long>ctypes.addressof(arr))
145        tcode[0] = kBytes
146        temp_args.append(arr)
147    elif isinstance(arg, string_types):
148        tstr = c_str(arg)
149        value[0].v_str = tstr
150        tcode[0] = kStr
151        temp_args.append(tstr)
152    elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
153        arg = convert_to_node(arg)
154        value[0].v_handle = (<ObjectBase>arg).chandle
155        tcode[0] = kObjectHandle
156        temp_args.append(arg)
157    elif isinstance(arg, _CLASS_MODULE):
158        value[0].v_handle = c_handle(arg.handle)
159        tcode[0] = kModuleHandle
160    elif isinstance(arg, FunctionBase):
161        value[0].v_handle = (<FunctionBase>arg).chandle
162        tcode[0] = kFuncHandle
163    elif isinstance(arg, ctypes.c_void_p):
164        value[0].v_handle = c_handle(arg)
165        tcode[0] = kHandle
166    elif callable(arg):
167        arg = convert_to_tvm_func(arg)
168        value[0].v_handle = (<FunctionBase>arg).chandle
169        tcode[0] = kFuncHandle
170        temp_args.append(arg)
171    else:
172        raise TypeError("Don't know how to handle type %s" % type(arg))
173    return 0
174
175cdef inline bytearray make_ret_bytes(void* chandle):
176    handle = ctypes_handle(chandle)
177    arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
178    size = arr.size
179    res = bytearray(size)
180    rptr = (ctypes.c_byte * size).from_buffer(res)
181    if not ctypes.memmove(rptr, arr.data, size):
182        raise RuntimeError('memmove failed')
183    return res
184
185cdef inline object make_ret(TVMValue value, int tcode):
186    """convert result to return value."""
187    if tcode == kObjectHandle:
188        return make_ret_object(value.v_handle)
189    elif tcode == kNull:
190        return None
191    elif tcode == kInt:
192        return value.v_int64
193    elif tcode == kFloat:
194        return value.v_float64
195    elif tcode == kNDArrayContainer:
196        return c_make_array(value.v_handle, False, True)
197    elif tcode == kStr:
198        return py_str(value.v_str)
199    elif tcode == kBytes:
200        return make_ret_bytes(value.v_handle)
201    elif tcode == kHandle:
202        return ctypes_handle(value.v_handle)
203    elif tcode == kTVMContext:
204        return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
205    elif tcode == kModuleHandle:
206        return _CLASS_MODULE(ctypes_handle(value.v_handle))
207    elif tcode == kFuncHandle:
208        fobj = _CLASS_FUNCTION(None, False)
209        (<FunctionBase>fobj).chandle = value.v_handle
210        return fobj
211    elif tcode in _TVM_EXT_RET:
212        return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
213
214    raise ValueError("Unhandled type code %d" % tcode)
215
216
217cdef inline int FuncCall3(void* chandle,
218                          tuple args,
219                          int nargs,
220                          TVMValue* ret_val,
221                          int* ret_tcode) except -1:
222    cdef TVMValue[3] values
223    cdef int[3] tcodes
224    nargs = len(args)
225    temp_args = []
226    for i in range(nargs):
227        make_arg(args[i], &values[i], &tcodes[i], temp_args)
228    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
229                     nargs, ret_val, ret_tcode))
230    return 0
231
232cdef inline int FuncCall(void* chandle,
233                         tuple args,
234                         TVMValue* ret_val,
235                         int* ret_tcode) except -1:
236    cdef int nargs
237    nargs = len(args)
238    if nargs <= 3:
239        FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
240        return 0
241
242    cdef vector[TVMValue] values
243    cdef vector[int] tcodes
244    values.resize(max(nargs, 1))
245    tcodes.resize(max(nargs, 1))
246    temp_args = []
247    for i in range(nargs):
248        make_arg(args[i], &values[i], &tcodes[i], temp_args)
249    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
250                     nargs, ret_val, ret_tcode))
251    return 0
252
253
254cdef inline int ConstructorCall(void* constructor_handle,
255                                int type_code,
256                                tuple args,
257                                void** handle) except -1:
258    """Call contructor of a handle function"""
259    cdef TVMValue ret_val
260    cdef int ret_tcode
261    FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
262    assert ret_tcode == type_code
263    handle[0] = ret_val.v_handle
264    return 0
265
266
267cdef class FunctionBase:
268    cdef TVMFunctionHandle chandle
269    cdef int is_global
270
271    cdef inline _set_handle(self, handle):
272        if handle is None:
273            self.chandle = NULL
274        else:
275            self.chandle = c_handle(handle)
276
277    property is_global:
278        def __get__(self):
279            return self.c_is_global != 0
280
281        def __set__(self, value):
282            self.c_is_global = value
283
284    property handle:
285        def __get__(self):
286            if self.chandle == NULL:
287                return None
288            else:
289                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
290        def __set__(self, value):
291            self._set_handle(value)
292
293    def __init__(self, handle, is_global):
294        self._set_handle(handle)
295        self.c_is_global = is_global
296
297    def __dealloc__(self):
298        if self.is_global == 0:
299            CALL(TVMFuncFree(self.chandle))
300
301    def __call__(self, *args):
302        cdef TVMValue ret_val
303        cdef int ret_tcode
304        FuncCall(self.chandle, args, &ret_val, &ret_tcode)
305        return make_ret(ret_val, ret_tcode)
306
307
308_CLASS_FUNCTION = None
309_CLASS_MODULE = None
310_CLASS_OBJECT = None
311_CLASS_NODE = None
312
313def _set_class_module(module_class):
314    """Initialize the module."""
315    global _CLASS_MODULE
316    _CLASS_MODULE = module_class
317
318def _set_class_function(func_class):
319    global _CLASS_FUNCTION
320    _CLASS_FUNCTION = func_class
321
322def _set_class_object(obj_class):
323    global _CLASS_OBJECT
324    _CLASS_OBJECT = obj_class
325
326def _set_class_node(node_class):
327    global _CLASS_NODE
328    _CLASS_NODE = node_class
329