1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file l2_normalization_op-inl.h
22  * \brief instance l2 Normalization op
23 */
24 #ifndef MXNET_OPERATOR_L2_NORMALIZATION_INL_H_
25 #define MXNET_OPERATOR_L2_NORMALIZATION_INL_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <mxnet/operator.h>
30 #include <map>
31 #include <vector>
32 #include <string>
33 #include <utility>
34 #include "./operator_common.h"
35 #include "./mshadow_op.h"
36 
37 namespace mxnet {
38 namespace op {
39 
40 // Declare enumeration of input order to make code more intuitive.
41 // These enums are only visible within this header
42 namespace l2_normalization {
43 enum L2NormalizationOpInputs {kData};
44 enum L2NormalizationOpOutputs {kOut, kNorm};
45 enum L2NormalizationOpType {kInstance, kChannel, kSpatial};
46 enum L2NormalizationBackResource {kTempSpace};
47 }  // l2_normalization
48 
49 struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> {
50   float eps;
51   int mode;
DMLC_DECLARE_PARAMETERL2NormalizationParam52   DMLC_DECLARE_PARAMETER(L2NormalizationParam) {
53     DMLC_DECLARE_FIELD(eps).set_default(1e-10f)
54     .describe("A small constant for numerical stability.");
55     DMLC_DECLARE_FIELD(mode)
56     .add_enum("instance", l2_normalization::kInstance)
57     .add_enum("spatial", l2_normalization::kSpatial)
58     .add_enum("channel", l2_normalization::kChannel)
59     .set_default(l2_normalization::kInstance)
60     .describe("Specify the dimension along which to compute L2 norm.");
61   }
62 };
63 
64 /**
65  * \brief This is the implementation of l2 normalization operator.
66  * \tparam xpu The device that the op will be executed on.
67  */
68 template<typename xpu, typename DType>
69 class L2NormalizationOp : public Operator {
70  public:
L2NormalizationOp(L2NormalizationParam p)71   explicit L2NormalizationOp(L2NormalizationParam p) {
72     this->param_ = p;
73   }
74 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)75   virtual void Forward(const OpContext &ctx,
76                        const std::vector<TBlob> &in_data,
77                        const std::vector<OpReqType> &req,
78                        const std::vector<TBlob> &out_data,
79                        const std::vector<TBlob> &aux_args) {
80     using namespace mshadow;
81     using namespace mshadow::expr;
82     if (req[l2_normalization::kOut] == kNullOp) return;
83     CHECK_EQ(req[l2_normalization::kOut], kWriteTo);
84     CHECK_EQ(in_data.size(), 1U);
85     CHECK_EQ(out_data.size(), 2U);
86     Stream<xpu> *s = ctx.get_stream<xpu>();
87     mxnet::TShape orig_shape = in_data[l2_normalization::kData].shape_;
88     if (param_.mode == l2_normalization::kInstance) {
89       Shape<2> dshape = Shape2(orig_shape[0],
90         orig_shape.ProdShape(1, orig_shape.ndim()));
91       Tensor<xpu, 2, DType> data = in_data[l2_normalization::kData]
92         .get_with_shape<xpu, 2, DType>(dshape, s);
93       Tensor<xpu, 2, DType> out = out_data[l2_normalization::kOut]
94         .get_with_shape<xpu, 2, DType>(dshape, s);
95       Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
96       norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::square>(data));
97       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
98         mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
99           s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps));
100       });
101       norm = F<mxnet::op::mshadow_op::square_root>(norm);
102       out = data / broadcast<0>(norm, out.shape_);
103     } else if (param_.mode == l2_normalization::kChannel) {
104       CHECK_GE(orig_shape.ndim(), 3);
105       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
106         orig_shape.ProdShape(2, orig_shape.ndim()));
107       Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
108         .get_with_shape<xpu, 3, DType>(dshape, s);
109       Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
110         .get_with_shape<xpu, 3, DType>(dshape, s);
111       Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
112       Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
113         .get_with_shape<xpu, 2, DType>(norm_shape, s);
114       norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 1);
115       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
116         mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
117           s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
118       });
119       norm = F<mxnet::op::mshadow_op::square_root>(norm);
120       out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
121     } else if (param_.mode == l2_normalization::kSpatial) {
122       CHECK_GE(orig_shape.ndim(), 3);
123       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
124         orig_shape.ProdShape(2, orig_shape.ndim()));
125       Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
126         .get_with_shape<xpu, 3, DType>(dshape, s);
127       Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
128         .get_with_shape<xpu, 3, DType>(dshape, s);
129       Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
130       Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
131         .get_with_shape<xpu, 2, DType>(norm_shape, s);
132       norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 2);
133       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
134         mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
135           s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
136       });
137       norm = F<mxnet::op::mshadow_op::square_root>(norm);
138       out = data / broadcast_with_axis(norm, 1, dshape[2]);
139     } else {
140       LOG(FATAL) << "Unexpected mode in l2 normalization";
141     }
142   }
143 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)144   virtual void Backward(const OpContext &ctx,
145                         const std::vector<TBlob> &out_grad,
146                         const std::vector<TBlob> &in_data,
147                         const std::vector<TBlob> &out_data,
148                         const std::vector<OpReqType> &req,
149                         const std::vector<TBlob> &in_grad,
150                         const std::vector<TBlob> &aux_args) {
151     using namespace mshadow;
152     using namespace mshadow::expr;
153     CHECK_EQ(out_grad.size(), 1U);
154     CHECK(in_data.size() == 1U && in_grad.size() == 1U);
155     CHECK_EQ(req.size(), 1U);
156 
157     Stream<xpu> *s = ctx.get_stream<xpu>();
158     mxnet::TShape orig_shape = out_data[l2_normalization::kOut].shape_;
159     if (param_.mode == l2_normalization::kInstance) {
160       Shape<2> dshape = Shape2(orig_shape[0],
161         orig_shape.ProdShape(1, orig_shape.ndim()));
162       Tensor<xpu, 2, DType> data = out_data[l2_normalization::kOut]
163         .get_with_shape<xpu, 2, DType>(dshape, s);
164       Tensor<xpu, 2, DType> grad_in = in_grad[l2_normalization::kData]
165         .get_with_shape<xpu, 2, DType>(dshape, s);
166       Tensor<xpu, 2, DType> grad_out = out_grad[l2_normalization::kOut]
167         .get_with_shape<xpu, 2, DType>(dshape, s);
168       Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
169       Tensor<xpu, 1, DType> temp = ctx.requested[l2_normalization::kTempSpace]
170         .get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s);
171       temp = sumall_except_dim<0>(grad_out * data);
172       Assign(grad_in, req[l2_normalization::kData],
173         (grad_out - data * broadcast<0>(temp, data.shape_)) /
174         broadcast<0>(norm, data.shape_));
175     } else if (param_.mode == l2_normalization::kChannel) {
176       CHECK_GE(orig_shape.ndim(), 3);
177       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
178         orig_shape.ProdShape(2, orig_shape.ndim()));
179       Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
180         .get_with_shape<xpu, 3, DType>(dshape, s);
181       Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
182         .get_with_shape<xpu, 3, DType>(dshape, s);
183       Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
184         .get_with_shape<xpu, 3, DType>(dshape, s);
185       Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
186       Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
187         .get_with_shape<xpu, 2, DType>(norm_shape, s);
188       Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
189         .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
190       temp = reduce_with_axis<red::sum, false>(grad_out * data, 1);
191       Assign(grad_in, req[l2_normalization::kData],
192         (grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
193         broadcast_with_axis(norm, 0, orig_shape[1]));
194     } else if (param_.mode == l2_normalization::kSpatial) {
195       CHECK_GE(orig_shape.ndim(), 3);
196       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
197         orig_shape.ProdShape(2, orig_shape.ndim()));
198       Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
199         .get_with_shape<xpu, 3, DType>(dshape, s);
200       Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
201         .get_with_shape<xpu, 3, DType>(dshape, s);
202       Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
203         .get_with_shape<xpu, 3, DType>(dshape, s);
204       Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
205       Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
206         .get_with_shape<xpu, 2, DType>(norm_shape, s);
207       Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
208         .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
209       temp = reduce_with_axis<red::sum, false>(grad_out * data, 2);
210       Assign(grad_in, req[l2_normalization::kData],
211         (grad_out - data * broadcast_with_axis(temp, 1, dshape[2])) /
212         broadcast_with_axis(norm, 1, dshape[2]));
213     } else {
214       LOG(FATAL) << "Unexpected mode in l2 normalization";
215     }
216   }
217 
218  protected:
219   L2NormalizationParam param_;
220 };  // class L2NormalizationOp
221 
222 // Decalre Factory function, used for dispatch specialization
223 template<typename xpu>
224 Operator* CreateOp(L2NormalizationParam param, int dtype);
225 
226 #if DMLC_USE_CXX11
227 class L2NormalizationProp : public OperatorProperty {
228  public:
ListArguments()229   std::vector<std::string> ListArguments() const override {
230     return {"data"};
231   }
232 
ListOutputs()233   std::vector<std::string> ListOutputs() const override {
234     return {"output", "norm"};
235   }
236 
NumVisibleOutputs()237   int NumVisibleOutputs() const override {
238     return 1;
239   }
240 
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)241   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
242     param_.Init(kwargs);
243   }
244 
GetParams()245   std::map<std::string, std::string> GetParams() const override {
246     return param_.__DICT__();
247   }
248 
InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)249   bool InferType(std::vector<int> *in_type,
250                  std::vector<int> *out_type,
251                  std::vector<int> *aux_type) const override {
252     int dtype = (*in_type)[0];
253     type_assign(&dtype, (*out_type)[0]);
254     type_assign(&dtype, (*out_type)[1]);
255 
256     TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
257     TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
258     TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
259     return dtype != -1;
260   }
261 
InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)262   bool InferShape(mxnet::ShapeVector *in_shape,
263                   mxnet::ShapeVector *out_shape,
264                   mxnet::ShapeVector *aux_shape) const override {
265     using namespace mshadow;
266     CHECK_EQ(in_shape->size(), 1U) << "L2Normalization layer only accepts data as input";
267     const mxnet::TShape &dshape = (*in_shape)[l2_normalization::kData];
268     // require data to be known
269     if ((*in_shape)[l2_normalization::kData].ndim() == 0) return false;
270     out_shape->clear();
271     out_shape->push_back(dshape);
272     if (param_.mode == l2_normalization::kInstance) {
273       out_shape->push_back(Shape1(dshape[0]));
274     } else if (param_.mode == l2_normalization::kChannel) {
275       CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in channel mode";
276       mxnet::TShape norm_shape = dshape;
277       norm_shape[1] = 1;
278       out_shape->push_back(norm_shape);
279     } else if (param_.mode == l2_normalization::kSpatial) {
280       CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in spatial mode";
281       out_shape->push_back(Shape2(dshape[0], dshape[1]));
282     } else {
283       return false;
284     }
285     return true;
286   }
287 
Copy()288   OperatorProperty* Copy() const override {
289     L2NormalizationProp* norm_sym = new L2NormalizationProp();
290     norm_sym->param_ = this->param_;
291     return norm_sym;
292   }
293 
TypeString()294   std::string TypeString() const override {
295     return "L2Normalization";
296   }
297 
298   // declare dependency and inplace optimization options
DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)299   std::vector<int> DeclareBackwardDependency(
300     const std::vector<int> &out_grad,
301     const std::vector<int> &in_data,
302     const std::vector<int> &out_data) const override {
303     return {out_grad[l2_normalization::kOut],
304       out_data[l2_normalization::kOut],
305       out_data[l2_normalization::kNorm]};
306   }
307 
BackwardInplaceOption(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data,const std::vector<void * > & in_grad)308   std::vector<std::pair<int, void*> > BackwardInplaceOption(
309     const std::vector<int> &out_grad,
310     const std::vector<int> &in_data,
311     const std::vector<int> &out_data,
312     const std::vector<void*> &in_grad) const override {
313     return {{out_grad[l2_normalization::kOut], in_grad[l2_normalization::kData]}};
314   }
315 
BackwardResource(const mxnet::ShapeVector & in_shape)316   std::vector<ResourceRequest> BackwardResource(
317       const mxnet::ShapeVector &in_shape) const override {
318     return {ResourceRequest::kTempSpace};
319   }
320 
CreateOperator(Context ctx)321   Operator* CreateOperator(Context ctx) const override {
322     LOG(FATAL) << "Not Implemented.";
323     return nullptr;
324   }
325 
326   Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
327                              std::vector<int> *in_type) const override;
328 
329  private:
330   L2NormalizationParam param_;
331 };  // class L2NormalizationSymbol
332 #endif
333 }  // namespace op
334 }  // namespace mxnet
335 #endif  // MXNET_OPERATOR_L2_NORMALIZATION_INL_H_
336