1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file mrcnn_mask_target-inl.h
22 * \brief Mask-RCNN target generator
23 * \author Serge Panev
24 */
25
26
27 #ifndef MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
28 #define MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
29
30 #include <mxnet/operator.h>
31 #include <vector>
32 #include "../operator_common.h"
33 #include "../mshadow_op.h"
34 #include "../tensor/init_op.h"
35
36 namespace mxnet {
37 namespace op {
38
39 namespace mrcnn_index {
40 enum ROIAlignOpInputs {kRoi, kGtMask, kMatches, kClasses};
41 enum ROIAlignOpOutputs {kMask, kMaskClasses};
42 } // namespace mrcnn_index
43
44 struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
45 int num_rois;
46 int num_classes;
47 int sample_ratio;
48 bool aligned;
49 mxnet::TShape mask_size;
50
DMLC_DECLARE_PARAMETERMRCNNMaskTargetParam51 DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
52 DMLC_DECLARE_FIELD(num_rois)
53 .describe("Number of sampled RoIs.");
54 DMLC_DECLARE_FIELD(num_classes)
55 .describe("Number of classes.");
56 DMLC_DECLARE_FIELD(mask_size)
57 .set_expect_ndim(2).enforce_nonzero()
58 .describe("Size of the pooled masks height and width: (h, w).");
59 DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
60 .describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
61 DMLC_DECLARE_FIELD(aligned).set_default(false)
62 .describe("Center-aligned ROIAlign introduced in Detectron2. "
63 "To enable, set aligned to True.");
64 }
65 };
66
MRCNNMaskTargetShape(const NodeAttrs & attrs,std::vector<mxnet::TShape> * in_shape,std::vector<mxnet::TShape> * out_shape)67 inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
68 std::vector<mxnet::TShape>* in_shape,
69 std::vector<mxnet::TShape>* out_shape) {
70 using namespace mshadow;
71 const MRCNNMaskTargetParam& param = nnvm::get<MRCNNMaskTargetParam>(attrs.parsed);
72
73 CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]";
74
75 // (B, N, 4)
76 mxnet::TShape tshape = in_shape->at(mrcnn_index::kRoi);
77 CHECK_EQ(tshape.ndim(), 3) << "rois should be a 2D tensor of shape [batch, rois, 4]";
78 CHECK_EQ(tshape[2], 4) << "rois should be a 2D tensor of shape [batch, rois, 4]";
79 auto batch_size = tshape[0];
80 auto num_rois = tshape[1];
81
82 // (B, M, H, W)
83 tshape = in_shape->at(mrcnn_index::kGtMask);
84 CHECK_EQ(tshape.ndim(), 4) << "gt_masks should be a 4D tensor";
85 CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
86
87 // (B, N)
88 tshape = in_shape->at(mrcnn_index::kMatches);
89 CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor";
90 CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
91
92 // (B, N)
93 tshape = in_shape->at(mrcnn_index::kClasses);
94 CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor";
95 CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
96
97 // out: 2 * (B, N, C, MS, MS)
98 auto oshape = Shape5(batch_size, num_rois, param.num_classes,
99 param.mask_size[0], param.mask_size[1]);
100 out_shape->clear();
101 out_shape->push_back(oshape);
102 out_shape->push_back(oshape);
103 return true;
104 }
105
MRCNNMaskTargetType(const NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)106 inline bool MRCNNMaskTargetType(const NodeAttrs& attrs,
107 std::vector<int>* in_type,
108 std::vector<int>* out_type) {
109 CHECK_EQ(in_type->size(), 4);
110 int dtype = (*in_type)[1];
111 CHECK_NE(dtype, -1) << "Input must have specified type";
112
113 out_type->clear();
114 out_type->push_back(dtype);
115 out_type->push_back(dtype);
116 return true;
117 }
118
119 template<typename xpu>
120 void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
121 const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
122
123 template<typename xpu>
MRCNNMaskTargetCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)124 void MRCNNMaskTargetCompute(const nnvm::NodeAttrs& attrs,
125 const OpContext &ctx,
126 const std::vector<TBlob> &inputs,
127 const std::vector<OpReqType> &req,
128 const std::vector<TBlob> &outputs) {
129 auto s = ctx.get_stream<xpu>();
130 const auto& p = dmlc::get<MRCNNMaskTargetParam>(attrs.parsed);
131 MRCNNMaskTargetRun<xpu>(p, inputs, outputs, s);
132 }
133
134 } // namespace op
135 } // namespace mxnet
136
137 #endif // MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
138