1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file op_suppl.h
22 * \brief A supplement and amendment of the operators from op.h
23 * \author Zhang Chen, zhubuntu, Xin Li
24 */
25 
26 #ifndef MXNET_CPP_OP_SUPPL_H_
27 #define MXNET_CPP_OP_SUPPL_H_
28 
29 #include <cassert>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/shape.h"
34 #include "mxnet-cpp/operator.h"
35 #include "mxnet-cpp/MxNetCpp.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
_Plus(Symbol lhs,Symbol rhs)40 inline Symbol _Plus(Symbol lhs, Symbol rhs) {
41   return Operator("_Plus")(lhs, rhs)
42            .CreateSymbol();
43 }
_Mul(Symbol lhs,Symbol rhs)44 inline Symbol _Mul(Symbol lhs, Symbol rhs) {
45   return Operator("_Mul")(lhs, rhs)
46            .CreateSymbol();
47 }
_Minus(Symbol lhs,Symbol rhs)48 inline Symbol _Minus(Symbol lhs, Symbol rhs) {
49   return Operator("_Minus")(lhs, rhs)
50            .CreateSymbol();
51 }
_Div(Symbol lhs,Symbol rhs)52 inline Symbol _Div(Symbol lhs, Symbol rhs) {
53   return Operator("_Div")(lhs, rhs)
54            .CreateSymbol();
55 }
_Mod(Symbol lhs,Symbol rhs)56 inline Symbol _Mod(Symbol lhs, Symbol rhs) {
57   return Operator("_Mod")(lhs, rhs)
58            .CreateSymbol();
59 }
_Power(Symbol lhs,Symbol rhs)60 inline Symbol _Power(Symbol lhs, Symbol rhs) {
61   return Operator("_Power")(lhs, rhs)
62            .CreateSymbol();
63 }
_Maximum(Symbol lhs,Symbol rhs)64 inline Symbol _Maximum(Symbol lhs, Symbol rhs) {
65   return Operator("_Maximum")(lhs, rhs)
66            .CreateSymbol();
67 }
_Minimum(Symbol lhs,Symbol rhs)68 inline Symbol _Minimum(Symbol lhs, Symbol rhs) {
69   return Operator("_Minimum")(lhs, rhs)
70            .CreateSymbol();
71 }
_PlusScalar(Symbol lhs,mx_float scalar)72 inline Symbol _PlusScalar(Symbol lhs, mx_float scalar) {
73   return Operator("_PlusScalar")(lhs)
74            .SetParam("scalar", scalar)
75            .CreateSymbol();
76 }
_MinusScalar(Symbol lhs,mx_float scalar)77 inline Symbol _MinusScalar(Symbol lhs, mx_float scalar) {
78   return Operator("_MinusScalar")(lhs)
79            .SetParam("scalar", scalar)
80            .CreateSymbol();
81 }
_RMinusScalar(mx_float scalar,Symbol rhs)82 inline Symbol _RMinusScalar(mx_float scalar, Symbol rhs) {
83   return Operator("_RMinusScalar")(rhs)
84            .SetParam("scalar", scalar)
85            .CreateSymbol();
86 }
_MulScalar(Symbol lhs,mx_float scalar)87 inline Symbol _MulScalar(Symbol lhs, mx_float scalar) {
88   return Operator("_MulScalar")(lhs)
89            .SetParam("scalar", scalar)
90            .CreateSymbol();
91 }
_DivScalar(Symbol lhs,mx_float scalar)92 inline Symbol _DivScalar(Symbol lhs, mx_float scalar) {
93   return Operator("_DivScalar")(lhs)
94            .SetParam("scalar", scalar)
95            .CreateSymbol();
96 }
_RDivScalar(mx_float scalar,Symbol rhs)97 inline Symbol _RDivScalar(mx_float scalar, Symbol rhs) {
98   return Operator("_RDivScalar")(rhs)
99            .SetParam("scalar", scalar)
100            .CreateSymbol();
101 }
_ModScalar(Symbol lhs,mx_float scalar)102 inline Symbol _ModScalar(Symbol lhs, mx_float scalar) {
103   return Operator("_ModScalar")(lhs)
104            .SetParam("scalar", scalar)
105            .CreateSymbol();
106 }
_RModScalar(mx_float scalar,Symbol rhs)107 inline Symbol _RModScalar(mx_float scalar, Symbol rhs) {
108   return Operator("_RModScalar")(rhs)
109            .SetParam("scalar", scalar)
110            .CreateSymbol();
111 }
_PowerScalar(Symbol lhs,mx_float scalar)112 inline Symbol _PowerScalar(Symbol lhs, mx_float scalar) {
113   return Operator("_PowerScalar")(lhs)
114            .SetParam("scalar", scalar)
115            .CreateSymbol();
116 }
_RPowerScalar(mx_float scalar,Symbol rhs)117 inline Symbol _RPowerScalar(mx_float scalar, Symbol rhs) {
118   return Operator("_RPowerScalar")(rhs)
119            .SetParam("scalar", scalar)
120            .CreateSymbol();
121 }
_MaximumScalar(Symbol lhs,mx_float scalar)122 inline Symbol _MaximumScalar(Symbol lhs, mx_float scalar) {
123   return Operator("_MaximumScalar")(lhs)
124            .SetParam("scalar", scalar)
125            .CreateSymbol();
126 }
_MinimumScalar(Symbol lhs,mx_float scalar)127 inline Symbol _MinimumScalar(Symbol lhs, mx_float scalar) {
128   return Operator("_MinimumScalar")(lhs)
129            .SetParam("scalar", scalar)
130            .CreateSymbol();
131 }
132 // TODO(zhangcheng-qinyinghua)
133 //  make crop function run in op.h
134 //  This function is due to [zhubuntu](https://github.com/zhubuntu)
135 inline Symbol Crop(const std::string& symbol_name,
136     int num_args,
137     Symbol data,
138     Symbol crop_like,
139     Shape offset = Shape(0, 0),
140     Shape h_w = Shape(0, 0),
141     bool center_crop = false) {
142   return Operator("Crop")
143     .SetParam("num_args", num_args)
144     .SetParam("offset", offset)
145     .SetParam("h_w", h_w)
146     .SetParam("center_crop", center_crop)
147     .SetInput("arg0", data)
148     .SetInput("arg1", crop_like)
149     .CreateSymbol(symbol_name);
150 }
151 
152 
153 /*!
154  * \brief Apply activation function to input.
155  *        Softmax Activation is only available with CUDNN on GPUand will be
156  *        computed at each location across channel if input is 4D.
157  * \param symbol_name name of the resulting symbol.
158  * \param data Input data to activation function.
159  * \param act_type Activation function to be applied.
160  * \return new symbol
161  */
Activation(const std::string & symbol_name,Symbol data,const std::string & act_type)162 inline Symbol Activation(const std::string& symbol_name,
163                          Symbol data,
164                          const std::string& act_type) {
165   assert(act_type == "relu" ||
166          act_type == "sigmoid" ||
167          act_type == "softrelu" ||
168          act_type == "tanh");
169   return Operator("Activation")
170            .SetParam("act_type", act_type.c_str())
171            .SetInput("data", data)
172            .CreateSymbol(symbol_name);
173 }
174 
175 }  // namespace cpp
176 }  // namespace mxnet
177 
178 #endif  // MXNET_CPP_OP_SUPPL_H_
179 
180