1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file mrcnn_mask_target-inl.h
22  * \brief Mask-RCNN target generator
23  * \author Serge Panev
24  */
25 
26 
27 #ifndef MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
28 #define MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
29 
30 #include <mxnet/operator.h>
31 #include <vector>
32 #include "../operator_common.h"
33 #include "../mshadow_op.h"
34 #include "../tensor/init_op.h"
35 
36 namespace mxnet {
37 namespace op {
38 
39 namespace mrcnn_index {
40   enum ROIAlignOpInputs {kRoi, kGtMask, kMatches, kClasses};
41   enum ROIAlignOpOutputs {kMask, kMaskClasses};
42 }  // namespace mrcnn_index
43 
44 struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
45   int num_rois;
46   int num_classes;
47   int sample_ratio;
48   bool aligned;
49   mxnet::TShape mask_size;
50 
DMLC_DECLARE_PARAMETERMRCNNMaskTargetParam51   DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
52     DMLC_DECLARE_FIELD(num_rois)
53     .describe("Number of sampled RoIs.");
54     DMLC_DECLARE_FIELD(num_classes)
55     .describe("Number of classes.");
56     DMLC_DECLARE_FIELD(mask_size)
57     .set_expect_ndim(2).enforce_nonzero()
58     .describe("Size of the pooled masks height and width: (h, w).");
59     DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
60     .describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
61     DMLC_DECLARE_FIELD(aligned).set_default(false)
62     .describe("Center-aligned ROIAlign introduced in Detectron2. "
63     "To enable, set aligned to True.");
64   }
65 };
66 
MRCNNMaskTargetShape(const NodeAttrs & attrs,std::vector<mxnet::TShape> * in_shape,std::vector<mxnet::TShape> * out_shape)67 inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
68                                  std::vector<mxnet::TShape>* in_shape,
69                                  std::vector<mxnet::TShape>* out_shape) {
70   using namespace mshadow;
71   const MRCNNMaskTargetParam& param = nnvm::get<MRCNNMaskTargetParam>(attrs.parsed);
72 
73   CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]";
74 
75   // (B, N, 4)
76   mxnet::TShape tshape = in_shape->at(mrcnn_index::kRoi);
77   CHECK_EQ(tshape.ndim(), 3) << "rois should be a 2D tensor of shape [batch, rois, 4]";
78   CHECK_EQ(tshape[2], 4) << "rois should be a 2D tensor of shape [batch, rois, 4]";
79   auto batch_size = tshape[0];
80   auto num_rois = tshape[1];
81 
82   // (B, M, H, W)
83   tshape = in_shape->at(mrcnn_index::kGtMask);
84   CHECK_EQ(tshape.ndim(), 4) << "gt_masks should be a 4D tensor";
85   CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
86 
87   // (B, N)
88   tshape = in_shape->at(mrcnn_index::kMatches);
89   CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor";
90   CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
91 
92   // (B, N)
93   tshape = in_shape->at(mrcnn_index::kClasses);
94   CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor";
95   CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
96 
97   // out: 2 * (B, N, C, MS, MS)
98   auto oshape = Shape5(batch_size, num_rois, param.num_classes,
99                        param.mask_size[0], param.mask_size[1]);
100   out_shape->clear();
101   out_shape->push_back(oshape);
102   out_shape->push_back(oshape);
103   return true;
104 }
105 
MRCNNMaskTargetType(const NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)106 inline bool MRCNNMaskTargetType(const NodeAttrs& attrs,
107                                 std::vector<int>* in_type,
108                                 std::vector<int>* out_type) {
109   CHECK_EQ(in_type->size(), 4);
110   int dtype = (*in_type)[1];
111   CHECK_NE(dtype, -1) << "Input must have specified type";
112 
113   out_type->clear();
114   out_type->push_back(dtype);
115   out_type->push_back(dtype);
116   return true;
117 }
118 
119 template<typename xpu>
120 void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
121                         const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
122 
123 template<typename xpu>
MRCNNMaskTargetCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)124 void MRCNNMaskTargetCompute(const nnvm::NodeAttrs& attrs,
125                             const OpContext &ctx,
126                             const std::vector<TBlob> &inputs,
127                             const std::vector<OpReqType> &req,
128                             const std::vector<TBlob> &outputs) {
129   auto s = ctx.get_stream<xpu>();
130   const auto& p = dmlc::get<MRCNNMaskTargetParam>(attrs.parsed);
131   MRCNNMaskTargetRun<xpu>(p, inputs, outputs, s);
132 }
133 
134 }  // namespace op
135 }  // namespace mxnet
136 
137 #endif  // MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
138