1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file multibox_detection-inl.h 22 * \brief post-process multibox detection predictions 23 * \author Joshua Zhang 24 */ 25 #ifndef MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_ 26 #define MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_ 27 #include <dmlc/logging.h> 28 #include <dmlc/parameter.h> 29 #include <mxnet/operator.h> 30 #include <mxnet/base.h> 31 #include <map> 32 #include <vector> 33 #include <string> 34 #include <utility> 35 #include <valarray> 36 #include "../operator_common.h" 37 38 namespace mxnet { 39 namespace op { 40 namespace mboxdet_enum { 41 enum MultiBoxDetectionOpInputs {kClsProb, kLocPred, kAnchor}; 42 enum MultiBoxDetectionOpOutputs {kOut}; 43 enum MultiBoxDetectionOpResource {kTempSpace}; 44 } // namespace mboxdet_enum 45 46 struct MultiBoxDetectionParam : public dmlc::Parameter<MultiBoxDetectionParam> { 47 bool clip; 48 float threshold; 49 int background_id; 50 float nms_threshold; 51 bool force_suppress; 52 int keep_topk; 53 int nms_topk; 54 mxnet::Tuple<float> variances; DMLC_DECLARE_PARAMETERMultiBoxDetectionParam55 DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { 56 DMLC_DECLARE_FIELD(clip).set_default(true) 57 .describe("Clip out-of-boundary boxes."); 58 DMLC_DECLARE_FIELD(threshold).set_default(0.01f) 59 .describe("Threshold to be a positive prediction."); 60 DMLC_DECLARE_FIELD(background_id).set_default(0) 61 .describe("Background id."); 62 DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5f) 63 .describe("Non-maximum suppression threshold."); 64 DMLC_DECLARE_FIELD(force_suppress).set_default(false) 65 .describe("Suppress all detections regardless of class_id."); 66 DMLC_DECLARE_FIELD(variances).set_default({0.1f, 0.1f, 0.2f, 0.2f}) 67 .describe("Variances to be decoded from box regression output."); 68 DMLC_DECLARE_FIELD(nms_topk).set_default(-1) 69 .describe("Keep maximum top k detections before nms, -1 for no limit."); 70 } 71 }; // struct MultiBoxDetectionParam 72 73 template<typename xpu, typename DType> 74 class MultiBoxDetectionOp : public Operator { 75 public: MultiBoxDetectionOp(MultiBoxDetectionParam param)76 explicit MultiBoxDetectionOp(MultiBoxDetectionParam param) { 77 this->param_ = param; 78 } 79 Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)80 virtual void Forward(const OpContext &ctx, 81 const std::vector<TBlob> &in_data, 82 const std::vector<OpReqType> &req, 83 const std::vector<TBlob> &out_data, 84 const std::vector<TBlob> &aux_args) { 85 using namespace mshadow; 86 using namespace mshadow::expr; 87 CHECK_EQ(in_data.size(), 3U) << "Input: [cls_prob, loc_pred, anchor]"; 88 mxnet::TShape ashape = in_data[mboxdet_enum::kAnchor].shape_; 89 CHECK_EQ(out_data.size(), 1U); 90 91 Stream<xpu> *s = ctx.get_stream<xpu>(); 92 Tensor<xpu, 3, DType> cls_prob = in_data[mboxdet_enum::kClsProb] 93 .get<xpu, 3, DType>(s); 94 Tensor<xpu, 2, DType> loc_pred = in_data[mboxdet_enum::kLocPred] 95 .get<xpu, 2, DType>(s); 96 Tensor<xpu, 2, DType> anchors = in_data[mboxdet_enum::kAnchor] 97 .get_with_shape<xpu, 2, DType>(Shape2(ashape[1], 4), s); 98 Tensor<xpu, 3, DType> out = out_data[mboxdet_enum::kOut] 99 .get<xpu, 3, DType>(s); 100 Tensor<xpu, 3, DType> temp_space = ctx.requested[mboxdet_enum::kTempSpace] 101 .get_space_typed<xpu, 3, DType>(out.shape_, s); 102 out = -1.f; 103 MultiBoxDetectionForward(out, cls_prob, loc_pred, anchors, temp_space, 104 param_.threshold, param_.clip, param_.variances, param_.nms_threshold, 105 param_.force_suppress, param_.nms_topk); 106 } 107 Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_states)108 virtual void Backward(const OpContext &ctx, 109 const std::vector<TBlob> &out_grad, 110 const std::vector<TBlob> &in_data, 111 const std::vector<TBlob> &out_data, 112 const std::vector<OpReqType> &req, 113 const std::vector<TBlob> &in_grad, 114 const std::vector<TBlob> &aux_states) { 115 using namespace mshadow; 116 using namespace mshadow::expr; 117 Stream<xpu> *s = ctx.get_stream<xpu>(); 118 Tensor<xpu, 2, DType> gradc = in_grad[mboxdet_enum::kClsProb].FlatTo2D<xpu, DType>(s); 119 Tensor<xpu, 2, DType> gradl = in_grad[mboxdet_enum::kLocPred].FlatTo2D<xpu, DType>(s); 120 Tensor<xpu, 2, DType> grada = in_grad[mboxdet_enum::kAnchor].FlatTo2D<xpu, DType>(s); 121 gradc = 0.f; 122 gradl = 0.f; 123 grada = 0.f; 124 } 125 126 private: 127 MultiBoxDetectionParam param_; 128 }; // class MultiBoxDetectionOp 129 130 template<typename xpu> 131 Operator *CreateOp(MultiBoxDetectionParam, int dtype); 132 133 #if DMLC_USE_CXX11 134 class MultiBoxDetectionProp : public OperatorProperty { 135 public: Init(const std::vector<std::pair<std::string,std::string>> & kwargs)136 void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { 137 param_.Init(kwargs); 138 } 139 GetParams()140 std::map<std::string, std::string> GetParams() const override { 141 return param_.__DICT__(); 142 } 143 ListArguments()144 std::vector<std::string> ListArguments() const override { 145 return {"cls_prob", "loc_pred", "anchor"}; 146 } 147 InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)148 bool InferShape(mxnet::ShapeVector *in_shape, 149 mxnet::ShapeVector *out_shape, 150 mxnet::ShapeVector *aux_shape) const override { 151 using namespace mshadow; 152 CHECK_EQ(in_shape->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]"; 153 mxnet::TShape cshape = in_shape->at(mboxdet_enum::kClsProb); 154 mxnet::TShape lshape = in_shape->at(mboxdet_enum::kLocPred); 155 mxnet::TShape ashape = in_shape->at(mboxdet_enum::kAnchor); 156 CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape; 157 CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape; 158 CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape; 159 CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; 160 CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; 161 CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0"; 162 CHECK_EQ(ashape[2], 4U); 163 mxnet::TShape oshape = mxnet::TShape(3, -1); 164 oshape[0] = cshape[0]; 165 oshape[1] = ashape[1]; 166 oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] 167 out_shape->clear(); 168 out_shape->push_back(oshape); 169 return true; 170 } 171 Copy()172 OperatorProperty* Copy() const override { 173 auto ptr = new MultiBoxDetectionProp(); 174 ptr->param_ = param_; 175 return ptr; 176 } 177 TypeString()178 std::string TypeString() const override { 179 return "_contrib_MultiBoxDetection"; 180 } 181 ForwardResource(const mxnet::ShapeVector & in_shape)182 std::vector<ResourceRequest> ForwardResource( 183 const mxnet::ShapeVector &in_shape) const override { 184 return {ResourceRequest::kTempSpace}; 185 } 186 CreateOperator(Context ctx)187 Operator* CreateOperator(Context ctx) const override { 188 LOG(FATAL) << "Not implemented"; 189 return nullptr; 190 } 191 192 Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, 193 std::vector<int> *in_type) const override; 194 195 private: 196 MultiBoxDetectionParam param_; 197 }; // class MultiBoxDetectionProp 198 #endif // DMLC_USE_CXX11 199 200 } // namespace op 201 } // namespace mxnet 202 203 #endif // MXNET_OPERATOR_CONTRIB_MULTIBOX_DETECTION_INL_H_ 204