1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file c_api.h 22 * \brief C API of mxnet 23 */ 24 #ifndef MXNET_C_API_H_ 25 #define MXNET_C_API_H_ 26 27 /*! \brief Inhibit C++ name-mangling for MXNet functions. */ 28 #ifdef __cplusplus 29 extern "C" { 30 #endif // __cplusplus 31 32 /*! \brief Keep the default value in C++ */ 33 #ifdef __cplusplus 34 #define DEFAULT(x) = x 35 #else 36 #define DEFAULT(x) 37 #endif // __cplusplus 38 39 #include <stdint.h> 40 41 #include <stdint.h> 42 #include <stddef.h> 43 #include <stdbool.h> 44 45 /*! \brief MXNET_DLL prefix for windows */ 46 #ifdef _WIN32 47 #ifdef MXNET_EXPORTS 48 #define MXNET_DLL __declspec(dllexport) 49 #else 50 #define MXNET_DLL __declspec(dllimport) 51 #endif 52 #else 53 #define MXNET_DLL 54 #endif 55 56 /*! \brief manually define unsigned int */ 57 typedef uint32_t mx_uint; 58 /*! \brief manually define float */ 59 typedef float mx_float; 60 /*! \brief data type to store dim size */ 61 typedef int64_t dim_t; 62 // all the handles are simply void * 63 // will be casted internally to specific pointers types 64 // these typedefs are mainly used for readablity reasons 65 /*! \brief handle to NDArray */ 66 typedef void *NDArrayHandle; 67 /*! \brief handle to a mxnet narray function that changes NDArray */ 68 typedef const void *FunctionHandle; 69 /*! \brief handle to a function that takes param and creates symbol */ 70 typedef void *AtomicSymbolCreator; 71 /*! \brief handle to cached operator */ 72 typedef void *CachedOpHandle; 73 /*! \brief handle to a symbol that can be bind as operator */ 74 typedef void *SymbolHandle; 75 /*! \brief handle to a AtomicSymbol */ 76 typedef void *AtomicSymbolHandle; 77 /*! \brief handle to an Executor */ 78 typedef void *ExecutorHandle; 79 /*! \brief handle a dataiter creator */ 80 typedef void *DataIterCreator; 81 /*! \brief handle to a DataIterator */ 82 typedef void *DataIterHandle; 83 /*! \brief handle to KVStore */ 84 typedef void *KVStoreHandle; 85 /*! \brief handle to RecordIO */ 86 typedef void *RecordIOHandle; 87 /*! \brief handle to MXRtc*/ 88 typedef void *RtcHandle; 89 /*! \brief handle to rtc cuda module*/ 90 typedef void *CudaModuleHandle; 91 /*! \brief handle to rtc cuda kernel*/ 92 typedef void *CudaKernelHandle; 93 /*! \brief handle to a Profile object (domain, duration, counter, etc.) */ 94 typedef void *ProfileHandle; 95 /*! \brief handle to DLManagedTensor*/ 96 typedef void *DLManagedTensorHandle; 97 /*! \brief handle to Context */ 98 typedef const void *ContextHandle; 99 /*! \brief handle to Engine FnProperty */ 100 typedef const void *EngineFnPropertyHandle; 101 /*! \brief handle to Engine VarHandle */ 102 typedef void *EngineVarHandle; 103 104 /*! \brief Engine asynchronous operation */ 105 typedef void (*EngineAsyncFunc)(void*, void*, void*); 106 /*! \brief Engine synchronous operation */ 107 typedef void (*EngineSyncFunc)(void*, void*); 108 /*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */ 109 typedef void (*EngineFuncParamDeleter)(void*); 110 typedef void (*ExecutorMonitorCallback)(const char*, 111 NDArrayHandle, 112 void*); 113 /*! \brief Monitor callback called at operator level for cached op */ 114 typedef void (*CachedOpMonitorCallback)(const char*, 115 const char*, 116 NDArrayHandle); 117 118 119 struct NativeOpInfo { 120 void (*forward)(int, float**, int*, unsigned**, int*, void*); 121 void (*backward)(int, float**, int*, unsigned**, int*, void*); 122 void (*infer_shape)(int, int*, unsigned**, void*); 123 void (*list_outputs)(char***, void*); 124 void (*list_arguments)(char***, void*); 125 // all functions also pass a payload void* pointer 126 void* p_forward; 127 void* p_backward; 128 void* p_infer_shape; 129 void* p_list_outputs; 130 void* p_list_arguments; 131 }; 132 133 struct NDArrayOpInfo { 134 bool (*forward)(int, void**, int*, void*); 135 bool (*backward)(int, void**, int*, void*); 136 bool (*infer_shape)(int, int*, unsigned**, void*); 137 bool (*list_outputs)(char***, void*); 138 bool (*list_arguments)(char***, void*); 139 bool (*declare_backward_dependency)(const int*, const int*, const int*, 140 int*, int**, void*); 141 // all functions also pass a payload void* pointer 142 void* p_forward; 143 void* p_backward; 144 void* p_infer_shape; 145 void* p_list_outputs; 146 void* p_list_arguments; 147 void* p_declare_backward_dependency; 148 }; 149 150 typedef int (*MXGenericCallback)(void); 151 152 struct MXCallbackList { 153 int num_callbacks; 154 int (**callbacks)(void); 155 void **contexts; 156 }; 157 158 struct LibFeature { 159 const char* name; 160 bool enabled; 161 }; 162 163 enum CustomOpCallbacks { 164 kCustomOpDelete, 165 kCustomOpForward, 166 kCustomOpBackward 167 }; 168 169 enum CustomOpPropCallbacks { 170 kCustomOpPropDelete, 171 kCustomOpPropListArguments, 172 kCustomOpPropListOutputs, 173 kCustomOpPropListAuxiliaryStates, 174 kCustomOpPropInferShape, 175 kCustomOpPropDeclareBackwardDependency, 176 kCustomOpPropCreateOperator, 177 kCustomOpPropInferType, 178 kCustomOpPropInferStorageType, 179 kCustomOpPropBackwardInferStorageType 180 }; 181 182 183 typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/, 184 const int* /*reqs*/, const int /*is_train*/, 185 void* /*state*/); 186 typedef int (*CustomOpDelFunc)(void* /*state*/); 187 typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/); 188 typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/, 189 int** /*shapes*/, void* /*state*/); 190 typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/); 191 typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/, 192 int * /*stypes*/, 193 int * /*tags*/, 194 void * /*state*/); 195 typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/); 196 typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/, 197 const int* /*out_data*/, int* /*num_deps*/, 198 int** /*rdeps*/, void* /*state*/); 199 typedef int (*CustomOpCreateFunc)(const char* /*ctx*/, int /*num_inputs*/, 200 unsigned** /*shapes*/, const int* /*ndims*/, 201 const int* /*dtypes*/, struct MXCallbackList* /*ret*/, 202 void* /*state*/); 203 typedef int (*CustomOpPropCreator)(const char* /*op_type*/, const int /*num_kwargs*/, 204 const char** /*keys*/, const char** /*values*/, 205 struct MXCallbackList* /*ret*/); 206 207 208 enum CustomFunctionCallbacks { 209 kCustomFunctionBackward, 210 kCustomFunctionDelete 211 }; 212 213 typedef int (*CustomFunctionBwdFunc)(int /*num_ograds*/, int /*num_igrads*/, void** /*ptrs*/, 214 const int* /*reqs*/, const int /*is_train*/, 215 void* /*state*/); 216 typedef int (*CustomFunctionDelFunc)(void* /*state*/); 217 218 /*! 219 * \brief return str message of the last error 220 * all function in this file will return 0 when success 221 * and -1 when an error occured, 222 * MXGetLastError can be called to retrieve the error 223 * 224 * this function is threadsafe and can be called by different thread 225 * \return error info 226 */ 227 MXNET_DLL const char *MXGetLastError(); 228 229 //------------------------------------- 230 // Part 0: Global State setups 231 //------------------------------------- 232 233 /*! 234 * \brief Load library dynamically 235 * \param path to the library .so file 236 * \param verbose 0 for quiet, 1 for verbose 237 * \return 0 when success, -1 when failure happens. 238 */ 239 MXNET_DLL int MXLoadLib(const char *path, unsigned verbose); 240 241 /*! 242 * \brief Get list of features supported on the runtime 243 * \param libFeature pointer to array of LibFeature 244 * \param size of the array 245 * \return 0 when success, -1 when failure happens. 246 */ 247 MXNET_DLL int MXLibInfoFeatures(const struct LibFeature **libFeature, size_t *size); 248 249 /*! 250 * \brief Seed all global random number generators in mxnet. 251 * \param seed the random number seed. 252 * \return 0 when success, -1 when failure happens. 253 */ 254 MXNET_DLL int MXRandomSeed(int seed); 255 256 /*! 257 * \brief Seed the global random number generator of the given device. 258 * \param seed the random number seed. 259 * \return 0 when success, -1 when failure happens. 260 */ 261 MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id); 262 263 /*! 264 * \brief Notify the engine about a shutdown, 265 * This can help engine to print less messages into display. 266 * 267 * User do not have to call this function. 268 * \return 0 when success, -1 when failure happens. 269 */ 270 MXNET_DLL int MXNotifyShutdown(); 271 272 /*! 273 * \brief Set up configuration of profiler for the process passed as profile_process in keys 274 * \param num_params Number of parameters 275 * \param keys array of parameter keys 276 * \param vals array of parameter values 277 * \param kvstoreHandle handle to kvstore 278 * \return 0 when success, -1 when failure happens. 279 */ 280 MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys, 281 const char* const* vals, 282 KVStoreHandle kvstoreHandle); 283 284 /*! 285 * \brief Set up configuration of profiler for worker/current process 286 * \param num_params Number of parameters 287 * \param keys array of parameter keys 288 * \param vals array of parameter values 289 * \return 0 when success, -1 when failure happens. 290 */ 291 MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals); 292 293 /*! 294 * \brief Set up state of profiler for either worker or server process 295 * \param state indicate the working state of profiler, 296 * profiler not running when state == 0, 297 * profiler running when state == 1 298 * \param profile_process an int, 299 * when 0 command is for worker/current process, 300 * when 1 command is for server process 301 * \param kvstoreHandle handle to kvstore, needed for server process profiling 302 * \return 0 when success, -1 when failure happens. 303 */ 304 MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process, 305 KVStoreHandle kvStoreHandle); 306 307 /*! 308 * \brief Set up state of profiler for current process 309 * \param state indicate the working state of profiler, 310 * profiler not running when state == 0, 311 * profiler running when state == 1 312 * \return 0 when success, -1 when failure happens. 313 */ 314 MXNET_DLL int MXSetProfilerState(int state); 315 316 /*! 317 * \brief Save profile and stop profiler 318 * \param finished true if stat output should stop after this point 319 * \param profile_process an int, 320 * when 0 command is for worker/current process, 321 * when 1 command is for server process 322 * \param kvstoreHandle handle to kvstore 323 * \return 0 when success, -1 when failure happens. 324 */ 325 MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle); 326 327 328 /*! 329 * \brief Save profile and stop profiler for worker/current process 330 * \param finished true if stat output should stop after this point 331 * \return 0 when success, -1 when failure happens. 332 */ 333 MXNET_DLL int MXDumpProfile(int finished); 334 335 336 /*! 337 * \brief Deprecated, use MXAggregateProfileStatsPrintEx instead. 338 * \param out_str Will receive a pointer to the output string 339 * \param reset Clear the aggregate stats after printing 340 * \return 0 when success, -1 when failure happens. 341 * \note 342 */ 343 MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset); 344 345 /*! 346 * \brief Print sorted aggregate stats to the a string 347 * How aggregate stats are stored will not change 348 * \param out_str will receive a pointer to the output string 349 * \param reset clear the aggregate stats after printing 350 * \param format whether to return in tabular or json format 351 * \param sort_by sort by total, avg, min, max, or count 352 * \param ascending whether to sort ascendingly 353 * \return 0 when success, -1 when failure happens. 354 * \note 355 */ 356 MXNET_DLL int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, 357 int sort_by, int ascending); 358 359 /*! 360 * \brief Pause profiler tuning collection 361 * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues 362 * \param profile_process integer which denotes whether to process worker or server process 363 * \param kvstoreHandle handle to kvstore 364 * \return 0 when success, -1 when failure happens. 365 * \note pausing and resuming is global and not recursive 366 */ 367 MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle); 368 369 /*! 370 * \brief Pause profiler tuning collection for worker/current process 371 * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues 372 * \return 0 when success, -1 when failure happens. 373 * \note pausing and resuming is global and not recursive 374 */ 375 MXNET_DLL int MXProfilePause(int paused); 376 377 /*! 378 * \brief Create profiling domain 379 * \param domain String representing the domain name to create 380 * \param out Return domain object 381 * \return 0 when success, -1 when failure happens. 382 */ 383 MXNET_DLL int MXProfileCreateDomain(const char *domain, ProfileHandle *out); 384 385 /*! 386 * \brief Create profile task 387 * \param name Name of the task 388 * \param domain Domain of the task 389 * \param out Output handle 390 * \return 0 when success, -1 when failure happens. 391 */ 392 MXNET_DLL int MXProfileCreateTask(ProfileHandle domain, 393 const char *task_name, 394 ProfileHandle *out); 395 396 /*! 397 * \brief Create profile frame 398 * \param name Name of the frame 399 * \param domain Domain of the frame 400 * \param out Output handle 401 * \return 0 when success, -1 when failure happens. 402 */ 403 MXNET_DLL int MXProfileCreateFrame(ProfileHandle domain, 404 const char *frame_name, 405 ProfileHandle *out); 406 407 /*! 408 * \brief Create profile event 409 * \param name Name of the event 410 * \param out Output handle 411 * \return 0 when success, -1 when failure happens. 412 */ 413 MXNET_DLL int MXProfileCreateEvent(const char *event_name, ProfileHandle *out); 414 415 /*! 416 * \brief Create profile counter 417 * \param name Name of the counter 418 * \param domain Domain of the counter 419 * \param out Output handle 420 * \return 0 when success, -1 when failure happens. 421 */ 422 MXNET_DLL int MXProfileCreateCounter(ProfileHandle domain, 423 const char *counter_name, 424 ProfileHandle *out); 425 426 /*! 427 * \brief Destroy a frame 428 * \param frame_handle Handle to frame to destroy 429 * \return 0 when success, -1 when failure happens. 430 */ 431 MXNET_DLL int MXProfileDestroyHandle(ProfileHandle frame_handle); 432 433 /*! 434 * \brief Start timing the duration of a profile duration object such as an event, task or frame 435 * \param duration_handle handle to the duration object 436 * \return 0 when success, -1 when failure happens. 437 */ 438 MXNET_DLL int MXProfileDurationStart(ProfileHandle duration_handle); 439 440 /*! 441 * \brief Stop timing the duration of a profile duration object such as an event, task or frame 442 * \param duration_handle handle to the duration object 443 * \return 0 when success, -1 when failure happens. 444 */ 445 MXNET_DLL int MXProfileDurationStop(ProfileHandle duration_handle); 446 447 /*! 448 * \brief Set a counter, given its handle 449 * \param counter_handle Handle to counter to set 450 * \param value Value to set the counter to (64-bit unsigned integer) 451 * \return 0 when success, -1 when failure happens. 452 */ 453 MXNET_DLL int MXProfileSetCounter(ProfileHandle counter_handle, uint64_t value); 454 455 /*! 456 * \brief Adjust a counter by the given amount, given its handle 457 * \param counter_handle Handle to counter to adjust 458 * \param value Value to adjust the counter by (64-bit signed integer) 459 * \return 0 when success, -1 when failure happens. 460 */ 461 MXNET_DLL int MXProfileAdjustCounter(ProfileHandle counter_handle, int64_t value); 462 463 /*! 464 * \brief Mark a single instant in time 465 * \param domain Domain of the marker 466 * \param instant_marker_name Name of the marker 467 * \param scope Scope of marker ('global', 'process', 'thread', 'task', 'marker') 468 * \return 0 when success, -1 when failure happens. 469 */ 470 MXNET_DLL int MXProfileSetMarker(ProfileHandle domain, 471 const char *instant_marker_name, 472 const char *scope); 473 474 /*! 475 * \brief Set the number of OMP threads to use 476 * \param thread_num Number of OMP threads desired 477 * \return 0 when success, -1 when failure happens. 478 */ 479 MXNET_DLL int MXSetNumOMPThreads(int thread_num); 480 481 /*! 482 * \brief set bulk execution limit 483 * \param bulk_size new bulk_size 484 * \param prev_bulk_size previous bulk_size 485 */ 486 MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); 487 488 /*! 489 * \brief Get the number of GPUs. 490 * \param pointer to int that will hold the number of GPUs available. 491 * \return 0 when success, -1 when failure happens. 492 */ 493 MXNET_DLL int MXGetGPUCount(int* out); 494 495 /*! 496 * \brief get the free and total available memory on a GPU 497 * Note: Deprecated, use MXGetGPUMemoryInformation64 instead. 498 * \param dev the GPU number to query 499 * \param free_mem pointer to the integer holding free GPU memory 500 * \param total_mem pointer to the integer holding total GPU memory 501 * \return 0 when success, -1 when failure happens 502 */ 503 MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem); 504 505 /*! 506 * \brief get the free and total available memory on a GPU 507 * \param dev the GPU number to query 508 * \param free_mem pointer to the uint64_t holding free GPU memory 509 * \param total_mem pointer to the uint64_t holding total GPU memory 510 * \return 0 when success, -1 when failure happens 511 */ 512 MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem); 513 514 /*! 515 * \brief get the MXNet library version as an integer 516 * \param pointer to the integer holding the version number 517 * \return 0 when success, -1 when failure happens 518 */ 519 MXNET_DLL int MXGetVersion(int *out); 520 521 /*! 522 * \brief Load TVM operator from the binary library 523 * \param libpath TVM operators lib file 524 * \return 0 when success, -1 when failure happens 525 */ 526 #if MXNET_USE_TVM_OP 527 MXNET_DLL int MXLoadTVMOp(const char *libpath); 528 529 struct OtherOptionEntity { 530 int val; 531 }; 532 533 struct OtherOptionSpace { 534 OtherOptionEntity* entities; 535 int entities_size; 536 }; 537 538 struct ConfigSpace { 539 int entity_map_size; 540 char** entity_map_key; 541 OtherOptionEntity* entity_map_val; 542 int space_map_size; 543 char** space_map_key; 544 OtherOptionSpace* space_map_val; 545 }; 546 547 typedef struct ConfigSpaces { 548 int spaces_size; 549 char** spaces_key; 550 ConfigSpace* spaces_val; 551 } ConfigSpaces; 552 553 MXNET_DLL int MXLoadTVMConfig(ConfigSpaces config); 554 #endif // MXNET_USE_TVM_OP 555 556 557 //------------------------------------- 558 // Part 1: NDArray creation and deletion 559 //------------------------------------- 560 /*! 561 * \brief create a NDArray handle that is not initialized 562 * can be used to pass in as mutate variables 563 * to hold the result of NDArray 564 * \param out the returning handle 565 * \return 0 when success, -1 when failure happens 566 */ 567 MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); 568 /*! 569 * \brief create a NDArray with specified shape 570 * \param shape the pointer to the shape 571 * \param ndim the dimension of the shape 572 * \param dev_type device type, specify device we want to take 573 * \param dev_id the device id of the specific device 574 * \param delay_alloc whether to delay allocation until 575 * the narray is first mutated 576 * \param out the returning handle 577 * \return 0 when success, -1 when failure happens 578 */ 579 MXNET_DLL int MXNDArrayCreate(const uint32_t *shape, 580 uint32_t ndim, 581 int dev_type, 582 int dev_id, 583 int delay_alloc, 584 NDArrayHandle *out); 585 586 /*! 587 * \brief create a NDArray with specified shape and data type 588 * This api is available when MXNet is built with flag 589 * USE_INT64_TENSOR_SIZE=0 (by default) 590 * \param shape the pointer to the shape 591 * \param ndim the dimension of the shape 592 * \param dev_type device type, specify device we want to take 593 * \param dev_id the device id of the specific device 594 * \param delay_alloc whether to delay allocation until 595 * the narray is first mutated 596 * \param dtype data type of created array 597 * \param out the returning handle 598 * \return 0 when success, -1 when failure happens 599 */ 600 MXNET_DLL int MXNDArrayCreateEx(const uint32_t *shape, 601 uint32_t ndim, 602 int dev_type, 603 int dev_id, 604 int delay_alloc, 605 int dtype, 606 NDArrayHandle *out); 607 608 /*! 609 * \brief create a NDArray with specified shape and data type 610 * This api is available when MXNet is built with flag 611 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 612 * \param shape the pointer to int64_t shape 613 * \param ndim the dimension of the shape 614 * \param dev_type device type, specify device we want to take 615 * \param dev_id the device id of the specific device 616 * \param delay_alloc whether to delay allocation until 617 * the narray is first mutated 618 * \param dtype data type of created array 619 * \param out the returning handle 620 * \return 0 when success, -1 when failure happens 621 */ 622 MXNET_DLL int MXNDArrayCreateEx64(const int64_t *shape, 623 int ndim, 624 int dev_type, 625 int dev_id, 626 int delay_alloc, 627 int dtype, 628 NDArrayHandle *out); 629 630 /*! 631 * \brief create an empty sparse NDArray with specified shape and data type 632 * This api is available when MXNet is built with flag 633 * USE_INT64_TENSOR_SIZE=0 (by default) 634 * \param storage_type the storage type of the ndarray 635 * \param shape the pointer to the shape 636 * \param ndim the dimension of the shape 637 * \param dev_type device type, specify device we want to take 638 * \param dev_id the device id of the specific device 639 * \param delay_alloc whether to delay allocation until 640 * the narray is first mutated 641 * \param dtype data type of created array 642 * \param num_aux the number of aux data to support this ndarray 643 * \param aux_type data type of the aux data for the created array 644 * \param aux_ndims the dimension of the shapes of aux data 645 * \param aux_shape the shapes of aux data 646 * \param out the returning handle 647 * \return 0 when success, -1 when failure happens 648 */ 649 MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, 650 const uint32_t *shape, 651 uint32_t ndim, 652 int dev_type, 653 int dev_id, 654 int delay_alloc, 655 int dtype, 656 uint32_t num_aux, 657 int *aux_type, 658 uint32_t *aux_ndims, 659 const uint32_t *aux_shape, 660 NDArrayHandle *out); 661 662 /*! 663 * \brief create an empty sparse NDArray with specified shape and data type 664 * This api is available when MXNet is built with flag 665 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 666 * \param storage_type the storage type of the ndarray 667 * \param shape the pointer to the shape 668 * \param ndim the dimension of the shape 669 * \param dev_type device type, specify device we want to take 670 * \param dev_id the device id of the specific device 671 * \param delay_alloc whether to delay allocation until 672 * the narray is first mutated 673 * \param dtype data type of created array 674 * \param num_aux the number of aux data to support this ndarray 675 * \param aux_type data type of the aux data for the created array 676 * \param aux_ndims the dimension of the shapes of aux data 677 * \param aux_shape the shapes of aux data 678 * \param out the returning handle 679 * \return 0 when success, -1 when failure happens 680 */ 681 MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type, 682 const int64_t *shape, 683 int ndim, 684 int dev_type, 685 int dev_id, 686 int delay_alloc, 687 int dtype, 688 uint32_t num_aux, 689 int *aux_type, 690 int *aux_ndims, 691 const int64_t *aux_shape, 692 NDArrayHandle *out); 693 694 /*! 695 * \brief create a NDArray handle that is loaded from raw bytes. 696 * \param buf the head of the raw bytes 697 * \param size size of the raw bytes 698 * \param out the returning handle 699 * \return 0 when success, -1 when failure happens 700 */ 701 MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, 702 size_t size, 703 NDArrayHandle *out); 704 /*! 705 * \brief save the NDArray into raw bytes. 706 * \param handle the NDArray handle 707 * \param out_size size of the raw bytes 708 * \param out_buf the head of returning memory bytes. 709 * \return 0 when success, -1 when failure happens 710 */ 711 MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, 712 size_t *out_size, 713 const char **out_buf); 714 /*! 715 * \brief Save list of narray into the file. 716 * \param fname name of the file. 717 * \param num_args number of arguments to save. 718 * \param args the array of NDArrayHandles to be saved. 719 * \param keys the name of the NDArray, optional, can be NULL 720 * \return 0 when success, -1 when failure happens 721 */ 722 MXNET_DLL int MXNDArraySave(const char* fname, 723 uint32_t num_args, 724 NDArrayHandle* args, 725 const char** keys); 726 /*! 727 * \brief Load list of narray from the file. 728 * \param fname name of the file. 729 * \param out_size number of narray loaded. 730 * \param out_arr head of the returning narray handles. 731 * \param out_name_size size of output name arrray. 732 * \param out_names the names of returning NDArrays, can be NULL 733 * \return 0 when success, -1 when failure happens 734 */ 735 MXNET_DLL int MXNDArrayLoad(const char* fname, 736 uint32_t *out_size, 737 NDArrayHandle** out_arr, 738 uint32_t *out_name_size, 739 const char*** out_names); 740 741 /*! 742 * \brief Load list / dictionary of narrays from file content loaded into memory. 743 * This will load a list of ndarrays in a similar 744 * manner to MXNDArrayLoad, however, it loads from 745 * buffer containing the contents of a file, rather than 746 * from a specified file. 747 * \param ndarray_buffer pointer to the start of the ndarray file content 748 * \param size size of the file 749 * \param out_size number of narray loaded. 750 * \param out_arr head of the returning narray handles. 751 * \param out_name_size size of output name arrray. 752 * \param out_names the names of returning NDArrays, can be NULL 753 * \return 0 when success, -1 when failure happens 754 */ 755 MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, 756 size_t size, 757 uint32_t *out_size, 758 NDArrayHandle** out_arr, 759 uint32_t *out_name_size, 760 const char*** out_names); 761 762 /*! 763 * \brief Perform a synchronize copy from a contiguous CPU memory region. 764 * 765 * This function will call WaitToWrite before the copy is performed. 766 * This is useful to copy data from existing memory region that are 767 * not wrapped by NDArray(thus dependency not being tracked). 768 * 769 * \param handle the NDArray handle 770 * \param data the data source to copy from. 771 * \param size the memory size we want to copy from. 772 */ 773 MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, 774 const void *data, 775 size_t size); 776 /*! 777 * \brief Perform a synchronize copyto a contiguous CPU memory region. 778 * 779 * This function will call WaitToRead before the copy is performed. 780 * This is useful to copy data from existing memory region that are 781 * not wrapped by NDArray(thus dependency not being tracked). 782 * 783 * \param handle the NDArray handle 784 * \param data the data source to copy into. 785 * \param size the memory size we want to copy into. 786 */ 787 MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, 788 void *data, 789 size_t size); 790 791 /*! 792 * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0 793 * This function blocks. Do not use it in performance critical code. 794 * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated 795 * \param handle_src handle of a src ndarray which has default storage type 796 * \param i dst data blob indicator 797 */ 798 MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst, 799 const NDArrayHandle handle_src, 800 const int i); 801 802 /*! 803 * \brief check whether the NDArray format is valid 804 * \param full_check if `True`, rigorous check, O(N) operations 805 * Otherwise basic check, O(1) operations 806 */ 807 MXNET_DLL int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check); 808 809 /*! 810 * \brief Wait until all the pending writes with respect NDArray are finished. 811 * Always call this before read data out synchronizely. 812 * \param handle the NDArray handle 813 * \return 0 when success, -1 when failure happens 814 */ 815 MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle); 816 817 /*! 818 * \brief Wait until all the pending read/write with respect NDArray are finished. 819 * Always call this before write data into NDArray synchronizely. 820 * \param handle the NDArray handle 821 * \return 0 when success, -1 when failure happens 822 */ 823 MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle); 824 825 /*! 826 * \brief wait until all delayed operations in 827 * the system is completed 828 * \return 0 when success, -1 when failure happens 829 */ 830 MXNET_DLL int MXNDArrayWaitAll(); 831 832 /*! 833 * \brief free the narray handle 834 * \param handle the handle to be freed 835 * \return 0 when success, -1 when failure happens 836 */ 837 MXNET_DLL int MXNDArrayFree(NDArrayHandle handle); 838 839 /*! 840 * \brief Slice the NDArray along axis 0. 841 * This api is available when MXNet is built with flag 842 * USE_INT64_TENSOR_SIZE=0 (by default) 843 * \param handle the handle to the NDArray 844 * \param slice_begin The beginning index of slice 845 * \param slice_end The ending index of slice 846 * \param out The NDArrayHandle of sliced NDArray 847 * \return 0 when success, -1 when failure happens 848 */ 849 MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, 850 uint32_t slice_begin, 851 uint32_t slice_end, 852 NDArrayHandle *out); 853 854 /*! 855 * \brief Slice the NDArray along axis 0. 856 * This api is available when MXNet is built with flag 857 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 858 * \param handle the handle to the NDArray 859 * \param slice_begin The beginning index of slice 860 * \param slice_end The ending index of slice 861 * \param out The NDArrayHandle of sliced NDArray 862 * \return 0 when success, -1 when failure happens 863 */ 864 MXNET_DLL int MXNDArraySlice64(NDArrayHandle handle, 865 int64_t slice_begin, 866 int64_t slice_end, 867 NDArrayHandle *out); 868 869 /*! 870 * \brief Index the NDArray along axis 0. 871 * This api is available when MXNet is built with flag 872 * USE_INT64_TENSOR_SIZE=0 (by default) 873 * \param handle the handle to the NDArray 874 * \param idx the index 875 * \param out The NDArrayHandle of output NDArray 876 * \return 0 when success, -1 when failure happens 877 */ 878 MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, 879 uint32_t idx, 880 NDArrayHandle *out); 881 882 /*! 883 * \brief Index the NDArray along axis 0. 884 * This api is available when MXNet is built with flag 885 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 886 * \param handle the handle to the NDArray 887 * \param idx the index 888 * \param out The NDArrayHandle of output NDArray 889 * \return 0 when success, -1 when failure happens 890 */ 891 MXNET_DLL int MXNDArrayAt64(NDArrayHandle handle, 892 int64_t idx, 893 NDArrayHandle *out); 894 895 /*! 896 * \brief get the storage type of the array 897 */ 898 MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, 899 int *out_storage_type); 900 901 /*! 902 * \brief Reshape the NDArray. 903 * \param handle the handle to the narray 904 * \param ndim number of dimensions of new shape 905 * \param dims new shape 906 * \param out the NDArrayHandle of reshaped NDArray 907 * \return 0 when success, -1 when failure happens 908 */ 909 MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, 910 int ndim, 911 int *dims, 912 NDArrayHandle *out); 913 914 /*! 915 * \brief Reshape the NDArray. 916 * \param handle the handle to the narray 917 * \param ndim number of dimensions of new shape 918 * \param dims new shape 919 * \param out the NDArrayHandle of reshaped NDArray 920 * \return 0 when success, -1 when failure happens 921 */ 922 MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, 923 int ndim, 924 dim_t *dims, 925 bool reverse, 926 NDArrayHandle *out); 927 /*! 928 * \brief DEPRECATED. Use MXNDArrayGetShapeEx instead. 929 * get the shape of the array 930 * \param handle the handle to the narray 931 * \param out_dim the output dimension 932 * \param out_pdata pointer holder to get data pointer of the shape 933 * \return 0 when success, -1 when failure happens 934 */ 935 MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, 936 uint32_t *out_dim, 937 const uint32_t **out_pdata); 938 939 /*! 940 * \brief get the shape of the array 941 * This api is available when MXNet is built with flag 942 * USE_INT64_TENSOR_SIZE=0 (by default) 943 * \param handle the handle to the narray 944 * \param out_dim the output dimension 945 * \param out_pdata pointer holder to get data pointer of the shape 946 * \return 0 when success, -1 when failure happens 947 */ 948 MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, 949 int *out_dim, 950 const int **out_pdata); 951 952 /*! 953 * \brief get the shape of the array 954 * This api is available when MXNet is built with flag 955 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 956 * \param handle the handle to the narray 957 * \param out_dim the output dimension 958 * \param out_pdata pointer holder to get data pointer of the shape 959 * \return 0 when success, -1 when failure happens 960 */ 961 MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle, 962 int *out_dim, 963 const int64_t **out_pdata); 964 965 /*! 966 * \brief get the content of the data in NDArray 967 * \param handle the handle to the ndarray 968 * \param out_pdata pointer holder to get pointer of data 969 * \return 0 when success, -1 when failure happens 970 */ 971 MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, 972 void **out_pdata); 973 /*! 974 * \brief Create a reference view of NDArray that 975 * represents as DLManagedTensor 976 * Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead or 977 * MXNDArrayWaitToWrite before calling MXNDArrayToDLPack. 978 * \param handle the handle to the ndarray 979 * \param out_dlpack pointer holder to get pointer of DLManagedTensor 980 * \return 0 when success, -1 when failure happens 981 */ 982 MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, 983 DLManagedTensorHandle *out_dlpack); 984 985 /*! 986 * \brief DEPRECATED. Use MXNDArrayFromDLPackEx instead. 987 988 * 989 * This allows us to create a NDArray using the memory 990 * allocated by an external deep learning framework 991 * that is DLPack compatible. 992 * 993 * The memory is retained until the NDArray went out of scope. 994 * 995 * \param dlpack the pointer of the input DLManagedTensor 996 * \param transient_handle whether the handle will be destructed before calling the deleter 997 * \param out_handle pointer holder to get pointer of NDArray 998 * \return 0 when success, -1 when failure happens 999 */ 1000 MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, 1001 NDArrayHandle *out_handle); 1002 1003 /*! 1004 * \brief Create a NDArray backed by a dlpack tensor. 1005 * 1006 * This allows us to create a NDArray using the memory 1007 * allocated by an external deep learning framework 1008 * that is DLPack compatible. 1009 * 1010 * The memory is retained until the NDArray went out of scope. 1011 * 1012 * \param dlpack the pointer of the input DLManagedTensor 1013 * \param transient_handle whether the handle will be destructed before calling the deleter 1014 * \param out_handle pointer holder to get pointer of NDArray 1015 * \return 0 when success, -1 when failure happens 1016 */ 1017 MXNET_DLL int MXNDArrayFromDLPackEx(DLManagedTensorHandle dlpack, 1018 const bool transient_handle, 1019 NDArrayHandle *out_handle); 1020 1021 /*! 1022 * \brief Delete a dlpack tensor 1023 * \param dlpack the pointer of the input DLManagedTensor 1024 * \return 0 when success, -1 when failure happens 1025 */ 1026 MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack); 1027 1028 /*! 1029 * \brief get the type of the data in NDArray 1030 * \param handle the handle to the narray 1031 * \param out_dtype pointer holder to get type of data 1032 * \return 0 when success, -1 when failure happens 1033 */ 1034 MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, 1035 int *out_dtype); 1036 1037 /*! 1038 * \brief get the type of the ith aux data in NDArray 1039 * This api is available when MXNet is built with flag 1040 * USE_INT64_TENSOR_SIZE=0 (by default) 1041 * \param handle the handle to the narray 1042 * \param i the index of the aux data 1043 * \param out_type pointer holder to get type of aux data 1044 * \return 0 when success, -1 when failure happens 1045 */ 1046 MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, 1047 uint32_t i, 1048 int *out_type); 1049 1050 /*! 1051 * \brief get the type of the ith aux data in NDArray 1052 * This api is available when MXNet is built with flag 1053 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 1054 * \param handle the handle to the narray 1055 * \param i the index of the aux data 1056 * \param out_type pointer holder to get type of aux data 1057 * \return 0 when success, -1 when failure happens 1058 */ 1059 MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle, 1060 int64_t i, 1061 int *out_type); 1062 1063 /*! 1064 * \brief Get a deep copy of the ith aux data blob 1065 * This api is available when MXNet is built with flag 1066 * USE_INT64_TENSOR_SIZE=0 (by default) 1067 * in the form of an NDArray of default storage type. 1068 * This function blocks. Do not use it in performance critical code. 1069 */ 1070 MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, 1071 uint32_t i, 1072 NDArrayHandle *out); 1073 1074 /*! 1075 * \brief Get a deep copy of the ith aux data blob 1076 * This api is available when MXNet is built with flag 1077 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 1078 * in the form of an NDArray of default storage type. 1079 * This function blocks. Do not use it in performance critical code. 1080 */ 1081 MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle, 1082 int64_t i, 1083 NDArrayHandle *out); 1084 1085 /*! 1086 * \brief Get a deep copy of the data blob 1087 * in the form of an NDArray of default storage type. 1088 * This function blocks. Do not use it in performance critical code. 1089 */ 1090 MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, 1091 NDArrayHandle *out); 1092 /*! 1093 * \brief get the context of the NDArray 1094 * \param handle the handle to the narray 1095 * \param out_dev_type the output device type 1096 * \param out_dev_id the output device id 1097 * \return 0 when success, -1 when failure happens 1098 */ 1099 MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, 1100 int *out_dev_type, 1101 int *out_dev_id); 1102 /*! 1103 * \brief return gradient buffer attached to this NDArray 1104 * \param handle NDArray handle 1105 * \return 0 when success, -1 when failure happens 1106 */ 1107 MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out); 1108 /*! 1109 * \brief detach and ndarray from computation graph by clearing entry_ 1110 * \param handle NDArray handle 1111 * \return 0 when success, -1 when failure happens 1112 */ 1113 MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out); 1114 /*! 1115 * \brief set the flag for gradient array state. 1116 * \param handle NDArray handle 1117 * \param state the new state. 1118 * \return 0 when success, -1 when failure happens 1119 */ 1120 MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state); 1121 /*! 1122 * \brief set the flag for gradient array state. 1123 * \param handle NDArray handle 1124 * \param state the new state. 1125 * \return 0 when success, -1 when failure happens 1126 */ 1127 MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); 1128 //-------------------------------- 1129 // Part 2: functions on NDArray 1130 //-------------------------------- 1131 /*! 1132 * \brief list all the available functions handles 1133 * most user can use it to list all the needed functions 1134 * \param out_size the size of returned array 1135 * \param out_array the output function array 1136 * \return 0 when success, -1 when failure happens 1137 */ 1138 MXNET_DLL int MXListFunctions(uint32_t *out_size, 1139 FunctionHandle **out_array); 1140 1141 /*! 1142 * \brief get the function handle by name 1143 * \param name the name of the function 1144 * \param out the corresponding function handle 1145 * \return 0 when success, -1 when failure happens 1146 */ 1147 MXNET_DLL int MXGetFunction(const char *name, 1148 FunctionHandle *out); 1149 /*! 1150 * \brief Get the information of the function handle. 1151 * \param fun The function handle. 1152 * \param name The returned name of the function. 1153 * \param description The returned description of the function. 1154 * \param num_args Number of arguments. 1155 * \param arg_names Name of the arguments. 1156 * \param arg_type_infos Type information about the arguments. 1157 * \param arg_descriptions Description information about the arguments. 1158 * \param return_type Return type of the function. 1159 * \return 0 when success, -1 when failure happens 1160 */ 1161 MXNET_DLL int MXFuncGetInfo(FunctionHandle fun, 1162 const char **name, 1163 const char **description, 1164 uint32_t *num_args, 1165 const char ***arg_names, 1166 const char ***arg_type_infos, 1167 const char ***arg_descriptions, 1168 const char **return_type DEFAULT(NULL)); 1169 /*! 1170 * \brief get the argument requirements of the function 1171 * \param fun input function handle 1172 * \param num_use_vars how many NDArrays to be passed in as used_vars 1173 * \param num_scalars scalar variable is needed 1174 * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars 1175 * \param type_mask the type mask of this function 1176 * \return 0 when success, -1 when failure happens 1177 * \sa MXFuncInvoke 1178 */ 1179 MXNET_DLL int MXFuncDescribe(FunctionHandle fun, 1180 uint32_t *num_use_vars, 1181 uint32_t *num_scalars, 1182 uint32_t *num_mutate_vars, 1183 int *type_mask); 1184 /*! 1185 * \brief invoke a function, the array size of passed in arguments 1186 * must match the values in the 1187 * \param fun the function 1188 * \param use_vars the normal arguments passed to function 1189 * \param scalar_args the scalar qarguments 1190 * \param mutate_vars the mutate arguments 1191 * \return 0 when success, -1 when failure happens 1192 * \sa MXFuncDescribeArgs 1193 */ 1194 MXNET_DLL int MXFuncInvoke(FunctionHandle fun, 1195 NDArrayHandle *use_vars, 1196 float *scalar_args, 1197 NDArrayHandle *mutate_vars); 1198 /*! 1199 * \brief invoke a function, the array size of passed in arguments 1200 * must match the values in the 1201 * \param fun the function 1202 * \param use_vars the normal arguments passed to function 1203 * \param scalar_args the scalar qarguments 1204 * \param mutate_vars the mutate arguments 1205 * \param num_params number of keyword parameters 1206 * \param param_keys keys for keyword parameters 1207 * \param param_vals values for keyword parameters 1208 * \return 0 when success, -1 when failure happens 1209 * \sa MXFuncDescribeArgs 1210 */ 1211 MXNET_DLL int MXFuncInvokeEx(FunctionHandle fun, 1212 NDArrayHandle *use_vars, 1213 float *scalar_args, 1214 NDArrayHandle *mutate_vars, 1215 int num_params, 1216 char **param_keys, 1217 char **param_vals); 1218 /*! 1219 * \brief invoke a nnvm op and imperative function 1220 * \param creator the op 1221 * \param num_inputs number of input NDArrays 1222 * \param inputs input NDArrays 1223 * \param num_outputs number of output NDArrays 1224 * \param outputs output NDArrays 1225 * \param num_params number of keyword parameters 1226 * \param param_keys keys for keyword parameters 1227 * \param param_vals values for keyword parameters 1228 * \return 0 when success, -1 when failure happens 1229 */ 1230 MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, 1231 int num_inputs, 1232 NDArrayHandle *inputs, 1233 int *num_outputs, 1234 NDArrayHandle **outputs, 1235 int num_params, 1236 const char **param_keys, 1237 const char **param_vals); 1238 /*! 1239 * \brief invoke a nnvm op and imperative function 1240 * \param creator the op 1241 * \param num_inputs number of input NDArrays 1242 * \param inputs input NDArrays 1243 * \param num_outputs number of output NDArrays 1244 * \param outputs output NDArrays 1245 * \param num_params number of keyword parameters 1246 * \param param_keys keys for keyword parameters 1247 * \param param_vals values for keyword parameters 1248 * \param out_stypes output ndarrays' stypes 1249 * \return 0 when success, -1 when failure happens 1250 */ 1251 MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator, 1252 int num_inputs, 1253 NDArrayHandle *inputs, 1254 int *num_outputs, 1255 NDArrayHandle **outputs, 1256 int num_params, 1257 const char **param_keys, 1258 const char **param_vals, 1259 const int **out_stypes); 1260 /*! 1261 * \brief set whether to record operator for autograd 1262 * \param is_recording 1 when recording, 0 when not recording. 1263 * \param prev returns the previous status before this set. 1264 * \return 0 when success, -1 when failure happens 1265 */ 1266 MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int* prev); 1267 /*! 1268 * \brief set whether to record operator for autograd 1269 * \param is_training 1 when training, 0 when testing 1270 * \param prev returns the previous status before this set. 1271 * \return 0 when success, -1 when failure happens 1272 */ 1273 MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev); 1274 /*! 1275 * \brief get whether autograd recording is on 1276 * \param curr returns the current status. 1277 * \return 0 when success, -1 when failure happens 1278 */ 1279 MXNET_DLL int MXAutogradIsRecording(bool* curr); 1280 /*! 1281 * \brief get whether training mode is on 1282 * \param curr returns the current status. 1283 * \return 0 when success, -1 when failure happens 1284 */ 1285 MXNET_DLL int MXAutogradIsTraining(bool* curr); 1286 /*! 1287 * \brief get whether numpy compatibility is on 1288 * \param curr returns the current status 1289 * \return 0 when success, -1 when failure happens 1290 */ 1291 MXNET_DLL int MXIsNumpyShape(int* curr); 1292 /*! 1293 * \brief set numpy compatibility switch 1294 * \param is_np_shape 1 when numpy shape semantics is thread local on, 1295 * 2 when numpy shape semantics is global on and 0 when off 1296 * \param prev returns the previous status before this set 1297 * \return 0 when success, -1 when failure happens 1298 */ 1299 MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev); 1300 /*! 1301 * \brief mark NDArrays as variables to compute gradient for autograd 1302 * \param num_var number of variable NDArrays 1303 * \param var_handles variable NDArrays 1304 * \return 0 when success, -1 when failure happens 1305 */ 1306 MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var, 1307 NDArrayHandle *var_handles, 1308 uint32_t *reqs_array, 1309 NDArrayHandle *grad_handles); 1310 /*! 1311 * \brief compute the gradient of outputs w.r.t variabels 1312 * \param num_output number of output NDArray 1313 * \param output_handles output NDArrays 1314 * \return 0 when success, -1 when failure happens 1315 */ 1316 MXNET_DLL int MXAutogradComputeGradient(uint32_t num_output, 1317 NDArrayHandle* output_handles); 1318 /*! 1319 * \brief compute the gradient of outputs w.r.t variabels 1320 * \param num_output number of output NDArray 1321 * \param output_handles output NDArrays 1322 * \param ograd_handles head gradient for NDArrays 1323 * \param retain_graph whether to keep the graph after backward 1324 * \return 0 when success, -1 when failure happens 1325 */ 1326 MXNET_DLL int MXAutogradBackward(uint32_t num_output, 1327 NDArrayHandle* output_handles, 1328 NDArrayHandle* ograd_handles, 1329 int retain_graph); 1330 /*! 1331 * \brief compute the gradient of outputs w.r.t variabels 1332 * \param num_output number of output NDArray 1333 * \param output_handles output NDArrays 1334 * \param ograd_handles head gradient for NDArrays 1335 * \param num_variables number of variables 1336 * \param 1337 * \param retain_graph whether to keep the graph after backward 1338 * \param is_train whether to do backward for training or inference 1339 * \return 0 when success, -1 when failure happens 1340 */ 1341 MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output, 1342 NDArrayHandle *output_handles, 1343 NDArrayHandle *ograd_handles, 1344 uint32_t num_variables, 1345 NDArrayHandle *var_handles, 1346 int retain_graph, 1347 int create_graph, 1348 int is_train, 1349 NDArrayHandle **grad_handles, 1350 int **grad_stypes); 1351 /* 1352 * \brief get the graph constructed by autograd. 1353 * \param handle ndarray handle 1354 * \param out output symbol handle 1355 */ 1356 MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out); 1357 /*! 1358 * \brief create cached operator 1359 */ 1360 MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); 1361 /*! 1362 * \brief create cached operator 1363 */ 1364 MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, 1365 int num_flags, 1366 const char** keys, 1367 const char** vals, 1368 CachedOpHandle *out); 1369 1370 /*! 1371 * \brief create cached operator, allows to choose thread_safe version 1372 * of cachedop 1373 */ 1374 MXNET_DLL int MXCreateCachedOpEX(SymbolHandle handle, 1375 int num_flags, 1376 const char** keys, 1377 const char** vals, 1378 CachedOpHandle *out, 1379 bool thread_safe DEFAULT(false)); 1380 1381 /*! 1382 * \brief free cached operator 1383 */ 1384 MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle); 1385 1386 /*! 1387 * \brief invoke cached operator 1388 */ 1389 MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, 1390 int num_inputs, 1391 NDArrayHandle *inputs, 1392 int *num_outputs, 1393 NDArrayHandle **outputs); 1394 1395 /*! 1396 * \brief invoke a cached op 1397 * \param handle the handle to the cached op 1398 * \param num_inputs number of input NDArrays 1399 * \param inputs input NDArrays 1400 * \param num_outputs number of output NDArrays 1401 * \param outputs output NDArrays 1402 * \param out_stypes output ndarrays' stypes 1403 * \return 0 when success, -1 when failure happens 1404 */ 1405 MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, 1406 int num_inputs, 1407 NDArrayHandle *inputs, 1408 int *num_outputs, 1409 NDArrayHandle **outputs, 1410 const int** out_stypes); 1411 1412 /*! 1413 * \brief cached op set monitor callback 1414 */ 1415 MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle, 1416 CachedOpMonitorCallback callback, 1417 bool monitor_all); 1418 1419 //-------------------------------------------- 1420 // Part 3: symbolic configuration generation 1421 //-------------------------------------------- 1422 /*! 1423 * \brief list all the available operator names, include entries. 1424 * \param out_size the size of returned array 1425 * \param out_array the output operator name array. 1426 * \return 0 when success, -1 when failure happens 1427 */ 1428 MXNET_DLL int MXListAllOpNames(uint32_t *out_size, 1429 const char ***out_array); 1430 1431 /*! 1432 * \brief list all the available AtomicSymbolEntry 1433 * \param out_size the size of returned array 1434 * \param out_array the output AtomicSymbolCreator array 1435 * \return 0 when success, -1 when failure happens 1436 */ 1437 MXNET_DLL int MXSymbolListAtomicSymbolCreators(uint32_t *out_size, 1438 AtomicSymbolCreator **out_array); 1439 1440 /*! 1441 * \brief Get the name of an atomic symbol. 1442 * \param creator the AtomicSymbolCreator. 1443 * \param name The returned name of the creator. 1444 */ 1445 MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, 1446 const char **name); 1447 1448 /*! 1449 * \brief Get the input symbols of the graph. 1450 * \param sym The graph. 1451 * \param inputs The input symbols of the graph. 1452 * \param input_size the number of input symbols returned. 1453 */ 1454 MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs, 1455 int *input_size); 1456 1457 /*! 1458 * \brief Cut a subgraph whose nodes are marked with a subgraph attribute. 1459 * The input graph will be modified. A variable node will be created for each 1460 * edge that connects to nodes outside the subgraph. The outside nodes that 1461 * connect to the subgraph will be returned. 1462 * \param sym The graph. 1463 * \param inputs The nodes that connect to the subgraph. 1464 * \param input_size The number of such nodes. 1465 */ 1466 MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs, 1467 int *input_size); 1468 1469 /*! 1470 * \brief Get the detailed information about atomic symbol. 1471 * \param creator the AtomicSymbolCreator. 1472 * \param name The returned name of the creator. 1473 * \param description The returned description of the symbol. 1474 * \param num_args Number of arguments. 1475 * \param arg_names Name of the arguments. 1476 * \param arg_type_infos Type informations about the arguments. 1477 * \param arg_descriptions Description information about the arguments. 1478 * \param key_var_num_args The keyword argument for specifying variable number of arguments. 1479 * When this parameter has non-zero length, the function allows variable number 1480 * of positional arguments, and will need the caller to pass it in in 1481 * MXSymbolCreateAtomicSymbol, 1482 * With key = key_var_num_args, and value = number of positional arguments. 1483 * \param return_type Return type of the function, can be Symbol or Symbol[] 1484 * \return 0 when success, -1 when failure happens 1485 */ 1486 MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, 1487 const char **name, 1488 const char **description, 1489 uint32_t *num_args, 1490 const char ***arg_names, 1491 const char ***arg_type_infos, 1492 const char ***arg_descriptions, 1493 const char **key_var_num_args, 1494 const char **return_type DEFAULT(NULL)); 1495 /*! 1496 * \brief Create an AtomicSymbol. 1497 * \param creator the AtomicSymbolCreator 1498 * \param num_param the number of parameters 1499 * \param keys the keys to the params 1500 * \param vals the vals of the params 1501 * \param out pointer to the created symbol handle 1502 * \return 0 when success, -1 when failure happens 1503 */ 1504 MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, 1505 uint32_t num_param, 1506 const char **keys, 1507 const char **vals, 1508 SymbolHandle *out); 1509 /*! 1510 * \brief Create a Variable Symbol. 1511 * \param name name of the variable 1512 * \param out pointer to the created symbol handle 1513 * \return 0 when success, -1 when failure happens 1514 */ 1515 MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); 1516 /*! 1517 * \brief Create a Symbol by grouping list of symbols together 1518 * \param num_symbols number of symbols to be grouped 1519 * \param symbols array of symbol handles 1520 * \param out pointer to the created symbol handle 1521 * \return 0 when success, -1 when failure happens 1522 */ 1523 MXNET_DLL int MXSymbolCreateGroup(uint32_t num_symbols, 1524 SymbolHandle *symbols, 1525 SymbolHandle *out); 1526 /*! 1527 * \brief Load a symbol from a json file. 1528 * \param fname the file name. 1529 * \param out the output symbol. 1530 * \return 0 when success, -1 when failure happens 1531 */ 1532 MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out); 1533 /*! 1534 * \brief Load a symbol from a json string. 1535 * \param json the json string. 1536 * \param out the output symbol. 1537 * \return 0 when success, -1 when failure happens 1538 */ 1539 MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out); 1540 /*! 1541 * \brief Remove the operators amp_cast and amp_multicast 1542 * \param sym_handle the input symbol. 1543 * \param ret_sym_handle the output symbol. 1544 * \return 0 when success, -1 when failure happens 1545 */ 1546 MXNET_DLL int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle); 1547 /*! 1548 * \brief Save a symbol into a json file. 1549 * \param symbol the input symbol. 1550 * \param fname the file name. 1551 * \return 0 when success, -1 when failure happens 1552 */ 1553 MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname); 1554 /*! 1555 * \brief Save a symbol into a json string 1556 * \param symbol the input symbol. 1557 * \param out_json output json string. 1558 * \return 0 when success, -1 when failure happens 1559 */ 1560 MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json); 1561 /*! 1562 * \brief Free the symbol handle. 1563 * \param symbol the symbol 1564 * \return 0 when success, -1 when failure happens 1565 */ 1566 MXNET_DLL int MXSymbolFree(SymbolHandle symbol); 1567 /*! 1568 * \brief Copy the symbol to another handle 1569 * \param symbol the source symbol 1570 * \param out used to hold the result of copy 1571 * \return 0 when success, -1 when failure happens 1572 */ 1573 MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); 1574 /*! 1575 * \brief Print the content of symbol, used for debug. 1576 * \param symbol the symbol 1577 * \param out_str pointer to hold the output string of the printing. 1578 * \return 0 when success, -1 when failure happens 1579 */ 1580 MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); 1581 /*! 1582 * \brief Get string name from symbol 1583 * \param symbol the source symbol 1584 * \param out The result name. 1585 * \param success Whether the result is contained in out. 1586 * \return 0 when success, -1 when failure happens 1587 */ 1588 MXNET_DLL int MXSymbolGetName(SymbolHandle symbol, 1589 const char** out, 1590 int *success); 1591 /*! 1592 * \brief Get string attribute from symbol 1593 * \param symbol the source symbol 1594 * \param key The key of the symbol. 1595 * \param out The result attribute, can be NULL if the attribute do not exist. 1596 * \param success Whether the result is contained in out. 1597 * \return 0 when success, -1 when failure happens 1598 */ 1599 MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol, 1600 const char* key, 1601 const char** out, 1602 int *success); 1603 /*! 1604 * \brief Set string attribute from symbol. 1605 * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. 1606 * 1607 * Safe recommendaton: use immutable graph 1608 * - Only allow set attributes during creation of new symbol as optional parameter 1609 * 1610 * Mutable graph (be careful about the semantics): 1611 * - Allow set attr at any point. 1612 * - Mutating an attribute of some common node of two graphs can cause confusion from user. 1613 * 1614 * \param symbol the source symbol 1615 * \param key The key of the symbol. 1616 * \param value The value to be saved. 1617 * \return 0 when success, -1 when failure happens 1618 */ 1619 MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol, 1620 const char* key, 1621 const char* value); 1622 /*! 1623 * \brief Get all attributes from symbol, including all descendents. 1624 * \param symbol the source symbol 1625 * \param out_size The number of output attributes 1626 * \param out 2*out_size strings representing key value pairs. 1627 * \return 0 when success, -1 when failure happens 1628 */ 1629 MXNET_DLL int MXSymbolListAttr(SymbolHandle symbol, 1630 uint32_t *out_size, 1631 const char*** out); 1632 /*! 1633 * \brief Get all attributes from symbol, excluding descendents. 1634 * \param symbol the source symbol 1635 * \param out_size The number of output attributes 1636 * \param out 2*out_size strings representing key value pairs. 1637 * \return 0 when success, -1 when failure happens 1638 */ 1639 MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol, 1640 uint32_t *out_size, 1641 const char*** out); 1642 /*! 1643 * \brief List arguments in the symbol. 1644 * \param symbol the symbol 1645 * \param out_size output size 1646 * \param out_str_array pointer to hold the output string array 1647 * \return 0 when success, -1 when failure happens 1648 */ 1649 MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, 1650 uint32_t *out_size, 1651 const char ***out_str_array); 1652 1653 /*! 1654 * \brief List returns in the symbol. 1655 * \param symbol the symbol 1656 * \param out_size output size 1657 * \param out_str_array pointer to hold the output string array 1658 * \return 0 when success, -1 when failure happens 1659 */ 1660 MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, 1661 uint32_t *out_size, 1662 const char ***out_str_array); 1663 1664 /*! 1665 * \brief Get number of outputs of the symbol. 1666 * \param symbol The symbol 1667 * \param out_size number of outputs 1668 * \return 0 when success, -1 when failure happens 1669 */ 1670 MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, 1671 uint32_t *output_count); 1672 1673 /*! 1674 * \brief Get a symbol that contains all the internals. 1675 * \param symbol The symbol 1676 * \param out The output symbol whose outputs are all the internals. 1677 * \return 0 when success, -1 when failure happens 1678 */ 1679 MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, 1680 SymbolHandle *out); 1681 /*! 1682 * \brief Get a symbol that contains only direct children. 1683 * \param symbol The symbol 1684 * \param out The output symbol whose outputs are the direct children. 1685 * \return 0 when success, -1 when failure happens 1686 */ 1687 MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol, 1688 SymbolHandle *out); 1689 /*! 1690 * \brief Get index-th outputs of the symbol. 1691 * \param symbol The symbol 1692 * \param index the Index of the output. 1693 * \param out The output symbol whose outputs are the index-th symbol. 1694 * \return 0 when success, -1 when failure happens 1695 */ 1696 MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, 1697 uint32_t index, 1698 SymbolHandle *out); 1699 1700 /*! 1701 * \brief List auxiliary states in the symbol. 1702 * \param symbol the symbol 1703 * \param out_size output size 1704 * \param out_str_array pointer to hold the output string array 1705 * \return 0 when success, -1 when failure happens 1706 */ 1707 MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, 1708 uint32_t *out_size, 1709 const char ***out_str_array); 1710 1711 /*! 1712 * \brief Compose the symbol on other symbols. 1713 * 1714 * This function will change the sym hanlde. 1715 * To achieve function apply behavior, copy the symbol first 1716 * before apply. 1717 * 1718 * \param sym the symbol to apply 1719 * \param name the name of symbol 1720 * \param num_args number of arguments 1721 * \param keys the key of keyword args (optional) 1722 * \param args arguments to sym 1723 * \return 0 when success, -1 when failure happens 1724 */ 1725 MXNET_DLL int MXSymbolCompose(SymbolHandle sym, 1726 const char *name, 1727 uint32_t num_args, 1728 const char** keys, 1729 SymbolHandle* args); 1730 /*! 1731 * \brief Get the gradient graph of the symbol 1732 * 1733 * \param sym the symbol to get gradient 1734 * \param num_wrt number of arguments to get gradient 1735 * \param wrt the name of the arguments to get gradient 1736 * \param out the returned symbol that has gradient 1737 * \return 0 when success, -1 when failure happens 1738 */ 1739 MXNET_DLL int MXSymbolGrad(SymbolHandle sym, 1740 uint32_t num_wrt, 1741 const char** wrt, 1742 SymbolHandle* out); 1743 /*! 1744 * \brief DEPRECATED. Use MXSymbolInferShapeEx instead. 1745 * infer shape of unknown input shapes given the known one. 1746 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1747 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1748 * 1749 * \param sym symbol handle 1750 * \param num_args numbe of input arguments. 1751 * \param keys the key of keyword args (optional) 1752 * \param arg_ind_ptr the head pointer of the rows in CSR 1753 * \param arg_shape_data the content of the CSR 1754 * \param in_shape_size sizeof the returning array of in_shapes 1755 * \param in_shape_ndim returning array of shape dimensions of each input shape. 1756 * \param in_shape_data returning array of pointers to head of the input shape. 1757 * \param out_shape_size sizeof the returning array of out_shapes 1758 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1759 * \param out_shape_data returning array of pointers to head of the output shape. 1760 * \param aux_shape_size sizeof the returning array of aux_shapes 1761 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1762 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1763 * \param complete whether infer shape completes or more information is needed. 1764 * \return 0 when success, -1 when failure happens 1765 */ 1766 MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, 1767 uint32_t num_args, 1768 const char** keys, 1769 const uint32_t *arg_ind_ptr, 1770 const uint32_t *arg_shape_data, 1771 uint32_t *in_shape_size, 1772 const uint32_t **in_shape_ndim, 1773 const uint32_t ***in_shape_data, 1774 uint32_t *out_shape_size, 1775 const uint32_t **out_shape_ndim, 1776 const uint32_t ***out_shape_data, 1777 uint32_t *aux_shape_size, 1778 const uint32_t **aux_shape_ndim, 1779 const uint32_t ***aux_shape_data, 1780 int *complete); 1781 1782 /*! 1783 * \brief infer shape of unknown input shapes given the known one. 1784 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1785 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1786 * This api is available when MXNet is built with flag 1787 * USE_INT64_TENSOR_SIZE=0 (by default) 1788 * \param sym symbol handle 1789 * \param num_args number of input arguments. 1790 * \param keys the key of keyword args (optional) 1791 * \param arg_ind_ptr the head pointer of the rows in CSR 1792 * \param arg_shape_data the content of the CSR 1793 * \param in_shape_size sizeof the returning array of in_shapes 1794 * \param in_shape_ndim returning array of shape dimensions of eachs input shape. 1795 * \param in_shape_data returning array of pointers to head of the input shape. 1796 * \param out_shape_size sizeof the returning array of out_shapes 1797 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1798 * \param out_shape_data returning array of pointers to head of the output shape. 1799 * \param aux_shape_size sizeof the returning array of aux_shapes 1800 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1801 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1802 * \param complete whether infer shape completes or more information is needed. 1803 * \return 0 when success, -1 when failure happens 1804 */ 1805 MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, 1806 uint32_t num_args, 1807 const char** keys, 1808 const uint32_t *arg_ind_ptr, 1809 const int *arg_shape_data, 1810 uint32_t *in_shape_size, 1811 const int **in_shape_ndim, 1812 const int ***in_shape_data, 1813 uint32_t *out_shape_size, 1814 const int **out_shape_ndim, 1815 const int ***out_shape_data, 1816 uint32_t *aux_shape_size, 1817 const int **aux_shape_ndim, 1818 const int ***aux_shape_data, 1819 int *complete); 1820 1821 /*! 1822 * \brief infer shape of unknown input shapes given the known one. 1823 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1824 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1825 * This api is available when MXNet is built with flag 1826 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 1827 * \param sym symbol handle 1828 * \param num_args number of input arguments. 1829 * \param keys the key of keyword args (optional) 1830 * \param arg_ind_ptr the head pointer of the rows in CSR 1831 * \param arg_shape_data the content of the CSR 1832 * \param in_shape_size sizeof the returning array of in_shapes 1833 * \param in_shape_ndim returning array of shape dimensions of each input shape. 1834 * \param in_shape_data returning array of pointers to head of the input shape. 1835 * \param out_shape_size sizeof the returning array of out_shapes 1836 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1837 * \param out_shape_data returning array of pointers to head of the output shape. 1838 * \param aux_shape_size sizeof the returning array of aux_shapes 1839 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1840 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1841 * \param complete whether infer shape completes or more information is needed. 1842 * \return 0 when success, -1 when failure happens 1843 */ 1844 MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym, 1845 uint32_t num_args, 1846 const char** keys, 1847 const int64_t *arg_ind_ptr, 1848 const int64_t *arg_shape_data, 1849 size_t *in_shape_size, 1850 const int **in_shape_ndim, 1851 const int64_t ***in_shape_data, 1852 size_t *out_shape_size, 1853 const int **out_shape_ndim, 1854 const int64_t ***out_shape_data, 1855 size_t *aux_shape_size, 1856 const int **aux_shape_ndim, 1857 const int64_t ***aux_shape_data, 1858 int *complete); 1859 1860 /*! 1861 * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. 1862 * partially infer shape of unknown input shapes given the known one. 1863 * 1864 * Return partially inferred results if not all shapes could be inferred. 1865 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1866 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1867 * 1868 * \param sym symbol handle 1869 * \param num_args numbe of input arguments. 1870 * \param keys the key of keyword args (optional) 1871 * \param arg_ind_ptr the head pointer of the rows in CSR 1872 * \param arg_shape_data the content of the CSR 1873 * \param in_shape_size sizeof the returning array of in_shapes 1874 * \param in_shape_ndim returning array of shape dimensions of each input shape. 1875 * \param in_shape_data returning array of pointers to head of the input shape. 1876 * \param out_shape_size sizeof the returning array of out_shapes 1877 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1878 * \param out_shape_data returning array of pointers to head of the output shape. 1879 * \param aux_shape_size sizeof the returning array of aux_shapes 1880 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1881 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1882 * \param complete whether infer shape completes or more information is needed. 1883 * \return 0 when success, -1 when failure happens 1884 */ 1885 MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, 1886 uint32_t num_args, 1887 const char** keys, 1888 const uint32_t *arg_ind_ptr, 1889 const uint32_t *arg_shape_data, 1890 uint32_t *in_shape_size, 1891 const uint32_t **in_shape_ndim, 1892 const uint32_t ***in_shape_data, 1893 uint32_t *out_shape_size, 1894 const uint32_t **out_shape_ndim, 1895 const uint32_t ***out_shape_data, 1896 uint32_t *aux_shape_size, 1897 const uint32_t **aux_shape_ndim, 1898 const uint32_t ***aux_shape_data, 1899 int *complete); 1900 1901 /*! 1902 * \brief partially infer shape of unknown input shapes given the known one. 1903 * 1904 * Return partially inferred results if not all shapes could be inferred. 1905 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1906 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1907 * This api is available when MXNet is built with flag 1908 * USE_INT64_TENSOR_SIZE=0 (by default) 1909 * 1910 * \param sym symbol handle 1911 * \param num_args number of input arguments. 1912 * \param keys the key of keyword args (optional) 1913 * \param arg_ind_ptr the head pointer of the rows in CSR 1914 * \param arg_shape_data the content of the CSR 1915 * \param in_shape_size sizeof the returning array of in_shapes 1916 * \param in_shape_ndim returning array of shape dimensions of each input shape. 1917 * \param in_shape_data returning array of pointers to head of the input shape. 1918 * \param out_shape_size sizeof the returning array of out_shapes 1919 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1920 * \param out_shape_data returning array of pointers to head of the output shape. 1921 * \param aux_shape_size sizeof the returning array of aux_shapes 1922 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1923 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1924 * \param complete whether infer shape completes or more information is needed. 1925 * \return 0 when success, -1 when failure happens 1926 */ 1927 MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, 1928 uint32_t num_args, 1929 const char** keys, 1930 const uint32_t *arg_ind_ptr, 1931 const int *arg_shape_data, 1932 uint32_t *in_shape_size, 1933 const int **in_shape_ndim, 1934 const int ***in_shape_data, 1935 uint32_t *out_shape_size, 1936 const int **out_shape_ndim, 1937 const int ***out_shape_data, 1938 uint32_t *aux_shape_size, 1939 const int **aux_shape_ndim, 1940 const int ***aux_shape_data, 1941 int *complete); 1942 1943 /*! 1944 * \brief partially infer shape of unknown input shapes given the known one. 1945 * 1946 * Return partially inferred results if not all shapes could be inferred. 1947 * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data 1948 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1949 * This api is available when MXNet is built with flag 1950 * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support 1951 * 1952 * \param sym symbol handle 1953 * \param num_args number of input arguments. 1954 * \param keys the key of keyword args (optional) 1955 * \param arg_ind_ptr the head pointer of the rows in CSR 1956 * \param arg_shape_data the content of the CSR 1957 * \param in_shape_size sizeof the returning array of in_shapes 1958 * \param in_shape_ndim returning array of shape dimensions of each input shape. 1959 * \param in_shape_data returning array of pointers to head of the input shape. 1960 * \param out_shape_size sizeof the returning array of out_shapes 1961 * \param out_shape_ndim returning array of shape dimensions of each output shape. 1962 * \param out_shape_data returning array of pointers to head of the output shape. 1963 * \param aux_shape_size sizeof the returning array of aux_shapes 1964 * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. 1965 * \param aux_shape_data returning array of pointers to head of the auxiliary shape. 1966 * \param complete whether infer shape completes or more information is needed. 1967 * \return 0 when success, -1 when failure happens 1968 */ 1969 MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym, 1970 uint32_t num_args, 1971 const char** keys, 1972 const int64_t *arg_ind_ptr, 1973 const int64_t *arg_shape_data, 1974 size_t *in_shape_size, 1975 const int **in_shape_ndim, 1976 const int64_t ***in_shape_data, 1977 size_t *out_shape_size, 1978 const int **out_shape_ndim, 1979 const int64_t ***out_shape_data, 1980 size_t *aux_shape_size, 1981 const int **aux_shape_ndim, 1982 const int64_t ***aux_shape_data, 1983 int *complete); 1984 1985 /*! 1986 * \brief infer type of unknown input types given the known one. 1987 * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data 1988 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 1989 * 1990 * \param sym symbol handle 1991 * \param num_args numbe of input arguments. 1992 * \param keys the key of keyword args (optional) 1993 * \param arg_type_data the content of the CSR 1994 * \param in_type_size sizeof the returning array of in_types 1995 * \param in_type_data returning array of pointers to head of the input type. 1996 * \param out_type_size sizeof the returning array of out_types 1997 * \param out_type_data returning array of pointers to head of the output type. 1998 * \param aux_type_size sizeof the returning array of aux_types 1999 * \param aux_type_data returning array of pointers to head of the auxiliary type. 2000 * \param complete whether infer type completes or more information is needed. 2001 * \return 0 when success, -1 when failure happens 2002 */ 2003 MXNET_DLL int MXSymbolInferType(SymbolHandle sym, 2004 uint32_t num_args, 2005 const char** keys, 2006 const int *arg_type_data, 2007 uint32_t *in_type_size, 2008 const int **in_type_data, 2009 uint32_t *out_type_size, 2010 const int **out_type_data, 2011 uint32_t *aux_type_size, 2012 const int **aux_type_data, 2013 int *complete); 2014 2015 /*! 2016 * \brief partially infer type of unknown input types given the known one. 2017 * 2018 * Return partially inferred results if not all types could be inferred. 2019 * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data 2020 * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. 2021 * 2022 * \param sym symbol handle 2023 * \param num_args numbe of input arguments. 2024 * \param keys the key of keyword args (optional) 2025 * \param arg_type_data the content of the CSR 2026 * \param in_type_size sizeof the returning array of in_types 2027 * \param in_type_data returning array of pointers to head of the input type. 2028 * \param out_type_size sizeof the returning array of out_types 2029 * \param out_type_data returning array of pointers to head of the output type. 2030 * \param aux_type_size sizeof the returning array of aux_types 2031 * \param aux_type_data returning array of pointers to head of the auxiliary type. 2032 * \param complete whether infer type completes or more information is needed. 2033 * \return 0 when success, -1 when failure happens 2034 */ 2035 MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, 2036 uint32_t num_args, 2037 const char** keys, 2038 const int *arg_type_data, 2039 uint32_t *in_type_size, 2040 const int **in_type_data, 2041 uint32_t *out_type_size, 2042 const int **out_type_data, 2043 uint32_t *aux_type_size, 2044 const int **aux_type_data, 2045 int *complete); 2046 2047 /*! 2048 * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 2049 * \param sym_handle symbol to be converted 2050 * \param ret_sym_handle quantized symbol result 2051 * \param dev_type device type 2052 * \param num_excluded_sym_names number of layers excluded from being quantized in the input symbol 2053 * \param excluded_sym_names node names to be excluded from being quantized 2054 * \param num_excluded_op_names number of operators excluded from being quantized in the input symbol 2055 * \param excluded_op_names operator names to be excluded from being quantized 2056 * \param num_offline number of parameters that are quantized offline 2057 * \param offline_params array of c strings representing the names of params quantized offline 2058 * \param quantized_dtype the quantized destination type for input data 2059 * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could 2060 * \param quantize_mode quantize mode to be used in quantize pass 2061 * \param quantize_granularity quantize granularity, tensor-wise or channel-wise 2062 * \param out_num_calib_names return the number of nodes to be calibrated 2063 * \param out_calib_names return the node names to be calibrated 2064 */ 2065 MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, 2066 SymbolHandle *ret_sym_handle, 2067 const int* dev_type, 2068 const uint32_t num_excluded_sym_names, 2069 const char **excluded_sym_names, 2070 const uint32_t num_excluded_op_names, 2071 const char **excluded_op_names, 2072 const uint32_t num_offline, const char **offline_params, 2073 const char *quantized_dtype, const bool calib_quantize, 2074 const char *quantize_mode, const char *quantize_granularity, 2075 uint32_t* out_num_calib_names, const char ***out_calib_names); 2076 2077 /*! 2078 * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting 2079 * \param sym_handle symbol to be converted 2080 * \param ret_sym_handle mixed precision symbol result 2081 * \param num_args number of arguments for known dtypes 2082 * \param arg_type_data arg types of the arguments 2083 * \param target_dtype target_dtype for mixed precision symbol 2084 * \param cast_optional_params whether to cast optional params to target_dtype 2085 * \param num_target_dtype_op_names number of ops to be casted to target_dtype 2086 * \param num_fp32_op_names number of ops to be casted to FP32 2087 * \param num_widest_dtype_op_names number of ops to be casted to widest dtype 2088 * \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition 2089 * \param num_excluded_symbols number of symbols to be excluded from casting 2090 * \param num_model_params number of model parameters 2091 * \param num_widest_dtype_op_names number of ops to be casted to the widest dtype 2092 * \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision 2093 * \param target_dtype_op_names op names to be casted to target_dtype 2094 * \param fp32_op_names op names to be casted to fp32 2095 * \param widest_dtype_op_names names to be casted to widest dtype 2096 * \param conditional_fp32_op_names names to be casted to FP32 conditionally 2097 * \param excluded_symbols symbol names to be excluded from casting 2098 * \param param_names param names for conditional FP32 casting 2099 * \param param_values param values for conditional FP32 casting 2100 * \param arg_names argument names for which type information is provided 2101 * \param model_param_names names for model parameters 2102 */ 2103 MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle, 2104 SymbolHandle *ret_sym_handle, 2105 uint32_t num_args, 2106 const int* arg_type_data, 2107 uint32_t num_ind_ptr, 2108 const int* ind_ptr, 2109 const int* target_dtype, 2110 const int cast_optional_params, 2111 const uint32_t num_target_dtype_op_names, 2112 const uint32_t num_fp32_op_names, 2113 const uint32_t num_widest_dtype_op_names, 2114 const uint32_t num_conditional_fp32_op_names, 2115 const uint32_t num_excluded_symbols, 2116 const uint32_t num_model_params, 2117 const char **target_dtype_op_names, 2118 const char **fp32_op_names, 2119 const char **widest_dtype_op_names, 2120 const char **conditional_fp32_op_names, 2121 const char **excluded_symbols, 2122 const char **conditional_param_names, 2123 const char **conditional_param_vals, 2124 const char **model_param_names, 2125 const char **arg_names); 2126 /*! 2127 * \brief Set calibration table to node attributes in the sym 2128 * \param sym_handle symbol whose node attributes are to be set by calibration table 2129 * \param num_layers number of layers in the calibration table 2130 * \param layer names stored as keys in the calibration table 2131 * \param low_quantiles low quantiles of layers stored in the calibration table 2132 * \param high_quantiles high quantiles of layers stored in the calibration table 2133 * \param ret_sym_handle returned symbol 2134 */ 2135 MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, 2136 const uint32_t num_layers, 2137 const char** layer_names, 2138 const float* low_quantiles, 2139 const float* high_quantiles, 2140 SymbolHandle* ret_sym_handle); 2141 2142 /*! 2143 * \brief Run subgraph pass based on the backend provided 2144 * \param sym_handle symbol to be converted 2145 * \param backend backend names for subgraph pass 2146 * \param ret_sym_handle returned symbol 2147 */ 2148 MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, 2149 SymbolHandle *ret_sym_handle); 2150 2151 /*! 2152 * \brief Generate atomic symbol (able to be composed) from a source symbol 2153 * \param sym_handle source symbol 2154 * \param ret_sym_handle returned atomic symbol 2155 */ 2156 MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle); 2157 /*! 2158 * \brief Partitions symbol for given backend, potentially creating subgraphs 2159 * \param sym_handle symbol to be partitioned 2160 * \param dev_type context device type 2161 * \param backend_name backend name 2162 * \param ret_sym_handle partitioned symbol returned 2163 * \param len number of args 2164 * \param in_args_handle args array 2165 * \param num_options number of key value pairs 2166 * \param keys keys for options 2167 * \param vals values corresponding to keys 2168 * \param num_input_shapes number of input shapes 2169 * \param input_shape_names names of the input shapes 2170 * \param input_shape_data pointer to the contiguous data shapes 2171 * \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape 2172 * is calculate as input_shape_idx[i+1] - input_shape_idx[i] 2173 * \param num_input_dtypes number of input data types 2174 * \param input_dtype_names array of names of the input data types 2175 * \param input_dtypes array of values of the input data types 2176 * \param num_input_stypesnumber of input storage types 2177 * \param input_stype_names array of names of the input storage types 2178 * \param input_stypes array of values of input storage types 2179 * \param skip_infer if the optimization should skip the attribute inferences 2180 * (to use if the backend does not require shape inference) 2181 * \param new_args_cnt pointer a number to store the number of new args 2182 * \param new_args_handle pointer on array to store the new args handles 2183 * \param new_arg_names_handle pointer on array to store the new args names 2184 * \param new_aux_cnt pointer a number to store the number of new aux 2185 * \param new_aux_handle pointer on array to store the new aux handles 2186 * \param new_aux_names_handle pointer on array to store the new aux names 2187 */ 2188 MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, 2189 const char* backend_name, 2190 const int dev_type, 2191 SymbolHandle* ret_sym_handle, 2192 const mx_uint args_len, 2193 NDArrayHandle* in_args_handle, 2194 const mx_uint aux_len, 2195 NDArrayHandle* in_aux_handle, 2196 const mx_uint num_options, 2197 const char** keys, 2198 const char** vals, 2199 const uint32_t num_input_shapes, 2200 const char** input_shape_names, 2201 const int64_t* input_shape_data, 2202 const uint32_t* input_shape_idx, 2203 const uint32_t num_input_dtypes, 2204 const char** input_dtype_names, 2205 const int* input_dtypes, 2206 const uint32_t num_input_stypes, 2207 const char** input_stype_names, 2208 const int* input_stypes, 2209 bool skip_infer, 2210 int* new_args_cnt, 2211 NDArrayHandle** new_args_handle, 2212 char*** new_arg_names_handle, 2213 int* new_aux_cnt, 2214 NDArrayHandle** new_aux_handle, 2215 char*** new_aux_names_handle); 2216 2217 2218 //-------------------------------------------- 2219 // Part 4: Executor interface 2220 //-------------------------------------------- 2221 /*! 2222 * \brief Delete the executor 2223 * \param handle the executor. 2224 * \return 0 when success, -1 when failure happens 2225 */ 2226 MXNET_DLL int MXExecutorFree(ExecutorHandle handle); 2227 /*! 2228 * \brief Print the content of execution plan, used for debug. 2229 * \param handle the executor. 2230 * \param out_str pointer to hold the output string of the printing. 2231 * \return 0 when success, -1 when failure happens 2232 */ 2233 MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str); 2234 /*! 2235 * \brief Executor forward method 2236 * 2237 * \param handle executor handle 2238 * \param is_train int value to indicate whether the forward pass is for evaluation 2239 * \return 0 when success, -1 when failure happens 2240 */ 2241 MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train); 2242 /*! 2243 * \brief Excecutor run backward 2244 * 2245 * \param handle execute handle 2246 * \param len lenth 2247 * \param head_grads NDArray handle for heads' gradient 2248 * 2249 * \return 0 when success, -1 when failure happens 2250 */ 2251 MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, 2252 uint32_t len, 2253 NDArrayHandle *head_grads); 2254 /*! 2255 * \brief Excecutor run backward 2256 * 2257 * \param handle execute handle 2258 * \param len lenth 2259 * \param head_grads NDArray handle for heads' gradient 2260 * \param is_train int value to indicate whether the backward pass is for evaluation 2261 * 2262 * \return 0 when success, -1 when failure happens 2263 */ 2264 MXNET_DLL int MXExecutorBackwardEx(ExecutorHandle handle, 2265 uint32_t len, 2266 NDArrayHandle *head_grads, 2267 int is_train); 2268 /*! 2269 * \brief Get executor's head NDArray 2270 * 2271 * \param handle executor handle 2272 * \param out_size output narray vector size 2273 * \param out out put narray handles 2274 * \return 0 when success, -1 when failure happens 2275 */ 2276 MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, 2277 uint32_t *out_size, 2278 NDArrayHandle **out); 2279 2280 /*! 2281 * \brief Generate Executor from symbol 2282 * 2283 * \param symbol_handle symbol handle 2284 * \param dev_type device type 2285 * \param dev_id device id 2286 * \param len length 2287 * \param in_args in args array 2288 * \param arg_grad_store arg grads handle array 2289 * \param grad_req_type grad req array 2290 * \param aux_states_len length of auxiliary states 2291 * \param aux_states auxiliary states array 2292 * \param out output executor handle 2293 * \return 0 when success, -1 when failure happens 2294 */ 2295 MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, 2296 int dev_type, 2297 int dev_id, 2298 uint32_t len, 2299 NDArrayHandle *in_args, 2300 NDArrayHandle *arg_grad_store, 2301 uint32_t *grad_req_type, 2302 uint32_t aux_states_len, 2303 NDArrayHandle *aux_states, 2304 ExecutorHandle *out); 2305 /*! 2306 * \brief Generate Executor from symbol, 2307 * This is advanced function, allow specify group2ctx map. 2308 * The user can annotate "ctx_group" attribute to name each group. 2309 * 2310 * \param symbol_handle symbol handle 2311 * \param dev_type device type of default context 2312 * \param dev_id device id of default context 2313 * \param num_map_keys size of group2ctx map 2314 * \param map_keys keys of group2ctx map 2315 * \param map_dev_types device type of group2ctx map 2316 * \param map_dev_ids device id of group2ctx map 2317 * \param len length 2318 * \param in_args in args array 2319 * \param arg_grad_store arg grads handle array 2320 * \param grad_req_type grad req array 2321 * \param aux_states_len length of auxiliary states 2322 * \param aux_states auxiliary states array 2323 * \param out output executor handle 2324 * \return 0 when success, -1 when failure happens 2325 */ 2326 MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle, 2327 int dev_type, 2328 int dev_id, 2329 uint32_t num_map_keys, 2330 const char** map_keys, 2331 const int* map_dev_types, 2332 const int* map_dev_ids, 2333 uint32_t len, 2334 NDArrayHandle *in_args, 2335 NDArrayHandle *arg_grad_store, 2336 uint32_t *grad_req_type, 2337 uint32_t aux_states_len, 2338 NDArrayHandle *aux_states, 2339 ExecutorHandle *out); 2340 /*! 2341 * \brief Generate Executor from symbol, 2342 * This is advanced function, allow specify group2ctx map. 2343 * The user can annotate "ctx_group" attribute to name each group. 2344 * 2345 * \param symbol_handle symbol handle 2346 * \param dev_type device type of default context 2347 * \param dev_id device id of default context 2348 * \param num_map_keys size of group2ctx map 2349 * \param map_keys keys of group2ctx map 2350 * \param map_dev_types device type of group2ctx map 2351 * \param map_dev_ids device id of group2ctx map 2352 * \param len length 2353 * \param in_args in args array 2354 * \param arg_grad_store arg grads handle array 2355 * \param grad_req_type grad req array 2356 * \param aux_states_len length of auxiliary states 2357 * \param aux_states auxiliary states array 2358 * \param shared_exec input executor handle for memory sharing 2359 * \param out output executor handle 2360 * \return 0 when success, -1 when failure happens 2361 */ 2362 MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, 2363 int dev_type, 2364 int dev_id, 2365 uint32_t num_map_keys, 2366 const char** map_keys, 2367 const int* map_dev_types, 2368 const int* map_dev_ids, 2369 uint32_t len, 2370 NDArrayHandle *in_args, 2371 NDArrayHandle *arg_grad_store, 2372 uint32_t *grad_req_type, 2373 uint32_t aux_states_len, 2374 NDArrayHandle *aux_states, 2375 ExecutorHandle shared_exec, 2376 ExecutorHandle *out); 2377 /*! \brief DEPRECATED. Use MXExecutorSimpleBindEx instead. 2378 */ 2379 MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, 2380 int dev_type, 2381 int dev_id, 2382 const uint32_t num_g2c_keys, 2383 const char** g2c_keys, 2384 const int* g2c_dev_types, 2385 const int* g2c_dev_ids, 2386 const uint32_t provided_grad_req_list_len, 2387 const char** provided_grad_req_names, 2388 const char** provided_grad_req_types, 2389 const uint32_t num_provided_arg_shapes, 2390 const char** provided_arg_shape_names, 2391 const uint32_t* provided_arg_shape_data, 2392 const uint32_t* provided_arg_shape_idx, 2393 const uint32_t num_provided_arg_dtypes, 2394 const char** provided_arg_dtype_names, 2395 const int* provided_arg_dtypes, 2396 const uint32_t num_provided_arg_stypes, 2397 const char** provided_arg_stype_names, 2398 const int* provided_arg_stypes, 2399 const uint32_t num_shared_arg_names, 2400 const char** shared_arg_name_list, 2401 int* shared_buffer_len, 2402 const char** shared_buffer_name_list, 2403 NDArrayHandle* shared_buffer_handle_list, 2404 const char*** updated_shared_buffer_name_list, 2405 NDArrayHandle** updated_shared_buffer_handle_list, 2406 uint32_t* num_in_args, 2407 NDArrayHandle** in_args, 2408 NDArrayHandle** arg_grads, 2409 uint32_t* num_aux_states, 2410 NDArrayHandle** aux_states, 2411 ExecutorHandle shared_exec_handle, 2412 ExecutorHandle* out); 2413 2414 2415 MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, 2416 int dev_type, 2417 int dev_id, 2418 const uint32_t num_g2c_keys, 2419 const char** g2c_keys, 2420 const int* g2c_dev_types, 2421 const int* g2c_dev_ids, 2422 const uint32_t provided_grad_req_list_len, 2423 const char** provided_grad_req_names, 2424 const char** provided_grad_req_types, 2425 const uint32_t num_provided_arg_shapes, 2426 const char** provided_arg_shape_names, 2427 const int* provided_arg_shape_data, 2428 const uint32_t* provided_arg_shape_idx, 2429 const uint32_t num_provided_arg_dtypes, 2430 const char** provided_arg_dtype_names, 2431 const int* provided_arg_dtypes, 2432 const uint32_t num_provided_arg_stypes, 2433 const char** provided_arg_stype_names, 2434 const int* provided_arg_stypes, 2435 const uint32_t num_shared_arg_names, 2436 const char** shared_arg_name_list, 2437 int* shared_buffer_len, 2438 const char** shared_buffer_name_list, 2439 NDArrayHandle* shared_buffer_handle_list, 2440 const char*** updated_shared_buffer_name_list, 2441 NDArrayHandle** updated_shared_buffer_handle_list, 2442 uint32_t* num_in_args, 2443 NDArrayHandle** in_args, 2444 NDArrayHandle** arg_grads, 2445 uint32_t* num_aux_states, 2446 NDArrayHandle** aux_states, 2447 ExecutorHandle shared_exec_handle, 2448 ExecutorHandle* out); 2449 2450 2451 MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, 2452 int dev_type, 2453 int dev_id, 2454 const uint32_t num_g2c_keys, 2455 const char** g2c_keys, 2456 const int* g2c_dev_types, 2457 const int* g2c_dev_ids, 2458 const uint32_t provided_grad_req_list_len, 2459 const char** provided_grad_req_names, 2460 const char** provided_grad_req_types, 2461 const uint32_t num_provided_arg_shapes, 2462 const char** provided_arg_shape_names, 2463 const int64_t* provided_arg_shape_data, 2464 const uint32_t* provided_arg_shape_idx, 2465 const uint32_t num_provided_arg_dtypes, 2466 const char** provided_arg_dtype_names, 2467 const int* provided_arg_dtypes, 2468 const uint32_t num_provided_arg_stypes, 2469 const char** provided_arg_stype_names, 2470 const int* provided_arg_stypes, 2471 const uint32_t num_shared_arg_names, 2472 const char** shared_arg_name_list, 2473 int* shared_buffer_len, 2474 const char** shared_buffer_name_list, 2475 NDArrayHandle* shared_buffer_handle_list, 2476 const char*** updated_shared_buffer_name_list, 2477 NDArrayHandle** updated_shared_buffer_handle_list, 2478 uint32_t* num_in_args, 2479 NDArrayHandle** in_args, 2480 NDArrayHandle** arg_grads, 2481 uint32_t* num_aux_states, 2482 NDArrayHandle** aux_states, 2483 ExecutorHandle shared_exec_handle, 2484 ExecutorHandle* out); 2485 2486 2487 /*! 2488 * \brief DEPRECATED. Use MXExecutorReshapeEx instead. 2489 * Return a new executor with the same symbol and shared memory, 2490 * but different input/output shapes. 2491 * 2492 * \param partial_shaping Whether to allow changing the shape of unspecified arguments. 2493 * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original. 2494 * \param dev_type device type of default context 2495 * \param dev_id device id of default context 2496 * \param num_map_keys size of group2ctx map 2497 * \param map_keys keys of group2ctx map 2498 * \param map_dev_types device type of group2ctx map 2499 * \param map_dev_ids device id of group2ctx map 2500 * \param num_in_args length of in_args 2501 * \param in_args in args array 2502 * \param arg_grads arg grads handle array 2503 * \param num_aux_states length of auxiliary states 2504 * \param aux_states auxiliary states array 2505 * \param shared_exec input executor handle for memory sharing 2506 * \param out output executor handle 2507 * \return a new executor 2508 */ 2509 MXNET_DLL int MXExecutorReshape(int partial_shaping, 2510 int allow_up_sizing, 2511 int dev_type, 2512 int dev_id, 2513 uint32_t num_map_keys, 2514 const char** map_keys, 2515 const int* map_dev_types, 2516 const int* map_dev_ids, 2517 const uint32_t num_provided_arg_shapes, 2518 const char** provided_arg_shape_names, 2519 const uint32_t* provided_arg_shape_data, 2520 const uint32_t* provided_arg_shape_idx, 2521 uint32_t* num_in_args, 2522 NDArrayHandle** in_args, 2523 NDArrayHandle** arg_grads, 2524 uint32_t* num_aux_states, 2525 NDArrayHandle** aux_states, 2526 ExecutorHandle shared_exec, 2527 ExecutorHandle *out); 2528 /*! 2529 * \brief Return a new executor with the same symbol and shared memory, 2530 * but different input/output shapes. 2531 * 2532 * \param partial_shaping Whether to allow changing the shape of unspecified arguments. 2533 * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original. 2534 * \param dev_type device type of default context 2535 * \param dev_id device id of default context 2536 * \param num_map_keys size of group2ctx map 2537 * \param map_keys keys of group2ctx map 2538 * \param map_dev_types device type of group2ctx map 2539 * \param map_dev_ids device id of group2ctx map 2540 * \param num_in_args length of in_args 2541 * \param in_args in args array 2542 * \param arg_grads arg grads handle array 2543 * \param num_aux_states length of auxiliary states 2544 * \param aux_states auxiliary states array 2545 * \param shared_exec input executor handle for memory sharing 2546 * \param out output executor handle 2547 * \return a new executor 2548 */ 2549 MXNET_DLL int MXExecutorReshapeEx(int partial_shaping, 2550 int allow_up_sizing, 2551 int dev_type, 2552 int dev_id, 2553 uint32_t num_map_keys, 2554 const char** map_keys, 2555 const int* map_dev_types, 2556 const int* map_dev_ids, 2557 const uint32_t num_provided_arg_shapes, 2558 const char** provided_arg_shape_names, 2559 const int* provided_arg_shape_data, 2560 const uint32_t* provided_arg_shape_idx, 2561 uint32_t* num_in_args, 2562 NDArrayHandle** in_args, 2563 NDArrayHandle** arg_grads, 2564 uint32_t* num_aux_states, 2565 NDArrayHandle** aux_states, 2566 ExecutorHandle shared_exec, 2567 ExecutorHandle *out); 2568 2569 /*! 2570 * \brief get optimized graph from graph executor 2571 */ 2572 MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, 2573 SymbolHandle *out); 2574 /*! 2575 * \brief set a call back to notify the completion of operation 2576 */ 2577 MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, 2578 ExecutorMonitorCallback callback, 2579 void* callback_handle); 2580 2581 /*! 2582 * \brief set a call back to notify the completion of operation 2583 * \param monitor_all If true, monitor both input and output, otherwise monitor output only. 2584 */ 2585 MXNET_DLL int MXExecutorSetMonitorCallbackEX(ExecutorHandle handle, 2586 ExecutorMonitorCallback callback, 2587 void *callback_handle, bool monitor_all); 2588 //-------------------------------------------- 2589 // Part 5: IO Interface 2590 //-------------------------------------------- 2591 /*! 2592 * \brief List all the available iterator entries 2593 * \param out_size the size of returned iterators 2594 * \param out_array the output iteratos entries 2595 * \return 0 when success, -1 when failure happens 2596 */ 2597 MXNET_DLL int MXListDataIters(uint32_t *out_size, 2598 DataIterCreator **out_array); 2599 /*! 2600 * \brief Init an iterator, init with parameters 2601 * the array size of passed in arguments 2602 * \param handle of the iterator creator 2603 * \param num_param number of parameter 2604 * \param keys parameter keys 2605 * \param vals parameter values 2606 * \param out resulting iterator 2607 * \return 0 when success, -1 when failure happens 2608 */ 2609 MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, 2610 uint32_t num_param, 2611 const char **keys, 2612 const char **vals, 2613 DataIterHandle *out); 2614 /*! 2615 * \brief Get the detailed information about data iterator. 2616 * \param creator the DataIterCreator. 2617 * \param name The returned name of the creator. 2618 * \param description The returned description of the symbol. 2619 * \param num_args Number of arguments. 2620 * \param arg_names Name of the arguments. 2621 * \param arg_type_infos Type informations about the arguments. 2622 * \param arg_descriptions Description information about the arguments. 2623 * \return 0 when success, -1 when failure happens 2624 */ 2625 MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, 2626 const char **name, 2627 const char **description, 2628 uint32_t *num_args, 2629 const char ***arg_names, 2630 const char ***arg_type_infos, 2631 const char ***arg_descriptions); 2632 /*! 2633 * \brief Free the handle to the IO module 2634 * \param handle the handle pointer to the data iterator 2635 * \return 0 when success, -1 when failure happens 2636 */ 2637 MXNET_DLL int MXDataIterFree(DataIterHandle handle); 2638 /*! 2639 * \brief Move iterator to next position 2640 * \param handle the handle to iterator 2641 * \param out return value of next 2642 * \return 0 when success, -1 when failure happens 2643 */ 2644 MXNET_DLL int MXDataIterNext(DataIterHandle handle, 2645 int *out); 2646 /*! 2647 * \brief Call iterator.Reset 2648 * \param handle the handle to iterator 2649 * \return 0 when success, -1 when failure happens 2650 */ 2651 MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); 2652 2653 /*! 2654 * \brief Get the handle to the NDArray of underlying data 2655 * \param handle the handle pointer to the data iterator 2656 * \param out handle to underlying data NDArray 2657 * \return 0 when success, -1 when failure happens 2658 */ 2659 MXNET_DLL int MXDataIterGetData(DataIterHandle handle, 2660 NDArrayHandle *out); 2661 /*! 2662 * \brief Get the image index by array. 2663 * \param handle the handle pointer to the data iterator 2664 * \param out_index output index of the array. 2665 * \param out_size output size of the array. 2666 * \return 0 when success, -1 when failure happens 2667 */ 2668 MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, 2669 uint64_t **out_index, 2670 uint64_t *out_size); 2671 /*! 2672 * \brief Get the padding number in current data batch 2673 * \param handle the handle pointer to the data iterator 2674 * \param pad pad number ptr 2675 * \return 0 when success, -1 when failure happens 2676 */ 2677 MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle, 2678 int *pad); 2679 2680 /*! 2681 * \brief Get the handle to the NDArray of underlying label 2682 * \param handle the handle pointer to the data iterator 2683 * \param out the handle to underlying label NDArray 2684 * \return 0 when success, -1 when failure happens 2685 */ 2686 MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, 2687 NDArrayHandle *out); 2688 //-------------------------------------------- 2689 // Part 6: basic KVStore interface 2690 //-------------------------------------------- 2691 /*! 2692 * \brief Initialized ps-lite environment variables 2693 * \param num_vars number of variables to initialize 2694 * \param keys environment keys 2695 * \param vals environment values 2696 */ 2697 MXNET_DLL int MXInitPSEnv(uint32_t num_vars, 2698 const char **keys, 2699 const char **vals); 2700 2701 2702 /*! 2703 * \brief Create a kvstore 2704 * \param type the type of KVStore 2705 * \param out The output type of KVStore 2706 * \return 0 when success, -1 when failure happens 2707 */ 2708 MXNET_DLL int MXKVStoreCreate(const char *type, 2709 KVStoreHandle *out); 2710 2711 /*! 2712 * \brief Set parameters to use low-bit compressed gradients 2713 * \param handle handle to the kvstore 2714 * \param keys keys for compression parameters 2715 * \param vals values for compression parameters 2716 * \return 0 when success, -1 when failure happens 2717 */ 2718 MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle, 2719 uint32_t num_params, 2720 const char** keys, 2721 const char** vals); 2722 2723 /*! 2724 * \brief Delete a KVStore handle. 2725 * \param handle handle to the kvstore 2726 * \return 0 when success, -1 when failure happens 2727 */ 2728 MXNET_DLL int MXKVStoreFree(KVStoreHandle handle); 2729 /*! 2730 * \brief Init a list of (key,value) pairs in kvstore 2731 * \param handle handle to the kvstore 2732 * \param num the number of key-value pairs 2733 * \param keys the list of keys 2734 * \param vals the list of values 2735 * \return 0 when success, -1 when failure happens 2736 */ 2737 MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, 2738 uint32_t num, 2739 const int* keys, 2740 NDArrayHandle* vals); 2741 2742 /*! 2743 * \brief Init a list of (key,value) pairs in kvstore, where each key is a string 2744 * \param handle handle to the kvstore 2745 * \param num the number of key-value pairs 2746 * \param keys the list of keys 2747 * \param vals the list of values 2748 * \return 0 when success, -1 when failure happens 2749 */ 2750 MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, 2751 uint32_t num, 2752 const char** keys, 2753 NDArrayHandle* vals); 2754 2755 /*! 2756 * \brief Push a list of (key,value) pairs to kvstore 2757 * \param handle handle to the kvstore 2758 * \param num the number of key-value pairs 2759 * \param keys the list of keys 2760 * \param vals the list of values 2761 * \param priority the priority of the action 2762 * \return 0 when success, -1 when failure happens 2763 */ 2764 MXNET_DLL int MXKVStorePush(KVStoreHandle handle, 2765 uint32_t num, 2766 const int* keys, 2767 NDArrayHandle* vals, 2768 int priority); 2769 /*! 2770 * \brief Push a list of (key,value) pairs to kvstore, where each key is a string 2771 * \param handle handle to the kvstore 2772 * \param num the number of key-value pairs 2773 * \param keys the list of keys 2774 * \param vals the list of values 2775 * \param priority the priority of the action 2776 * \return 0 when success, -1 when failure happens 2777 */ 2778 MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, 2779 uint32_t num, 2780 const char** keys, 2781 NDArrayHandle* vals, 2782 int priority); 2783 /*! 2784 * \brief pull a list of (key, value) pairs from the kvstore 2785 * \param handle handle to the kvstore 2786 * \param num the number of key-value pairs 2787 * \param keys the list of keys 2788 * \param vals the list of values 2789 * \param priority the priority of the action 2790 * \param ignore_sparse whether to ignore sparse arrays in the request 2791 * \return 0 when success, -1 when failure happens 2792 */ 2793 MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle, 2794 uint32_t num, 2795 const int* keys, 2796 NDArrayHandle* vals, 2797 int priority, 2798 bool ignore_sparse); 2799 /*! 2800 * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string 2801 * \param handle handle to the kvstore 2802 * \param num the number of key-value pairs 2803 * \param keys the list of keys 2804 * \param vals the list of values 2805 * \param priority the priority of the action 2806 * \param ignore_sparse whether to ignore sparse arrays in the request 2807 * \return 0 when success, -1 when failure happens 2808 */ 2809 MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle, 2810 uint32_t num, 2811 const char** keys, 2812 NDArrayHandle* vals, 2813 int priority, 2814 bool ignore_sparse); 2815 /*! 2816 * \brief pull a list of (key, value) pairs from the kvstore 2817 * \param handle handle to the kvstore 2818 * \param num the number of key-value pairs 2819 * \param keys the list of keys 2820 * \param vals the list of values 2821 * \param priority the priority of the action 2822 * \return 0 when success, -1 when failure happens 2823 */ 2824 MXNET_DLL int MXKVStorePull(KVStoreHandle handle, 2825 uint32_t num, 2826 const int* keys, 2827 NDArrayHandle* vals, 2828 int priority); 2829 /*! 2830 * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string 2831 * \param handle handle to the kvstore 2832 * \param num the number of key-value pairs 2833 * \param keys the list of keys 2834 * \param vals the list of values 2835 * \param priority the priority of the action 2836 * \return 0 when success, -1 when failure happens 2837 */ 2838 MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, 2839 uint32_t num, 2840 const char** keys, 2841 NDArrayHandle* vals, 2842 int priority); 2843 2844 /*! 2845 * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer. 2846 * The NDArray pulled back will be in row_sparse storage with only the specified 2847 * row_ids present based row_ids (others rows are zeros). 2848 * \param handle handle to the kvstore 2849 * \param num the number of key-value pairs 2850 * \param keys the list of keys 2851 * \param vals the list of values 2852 * \param row_ids the list of row_id NDArrays 2853 * \param priority the priority of the action 2854 * \return 0 when success, -1 when failure happens 2855 */ 2856 MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle, 2857 uint32_t num, 2858 const int* keys, 2859 NDArrayHandle* vals, 2860 const NDArrayHandle* row_ids, 2861 int priority); 2862 /*! 2863 * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string. 2864 * The NDArray pulled back will be in row_sparse storage with only the specified 2865 * row_ids present based row_ids (others rows are zeros). 2866 * \param handle handle to the kvstore 2867 * \param num the number of key-value pairs 2868 * \param keys the list of keys 2869 * \param vals the list of values 2870 * \param row_ids the list of row_id NDArrays 2871 * \param priority the priority of the action 2872 * \return 0 when success, -1 when failure happens 2873 */ 2874 MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle, 2875 uint32_t num, 2876 const char** keys, 2877 NDArrayHandle* vals, 2878 const NDArrayHandle* row_ids, 2879 int priority); 2880 2881 /*! 2882 * \brief broadcast a list of (key, value) pairs from the kvstore 2883 * \param handle handle to the kvstore 2884 * \param vnum the number of key-value pairs corresponding to vkeys 2885 * \param vkeys the list of keys for the values to be pushed 2886 * \param onum the number of key-value pairs corresponding to okeys 2887 * \param okeys the list of keys for the values to be pulled 2888 * \param vals the list of values 2889 * \param outs the list of outputs 2890 * \param priority the priority of the action 2891 * \return 0 when success, -1 when failure happens 2892 */ 2893 MXNET_DLL int MXKVStoreBroadcast(KVStoreHandle handle, 2894 mx_uint vnum, 2895 const int* vkeys, 2896 mx_uint onum, 2897 const int* okeys, 2898 NDArrayHandle* vals, 2899 NDArrayHandle* outs, 2900 int priority); 2901 /*! 2902 * \brief broadcast a list of (key, value) pairs from the kvstore, 2903 * where each key is a string 2904 * \param handle handle to the kvstore 2905 * \param vnum the number of key-value pairs corresponding to vkeys 2906 * \param vkeys the list of keys for the values to be pushed 2907 * \param onum the number of key-value pairs corresponding to okeys 2908 * \param okeys the list of keys for the values to be pulled 2909 * \param vals the list of values 2910 * \param outs the list of outputs 2911 * \param priority the priority of the action 2912 * \return 0 when success, -1 when failure happens 2913 */ 2914 MXNET_DLL int MXKVStoreBroadcastEx(KVStoreHandle handle, 2915 mx_uint vnum, 2916 const char** vkeys, 2917 mx_uint onum, 2918 const char** okeys, 2919 NDArrayHandle* vals, 2920 NDArrayHandle* outs, 2921 int priority); 2922 2923 /*! 2924 * \brief push and pull a list of (key, value) pairs from the kvstore 2925 * \param handle handle to the kvstore 2926 * \param vnum the number of key-value pairs corresponding to vkeys 2927 * \param vkeys the list of keys for the values to be pushed 2928 * \param onum the number of key-value pairs corresponding to okeys 2929 * \param okeys the list of keys for the values to be pulled 2930 * \param vals the list of values 2931 * \param outs the list of outputs 2932 * \param priority the priority of the action 2933 * \return 0 when success, -1 when failure happens 2934 */ 2935 MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle, 2936 mx_uint vnum, 2937 const int* vkeys, 2938 mx_uint onum, 2939 const int* okeys, 2940 NDArrayHandle* vals, 2941 NDArrayHandle* outs, 2942 int priority); 2943 /*! 2944 * \brief push and pull a list of (key, value) pairs from the kvstore, 2945 * where each key is a string 2946 * \param handle handle to the kvstore 2947 * \param vnum the number of key-value pairs corresponding to vkeys 2948 * \param vkeys the list of keys for the values to be pushed 2949 * \param onum the number of key-value pairs corresponding to okeys 2950 * \param okeys the list of keys for the values to be pulled 2951 * \param vals the list of values 2952 * \param outs the list of outputs 2953 * \param priority the priority of the action 2954 * \return 0 when success, -1 when failure happens 2955 */ 2956 MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle, 2957 mx_uint vnum, 2958 const char** vkeys, 2959 mx_uint onum, 2960 const char** okeys, 2961 NDArrayHandle* vals, 2962 NDArrayHandle* outs, 2963 int priority); 2964 2965 /*! 2966 * \brief user-defined updater for the kvstore 2967 * It's this updater's responsibility to delete \a recv and \a local 2968 * \param the key 2969 * \param recv the pushed value on this key 2970 * \param local the value stored on local on this key 2971 * \param handle The additional handle to the updater 2972 */ 2973 typedef void (MXKVStoreUpdater)(int key, 2974 NDArrayHandle recv, 2975 NDArrayHandle local, 2976 void *handle); 2977 /*! 2978 * \brief user-defined updater for the kvstore with string keys 2979 * It's this updater's responsibility to delete \a recv and \a local 2980 * \param the key 2981 * \param recv the pushed value on this key 2982 * \param local the value stored on local on this key 2983 * \param handle The additional handle to the updater 2984 */ 2985 typedef void (MXKVStoreStrUpdater)(const char* key, 2986 NDArrayHandle recv, 2987 NDArrayHandle local, 2988 void *handle); 2989 /*! 2990 * \brief register a push updater 2991 * \param handle handle to the KVStore 2992 * \param updater udpater function 2993 * \param updater_handle The additional handle used to invoke the updater 2994 * \return 0 when success, -1 when failure happens 2995 */ 2996 MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle, 2997 MXKVStoreUpdater updater, 2998 void *updater_handle); 2999 /*! 3000 * \brief register a push updater with int keys and one with string keys 3001 * \param handle handle to the KVStore 3002 * \param updater updater function with int keys 3003 * \param str_updater updater function with string keys 3004 * \param updater_handle The additional handle used to invoke the updater 3005 * \return 0 when success, -1 when failure happens 3006 */ 3007 MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle, 3008 MXKVStoreUpdater updater, 3009 MXKVStoreStrUpdater str_updater, 3010 void *updater_handle); 3011 /*! 3012 * \brief get the type of the kvstore 3013 * \param handle handle to the KVStore 3014 * \param type a string type 3015 * \return 0 when success, -1 when failure happens 3016 */ 3017 MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle, 3018 const char** type); 3019 //-------------------------------------------- 3020 // Part 6: advanced KVStore for multi-machines 3021 //-------------------------------------------- 3022 3023 /** 3024 * \brief return The rank of this node in its group, which is in [0, GroupSize). 3025 * 3026 * \param handle handle to the KVStore 3027 * \param ret the node rank 3028 * \return 0 when success, -1 when failure happens 3029 */ 3030 MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle, 3031 int *ret); 3032 3033 /** 3034 * \brief return The number of nodes in this group, which is 3035 * - number of workers if if `IsWorkerNode() == true`, 3036 * - number of servers if if `IsServerNode() == true`, 3037 * - 1 if `IsSchedulerNode() == true`, 3038 * \param handle handle to the KVStore 3039 * \param ret the group size 3040 * \return 0 when success, -1 when failure happens 3041 */ 3042 MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle, 3043 int *ret); 3044 3045 /** 3046 * \brief return whether or not this process is a worker node. 3047 * \param ret 1 for yes, 0 for no 3048 * \return 0 when success, -1 when failure happens 3049 */ 3050 MXNET_DLL int MXKVStoreIsWorkerNode(int *ret); 3051 3052 3053 /** 3054 * \brief return whether or not this process is a server node. 3055 * \param ret 1 for yes, 0 for no 3056 * \return 0 when success, -1 when failure happens 3057 */ 3058 MXNET_DLL int MXKVStoreIsServerNode(int *ret); 3059 3060 3061 /** 3062 * \brief return whether or not this process is a scheduler node. 3063 * \param ret 1 for yes, 0 for no 3064 * \return 0 when success, -1 when failure happens 3065 */ 3066 MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret); 3067 3068 /** 3069 * \brief global barrier among all worker machines 3070 * 3071 * \param handle handle to the KVStore 3072 * \return 0 when success, -1 when failure happens 3073 */ 3074 MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle); 3075 3076 /** 3077 * \brief whether to do barrier when finalize 3078 * 3079 * \param handle handle to the KVStore 3080 * \param barrier_before_exit whether to do barrier when kvstore finalize 3081 * \return 0 when success, -1 when failure happens 3082 */ 3083 MXNET_DLL int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, 3084 const int barrier_before_exit); 3085 3086 /** 3087 * \brief the prototype of a server controller 3088 * \param head the head of the command 3089 * \param body the body of the command 3090 * \param controller_handle helper handle for implementing controller 3091 */ 3092 typedef void (MXKVStoreServerController)(int head, 3093 const char *body, 3094 void *controller_handle); 3095 3096 /** 3097 * \brief Run as server (or scheduler) 3098 * \param handle handle to the KVStore 3099 * \param controller the user-defined server controller 3100 * \param controller_handle helper handle for implementing controller 3101 * \return 0 when success, -1 when failure happens 3102 */ 3103 MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle, 3104 MXKVStoreServerController controller, 3105 void *controller_handle); 3106 3107 /** 3108 * \brief Send a command to all server nodes 3109 * \param handle handle to the KVStore 3110 * \param cmd_id the head of the command 3111 * \param cmd_body the body of the command 3112 * \return 0 when success, -1 when failure happens 3113 */ 3114 MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle, 3115 int cmd_id, 3116 const char* cmd_body); 3117 3118 /** 3119 * \brief Get the number of ps dead node(s) specified by {node_id} 3120 * 3121 * \param handle handle to the KVStore 3122 * \param node_id Can be a node group or a single node. 3123 * kScheduler = 1, kServerGroup = 2, kWorkerGroup = 4 3124 * \param number Ouptut number of dead nodes 3125 * \param timeout_sec A node fails to send heartbeart in {timeout_sec} seconds 3126 * will be presumed as 'dead' 3127 */ 3128 MXNET_DLL int MXKVStoreGetNumDeadNode(KVStoreHandle handle, 3129 const int node_id, 3130 int *number, 3131 const int timeout_sec DEFAULT(60)); 3132 3133 /** 3134 * \brief Create a RecordIO writer object 3135 * \param uri path to file 3136 * \param out handle pointer to the created object 3137 * \return 0 when success, -1 when failure happens 3138 */ 3139 MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out); 3140 3141 /** 3142 * \brief Delete a RecordIO writer object 3143 * \param handle handle to RecordIO object 3144 * \return 0 when success, -1 when failure happens 3145 */ 3146 MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle); 3147 3148 /** 3149 * \brief Write a record to a RecordIO object 3150 * \param handle handle to RecordIO object 3151 * \param buf buffer to write 3152 * \param size size of buffer 3153 * \return 0 when success, -1 when failure happens 3154 */ 3155 MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle handle, 3156 const char *buf, size_t size); 3157 3158 /** 3159 * \brief Get the current writer pointer position 3160 * \param handle handle to RecordIO object 3161 * \param pos handle to output position 3162 * \return 0 when success, -1 when failure happens 3163 */ 3164 MXNET_DLL int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos); 3165 3166 /** 3167 * \brief Create a RecordIO reader object 3168 * \param uri path to file 3169 * \param out handle pointer to the created object 3170 * \return 0 when success, -1 when failure happens 3171 */ 3172 MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out); 3173 3174 /** 3175 * \brief Delete a RecordIO reader object 3176 * \param handle handle to RecordIO object 3177 * \return 0 when success, -1 when failure happens 3178 */ 3179 MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle handle); 3180 3181 /** 3182 * \brief Write a record to a RecordIO object 3183 * \param handle handle to RecordIO object 3184 * \param buf pointer to return buffer 3185 * \param size point to size of buffer 3186 * \return 0 when success, -1 when failure happens 3187 */ 3188 MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle handle, 3189 char const **buf, size_t *size); 3190 3191 /** 3192 * \brief Set the current reader pointer position 3193 * \param handle handle to RecordIO object 3194 * \param pos target position 3195 * \return 0 when success, -1 when failure happens 3196 */ 3197 MXNET_DLL int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos); 3198 3199 /** 3200 * \brief Get the current writer pointer position 3201 * \param handle handle to RecordIO object 3202 * \param pos handle to output position 3203 * \return 0 when success, -1 when failure happens 3204 */ 3205 MXNET_DLL int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos); 3206 3207 /** 3208 * \brief Create a MXRtc object 3209 */ 3210 MXNET_DLL int MXRtcCreate(char* name, uint32_t num_input, uint32_t num_output, 3211 char** input_names, char** output_names, 3212 NDArrayHandle* inputs, NDArrayHandle* outputs, 3213 char* kernel, RtcHandle *out); 3214 3215 /** 3216 * \brief Run cuda kernel 3217 */ 3218 MXNET_DLL int MXRtcPush(RtcHandle handle, uint32_t num_input, uint32_t num_output, 3219 NDArrayHandle* inputs, NDArrayHandle* outputs, 3220 uint32_t gridDimX, 3221 uint32_t gridDimY, 3222 uint32_t gridDimZ, 3223 uint32_t blockDimX, 3224 uint32_t blockDimY, 3225 uint32_t blockDimZ); 3226 3227 /** 3228 * \brief Delete a MXRtc object 3229 */ 3230 MXNET_DLL int MXRtcFree(RtcHandle handle); 3231 /* 3232 * \brief register custom operators from frontend. 3233 * \param op_type name of custom op 3234 * \param creator 3235 */ 3236 MXNET_DLL int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator); 3237 /* 3238 * \brief record custom function for backward later. 3239 * \param num_inputs number of input NDArrays. 3240 * \param inputs handle to input NDArrays. 3241 * \param num_outputs number of output NDArrays. 3242 * \param outputs handle to output NDArrays. 3243 * \param callbacks callbacks for backward function. 3244 */ 3245 MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, 3246 int num_outputs, NDArrayHandle *outputs, 3247 struct MXCallbackList *callbacks); 3248 /* 3249 * \brief create cuda rtc module 3250 * \param source cuda source code 3251 * \param num_options number of compiler flags 3252 * \param options compiler flags 3253 * \param num_exports number of exported function names 3254 * \param exported function names 3255 * \param out handle to created module 3256 */ 3257 MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options, 3258 const char** options, int num_exports, 3259 const char** exports, CudaModuleHandle *out); 3260 /* 3261 * \brief delete cuda rtc module 3262 * \param handle handle to cuda module 3263 */ 3264 MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle); 3265 /* 3266 * \brief get kernel from module 3267 * \param handle handle to cuda module 3268 * \param name name of kernel function 3269 * \param num_args number of arguments 3270 * \param is_ndarray whether argument is ndarray 3271 * \param is_const whether argument is constant 3272 * \param arg_types data type of arguments 3273 * \param out created kernel 3274 */ 3275 MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, 3276 int num_args, int* is_ndarray, int* is_const, 3277 int* arg_types, CudaKernelHandle *out); 3278 /* 3279 * \brief delete kernel 3280 * \param handle handle to previously created kernel 3281 */ 3282 MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle); 3283 /* 3284 * \brief launch cuda kernel 3285 * \param handle handle to kernel 3286 * \param dev_id (GPU) device id 3287 * \param args pointer to arguments 3288 * \param grid_dim_x grid dimension x 3289 * \param grid_dim_y grid dimension y 3290 * \param grid_dim_z grid dimension z 3291 * \param block_dim_x block dimension x 3292 * \param block_dim_y block dimension y 3293 * \param block_dim_z block dimension z 3294 * \param shared_mem size of dynamically allocated shared memory 3295 */ 3296 MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args, 3297 uint32_t grid_dim_x, uint32_t grid_dim_y, 3298 uint32_t grid_dim_z, uint32_t block_dim_x, 3299 uint32_t block_dim_y, uint32_t block_dim_z, 3300 uint32_t shared_mem); 3301 /*! 3302 * \brief Get shared memory handle from NDArray 3303 * \param handle NDArray handle. 3304 * \param shared_pid output PID 3305 * \param shared_id output shared memory id. 3306 */ 3307 MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, 3308 int* shared_id); 3309 /*! 3310 * \brief DEPRECATED. Use MXNDArrayCreateFromSharedMemEx instead. 3311 * Reconstruct NDArray from shared memory handle 3312 * \param shared_pid shared PID 3313 * \param shared_id shared memory id 3314 * \param shape pointer to NDArray dimensions 3315 * \param ndim number of NDArray dimensions 3316 * \param dtype data type of NDArray 3317 * \param out constructed NDArray 3318 */ 3319 MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const uint32_t *shape, 3320 uint32_t ndim, int dtype, NDArrayHandle *out); 3321 3322 /*! 3323 * \brief Release all unreferenced memory from the devices storage managers memory pool 3324 * \param dev_type device type, specify device we want to take 3325 * \param dev_id the device id of the specific device 3326 */ 3327 MXNET_DLL int MXStorageEmptyCache(int dev_type, int dev_id); 3328 3329 /*! 3330 * \brief Reconstruct NDArray from shared memory handle 3331 * \param shared_pid shared PID 3332 * \param shared_id shared memory id 3333 * \param shape pointer to NDArray dimensions 3334 * \param ndim number of NDArray dimensions 3335 * \param dtype data type of NDArray 3336 * \param out constructed NDArray 3337 */ 3338 MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape, 3339 int ndim, int dtype, NDArrayHandle *out); 3340 3341 /*! 3342 * \brief Push an asynchronous operation to the engine. 3343 * \param async_func Execution function whici takes a parameter on_complete 3344 * that must be called when the execution ompletes. 3345 * \param func_param The parameter set on calling async_func, can be NULL. 3346 * \param deleter The callback to free func_param, can be NULL. 3347 * \param ctx_handle Execution context. 3348 * \param const_vars_handle The variables that current operation will use 3349 * but not mutate. 3350 * \param num_const_vars The number of const_vars_handle. 3351 * \param mutable_vars_handle The variables that current operation will mutate. 3352 * \param num_mutable_vars The number of mutable_vars_handle. 3353 * \param prop_handle Property of the function. 3354 * \param priority Priority of the action, as hint to the engine. 3355 * \param opr_name The operation name. 3356 * \param wait Whether this is a WaitForVar operation. 3357 */ 3358 MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, 3359 EngineFuncParamDeleter deleter, ContextHandle ctx_handle, 3360 EngineVarHandle const_vars_handle, int num_const_vars, 3361 EngineVarHandle mutable_vars_handle, int num_mutable_vars, 3362 EngineFnPropertyHandle prop_handle DEFAULT(NULL), 3363 int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), 3364 bool wait DEFAULT(false)); 3365 3366 /*! 3367 * \brief Push a synchronous operation to the engine. 3368 * \param sync_func Execution function that executes the operation. 3369 * \param func_param The parameter set on calling sync_func, can be NULL. 3370 * \param deleter The callback to free func_param, can be NULL. 3371 * \param ctx_handle Execution context. 3372 * \param const_vars_handle The variables that current operation will use 3373 * but not mutate. 3374 * \param num_const_vars The number of const_vars_handle. 3375 * \param mutable_vars_handle The variables that current operation will mutate. 3376 * \param num_mutable_vars The number of mutable_vars_handle. 3377 * \param prop_handle Property of the function. 3378 * \param priority Priority of the action, as hint to the engine. 3379 * \param opr_name The operation name. 3380 */ 3381 MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, 3382 EngineFuncParamDeleter deleter, ContextHandle ctx_handle, 3383 EngineVarHandle const_vars_handle, int num_const_vars, 3384 EngineVarHandle mutable_vars_handle, int num_mutable_vars, 3385 EngineFnPropertyHandle prop_handle DEFAULT(NULL), 3386 int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); 3387 /*! 3388 * \brief Create an NDArray from source sharing the same data chunk. 3389 * \param src source NDArray 3390 * \param out new NDArray sharing the same data chunck with src 3391 */ 3392 MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out); 3393 /*! 3394 * \brief Create an Symbol from source sharing the same graph structure. 3395 * \param src source Symbol 3396 * \param out new Symbol sharing the same graph structure with src 3397 */ 3398 MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out); 3399 3400 /*! 3401 * \brief Push an asynchronous operation to the engine. 3402 * \param async_func Execution function whici takes a parameter on_complete 3403 * that must be called when the execution ompletes. 3404 * \param func_param The parameter set on calling async_func, can be NULL. 3405 * \param deleter The callback to free func_param, can be NULL. 3406 * \param ctx_handle Execution context. 3407 * \param const_nds_handle The NDArrays that current operation will use 3408 * but not mutate. 3409 * \param num_const_nds The number of const_nds_handle. 3410 * \param mutable_nds_handle The NDArrays that current operation will mutate. 3411 * \param num_mutable_nds The number of mutable_nds_handle. 3412 * \param prop_handle Property of the function. 3413 * \param priority Priority of the action, as hint to the engine. 3414 * \param opr_name The operation name. 3415 * \param wait Whether this is a WaitForVar operation. 3416 */ 3417 MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, 3418 EngineFuncParamDeleter deleter, ContextHandle ctx_handle, 3419 NDArrayHandle* const_nds_handle, int num_const_nds, 3420 NDArrayHandle* mutable_nds_handle, int num_mutable_nds, 3421 EngineFnPropertyHandle prop_handle DEFAULT(NULL), 3422 int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), 3423 bool wait DEFAULT(false)); 3424 3425 /*! 3426 * \brief Push a synchronous operation to the engine. 3427 * \param sync_func Execution function that executes the operation. 3428 * \param func_param The parameter set on calling sync_func, can be NULL. 3429 * \param deleter The callback to free func_param, can be NULL. 3430 * \param ctx_handle Execution context. 3431 * \param const_nds_handle The NDArrays that current operation will use 3432 * but not mutate. 3433 * \param num_const_nds The number of const_nds_handle. 3434 * \param mutable_nds_handle The NDArrays that current operation will mutate. 3435 * \param num_mutable_nds The number of mutable_nds_handle. 3436 * \param prop_handle Property of the function. 3437 * \param priority Priority of the action, as hint to the engine. 3438 * \param opr_name The operation name. 3439 */ 3440 MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, 3441 EngineFuncParamDeleter deleter, ContextHandle ctx_handle, 3442 NDArrayHandle* const_nds_handle, int num_const_nds, 3443 NDArrayHandle* mutable_nds_handle, int num_mutable_nds, 3444 EngineFnPropertyHandle prop_handle DEFAULT(NULL), 3445 int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); 3446 3447 #ifdef __cplusplus 3448 } 3449 #endif // __cplusplus 3450 3451 #endif // MXNET_C_API_H_ 3452