1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file l2_normalization.cc
22 * \brief l2 normalization operator
23 */
24 #include "./l2_normalization-inl.h"
25
26 /* VisualStudio only supports openmp 2.0 */
27 #ifdef _MSC_VER
28 #define collapse(x)
29 #endif
30
31 namespace mxnet {
32 namespace op {
33
34 template<typename DType>
35 class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
36 public:
L2NormalizationOpCPU(L2NormalizationParam p)37 explicit L2NormalizationOpCPU(L2NormalizationParam p)
38 : L2NormalizationOp<cpu, DType>(p) {}
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)39 void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
40 const std::vector<OpReqType> &req,
41 const std::vector<TBlob> &out_data,
42 const std::vector<TBlob> &aux_args) override {
43 using namespace mshadow;
44 using namespace mshadow::expr;
45 if (req[l2_normalization::kOut] == kNullOp) return;
46 CHECK_EQ(req[l2_normalization::kOut], kWriteTo);
47 CHECK_EQ(in_data.size(), 1U);
48 CHECK_EQ(out_data.size(), 2U);
49 Stream<cpu> *s = ctx.get_stream<cpu>();
50 mxnet::TShape orig_shape = in_data[l2_normalization::kData].shape_;
51 auto omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
52 if (this->param_.mode == l2_normalization::kInstance) {
53 Shape<2> dshape = Shape2(orig_shape[0],
54 orig_shape.ProdShape(1, orig_shape.ndim()));
55 Tensor<cpu, 2, DType> data = in_data[l2_normalization::kData]
56 .get_with_shape<cpu, 2, DType>(dshape, s);
57 Tensor<cpu, 2, DType> out = out_data[l2_normalization::kOut]
58 .get_with_shape<cpu, 2, DType>(dshape, s);
59 Tensor<cpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<cpu, 1, DType>(s);
60 #pragma omp parallel for num_threads(omp_threads)
61 for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
62 norm[shape0] = DType(this->param_.eps);
63 for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
64 norm[shape0] += data[shape0][shape1] * data[shape0][shape1];
65 }
66 norm[shape0] = std::sqrt(norm[shape0]);
67 for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
68 out[shape0][shape1] = data[shape0][shape1] / norm[shape0];
69 }
70 }
71 } else if (this->param_.mode == l2_normalization::kChannel) {
72 CHECK_GE(orig_shape.ndim(), 3);
73 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
74 orig_shape.ProdShape(2, orig_shape.ndim()));
75 Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
76 .get_with_shape<cpu, 3, DType>(dshape, s);
77 Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
78 .get_with_shape<cpu, 3, DType>(dshape, s);
79 Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
80 Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
81 .get_with_shape<cpu, 2, DType>(norm_shape, s);
82 #pragma omp parallel for num_threads(omp_threads) collapse(2)
83 for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
84 for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
85 norm[shape0][shape2] = DType(this->param_.eps);
86 for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
87 norm[shape0][shape2] += data[shape0][shape1][shape2] * data[shape0][shape1][shape2];
88 }
89 norm[shape0][shape2] = std::sqrt(norm[shape0][shape2]);
90 for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
91 out[shape0][shape1][shape2] = data[shape0][shape1][shape2] / norm[shape0][shape2];
92 }
93 }
94 }
95 } else if (this->param_.mode == l2_normalization::kSpatial) {
96 CHECK_GE(orig_shape.ndim(), 3);
97 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
98 orig_shape.ProdShape(2, orig_shape.ndim()));
99 Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
100 .get_with_shape<cpu, 3, DType>(dshape, s);
101 Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
102 .get_with_shape<cpu, 3, DType>(dshape, s);
103 Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
104 Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
105 .get_with_shape<cpu, 2, DType>(norm_shape, s);
106 #pragma omp parallel for num_threads(omp_threads) collapse(2)
107 for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
108 for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
109 norm[shape0][shape1] = DType(this->param_.eps);
110 for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
111 norm[shape0][shape1] += data[shape0][shape1][shape2] * data[shape0][shape1][shape2];
112 }
113 norm[shape0][shape1] = std::sqrt(norm[shape0][shape1]);
114 for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
115 out[shape0][shape1][shape2] = data[shape0][shape1][shape2] / norm[shape0][shape1];
116 }
117 }
118 }
119 } else {
120 LOG(FATAL) << "Unexpected mode in l2 normalization";
121 }
122 }
123 };
124
125 template<>
CreateOp(L2NormalizationParam param,int dtype)126 Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
127 Operator* op = nullptr;
128 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
129 op = new L2NormalizationOpCPU<DType>(param);
130 });
131 return op;
132 }
133
134 // DO_BIND_DISPATCH comes from static_operator_common.h
CreateOperatorEx(Context ctx,mxnet::ShapeVector * in_shape,std::vector<int> * in_type) const135 Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
136 std::vector<int> *in_type) const {
137 DO_BIND_DISPATCH(CreateOp, this->param_, in_type->at(0));
138 }
139
140 DMLC_REGISTER_PARAMETER(L2NormalizationParam);
141
142 MXNET_REGISTER_OP_PROPERTY(L2Normalization, L2NormalizationProp)
143 .describe(R"code(Normalize the input array using the L2 norm.
144
145 For 1-D NDArray, it computes::
146
147 out = data / sqrt(sum(data ** 2) + eps)
148
149 For N-D NDArray, if the input array has shape (N, N, ..., N),
150
151 with ``mode`` = ``instance``, it normalizes each instance in the multidimensional
152 array by its L2 norm.::
153
154 for i in 0...N
155 out[i,:,:,...,:] = data[i,:,:,...,:] / sqrt(sum(data[i,:,:,...,:] ** 2) + eps)
156
157 with ``mode`` = ``channel``, it normalizes each channel in the array by its L2 norm.::
158
159 for i in 0...N
160 out[:,i,:,...,:] = data[:,i,:,...,:] / sqrt(sum(data[:,i,:,...,:] ** 2) + eps)
161
162 with ``mode`` = ``spatial``, it normalizes the cross channel norm for each position
163 in the array by its L2 norm.::
164
165 for dim in 2...N
166 for i in 0...N
167 out[.....,i,...] = take(out, indices=i, axis=dim) / sqrt(sum(take(out, indices=i, axis=dim) ** 2) + eps)
168 -dim-
169
170 Example::
171
172 x = [[[1,2],
173 [3,4]],
174 [[2,2],
175 [5,6]]]
176
177 L2Normalization(x, mode='instance')
178 =[[[ 0.18257418 0.36514837]
179 [ 0.54772252 0.73029673]]
180 [[ 0.24077171 0.24077171]
181 [ 0.60192931 0.72231513]]]
182
183 L2Normalization(x, mode='channel')
184 =[[[ 0.31622776 0.44721359]
185 [ 0.94868326 0.89442718]]
186 [[ 0.37139067 0.31622776]
187 [ 0.92847669 0.94868326]]]
188
189 L2Normalization(x, mode='spatial')
190 =[[[ 0.44721359 0.89442718]
191 [ 0.60000002 0.80000001]]
192 [[ 0.70710677 0.70710677]
193 [ 0.6401844 0.76822126]]]
194
195 )code" ADD_FILELINE)
196 .add_argument("data", "NDArray-or-Symbol", "Input array to normalize.")
197 .add_arguments(L2NormalizationParam::__FIELDS__());
198 } // namespace op
199 } // namespace mxnet
200