1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file multibox_detection-inl.h
22  * \brief post-process multibox detection predictions
23  * \author Joshua Zhang
24 */
25 #ifndef MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_
26 #define MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <mxnet/operator.h>
30 #include <mxnet/base.h>
31 #include <map>
32 #include <vector>
33 #include <string>
34 #include <utility>
35 #include <valarray>
36 #include "../operator_common.h"
37 
38 namespace mxnet {
39 namespace op {
40 namespace mboxdet_enum {
41 enum MultiBoxDetectionOpInputs {kClsProb, kLocPred, kAnchor};
42 enum MultiBoxDetectionOpOutputs {kOut};
43 enum MultiBoxDetectionOpResource {kTempSpace};
44 }  // namespace mboxdet_enum
45 
46 struct MultiBoxDetectionParam : public dmlc::Parameter<MultiBoxDetectionParam> {
47   bool clip;
48   float threshold;
49   int background_id;
50   float nms_threshold;
51   bool force_suppress;
52   int keep_topk;
53   int nms_topk;
54   mxnet::Tuple<float> variances;
DMLC_DECLARE_PARAMETERMultiBoxDetectionParam55   DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) {
56     DMLC_DECLARE_FIELD(clip).set_default(true)
57     .describe("Clip out-of-boundary boxes.");
58     DMLC_DECLARE_FIELD(threshold).set_default(0.01f)
59     .describe("Threshold to be a positive prediction.");
60     DMLC_DECLARE_FIELD(background_id).set_default(0)
61     .describe("Background id.");
62     DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5f)
63     .describe("Non-maximum suppression threshold.");
64     DMLC_DECLARE_FIELD(force_suppress).set_default(false)
65     .describe("Suppress all detections regardless of class_id.");
66     DMLC_DECLARE_FIELD(variances).set_default({0.1f, 0.1f, 0.2f, 0.2f})
67     .describe("Variances to be decoded from box regression output.");
68     DMLC_DECLARE_FIELD(nms_topk).set_default(-1)
69     .describe("Keep maximum top k detections before nms, -1 for no limit.");
70   }
71 };  // struct MultiBoxDetectionParam
72 
73 template<typename xpu, typename DType>
74 class MultiBoxDetectionOp : public Operator {
75  public:
MultiBoxDetectionOp(MultiBoxDetectionParam param)76   explicit MultiBoxDetectionOp(MultiBoxDetectionParam param) {
77     this->param_ = param;
78   }
79 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)80   virtual void Forward(const OpContext &ctx,
81                        const std::vector<TBlob> &in_data,
82                        const std::vector<OpReqType> &req,
83                        const std::vector<TBlob> &out_data,
84                        const std::vector<TBlob> &aux_args) {
85      using namespace mshadow;
86      using namespace mshadow::expr;
87      CHECK_EQ(in_data.size(), 3U) << "Input: [cls_prob, loc_pred, anchor]";
88      mxnet::TShape ashape = in_data[mboxdet_enum::kAnchor].shape_;
89      CHECK_EQ(out_data.size(), 1U);
90 
91      Stream<xpu> *s = ctx.get_stream<xpu>();
92      Tensor<xpu, 3, DType> cls_prob = in_data[mboxdet_enum::kClsProb]
93        .get<xpu, 3, DType>(s);
94      Tensor<xpu, 2, DType> loc_pred = in_data[mboxdet_enum::kLocPred]
95        .get<xpu, 2, DType>(s);
96      Tensor<xpu, 2, DType> anchors = in_data[mboxdet_enum::kAnchor]
97        .get_with_shape<xpu, 2, DType>(Shape2(ashape[1], 4), s);
98      Tensor<xpu, 3, DType> out = out_data[mboxdet_enum::kOut]
99        .get<xpu, 3, DType>(s);
100      Tensor<xpu, 3, DType> temp_space = ctx.requested[mboxdet_enum::kTempSpace]
101        .get_space_typed<xpu, 3, DType>(out.shape_, s);
102      out = -1.f;
103      MultiBoxDetectionForward(out, cls_prob, loc_pred, anchors, temp_space,
104        param_.threshold, param_.clip, param_.variances, param_.nms_threshold,
105        param_.force_suppress, param_.nms_topk);
106   }
107 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_states)108   virtual void Backward(const OpContext &ctx,
109                         const std::vector<TBlob> &out_grad,
110                         const std::vector<TBlob> &in_data,
111                         const std::vector<TBlob> &out_data,
112                         const std::vector<OpReqType> &req,
113                         const std::vector<TBlob> &in_grad,
114                         const std::vector<TBlob> &aux_states) {
115     using namespace mshadow;
116     using namespace mshadow::expr;
117     Stream<xpu> *s = ctx.get_stream<xpu>();
118     Tensor<xpu, 2, DType> gradc = in_grad[mboxdet_enum::kClsProb].FlatTo2D<xpu, DType>(s);
119     Tensor<xpu, 2, DType> gradl = in_grad[mboxdet_enum::kLocPred].FlatTo2D<xpu, DType>(s);
120     Tensor<xpu, 2, DType> grada = in_grad[mboxdet_enum::kAnchor].FlatTo2D<xpu, DType>(s);
121     gradc = 0.f;
122     gradl = 0.f;
123     grada = 0.f;
124 }
125 
126  private:
127   MultiBoxDetectionParam param_;
128 };  // class MultiBoxDetectionOp
129 
130 template<typename xpu>
131 Operator *CreateOp(MultiBoxDetectionParam, int dtype);
132 
133 #if DMLC_USE_CXX11
134 class MultiBoxDetectionProp : public OperatorProperty {
135  public:
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)136   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
137     param_.Init(kwargs);
138   }
139 
GetParams()140   std::map<std::string, std::string> GetParams() const override {
141     return param_.__DICT__();
142   }
143 
ListArguments()144   std::vector<std::string> ListArguments() const override {
145     return {"cls_prob", "loc_pred", "anchor"};
146   }
147 
InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)148   bool InferShape(mxnet::ShapeVector *in_shape,
149                   mxnet::ShapeVector *out_shape,
150                   mxnet::ShapeVector *aux_shape) const override {
151     using namespace mshadow;
152     CHECK_EQ(in_shape->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]";
153     mxnet::TShape cshape = in_shape->at(mboxdet_enum::kClsProb);
154     mxnet::TShape lshape = in_shape->at(mboxdet_enum::kLocPred);
155     mxnet::TShape ashape = in_shape->at(mboxdet_enum::kAnchor);
156     CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape;
157     CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape;
158     CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape;
159     CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch";
160     CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc";
161     CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0";
162     CHECK_EQ(ashape[2], 4U);
163     mxnet::TShape oshape = mxnet::TShape(3, -1);
164     oshape[0] = cshape[0];
165     oshape[1] = ashape[1];
166     oshape[2] = 6;  // [id, prob, xmin, ymin, xmax, ymax]
167     out_shape->clear();
168     out_shape->push_back(oshape);
169     return true;
170   }
171 
Copy()172   OperatorProperty* Copy() const override {
173     auto ptr = new MultiBoxDetectionProp();
174     ptr->param_ = param_;
175     return ptr;
176   }
177 
TypeString()178   std::string TypeString() const override {
179     return "_contrib_MultiBoxDetection";
180   }
181 
ForwardResource(const mxnet::ShapeVector & in_shape)182   std::vector<ResourceRequest> ForwardResource(
183       const mxnet::ShapeVector &in_shape) const override {
184     return {ResourceRequest::kTempSpace};
185   }
186 
CreateOperator(Context ctx)187   Operator* CreateOperator(Context ctx) const override {
188     LOG(FATAL) << "Not implemented";
189     return nullptr;
190   }
191 
192   Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
193                              std::vector<int> *in_type) const override;
194 
195  private:
196   MultiBoxDetectionParam param_;
197 };  // class MultiBoxDetectionProp
198 #endif  // DMLC_USE_CXX11
199 
200 }  // namespace op
201 }  // namespace mxnet
202 
203 #endif  // MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_
204