1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file op_map.h 22 * \brief definition of OpMap 23 * \author Chuntao Hong 24 */ 25 26 #ifndef MXNET_CPP_OP_MAP_H_ 27 #define MXNET_CPP_OP_MAP_H_ 28 29 #include <map> 30 #include <string> 31 #include "mxnet-cpp/base.h" 32 #include "dmlc/logging.h" 33 34 namespace mxnet { 35 namespace cpp { 36 37 /*! 38 * \brief OpMap instance holds a map of all the symbol creators so we can 39 * get symbol creators by name. 40 * This is used internally by Symbol and Operator. 41 */ 42 class OpMap { 43 public: 44 /*! 45 * \brief Create an Mxnet instance 46 */ OpMap()47 inline OpMap() { 48 mx_uint num_symbol_creators = 0; 49 AtomicSymbolCreator *symbol_creators = nullptr; 50 int r = 51 MXSymbolListAtomicSymbolCreators(&num_symbol_creators, &symbol_creators); 52 CHECK_EQ(r, 0); 53 for (mx_uint i = 0; i < num_symbol_creators; i++) { 54 const char *name; 55 const char *description; 56 mx_uint num_args; 57 const char **arg_names; 58 const char **arg_type_infos; 59 const char **arg_descriptions; 60 const char *key_var_num_args; 61 r = MXSymbolGetAtomicSymbolInfo(symbol_creators[i], &name, &description, 62 &num_args, &arg_names, &arg_type_infos, 63 &arg_descriptions, &key_var_num_args); 64 CHECK_EQ(r, 0); 65 symbol_creators_[name] = symbol_creators[i]; 66 } 67 68 nn_uint num_ops; 69 const char **op_names; 70 r = NNListAllOpNames(&num_ops, &op_names); 71 CHECK_EQ(r, 0); 72 for (nn_uint i = 0; i < num_ops; i++) { 73 OpHandle handle; 74 r = NNGetOpHandle(op_names[i], &handle); 75 CHECK_EQ(r, 0); 76 op_handles_[op_names[i]] = handle; 77 } 78 } 79 80 /*! 81 * \brief Get a symbol creator with its name. 82 * 83 * \param name name of the symbol creator 84 * \return handle to the symbol creator 85 */ GetSymbolCreator(const std::string & name)86 inline AtomicSymbolCreator GetSymbolCreator(const std::string &name) { 87 if (symbol_creators_.count(name) == 0) 88 return GetOpHandle(name); 89 return symbol_creators_[name]; 90 } 91 92 /*! 93 * \brief Get an op handle with its name. 94 * 95 * \param name name of the op 96 * \return handle to the op 97 */ GetOpHandle(const std::string & name)98 inline OpHandle GetOpHandle(const std::string &name) { 99 return op_handles_[name]; 100 } 101 102 private: 103 std::map<std::string, AtomicSymbolCreator> symbol_creators_; 104 std::map<std::string, OpHandle> op_handles_; 105 }; 106 107 } // namespace cpp 108 } // namespace mxnet 109 110 #endif // MXNET_CPP_OP_MAP_H_ 111