1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8
9#   http://www.apache.org/licenses/LICENSE-2.0
10
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from ..base import get_last_ffi_error
19
20from libcpp.vector cimport vector
21from libcpp.string cimport string
22from libcpp cimport bool as _bool
23from cpython.version cimport PY_MAJOR_VERSION
24
25ctypedef void* SymbolHandle
26ctypedef void* NDArrayHandle
27ctypedef void* OpHandle
28ctypedef void* CachedOpHandle
29ctypedef void* MonitorCallbackHandle
30ctypedef unsigned nn_uint
31ctypedef void (*CachedOpMonitorCallback)(const char*,
32                                         const char*,
33                                         NDArrayHandle)
34
35cdef py_str(const char* x):
36    if PY_MAJOR_VERSION < 3:
37        return x
38    else:
39        return x.decode("utf-8")
40
41cdef c_str(pystr):
42    """Create ctypes char * from a python string
43    Parameters
44    ----------
45    string : string type
46        python string
47
48    Returns
49    -------
50    str : c_char_p
51        A char pointer that can be passed to C API
52    """
53    return pystr.encode("utf-8")
54
55
56cdef CALL(int ret):
57    if ret != 0:
58        raise get_last_ffi_error()
59
60
61cdef const char** CBeginPtr(vector[const char*]& vec):
62    if (vec.size() != 0):
63        return &vec[0]
64    else:
65        return NULL
66
67cdef vector[const char*] SVec2Ptr(vector[string]& vec):
68    cdef vector[const char*] svec
69    svec.resize(vec.size())
70    for i in range(vec.size()):
71        svec[i] = vec[i].c_str()
72    return svec
73
74
75cdef extern from "nnvm/c_api.h":
76    const char* NNGetLastError();
77    int NNGetOpHandle(const char *op_name,
78                      OpHandle *handle);
79    int NNGetOpInfo(OpHandle op,
80                    const char **name,
81                    const char **description,
82                    nn_uint *num_doc_args,
83                    const char ***arg_names,
84                    const char ***arg_type_infos,
85                    const char ***arg_descriptions,
86                    const char **return_type);
87    int NNSymbolFree(SymbolHandle symbol);
88    int NNSymbolGetNumOutputs(SymbolHandle sym,
89                              nn_uint* output_count);
90    int NNSymbolCompose(SymbolHandle sym,
91                        const char* name,
92                        nn_uint num_args,
93                        const char** keys,
94                        SymbolHandle* args);
95
96
97cdef extern from "mxnet/c_api.h":
98    int MXListAllOpNames(nn_uint *out_size,
99                         const char ***out_array);
100    int MXSymbolGetAtomicSymbolInfo(OpHandle creator,
101                                    const char **name,
102                                    const char **description,
103                                    nn_uint *num_doc_args,
104                                    const char ***arg_names,
105                                    const char ***arg_type_infos,
106                                    const char ***arg_descriptions,
107                                    const char **key_var_args,
108                                    const char **return_type);
109    int MXSymbolCreateAtomicSymbol(OpHandle op,
110                                   nn_uint num_param,
111                                   const char **keys,
112                                   const char **vals,
113                                   SymbolHandle *out);
114    int MXSymbolSetAttr(SymbolHandle symbol,
115                        const char* key,
116                        const char* value);
117    int MXImperativeInvokeEx(OpHandle creator,
118                             int num_inputs,
119                             NDArrayHandle *inputs,
120                             int *num_outputs,
121                             NDArrayHandle **outputs,
122                             int num_params,
123                             const char **param_keys,
124                             const char **param_vals,
125                             const int **out_stypes);
126    int MXNDArrayFree(NDArrayHandle handle);
127    int MXCreateCachedOpEx(SymbolHandle handle,
128                            int num_flags,
129                            const char** keys,
130                            const char** vals,
131                            CachedOpHandle *out);
132    int MXFreeCachedOp(CachedOpHandle handle);
133    int MXInvokeCachedOpEx(CachedOpHandle handle,
134                           int num_inputs,
135                           NDArrayHandle *inputs,
136                           int *num_outputs,
137                           NDArrayHandle **outputs,
138                           const int **out_stypes);
139    int MXCachedOpRegisterOpHook(NDArrayHandle handle,
140                                 CachedOpMonitorCallback callback,
141                                 _bool monitor_all);
142