1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file operator.h
22 * \brief definition of operator
23 * \author Chuntao Hong, Zhang Chen
24 */
25 
26 #ifndef MXNET_CPP_OPERATOR_H_
27 #define MXNET_CPP_OPERATOR_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/op_map.h"
34 #include "mxnet-cpp/symbol.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 class Mxnet;
39 /*!
40 * \brief Operator interface
41 */
42 class Operator {
43  public:
44   /*!
45   * \brief Operator constructor
46   * \param operator_name type of the operator
47   */
48   explicit Operator(const std::string &operator_name);
49   Operator &operator=(const Operator &rhs);
50   /*!
51   * \brief set config parameters
52   * \param name name of the config parameter
53   * \param value value of the config parameter
54   * \return reference of self
55   */
56   template <typename T>
SetParam(const std::string & name,const T & value)57   Operator &SetParam(const std::string &name, const T &value) {
58     std::string value_str;
59     std::stringstream ss;
60     ss << value;
61     ss >> value_str;
62 
63     params_[name] = value_str;
64     return *this;
65   }
66   /*!
67   * \brief set config parameters from positional inputs
68   * \param pos the position of parameter
69   * \param value value of the config parameter
70   * \return reference of self
71   */
72   template <typename T>
SetParam(int pos,const T & value)73   Operator &SetParam(int pos, const T &value) {
74     std::string value_str;
75     std::stringstream ss;
76     ss << value;
77     ss >> value_str;
78 
79     params_[arg_names_[pos]] = value_str;
80     return *this;
81   }
82   /*!
83   * \brief add an input symbol
84   * \param name name of the input symbol
85   * \param symbol the input symbol
86   * \return reference of self
87   */
88   Operator &SetInput(const std::string &name, const Symbol &symbol);
89   /*!
90   * \brief add an input symbol
91   * \param symbol the input symbol
92   */
93   template<int N = 0>
PushInput(const Symbol & symbol)94   void PushInput(const Symbol &symbol) {
95     input_symbols_.push_back(symbol.GetHandle());
96   }
97   /*!
98   * \brief add input symbols
99   * \return reference of self
100   */
operator()101   Operator &operator()() { return *this; }
102   /*!
103   * \brief add input symbols
104   * \param symbol the input symbol
105   * \return reference of self
106   */
operator()107   Operator &operator()(const Symbol &symbol) {
108     input_symbols_.push_back(symbol.GetHandle());
109     return *this;
110   }
111   /*!
112   * \brief add a list of input symbols
113   * \param symbols the vector of the input symbols
114   * \return reference of self
115   */
operator()116   Operator &operator()(const std::vector<Symbol> &symbols) {
117     for (auto &s : symbols) {
118       input_symbols_.push_back(s.GetHandle());
119     }
120     return *this;
121   }
122   /*!
123   * \brief create a Symbol from the current operator
124   * \param name the name of the operator
125   * \return the operator Symbol
126   */
127   Symbol CreateSymbol(const std::string &name = "");
128 
129   /*!
130   * \brief add an input ndarray
131   * \param name name of the input ndarray
132   * \param ndarray the input ndarray
133   * \return reference of self
134   */
135   Operator &SetInput(const std::string &name, const NDArray &ndarray);
136   /*!
137   * \brief add an input ndarray
138   * \param ndarray the input ndarray
139   */
140   template<int N = 0>
PushInput(const NDArray & ndarray)141   Operator &PushInput(const NDArray &ndarray) {
142     input_ndarrays_.push_back(ndarray.GetHandle());
143     return *this;
144   }
145   /*!
146   * \brief add positional inputs
147   */
148   template <class T, class... Args, int N = 0>
PushInput(const T & t,Args...args)149   Operator &PushInput(const T &t, Args... args) {
150     SetParam(N, t);
151     PushInput<Args..., N+1>(args...);
152     return *this;
153   }
154   /*!
155   * \brief add the last positional input
156   */
157   template <class T, int N = 0>
PushInput(const T & t)158   Operator &PushInput(const T &t) {
159     SetParam(N, t);
160     return *this;
161   }
162   /*!
163   * \brief add input ndarrays
164   * \param ndarray the input ndarray
165   * \return reference of self
166   */
operator()167   Operator &operator()(const NDArray &ndarray) {
168     input_ndarrays_.push_back(ndarray.GetHandle());
169     return *this;
170   }
171   /*!
172   * \brief add a list of input ndarrays
173   * \param ndarrays the vector of the input ndarrays
174   * \return reference of self
175   */
operator()176   Operator &operator()(const std::vector<NDArray> &ndarrays) {
177     for (auto &s : ndarrays) {
178       input_ndarrays_.push_back(s.GetHandle());
179     }
180     return *this;
181   }
182   /*!
183   * \brief add input ndarrays
184   * \return reference of self
185   */
186   template <typename... Args>
operator()187   Operator &operator()(Args... args) {
188     PushInput(args...);
189     return *this;
190   }
191   std::vector<NDArray> Invoke();
192   void Invoke(NDArray &output);
193   void Invoke(std::vector<NDArray> &outputs);
194 
195  private:
196   std::map<std::string, std::string> params_desc_;
197   bool variable_params_ = false;
198   std::map<std::string, std::string> params_;
199   std::vector<SymbolHandle> input_symbols_;
200   std::vector<NDArrayHandle> input_ndarrays_;
201   std::vector<std::string> input_keys_;
202   std::vector<std::string> arg_names_;
203   AtomicSymbolCreator handle_;
204   static OpMap*& op_map();
205 };
206 }  // namespace cpp
207 }  // namespace mxnet
208 
209 #endif  // MXNET_CPP_OPERATOR_H_
210