1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file op_suppl.h
22 * \brief A supplement and amendment of the operators from op.h
23 * \author Zhang Chen, zhubuntu, Xin Li
24 */
25
26 #ifndef MXNET_CPP_OP_SUPPL_H_
27 #define MXNET_CPP_OP_SUPPL_H_
28
29 #include <cassert>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/shape.h"
34 #include "mxnet-cpp/operator.h"
35 #include "mxnet-cpp/MxNetCpp.h"
36
37 namespace mxnet {
38 namespace cpp {
39
_Plus(Symbol lhs,Symbol rhs)40 inline Symbol _Plus(Symbol lhs, Symbol rhs) {
41 return Operator("_Plus")(lhs, rhs)
42 .CreateSymbol();
43 }
_Mul(Symbol lhs,Symbol rhs)44 inline Symbol _Mul(Symbol lhs, Symbol rhs) {
45 return Operator("_Mul")(lhs, rhs)
46 .CreateSymbol();
47 }
_Minus(Symbol lhs,Symbol rhs)48 inline Symbol _Minus(Symbol lhs, Symbol rhs) {
49 return Operator("_Minus")(lhs, rhs)
50 .CreateSymbol();
51 }
_Div(Symbol lhs,Symbol rhs)52 inline Symbol _Div(Symbol lhs, Symbol rhs) {
53 return Operator("_Div")(lhs, rhs)
54 .CreateSymbol();
55 }
_Mod(Symbol lhs,Symbol rhs)56 inline Symbol _Mod(Symbol lhs, Symbol rhs) {
57 return Operator("_Mod")(lhs, rhs)
58 .CreateSymbol();
59 }
_Power(Symbol lhs,Symbol rhs)60 inline Symbol _Power(Symbol lhs, Symbol rhs) {
61 return Operator("_Power")(lhs, rhs)
62 .CreateSymbol();
63 }
_Maximum(Symbol lhs,Symbol rhs)64 inline Symbol _Maximum(Symbol lhs, Symbol rhs) {
65 return Operator("_Maximum")(lhs, rhs)
66 .CreateSymbol();
67 }
_Minimum(Symbol lhs,Symbol rhs)68 inline Symbol _Minimum(Symbol lhs, Symbol rhs) {
69 return Operator("_Minimum")(lhs, rhs)
70 .CreateSymbol();
71 }
_PlusScalar(Symbol lhs,mx_float scalar)72 inline Symbol _PlusScalar(Symbol lhs, mx_float scalar) {
73 return Operator("_PlusScalar")(lhs)
74 .SetParam("scalar", scalar)
75 .CreateSymbol();
76 }
_MinusScalar(Symbol lhs,mx_float scalar)77 inline Symbol _MinusScalar(Symbol lhs, mx_float scalar) {
78 return Operator("_MinusScalar")(lhs)
79 .SetParam("scalar", scalar)
80 .CreateSymbol();
81 }
_RMinusScalar(mx_float scalar,Symbol rhs)82 inline Symbol _RMinusScalar(mx_float scalar, Symbol rhs) {
83 return Operator("_RMinusScalar")(rhs)
84 .SetParam("scalar", scalar)
85 .CreateSymbol();
86 }
_MulScalar(Symbol lhs,mx_float scalar)87 inline Symbol _MulScalar(Symbol lhs, mx_float scalar) {
88 return Operator("_MulScalar")(lhs)
89 .SetParam("scalar", scalar)
90 .CreateSymbol();
91 }
_DivScalar(Symbol lhs,mx_float scalar)92 inline Symbol _DivScalar(Symbol lhs, mx_float scalar) {
93 return Operator("_DivScalar")(lhs)
94 .SetParam("scalar", scalar)
95 .CreateSymbol();
96 }
_RDivScalar(mx_float scalar,Symbol rhs)97 inline Symbol _RDivScalar(mx_float scalar, Symbol rhs) {
98 return Operator("_RDivScalar")(rhs)
99 .SetParam("scalar", scalar)
100 .CreateSymbol();
101 }
_ModScalar(Symbol lhs,mx_float scalar)102 inline Symbol _ModScalar(Symbol lhs, mx_float scalar) {
103 return Operator("_ModScalar")(lhs)
104 .SetParam("scalar", scalar)
105 .CreateSymbol();
106 }
_RModScalar(mx_float scalar,Symbol rhs)107 inline Symbol _RModScalar(mx_float scalar, Symbol rhs) {
108 return Operator("_RModScalar")(rhs)
109 .SetParam("scalar", scalar)
110 .CreateSymbol();
111 }
_PowerScalar(Symbol lhs,mx_float scalar)112 inline Symbol _PowerScalar(Symbol lhs, mx_float scalar) {
113 return Operator("_PowerScalar")(lhs)
114 .SetParam("scalar", scalar)
115 .CreateSymbol();
116 }
_RPowerScalar(mx_float scalar,Symbol rhs)117 inline Symbol _RPowerScalar(mx_float scalar, Symbol rhs) {
118 return Operator("_RPowerScalar")(rhs)
119 .SetParam("scalar", scalar)
120 .CreateSymbol();
121 }
_MaximumScalar(Symbol lhs,mx_float scalar)122 inline Symbol _MaximumScalar(Symbol lhs, mx_float scalar) {
123 return Operator("_MaximumScalar")(lhs)
124 .SetParam("scalar", scalar)
125 .CreateSymbol();
126 }
_MinimumScalar(Symbol lhs,mx_float scalar)127 inline Symbol _MinimumScalar(Symbol lhs, mx_float scalar) {
128 return Operator("_MinimumScalar")(lhs)
129 .SetParam("scalar", scalar)
130 .CreateSymbol();
131 }
132 // TODO(zhangcheng-qinyinghua)
133 // make crop function run in op.h
134 // This function is due to [zhubuntu](https://github.com/zhubuntu)
135 inline Symbol Crop(const std::string& symbol_name,
136 int num_args,
137 Symbol data,
138 Symbol crop_like,
139 Shape offset = Shape(0, 0),
140 Shape h_w = Shape(0, 0),
141 bool center_crop = false) {
142 return Operator("Crop")
143 .SetParam("num_args", num_args)
144 .SetParam("offset", offset)
145 .SetParam("h_w", h_w)
146 .SetParam("center_crop", center_crop)
147 .SetInput("arg0", data)
148 .SetInput("arg1", crop_like)
149 .CreateSymbol(symbol_name);
150 }
151
152
153 /*!
154 * \brief Apply activation function to input.
155 * Softmax Activation is only available with CUDNN on GPUand will be
156 * computed at each location across channel if input is 4D.
157 * \param symbol_name name of the resulting symbol.
158 * \param data Input data to activation function.
159 * \param act_type Activation function to be applied.
160 * \return new symbol
161 */
Activation(const std::string & symbol_name,Symbol data,const std::string & act_type)162 inline Symbol Activation(const std::string& symbol_name,
163 Symbol data,
164 const std::string& act_type) {
165 assert(act_type == "relu" ||
166 act_type == "sigmoid" ||
167 act_type == "softrelu" ||
168 act_type == "tanh");
169 return Operator("Activation")
170 .SetParam("act_type", act_type.c_str())
171 .SetInput("data", data)
172 .CreateSymbol(symbol_name);
173 }
174
175 } // namespace cpp
176 } // namespace mxnet
177
178 #endif // MXNET_CPP_OP_SUPPL_H_
179
180