1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file c_predict_api.h 22 * \brief C predict API of mxnet, contains a minimum API to run prediction. 23 * This file is self-contained, and do not dependent on any other files. 24 */ 25 #ifndef MXNET_C_PREDICT_API_H_ 26 #define MXNET_C_PREDICT_API_H_ 27 28 /*! \brief Inhibit C++ name-mangling for MXNet functions. */ 29 #ifdef __cplusplus 30 extern "C" { 31 #endif // __cplusplus 32 33 #ifdef _WIN32 34 #ifdef MXNET_EXPORTS 35 #define MXNET_DLL __declspec(dllexport) 36 #else 37 #define MXNET_DLL __declspec(dllimport) 38 #endif 39 #else 40 #define MXNET_DLL 41 #endif 42 43 /*! \brief manually define unsigned int */ 44 typedef uint32_t mx_uint; 45 /*! \brief manually define float */ 46 typedef float mx_float; 47 /*! \brief handle to Predictor */ 48 typedef void *PredictorHandle; 49 /*! \brief handle to NDArray list */ 50 typedef void *NDListHandle; 51 /*! \brief handle to NDArray */ 52 typedef void *NDArrayHandle; 53 /*! \brief callback used for add monitoring to nodes in the graph */ 54 typedef void (*PredMonitorCallback)(const char*, 55 NDArrayHandle, 56 void*); 57 58 /*! 59 * \brief Get the last error happeneed. 60 * \return The last error happened at the predictor. 61 */ 62 MXNET_DLL const char* MXGetLastError(); 63 64 /*! 65 * \brief create a predictor 66 * \param symbol_json_str The JSON string of the symbol. 67 * \param param_bytes The in-memory raw bytes of parameter ndarray file. 68 * \param param_size The size of parameter ndarray file. 69 * \param dev_type The device type, 1: cpu, 2:gpu 70 * \param dev_id The device id of the predictor. 71 * \param num_input_nodes Number of input nodes to the net, 72 * For feedforward net, this is 1. 73 * \param input_keys The name of input argument. 74 * For feedforward net, this is {"data"} 75 * \param input_shape_indptr Index pointer of shapes of each input node. 76 * The length of this array = num_input_nodes + 1. 77 * For feedforward net that takes 4 dimensional input, this is {0, 4}. 78 * \param input_shape_data A flattened data of shapes of each input node. 79 * For feedforward net that takes 4 dimensional input, this is the shape data. 80 * \param out The created predictor handle. 81 * \return 0 when success, -1 when failure. 82 */ 83 MXNET_DLL int MXPredCreate(const char* symbol_json_str, 84 const void* param_bytes, 85 int param_size, 86 int dev_type, int dev_id, 87 uint32_t num_input_nodes, 88 const char** input_keys, 89 const uint32_t* input_shape_indptr, 90 const uint32_t* input_shape_data, 91 PredictorHandle* out); 92 93 /*! 94 * \brief create a predictor 95 * \param symbol_json_str The JSON string of the symbol. 96 * \param param_bytes The in-memory raw bytes of parameter ndarray file. 97 * \param param_size The size of parameter ndarray file. 98 * \param dev_type The device type, 1: cpu, 2: gpu 99 * \param dev_id The device id of the predictor. 100 * \param num_input_nodes Number of input nodes to the net. 101 * For feedforward net, this is 1. 102 * \param input_keys The name of the input argument. 103 * For feedforward net, this is {"data"} 104 * \param input_shape_indptr Index pointer of shapes of each input node. 105 * The length of this array = num_input_nodes + 1. 106 * For feedforward net that takes 4 dimensional input, this is {0, 4}. 107 * \param input_shape_data A flattened data of shapes of each input node. 108 * For feedforward net that takes 4 dimensional input, this is the shape data. 109 * \param num_provided_arg_dtypes 110 * The length of provided_arg_dtypes. 111 * \param provided_arg_dtype_names 112 * The provided_arg_dtype_names the names of args for which dtypes are provided. 113 * \param provided_arg_dtypes 114 * The provided_arg_dtypes the dtype provided 115 * \param out The created predictor handle. 116 * \return 0 when success, -1 when failure. 117 */ 118 MXNET_DLL int MXPredCreateEx(const char* symbol_json_str, 119 const void* param_bytes, 120 int param_size, 121 int dev_type, int dev_id, 122 const uint32_t num_input_nodes, 123 const char** input_keys, 124 const uint32_t* input_shape_indptr, 125 const uint32_t* input_shape_data, 126 const uint32_t num_provided_arg_dtypes, 127 const char** provided_arg_dtype_names, 128 const int* provided_arg_dtypes, 129 PredictorHandle* out); 130 131 /*! 132 * \brief create a predictor wich customized outputs 133 * \param symbol_json_str The JSON string of the symbol. 134 * \param param_bytes The in-memory raw bytes of parameter ndarray file. 135 * \param param_size The size of parameter ndarray file. 136 * \param dev_type The device type, 1: cpu, 2:gpu 137 * \param dev_id The device id of the predictor. 138 * \param num_input_nodes Number of input nodes to the net, 139 * For feedforward net, this is 1. 140 * \param input_keys The name of input argument. 141 * For feedforward net, this is {"data"} 142 * \param input_shape_indptr Index pointer of shapes of each input node. 143 * The length of this array = num_input_nodes + 1. 144 * For feedforward net that takes 4 dimensional input, this is {0, 4}. 145 * \param input_shape_data A flattened data of shapes of each input node. 146 * For feedforward net that takes 4 dimensional input, this is the shape data. 147 * \param num_output_nodes Number of output nodes to the net, 148 * \param output_keys The name of output argument. 149 * For example {"global_pool"} 150 * \param out The created predictor handle. 151 * \return 0 when success, -1 when failure. 152 */ 153 154 MXNET_DLL int MXPredCreatePartialOut(const char* symbol_json_str, 155 const void* param_bytes, 156 int param_size, 157 int dev_type, int dev_id, 158 uint32_t num_input_nodes, 159 const char** input_keys, 160 const uint32_t* input_shape_indptr, 161 const uint32_t* input_shape_data, 162 uint32_t num_output_nodes, 163 const char** output_keys, 164 PredictorHandle* out); 165 166 /*! 167 * \brief create predictors for multiple threads. One predictor for a thread. 168 * \param symbol_json_str The JSON string of the symbol. 169 * \param param_bytes The in-memory raw bytes of parameter ndarray file. 170 * \param param_size The size of parameter ndarray file. 171 * \param dev_type The device type, 1: cpu, 2:gpu 172 * \param dev_id The device id of the predictor. 173 * \param num_input_nodes Number of input nodes to the net, 174 * For feedforward net, this is 1. 175 * \param input_keys The name of input argument. 176 * For feedforward net, this is {"data"} 177 * \param input_shape_indptr Index pointer of shapes of each input node. 178 * The length of this array = num_input_nodes + 1. 179 * For feedforward net that takes 4 dimensional input, this is {0, 4}. 180 * \param input_shape_data A flattened data of shapes of each input node. 181 * For feedforward net that takes 4 dimensional input, this is the shape data. 182 * \param num_threads The number of threads that we'll run the predictors. 183 * \param out An array of created predictor handles. The array has to be large 184 * enough to keep `num_threads` predictors. 185 * \return 0 when success, -1 when failure. 186 */ 187 MXNET_DLL int MXPredCreateMultiThread(const char* symbol_json_str, 188 const void* param_bytes, 189 int param_size, 190 int dev_type, int dev_id, 191 uint32_t num_input_nodes, 192 const char** input_keys, 193 const uint32_t* input_shape_indptr, 194 const uint32_t* input_shape_data, 195 int num_threads, 196 PredictorHandle* out); 197 198 /*! 199 * \brief Change the input shape of an existing predictor. 200 * \param num_input_nodes Number of input nodes to the net, 201 * For feedforward net, this is 1. 202 * \param input_keys The name of input argument. 203 * For feedforward net, this is {"data"} 204 * \param input_shape_indptr Index pointer of shapes of each input node. 205 * The length of this array = num_input_nodes + 1. 206 * For feedforward net that takes 4 dimensional input, this is {0, 4}. 207 * \param input_shape_data A flattened data of shapes of each input node. 208 * For feedforward net that takes 4 dimensional input, this is the shape data. 209 * \param handle The original predictor handle. 210 * \param out The reshaped predictor handle. 211 * \return 0 when success, -1 when failure. 212 */ 213 MXNET_DLL int MXPredReshape(uint32_t num_input_nodes, 214 const char** input_keys, 215 const uint32_t* input_shape_indptr, 216 const uint32_t* input_shape_data, 217 PredictorHandle handle, 218 PredictorHandle* out); 219 /*! 220 * \brief Get the shape of output node. 221 * The returned shape_data and shape_ndim is only valid before next call to MXPred function. 222 * \param handle The handle of the predictor. 223 * \param index The index of output node, set to 0 if there is only one output. 224 * \param shape_data Used to hold pointer to the shape data 225 * \param shape_ndim Used to hold shape dimension. 226 * \return 0 when success, -1 when failure. 227 */ 228 MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle, 229 uint32_t index, 230 uint32_t** shape_data, 231 uint32_t* shape_ndim); 232 233 /*! 234 * \brief Get the dtype of output node. 235 * The returned data type is only valid before next call to MXPred function. 236 * \param handle The handle of the predictor. 237 * \param out_index The index of the output node, set to 0 if there is only one output. 238 * \param out_dtype The dtype of the output node 239 */ 240 MXNET_DLL int MXPredGetOutputType(PredictorHandle handle, 241 uint32_t out_index, 242 int* out_dtype); 243 244 /*! 245 * \brief Set the input data of predictor. 246 * \param handle The predictor handle. 247 * \param key The name of input node to set. 248 * For feedforward net, this is "data". 249 * \param data The pointer to the data to be set, with the shape specified in MXPredCreate. 250 * \param size The size of data array, used for safety check. 251 * \return 0 when success, -1 when failure. 252 */ 253 MXNET_DLL int MXPredSetInput(PredictorHandle handle, 254 const char* key, 255 const float* data, 256 uint32_t size); 257 /*! 258 * \brief Run a forward pass to get the output. 259 * \param handle The handle of the predictor. 260 * \return 0 when success, -1 when failure. 261 */ 262 MXNET_DLL int MXPredForward(PredictorHandle handle); 263 /*! 264 * \brief Run a interactive forward pass to get the output. 265 * This is helpful for displaying progress of prediction which can be slow. 266 * User must call PartialForward from step=0, keep increasing it until step_left=0. 267 * \code 268 * int step_left = 1; 269 * for (int step = 0; step_left != 0; ++step) { 270 * MXPredPartialForward(handle, step, &step_left); 271 * printf("Current progress [%d/%d]\n", step, step + step_left + 1); 272 * } 273 * \endcode 274 * \param handle The handle of the predictor. 275 * \param step The current step to run forward on. 276 * \param step_left The number of steps left 277 * \return 0 when success, -1 when failure. 278 */ 279 MXNET_DLL int MXPredPartialForward(PredictorHandle handle, int step, int* step_left); 280 /*! 281 * \brief Get the output value of prediction. 282 * \param handle The handle of the predictor. 283 * \param index The index of output node, set to 0 if there is only one output. 284 * \param data User allocated data to hold the output. 285 * \param size The size of data array, used for safe checking. 286 * \return 0 when success, -1 when failure. 287 */ 288 MXNET_DLL int MXPredGetOutput(PredictorHandle handle, 289 uint32_t index, 290 float* data, 291 uint32_t size); 292 /*! 293 * \brief Free a predictor handle. 294 * \param handle The handle of the predictor. 295 * \return 0 when success, -1 when failure. 296 */ 297 MXNET_DLL int MXPredFree(PredictorHandle handle); 298 /*! 299 * \brief Create a NDArray List by loading from ndarray file. 300 * This can be used to load mean image file. 301 * \param nd_file_bytes The byte contents of nd file to be loaded. 302 * \param nd_file_size The size of the nd file to be loaded. 303 * \param out The out put NDListHandle 304 * \param out_length Length of the list. 305 * \return 0 when success, -1 when failure. 306 */ 307 MXNET_DLL int MXNDListCreate(const char* nd_file_bytes, 308 int nd_file_size, 309 NDListHandle *out, 310 uint32_t* out_length); 311 /*! 312 * \brief Get an element from list 313 * \param handle The handle to the NDArray 314 * \param index The index in the list 315 * \param out_key The output key of the item 316 * \param out_data The data region of the item 317 * \param out_shape The shape of the item. 318 * \param out_ndim The number of dimension in the shape. 319 * \return 0 when success, -1 when failure. 320 */ 321 MXNET_DLL int MXNDListGet(NDListHandle handle, 322 uint32_t index, 323 const char** out_key, 324 const float** out_data, 325 const uint32_t** out_shape, 326 uint32_t* out_ndim); 327 328 /*! 329 * \brief set a call back to notify the completion of operation and allow for 330 * additional monitoring 331 */ 332 MXNET_DLL int MXPredSetMonitorCallback(PredictorHandle handle, 333 PredMonitorCallback callback, 334 void* callback_handle, 335 bool monitor_all); 336 /*! 337 * \brief Free a MXAPINDList 338 * \param handle The handle of the MXAPINDList. 339 * \return 0 when success, -1 when failure. 340 */ 341 MXNET_DLL int MXNDListFree(NDListHandle handle); 342 343 #ifdef __cplusplus 344 } 345 #endif // __cplusplus 346 347 #endif // MXNET_C_PREDICT_API_H_ 348