1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Acknowledgement: This file originates from incubator-tvm""" 19 20from libcpp.vector cimport vector 21from cpython.version cimport PY_MAJOR_VERSION 22from cpython cimport pycapsule 23from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t, uint32_t 24import ctypes 25from ...base import get_last_ffi_error 26 27cdef enum MXNetTypeCode: 28 kInt = 0 29 kUInt = 1 30 kFloat = 2 31 kHandle = 3 32 kNull = 4 33 kMXNetType = 5 34 kMXNetContext = 6 35 kArrayHandle = 7 36 kObjectHandle = 8 37 kModuleHandle = 9 38 kFuncHandle = 10 39 kStr = 11 40 kBytes = 12 41 kNDArrayContainer = 13 42 kNDArrayHandle = 14 43 kExtBegin = 15 44 45cdef extern from "mxnet/runtime/c_runtime_api.h": 46 ctypedef struct MXNetValue: 47 int64_t v_int64 48 double v_float64 49 void* v_handle 50 const char* v_str 51 52ctypedef void* MXNetRetValueHandle 53ctypedef void* MXNetFunctionHandle 54ctypedef void* ObjectHandle 55 56 57cdef extern from "mxnet/runtime/c_runtime_api.h": 58 int MXNetFuncCall(MXNetFunctionHandle func, 59 MXNetValue* arg_values, 60 int* type_codes, 61 int num_args, 62 MXNetValue* ret_val, 63 int* ret_type_code) 64 int MXNetFuncFree(MXNetFunctionHandle func) 65 66 67cdef inline py_str(const char* x): 68 if PY_MAJOR_VERSION < 3: 69 return x 70 else: 71 return x.decode("utf-8") 72 73 74cdef inline c_str(pystr): 75 """Create ctypes char * from a python string 76 Parameters 77 ---------- 78 string : string type 79 python string 80 81 Returns 82 ------- 83 str : c_char_p 84 A char pointer that can be passed to C API 85 """ 86 return pystr.encode("utf-8") 87 88 89cdef inline CALL(int ret): 90 if ret != 0: 91 raise get_last_ffi_error() 92 93 94cdef inline object ctypes_handle(void* chandle): 95 """Cast C handle to ctypes handle.""" 96 return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p) 97 98 99cdef inline void* c_handle(object handle): 100 """Cast C types handle to c handle.""" 101 cdef unsigned long long v_ptr 102 v_ptr = handle.value 103 return <void*>(v_ptr) 104