1 /*! 2 * Copyright (c) 2017 by Contributors 3 * Copyright (c) 2017 Microsoft 4 * Licensed under The MIT License [see LICENSE for details] 5 * \file psroi_pooling-inl.h 6 * \brief psroi pooling operator and symbol 7 * \author Yi Li, Tairui Chen, Guodong Zhang, Jifeng Dai 8 * 9 * Code from https://github.com/msracver/Deformable-ConvNets/blob/d51075968c5fd40b37a55d20c8e945c1f181d529/rfcn/operator_cxx/psroi_pooling-inl.h 10 */ 11 #ifndef MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_ 12 #define MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_ 13 14 #include <dmlc/logging.h> 15 #include <dmlc/parameter.h> 16 #include <mxnet/operator.h> 17 #include <map> 18 #include <vector> 19 #include <string> 20 #include <utility> 21 #include "../mshadow_op.h" 22 #include "../operator_common.h" 23 24 25 namespace mxnet { 26 namespace op { 27 28 // Declare enumeration of input order to make code more intuitive. 29 // These enums are only visible within this header 30 namespace psroipool { 31 enum PSROIPoolingOpInputs {kData, kBox}; 32 enum PSROIPoolingOpOutputs {kOut}; 33 } // psroipool 34 35 struct PSROIPoolingParam : public dmlc::Parameter<PSROIPoolingParam> { 36 // mxnet::TShape pooled_size; 37 float spatial_scale; 38 int output_dim; 39 int pooled_size; 40 int group_size; DMLC_DECLARE_PARAMETERPSROIPoolingParam41 DMLC_DECLARE_PARAMETER(PSROIPoolingParam) { 42 DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0) 43 .describe("Ratio of input feature map height (or w) to raw image height (or w). " 44 "Equals the reciprocal of total stride in convolutional layers"); 45 DMLC_DECLARE_FIELD(output_dim).describe("fix output dim"); 46 DMLC_DECLARE_FIELD(pooled_size).describe("fix pooled size"); 47 DMLC_DECLARE_FIELD(group_size).set_default(0).describe("fix group size"); 48 } 49 }; 50 51 template<typename xpu, typename DType> 52 class PSROIPoolingOp : public Operator { 53 public: PSROIPoolingOp(PSROIPoolingParam p)54 explicit PSROIPoolingOp(PSROIPoolingParam p) { 55 this->param_ = p; 56 } 57 Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)58 virtual void Forward(const OpContext &ctx, 59 const std::vector<TBlob> &in_data, 60 const std::vector<OpReqType> &req, 61 const std::vector<TBlob> &out_data, 62 const std::vector<TBlob> &aux_args) { 63 using namespace mshadow; 64 CHECK_EQ(in_data.size(), 2); 65 CHECK_EQ(out_data.size(), 1); 66 CHECK_EQ(out_data[psroipool::kOut].shape_[0], in_data[psroipool::kBox].shape_[0]); 67 Stream<xpu> *s = ctx.get_stream<xpu>(); 68 69 Tensor<xpu, 4, DType> data = in_data[psroipool::kData].get<xpu, 4, DType>(s); 70 Tensor<xpu, 2, DType> bbox = in_data[psroipool::kBox].get<xpu, 2, DType>(s); 71 Tensor<xpu, 4, DType> out = out_data[psroipool::kOut].get<xpu, 4, DType>(s); 72 CHECK_EQ(data.CheckContiguous(), true); 73 CHECK_EQ(bbox.CheckContiguous(), true); 74 CHECK_EQ(out.CheckContiguous(), true); 75 out = -FLT_MAX; 76 PSROIPoolForward(out, data, bbox, param_.spatial_scale, param_.output_dim, param_.group_size); 77 } 78 Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)79 virtual void Backward(const OpContext &ctx, 80 const std::vector<TBlob> &out_grad, 81 const std::vector<TBlob> &in_data, 82 const std::vector<TBlob> &out_data, 83 const std::vector<OpReqType> &req, 84 const std::vector<TBlob> &in_grad, 85 const std::vector<TBlob> &aux_args) { 86 using namespace mshadow; 87 CHECK_EQ(in_data.size(), 2); 88 CHECK_EQ(out_data.size(), 1); 89 CHECK_EQ(out_grad[psroipool::kOut].shape_[0], in_data[psroipool::kBox].shape_[0]); 90 CHECK_NE(req[psroipool::kData], kWriteInplace) << 91 "ROIPooling: Backward doesn't support kWriteInplace."; 92 CHECK_NE(req[psroipool::kBox], kWriteInplace) << 93 "ROIPooling: Backward doesn't support kWriteInplace."; 94 Stream<xpu> *s = ctx.get_stream<xpu>(); 95 96 Tensor<xpu, 4, DType> grad_out = out_grad[psroipool::kOut].get<xpu, 4, DType>(s); 97 Tensor<xpu, 2, DType> bbox = in_data[psroipool::kBox].get<xpu, 2, DType>(s); 98 Tensor<xpu, 4, DType> grad_in = in_grad[psroipool::kData].get<xpu, 4, DType>(s); 99 Tensor<xpu, 2, DType> grad_roi = in_grad[psroipool::kBox].get<xpu, 2, DType>(s); 100 101 CHECK_EQ(grad_out.CheckContiguous(), true); 102 CHECK_EQ(bbox.CheckContiguous(), true); 103 CHECK_EQ(grad_in.CheckContiguous(), true); 104 105 if (kAddTo == req[psroipool::kData] || kWriteTo == req[psroipool::kData]) { 106 if (kWriteTo == req[psroipool::kData]) { 107 grad_in = 0.0f; 108 } 109 PSROIPoolBackwardAcc(grad_in, grad_out, bbox, param_.spatial_scale, 110 param_.output_dim, param_.group_size); 111 } 112 if (kWriteTo == req[psroipool::kBox]) { 113 grad_roi = 0.0f; 114 } 115 } 116 117 private: 118 PSROIPoolingParam param_; 119 }; // class PSROIPoolingOp 120 121 // Decalre Factory function, used for dispatch specialization 122 template<typename xpu> 123 Operator* CreateOp(PSROIPoolingParam param, int dtype); 124 125 #if DMLC_USE_CXX11 126 class PSROIPoolingProp : public OperatorProperty { 127 public: ListArguments()128 std::vector<std::string> ListArguments() const override { 129 return {"data", "rois"}; 130 } 131 ListOutputs()132 std::vector<std::string> ListOutputs() const override { 133 return {"output"}; 134 } 135 NumOutputs()136 int NumOutputs() const override { 137 return 1; 138 } 139 NumVisibleOutputs()140 int NumVisibleOutputs() const override { 141 return 1; 142 } 143 Init(const std::vector<std::pair<std::string,std::string>> & kwargs)144 void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { 145 param_.Init(kwargs); 146 if (param_.group_size == 0) { 147 param_.group_size = param_.pooled_size; 148 } 149 } 150 GetParams()151 std::map<std::string, std::string> GetParams() const override { 152 return param_.__DICT__(); 153 } 154 InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)155 bool InferShape(mxnet::ShapeVector *in_shape, 156 mxnet::ShapeVector *out_shape, 157 mxnet::ShapeVector *aux_shape) const override { 158 using namespace mshadow; 159 CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]"; 160 161 // data: [batch_size, c, h, w] 162 mxnet::TShape dshape = in_shape->at(psroipool::kData); 163 CHECK_EQ(dshape.ndim(), 4) << "data should be a 4D tensor"; 164 165 // bbox: [num_rois, 5] 166 mxnet::TShape bshape = in_shape->at(psroipool::kBox); 167 CHECK_EQ(bshape.ndim(), 2) << "bbox should be a 2D tensor of shape [batch, 5]"; 168 CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]"; 169 170 // out: [num_rois, c, pooled_h, pooled_w] 171 out_shape->clear(); 172 out_shape->push_back( 173 Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size)); 174 return true; 175 } 176 InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)177 bool InferType(std::vector<int> *in_type, 178 std::vector<int> *out_type, 179 std::vector<int> *aux_type) const override { 180 CHECK_EQ(in_type->size(), 2); 181 int dtype = (*in_type)[0]; 182 CHECK_EQ(dtype, (*in_type)[1]); 183 CHECK_NE(dtype, -1) << "Input must have specified type"; 184 185 out_type->clear(); 186 out_type->push_back(dtype); 187 return true; 188 } 189 Copy()190 OperatorProperty* Copy() const override { 191 PSROIPoolingProp* psroi_pooling_sym = new PSROIPoolingProp(); 192 psroi_pooling_sym->param_ = this->param_; 193 return psroi_pooling_sym; 194 } 195 TypeString()196 std::string TypeString() const override { 197 return "_contrib_PSROIPooling"; 198 } 199 200 // decalre dependency and inplace optimization options DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)201 std::vector<int> DeclareBackwardDependency( 202 const std::vector<int> &out_grad, 203 const std::vector<int> &in_data, 204 const std::vector<int> &out_data) const override { 205 return {out_grad[psroipool::kOut], in_data[psroipool::kBox]}; 206 } 207 208 CreateOperator(Context ctx)209 Operator* CreateOperator(Context ctx) const override { 210 LOG(FATAL) << "Not Implemented."; 211 return nullptr; 212 } 213 214 Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, 215 std::vector<int> *in_type) const override; 216 217 218 private: 219 PSROIPoolingParam param_; 220 }; // class PSROIPoolingProp 221 #endif 222 } // namespace op 223 } // namespace mxnet 224 #endif // MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_ 225