1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file l2_normalization_op-inl.h 22 * \brief instance l2 Normalization op 23 */ 24 #ifndef MXNET_OPERATOR_L2_NORMALIZATION_INL_H_ 25 #define MXNET_OPERATOR_L2_NORMALIZATION_INL_H_ 26 27 #include <dmlc/logging.h> 28 #include <dmlc/parameter.h> 29 #include <mxnet/operator.h> 30 #include <map> 31 #include <vector> 32 #include <string> 33 #include <utility> 34 #include "./operator_common.h" 35 #include "./mshadow_op.h" 36 37 namespace mxnet { 38 namespace op { 39 40 // Declare enumeration of input order to make code more intuitive. 41 // These enums are only visible within this header 42 namespace l2_normalization { 43 enum L2NormalizationOpInputs {kData}; 44 enum L2NormalizationOpOutputs {kOut, kNorm}; 45 enum L2NormalizationOpType {kInstance, kChannel, kSpatial}; 46 enum L2NormalizationBackResource {kTempSpace}; 47 } // l2_normalization 48 49 struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> { 50 float eps; 51 int mode; DMLC_DECLARE_PARAMETERL2NormalizationParam52 DMLC_DECLARE_PARAMETER(L2NormalizationParam) { 53 DMLC_DECLARE_FIELD(eps).set_default(1e-10f) 54 .describe("A small constant for numerical stability."); 55 DMLC_DECLARE_FIELD(mode) 56 .add_enum("instance", l2_normalization::kInstance) 57 .add_enum("spatial", l2_normalization::kSpatial) 58 .add_enum("channel", l2_normalization::kChannel) 59 .set_default(l2_normalization::kInstance) 60 .describe("Specify the dimension along which to compute L2 norm."); 61 } 62 }; 63 64 /** 65 * \brief This is the implementation of l2 normalization operator. 66 * \tparam xpu The device that the op will be executed on. 67 */ 68 template<typename xpu, typename DType> 69 class L2NormalizationOp : public Operator { 70 public: L2NormalizationOp(L2NormalizationParam p)71 explicit L2NormalizationOp(L2NormalizationParam p) { 72 this->param_ = p; 73 } 74 Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)75 virtual void Forward(const OpContext &ctx, 76 const std::vector<TBlob> &in_data, 77 const std::vector<OpReqType> &req, 78 const std::vector<TBlob> &out_data, 79 const std::vector<TBlob> &aux_args) { 80 using namespace mshadow; 81 using namespace mshadow::expr; 82 if (req[l2_normalization::kOut] == kNullOp) return; 83 CHECK_EQ(req[l2_normalization::kOut], kWriteTo); 84 CHECK_EQ(in_data.size(), 1U); 85 CHECK_EQ(out_data.size(), 2U); 86 Stream<xpu> *s = ctx.get_stream<xpu>(); 87 mxnet::TShape orig_shape = in_data[l2_normalization::kData].shape_; 88 if (param_.mode == l2_normalization::kInstance) { 89 Shape<2> dshape = Shape2(orig_shape[0], 90 orig_shape.ProdShape(1, orig_shape.ndim())); 91 Tensor<xpu, 2, DType> data = in_data[l2_normalization::kData] 92 .get_with_shape<xpu, 2, DType>(dshape, s); 93 Tensor<xpu, 2, DType> out = out_data[l2_normalization::kOut] 94 .get_with_shape<xpu, 2, DType>(dshape, s); 95 Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s); 96 norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::square>(data)); 97 MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { 98 mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( 99 s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps)); 100 }); 101 norm = F<mxnet::op::mshadow_op::square_root>(norm); 102 out = data / broadcast<0>(norm, out.shape_); 103 } else if (param_.mode == l2_normalization::kChannel) { 104 CHECK_GE(orig_shape.ndim(), 3); 105 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], 106 orig_shape.ProdShape(2, orig_shape.ndim())); 107 Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData] 108 .get_with_shape<xpu, 3, DType>(dshape, s); 109 Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut] 110 .get_with_shape<xpu, 3, DType>(dshape, s); 111 Shape<2> norm_shape = Shape2(dshape[0], dshape[2]); 112 Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm] 113 .get_with_shape<xpu, 2, DType>(norm_shape, s); 114 norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 1); 115 MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { 116 mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( 117 s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); 118 }); 119 norm = F<mxnet::op::mshadow_op::square_root>(norm); 120 out = data / broadcast_with_axis(norm, 0, orig_shape[1]); 121 } else if (param_.mode == l2_normalization::kSpatial) { 122 CHECK_GE(orig_shape.ndim(), 3); 123 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], 124 orig_shape.ProdShape(2, orig_shape.ndim())); 125 Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData] 126 .get_with_shape<xpu, 3, DType>(dshape, s); 127 Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut] 128 .get_with_shape<xpu, 3, DType>(dshape, s); 129 Shape<2> norm_shape = Shape2(dshape[0], dshape[1]); 130 Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm] 131 .get_with_shape<xpu, 2, DType>(norm_shape, s); 132 norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 2); 133 MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { 134 mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( 135 s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); 136 }); 137 norm = F<mxnet::op::mshadow_op::square_root>(norm); 138 out = data / broadcast_with_axis(norm, 1, dshape[2]); 139 } else { 140 LOG(FATAL) << "Unexpected mode in l2 normalization"; 141 } 142 } 143 Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)144 virtual void Backward(const OpContext &ctx, 145 const std::vector<TBlob> &out_grad, 146 const std::vector<TBlob> &in_data, 147 const std::vector<TBlob> &out_data, 148 const std::vector<OpReqType> &req, 149 const std::vector<TBlob> &in_grad, 150 const std::vector<TBlob> &aux_args) { 151 using namespace mshadow; 152 using namespace mshadow::expr; 153 CHECK_EQ(out_grad.size(), 1U); 154 CHECK(in_data.size() == 1U && in_grad.size() == 1U); 155 CHECK_EQ(req.size(), 1U); 156 157 Stream<xpu> *s = ctx.get_stream<xpu>(); 158 mxnet::TShape orig_shape = out_data[l2_normalization::kOut].shape_; 159 if (param_.mode == l2_normalization::kInstance) { 160 Shape<2> dshape = Shape2(orig_shape[0], 161 orig_shape.ProdShape(1, orig_shape.ndim())); 162 Tensor<xpu, 2, DType> data = out_data[l2_normalization::kOut] 163 .get_with_shape<xpu, 2, DType>(dshape, s); 164 Tensor<xpu, 2, DType> grad_in = in_grad[l2_normalization::kData] 165 .get_with_shape<xpu, 2, DType>(dshape, s); 166 Tensor<xpu, 2, DType> grad_out = out_grad[l2_normalization::kOut] 167 .get_with_shape<xpu, 2, DType>(dshape, s); 168 Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s); 169 Tensor<xpu, 1, DType> temp = ctx.requested[l2_normalization::kTempSpace] 170 .get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s); 171 temp = sumall_except_dim<0>(grad_out * data); 172 Assign(grad_in, req[l2_normalization::kData], 173 (grad_out - data * broadcast<0>(temp, data.shape_)) / 174 broadcast<0>(norm, data.shape_)); 175 } else if (param_.mode == l2_normalization::kChannel) { 176 CHECK_GE(orig_shape.ndim(), 3); 177 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], 178 orig_shape.ProdShape(2, orig_shape.ndim())); 179 Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut] 180 .get_with_shape<xpu, 3, DType>(dshape, s); 181 Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData] 182 .get_with_shape<xpu, 3, DType>(dshape, s); 183 Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut] 184 .get_with_shape<xpu, 3, DType>(dshape, s); 185 Shape<2> norm_shape = Shape2(dshape[0], dshape[2]); 186 Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm] 187 .get_with_shape<xpu, 2, DType>(norm_shape, s); 188 Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace] 189 .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s); 190 temp = reduce_with_axis<red::sum, false>(grad_out * data, 1); 191 Assign(grad_in, req[l2_normalization::kData], 192 (grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) / 193 broadcast_with_axis(norm, 0, orig_shape[1])); 194 } else if (param_.mode == l2_normalization::kSpatial) { 195 CHECK_GE(orig_shape.ndim(), 3); 196 Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], 197 orig_shape.ProdShape(2, orig_shape.ndim())); 198 Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut] 199 .get_with_shape<xpu, 3, DType>(dshape, s); 200 Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData] 201 .get_with_shape<xpu, 3, DType>(dshape, s); 202 Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut] 203 .get_with_shape<xpu, 3, DType>(dshape, s); 204 Shape<2> norm_shape = Shape2(dshape[0], dshape[1]); 205 Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm] 206 .get_with_shape<xpu, 2, DType>(norm_shape, s); 207 Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace] 208 .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s); 209 temp = reduce_with_axis<red::sum, false>(grad_out * data, 2); 210 Assign(grad_in, req[l2_normalization::kData], 211 (grad_out - data * broadcast_with_axis(temp, 1, dshape[2])) / 212 broadcast_with_axis(norm, 1, dshape[2])); 213 } else { 214 LOG(FATAL) << "Unexpected mode in l2 normalization"; 215 } 216 } 217 218 protected: 219 L2NormalizationParam param_; 220 }; // class L2NormalizationOp 221 222 // Decalre Factory function, used for dispatch specialization 223 template<typename xpu> 224 Operator* CreateOp(L2NormalizationParam param, int dtype); 225 226 #if DMLC_USE_CXX11 227 class L2NormalizationProp : public OperatorProperty { 228 public: ListArguments()229 std::vector<std::string> ListArguments() const override { 230 return {"data"}; 231 } 232 ListOutputs()233 std::vector<std::string> ListOutputs() const override { 234 return {"output", "norm"}; 235 } 236 NumVisibleOutputs()237 int NumVisibleOutputs() const override { 238 return 1; 239 } 240 Init(const std::vector<std::pair<std::string,std::string>> & kwargs)241 void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { 242 param_.Init(kwargs); 243 } 244 GetParams()245 std::map<std::string, std::string> GetParams() const override { 246 return param_.__DICT__(); 247 } 248 InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)249 bool InferType(std::vector<int> *in_type, 250 std::vector<int> *out_type, 251 std::vector<int> *aux_type) const override { 252 int dtype = (*in_type)[0]; 253 type_assign(&dtype, (*out_type)[0]); 254 type_assign(&dtype, (*out_type)[1]); 255 256 TYPE_ASSIGN_CHECK(*in_type, 0, dtype); 257 TYPE_ASSIGN_CHECK(*out_type, 0, dtype); 258 TYPE_ASSIGN_CHECK(*out_type, 1, dtype); 259 return dtype != -1; 260 } 261 InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)262 bool InferShape(mxnet::ShapeVector *in_shape, 263 mxnet::ShapeVector *out_shape, 264 mxnet::ShapeVector *aux_shape) const override { 265 using namespace mshadow; 266 CHECK_EQ(in_shape->size(), 1U) << "L2Normalization layer only accepts data as input"; 267 const mxnet::TShape &dshape = (*in_shape)[l2_normalization::kData]; 268 // require data to be known 269 if ((*in_shape)[l2_normalization::kData].ndim() == 0) return false; 270 out_shape->clear(); 271 out_shape->push_back(dshape); 272 if (param_.mode == l2_normalization::kInstance) { 273 out_shape->push_back(Shape1(dshape[0])); 274 } else if (param_.mode == l2_normalization::kChannel) { 275 CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in channel mode"; 276 mxnet::TShape norm_shape = dshape; 277 norm_shape[1] = 1; 278 out_shape->push_back(norm_shape); 279 } else if (param_.mode == l2_normalization::kSpatial) { 280 CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in spatial mode"; 281 out_shape->push_back(Shape2(dshape[0], dshape[1])); 282 } else { 283 return false; 284 } 285 return true; 286 } 287 Copy()288 OperatorProperty* Copy() const override { 289 L2NormalizationProp* norm_sym = new L2NormalizationProp(); 290 norm_sym->param_ = this->param_; 291 return norm_sym; 292 } 293 TypeString()294 std::string TypeString() const override { 295 return "L2Normalization"; 296 } 297 298 // declare dependency and inplace optimization options DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)299 std::vector<int> DeclareBackwardDependency( 300 const std::vector<int> &out_grad, 301 const std::vector<int> &in_data, 302 const std::vector<int> &out_data) const override { 303 return {out_grad[l2_normalization::kOut], 304 out_data[l2_normalization::kOut], 305 out_data[l2_normalization::kNorm]}; 306 } 307 BackwardInplaceOption(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data,const std::vector<void * > & in_grad)308 std::vector<std::pair<int, void*> > BackwardInplaceOption( 309 const std::vector<int> &out_grad, 310 const std::vector<int> &in_data, 311 const std::vector<int> &out_data, 312 const std::vector<void*> &in_grad) const override { 313 return {{out_grad[l2_normalization::kOut], in_grad[l2_normalization::kData]}}; 314 } 315 BackwardResource(const mxnet::ShapeVector & in_shape)316 std::vector<ResourceRequest> BackwardResource( 317 const mxnet::ShapeVector &in_shape) const override { 318 return {ResourceRequest::kTempSpace}; 319 } 320 CreateOperator(Context ctx)321 Operator* CreateOperator(Context ctx) const override { 322 LOG(FATAL) << "Not Implemented."; 323 return nullptr; 324 } 325 326 Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, 327 std::vector<int> *in_type) const override; 328 329 private: 330 L2NormalizationParam param_; 331 }; // class L2NormalizationSymbol 332 #endif 333 } // namespace op 334 } // namespace mxnet 335 #endif // MXNET_OPERATOR_L2_NORMALIZATION_INL_H_ 336