1 /*!
2  * Copyright (c) 2017 by Contributors
3  * Copyright (c) 2017 Microsoft
4  * Licensed under The MIT License [see LICENSE for details]
5  * \file psroi_pooling-inl.h
6  * \brief psroi pooling operator and symbol
7  * \author Yi Li, Tairui Chen, Guodong Zhang, Jifeng Dai
8  *
9  * Code from https://github.com/msracver/Deformable-ConvNets/blob/d51075968c5fd40b37a55d20c8e945c1f181d529/rfcn/operator_cxx/psroi_pooling-inl.h
10 */
11 #ifndef MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_
12 #define MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_
13 
14 #include <dmlc/logging.h>
15 #include <dmlc/parameter.h>
16 #include <mxnet/operator.h>
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include "../mshadow_op.h"
22 #include "../operator_common.h"
23 
24 
25 namespace mxnet {
26 namespace op {
27 
28 // Declare enumeration of input order to make code more intuitive.
29 // These enums are only visible within this header
30 namespace psroipool {
31 enum PSROIPoolingOpInputs {kData, kBox};
32 enum PSROIPoolingOpOutputs {kOut};
33 }  // psroipool
34 
35 struct PSROIPoolingParam : public dmlc::Parameter<PSROIPoolingParam> {
36   // mxnet::TShape pooled_size;
37   float spatial_scale;
38   int output_dim;
39   int pooled_size;
40   int group_size;
DMLC_DECLARE_PARAMETERPSROIPoolingParam41   DMLC_DECLARE_PARAMETER(PSROIPoolingParam) {
42     DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0)
43     .describe("Ratio of input feature map height (or w) to raw image height (or w). "
44     "Equals the reciprocal of total stride in convolutional layers");
45     DMLC_DECLARE_FIELD(output_dim).describe("fix output dim");
46   DMLC_DECLARE_FIELD(pooled_size).describe("fix pooled size");
47     DMLC_DECLARE_FIELD(group_size).set_default(0).describe("fix group size");
48   }
49 };
50 
51 template<typename xpu, typename DType>
52 class PSROIPoolingOp : public Operator {
53  public:
PSROIPoolingOp(PSROIPoolingParam p)54   explicit PSROIPoolingOp(PSROIPoolingParam p) {
55     this->param_ = p;
56   }
57 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)58   virtual void Forward(const OpContext &ctx,
59                        const std::vector<TBlob> &in_data,
60                        const std::vector<OpReqType> &req,
61                        const std::vector<TBlob> &out_data,
62                        const std::vector<TBlob> &aux_args) {
63     using namespace mshadow;
64     CHECK_EQ(in_data.size(), 2);
65     CHECK_EQ(out_data.size(), 1);
66     CHECK_EQ(out_data[psroipool::kOut].shape_[0], in_data[psroipool::kBox].shape_[0]);
67     Stream<xpu> *s = ctx.get_stream<xpu>();
68 
69     Tensor<xpu, 4, DType> data = in_data[psroipool::kData].get<xpu, 4, DType>(s);
70     Tensor<xpu, 2, DType> bbox = in_data[psroipool::kBox].get<xpu, 2, DType>(s);
71     Tensor<xpu, 4, DType> out = out_data[psroipool::kOut].get<xpu, 4, DType>(s);
72     CHECK_EQ(data.CheckContiguous(), true);
73     CHECK_EQ(bbox.CheckContiguous(), true);
74     CHECK_EQ(out.CheckContiguous(), true);
75     out = -FLT_MAX;
76     PSROIPoolForward(out, data, bbox, param_.spatial_scale, param_.output_dim, param_.group_size);
77   }
78 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)79   virtual void Backward(const OpContext &ctx,
80                         const std::vector<TBlob> &out_grad,
81                         const std::vector<TBlob> &in_data,
82                         const std::vector<TBlob> &out_data,
83                         const std::vector<OpReqType> &req,
84                         const std::vector<TBlob> &in_grad,
85                         const std::vector<TBlob> &aux_args) {
86     using namespace mshadow;
87     CHECK_EQ(in_data.size(), 2);
88     CHECK_EQ(out_data.size(), 1);
89     CHECK_EQ(out_grad[psroipool::kOut].shape_[0], in_data[psroipool::kBox].shape_[0]);
90     CHECK_NE(req[psroipool::kData], kWriteInplace) <<
91       "ROIPooling: Backward doesn't support kWriteInplace.";
92     CHECK_NE(req[psroipool::kBox], kWriteInplace) <<
93       "ROIPooling: Backward doesn't support kWriteInplace.";
94     Stream<xpu> *s = ctx.get_stream<xpu>();
95 
96     Tensor<xpu, 4, DType> grad_out = out_grad[psroipool::kOut].get<xpu, 4, DType>(s);
97     Tensor<xpu, 2, DType> bbox = in_data[psroipool::kBox].get<xpu, 2, DType>(s);
98     Tensor<xpu, 4, DType> grad_in = in_grad[psroipool::kData].get<xpu, 4, DType>(s);
99     Tensor<xpu, 2, DType> grad_roi = in_grad[psroipool::kBox].get<xpu, 2, DType>(s);
100 
101     CHECK_EQ(grad_out.CheckContiguous(), true);
102     CHECK_EQ(bbox.CheckContiguous(), true);
103     CHECK_EQ(grad_in.CheckContiguous(), true);
104 
105     if (kAddTo == req[psroipool::kData] || kWriteTo == req[psroipool::kData]) {
106       if (kWriteTo == req[psroipool::kData]) {
107         grad_in = 0.0f;
108       }
109       PSROIPoolBackwardAcc(grad_in, grad_out, bbox, param_.spatial_scale,
110                            param_.output_dim, param_.group_size);
111     }
112     if (kWriteTo == req[psroipool::kBox]) {
113       grad_roi = 0.0f;
114     }
115   }
116 
117  private:
118   PSROIPoolingParam param_;
119 };  // class PSROIPoolingOp
120 
121 // Decalre Factory function, used for dispatch specialization
122 template<typename xpu>
123 Operator* CreateOp(PSROIPoolingParam param, int dtype);
124 
125 #if DMLC_USE_CXX11
126 class PSROIPoolingProp : public OperatorProperty {
127  public:
ListArguments()128   std::vector<std::string> ListArguments() const override {
129     return {"data", "rois"};
130   }
131 
ListOutputs()132   std::vector<std::string> ListOutputs() const override {
133     return {"output"};
134   }
135 
NumOutputs()136   int NumOutputs() const override {
137     return 1;
138   }
139 
NumVisibleOutputs()140   int NumVisibleOutputs() const override {
141     return 1;
142   }
143 
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)144   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
145     param_.Init(kwargs);
146   if (param_.group_size == 0) {
147     param_.group_size = param_.pooled_size;
148   }
149   }
150 
GetParams()151   std::map<std::string, std::string> GetParams() const override {
152     return param_.__DICT__();
153   }
154 
InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)155   bool InferShape(mxnet::ShapeVector *in_shape,
156                   mxnet::ShapeVector *out_shape,
157                   mxnet::ShapeVector *aux_shape) const override {
158     using namespace mshadow;
159     CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
160 
161     // data: [batch_size, c, h, w]
162     mxnet::TShape dshape = in_shape->at(psroipool::kData);
163     CHECK_EQ(dshape.ndim(), 4) << "data should be a 4D tensor";
164 
165     // bbox: [num_rois, 5]
166     mxnet::TShape bshape = in_shape->at(psroipool::kBox);
167     CHECK_EQ(bshape.ndim(), 2) << "bbox should be a 2D tensor of shape [batch, 5]";
168     CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]";
169 
170     // out: [num_rois, c, pooled_h, pooled_w]
171     out_shape->clear();
172     out_shape->push_back(
173          Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size));
174     return true;
175   }
176 
InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)177   bool InferType(std::vector<int> *in_type,
178                  std::vector<int> *out_type,
179                  std::vector<int> *aux_type) const override {
180     CHECK_EQ(in_type->size(), 2);
181     int dtype = (*in_type)[0];
182     CHECK_EQ(dtype, (*in_type)[1]);
183     CHECK_NE(dtype, -1) << "Input must have specified type";
184 
185     out_type->clear();
186     out_type->push_back(dtype);
187     return true;
188   }
189 
Copy()190   OperatorProperty* Copy() const override {
191     PSROIPoolingProp* psroi_pooling_sym = new PSROIPoolingProp();
192     psroi_pooling_sym->param_ = this->param_;
193     return psroi_pooling_sym;
194   }
195 
TypeString()196   std::string TypeString() const override {
197     return "_contrib_PSROIPooling";
198   }
199 
200   // decalre dependency and inplace optimization options
DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)201   std::vector<int> DeclareBackwardDependency(
202     const std::vector<int> &out_grad,
203     const std::vector<int> &in_data,
204     const std::vector<int> &out_data) const override {
205     return {out_grad[psroipool::kOut], in_data[psroipool::kBox]};
206   }
207 
208 
CreateOperator(Context ctx)209   Operator* CreateOperator(Context ctx) const override {
210     LOG(FATAL) << "Not Implemented.";
211     return nullptr;
212   }
213 
214   Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
215                              std::vector<int> *in_type) const override;
216 
217 
218  private:
219   PSROIPoolingParam param_;
220 };  // class PSROIPoolingProp
221 #endif
222 }  // namespace op
223 }  // namespace mxnet
224 #endif  // MXNET_OPERATOR_CONTRIB_PSROI_POOLING_INL_H_
225