1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file operator.h 22 * \brief definition of operator 23 * \author Chuntao Hong, Zhang Chen 24 */ 25 26 #ifndef MXNET_CPP_OPERATOR_H_ 27 #define MXNET_CPP_OPERATOR_H_ 28 29 #include <map> 30 #include <string> 31 #include <vector> 32 #include "mxnet-cpp/base.h" 33 #include "mxnet-cpp/op_map.h" 34 #include "mxnet-cpp/symbol.h" 35 36 namespace mxnet { 37 namespace cpp { 38 class Mxnet; 39 /*! 40 * \brief Operator interface 41 */ 42 class Operator { 43 public: 44 /*! 45 * \brief Operator constructor 46 * \param operator_name type of the operator 47 */ 48 explicit Operator(const std::string &operator_name); 49 Operator &operator=(const Operator &rhs); 50 /*! 51 * \brief set config parameters 52 * \param name name of the config parameter 53 * \param value value of the config parameter 54 * \return reference of self 55 */ 56 template <typename T> SetParam(const std::string & name,const T & value)57 Operator &SetParam(const std::string &name, const T &value) { 58 std::string value_str; 59 std::stringstream ss; 60 ss << value; 61 ss >> value_str; 62 63 params_[name] = value_str; 64 return *this; 65 } 66 /*! 67 * \brief set config parameters from positional inputs 68 * \param pos the position of parameter 69 * \param value value of the config parameter 70 * \return reference of self 71 */ 72 template <typename T> SetParam(int pos,const T & value)73 Operator &SetParam(int pos, const T &value) { 74 std::string value_str; 75 std::stringstream ss; 76 ss << value; 77 ss >> value_str; 78 79 params_[arg_names_[pos]] = value_str; 80 return *this; 81 } 82 /*! 83 * \brief add an input symbol 84 * \param name name of the input symbol 85 * \param symbol the input symbol 86 * \return reference of self 87 */ 88 Operator &SetInput(const std::string &name, const Symbol &symbol); 89 /*! 90 * \brief add an input symbol 91 * \param symbol the input symbol 92 */ 93 template<int N = 0> PushInput(const Symbol & symbol)94 void PushInput(const Symbol &symbol) { 95 input_symbols_.push_back(symbol.GetHandle()); 96 } 97 /*! 98 * \brief add input symbols 99 * \return reference of self 100 */ operator()101 Operator &operator()() { return *this; } 102 /*! 103 * \brief add input symbols 104 * \param symbol the input symbol 105 * \return reference of self 106 */ operator()107 Operator &operator()(const Symbol &symbol) { 108 input_symbols_.push_back(symbol.GetHandle()); 109 return *this; 110 } 111 /*! 112 * \brief add a list of input symbols 113 * \param symbols the vector of the input symbols 114 * \return reference of self 115 */ operator()116 Operator &operator()(const std::vector<Symbol> &symbols) { 117 for (auto &s : symbols) { 118 input_symbols_.push_back(s.GetHandle()); 119 } 120 return *this; 121 } 122 /*! 123 * \brief create a Symbol from the current operator 124 * \param name the name of the operator 125 * \return the operator Symbol 126 */ 127 Symbol CreateSymbol(const std::string &name = ""); 128 129 /*! 130 * \brief add an input ndarray 131 * \param name name of the input ndarray 132 * \param ndarray the input ndarray 133 * \return reference of self 134 */ 135 Operator &SetInput(const std::string &name, const NDArray &ndarray); 136 /*! 137 * \brief add an input ndarray 138 * \param ndarray the input ndarray 139 */ 140 template<int N = 0> PushInput(const NDArray & ndarray)141 Operator &PushInput(const NDArray &ndarray) { 142 input_ndarrays_.push_back(ndarray.GetHandle()); 143 return *this; 144 } 145 /*! 146 * \brief add positional inputs 147 */ 148 template <class T, class... Args, int N = 0> PushInput(const T & t,Args...args)149 Operator &PushInput(const T &t, Args... args) { 150 SetParam(N, t); 151 PushInput<Args..., N+1>(args...); 152 return *this; 153 } 154 /*! 155 * \brief add the last positional input 156 */ 157 template <class T, int N = 0> PushInput(const T & t)158 Operator &PushInput(const T &t) { 159 SetParam(N, t); 160 return *this; 161 } 162 /*! 163 * \brief add input ndarrays 164 * \param ndarray the input ndarray 165 * \return reference of self 166 */ operator()167 Operator &operator()(const NDArray &ndarray) { 168 input_ndarrays_.push_back(ndarray.GetHandle()); 169 return *this; 170 } 171 /*! 172 * \brief add a list of input ndarrays 173 * \param ndarrays the vector of the input ndarrays 174 * \return reference of self 175 */ operator()176 Operator &operator()(const std::vector<NDArray> &ndarrays) { 177 for (auto &s : ndarrays) { 178 input_ndarrays_.push_back(s.GetHandle()); 179 } 180 return *this; 181 } 182 /*! 183 * \brief add input ndarrays 184 * \return reference of self 185 */ 186 template <typename... Args> operator()187 Operator &operator()(Args... args) { 188 PushInput(args...); 189 return *this; 190 } 191 std::vector<NDArray> Invoke(); 192 void Invoke(NDArray &output); 193 void Invoke(std::vector<NDArray> &outputs); 194 195 private: 196 std::map<std::string, std::string> params_desc_; 197 bool variable_params_ = false; 198 std::map<std::string, std::string> params_; 199 std::vector<SymbolHandle> input_symbols_; 200 std::vector<NDArrayHandle> input_ndarrays_; 201 std::vector<std::string> input_keys_; 202 std::vector<std::string> arg_names_; 203 AtomicSymbolCreator handle_; 204 static OpMap*& op_map(); 205 }; 206 } // namespace cpp 207 } // namespace mxnet 208 209 #endif // MXNET_CPP_OPERATOR_H_ 210