1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Acknowledgement: This file originates from incubator-tvm"""
19
20from libcpp.vector cimport vector
21from cpython.version cimport PY_MAJOR_VERSION
22from cpython cimport pycapsule
23from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t, uint32_t
24import ctypes
25from ...base import get_last_ffi_error
26
27cdef enum MXNetTypeCode:
28    kInt = 0
29    kUInt = 1
30    kFloat = 2
31    kHandle = 3
32    kNull = 4
33    kMXNetType = 5
34    kMXNetContext = 6
35    kArrayHandle = 7
36    kObjectHandle = 8
37    kModuleHandle = 9
38    kFuncHandle = 10
39    kStr = 11
40    kBytes = 12
41    kNDArrayContainer = 13
42    kNDArrayHandle = 14
43    kExtBegin = 15
44
45cdef extern from "mxnet/runtime/c_runtime_api.h":
46    ctypedef struct MXNetValue:
47        int64_t v_int64
48        double v_float64
49        void* v_handle
50        const char* v_str
51
52ctypedef void* MXNetRetValueHandle
53ctypedef void* MXNetFunctionHandle
54ctypedef void* ObjectHandle
55
56
57cdef extern from "mxnet/runtime/c_runtime_api.h":
58    int MXNetFuncCall(MXNetFunctionHandle func,
59                      MXNetValue* arg_values,
60                      int* type_codes,
61                      int num_args,
62                      MXNetValue* ret_val,
63                      int* ret_type_code)
64    int MXNetFuncFree(MXNetFunctionHandle func)
65
66
67cdef inline py_str(const char* x):
68    if PY_MAJOR_VERSION < 3:
69        return x
70    else:
71        return x.decode("utf-8")
72
73
74cdef inline c_str(pystr):
75    """Create ctypes char * from a python string
76    Parameters
77    ----------
78    string : string type
79        python string
80
81    Returns
82    -------
83    str : c_char_p
84        A char pointer that can be passed to C API
85    """
86    return pystr.encode("utf-8")
87
88
89cdef inline CALL(int ret):
90    if ret != 0:
91        raise get_last_ffi_error()
92
93
94cdef inline object ctypes_handle(void* chandle):
95    """Cast C handle to ctypes handle."""
96    return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
97
98
99cdef inline void* c_handle(object handle):
100    """Cast C types handle to c handle."""
101    cdef unsigned long long v_ptr
102    v_ptr = handle.value
103    return <void*>(v_ptr)
104