1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8 9# http://www.apache.org/licenses/LICENSE-2.0 10 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from ..base import get_last_ffi_error 19 20from libcpp.vector cimport vector 21from libcpp.string cimport string 22from libcpp cimport bool as _bool 23from cpython.version cimport PY_MAJOR_VERSION 24 25ctypedef void* SymbolHandle 26ctypedef void* NDArrayHandle 27ctypedef void* OpHandle 28ctypedef void* CachedOpHandle 29ctypedef void* MonitorCallbackHandle 30ctypedef unsigned nn_uint 31ctypedef void (*CachedOpMonitorCallback)(const char*, 32 const char*, 33 NDArrayHandle) 34 35cdef py_str(const char* x): 36 if PY_MAJOR_VERSION < 3: 37 return x 38 else: 39 return x.decode("utf-8") 40 41cdef c_str(pystr): 42 """Create ctypes char * from a python string 43 Parameters 44 ---------- 45 string : string type 46 python string 47 48 Returns 49 ------- 50 str : c_char_p 51 A char pointer that can be passed to C API 52 """ 53 return pystr.encode("utf-8") 54 55 56cdef CALL(int ret): 57 if ret != 0: 58 raise get_last_ffi_error() 59 60 61cdef const char** CBeginPtr(vector[const char*]& vec): 62 if (vec.size() != 0): 63 return &vec[0] 64 else: 65 return NULL 66 67cdef vector[const char*] SVec2Ptr(vector[string]& vec): 68 cdef vector[const char*] svec 69 svec.resize(vec.size()) 70 for i in range(vec.size()): 71 svec[i] = vec[i].c_str() 72 return svec 73 74 75cdef extern from "nnvm/c_api.h": 76 const char* NNGetLastError(); 77 int NNGetOpHandle(const char *op_name, 78 OpHandle *handle); 79 int NNGetOpInfo(OpHandle op, 80 const char **name, 81 const char **description, 82 nn_uint *num_doc_args, 83 const char ***arg_names, 84 const char ***arg_type_infos, 85 const char ***arg_descriptions, 86 const char **return_type); 87 int NNSymbolFree(SymbolHandle symbol); 88 int NNSymbolGetNumOutputs(SymbolHandle sym, 89 nn_uint* output_count); 90 int NNSymbolCompose(SymbolHandle sym, 91 const char* name, 92 nn_uint num_args, 93 const char** keys, 94 SymbolHandle* args); 95 96 97cdef extern from "mxnet/c_api.h": 98 int MXListAllOpNames(nn_uint *out_size, 99 const char ***out_array); 100 int MXSymbolGetAtomicSymbolInfo(OpHandle creator, 101 const char **name, 102 const char **description, 103 nn_uint *num_doc_args, 104 const char ***arg_names, 105 const char ***arg_type_infos, 106 const char ***arg_descriptions, 107 const char **key_var_args, 108 const char **return_type); 109 int MXSymbolCreateAtomicSymbol(OpHandle op, 110 nn_uint num_param, 111 const char **keys, 112 const char **vals, 113 SymbolHandle *out); 114 int MXSymbolSetAttr(SymbolHandle symbol, 115 const char* key, 116 const char* value); 117 int MXImperativeInvokeEx(OpHandle creator, 118 int num_inputs, 119 NDArrayHandle *inputs, 120 int *num_outputs, 121 NDArrayHandle **outputs, 122 int num_params, 123 const char **param_keys, 124 const char **param_vals, 125 const int **out_stypes); 126 int MXNDArrayFree(NDArrayHandle handle); 127 int MXCreateCachedOpEx(SymbolHandle handle, 128 int num_flags, 129 const char** keys, 130 const char** vals, 131 CachedOpHandle *out); 132 int MXFreeCachedOp(CachedOpHandle handle); 133 int MXInvokeCachedOpEx(CachedOpHandle handle, 134 int num_inputs, 135 NDArrayHandle *inputs, 136 int *num_outputs, 137 NDArrayHandle **outputs, 138 const int **out_stypes); 139 int MXCachedOpRegisterOpHook(NDArrayHandle handle, 140 CachedOpMonitorCallback callback, 141 _bool monitor_all); 142