1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file op_map.h
22 * \brief definition of OpMap
23 * \author Chuntao Hong
24 */
25 
26 #ifndef MXNET_CPP_OP_MAP_H_
27 #define MXNET_CPP_OP_MAP_H_
28 
29 #include <map>
30 #include <string>
31 #include "mxnet-cpp/base.h"
32 #include "dmlc/logging.h"
33 
34 namespace mxnet {
35 namespace cpp {
36 
37 /*!
38 * \brief OpMap instance holds a map of all the symbol creators so we can
39 *  get symbol creators by name.
40 *  This is used internally by Symbol and Operator.
41 */
42 class OpMap {
43  public:
44   /*!
45   * \brief Create an Mxnet instance
46   */
OpMap()47   inline OpMap() {
48     mx_uint num_symbol_creators = 0;
49     AtomicSymbolCreator *symbol_creators = nullptr;
50     int r =
51       MXSymbolListAtomicSymbolCreators(&num_symbol_creators, &symbol_creators);
52     CHECK_EQ(r, 0);
53     for (mx_uint i = 0; i < num_symbol_creators; i++) {
54       const char *name;
55       const char *description;
56       mx_uint num_args;
57       const char **arg_names;
58       const char **arg_type_infos;
59       const char **arg_descriptions;
60       const char *key_var_num_args;
61       r = MXSymbolGetAtomicSymbolInfo(symbol_creators[i], &name, &description,
62         &num_args, &arg_names, &arg_type_infos,
63         &arg_descriptions, &key_var_num_args);
64       CHECK_EQ(r, 0);
65       symbol_creators_[name] = symbol_creators[i];
66     }
67 
68     nn_uint num_ops;
69     const char **op_names;
70     r = NNListAllOpNames(&num_ops, &op_names);
71     CHECK_EQ(r, 0);
72     for (nn_uint i = 0; i < num_ops; i++) {
73       OpHandle handle;
74       r = NNGetOpHandle(op_names[i], &handle);
75       CHECK_EQ(r, 0);
76       op_handles_[op_names[i]] = handle;
77     }
78   }
79 
80   /*!
81   * \brief Get a symbol creator with its name.
82   *
83   * \param name name of the symbol creator
84   * \return handle to the symbol creator
85   */
GetSymbolCreator(const std::string & name)86   inline AtomicSymbolCreator GetSymbolCreator(const std::string &name) {
87     if (symbol_creators_.count(name) == 0)
88       return GetOpHandle(name);
89     return symbol_creators_[name];
90   }
91 
92   /*!
93   * \brief Get an op handle with its name.
94   *
95   * \param name name of the op
96   * \return handle to the op
97   */
GetOpHandle(const std::string & name)98   inline OpHandle GetOpHandle(const std::string &name) {
99     return op_handles_[name];
100   }
101 
102  private:
103   std::map<std::string, AtomicSymbolCreator> symbol_creators_;
104   std::map<std::string, OpHandle> op_handles_;
105 };
106 
107 }  // namespace cpp
108 }  // namespace mxnet
109 
110 #endif  // MXNET_CPP_OP_MAP_H_
111