1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_predict_api.h
22  * \brief C predict API of mxnet, contains a minimum API to run prediction.
23  *  This file is self-contained, and do not dependent on any other files.
24  */
25 #ifndef MXNET_C_PREDICT_API_H_
26 #define MXNET_C_PREDICT_API_H_
27 
28 /*! \brief Inhibit C++ name-mangling for MXNet functions. */
29 #ifdef __cplusplus
30 extern "C" {
31 #endif  // __cplusplus
32 
33 #ifdef _WIN32
34 #ifdef MXNET_EXPORTS
35 #define MXNET_DLL __declspec(dllexport)
36 #else
37 #define MXNET_DLL __declspec(dllimport)
38 #endif
39 #else
40 #define MXNET_DLL
41 #endif
42 
43 /*! \brief manually define unsigned int */
44 typedef uint32_t mx_uint;
45 /*! \brief manually define float */
46 typedef float mx_float;
47 /*! \brief handle to Predictor */
48 typedef void *PredictorHandle;
49 /*! \brief handle to NDArray list */
50 typedef void *NDListHandle;
51 /*! \brief handle to NDArray */
52 typedef void *NDArrayHandle;
53 /*! \brief callback used for add monitoring to nodes in the graph */
54 typedef void (*PredMonitorCallback)(const char*,
55                                     NDArrayHandle,
56                                     void*);
57 
58 /*!
59  * \brief Get the last error happeneed.
60  * \return The last error happened at the predictor.
61  */
62 MXNET_DLL const char* MXGetLastError();
63 
64 /*!
65  * \brief create a predictor
66  * \param symbol_json_str The JSON string of the symbol.
67  * \param param_bytes The in-memory raw bytes of parameter ndarray file.
68  * \param param_size The size of parameter ndarray file.
69  * \param dev_type The device type, 1: cpu, 2:gpu
70  * \param dev_id The device id of the predictor.
71  * \param num_input_nodes Number of input nodes to the net,
72  *    For feedforward net, this is 1.
73  * \param input_keys The name of input argument.
74  *    For feedforward net, this is {"data"}
75  * \param input_shape_indptr Index pointer of shapes of each input node.
76  *    The length of this array = num_input_nodes + 1.
77  *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
78  * \param input_shape_data A flattened data of shapes of each input node.
79  *    For feedforward net that takes 4 dimensional input, this is the shape data.
80  * \param out The created predictor handle.
81  * \return 0 when success, -1 when failure.
82  */
83 MXNET_DLL int MXPredCreate(const char* symbol_json_str,
84                            const void* param_bytes,
85                            int param_size,
86                            int dev_type, int dev_id,
87                            uint32_t num_input_nodes,
88                            const char** input_keys,
89                            const uint32_t* input_shape_indptr,
90                            const uint32_t* input_shape_data,
91                            PredictorHandle* out);
92 
93 /*!
94  * \brief create a predictor
95  * \param symbol_json_str The JSON string of the symbol.
96  * \param param_bytes The in-memory raw bytes of parameter ndarray file.
97  * \param param_size The size of parameter ndarray file.
98  * \param dev_type The device type, 1: cpu, 2: gpu
99  * \param dev_id The device id of the predictor.
100  * \param num_input_nodes Number of input nodes to the net.
101  *    For feedforward net, this is 1.
102  * \param input_keys The name of the input argument.
103  *    For feedforward net, this is {"data"}
104  * \param input_shape_indptr Index pointer of shapes of each input node.
105  *    The length of this array = num_input_nodes + 1.
106  *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
107  * \param input_shape_data A flattened data of shapes of each input node.
108  *    For feedforward net that takes 4 dimensional input, this is the shape data.
109  * \param num_provided_arg_dtypes
110  *    The length of provided_arg_dtypes.
111  * \param provided_arg_dtype_names
112  *    The provided_arg_dtype_names the names of args for which dtypes are provided.
113  * \param provided_arg_dtypes
114  *    The provided_arg_dtypes the dtype provided
115  * \param out The created predictor handle.
116  * \return 0 when success, -1 when failure.
117  */
118 MXNET_DLL int MXPredCreateEx(const char* symbol_json_str,
119                              const void* param_bytes,
120                              int param_size,
121                              int dev_type, int dev_id,
122                              const uint32_t num_input_nodes,
123                              const char** input_keys,
124                              const uint32_t* input_shape_indptr,
125                              const uint32_t* input_shape_data,
126                              const uint32_t num_provided_arg_dtypes,
127                              const char** provided_arg_dtype_names,
128                              const int* provided_arg_dtypes,
129                              PredictorHandle* out);
130 
131 /*!
132  * \brief create a predictor wich customized outputs
133  * \param symbol_json_str The JSON string of the symbol.
134  * \param param_bytes The in-memory raw bytes of parameter ndarray file.
135  * \param param_size The size of parameter ndarray file.
136  * \param dev_type The device type, 1: cpu, 2:gpu
137  * \param dev_id The device id of the predictor.
138  * \param num_input_nodes Number of input nodes to the net,
139  *    For feedforward net, this is 1.
140  * \param input_keys The name of input argument.
141  *    For feedforward net, this is {"data"}
142  * \param input_shape_indptr Index pointer of shapes of each input node.
143  *    The length of this array = num_input_nodes + 1.
144  *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
145  * \param input_shape_data A flattened data of shapes of each input node.
146  *    For feedforward net that takes 4 dimensional input, this is the shape data.
147  * \param num_output_nodes Number of output nodes to the net,
148  * \param output_keys The name of output argument.
149  *    For example {"global_pool"}
150  * \param out The created predictor handle.
151  * \return 0 when success, -1 when failure.
152  */
153 
154 MXNET_DLL int MXPredCreatePartialOut(const char* symbol_json_str,
155                                      const void* param_bytes,
156                                      int param_size,
157                                      int dev_type, int dev_id,
158                                      uint32_t num_input_nodes,
159                                      const char** input_keys,
160                                      const uint32_t* input_shape_indptr,
161                                      const uint32_t* input_shape_data,
162                                      uint32_t num_output_nodes,
163                                      const char** output_keys,
164                                      PredictorHandle* out);
165 
166 /*!
167  * \brief create predictors for multiple threads. One predictor for a thread.
168  * \param symbol_json_str The JSON string of the symbol.
169  * \param param_bytes The in-memory raw bytes of parameter ndarray file.
170  * \param param_size The size of parameter ndarray file.
171  * \param dev_type The device type, 1: cpu, 2:gpu
172  * \param dev_id The device id of the predictor.
173  * \param num_input_nodes Number of input nodes to the net,
174  *    For feedforward net, this is 1.
175  * \param input_keys The name of input argument.
176  *    For feedforward net, this is {"data"}
177  * \param input_shape_indptr Index pointer of shapes of each input node.
178  *    The length of this array = num_input_nodes + 1.
179  *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
180  * \param input_shape_data A flattened data of shapes of each input node.
181  *    For feedforward net that takes 4 dimensional input, this is the shape data.
182  * \param num_threads The number of threads that we'll run the predictors.
183  * \param out An array of created predictor handles. The array has to be large
184  *   enough to keep `num_threads` predictors.
185  * \return 0 when success, -1 when failure.
186  */
187 MXNET_DLL int MXPredCreateMultiThread(const char* symbol_json_str,
188                                       const void* param_bytes,
189                                       int param_size,
190                                       int dev_type, int dev_id,
191                                       uint32_t num_input_nodes,
192                                       const char** input_keys,
193                                       const uint32_t* input_shape_indptr,
194                                       const uint32_t* input_shape_data,
195                                       int num_threads,
196                                       PredictorHandle* out);
197 
198 /*!
199  * \brief Change the input shape of an existing predictor.
200  * \param num_input_nodes Number of input nodes to the net,
201  *    For feedforward net, this is 1.
202  * \param input_keys The name of input argument.
203  *    For feedforward net, this is {"data"}
204  * \param input_shape_indptr Index pointer of shapes of each input node.
205  *    The length of this array = num_input_nodes + 1.
206  *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
207  * \param input_shape_data A flattened data of shapes of each input node.
208  *    For feedforward net that takes 4 dimensional input, this is the shape data.
209  * \param handle The original predictor handle.
210  * \param out The reshaped predictor handle.
211  * \return 0 when success, -1 when failure.
212  */
213 MXNET_DLL int MXPredReshape(uint32_t num_input_nodes,
214                   const char** input_keys,
215                   const uint32_t* input_shape_indptr,
216                   const uint32_t* input_shape_data,
217                   PredictorHandle handle,
218                   PredictorHandle* out);
219 /*!
220  * \brief Get the shape of output node.
221  *  The returned shape_data and shape_ndim is only valid before next call to MXPred function.
222  * \param handle The handle of the predictor.
223  * \param index The index of output node, set to 0 if there is only one output.
224  * \param shape_data Used to hold pointer to the shape data
225  * \param shape_ndim Used to hold shape dimension.
226  * \return 0 when success, -1 when failure.
227  */
228 MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle,
229                                    uint32_t index,
230                                    uint32_t** shape_data,
231                                    uint32_t* shape_ndim);
232 
233 /*!
234  * \brief Get the dtype of output node.
235  * The returned data type is only valid before next call to MXPred function.
236  * \param handle The handle of the predictor.
237  * \param out_index The index of the output node, set to 0 if there is only one output.
238  * \param out_dtype The dtype of the output node
239  */
240 MXNET_DLL int MXPredGetOutputType(PredictorHandle handle,
241                                   uint32_t out_index,
242                                   int* out_dtype);
243 
244 /*!
245  * \brief Set the input data of predictor.
246  * \param handle The predictor handle.
247  * \param key The name of input node to set.
248  *     For feedforward net, this is "data".
249  * \param data The pointer to the data to be set, with the shape specified in MXPredCreate.
250  * \param size The size of data array, used for safety check.
251  * \return 0 when success, -1 when failure.
252  */
253 MXNET_DLL int MXPredSetInput(PredictorHandle handle,
254                              const char* key,
255                              const float* data,
256                              uint32_t size);
257 /*!
258  * \brief Run a forward pass to get the output.
259  * \param handle The handle of the predictor.
260  * \return 0 when success, -1 when failure.
261  */
262 MXNET_DLL int MXPredForward(PredictorHandle handle);
263 /*!
264  * \brief Run a interactive forward pass to get the output.
265  *  This is helpful for displaying progress of prediction which can be slow.
266  *  User must call PartialForward from step=0, keep increasing it until step_left=0.
267  * \code
268  * int step_left = 1;
269  * for (int step = 0; step_left != 0; ++step) {
270  *    MXPredPartialForward(handle, step, &step_left);
271  *    printf("Current progress [%d/%d]\n", step, step + step_left + 1);
272  * }
273  * \endcode
274  * \param handle The handle of the predictor.
275  * \param step The current step to run forward on.
276  * \param step_left The number of steps left
277  * \return 0 when success, -1 when failure.
278  */
279 MXNET_DLL int MXPredPartialForward(PredictorHandle handle, int step, int* step_left);
280 /*!
281  * \brief Get the output value of prediction.
282  * \param handle The handle of the predictor.
283  * \param index The index of output node, set to 0 if there is only one output.
284  * \param data User allocated data to hold the output.
285  * \param size The size of data array, used for safe checking.
286  * \return 0 when success, -1 when failure.
287  */
288 MXNET_DLL int MXPredGetOutput(PredictorHandle handle,
289                               uint32_t index,
290                               float* data,
291                               uint32_t size);
292 /*!
293  * \brief Free a predictor handle.
294  * \param handle The handle of the predictor.
295  * \return 0 when success, -1 when failure.
296  */
297 MXNET_DLL int MXPredFree(PredictorHandle handle);
298 /*!
299  * \brief Create a NDArray List by loading from ndarray file.
300  *     This can be used to load mean image file.
301  * \param nd_file_bytes The byte contents of nd file to be loaded.
302  * \param nd_file_size The size of the nd file to be loaded.
303  * \param out The out put NDListHandle
304  * \param out_length Length of the list.
305  * \return 0 when success, -1 when failure.
306  */
307 MXNET_DLL int MXNDListCreate(const char* nd_file_bytes,
308                              int nd_file_size,
309                              NDListHandle *out,
310                              uint32_t* out_length);
311 /*!
312  * \brief Get an element from list
313  * \param handle The handle to the NDArray
314  * \param index The index in the list
315  * \param out_key The output key of the item
316  * \param out_data The data region of the item
317  * \param out_shape The shape of the item.
318  * \param out_ndim The number of dimension in the shape.
319  * \return 0 when success, -1 when failure.
320  */
321 MXNET_DLL int MXNDListGet(NDListHandle handle,
322                           uint32_t index,
323                           const char** out_key,
324                           const float** out_data,
325                           const uint32_t** out_shape,
326                           uint32_t* out_ndim);
327 
328 /*!
329  * \brief set a call back to notify the completion of operation and allow for
330  * additional monitoring
331  */
332 MXNET_DLL int MXPredSetMonitorCallback(PredictorHandle handle,
333                                        PredMonitorCallback callback,
334                                        void* callback_handle,
335                                        bool monitor_all);
336 /*!
337  * \brief Free a MXAPINDList
338  * \param handle The handle of the MXAPINDList.
339  * \return 0 when success, -1 when failure.
340  */
341 MXNET_DLL int MXNDListFree(NDListHandle handle);
342 
343 #ifdef __cplusplus
344 }
345 #endif  // __cplusplus
346 
347 #endif  // MXNET_C_PREDICT_API_H_
348