1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_api.h
22  * \brief C API of mxnet
23  */
24 #ifndef MXNET_C_API_H_
25 #define MXNET_C_API_H_
26 
27 /*! \brief Inhibit C++ name-mangling for MXNet functions. */
28 #ifdef __cplusplus
29 extern "C" {
30 #endif  // __cplusplus
31 
32 /*! \brief Keep the default value in C++ */
33 #ifdef __cplusplus
34 #define DEFAULT(x) = x
35 #else
36 #define DEFAULT(x)
37 #endif  // __cplusplus
38 
39 #include <stdint.h>
40 
41 #include <stdint.h>
42 #include <stddef.h>
43 #include <stdbool.h>
44 
45 /*! \brief MXNET_DLL prefix for windows */
46 #ifdef _WIN32
47 #ifdef MXNET_EXPORTS
48 #define MXNET_DLL __declspec(dllexport)
49 #else
50 #define MXNET_DLL __declspec(dllimport)
51 #endif
52 #else
53 #define MXNET_DLL
54 #endif
55 
56 /*! \brief manually define unsigned int */
57 typedef uint32_t mx_uint;
58 /*! \brief manually define float */
59 typedef float mx_float;
60 /*! \brief data type to store dim size */
61 typedef int64_t dim_t;
62 // all the handles are simply void *
63 // will be casted internally to specific pointers types
64 // these typedefs are mainly used for readablity reasons
65 /*! \brief handle to NDArray */
66 typedef void *NDArrayHandle;
67 /*! \brief handle to a mxnet narray function that changes NDArray */
68 typedef const void *FunctionHandle;
69 /*! \brief handle to a function that takes param and creates symbol */
70 typedef void *AtomicSymbolCreator;
71 /*! \brief handle to cached operator */
72 typedef void *CachedOpHandle;
73 /*! \brief handle to a symbol that can be bind as operator */
74 typedef void *SymbolHandle;
75 /*! \brief handle to a AtomicSymbol */
76 typedef void *AtomicSymbolHandle;
77 /*! \brief handle to an Executor */
78 typedef void *ExecutorHandle;
79 /*! \brief handle a dataiter creator */
80 typedef void *DataIterCreator;
81 /*! \brief handle to a DataIterator */
82 typedef void *DataIterHandle;
83 /*! \brief handle to KVStore */
84 typedef void *KVStoreHandle;
85 /*! \brief handle to RecordIO */
86 typedef void *RecordIOHandle;
87 /*! \brief handle to MXRtc*/
88 typedef void *RtcHandle;
89 /*! \brief handle to rtc cuda module*/
90 typedef void *CudaModuleHandle;
91 /*! \brief handle to rtc cuda kernel*/
92 typedef void *CudaKernelHandle;
93 /*! \brief handle to a Profile object (domain, duration, counter, etc.) */
94 typedef void *ProfileHandle;
95 /*! \brief handle to DLManagedTensor*/
96 typedef void *DLManagedTensorHandle;
97 /*! \brief handle to Context */
98 typedef const void *ContextHandle;
99 /*! \brief handle to Engine FnProperty */
100 typedef const void *EngineFnPropertyHandle;
101 /*! \brief handle to Engine VarHandle */
102 typedef void *EngineVarHandle;
103 
104 /*! \brief Engine asynchronous operation */
105 typedef void (*EngineAsyncFunc)(void*, void*, void*);
106 /*! \brief Engine synchronous operation */
107 typedef void (*EngineSyncFunc)(void*, void*);
108 /*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
109 typedef void (*EngineFuncParamDeleter)(void*);
110 typedef void (*ExecutorMonitorCallback)(const char*,
111                                         NDArrayHandle,
112                                         void*);
113 /*! \brief Monitor callback called at operator level for cached op */
114 typedef void (*CachedOpMonitorCallback)(const char*,
115                                         const char*,
116                                         NDArrayHandle);
117 
118 
119 struct NativeOpInfo {
120   void (*forward)(int, float**, int*, unsigned**, int*, void*);
121   void (*backward)(int, float**, int*, unsigned**, int*, void*);
122   void (*infer_shape)(int, int*, unsigned**, void*);
123   void (*list_outputs)(char***, void*);
124   void (*list_arguments)(char***, void*);
125   // all functions also pass a payload void* pointer
126   void* p_forward;
127   void* p_backward;
128   void* p_infer_shape;
129   void* p_list_outputs;
130   void* p_list_arguments;
131 };
132 
133 struct NDArrayOpInfo {
134   bool (*forward)(int, void**, int*, void*);
135   bool (*backward)(int, void**, int*, void*);
136   bool (*infer_shape)(int, int*, unsigned**, void*);
137   bool (*list_outputs)(char***, void*);
138   bool (*list_arguments)(char***, void*);
139   bool (*declare_backward_dependency)(const int*, const int*, const int*,
140                                       int*, int**, void*);
141   // all functions also pass a payload void* pointer
142   void* p_forward;
143   void* p_backward;
144   void* p_infer_shape;
145   void* p_list_outputs;
146   void* p_list_arguments;
147   void* p_declare_backward_dependency;
148 };
149 
150 typedef int (*MXGenericCallback)(void);
151 
152 struct MXCallbackList {
153   int num_callbacks;
154   int (**callbacks)(void);
155   void **contexts;
156 };
157 
158 struct LibFeature {
159   const char* name;
160   bool enabled;
161 };
162 
163 enum CustomOpCallbacks {
164   kCustomOpDelete,
165   kCustomOpForward,
166   kCustomOpBackward
167 };
168 
169 enum CustomOpPropCallbacks {
170   kCustomOpPropDelete,
171   kCustomOpPropListArguments,
172   kCustomOpPropListOutputs,
173   kCustomOpPropListAuxiliaryStates,
174   kCustomOpPropInferShape,
175   kCustomOpPropDeclareBackwardDependency,
176   kCustomOpPropCreateOperator,
177   kCustomOpPropInferType,
178   kCustomOpPropInferStorageType,
179   kCustomOpPropBackwardInferStorageType
180 };
181 
182 
183 typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/,
184                               const int* /*reqs*/, const int /*is_train*/,
185                               void* /*state*/);
186 typedef int (*CustomOpDelFunc)(void* /*state*/);
187 typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
188 typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
189                                       int** /*shapes*/, void* /*state*/);
190 typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
191 typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
192                                                     int * /*stypes*/,
193                                                     int * /*tags*/,
194                                                     void * /*state*/);
195 typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/);
196 typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/,
197                                   const int* /*out_data*/, int* /*num_deps*/,
198                                   int** /*rdeps*/, void* /*state*/);
199 typedef int (*CustomOpCreateFunc)(const char* /*ctx*/, int /*num_inputs*/,
200                                   unsigned** /*shapes*/, const int* /*ndims*/,
201                                   const int* /*dtypes*/, struct MXCallbackList* /*ret*/,
202                                   void* /*state*/);
203 typedef int (*CustomOpPropCreator)(const char* /*op_type*/, const int /*num_kwargs*/,
204                                    const char** /*keys*/, const char** /*values*/,
205                                    struct MXCallbackList* /*ret*/);
206 
207 
208 enum CustomFunctionCallbacks {
209   kCustomFunctionBackward,
210   kCustomFunctionDelete
211 };
212 
213 typedef int (*CustomFunctionBwdFunc)(int /*num_ograds*/, int /*num_igrads*/, void** /*ptrs*/,
214                                      const int* /*reqs*/, const int /*is_train*/,
215                                      void* /*state*/);
216 typedef int (*CustomFunctionDelFunc)(void* /*state*/);
217 
218 /*!
219  * \brief return str message of the last error
220  *  all function in this file will return 0 when success
221  *  and -1 when an error occured,
222  *  MXGetLastError can be called to retrieve the error
223  *
224  *  this function is threadsafe and can be called by different thread
225  *  \return error info
226  */
227 MXNET_DLL const char *MXGetLastError();
228 
229 //-------------------------------------
230 // Part 0: Global State setups
231 //-------------------------------------
232 
233 /*!
234  * \brief Load library dynamically
235  * \param path to the library .so file
236  * \param verbose 0 for quiet, 1 for verbose
237  * \return 0 when success, -1 when failure happens.
238  */
239 MXNET_DLL int MXLoadLib(const char *path, unsigned verbose);
240 
241 /*!
242  * \brief Get list of features supported on the runtime
243  * \param libFeature pointer to array of LibFeature
244  * \param size of the array
245  * \return 0 when success, -1 when failure happens.
246  */
247 MXNET_DLL int MXLibInfoFeatures(const struct LibFeature **libFeature, size_t *size);
248 
249 /*!
250  * \brief Seed all global random number generators in mxnet.
251  * \param seed the random number seed.
252  * \return 0 when success, -1 when failure happens.
253  */
254 MXNET_DLL int MXRandomSeed(int seed);
255 
256 /*!
257  * \brief Seed the global random number generator of the given device.
258  * \param seed the random number seed.
259  * \return 0 when success, -1 when failure happens.
260  */
261 MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
262 
263 /*!
264  * \brief Notify the engine about a shutdown,
265  *  This can help engine to print less messages into display.
266  *
267  *  User do not have to call this function.
268  * \return 0 when success, -1 when failure happens.
269  */
270 MXNET_DLL int MXNotifyShutdown();
271 
272 /*!
273  * \brief Set up configuration of profiler for the process passed as profile_process in keys
274  * \param num_params Number of parameters
275  * \param keys array of parameter keys
276  * \param vals array of parameter values
277  * \param kvstoreHandle handle to kvstore
278  * \return 0 when success, -1 when failure happens.
279  */
280 MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
281                                          const char* const* vals,
282                                          KVStoreHandle kvstoreHandle);
283 
284 /*!
285  * \brief Set up configuration of profiler for worker/current process
286  * \param num_params Number of parameters
287  * \param keys array of parameter keys
288  * \param vals array of parameter values
289  * \return 0 when success, -1 when failure happens.
290  */
291 MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);
292 
293 /*!
294  * \brief Set up state of profiler for either worker or server process
295  * \param state indicate the working state of profiler,
296  *  profiler not running when state == 0,
297  *  profiler running when state == 1
298  * \param profile_process an int,
299  * when 0 command is for worker/current process,
300  * when 1 command is for server process
301  * \param kvstoreHandle handle to kvstore, needed for server process profiling
302  * \return 0 when success, -1 when failure happens.
303  */
304 MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
305                                         KVStoreHandle kvStoreHandle);
306 
307 /*!
308  * \brief Set up state of profiler for current process
309  * \param state indicate the working state of profiler,
310  *  profiler not running when state == 0,
311  *  profiler running when state == 1
312  * \return 0 when success, -1 when failure happens.
313  */
314 MXNET_DLL int MXSetProfilerState(int state);
315 
316 /*!
317  * \brief Save profile and stop profiler
318  * \param finished true if stat output should stop after this point
319  * \param profile_process an int,
320  * when 0 command is for worker/current process,
321  * when 1 command is for server process
322  * \param kvstoreHandle handle to kvstore
323  * \return 0 when success, -1 when failure happens.
324  */
325 MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);
326 
327 
328 /*!
329  * \brief Save profile and stop profiler for worker/current process
330  * \param finished true if stat output should stop after this point
331  * \return 0 when success, -1 when failure happens.
332  */
333 MXNET_DLL int MXDumpProfile(int finished);
334 
335 
336 /*!
337  * \brief Deprecated, use MXAggregateProfileStatsPrintEx instead.
338  * \param out_str Will receive a pointer to the output string
339  * \param reset Clear the aggregate stats after printing
340  * \return 0 when success, -1 when failure happens.
341  * \note
342  */
343 MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
344 
345 /*!
346  * \brief Print sorted aggregate stats to the a string
347  *        How aggregate stats are stored will not change
348  * \param out_str will receive a pointer to the output string
349  * \param reset clear the aggregate stats after printing
350  * \param format whether to return in tabular or json format
351  * \param sort_by sort by total, avg, min, max, or count
352  * \param ascending whether to sort ascendingly
353  * \return 0 when success, -1 when failure happens.
354  * \note
355  */
356 MXNET_DLL int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format,
357                                             int sort_by, int ascending);
358 
359 /*!
360  * \brief Pause profiler tuning collection
361  * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
362  * \param profile_process integer which denotes whether to process worker or server process
363  * \param kvstoreHandle handle to kvstore
364  * \return 0 when success, -1 when failure happens.
365  * \note pausing and resuming is global and not recursive
366  */
367 MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);
368 
369 /*!
370  * \brief Pause profiler tuning collection for worker/current process
371  * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
372  * \return 0 when success, -1 when failure happens.
373  * \note pausing and resuming is global and not recursive
374  */
375 MXNET_DLL int MXProfilePause(int paused);
376 
377 /*!
378  * \brief Create profiling domain
379  * \param domain String representing the domain name to create
380  * \param out Return domain object
381  * \return 0 when success, -1 when failure happens.
382  */
383 MXNET_DLL int MXProfileCreateDomain(const char *domain, ProfileHandle *out);
384 
385 /*!
386  * \brief Create profile task
387  * \param name Name of the task
388  * \param domain Domain of the task
389  * \param out Output handle
390  * \return 0 when success, -1 when failure happens.
391  */
392 MXNET_DLL int MXProfileCreateTask(ProfileHandle domain,
393                                   const char *task_name,
394                                   ProfileHandle *out);
395 
396 /*!
397  * \brief Create profile frame
398  * \param name Name of the frame
399  * \param domain Domain of the frame
400  * \param out Output handle
401  * \return 0 when success, -1 when failure happens.
402  */
403 MXNET_DLL int MXProfileCreateFrame(ProfileHandle domain,
404                                    const char *frame_name,
405                                    ProfileHandle *out);
406 
407 /*!
408  * \brief Create profile event
409  * \param name Name of the event
410  * \param out Output handle
411  * \return 0 when success, -1 when failure happens.
412  */
413 MXNET_DLL int MXProfileCreateEvent(const char *event_name, ProfileHandle *out);
414 
415 /*!
416  * \brief Create profile counter
417  * \param name Name of the counter
418  * \param domain Domain of the counter
419  * \param out Output handle
420  * \return 0 when success, -1 when failure happens.
421  */
422 MXNET_DLL int MXProfileCreateCounter(ProfileHandle domain,
423                                      const char *counter_name,
424                                      ProfileHandle *out);
425 
426 /*!
427  * \brief Destroy a frame
428  * \param frame_handle Handle to frame to destroy
429  * \return 0 when success, -1 when failure happens.
430  */
431 MXNET_DLL int MXProfileDestroyHandle(ProfileHandle frame_handle);
432 
433 /*!
434  * \brief Start timing the duration of a profile duration object such as an event, task or frame
435  * \param duration_handle handle to the duration object
436  * \return 0 when success, -1 when failure happens.
437  */
438 MXNET_DLL int MXProfileDurationStart(ProfileHandle duration_handle);
439 
440 /*!
441  * \brief Stop timing the duration of a profile duration object such as an event, task or frame
442  * \param duration_handle handle to the duration object
443  * \return 0 when success, -1 when failure happens.
444  */
445 MXNET_DLL int MXProfileDurationStop(ProfileHandle duration_handle);
446 
447 /*!
448  * \brief Set a counter, given its handle
449  * \param counter_handle Handle to counter to set
450  * \param value Value to set the counter to (64-bit unsigned integer)
451  * \return 0 when success, -1 when failure happens.
452  */
453 MXNET_DLL int MXProfileSetCounter(ProfileHandle counter_handle, uint64_t value);
454 
455 /*!
456  * \brief Adjust a counter by the given amount, given its handle
457  * \param counter_handle Handle to counter to adjust
458  * \param value Value to adjust the counter by (64-bit signed integer)
459  * \return 0 when success, -1 when failure happens.
460  */
461 MXNET_DLL int MXProfileAdjustCounter(ProfileHandle counter_handle, int64_t value);
462 
463 /*!
464  * \brief Mark a single instant in time
465  * \param domain Domain of the marker
466  * \param instant_marker_name Name of the marker
467  * \param scope Scope of marker ('global', 'process', 'thread', 'task', 'marker')
468  * \return 0 when success, -1 when failure happens.
469  */
470 MXNET_DLL int MXProfileSetMarker(ProfileHandle domain,
471                                  const char *instant_marker_name,
472                                  const char *scope);
473 
474 /*!
475  * \brief Set the number of OMP threads to use
476  * \param thread_num Number of OMP threads desired
477  * \return 0 when success, -1 when failure happens.
478  */
479 MXNET_DLL int MXSetNumOMPThreads(int thread_num);
480 
481 /*!
482  * \brief set bulk execution limit
483  * \param bulk_size new bulk_size
484  * \param prev_bulk_size previous bulk_size
485  */
486 MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
487 
488 /*!
489  * \brief Get the number of GPUs.
490  * \param pointer to int that will hold the number of GPUs available.
491  * \return 0 when success, -1 when failure happens.
492  */
493 MXNET_DLL int MXGetGPUCount(int* out);
494 
495 /*!
496  * \brief get the free and total available memory on a GPU
497  *  Note: Deprecated, use MXGetGPUMemoryInformation64 instead.
498  * \param dev the GPU number to query
499  * \param free_mem pointer to the integer holding free GPU memory
500  * \param total_mem pointer to the integer holding total GPU memory
501  * \return 0 when success, -1 when failure happens
502  */
503 MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);
504 
505 /*!
506  * \brief get the free and total available memory on a GPU
507  * \param dev the GPU number to query
508  * \param free_mem pointer to the uint64_t holding free GPU memory
509  * \param total_mem pointer to the uint64_t holding total GPU memory
510  * \return 0 when success, -1 when failure happens
511  */
512 MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem);
513 
514 /*!
515  * \brief get the MXNet library version as an integer
516  * \param pointer to the integer holding the version number
517  * \return 0 when success, -1 when failure happens
518  */
519 MXNET_DLL int MXGetVersion(int *out);
520 
521 /*!
522  * \brief Load TVM operator from the binary library
523  * \param libpath TVM operators lib file
524  * \return 0 when success, -1 when failure happens
525  */
526 #if MXNET_USE_TVM_OP
527 MXNET_DLL int MXLoadTVMOp(const char *libpath);
528 
529 struct OtherOptionEntity {
530   int val;
531 };
532 
533 struct OtherOptionSpace {
534   OtherOptionEntity* entities;
535   int entities_size;
536 };
537 
538 struct ConfigSpace {
539   int entity_map_size;
540   char** entity_map_key;
541   OtherOptionEntity* entity_map_val;
542   int space_map_size;
543   char** space_map_key;
544   OtherOptionSpace* space_map_val;
545 };
546 
547 typedef struct ConfigSpaces {
548   int spaces_size;
549   char** spaces_key;
550   ConfigSpace* spaces_val;
551 } ConfigSpaces;
552 
553 MXNET_DLL int MXLoadTVMConfig(ConfigSpaces config);
554 #endif  // MXNET_USE_TVM_OP
555 
556 
557 //-------------------------------------
558 // Part 1: NDArray creation and deletion
559 //-------------------------------------
560 /*!
561  * \brief create a NDArray handle that is not initialized
562  *  can be used to pass in as mutate variables
563  *  to hold the result of NDArray
564  * \param out the returning handle
565  * \return 0 when success, -1 when failure happens
566  */
567 MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out);
568 /*!
569  * \brief create a NDArray with specified shape
570  * \param shape the pointer to the shape
571  * \param ndim the dimension of the shape
572  * \param dev_type device type, specify device we want to take
573  * \param dev_id the device id of the specific device
574  * \param delay_alloc whether to delay allocation until
575  *    the narray is first mutated
576  * \param out the returning handle
577  * \return 0 when success, -1 when failure happens
578  */
579 MXNET_DLL int MXNDArrayCreate(const uint32_t *shape,
580                               uint32_t ndim,
581                               int dev_type,
582                               int dev_id,
583                               int delay_alloc,
584                               NDArrayHandle *out);
585 
586 /*!
587  * \brief create a NDArray with specified shape and data type
588  *  This api is available when MXNet is built with flag
589  *  USE_INT64_TENSOR_SIZE=0 (by default)
590  * \param shape the pointer to the shape
591  * \param ndim the dimension of the shape
592  * \param dev_type device type, specify device we want to take
593  * \param dev_id the device id of the specific device
594  * \param delay_alloc whether to delay allocation until
595  *    the narray is first mutated
596  * \param dtype data type of created array
597  * \param out the returning handle
598  * \return 0 when success, -1 when failure happens
599  */
600 MXNET_DLL int MXNDArrayCreateEx(const uint32_t *shape,
601                                 uint32_t ndim,
602                                 int dev_type,
603                                 int dev_id,
604                                 int delay_alloc,
605                                 int dtype,
606                                 NDArrayHandle *out);
607 
608 /*!
609  * \brief create a NDArray with specified shape and data type
610  *  This api is available when MXNet is built with flag
611  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
612  * \param shape the pointer to int64_t shape
613  * \param ndim the dimension of the shape
614  * \param dev_type device type, specify device we want to take
615  * \param dev_id the device id of the specific device
616  * \param delay_alloc whether to delay allocation until
617  *    the narray is first mutated
618  * \param dtype data type of created array
619  * \param out the returning handle
620  * \return 0 when success, -1 when failure happens
621  */
622 MXNET_DLL int MXNDArrayCreateEx64(const int64_t *shape,
623                                   int ndim,
624                                   int dev_type,
625                                   int dev_id,
626                                   int delay_alloc,
627                                   int dtype,
628                                   NDArrayHandle *out);
629 
630 /*!
631  * \brief create an empty sparse NDArray with specified shape and data type
632  *  This api is available when MXNet is built with flag
633  *  USE_INT64_TENSOR_SIZE=0 (by default)
634  * \param storage_type the storage type of the ndarray
635  * \param shape the pointer to the shape
636  * \param ndim the dimension of the shape
637  * \param dev_type device type, specify device we want to take
638  * \param dev_id the device id of the specific device
639  * \param delay_alloc whether to delay allocation until
640  *        the narray is first mutated
641  * \param dtype data type of created array
642  * \param num_aux the number of aux data to support this ndarray
643  * \param aux_type data type of the aux data for the created array
644  * \param aux_ndims the dimension of the shapes of aux data
645  * \param aux_shape the shapes of aux data
646  * \param out the returning handle
647  * \return 0 when success, -1 when failure happens
648  */
649 MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
650                                       const uint32_t *shape,
651                                       uint32_t ndim,
652                                       int dev_type,
653                                       int dev_id,
654                                       int delay_alloc,
655                                       int dtype,
656                                       uint32_t num_aux,
657                                       int *aux_type,
658                                       uint32_t *aux_ndims,
659                                       const uint32_t *aux_shape,
660                                       NDArrayHandle *out);
661 
662 /*!
663  * \brief create an empty sparse NDArray with specified shape and data type
664  *  This api is available when MXNet is built with flag
665  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
666  * \param storage_type the storage type of the ndarray
667  * \param shape the pointer to the shape
668  * \param ndim the dimension of the shape
669  * \param dev_type device type, specify device we want to take
670  * \param dev_id the device id of the specific device
671  * \param delay_alloc whether to delay allocation until
672  *        the narray is first mutated
673  * \param dtype data type of created array
674  * \param num_aux the number of aux data to support this ndarray
675  * \param aux_type data type of the aux data for the created array
676  * \param aux_ndims the dimension of the shapes of aux data
677  * \param aux_shape the shapes of aux data
678  * \param out the returning handle
679  * \return 0 when success, -1 when failure happens
680  */
681 MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
682                                         const int64_t *shape,
683                                         int ndim,
684                                         int dev_type,
685                                         int dev_id,
686                                         int delay_alloc,
687                                         int dtype,
688                                         uint32_t num_aux,
689                                         int *aux_type,
690                                         int *aux_ndims,
691                                         const int64_t *aux_shape,
692                                         NDArrayHandle *out);
693 
694 /*!
695  * \brief create a NDArray handle that is loaded from raw bytes.
696  * \param buf the head of the raw bytes
697  * \param size size of the raw bytes
698  * \param out the returning handle
699  * \return 0 when success, -1 when failure happens
700  */
701 MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf,
702                                         size_t size,
703                                         NDArrayHandle *out);
704 /*!
705  * \brief save the NDArray into raw bytes.
706  * \param handle the NDArray handle
707  * \param out_size size of the raw bytes
708  * \param out_buf the head of returning memory bytes.
709  * \return 0 when success, -1 when failure happens
710  */
711 MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle,
712                                     size_t *out_size,
713                                     const char **out_buf);
714 /*!
715  * \brief Save list of narray into the file.
716  * \param fname name of the file.
717  * \param num_args number of arguments to save.
718  * \param args the array of NDArrayHandles to be saved.
719  * \param keys the name of the NDArray, optional, can be NULL
720  * \return 0 when success, -1 when failure happens
721  */
722 MXNET_DLL int MXNDArraySave(const char* fname,
723                             uint32_t num_args,
724                             NDArrayHandle* args,
725                             const char** keys);
726 /*!
727  * \brief Load list of narray from the file.
728  * \param fname name of the file.
729  * \param out_size number of narray loaded.
730  * \param out_arr head of the returning narray handles.
731  * \param out_name_size size of output name arrray.
732  * \param out_names the names of returning NDArrays, can be NULL
733  * \return 0 when success, -1 when failure happens
734  */
735 MXNET_DLL int MXNDArrayLoad(const char* fname,
736                             uint32_t *out_size,
737                             NDArrayHandle** out_arr,
738                             uint32_t *out_name_size,
739                             const char*** out_names);
740 
741 /*!
742  * \brief Load list / dictionary of narrays from file content loaded into memory.
743  * This will load a list of ndarrays in a similar
744  * manner to MXNDArrayLoad, however, it loads from
745  * buffer containing the contents of a file, rather than
746  * from a specified file.
747  * \param ndarray_buffer pointer to the start of the ndarray file content
748  * \param size size of the file
749  * \param out_size number of narray loaded.
750  * \param out_arr head of the returning narray handles.
751  * \param out_name_size size of output name arrray.
752  * \param out_names the names of returning NDArrays, can be NULL
753  * \return 0 when success, -1 when failure happens
754  */
755 MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
756                                       size_t size,
757                                       uint32_t *out_size,
758                                       NDArrayHandle** out_arr,
759                                       uint32_t *out_name_size,
760                                       const char*** out_names);
761 
762 /*!
763  * \brief Perform a synchronize copy from a contiguous CPU memory region.
764  *
765  *  This function will call WaitToWrite before the copy is performed.
766  *  This is useful to copy data from existing memory region that are
767  *  not wrapped by NDArray(thus dependency not being tracked).
768  *
769  * \param handle the NDArray handle
770  * \param data the data source to copy from.
771  * \param size the memory size we want to copy from.
772  */
773 MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
774                                        const void *data,
775                                        size_t size);
776 /*!
777  * \brief Perform a synchronize copyto a contiguous CPU memory region.
778  *
779  *  This function will call WaitToRead before the copy is performed.
780  *  This is useful to copy data from existing memory region that are
781  *  not wrapped by NDArray(thus dependency not being tracked).
782  *
783  * \param handle the NDArray handle
784  * \param data the data source to copy into.
785  * \param size the memory size we want to copy into.
786  */
787 MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
788                                      void *data,
789                                      size_t size);
790 
791 /*!
792  * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
793  * This function blocks. Do not use it in performance critical code.
794  * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated
795  * \param handle_src handle of a src ndarray which has default storage type
796  * \param i dst data blob indicator
797  */
798 MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
799                                            const NDArrayHandle handle_src,
800                                            const int i);
801 
802 /*!
803  * \brief check whether the NDArray format is valid
804  * \param full_check if `True`, rigorous check, O(N) operations
805  *    Otherwise basic check, O(1) operations
806  */
807 MXNET_DLL int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check);
808 
809 /*!
810  * \brief Wait until all the pending writes with respect NDArray are finished.
811  *  Always call this before read data out synchronizely.
812  * \param handle the NDArray handle
813  * \return 0 when success, -1 when failure happens
814  */
815 MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle);
816 
817 /*!
818  * \brief Wait until all the pending read/write with respect NDArray are finished.
819  *  Always call this before write data into NDArray synchronizely.
820  * \param handle the NDArray handle
821  * \return 0 when success, -1 when failure happens
822  */
823 MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle);
824 
825 /*!
826  * \brief wait until all delayed operations in
827  *   the system is completed
828  * \return 0 when success, -1 when failure happens
829  */
830 MXNET_DLL int MXNDArrayWaitAll();
831 
832 /*!
833  * \brief free the narray handle
834  * \param handle the handle to be freed
835  * \return 0 when success, -1 when failure happens
836  */
837 MXNET_DLL int MXNDArrayFree(NDArrayHandle handle);
838 
839 /*!
840  * \brief Slice the NDArray along axis 0.
841  *  This api is available when MXNet is built with flag
842  *  USE_INT64_TENSOR_SIZE=0 (by default)
843  * \param handle the handle to the NDArray
844  * \param slice_begin The beginning index of slice
845  * \param slice_end The ending index of slice
846  * \param out The NDArrayHandle of sliced NDArray
847  * \return 0 when success, -1 when failure happens
848  */
849 MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
850                              uint32_t slice_begin,
851                              uint32_t slice_end,
852                              NDArrayHandle *out);
853 
854 /*!
855  * \brief Slice the NDArray along axis 0.
856  *  This api is available when MXNet is built with flag
857  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
858  * \param handle the handle to the NDArray
859  * \param slice_begin The beginning index of slice
860  * \param slice_end The ending index of slice
861  * \param out The NDArrayHandle of sliced NDArray
862  * \return 0 when success, -1 when failure happens
863  */
864 MXNET_DLL int MXNDArraySlice64(NDArrayHandle handle,
865                                int64_t slice_begin,
866                                int64_t slice_end,
867                                NDArrayHandle *out);
868 
869 /*!
870  * \brief Index the NDArray along axis 0.
871  *  This api is available when MXNet is built with flag
872  *  USE_INT64_TENSOR_SIZE=0 (by default)
873  * \param handle the handle to the NDArray
874  * \param idx the index
875  * \param out The NDArrayHandle of output NDArray
876  * \return 0 when success, -1 when failure happens
877  */
878 MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
879                           uint32_t idx,
880                           NDArrayHandle *out);
881 
882 /*!
883  * \brief Index the NDArray along axis 0.
884  *  This api is available when MXNet is built with flag
885  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
886  * \param handle the handle to the NDArray
887  * \param idx the index
888  * \param out The NDArrayHandle of output NDArray
889  * \return 0 when success, -1 when failure happens
890  */
891 MXNET_DLL int MXNDArrayAt64(NDArrayHandle handle,
892                             int64_t idx,
893                             NDArrayHandle *out);
894 
895 /*!
896  * \brief get the storage type of the array
897  */
898 MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
899                                       int *out_storage_type);
900 
901 /*!
902  * \brief Reshape the NDArray.
903  * \param handle the handle to the narray
904  * \param ndim number of dimensions of new shape
905  * \param dims new shape
906  * \param out the NDArrayHandle of reshaped NDArray
907  * \return 0 when success, -1 when failure happens
908  */
909 MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
910                                int ndim,
911                                int *dims,
912                                NDArrayHandle *out);
913 
914 /*!
915  * \brief Reshape the NDArray.
916  * \param handle the handle to the narray
917  * \param ndim number of dimensions of new shape
918  * \param dims new shape
919  * \param out the NDArrayHandle of reshaped NDArray
920  * \return 0 when success, -1 when failure happens
921  */
922 MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
923                                  int ndim,
924                                  dim_t *dims,
925                                  bool reverse,
926                                  NDArrayHandle *out);
927 /*!
928  * \brief DEPRECATED. Use MXNDArrayGetShapeEx instead.
929  * get the shape of the array
930  * \param handle the handle to the narray
931  * \param out_dim the output dimension
932  * \param out_pdata pointer holder to get data pointer of the shape
933  * \return 0 when success, -1 when failure happens
934  */
935 MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
936                                 uint32_t *out_dim,
937                                 const uint32_t **out_pdata);
938 
939 /*!
940  * \brief get the shape of the array
941  *  This api is available when MXNet is built with flag
942  *  USE_INT64_TENSOR_SIZE=0 (by default)
943  * \param handle the handle to the narray
944  * \param out_dim the output dimension
945  * \param out_pdata pointer holder to get data pointer of the shape
946  * \return 0 when success, -1 when failure happens
947  */
948 MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
949                                   int *out_dim,
950                                   const int **out_pdata);
951 
952 /*!
953  * \brief get the shape of the array
954  *  This api is available when MXNet is built with flag
955  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
956  * \param handle the handle to the narray
957  * \param out_dim the output dimension
958  * \param out_pdata pointer holder to get data pointer of the shape
959  * \return 0 when success, -1 when failure happens
960  */
961 MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
962                                     int *out_dim,
963                                     const int64_t **out_pdata);
964 
965 /*!
966  * \brief get the content of the data in NDArray
967  * \param handle the handle to the ndarray
968  * \param out_pdata pointer holder to get pointer of data
969  * \return 0 when success, -1 when failure happens
970  */
971 MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
972                                void **out_pdata);
973 /*!
974 * \brief Create a reference view of NDArray that
975 *  represents as DLManagedTensor
976 *  Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead or
977 *          MXNDArrayWaitToWrite before calling MXNDArrayToDLPack.
978 * \param handle the handle to the ndarray
979 * \param out_dlpack pointer holder to get pointer of DLManagedTensor
980 * \return 0 when success, -1 when failure happens
981 */
982 MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
983                                        DLManagedTensorHandle *out_dlpack);
984 
985 /*!
986 * \brief DEPRECATED. Use MXNDArrayFromDLPackEx instead.
987 
988 *
989 * This allows us to create a NDArray using the memory
990 * allocated by an external deep learning framework
991 * that is DLPack compatible.
992 *
993 * The memory is retained until the NDArray went out of scope.
994 *
995 * \param dlpack the pointer of the input DLManagedTensor
996 * \param transient_handle whether the handle will be destructed before calling the deleter
997 * \param out_handle pointer holder to get pointer of NDArray
998 * \return 0 when success, -1 when failure happens
999 */
1000 MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
1001                                   NDArrayHandle *out_handle);
1002 
1003 /*!
1004 * \brief Create a NDArray backed by a dlpack tensor.
1005 *
1006 * This allows us to create a NDArray using the memory
1007 * allocated by an external deep learning framework
1008 * that is DLPack compatible.
1009 *
1010 * The memory is retained until the NDArray went out of scope.
1011 *
1012 * \param dlpack the pointer of the input DLManagedTensor
1013 * \param transient_handle whether the handle will be destructed before calling the deleter
1014 * \param out_handle pointer holder to get pointer of NDArray
1015 * \return 0 when success, -1 when failure happens
1016 */
1017 MXNET_DLL int MXNDArrayFromDLPackEx(DLManagedTensorHandle dlpack,
1018                                     const bool transient_handle,
1019                                     NDArrayHandle *out_handle);
1020 
1021 /*!
1022  * \brief Delete a dlpack tensor
1023  * \param dlpack the pointer of the input DLManagedTensor
1024  * \return 0 when success, -1 when failure happens
1025  */
1026 MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack);
1027 
1028 /*!
1029  * \brief get the type of the data in NDArray
1030  * \param handle the handle to the narray
1031  * \param out_dtype pointer holder to get type of data
1032  * \return 0 when success, -1 when failure happens
1033  */
1034 MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
1035                                int *out_dtype);
1036 
1037 /*!
1038  * \brief get the type of the ith aux data in NDArray
1039  *  This api is available when MXNet is built with flag
1040  *  USE_INT64_TENSOR_SIZE=0 (by default)
1041  * \param handle the handle to the narray
1042  * \param i the index of the aux data
1043  * \param out_type pointer holder to get type of aux data
1044  * \return 0 when success, -1 when failure happens
1045  */
1046 MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
1047                                   uint32_t i,
1048                                   int *out_type);
1049 
1050 /*!
1051  * \brief get the type of the ith aux data in NDArray
1052  *  This api is available when MXNet is built with flag
1053  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
1054  * \param handle the handle to the narray
1055  * \param i the index of the aux data
1056  * \param out_type pointer holder to get type of aux data
1057  * \return 0 when success, -1 when failure happens
1058  */
1059 MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
1060                                     int64_t i,
1061                                     int *out_type);
1062 
1063 /*!
1064  * \brief Get a deep copy of the ith aux data blob
1065  *  This api is available when MXNet is built with flag
1066  *  USE_INT64_TENSOR_SIZE=0 (by default)
1067  * in the form of an NDArray of default storage type.
1068  * This function blocks. Do not use it in performance critical code.
1069  */
1070 MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
1071                                      uint32_t i,
1072                                      NDArrayHandle *out);
1073 
1074 /*!
1075  * \brief Get a deep copy of the ith aux data blob
1076  *  This api is available when MXNet is built with flag
1077  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
1078  * in the form of an NDArray of default storage type.
1079  * This function blocks. Do not use it in performance critical code.
1080  */
1081 MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
1082                                        int64_t i,
1083                                        NDArrayHandle *out);
1084 
1085 /*!
1086  * \brief Get a deep copy of the data blob
1087  * in the form of an NDArray of default storage type.
1088  * This function blocks. Do not use it in performance critical code.
1089  */
1090 MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
1091                                       NDArrayHandle *out);
1092 /*!
1093  * \brief get the context of the NDArray
1094  * \param handle the handle to the narray
1095  * \param out_dev_type the output device type
1096  * \param out_dev_id the output device id
1097  * \return 0 when success, -1 when failure happens
1098  */
1099 MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
1100                                   int *out_dev_type,
1101                                   int *out_dev_id);
1102 /*!
1103  * \brief return gradient buffer attached to this NDArray
1104  * \param handle NDArray handle
1105  * \return 0 when success, -1 when failure happens
1106  */
1107 MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out);
1108 /*!
1109  * \brief detach and ndarray from computation graph by clearing entry_
1110  * \param handle NDArray handle
1111  * \return 0 when success, -1 when failure happens
1112  */
1113 MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out);
1114 /*!
1115  * \brief set the flag for gradient array state.
1116  * \param handle NDArray handle
1117  * \param state the new state.
1118  * \return 0 when success, -1 when failure happens
1119  */
1120 MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state);
1121 /*!
1122  * \brief set the flag for gradient array state.
1123  * \param handle NDArray handle
1124  * \param state the new state.
1125  * \return 0 when success, -1 when failure happens
1126  */
1127 MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
1128 //--------------------------------
1129 // Part 2: functions on NDArray
1130 //--------------------------------
1131 /*!
1132  * \brief list all the available functions handles
1133  *   most user can use it to list all the needed functions
1134  * \param out_size the size of returned array
1135  * \param out_array the output function array
1136  * \return 0 when success, -1 when failure happens
1137  */
1138 MXNET_DLL int MXListFunctions(uint32_t *out_size,
1139                               FunctionHandle **out_array);
1140 
1141 /*!
1142  * \brief get the function handle by name
1143  * \param name the name of the function
1144  * \param out the corresponding function handle
1145  * \return 0 when success, -1 when failure happens
1146  */
1147 MXNET_DLL int MXGetFunction(const char *name,
1148                             FunctionHandle *out);
1149 /*!
1150  * \brief Get the information of the function handle.
1151  * \param fun The function handle.
1152  * \param name The returned name of the function.
1153  * \param description The returned description of the function.
1154  * \param num_args Number of arguments.
1155  * \param arg_names Name of the arguments.
1156  * \param arg_type_infos Type information about the arguments.
1157  * \param arg_descriptions Description information about the arguments.
1158  * \param return_type Return type of the function.
1159  * \return 0 when success, -1 when failure happens
1160  */
1161 MXNET_DLL int MXFuncGetInfo(FunctionHandle fun,
1162                             const char **name,
1163                             const char **description,
1164                             uint32_t *num_args,
1165                             const char ***arg_names,
1166                             const char ***arg_type_infos,
1167                             const char ***arg_descriptions,
1168                             const char **return_type DEFAULT(NULL));
1169 /*!
1170  * \brief get the argument requirements of the function
1171  * \param fun input function handle
1172  * \param num_use_vars how many NDArrays to be passed in as used_vars
1173  * \param num_scalars scalar variable is needed
1174  * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars
1175  * \param type_mask the type mask of this function
1176  * \return 0 when success, -1 when failure happens
1177  * \sa MXFuncInvoke
1178  */
1179 MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
1180                              uint32_t *num_use_vars,
1181                              uint32_t *num_scalars,
1182                              uint32_t *num_mutate_vars,
1183                              int *type_mask);
1184 /*!
1185  * \brief invoke a function, the array size of passed in arguments
1186  *   must match the values in the
1187  * \param fun the function
1188  * \param use_vars the normal arguments passed to function
1189  * \param scalar_args the scalar qarguments
1190  * \param mutate_vars the mutate arguments
1191  * \return 0 when success, -1 when failure happens
1192  * \sa MXFuncDescribeArgs
1193  */
1194 MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
1195                            NDArrayHandle *use_vars,
1196                            float *scalar_args,
1197                            NDArrayHandle *mutate_vars);
1198 /*!
1199  * \brief invoke a function, the array size of passed in arguments
1200  *   must match the values in the
1201  * \param fun the function
1202  * \param use_vars the normal arguments passed to function
1203  * \param scalar_args the scalar qarguments
1204  * \param mutate_vars the mutate arguments
1205  * \param num_params number of keyword parameters
1206  * \param param_keys keys for keyword parameters
1207  * \param param_vals values for keyword parameters
1208  * \return 0 when success, -1 when failure happens
1209  * \sa MXFuncDescribeArgs
1210  */
1211 MXNET_DLL int MXFuncInvokeEx(FunctionHandle fun,
1212                              NDArrayHandle *use_vars,
1213                              float *scalar_args,
1214                              NDArrayHandle *mutate_vars,
1215                              int num_params,
1216                              char **param_keys,
1217                              char **param_vals);
1218 /*!
1219  * \brief invoke a nnvm op and imperative function
1220  * \param creator the op
1221  * \param num_inputs number of input NDArrays
1222  * \param inputs input NDArrays
1223  * \param num_outputs number of output NDArrays
1224  * \param outputs output NDArrays
1225  * \param num_params number of keyword parameters
1226  * \param param_keys keys for keyword parameters
1227  * \param param_vals values for keyword parameters
1228  * \return 0 when success, -1 when failure happens
1229  */
1230 MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
1231                                  int num_inputs,
1232                                  NDArrayHandle *inputs,
1233                                  int *num_outputs,
1234                                  NDArrayHandle **outputs,
1235                                  int num_params,
1236                                  const char **param_keys,
1237                                  const char **param_vals);
1238 /*!
1239  * \brief invoke a nnvm op and imperative function
1240  * \param creator the op
1241  * \param num_inputs number of input NDArrays
1242  * \param inputs input NDArrays
1243  * \param num_outputs number of output NDArrays
1244  * \param outputs output NDArrays
1245  * \param num_params number of keyword parameters
1246  * \param param_keys keys for keyword parameters
1247  * \param param_vals values for keyword parameters
1248  * \param out_stypes output ndarrays' stypes
1249  * \return 0 when success, -1 when failure happens
1250  */
1251 MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator,
1252                                    int num_inputs,
1253                                    NDArrayHandle *inputs,
1254                                    int *num_outputs,
1255                                    NDArrayHandle **outputs,
1256                                    int num_params,
1257                                    const char **param_keys,
1258                                    const char **param_vals,
1259                                    const int **out_stypes);
1260 /*!
1261  * \brief set whether to record operator for autograd
1262  * \param is_recording 1 when recording, 0 when not recording.
1263  * \param prev returns the previous status before this set.
1264  * \return 0 when success, -1 when failure happens
1265  */
1266 MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int* prev);
1267 /*!
1268  * \brief set whether to record operator for autograd
1269  * \param is_training 1 when training, 0 when testing
1270  * \param prev returns the previous status before this set.
1271  * \return 0 when success, -1 when failure happens
1272  */
1273 MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev);
1274 /*!
1275  * \brief get whether autograd recording is on
1276  * \param curr returns the current status.
1277  * \return 0 when success, -1 when failure happens
1278  */
1279 MXNET_DLL int MXAutogradIsRecording(bool* curr);
1280 /*!
1281  * \brief get whether training mode is on
1282  * \param curr returns the current status.
1283  * \return 0 when success, -1 when failure happens
1284  */
1285 MXNET_DLL int MXAutogradIsTraining(bool* curr);
1286 /*!
1287  * \brief get whether numpy compatibility is on
1288  * \param curr returns the current status
1289  * \return 0 when success, -1 when failure happens
1290  */
1291 MXNET_DLL int MXIsNumpyShape(int* curr);
1292 /*!
1293  * \brief set numpy compatibility switch
1294  * \param is_np_shape 1 when numpy shape semantics is thread local on,
1295  *        2 when numpy shape semantics is global on and 0 when off
1296  * \param prev returns the previous status before this set
1297  * \return 0 when success, -1 when failure happens
1298  */
1299 MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
1300 /*!
1301  * \brief mark NDArrays as variables to compute gradient for autograd
1302  * \param num_var number of variable NDArrays
1303  * \param var_handles variable NDArrays
1304  * \return 0 when success, -1 when failure happens
1305  */
1306 MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
1307                                       NDArrayHandle *var_handles,
1308                                       uint32_t *reqs_array,
1309                                       NDArrayHandle *grad_handles);
1310 /*!
1311  * \brief compute the gradient of outputs w.r.t variabels
1312  * \param num_output number of output NDArray
1313  * \param output_handles output NDArrays
1314  * \return 0 when success, -1 when failure happens
1315  */
1316 MXNET_DLL int MXAutogradComputeGradient(uint32_t num_output,
1317                                         NDArrayHandle* output_handles);
1318 /*!
1319  * \brief compute the gradient of outputs w.r.t variabels
1320  * \param num_output number of output NDArray
1321  * \param output_handles output NDArrays
1322  * \param ograd_handles head gradient for NDArrays
1323  * \param retain_graph whether to keep the graph after backward
1324  * \return 0 when success, -1 when failure happens
1325  */
1326 MXNET_DLL int MXAutogradBackward(uint32_t num_output,
1327                                  NDArrayHandle* output_handles,
1328                                  NDArrayHandle* ograd_handles,
1329                                  int retain_graph);
1330 /*!
1331  * \brief compute the gradient of outputs w.r.t variabels
1332  * \param num_output number of output NDArray
1333  * \param output_handles output NDArrays
1334  * \param ograd_handles head gradient for NDArrays
1335  * \param num_variables number of variables
1336  * \param
1337  * \param retain_graph whether to keep the graph after backward
1338  * \param is_train whether to do backward for training or inference
1339  * \return 0 when success, -1 when failure happens
1340  */
1341 MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output,
1342                                    NDArrayHandle *output_handles,
1343                                    NDArrayHandle *ograd_handles,
1344                                    uint32_t num_variables,
1345                                    NDArrayHandle *var_handles,
1346                                    int retain_graph,
1347                                    int create_graph,
1348                                    int is_train,
1349                                    NDArrayHandle **grad_handles,
1350                                    int **grad_stypes);
1351 /*
1352  * \brief get the graph constructed by autograd.
1353  * \param handle ndarray handle
1354  * \param out output symbol handle
1355  */
1356 MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out);
1357 /*!
1358  * \brief create cached operator
1359  */
1360 MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
1361 /*!
1362  * \brief create cached operator
1363  */
1364 MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
1365                                  int num_flags,
1366                                  const char** keys,
1367                                  const char** vals,
1368                                  CachedOpHandle *out);
1369 
1370 /*!
1371  * \brief create cached operator, allows to choose thread_safe version
1372  * of cachedop
1373  */
1374 MXNET_DLL int MXCreateCachedOpEX(SymbolHandle handle,
1375                                  int num_flags,
1376                                  const char** keys,
1377                                  const char** vals,
1378                                  CachedOpHandle *out,
1379                                  bool thread_safe DEFAULT(false));
1380 
1381 /*!
1382  * \brief free cached operator
1383  */
1384 MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle);
1385 
1386 /*!
1387  * \brief invoke cached operator
1388  */
1389 MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
1390                                int num_inputs,
1391                                NDArrayHandle *inputs,
1392                                int *num_outputs,
1393                                NDArrayHandle **outputs);
1394 
1395 /*!
1396  * \brief invoke a cached op
1397  * \param handle the handle to the cached op
1398  * \param num_inputs number of input NDArrays
1399  * \param inputs input NDArrays
1400  * \param num_outputs number of output NDArrays
1401  * \param outputs output NDArrays
1402  * \param out_stypes output ndarrays' stypes
1403  * \return 0 when success, -1 when failure happens
1404  */
1405 MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
1406                                  int num_inputs,
1407                                  NDArrayHandle *inputs,
1408                                  int *num_outputs,
1409                                  NDArrayHandle **outputs,
1410                                  const int** out_stypes);
1411 
1412 /*!
1413  * \brief cached op set monitor callback
1414  */
1415 MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
1416                                        CachedOpMonitorCallback callback,
1417                                        bool monitor_all);
1418 
1419 //--------------------------------------------
1420 // Part 3: symbolic configuration generation
1421 //--------------------------------------------
1422 /*!
1423  * \brief list all the available operator names, include entries.
1424  * \param out_size the size of returned array
1425  * \param out_array the output operator name array.
1426  * \return 0 when success, -1 when failure happens
1427  */
1428 MXNET_DLL int MXListAllOpNames(uint32_t *out_size,
1429                                const char ***out_array);
1430 
1431 /*!
1432  * \brief list all the available AtomicSymbolEntry
1433  * \param out_size the size of returned array
1434  * \param out_array the output AtomicSymbolCreator array
1435  * \return 0 when success, -1 when failure happens
1436  */
1437 MXNET_DLL int MXSymbolListAtomicSymbolCreators(uint32_t *out_size,
1438                                                AtomicSymbolCreator **out_array);
1439 
1440 /*!
1441  * \brief Get the name of an atomic symbol.
1442  * \param creator the AtomicSymbolCreator.
1443  * \param name The returned name of the creator.
1444  */
1445 MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
1446                                           const char **name);
1447 
1448 /*!
1449  * \brief Get the input symbols of the graph.
1450  * \param sym The graph.
1451  * \param inputs The input symbols of the graph.
1452  * \param input_size the number of input symbols returned.
1453  */
1454 MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
1455                                       int *input_size);
1456 
1457 /*!
1458  * \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
1459  * The input graph will be modified. A variable node will be created for each
1460  * edge that connects to nodes outside the subgraph. The outside nodes that
1461  * connect to the subgraph will be returned.
1462  * \param sym The graph.
1463  * \param inputs The nodes that connect to the subgraph.
1464  * \param input_size The number of such nodes.
1465  */
1466 MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
1467                                   int *input_size);
1468 
1469 /*!
1470  * \brief Get the detailed information about atomic symbol.
1471  * \param creator the AtomicSymbolCreator.
1472  * \param name The returned name of the creator.
1473  * \param description The returned description of the symbol.
1474  * \param num_args Number of arguments.
1475  * \param arg_names Name of the arguments.
1476  * \param arg_type_infos Type informations about the arguments.
1477  * \param arg_descriptions Description information about the arguments.
1478  * \param key_var_num_args The keyword argument for specifying variable number of arguments.
1479  *            When this parameter has non-zero length, the function allows variable number
1480  *            of positional arguments, and will need the caller to pass it in in
1481  *            MXSymbolCreateAtomicSymbol,
1482  *            With key = key_var_num_args, and value = number of positional arguments.
1483  * \param return_type Return type of the function, can be Symbol or Symbol[]
1484  * \return 0 when success, -1 when failure happens
1485  */
1486 MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
1487                                           const char **name,
1488                                           const char **description,
1489                                           uint32_t *num_args,
1490                                           const char ***arg_names,
1491                                           const char ***arg_type_infos,
1492                                           const char ***arg_descriptions,
1493                                           const char **key_var_num_args,
1494                                           const char **return_type DEFAULT(NULL));
1495 /*!
1496  * \brief Create an AtomicSymbol.
1497  * \param creator the AtomicSymbolCreator
1498  * \param num_param the number of parameters
1499  * \param keys the keys to the params
1500  * \param vals the vals of the params
1501  * \param out pointer to the created symbol handle
1502  * \return 0 when success, -1 when failure happens
1503  */
1504 MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
1505                                          uint32_t num_param,
1506                                          const char **keys,
1507                                          const char **vals,
1508                                          SymbolHandle *out);
1509 /*!
1510  * \brief Create a Variable Symbol.
1511  * \param name name of the variable
1512  * \param out pointer to the created symbol handle
1513  * \return 0 when success, -1 when failure happens
1514  */
1515 MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out);
1516 /*!
1517  * \brief Create a Symbol by grouping list of symbols together
1518  * \param num_symbols number of symbols to be grouped
1519  * \param symbols array of symbol handles
1520  * \param out pointer to the created symbol handle
1521  * \return 0 when success, -1 when failure happens
1522  */
1523 MXNET_DLL int MXSymbolCreateGroup(uint32_t num_symbols,
1524                                   SymbolHandle *symbols,
1525                                   SymbolHandle *out);
1526 /*!
1527  * \brief Load a symbol from a json file.
1528  * \param fname the file name.
1529  * \param out the output symbol.
1530  * \return 0 when success, -1 when failure happens
1531  */
1532 MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out);
1533 /*!
1534  * \brief Load a symbol from a json string.
1535  * \param json the json string.
1536  * \param out the output symbol.
1537  * \return 0 when success, -1 when failure happens
1538  */
1539 MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out);
1540 /*!
1541  * \brief Remove the operators amp_cast and amp_multicast
1542  * \param sym_handle the input symbol.
1543  * \param ret_sym_handle the output symbol.
1544  * \return 0 when success, -1 when failure happens
1545  */
1546 MXNET_DLL int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle);
1547 /*!
1548  * \brief Save a symbol into a json file.
1549  * \param symbol the input symbol.
1550  * \param fname the file name.
1551  * \return 0 when success, -1 when failure happens
1552  */
1553 MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname);
1554 /*!
1555  * \brief Save a symbol into a json string
1556  * \param symbol the input symbol.
1557  * \param out_json output json string.
1558  * \return 0 when success, -1 when failure happens
1559  */
1560 MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json);
1561 /*!
1562  * \brief Free the symbol handle.
1563  * \param symbol the symbol
1564  * \return 0 when success, -1 when failure happens
1565  */
1566 MXNET_DLL int MXSymbolFree(SymbolHandle symbol);
1567 /*!
1568  * \brief Copy the symbol to another handle
1569  * \param symbol the source symbol
1570  * \param out used to hold the result of copy
1571  * \return 0 when success, -1 when failure happens
1572  */
1573 MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
1574 /*!
1575  * \brief Print the content of symbol, used for debug.
1576  * \param symbol the symbol
1577  * \param out_str pointer to hold the output string of the printing.
1578  * \return 0 when success, -1 when failure happens
1579  */
1580 MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str);
1581 /*!
1582  * \brief Get string name from symbol
1583  * \param symbol the source symbol
1584  * \param out The result name.
1585  * \param success Whether the result is contained in out.
1586  * \return 0 when success, -1 when failure happens
1587  */
1588 MXNET_DLL int MXSymbolGetName(SymbolHandle symbol,
1589                               const char** out,
1590                               int *success);
1591 /*!
1592  * \brief Get string attribute from symbol
1593  * \param symbol the source symbol
1594  * \param key The key of the symbol.
1595  * \param out The result attribute, can be NULL if the attribute do not exist.
1596  * \param success Whether the result is contained in out.
1597  * \return 0 when success, -1 when failure happens
1598  */
1599 MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol,
1600                               const char* key,
1601                               const char** out,
1602                               int *success);
1603 /*!
1604  * \brief Set string attribute from symbol.
1605  *  NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
1606  *
1607  *  Safe recommendaton: use  immutable graph
1608  *  - Only allow set attributes during creation of new symbol as optional parameter
1609  *
1610  *  Mutable graph (be careful about the semantics):
1611  *  - Allow set attr at any point.
1612  *  - Mutating an attribute of some common node of two graphs can cause confusion from user.
1613  *
1614  * \param symbol the source symbol
1615  * \param key The key of the symbol.
1616  * \param value The value to be saved.
1617  * \return 0 when success, -1 when failure happens
1618  */
1619 MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol,
1620                               const char* key,
1621                               const char* value);
1622 /*!
1623  * \brief Get all attributes from symbol, including all descendents.
1624  * \param symbol the source symbol
1625  * \param out_size The number of output attributes
1626  * \param out 2*out_size strings representing key value pairs.
1627  * \return 0 when success, -1 when failure happens
1628  */
1629 MXNET_DLL int MXSymbolListAttr(SymbolHandle symbol,
1630                                uint32_t *out_size,
1631                                const char*** out);
1632 /*!
1633  * \brief Get all attributes from symbol, excluding descendents.
1634  * \param symbol the source symbol
1635  * \param out_size The number of output attributes
1636  * \param out 2*out_size strings representing key value pairs.
1637  * \return 0 when success, -1 when failure happens
1638  */
1639 MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
1640                                       uint32_t *out_size,
1641                                       const char*** out);
1642 /*!
1643  * \brief List arguments in the symbol.
1644  * \param symbol the symbol
1645  * \param out_size output size
1646  * \param out_str_array pointer to hold the output string array
1647  * \return 0 when success, -1 when failure happens
1648  */
1649 MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
1650                                     uint32_t *out_size,
1651                                     const char ***out_str_array);
1652 
1653 /*!
1654  * \brief List returns in the symbol.
1655  * \param symbol the symbol
1656  * \param out_size output size
1657  * \param out_str_array pointer to hold the output string array
1658  * \return 0 when success, -1 when failure happens
1659  */
1660 MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
1661                                   uint32_t *out_size,
1662                                   const char ***out_str_array);
1663 
1664 /*!
1665  * \brief Get number of outputs of the symbol.
1666  * \param symbol The symbol
1667  * \param out_size number of outputs
1668  * \return 0 when success, -1 when failure happens
1669  */
1670 MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
1671                                     uint32_t *output_count);
1672 
1673 /*!
1674  * \brief Get a symbol that contains all the internals.
1675  * \param symbol The symbol
1676  * \param out The output symbol whose outputs are all the internals.
1677  * \return 0 when success, -1 when failure happens
1678  */
1679 MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
1680                                    SymbolHandle *out);
1681 /*!
1682  * \brief Get a symbol that contains only direct children.
1683  * \param symbol The symbol
1684  * \param out The output symbol whose outputs are the direct children.
1685  * \return 0 when success, -1 when failure happens
1686  */
1687 MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
1688                                   SymbolHandle *out);
1689 /*!
1690  * \brief Get index-th outputs of the symbol.
1691  * \param symbol The symbol
1692  * \param index the Index of the output.
1693  * \param out The output symbol whose outputs are the index-th symbol.
1694  * \return 0 when success, -1 when failure happens
1695  */
1696 MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
1697                                 uint32_t index,
1698                                 SymbolHandle *out);
1699 
1700 /*!
1701  * \brief List auxiliary states in the symbol.
1702  * \param symbol the symbol
1703  * \param out_size output size
1704  * \param out_str_array pointer to hold the output string array
1705  * \return 0 when success, -1 when failure happens
1706  */
1707 MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
1708                                           uint32_t *out_size,
1709                                           const char ***out_str_array);
1710 
1711 /*!
1712  * \brief Compose the symbol on other symbols.
1713  *
1714  *  This function will change the sym hanlde.
1715  *  To achieve function apply behavior, copy the symbol first
1716  *  before apply.
1717  *
1718  * \param sym the symbol to apply
1719  * \param name the name of symbol
1720  * \param num_args number of arguments
1721  * \param keys the key of keyword args (optional)
1722  * \param args arguments to sym
1723  * \return 0 when success, -1 when failure happens
1724  */
1725 MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
1726                               const char *name,
1727                               uint32_t num_args,
1728                               const char** keys,
1729                               SymbolHandle* args);
1730 /*!
1731  * \brief Get the gradient graph of the symbol
1732  *
1733  * \param sym the symbol to get gradient
1734  * \param num_wrt number of arguments to get gradient
1735  * \param wrt the name of the arguments to get gradient
1736  * \param out the returned symbol that has gradient
1737  * \return 0 when success, -1 when failure happens
1738  */
1739 MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
1740                            uint32_t num_wrt,
1741                            const char** wrt,
1742                            SymbolHandle* out);
1743 /*!
1744  * \brief DEPRECATED. Use MXSymbolInferShapeEx instead.
1745  * infer shape of unknown input shapes given the known one.
1746  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1747  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1748  *
1749  * \param sym symbol handle
1750  * \param num_args numbe of input arguments.
1751  * \param keys the key of keyword args (optional)
1752  * \param arg_ind_ptr the head pointer of the rows in CSR
1753  * \param arg_shape_data the content of the CSR
1754  * \param in_shape_size sizeof the returning array of in_shapes
1755  * \param in_shape_ndim returning array of shape dimensions of each input shape.
1756  * \param in_shape_data returning array of pointers to head of the input shape.
1757  * \param out_shape_size sizeof the returning array of out_shapes
1758  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1759  * \param out_shape_data returning array of pointers to head of the output shape.
1760  * \param aux_shape_size sizeof the returning array of aux_shapes
1761  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1762  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1763  * \param complete whether infer shape completes or more information is needed.
1764  * \return 0 when success, -1 when failure happens
1765  */
1766 MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
1767                                  uint32_t num_args,
1768                                  const char** keys,
1769                                  const uint32_t *arg_ind_ptr,
1770                                  const uint32_t *arg_shape_data,
1771                                  uint32_t *in_shape_size,
1772                                  const uint32_t **in_shape_ndim,
1773                                  const uint32_t ***in_shape_data,
1774                                  uint32_t *out_shape_size,
1775                                  const uint32_t **out_shape_ndim,
1776                                  const uint32_t ***out_shape_data,
1777                                  uint32_t *aux_shape_size,
1778                                  const uint32_t **aux_shape_ndim,
1779                                  const uint32_t ***aux_shape_data,
1780                                  int *complete);
1781 
1782 /*!
1783  * \brief infer shape of unknown input shapes given the known one.
1784  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1785  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1786  *  This api is available when MXNet is built with flag
1787  *  USE_INT64_TENSOR_SIZE=0 (by default)
1788  * \param sym symbol handle
1789  * \param num_args number of input arguments.
1790  * \param keys the key of keyword args (optional)
1791  * \param arg_ind_ptr the head pointer of the rows in CSR
1792  * \param arg_shape_data the content of the CSR
1793  * \param in_shape_size sizeof the returning array of in_shapes
1794  * \param in_shape_ndim returning array of shape dimensions of eachs input shape.
1795  * \param in_shape_data returning array of pointers to head of the input shape.
1796  * \param out_shape_size sizeof the returning array of out_shapes
1797  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1798  * \param out_shape_data returning array of pointers to head of the output shape.
1799  * \param aux_shape_size sizeof the returning array of aux_shapes
1800  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1801  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1802  * \param complete whether infer shape completes or more information is needed.
1803  * \return 0 when success, -1 when failure happens
1804  */
1805 MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
1806                                    uint32_t num_args,
1807                                    const char** keys,
1808                                    const uint32_t *arg_ind_ptr,
1809                                    const int *arg_shape_data,
1810                                    uint32_t *in_shape_size,
1811                                    const int **in_shape_ndim,
1812                                    const int ***in_shape_data,
1813                                    uint32_t *out_shape_size,
1814                                    const int **out_shape_ndim,
1815                                    const int ***out_shape_data,
1816                                    uint32_t *aux_shape_size,
1817                                    const int **aux_shape_ndim,
1818                                    const int ***aux_shape_data,
1819                                    int *complete);
1820 
1821 /*!
1822  * \brief infer shape of unknown input shapes given the known one.
1823  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1824  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1825  *  This api is available when MXNet is built with flag
1826  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
1827  * \param sym symbol handle
1828  * \param num_args number of input arguments.
1829  * \param keys the key of keyword args (optional)
1830  * \param arg_ind_ptr the head pointer of the rows in CSR
1831  * \param arg_shape_data the content of the CSR
1832  * \param in_shape_size sizeof the returning array of in_shapes
1833  * \param in_shape_ndim returning array of shape dimensions of each input shape.
1834  * \param in_shape_data returning array of pointers to head of the input shape.
1835  * \param out_shape_size sizeof the returning array of out_shapes
1836  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1837  * \param out_shape_data returning array of pointers to head of the output shape.
1838  * \param aux_shape_size sizeof the returning array of aux_shapes
1839  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1840  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1841  * \param complete whether infer shape completes or more information is needed.
1842  * \return 0 when success, -1 when failure happens
1843  */
1844 MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
1845                                      uint32_t num_args,
1846                                      const char** keys,
1847                                      const int64_t *arg_ind_ptr,
1848                                      const int64_t *arg_shape_data,
1849                                      size_t *in_shape_size,
1850                                      const int **in_shape_ndim,
1851                                      const int64_t ***in_shape_data,
1852                                      size_t *out_shape_size,
1853                                      const int **out_shape_ndim,
1854                                      const int64_t ***out_shape_data,
1855                                      size_t *aux_shape_size,
1856                                      const int **aux_shape_ndim,
1857                                      const int64_t ***aux_shape_data,
1858                                      int *complete);
1859 
1860 /*!
1861  * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
1862  * partially infer shape of unknown input shapes given the known one.
1863  *
1864  *  Return partially inferred results if not all shapes could be inferred.
1865  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1866  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1867  *
1868  * \param sym symbol handle
1869  * \param num_args numbe of input arguments.
1870  * \param keys the key of keyword args (optional)
1871  * \param arg_ind_ptr the head pointer of the rows in CSR
1872  * \param arg_shape_data the content of the CSR
1873  * \param in_shape_size sizeof the returning array of in_shapes
1874  * \param in_shape_ndim returning array of shape dimensions of each input shape.
1875  * \param in_shape_data returning array of pointers to head of the input shape.
1876  * \param out_shape_size sizeof the returning array of out_shapes
1877  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1878  * \param out_shape_data returning array of pointers to head of the output shape.
1879  * \param aux_shape_size sizeof the returning array of aux_shapes
1880  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1881  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1882  * \param complete whether infer shape completes or more information is needed.
1883  * \return 0 when success, -1 when failure happens
1884  */
1885 MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
1886                                         uint32_t num_args,
1887                                         const char** keys,
1888                                         const uint32_t *arg_ind_ptr,
1889                                         const uint32_t *arg_shape_data,
1890                                         uint32_t *in_shape_size,
1891                                         const uint32_t **in_shape_ndim,
1892                                         const uint32_t ***in_shape_data,
1893                                         uint32_t *out_shape_size,
1894                                         const uint32_t **out_shape_ndim,
1895                                         const uint32_t ***out_shape_data,
1896                                         uint32_t *aux_shape_size,
1897                                         const uint32_t **aux_shape_ndim,
1898                                         const uint32_t ***aux_shape_data,
1899                                         int *complete);
1900 
1901 /*!
1902  * \brief partially infer shape of unknown input shapes given the known one.
1903  *
1904  *  Return partially inferred results if not all shapes could be inferred.
1905  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1906  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1907  *  This api is available when MXNet is built with flag
1908  *  USE_INT64_TENSOR_SIZE=0 (by default)
1909  *
1910  * \param sym symbol handle
1911  * \param num_args number of input arguments.
1912  * \param keys the key of keyword args (optional)
1913  * \param arg_ind_ptr the head pointer of the rows in CSR
1914  * \param arg_shape_data the content of the CSR
1915  * \param in_shape_size sizeof the returning array of in_shapes
1916  * \param in_shape_ndim returning array of shape dimensions of each input shape.
1917  * \param in_shape_data returning array of pointers to head of the input shape.
1918  * \param out_shape_size sizeof the returning array of out_shapes
1919  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1920  * \param out_shape_data returning array of pointers to head of the output shape.
1921  * \param aux_shape_size sizeof the returning array of aux_shapes
1922  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1923  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1924  * \param complete whether infer shape completes or more information is needed.
1925  * \return 0 when success, -1 when failure happens
1926  */
1927 MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
1928                                           uint32_t num_args,
1929                                           const char** keys,
1930                                           const uint32_t *arg_ind_ptr,
1931                                           const int *arg_shape_data,
1932                                           uint32_t *in_shape_size,
1933                                           const int **in_shape_ndim,
1934                                           const int ***in_shape_data,
1935                                           uint32_t *out_shape_size,
1936                                           const int **out_shape_ndim,
1937                                           const int ***out_shape_data,
1938                                           uint32_t *aux_shape_size,
1939                                           const int **aux_shape_ndim,
1940                                           const int ***aux_shape_data,
1941                                           int *complete);
1942 
1943 /*!
1944  * \brief partially infer shape of unknown input shapes given the known one.
1945  *
1946  *  Return partially inferred results if not all shapes could be inferred.
1947  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
1948  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1949  *  This api is available when MXNet is built with flag
1950  *  USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
1951  *
1952  * \param sym symbol handle
1953  * \param num_args number of input arguments.
1954  * \param keys the key of keyword args (optional)
1955  * \param arg_ind_ptr the head pointer of the rows in CSR
1956  * \param arg_shape_data the content of the CSR
1957  * \param in_shape_size sizeof the returning array of in_shapes
1958  * \param in_shape_ndim returning array of shape dimensions of each input shape.
1959  * \param in_shape_data returning array of pointers to head of the input shape.
1960  * \param out_shape_size sizeof the returning array of out_shapes
1961  * \param out_shape_ndim returning array of shape dimensions of each output shape.
1962  * \param out_shape_data returning array of pointers to head of the output shape.
1963  * \param aux_shape_size sizeof the returning array of aux_shapes
1964  * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape.
1965  * \param aux_shape_data returning array of pointers to head of the auxiliary shape.
1966  * \param complete whether infer shape completes or more information is needed.
1967  * \return 0 when success, -1 when failure happens
1968  */
1969 MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
1970                                             uint32_t num_args,
1971                                             const char** keys,
1972                                             const int64_t *arg_ind_ptr,
1973                                             const int64_t *arg_shape_data,
1974                                             size_t *in_shape_size,
1975                                             const int **in_shape_ndim,
1976                                             const int64_t ***in_shape_data,
1977                                             size_t *out_shape_size,
1978                                             const int **out_shape_ndim,
1979                                             const int64_t ***out_shape_data,
1980                                             size_t *aux_shape_size,
1981                                             const int **aux_shape_ndim,
1982                                             const int64_t ***aux_shape_data,
1983                                             int *complete);
1984 
1985 /*!
1986  * \brief infer type of unknown input types given the known one.
1987  *  The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
1988  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
1989  *
1990  * \param sym symbol handle
1991  * \param num_args numbe of input arguments.
1992  * \param keys the key of keyword args (optional)
1993  * \param arg_type_data the content of the CSR
1994  * \param in_type_size sizeof the returning array of in_types
1995  * \param in_type_data returning array of pointers to head of the input type.
1996  * \param out_type_size sizeof the returning array of out_types
1997  * \param out_type_data returning array of pointers to head of the output type.
1998  * \param aux_type_size sizeof the returning array of aux_types
1999  * \param aux_type_data returning array of pointers to head of the auxiliary type.
2000  * \param complete whether infer type completes or more information is needed.
2001  * \return 0 when success, -1 when failure happens
2002  */
2003 MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
2004                                 uint32_t num_args,
2005                                 const char** keys,
2006                                 const int *arg_type_data,
2007                                 uint32_t *in_type_size,
2008                                 const int **in_type_data,
2009                                 uint32_t *out_type_size,
2010                                 const int **out_type_data,
2011                                 uint32_t *aux_type_size,
2012                                 const int **aux_type_data,
2013                                 int *complete);
2014 
2015 /*!
2016  * \brief partially infer type of unknown input types given the known one.
2017  *
2018  *  Return partially inferred results if not all types could be inferred.
2019  *  The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
2020  *  The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional.
2021  *
2022  * \param sym symbol handle
2023  * \param num_args numbe of input arguments.
2024  * \param keys the key of keyword args (optional)
2025  * \param arg_type_data the content of the CSR
2026  * \param in_type_size sizeof the returning array of in_types
2027  * \param in_type_data returning array of pointers to head of the input type.
2028  * \param out_type_size sizeof the returning array of out_types
2029  * \param out_type_data returning array of pointers to head of the output type.
2030  * \param aux_type_size sizeof the returning array of aux_types
2031  * \param aux_type_data returning array of pointers to head of the auxiliary type.
2032  * \param complete whether infer type completes or more information is needed.
2033  * \return 0 when success, -1 when failure happens
2034  */
2035 MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
2036                                        uint32_t num_args,
2037                                        const char** keys,
2038                                        const int *arg_type_data,
2039                                        uint32_t *in_type_size,
2040                                        const int **in_type_data,
2041                                        uint32_t *out_type_size,
2042                                        const int **out_type_data,
2043                                        uint32_t *aux_type_size,
2044                                        const int **aux_type_data,
2045                                        int *complete);
2046 
2047 /*!
2048  * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8
2049  * \param sym_handle symbol to be converted
2050  * \param ret_sym_handle quantized symbol result
2051  * \param dev_type device type
2052  * \param num_excluded_sym_names number of layers excluded from being quantized in the input symbol
2053  * \param excluded_sym_names node names to be excluded from being quantized
2054  * \param num_excluded_op_names number of operators excluded from being quantized in the input symbol
2055  * \param excluded_op_names operator names to be excluded from being quantized
2056  * \param num_offline number of parameters that are quantized offline
2057  * \param offline_params array of c strings representing the names of params quantized offline
2058  * \param quantized_dtype the quantized destination type for input data
2059  * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could
2060  * \param quantize_mode quantize mode to be used in quantize pass
2061  * \param quantize_granularity quantize granularity, tensor-wise or channel-wise
2062  * \param out_num_calib_names return the number of nodes to be calibrated
2063  * \param out_calib_names return the node names to be calibrated
2064  */
2065 MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle,
2066                                SymbolHandle *ret_sym_handle,
2067                                const int* dev_type,
2068                                const uint32_t num_excluded_sym_names,
2069                                const char **excluded_sym_names,
2070                                const uint32_t num_excluded_op_names,
2071                                const char **excluded_op_names,
2072                                const uint32_t num_offline, const char **offline_params,
2073                                const char *quantized_dtype, const bool calib_quantize,
2074                                const char *quantize_mode, const char *quantize_granularity,
2075                                uint32_t* out_num_calib_names, const char ***out_calib_names);
2076 
2077 /*!
2078  * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting
2079  * \param sym_handle symbol to be converted
2080  * \param ret_sym_handle mixed precision symbol result
2081  * \param num_args number of arguments for known dtypes
2082  * \param arg_type_data arg types of the arguments
2083  * \param target_dtype target_dtype for mixed precision symbol
2084  * \param cast_optional_params whether to cast optional params to target_dtype
2085  * \param num_target_dtype_op_names number of ops to be casted to target_dtype
2086  * \param num_fp32_op_names number of ops to be casted to FP32
2087  * \param num_widest_dtype_op_names number of ops to be casted to widest dtype
2088  * \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition
2089  * \param num_excluded_symbols number of symbols to be excluded from casting
2090  * \param num_model_params number of model parameters
2091  * \param num_widest_dtype_op_names number of ops to be casted to the widest dtype
2092  * \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision
2093  * \param target_dtype_op_names op names to be casted to target_dtype
2094  * \param fp32_op_names op names to be casted to fp32
2095  * \param widest_dtype_op_names names to be casted to widest dtype
2096  * \param conditional_fp32_op_names names to be casted to FP32 conditionally
2097  * \param excluded_symbols symbol names to be excluded from casting
2098  * \param param_names param names for conditional FP32 casting
2099  * \param param_values param values for conditional FP32 casting
2100  * \param arg_names argument names for which type information is provided
2101  * \param model_param_names names for model parameters
2102  */
2103 MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
2104                                       SymbolHandle *ret_sym_handle,
2105                                       uint32_t num_args,
2106                                       const int* arg_type_data,
2107                                       uint32_t num_ind_ptr,
2108                                       const int* ind_ptr,
2109                                       const int* target_dtype,
2110                                       const int cast_optional_params,
2111                                       const uint32_t num_target_dtype_op_names,
2112                                       const uint32_t num_fp32_op_names,
2113                                       const uint32_t num_widest_dtype_op_names,
2114                                       const uint32_t num_conditional_fp32_op_names,
2115                                       const uint32_t num_excluded_symbols,
2116                                       const uint32_t num_model_params,
2117                                       const char **target_dtype_op_names,
2118                                       const char **fp32_op_names,
2119                                       const char **widest_dtype_op_names,
2120                                       const char **conditional_fp32_op_names,
2121                                       const char **excluded_symbols,
2122                                       const char **conditional_param_names,
2123                                       const char **conditional_param_vals,
2124                                       const char **model_param_names,
2125                                       const char **arg_names);
2126 /*!
2127  * \brief Set calibration table to node attributes in the sym
2128  * \param sym_handle symbol whose node attributes are to be set by calibration table
2129  * \param num_layers number of layers in the calibration table
2130  * \param layer names stored as keys in the calibration table
2131  * \param low_quantiles low quantiles of layers stored in the calibration table
2132  * \param high_quantiles high quantiles of layers stored in the calibration table
2133  * \param ret_sym_handle returned symbol
2134  */
2135 MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
2136                                                const uint32_t num_layers,
2137                                                const char** layer_names,
2138                                                const float* low_quantiles,
2139                                                const float* high_quantiles,
2140                                                SymbolHandle* ret_sym_handle);
2141 
2142 /*!
2143  * \brief Run subgraph pass based on the backend provided
2144  * \param sym_handle symbol to be converted
2145  * \param backend backend names for subgraph pass
2146  * \param ret_sym_handle returned symbol
2147  */
2148 MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
2149                                    SymbolHandle *ret_sym_handle);
2150 
2151 /*!
2152  * \brief Generate atomic symbol (able to be composed) from a source symbol
2153  * \param sym_handle source symbol
2154  * \param ret_sym_handle returned atomic symbol
2155  */
2156 MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle);
2157 /*!
2158  * \brief Partitions symbol for given backend, potentially creating subgraphs
2159  * \param sym_handle symbol to be partitioned
2160  * \param dev_type context device type
2161  * \param backend_name backend name
2162  * \param ret_sym_handle partitioned symbol returned
2163  * \param len number of args
2164  * \param in_args_handle args array
2165  * \param num_options number of key value pairs
2166  * \param keys keys for options
2167  * \param vals values corresponding to keys
2168  * \param num_input_shapes number of input shapes
2169  * \param input_shape_names names of the input shapes
2170  * \param input_shape_data pointer to the contiguous data shapes
2171  * \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape
2172  * is calculate as input_shape_idx[i+1] - input_shape_idx[i]
2173  * \param num_input_dtypes number of input data types
2174  * \param input_dtype_names array of names of the input data types
2175  * \param input_dtypes array of values of the input data types
2176  * \param num_input_stypesnumber of input storage types
2177  * \param input_stype_names array of names of the input storage types
2178  * \param input_stypes array of values of input storage types
2179  * \param skip_infer if the optimization should skip the attribute inferences
2180  * (to use if the backend does not require shape inference)
2181  * \param new_args_cnt pointer a number to store the number of new args
2182  * \param new_args_handle pointer on array to store the new args handles
2183  * \param new_arg_names_handle pointer on array to store the new args names
2184  * \param new_aux_cnt pointer a number to store the number of new aux
2185  * \param new_aux_handle pointer on array to store the new aux handles
2186  * \param new_aux_names_handle pointer on array to store the new aux names
2187  */
2188 MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
2189                                    const char* backend_name,
2190                                    const int dev_type,
2191                                    SymbolHandle* ret_sym_handle,
2192                                    const mx_uint args_len,
2193                                    NDArrayHandle* in_args_handle,
2194                                    const mx_uint aux_len,
2195                                    NDArrayHandle* in_aux_handle,
2196                                    const mx_uint num_options,
2197                                    const char** keys,
2198                                    const char** vals,
2199                                    const uint32_t num_input_shapes,
2200                                    const char** input_shape_names,
2201                                    const int64_t* input_shape_data,
2202                                    const uint32_t* input_shape_idx,
2203                                    const uint32_t num_input_dtypes,
2204                                    const char** input_dtype_names,
2205                                    const int* input_dtypes,
2206                                    const uint32_t num_input_stypes,
2207                                    const char** input_stype_names,
2208                                    const int* input_stypes,
2209                                    bool skip_infer,
2210                                    int* new_args_cnt,
2211                                    NDArrayHandle** new_args_handle,
2212                                    char*** new_arg_names_handle,
2213                                    int* new_aux_cnt,
2214                                    NDArrayHandle** new_aux_handle,
2215                                    char*** new_aux_names_handle);
2216 
2217 
2218 //--------------------------------------------
2219 // Part 4: Executor interface
2220 //--------------------------------------------
2221 /*!
2222  * \brief Delete the executor
2223  * \param handle the executor.
2224  * \return 0 when success, -1 when failure happens
2225  */
2226 MXNET_DLL int MXExecutorFree(ExecutorHandle handle);
2227 /*!
2228  * \brief Print the content of execution plan, used for debug.
2229  * \param handle the executor.
2230  * \param out_str pointer to hold the output string of the printing.
2231  * \return 0 when success, -1 when failure happens
2232  */
2233 MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str);
2234 /*!
2235  * \brief Executor forward method
2236  *
2237  * \param handle executor handle
2238  * \param is_train int value to indicate whether the forward pass is for evaluation
2239  * \return 0 when success, -1 when failure happens
2240  */
2241 MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train);
2242 /*!
2243  * \brief Excecutor run backward
2244  *
2245  * \param handle execute handle
2246  * \param len lenth
2247  * \param head_grads NDArray handle for heads' gradient
2248  *
2249  * \return 0 when success, -1 when failure happens
2250  */
2251 MXNET_DLL int MXExecutorBackward(ExecutorHandle handle,
2252                                  uint32_t len,
2253                                  NDArrayHandle *head_grads);
2254 /*!
2255  * \brief Excecutor run backward
2256  *
2257  * \param handle execute handle
2258  * \param len lenth
2259  * \param head_grads NDArray handle for heads' gradient
2260  * \param is_train int value to indicate whether the backward pass is for evaluation
2261  *
2262  * \return 0 when success, -1 when failure happens
2263  */
2264 MXNET_DLL int MXExecutorBackwardEx(ExecutorHandle handle,
2265                                    uint32_t len,
2266                                    NDArrayHandle *head_grads,
2267                                    int is_train);
2268 /*!
2269  * \brief Get executor's head NDArray
2270  *
2271  * \param handle executor handle
2272  * \param out_size output narray vector size
2273  * \param out out put narray handles
2274  * \return 0 when success, -1 when failure happens
2275  */
2276 MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle,
2277                                 uint32_t *out_size,
2278                                 NDArrayHandle **out);
2279 
2280 /*!
2281  * \brief Generate Executor from symbol
2282  *
2283  * \param symbol_handle symbol handle
2284  * \param dev_type device type
2285  * \param dev_id device id
2286  * \param len length
2287  * \param in_args in args array
2288  * \param arg_grad_store arg grads handle array
2289  * \param grad_req_type grad req array
2290  * \param aux_states_len length of auxiliary states
2291  * \param aux_states auxiliary states array
2292  * \param out output executor handle
2293  * \return 0 when success, -1 when failure happens
2294  */
2295 MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
2296                              int dev_type,
2297                              int dev_id,
2298                              uint32_t len,
2299                              NDArrayHandle *in_args,
2300                              NDArrayHandle *arg_grad_store,
2301                              uint32_t *grad_req_type,
2302                              uint32_t aux_states_len,
2303                              NDArrayHandle *aux_states,
2304                              ExecutorHandle *out);
2305 /*!
2306  * \brief Generate Executor from symbol,
2307  *  This is advanced function, allow specify group2ctx map.
2308  *  The user can annotate "ctx_group" attribute to name each group.
2309  *
2310  * \param symbol_handle symbol handle
2311  * \param dev_type device type of default context
2312  * \param dev_id device id of default context
2313  * \param num_map_keys size of group2ctx map
2314  * \param map_keys keys of group2ctx map
2315  * \param map_dev_types device type of group2ctx map
2316  * \param map_dev_ids device id of group2ctx map
2317  * \param len length
2318  * \param in_args in args array
2319  * \param arg_grad_store arg grads handle array
2320  * \param grad_req_type grad req array
2321  * \param aux_states_len length of auxiliary states
2322  * \param aux_states auxiliary states array
2323  * \param out output executor handle
2324  * \return 0 when success, -1 when failure happens
2325  */
2326 MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle,
2327                               int dev_type,
2328                               int dev_id,
2329                               uint32_t num_map_keys,
2330                               const char** map_keys,
2331                               const int* map_dev_types,
2332                               const int* map_dev_ids,
2333                               uint32_t len,
2334                               NDArrayHandle *in_args,
2335                               NDArrayHandle *arg_grad_store,
2336                               uint32_t *grad_req_type,
2337                               uint32_t aux_states_len,
2338                               NDArrayHandle *aux_states,
2339                               ExecutorHandle *out);
2340 /*!
2341  * \brief Generate Executor from symbol,
2342  *  This is advanced function, allow specify group2ctx map.
2343  *  The user can annotate "ctx_group" attribute to name each group.
2344  *
2345  * \param symbol_handle symbol handle
2346  * \param dev_type device type of default context
2347  * \param dev_id device id of default context
2348  * \param num_map_keys size of group2ctx map
2349  * \param map_keys keys of group2ctx map
2350  * \param map_dev_types device type of group2ctx map
2351  * \param map_dev_ids device id of group2ctx map
2352  * \param len length
2353  * \param in_args in args array
2354  * \param arg_grad_store arg grads handle array
2355  * \param grad_req_type grad req array
2356  * \param aux_states_len length of auxiliary states
2357  * \param aux_states auxiliary states array
2358  * \param shared_exec input executor handle for memory sharing
2359  * \param out output executor handle
2360  * \return 0 when success, -1 when failure happens
2361  */
2362 MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
2363                                int dev_type,
2364                                int dev_id,
2365                                uint32_t num_map_keys,
2366                                const char** map_keys,
2367                                const int* map_dev_types,
2368                                const int* map_dev_ids,
2369                                uint32_t len,
2370                                NDArrayHandle *in_args,
2371                                NDArrayHandle *arg_grad_store,
2372                                uint32_t *grad_req_type,
2373                                uint32_t aux_states_len,
2374                                NDArrayHandle *aux_states,
2375                                ExecutorHandle shared_exec,
2376                                ExecutorHandle *out);
2377 /*! \brief DEPRECATED. Use MXExecutorSimpleBindEx instead.
2378  */
2379 MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
2380                                    int dev_type,
2381                                    int dev_id,
2382                                    const uint32_t num_g2c_keys,
2383                                    const char** g2c_keys,
2384                                    const int* g2c_dev_types,
2385                                    const int* g2c_dev_ids,
2386                                    const uint32_t provided_grad_req_list_len,
2387                                    const char** provided_grad_req_names,
2388                                    const char** provided_grad_req_types,
2389                                    const uint32_t num_provided_arg_shapes,
2390                                    const char** provided_arg_shape_names,
2391                                    const uint32_t* provided_arg_shape_data,
2392                                    const uint32_t* provided_arg_shape_idx,
2393                                    const uint32_t num_provided_arg_dtypes,
2394                                    const char** provided_arg_dtype_names,
2395                                    const int* provided_arg_dtypes,
2396                                    const uint32_t num_provided_arg_stypes,
2397                                    const char** provided_arg_stype_names,
2398                                    const int* provided_arg_stypes,
2399                                    const uint32_t num_shared_arg_names,
2400                                    const char** shared_arg_name_list,
2401                                    int* shared_buffer_len,
2402                                    const char** shared_buffer_name_list,
2403                                    NDArrayHandle* shared_buffer_handle_list,
2404                                    const char*** updated_shared_buffer_name_list,
2405                                    NDArrayHandle** updated_shared_buffer_handle_list,
2406                                    uint32_t* num_in_args,
2407                                    NDArrayHandle** in_args,
2408                                    NDArrayHandle** arg_grads,
2409                                    uint32_t* num_aux_states,
2410                                    NDArrayHandle** aux_states,
2411                                    ExecutorHandle shared_exec_handle,
2412                                    ExecutorHandle* out);
2413 
2414 
2415 MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
2416                                      int dev_type,
2417                                      int dev_id,
2418                                      const uint32_t num_g2c_keys,
2419                                      const char** g2c_keys,
2420                                      const int* g2c_dev_types,
2421                                      const int* g2c_dev_ids,
2422                                      const uint32_t provided_grad_req_list_len,
2423                                      const char** provided_grad_req_names,
2424                                      const char** provided_grad_req_types,
2425                                      const uint32_t num_provided_arg_shapes,
2426                                      const char** provided_arg_shape_names,
2427                                      const int* provided_arg_shape_data,
2428                                      const uint32_t* provided_arg_shape_idx,
2429                                      const uint32_t num_provided_arg_dtypes,
2430                                      const char** provided_arg_dtype_names,
2431                                      const int* provided_arg_dtypes,
2432                                      const uint32_t num_provided_arg_stypes,
2433                                      const char** provided_arg_stype_names,
2434                                      const int* provided_arg_stypes,
2435                                      const uint32_t num_shared_arg_names,
2436                                      const char** shared_arg_name_list,
2437                                      int* shared_buffer_len,
2438                                      const char** shared_buffer_name_list,
2439                                      NDArrayHandle* shared_buffer_handle_list,
2440                                      const char*** updated_shared_buffer_name_list,
2441                                      NDArrayHandle** updated_shared_buffer_handle_list,
2442                                      uint32_t* num_in_args,
2443                                      NDArrayHandle** in_args,
2444                                      NDArrayHandle** arg_grads,
2445                                      uint32_t* num_aux_states,
2446                                      NDArrayHandle** aux_states,
2447                                      ExecutorHandle shared_exec_handle,
2448                                      ExecutorHandle* out);
2449 
2450 
2451 MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle,
2452                                      int dev_type,
2453                                      int dev_id,
2454                                      const uint32_t num_g2c_keys,
2455                                      const char** g2c_keys,
2456                                      const int* g2c_dev_types,
2457                                      const int* g2c_dev_ids,
2458                                      const uint32_t provided_grad_req_list_len,
2459                                      const char** provided_grad_req_names,
2460                                      const char** provided_grad_req_types,
2461                                      const uint32_t num_provided_arg_shapes,
2462                                      const char** provided_arg_shape_names,
2463                                      const int64_t* provided_arg_shape_data,
2464                                      const uint32_t* provided_arg_shape_idx,
2465                                      const uint32_t num_provided_arg_dtypes,
2466                                      const char** provided_arg_dtype_names,
2467                                      const int* provided_arg_dtypes,
2468                                      const uint32_t num_provided_arg_stypes,
2469                                      const char** provided_arg_stype_names,
2470                                      const int* provided_arg_stypes,
2471                                      const uint32_t num_shared_arg_names,
2472                                      const char** shared_arg_name_list,
2473                                      int* shared_buffer_len,
2474                                      const char** shared_buffer_name_list,
2475                                      NDArrayHandle* shared_buffer_handle_list,
2476                                      const char*** updated_shared_buffer_name_list,
2477                                      NDArrayHandle** updated_shared_buffer_handle_list,
2478                                      uint32_t* num_in_args,
2479                                      NDArrayHandle** in_args,
2480                                      NDArrayHandle** arg_grads,
2481                                      uint32_t* num_aux_states,
2482                                      NDArrayHandle** aux_states,
2483                                      ExecutorHandle shared_exec_handle,
2484                                      ExecutorHandle* out);
2485 
2486 
2487 /*!
2488  * \brief DEPRECATED. Use MXExecutorReshapeEx instead.
2489  * Return a new executor with the same symbol and shared memory,
2490  * but different input/output shapes.
2491  *
2492  * \param partial_shaping Whether to allow changing the shape of unspecified arguments.
2493  * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
2494  * \param dev_type device type of default context
2495  * \param dev_id device id of default context
2496  * \param num_map_keys size of group2ctx map
2497  * \param map_keys keys of group2ctx map
2498  * \param map_dev_types device type of group2ctx map
2499  * \param map_dev_ids device id of group2ctx map
2500  * \param num_in_args length of in_args
2501  * \param in_args in args array
2502  * \param arg_grads arg grads handle array
2503  * \param num_aux_states length of auxiliary states
2504  * \param aux_states auxiliary states array
2505  * \param shared_exec input executor handle for memory sharing
2506  * \param out output executor handle
2507  * \return a new executor
2508  */
2509 MXNET_DLL int MXExecutorReshape(int partial_shaping,
2510                                 int allow_up_sizing,
2511                                 int dev_type,
2512                                 int dev_id,
2513                                 uint32_t num_map_keys,
2514                                 const char** map_keys,
2515                                 const int* map_dev_types,
2516                                 const int* map_dev_ids,
2517                                 const uint32_t num_provided_arg_shapes,
2518                                 const char** provided_arg_shape_names,
2519                                 const uint32_t* provided_arg_shape_data,
2520                                 const uint32_t* provided_arg_shape_idx,
2521                                 uint32_t* num_in_args,
2522                                 NDArrayHandle** in_args,
2523                                 NDArrayHandle** arg_grads,
2524                                 uint32_t* num_aux_states,
2525                                 NDArrayHandle** aux_states,
2526                                 ExecutorHandle shared_exec,
2527                                 ExecutorHandle *out);
2528 /*!
2529  * \brief Return a new executor with the same symbol and shared memory,
2530  * but different input/output shapes.
2531  *
2532  * \param partial_shaping Whether to allow changing the shape of unspecified arguments.
2533  * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
2534  * \param dev_type device type of default context
2535  * \param dev_id device id of default context
2536  * \param num_map_keys size of group2ctx map
2537  * \param map_keys keys of group2ctx map
2538  * \param map_dev_types device type of group2ctx map
2539  * \param map_dev_ids device id of group2ctx map
2540  * \param num_in_args length of in_args
2541  * \param in_args in args array
2542  * \param arg_grads arg grads handle array
2543  * \param num_aux_states length of auxiliary states
2544  * \param aux_states auxiliary states array
2545  * \param shared_exec input executor handle for memory sharing
2546  * \param out output executor handle
2547  * \return a new executor
2548  */
2549 MXNET_DLL int MXExecutorReshapeEx(int partial_shaping,
2550                                   int allow_up_sizing,
2551                                   int dev_type,
2552                                   int dev_id,
2553                                   uint32_t num_map_keys,
2554                                   const char** map_keys,
2555                                   const int* map_dev_types,
2556                                   const int* map_dev_ids,
2557                                   const uint32_t num_provided_arg_shapes,
2558                                   const char** provided_arg_shape_names,
2559                                   const int* provided_arg_shape_data,
2560                                   const uint32_t* provided_arg_shape_idx,
2561                                   uint32_t* num_in_args,
2562                                   NDArrayHandle** in_args,
2563                                   NDArrayHandle** arg_grads,
2564                                   uint32_t* num_aux_states,
2565                                   NDArrayHandle** aux_states,
2566                                   ExecutorHandle shared_exec,
2567                                   ExecutorHandle *out);
2568 
2569 /*!
2570  * \brief get optimized graph from graph executor
2571  */
2572 MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
2573                                            SymbolHandle *out);
2574 /*!
2575  * \brief set a call back to notify the completion of operation
2576  */
2577 MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle,
2578                                            ExecutorMonitorCallback callback,
2579                                            void* callback_handle);
2580 
2581 /*!
2582  * \brief set a call back to notify the completion of operation
2583  * \param monitor_all If true, monitor both input and output, otherwise monitor output only.
2584  */
2585 MXNET_DLL int MXExecutorSetMonitorCallbackEX(ExecutorHandle handle,
2586                                              ExecutorMonitorCallback callback,
2587                                              void *callback_handle, bool monitor_all);
2588 //--------------------------------------------
2589 // Part 5: IO Interface
2590 //--------------------------------------------
2591 /*!
2592  * \brief List all the available iterator entries
2593  * \param out_size the size of returned iterators
2594  * \param out_array the output iteratos entries
2595  * \return 0 when success, -1 when failure happens
2596  */
2597 MXNET_DLL int MXListDataIters(uint32_t *out_size,
2598                               DataIterCreator **out_array);
2599 /*!
2600  * \brief Init an iterator, init with parameters
2601  * the array size of passed in arguments
2602  * \param handle of the iterator creator
2603  * \param num_param number of parameter
2604  * \param keys parameter keys
2605  * \param vals parameter values
2606  * \param out resulting iterator
2607  * \return 0 when success, -1 when failure happens
2608  */
2609 MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
2610                                    uint32_t num_param,
2611                                    const char **keys,
2612                                    const char **vals,
2613                                    DataIterHandle *out);
2614 /*!
2615  * \brief Get the detailed information about data iterator.
2616  * \param creator the DataIterCreator.
2617  * \param name The returned name of the creator.
2618  * \param description The returned description of the symbol.
2619  * \param num_args Number of arguments.
2620  * \param arg_names Name of the arguments.
2621  * \param arg_type_infos Type informations about the arguments.
2622  * \param arg_descriptions Description information about the arguments.
2623  * \return 0 when success, -1 when failure happens
2624  */
2625 MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator,
2626                                     const char **name,
2627                                     const char **description,
2628                                     uint32_t *num_args,
2629                                     const char ***arg_names,
2630                                     const char ***arg_type_infos,
2631                                     const char ***arg_descriptions);
2632 /*!
2633  * \brief Free the handle to the IO module
2634  * \param handle the handle pointer to the data iterator
2635  * \return 0 when success, -1 when failure happens
2636  */
2637 MXNET_DLL int MXDataIterFree(DataIterHandle handle);
2638 /*!
2639  * \brief Move iterator to next position
2640  * \param handle the handle to iterator
2641  * \param out return value of next
2642  * \return 0 when success, -1 when failure happens
2643  */
2644 MXNET_DLL int MXDataIterNext(DataIterHandle handle,
2645                              int *out);
2646 /*!
2647  * \brief Call iterator.Reset
2648  * \param handle the handle to iterator
2649  * \return 0 when success, -1 when failure happens
2650  */
2651 MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
2652 
2653 /*!
2654  * \brief Get the handle to the NDArray of underlying data
2655  * \param handle the handle pointer to the data iterator
2656  * \param out handle to underlying data NDArray
2657  * \return 0 when success, -1 when failure happens
2658  */
2659 MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
2660                                 NDArrayHandle *out);
2661 /*!
2662  * \brief Get the image index by array.
2663  * \param handle the handle pointer to the data iterator
2664  * \param out_index output index of the array.
2665  * \param out_size output size of the array.
2666  * \return 0 when success, -1 when failure happens
2667  */
2668 MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle,
2669                                  uint64_t **out_index,
2670                                  uint64_t *out_size);
2671 /*!
2672  * \brief Get the padding number in current data batch
2673  * \param handle the handle pointer to the data iterator
2674  * \param pad pad number ptr
2675  * \return 0 when success, -1 when failure happens
2676  */
2677 MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle,
2678                                   int *pad);
2679 
2680 /*!
2681  * \brief Get the handle to the NDArray of underlying label
2682  * \param handle the handle pointer to the data iterator
2683  * \param out the handle to underlying label NDArray
2684  * \return 0 when success, -1 when failure happens
2685  */
2686 MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
2687                                  NDArrayHandle *out);
2688 //--------------------------------------------
2689 // Part 6: basic KVStore interface
2690 //--------------------------------------------
2691 /*!
2692  * \brief Initialized ps-lite environment variables
2693  * \param num_vars number of variables to initialize
2694  * \param keys environment keys
2695  * \param vals environment values
2696  */
2697 MXNET_DLL int MXInitPSEnv(uint32_t num_vars,
2698                           const char **keys,
2699                           const char **vals);
2700 
2701 
2702 /*!
2703  * \brief Create a kvstore
2704  * \param type the type of KVStore
2705  * \param out The output type of KVStore
2706  * \return 0 when success, -1 when failure happens
2707  */
2708 MXNET_DLL int MXKVStoreCreate(const char *type,
2709                               KVStoreHandle *out);
2710 
2711 /*!
2712  * \brief Set parameters to use low-bit compressed gradients
2713  * \param handle handle to the kvstore
2714  * \param keys keys for compression parameters
2715  * \param vals values for compression parameters
2716  * \return 0 when success, -1 when failure happens
2717  */
2718 MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle,
2719                                               uint32_t num_params,
2720                                               const char** keys,
2721                                               const char** vals);
2722 
2723 /*!
2724  * \brief Delete a KVStore handle.
2725  * \param handle handle to the kvstore
2726  * \return 0 when success, -1 when failure happens
2727  */
2728 MXNET_DLL int MXKVStoreFree(KVStoreHandle handle);
2729 /*!
2730  * \brief Init a list of (key,value) pairs in kvstore
2731  * \param handle handle to the kvstore
2732  * \param num the number of key-value pairs
2733  * \param keys the list of keys
2734  * \param vals the list of values
2735  * \return 0 when success, -1 when failure happens
2736  */
2737 MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
2738                             uint32_t num,
2739                             const int* keys,
2740                             NDArrayHandle* vals);
2741 
2742 /*!
2743  * \brief Init a list of (key,value) pairs in kvstore, where each key is a string
2744  * \param handle handle to the kvstore
2745  * \param num the number of key-value pairs
2746  * \param keys the list of keys
2747  * \param vals the list of values
2748  * \return 0 when success, -1 when failure happens
2749  */
2750 MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle,
2751                               uint32_t num,
2752                               const char** keys,
2753                               NDArrayHandle* vals);
2754 
2755 /*!
2756  * \brief Push a list of (key,value) pairs to kvstore
2757  * \param handle handle to the kvstore
2758  * \param num the number of key-value pairs
2759  * \param keys the list of keys
2760  * \param vals the list of values
2761  * \param priority the priority of the action
2762  * \return 0 when success, -1 when failure happens
2763  */
2764 MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
2765                             uint32_t num,
2766                             const int* keys,
2767                             NDArrayHandle* vals,
2768                             int priority);
2769 /*!
2770  * \brief Push a list of (key,value) pairs to kvstore, where each key is a string
2771  * \param handle handle to the kvstore
2772  * \param num the number of key-value pairs
2773  * \param keys the list of keys
2774  * \param vals the list of values
2775  * \param priority the priority of the action
2776  * \return 0 when success, -1 when failure happens
2777  */
2778 MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
2779                               uint32_t num,
2780                               const char** keys,
2781                               NDArrayHandle* vals,
2782                               int priority);
2783 /*!
2784  * \brief pull a list of (key, value) pairs from the kvstore
2785  * \param handle handle to the kvstore
2786  * \param num the number of key-value pairs
2787  * \param keys the list of keys
2788  * \param vals the list of values
2789  * \param priority the priority of the action
2790  * \param ignore_sparse whether to ignore sparse arrays in the request
2791  * \return 0 when success, -1 when failure happens
2792  */
2793 MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle,
2794                                       uint32_t num,
2795                                       const int* keys,
2796                                       NDArrayHandle* vals,
2797                                       int priority,
2798                                       bool ignore_sparse);
2799 /*!
2800  * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
2801  * \param handle handle to the kvstore
2802  * \param num the number of key-value pairs
2803  * \param keys the list of keys
2804  * \param vals the list of values
2805  * \param priority the priority of the action
2806  * \param ignore_sparse whether to ignore sparse arrays in the request
2807  * \return 0 when success, -1 when failure happens
2808  */
2809 MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle,
2810                                         uint32_t num,
2811                                         const char** keys,
2812                                         NDArrayHandle* vals,
2813                                         int priority,
2814                                         bool ignore_sparse);
2815 /*!
2816  * \brief pull a list of (key, value) pairs from the kvstore
2817  * \param handle handle to the kvstore
2818  * \param num the number of key-value pairs
2819  * \param keys the list of keys
2820  * \param vals the list of values
2821  * \param priority the priority of the action
2822  * \return 0 when success, -1 when failure happens
2823  */
2824 MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
2825                             uint32_t num,
2826                             const int* keys,
2827                             NDArrayHandle* vals,
2828                             int priority);
2829 /*!
2830  * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
2831  * \param handle handle to the kvstore
2832  * \param num the number of key-value pairs
2833  * \param keys the list of keys
2834  * \param vals the list of values
2835  * \param priority the priority of the action
2836  * \return 0 when success, -1 when failure happens
2837  */
2838 MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
2839                               uint32_t num,
2840                               const char** keys,
2841                               NDArrayHandle* vals,
2842                               int priority);
2843 
2844 /*!
2845  * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer.
2846  *        The NDArray pulled back will be in row_sparse storage with only the specified
2847  *        row_ids present based row_ids (others rows are zeros).
2848  * \param handle handle to the kvstore
2849  * \param num the number of key-value pairs
2850  * \param keys the list of keys
2851  * \param vals the list of values
2852  * \param row_ids the list of row_id NDArrays
2853  * \param priority the priority of the action
2854  * \return 0 when success, -1 when failure happens
2855  */
2856 MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
2857                                      uint32_t num,
2858                                      const int* keys,
2859                                      NDArrayHandle* vals,
2860                                      const NDArrayHandle* row_ids,
2861                                      int priority);
2862 /*!
2863  * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
2864  *        The NDArray pulled back will be in row_sparse storage with only the specified
2865  *        row_ids present based row_ids (others rows are zeros).
2866  * \param handle handle to the kvstore
2867  * \param num the number of key-value pairs
2868  * \param keys the list of keys
2869  * \param vals the list of values
2870  * \param row_ids the list of row_id NDArrays
2871  * \param priority the priority of the action
2872  * \return 0 when success, -1 when failure happens
2873  */
2874 MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
2875                                        uint32_t num,
2876                                        const char** keys,
2877                                        NDArrayHandle* vals,
2878                                        const NDArrayHandle* row_ids,
2879                                        int priority);
2880 
2881 /*!
2882  * \brief broadcast a list of (key, value) pairs from the kvstore
2883  * \param handle handle to the kvstore
2884  * \param vnum the number of key-value pairs corresponding to vkeys
2885  * \param vkeys the list of keys for the values to be pushed
2886  * \param onum the number of key-value pairs corresponding to okeys
2887  * \param okeys the list of keys for the values to be pulled
2888  * \param vals the list of values
2889  * \param outs the list of outputs
2890  * \param priority the priority of the action
2891  * \return 0 when success, -1 when failure happens
2892  */
2893 MXNET_DLL int MXKVStoreBroadcast(KVStoreHandle handle,
2894                                  mx_uint vnum,
2895                                  const int* vkeys,
2896                                  mx_uint onum,
2897                                  const int* okeys,
2898                                  NDArrayHandle* vals,
2899                                  NDArrayHandle* outs,
2900                                  int priority);
2901 /*!
2902  * \brief broadcast a list of (key, value) pairs from the kvstore,
2903  * where each key is a string
2904  * \param handle handle to the kvstore
2905  * \param vnum the number of key-value pairs corresponding to vkeys
2906  * \param vkeys the list of keys for the values to be pushed
2907  * \param onum the number of key-value pairs corresponding to okeys
2908  * \param okeys the list of keys for the values to be pulled
2909  * \param vals the list of values
2910  * \param outs the list of outputs
2911  * \param priority the priority of the action
2912  * \return 0 when success, -1 when failure happens
2913  */
2914 MXNET_DLL int MXKVStoreBroadcastEx(KVStoreHandle handle,
2915                                    mx_uint vnum,
2916                                    const char** vkeys,
2917                                    mx_uint onum,
2918                                    const char** okeys,
2919                                    NDArrayHandle* vals,
2920                                    NDArrayHandle* outs,
2921                                    int priority);
2922 
2923 /*!
2924  * \brief push and pull a list of (key, value) pairs from the kvstore
2925  * \param handle handle to the kvstore
2926  * \param vnum the number of key-value pairs corresponding to vkeys
2927  * \param vkeys the list of keys for the values to be pushed
2928  * \param onum the number of key-value pairs corresponding to okeys
2929  * \param okeys the list of keys for the values to be pulled
2930  * \param vals the list of values
2931  * \param outs the list of outputs
2932  * \param priority the priority of the action
2933  * \return 0 when success, -1 when failure happens
2934  */
2935 MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle,
2936                                 mx_uint vnum,
2937                                 const int* vkeys,
2938                                 mx_uint onum,
2939                                 const int* okeys,
2940                                 NDArrayHandle* vals,
2941                                 NDArrayHandle* outs,
2942                                 int priority);
2943 /*!
2944  * \brief push and pull a list of (key, value) pairs from the kvstore,
2945  * where each key is a string
2946  * \param handle handle to the kvstore
2947  * \param vnum the number of key-value pairs corresponding to vkeys
2948  * \param vkeys the list of keys for the values to be pushed
2949  * \param onum the number of key-value pairs corresponding to okeys
2950  * \param okeys the list of keys for the values to be pulled
2951  * \param vals the list of values
2952  * \param outs the list of outputs
2953  * \param priority the priority of the action
2954  * \return 0 when success, -1 when failure happens
2955  */
2956 MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle,
2957                                   mx_uint vnum,
2958                                   const char** vkeys,
2959                                   mx_uint onum,
2960                                   const char** okeys,
2961                                   NDArrayHandle* vals,
2962                                   NDArrayHandle* outs,
2963                                   int priority);
2964 
2965 /*!
2966  * \brief user-defined updater for the kvstore
2967  * It's this updater's responsibility to delete \a recv and \a local
2968  * \param the key
2969  * \param recv the pushed value on this key
2970  * \param local the value stored on local on this key
2971  * \param handle The additional handle to the updater
2972  */
2973 typedef void (MXKVStoreUpdater)(int key,
2974                                 NDArrayHandle recv,
2975                                 NDArrayHandle local,
2976                                 void *handle);
2977 /*!
2978  * \brief user-defined updater for the kvstore with string keys
2979  * It's this updater's responsibility to delete \a recv and \a local
2980  * \param the key
2981  * \param recv the pushed value on this key
2982  * \param local the value stored on local on this key
2983  * \param handle The additional handle to the updater
2984  */
2985 typedef void (MXKVStoreStrUpdater)(const char* key,
2986                                    NDArrayHandle recv,
2987                                    NDArrayHandle local,
2988                                    void *handle);
2989 /*!
2990  * \brief register a push updater
2991  * \param handle handle to the KVStore
2992  * \param updater udpater function
2993  * \param updater_handle The additional handle used to invoke the updater
2994  * \return 0 when success, -1 when failure happens
2995  */
2996 MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle,
2997                                   MXKVStoreUpdater updater,
2998                                   void *updater_handle);
2999 /*!
3000  * \brief register a push updater with int keys and one with string keys
3001  * \param handle handle to the KVStore
3002  * \param updater updater function with int keys
3003  * \param str_updater updater function with string keys
3004  * \param updater_handle The additional handle used to invoke the updater
3005  * \return 0 when success, -1 when failure happens
3006  */
3007 MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
3008                                     MXKVStoreUpdater updater,
3009                                     MXKVStoreStrUpdater str_updater,
3010                                     void *updater_handle);
3011 /*!
3012  * \brief get the type of the kvstore
3013  * \param handle handle to the KVStore
3014  * \param type a string type
3015  * \return 0 when success, -1 when failure happens
3016  */
3017 MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle,
3018                                const char** type);
3019 //--------------------------------------------
3020 // Part 6: advanced KVStore for multi-machines
3021 //--------------------------------------------
3022 
3023 /**
3024  * \brief return The rank of this node in its group, which is in [0, GroupSize).
3025  *
3026  * \param handle handle to the KVStore
3027  * \param ret the node rank
3028  * \return 0 when success, -1 when failure happens
3029  */
3030 MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle,
3031                                int *ret);
3032 
3033 /**
3034  * \brief return The number of nodes in this group, which is
3035  * - number of workers if if `IsWorkerNode() == true`,
3036  * - number of servers if if `IsServerNode() == true`,
3037  * - 1 if `IsSchedulerNode() == true`,
3038  * \param handle handle to the KVStore
3039  * \param ret the group size
3040  * \return 0 when success, -1 when failure happens
3041  */
3042 MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle,
3043                                     int *ret);
3044 
3045 /**
3046  * \brief return whether or not this process is a worker node.
3047  * \param ret 1 for yes, 0 for no
3048  * \return 0 when success, -1 when failure happens
3049  */
3050 MXNET_DLL int MXKVStoreIsWorkerNode(int *ret);
3051 
3052 
3053 /**
3054  * \brief return whether or not this process is a server node.
3055  * \param ret 1 for yes, 0 for no
3056  * \return 0 when success, -1 when failure happens
3057  */
3058 MXNET_DLL int MXKVStoreIsServerNode(int *ret);
3059 
3060 
3061 /**
3062  * \brief return whether or not this process is a scheduler node.
3063  * \param ret 1 for yes, 0 for no
3064  * \return 0 when success, -1 when failure happens
3065  */
3066 MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret);
3067 
3068 /**
3069  * \brief global barrier among all worker machines
3070  *
3071  * \param handle handle to the KVStore
3072  * \return 0 when success, -1 when failure happens
3073  */
3074 MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle);
3075 
3076 /**
3077  * \brief whether to do barrier when finalize
3078  *
3079  * \param handle handle to the KVStore
3080  * \param barrier_before_exit whether to do barrier when kvstore finalize
3081  * \return 0 when success, -1 when failure happens
3082  */
3083 MXNET_DLL int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle,
3084                                             const int barrier_before_exit);
3085 
3086 /**
3087  * \brief the prototype of a server controller
3088  * \param head the head of the command
3089  * \param body the body of the command
3090  * \param controller_handle helper handle for implementing controller
3091  */
3092 typedef void (MXKVStoreServerController)(int head,
3093                                          const char *body,
3094                                          void *controller_handle);
3095 
3096 /**
3097  * \brief Run as server (or scheduler)
3098  * \param handle handle to the KVStore
3099  * \param controller the user-defined server controller
3100  * \param controller_handle helper handle for implementing controller
3101  * \return 0 when success, -1 when failure happens
3102  */
3103 MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
3104                                  MXKVStoreServerController controller,
3105                                  void *controller_handle);
3106 
3107 /**
3108  * \brief Send a command to all server nodes
3109  * \param handle handle to the KVStore
3110  * \param cmd_id the head of the command
3111  * \param cmd_body the body of the command
3112  * \return 0 when success, -1 when failure happens
3113  */
3114 MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle,
3115                                              int cmd_id,
3116                                              const char* cmd_body);
3117 
3118 /**
3119  * \brief Get the number of ps dead node(s) specified by {node_id}
3120  *
3121  * \param handle handle to the KVStore
3122  * \param node_id Can be a node group or a single node.
3123  *                kScheduler = 1, kServerGroup = 2, kWorkerGroup = 4
3124  * \param number Ouptut number of dead nodes
3125  * \param timeout_sec A node fails to send heartbeart in {timeout_sec} seconds
3126  *                    will be presumed as 'dead'
3127  */
3128 MXNET_DLL int MXKVStoreGetNumDeadNode(KVStoreHandle handle,
3129                                       const int node_id,
3130                                       int *number,
3131                                       const int timeout_sec DEFAULT(60));
3132 
3133 /**
3134  * \brief Create a RecordIO writer object
3135  * \param uri path to file
3136  * \param out handle pointer to the created object
3137  * \return 0 when success, -1 when failure happens
3138 */
3139 MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out);
3140 
3141 /**
3142  * \brief Delete a RecordIO writer object
3143  * \param handle handle to RecordIO object
3144  * \return 0 when success, -1 when failure happens
3145 */
3146 MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle);
3147 
3148 /**
3149  * \brief Write a record to a RecordIO object
3150  * \param handle handle to RecordIO object
3151  * \param buf buffer to write
3152  * \param size size of buffer
3153  * \return 0 when success, -1 when failure happens
3154 */
3155 MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle handle,
3156                                           const char *buf, size_t size);
3157 
3158 /**
3159  * \brief Get the current writer pointer position
3160  * \param handle handle to RecordIO object
3161  * \param pos handle to output position
3162  * \return 0 when success, -1 when failure happens
3163 */
3164 MXNET_DLL int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos);
3165 
3166 /**
3167  * \brief Create a RecordIO reader object
3168  * \param uri path to file
3169  * \param out handle pointer to the created object
3170  * \return 0 when success, -1 when failure happens
3171 */
3172 MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out);
3173 
3174 /**
3175  * \brief Delete a RecordIO reader object
3176  * \param handle handle to RecordIO object
3177  * \return 0 when success, -1 when failure happens
3178 */
3179 MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle handle);
3180 
3181 /**
3182  * \brief Write a record to a RecordIO object
3183  * \param handle handle to RecordIO object
3184  * \param buf pointer to return buffer
3185  * \param size point to size of buffer
3186  * \return 0 when success, -1 when failure happens
3187 */
3188 MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle handle,
3189                                         char const **buf, size_t *size);
3190 
3191 /**
3192  * \brief Set the current reader pointer position
3193  * \param handle handle to RecordIO object
3194  * \param pos target position
3195  * \return 0 when success, -1 when failure happens
3196 */
3197 MXNET_DLL int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos);
3198 
3199 /**
3200  * \brief Get the current writer pointer position
3201  * \param handle handle to RecordIO object
3202  * \param pos handle to output position
3203  * \return 0 when success, -1 when failure happens
3204 */
3205 MXNET_DLL int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos);
3206 
3207 /**
3208  * \brief Create a MXRtc object
3209 */
3210 MXNET_DLL int MXRtcCreate(char* name, uint32_t num_input, uint32_t num_output,
3211                           char** input_names, char** output_names,
3212                           NDArrayHandle* inputs, NDArrayHandle* outputs,
3213                           char* kernel, RtcHandle *out);
3214 
3215 /**
3216  * \brief Run cuda kernel
3217 */
3218 MXNET_DLL int MXRtcPush(RtcHandle handle, uint32_t num_input, uint32_t num_output,
3219                         NDArrayHandle* inputs, NDArrayHandle* outputs,
3220                         uint32_t gridDimX,
3221                         uint32_t gridDimY,
3222                         uint32_t gridDimZ,
3223                         uint32_t blockDimX,
3224                         uint32_t blockDimY,
3225                         uint32_t blockDimZ);
3226 
3227 /**
3228  * \brief Delete a MXRtc object
3229 */
3230 MXNET_DLL int MXRtcFree(RtcHandle handle);
3231 /*
3232  * \brief register custom operators from frontend.
3233  * \param op_type name of custom op
3234  * \param creator
3235  */
3236 MXNET_DLL int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator);
3237 /*
3238  * \brief record custom function for backward later.
3239  * \param num_inputs number of input NDArrays.
3240  * \param inputs handle to input NDArrays.
3241  * \param num_outputs number of output NDArrays.
3242  * \param outputs handle to output NDArrays.
3243  * \param callbacks callbacks for backward function.
3244  */
3245 MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
3246                                      int num_outputs, NDArrayHandle *outputs,
3247                                      struct MXCallbackList *callbacks);
3248 /*
3249  * \brief create cuda rtc module
3250  * \param source cuda source code
3251  * \param num_options number of compiler flags
3252  * \param options compiler flags
3253  * \param num_exports number of exported function names
3254  * \param exported function names
3255  * \param out handle to created module
3256  */
3257 MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options,
3258                                     const char** options, int num_exports,
3259                                     const char** exports, CudaModuleHandle *out);
3260 /*
3261  * \brief delete cuda rtc module
3262  * \param handle handle to cuda module
3263  */
3264 MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle);
3265 /*
3266  * \brief get kernel from module
3267  * \param handle handle to cuda module
3268  * \param name name of kernel function
3269  * \param num_args number of arguments
3270  * \param is_ndarray whether argument is ndarray
3271  * \param is_const whether argument is constant
3272  * \param arg_types data type of arguments
3273  * \param out created kernel
3274  */
3275 MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name,
3276                                     int num_args, int* is_ndarray, int* is_const,
3277                                     int* arg_types, CudaKernelHandle *out);
3278 /*
3279  * \brief delete kernel
3280  * \param handle handle to previously created kernel
3281  */
3282 MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle);
3283 /*
3284  * \brief launch cuda kernel
3285  * \param handle handle to kernel
3286  * \param dev_id (GPU) device id
3287  * \param args pointer to arguments
3288  * \param grid_dim_x grid dimension x
3289  * \param grid_dim_y grid dimension y
3290  * \param grid_dim_z grid dimension z
3291  * \param block_dim_x block dimension x
3292  * \param block_dim_y block dimension y
3293  * \param block_dim_z block dimension z
3294  * \param shared_mem size of dynamically allocated shared memory
3295  */
3296 MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
3297                                   uint32_t grid_dim_x, uint32_t grid_dim_y,
3298                                   uint32_t grid_dim_z, uint32_t block_dim_x,
3299                                   uint32_t block_dim_y, uint32_t block_dim_z,
3300                                   uint32_t shared_mem);
3301 /*!
3302  * \brief Get shared memory handle from NDArray
3303  * \param handle NDArray handle.
3304  * \param shared_pid output PID
3305  * \param shared_id output shared memory id.
3306  */
3307 MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
3308                                           int* shared_id);
3309 /*!
3310  * \brief DEPRECATED. Use MXNDArrayCreateFromSharedMemEx instead.
3311  * Reconstruct NDArray from shared memory handle
3312  * \param shared_pid shared PID
3313  * \param shared_id shared memory id
3314  * \param shape pointer to NDArray dimensions
3315  * \param ndim number of NDArray dimensions
3316  * \param dtype data type of NDArray
3317  * \param out constructed NDArray
3318  */
3319 MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const uint32_t *shape,
3320                                            uint32_t ndim, int dtype, NDArrayHandle *out);
3321 
3322 /*!
3323  * \brief Release all unreferenced memory from the devices storage managers memory pool
3324  * \param dev_type device type, specify device we want to take
3325  * \param dev_id the device id of the specific device
3326  */
3327 MXNET_DLL int MXStorageEmptyCache(int dev_type, int dev_id);
3328 
3329 /*!
3330  * \brief Reconstruct NDArray from shared memory handle
3331  * \param shared_pid shared PID
3332  * \param shared_id shared memory id
3333  * \param shape pointer to NDArray dimensions
3334  * \param ndim number of NDArray dimensions
3335  * \param dtype data type of NDArray
3336  * \param out constructed NDArray
3337  */
3338 MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape,
3339                                              int ndim, int dtype, NDArrayHandle *out);
3340 
3341 /*!
3342   * \brief Push an asynchronous operation to the engine.
3343   * \param async_func Execution function whici takes a parameter on_complete
3344   *                   that must be called when the execution ompletes.
3345   * \param func_param The parameter set on calling async_func, can be NULL.
3346   * \param deleter The callback to free func_param, can be NULL.
3347   * \param ctx_handle Execution context.
3348   * \param const_vars_handle The variables that current operation will use
3349   *                          but not mutate.
3350   * \param num_const_vars The number of const_vars_handle.
3351   * \param mutable_vars_handle The variables that current operation will mutate.
3352   * \param num_mutable_vars The number of mutable_vars_handle.
3353   * \param prop_handle Property of the function.
3354   * \param priority Priority of the action, as hint to the engine.
3355   * \param opr_name The operation name.
3356   * \param wait Whether this is a WaitForVar operation.
3357   */
3358 MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
3359                                 EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
3360                                 EngineVarHandle const_vars_handle, int num_const_vars,
3361                                 EngineVarHandle mutable_vars_handle, int num_mutable_vars,
3362                                 EngineFnPropertyHandle prop_handle DEFAULT(NULL),
3363                                 int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
3364                                 bool wait DEFAULT(false));
3365 
3366 /*!
3367   * \brief Push a synchronous operation to the engine.
3368   * \param sync_func Execution function that executes the operation.
3369   * \param func_param The parameter set on calling sync_func, can be NULL.
3370   * \param deleter The callback to free func_param, can be NULL.
3371   * \param ctx_handle Execution context.
3372   * \param const_vars_handle The variables that current operation will use
3373   *                          but not mutate.
3374   * \param num_const_vars The number of const_vars_handle.
3375   * \param mutable_vars_handle The variables that current operation will mutate.
3376   * \param num_mutable_vars The number of mutable_vars_handle.
3377   * \param prop_handle Property of the function.
3378   * \param priority Priority of the action, as hint to the engine.
3379   * \param opr_name The operation name.
3380   */
3381 MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
3382                                EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
3383                                EngineVarHandle const_vars_handle, int num_const_vars,
3384                                EngineVarHandle mutable_vars_handle, int num_mutable_vars,
3385                                EngineFnPropertyHandle prop_handle DEFAULT(NULL),
3386                                int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
3387 /*!
3388  * \brief Create an NDArray from source sharing the same data chunk.
3389  * \param src source NDArray
3390  * \param out new NDArray sharing the same data chunck with src
3391  */
3392 MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
3393 /*!
3394  * \brief Create an Symbol from source sharing the same graph structure.
3395  * \param src source Symbol
3396  * \param out new Symbol sharing the same graph structure with src
3397  */
3398 MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
3399 
3400 /*!
3401   * \brief Push an asynchronous operation to the engine.
3402   * \param async_func Execution function whici takes a parameter on_complete
3403   *                   that must be called when the execution ompletes.
3404   * \param func_param The parameter set on calling async_func, can be NULL.
3405   * \param deleter The callback to free func_param, can be NULL.
3406   * \param ctx_handle Execution context.
3407   * \param const_nds_handle The NDArrays that current operation will use
3408   *                          but not mutate.
3409   * \param num_const_nds The number of const_nds_handle.
3410   * \param mutable_nds_handle The NDArrays that current operation will mutate.
3411   * \param num_mutable_nds The number of mutable_nds_handle.
3412   * \param prop_handle Property of the function.
3413   * \param priority Priority of the action, as hint to the engine.
3414   * \param opr_name The operation name.
3415   * \param wait Whether this is a WaitForVar operation.
3416   */
3417 MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
3418                                   EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
3419                                   NDArrayHandle* const_nds_handle, int num_const_nds,
3420                                   NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
3421                                   EngineFnPropertyHandle prop_handle DEFAULT(NULL),
3422                                   int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
3423                                   bool wait DEFAULT(false));
3424 
3425 /*!
3426   * \brief Push a synchronous operation to the engine.
3427   * \param sync_func Execution function that executes the operation.
3428   * \param func_param The parameter set on calling sync_func, can be NULL.
3429   * \param deleter The callback to free func_param, can be NULL.
3430   * \param ctx_handle Execution context.
3431   * \param const_nds_handle The NDArrays that current operation will use
3432   *                          but not mutate.
3433   * \param num_const_nds The number of const_nds_handle.
3434   * \param mutable_nds_handle The NDArrays that current operation will mutate.
3435   * \param num_mutable_nds The number of mutable_nds_handle.
3436   * \param prop_handle Property of the function.
3437   * \param priority Priority of the action, as hint to the engine.
3438   * \param opr_name The operation name.
3439   */
3440 MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
3441                                  EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
3442                                  NDArrayHandle* const_nds_handle, int num_const_nds,
3443                                  NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
3444                                  EngineFnPropertyHandle prop_handle DEFAULT(NULL),
3445                                  int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
3446 
3447 #ifdef __cplusplus
3448 }
3449 #endif  // __cplusplus
3450 
3451 #endif  // MXNET_C_API_H_
3452