1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file roi_align.cc
21 * \brief roi align operator
22 * \author Hang Zhang, Shesung
23 * Adapted from Caffe2
24 */
25 #include "./roi_align-inl.h"
26
27
28 namespace mxnet {
29 namespace op {
30
31 template <typename T>
32 struct PreCalc {
33 int pos1;
34 int pos2;
35 int pos3;
36 int pos4;
37 T w1;
38 T w2;
39 T w3;
40 T w4;
41 };
42
43 template <typename T>
pre_calc_for_bilinear_interpolate(const int height,const int width,const int pooled_height,const int pooled_width,const int iy_upper,const int ix_upper,T roi_start_h,T roi_start_w,T bin_size_h,T bin_size_w,int roi_bin_grid_h,int roi_bin_grid_w,std::vector<PreCalc<T>> * pre_calc)44 void pre_calc_for_bilinear_interpolate(
45 const int height,
46 const int width,
47 const int pooled_height,
48 const int pooled_width,
49 const int iy_upper,
50 const int ix_upper,
51 T roi_start_h,
52 T roi_start_w,
53 T bin_size_h,
54 T bin_size_w,
55 int roi_bin_grid_h,
56 int roi_bin_grid_w,
57 std::vector<PreCalc<T>>* pre_calc) {
58 int pre_calc_index = 0;
59 for (int ph = 0; ph < pooled_height; ph++) {
60 for (int pw = 0; pw < pooled_width; pw++) {
61 for (int iy = 0; iy < iy_upper; iy++) {
62 const T yy = roi_start_h + ph * bin_size_h +
63 static_cast<T>(iy + .5f) * bin_size_h /
64 static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
65 for (int ix = 0; ix < ix_upper; ix++) {
66 const T xx = roi_start_w + pw * bin_size_w +
67 static_cast<T>(ix + .5f) * bin_size_w /
68 static_cast<T>(roi_bin_grid_w);
69
70 T x = xx;
71 T y = yy;
72 // deal with: inverse elements are out of feature map boundary
73 if (y < -1.0 || y > height || x < -1.0 || x > width) {
74 // empty
75 PreCalc<T> pc;
76 pc.pos1 = 0;
77 pc.pos2 = 0;
78 pc.pos3 = 0;
79 pc.pos4 = 0;
80 pc.w1 = 0;
81 pc.w2 = 0;
82 pc.w3 = 0;
83 pc.w4 = 0;
84 pre_calc->at(pre_calc_index) = pc;
85 pre_calc_index += 1;
86 continue;
87 }
88
89 if (y <= 0) {
90 y = 0;
91 }
92 if (x <= 0) {
93 x = 0;
94 }
95
96 int y_low = static_cast<int>(y);
97 int x_low = static_cast<int>(x);
98 int y_high;
99 int x_high;
100
101 if (y_low >= height - 1) {
102 y_high = y_low = height - 1;
103 y = (T)y_low;
104 } else {
105 y_high = y_low + 1;
106 }
107
108 if (x_low >= width - 1) {
109 x_high = x_low = width - 1;
110 x = (T)x_low;
111 } else {
112 x_high = x_low + 1;
113 }
114
115 T ly = y - y_low;
116 T lx = x - x_low;
117 T hy = 1. - ly, hx = 1. - lx;
118 T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
119
120 // save weights and indeces
121 PreCalc<T> pc;
122 pc.pos1 = y_low * width + x_low;
123 pc.pos2 = y_low * width + x_high;
124 pc.pos3 = y_high * width + x_low;
125 pc.pos4 = y_high * width + x_high;
126 pc.w1 = w1;
127 pc.w2 = w2;
128 pc.w3 = w3;
129 pc.w4 = w4;
130 pre_calc->at(pre_calc_index) = pc;
131
132 pre_calc_index += 1;
133 }
134 }
135 }
136 }
137 }
138
139 template <typename T>
ROIAlignForward(const int nthreads,const T * bottom_data,const T & spatial_scale,const bool position_sensitive,const bool continuous_coordinate,const int channels,const int height,const int width,const int pooled_height,const int pooled_width,const int sampling_ratio,const T * bottom_rois,int roi_cols,T * top_data)140 void ROIAlignForward(
141 const int nthreads,
142 const T* bottom_data,
143 const T& spatial_scale,
144 const bool position_sensitive,
145 const bool continuous_coordinate,
146 const int channels,
147 const int height,
148 const int width,
149 const int pooled_height,
150 const int pooled_width,
151 const int sampling_ratio,
152 const T* bottom_rois,
153 int roi_cols,
154 T* top_data) {
155 DCHECK(roi_cols == 4 || roi_cols == 5);
156
157 int n_rois = nthreads / channels / pooled_width / pooled_height;
158 // (n, c, ph, pw) is an element in the pooled output
159 // can be parallelized using omp
160 #pragma omp parallel for \
161 num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
162 for (int n = 0; n < n_rois; n++) {
163 int index_n = n * channels * pooled_width * pooled_height;
164
165 // roi could have 4 or 5 columns
166 const T* offset_bottom_rois = bottom_rois + n * roi_cols;
167 int roi_batch_ind = 0;
168 if (roi_cols == 5) {
169 roi_batch_ind = offset_bottom_rois[0];
170 if (roi_batch_ind < 0) {
171 top_data[n] = 0;
172 continue;
173 }
174 offset_bottom_rois++;
175 }
176
177 // Do not using rounding; this implementation detail is critical
178 T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : static_cast<T>(0);
179 T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
180 T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
181 T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
182 T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;
183
184 T roi_width = roi_end_w - roi_start_w;
185 T roi_height = roi_end_h - roi_start_h;
186 if (continuous_coordinate) {
187 CHECK_GT(roi_width, 0.);
188 CHECK_GT(roi_height, 0.);
189 } else { // backward compatiblity
190 // Force malformed ROIs to be 1x1
191 roi_width = std::max(roi_width, (T)1.);
192 roi_height = std::max(roi_height, (T)1.);
193 }
194 T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
195 T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
196
197 // We use roi_bin_grid to sample the grid and mimic integral
198 int roi_bin_grid_h = (sampling_ratio > 0)
199 ? sampling_ratio
200 : std::ceil(roi_height / pooled_height); // e.g., = 2
201 int roi_bin_grid_w =
202 (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width);
203
204 // We do average (integral) pooling inside a bin
205 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
206
207 // we want to precalculate indeces and weights shared by all chanels,
208 // this is the key point of optimiation
209 std::vector<PreCalc<T>> pre_calc(
210 roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
211 pre_calc_for_bilinear_interpolate(
212 height,
213 width,
214 pooled_height,
215 pooled_width,
216 roi_bin_grid_h,
217 roi_bin_grid_w,
218 roi_start_h,
219 roi_start_w,
220 bin_size_h,
221 bin_size_w,
222 roi_bin_grid_h,
223 roi_bin_grid_w,
224 &pre_calc);
225
226 for (int c = 0; c < channels; c++) {
227 int index_n_c = index_n + c * pooled_width * pooled_height;
228 int pre_calc_index = 0;
229
230 for (int ph = 0; ph < pooled_height; ph++) {
231 for (int pw = 0; pw < pooled_width; pw++) {
232 int index = index_n_c + ph * pooled_width + pw;
233
234 int c_unpooled = c;
235 int channels_unpooled = channels;
236 if (position_sensitive) {
237 c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
238 channels_unpooled = channels * pooled_height * pooled_width;
239 }
240 const T* offset_bottom_data =
241 bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled)
242 * height * width;
243 T output_val = 0.;
244 for (int iy = 0; iy < roi_bin_grid_h; iy++) {
245 for (int ix = 0; ix < roi_bin_grid_w; ix++) {
246 PreCalc<T> pc = pre_calc[pre_calc_index];
247 output_val += pc.w1 * offset_bottom_data[pc.pos1] +
248 pc.w2 * offset_bottom_data[pc.pos2] +
249 pc.w3 * offset_bottom_data[pc.pos3] +
250 pc.w4 * offset_bottom_data[pc.pos4];
251
252 pre_calc_index += 1;
253 }
254 }
255 output_val /= count;
256
257 top_data[index] = output_val;
258 } // for pw
259 } // for ph
260 } // for c
261 } // for n
262 }
263
264
265 template <typename T>
bilinear_interpolate_gradient(const int height,const int width,T y,T x,T * w1,T * w2,T * w3,T * w4,int * x_low,int * x_high,int * y_low,int * y_high,const int)266 void bilinear_interpolate_gradient(
267 const int height,
268 const int width,
269 T y,
270 T x,
271 T* w1,
272 T* w2,
273 T* w3,
274 T* w4,
275 int* x_low,
276 int* x_high,
277 int* y_low,
278 int* y_high,
279 const int /*index*/ /* index for debug only*/) {
280 // deal with cases that inverse elements are out of feature map boundary
281 if (y < -1.0 || y > height || x < -1.0 || x > width) {
282 // empty
283 *w1 = *w2 = *w3 = *w4 = 0.;
284 *x_low = *x_high = *y_low = *y_high = -1;
285 return;
286 }
287
288 if (y <= 0) {
289 y = 0;
290 }
291 if (x <= 0) {
292 x = 0;
293 }
294
295 *y_low = static_cast<int>(y);
296 *x_low = static_cast<int>(x);
297
298 if (*y_low >= height - 1) {
299 *y_high = *y_low = height - 1;
300 y = (T)*y_low;
301 } else {
302 *y_high = *y_low + 1;
303 }
304
305 if (*x_low >= width - 1) {
306 *x_high = *x_low = width - 1;
307 x = (T)*x_low;
308 } else {
309 *x_high = *x_low + 1;
310 }
311
312 T ly = y - *y_low;
313 T lx = x - *x_low;
314 T hy = 1. - ly, hx = 1. - lx;
315
316 *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
317
318 return;
319 }
320
321 template <class T>
add(const T & val,T * address)322 inline void add(const T& val, T* address) {
323 *address += val;
324 }
325
326 template <typename T>
ROIAlignBackward(const int nthreads,const T * top_diff,const int,const T & spatial_scale,const bool position_sensitive,const bool continuous_coordinate,const int channels,const int height,const int width,const int pooled_height,const int pooled_width,const int sampling_ratio,T * bottom_diff,const T * bottom_rois,int rois_cols)327 void ROIAlignBackward(
328 const int nthreads,
329 const T* top_diff,
330 const int /*num_rois*/,
331 const T& spatial_scale,
332 const bool position_sensitive,
333 const bool continuous_coordinate,
334 const int channels,
335 const int height,
336 const int width,
337 const int pooled_height,
338 const int pooled_width,
339 const int sampling_ratio,
340 T* bottom_diff,
341 const T* bottom_rois,
342 int rois_cols) {
343 DCHECK(rois_cols == 4 || rois_cols == 5);
344
345 for (int index = 0; index < nthreads; index++) {
346 // (n, c, ph, pw) is an element in the pooled output
347 int pw = index % pooled_width;
348 int ph = (index / pooled_width) % pooled_height;
349 int c = (index / pooled_width / pooled_height) % channels;
350 int n = index / pooled_width / pooled_height / channels;
351
352 const T* offset_bottom_rois = bottom_rois + n * rois_cols;
353 int roi_batch_ind = 0;
354 if (rois_cols == 5) {
355 roi_batch_ind = offset_bottom_rois[0];
356 if (roi_batch_ind < 0) continue;
357 offset_bottom_rois++;
358 }
359
360 // Do not using rounding; this implementation detail is critical
361 T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : static_cast<T>(0);
362 T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
363 T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
364 T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
365 T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;
366
367 T roi_width = roi_end_w - roi_start_w;
368 T roi_height = roi_end_h - roi_start_h;
369 if (!continuous_coordinate) { // backward compatiblity
370 // Force malformed ROIs to be 1x1
371 roi_width = std::max(roi_width, (T)1.);
372 roi_height = std::max(roi_height, (T)1.);
373 }
374 T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
375 T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
376
377 int c_unpooled = c;
378 int channels_unpooled = channels;
379 if (position_sensitive) {
380 c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
381 channels_unpooled = channels * pooled_height * pooled_width;
382 }
383 T* offset_bottom_diff =
384 bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled)
385 * height * width;
386
387 int top_offset = (n * channels + c) * pooled_height * pooled_width;
388 const T* offset_top_diff = top_diff + top_offset;
389 const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
390
391 // We use roi_bin_grid to sample the grid and mimic integral
392 int roi_bin_grid_h = (sampling_ratio > 0)
393 ? sampling_ratio
394 : std::ceil(roi_height / pooled_height); // e.g., = 2
395 int roi_bin_grid_w =
396 (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width);
397
398 // We do average (integral) pooling inside a bin
399 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
400
401 for (int iy = 0; iy < roi_bin_grid_h; iy++) {
402 const T y = roi_start_h + ph * bin_size_h +
403 static_cast<T>(iy + .5f) * bin_size_h /
404 static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
405 for (int ix = 0; ix < roi_bin_grid_w; ix++) {
406 const T x = roi_start_w + pw * bin_size_w +
407 static_cast<T>(ix + .5f) * bin_size_w /
408 static_cast<T>(roi_bin_grid_w);
409
410 T w1, w2, w3, w4;
411 int x_low, x_high, y_low, y_high;
412
413 bilinear_interpolate_gradient(
414 height,
415 width,
416 y,
417 x,
418 &w1,
419 &w2,
420 &w3,
421 &w4,
422 &x_low,
423 &x_high,
424 &y_low,
425 &y_high,
426 index);
427
428 T g1 = top_diff_this_bin * w1 / count;
429 T g2 = top_diff_this_bin * w2 / count;
430 T g3 = top_diff_this_bin * w3 / count;
431 T g4 = top_diff_this_bin * w4 / count;
432
433 if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
434 // atomic add is not needed for now since it is single threaded
435 add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
436 add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
437 add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
438 add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
439 } // if
440 } // ix
441 } // iy
442 } // for
443 } // ROIAlignBackward
444
445
446 template<typename xpu>
ROIAlignForwardCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data)447 void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
448 const OpContext& ctx,
449 const std::vector<TBlob>& in_data,
450 const std::vector<OpReqType>& req,
451 const std::vector<TBlob>& out_data) {
452 using namespace mshadow;
453 size_t expected_in = 2;
454 size_t expected_out = 1;
455 CHECK_EQ(in_data.size(), expected_in);
456 CHECK_EQ(out_data.size(), expected_out);
457 CHECK_EQ(out_data[roialign::kOut].shape_[0], in_data[roialign::kBox].shape_[0]);
458
459 const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
460
461 const int count = out_data[roialign::kOut].Size();
462 // const int num_rois = in_data[roialign::kBox].size(0);
463 const int channels = out_data[roialign::kOut].size(1); // channels of pooled output
464 const int height = in_data[roialign::kData].size(2);
465 const int width = in_data[roialign::kData].size(3);
466 const int pooled_height = out_data[roialign::kOut].size(2);
467 const int pooled_width = out_data[roialign::kOut].size(3);
468 const int rois_cols = in_data[roialign::kBox].size(1);
469
470 // assume all the data and gradient have the same type
471 MSHADOW_REAL_TYPE_SWITCH(in_data[0].type_flag_, DType, {
472 const DType *bottom_data = in_data[roialign::kData].dptr<DType>();
473 const DType *bottom_rois = in_data[roialign::kBox].dptr<DType>();
474 DType *top_data = out_data[roialign::kOut].dptr<DType>();
475
476 ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, param.position_sensitive,
477 param.aligned, channels, height, width, pooled_height, pooled_width,
478 param.sample_ratio, bottom_rois, rois_cols, top_data);
479 })
480 }
481
482 template<typename xpu>
ROIAlignBackwardCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)483 void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
484 const OpContext& ctx,
485 const std::vector<TBlob>& inputs,
486 const std::vector<OpReqType>& req,
487 const std::vector<TBlob>& outputs) {
488 using namespace mshadow;
489
490 CHECK_EQ(inputs.size(), 2);
491 CHECK_EQ(outputs.size(), 2);
492 // the order here relates to the order in ROIAlignGrad
493 std::vector<TBlob> out_grad(1, inputs[0]);
494 std::vector<TBlob> in_data(1, inputs[1]);
495 // std::vector<TBlob> out_data(1, inputs[2]);
496
497 CHECK_EQ(out_grad[0].shape_[0], in_data[0].shape_[0]);
498 CHECK_NE(req[0], kWriteInplace) <<
499 "ROIAlign: Backward doesn't support kWriteInplace.";
500 CHECK_NE(req[1], kWriteInplace) <<
501 "ROIAlign: Backward doesn't support kWriteInplace.";
502
503 const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
504
505 const int count = out_grad[0].Size();
506 const int num_rois = in_data[0].size(0);
507 const int channels = out_grad[0].size(1); // channels of pooled output
508 const int height = outputs[0].size(2);
509 const int width = outputs[0].size(3);
510 const int pooled_height = out_grad[0].size(2);
511 const int pooled_width = out_grad[0].size(3);
512 const int rois_cols = in_data[0].size(1);
513
514 Stream<cpu> *s = ctx.get_stream<cpu>();
515 // assume all the data and gradient have the same type
516 MSHADOW_REAL_TYPE_SWITCH(out_grad[0].type_flag_, DType, {
517 const DType *top_diff = out_grad[0].dptr<DType>();
518 const DType *bottom_rois = in_data[0].dptr<DType>();
519 DType *grad_in = outputs[0].dptr<DType>();
520
521 if (kAddTo == req[roialign::kData] || kWriteTo == req[roialign::kData]) {
522 if (kWriteTo == req[roialign::kData]) {
523 Fill<false>(s, outputs[0], kWriteTo, static_cast<DType>(0));
524 }
525 ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
526 param.position_sensitive, param.aligned, channels, height, width,
527 pooled_height, pooled_width, param.sample_ratio, grad_in,
528 bottom_rois, rois_cols);
529 }
530 if (kWriteTo == req[roialign::kBox]) {
531 Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
532 }
533 })
534 }
535
536 DMLC_REGISTER_PARAMETER(ROIAlignParam);
537
538 NNVM_REGISTER_OP(_contrib_ROIAlign)
539 .describe(R"code(
540 This operator takes a 4D feature map as an input array and region proposals as `rois`,
541 then align the feature map over sub-regions of input and produces a fixed-sized output array.
542 This operator is typically used in Faster R-CNN & Mask R-CNN networks. If roi batchid is less
543 than 0, it will be ignored, and the corresponding output will be set to 0.
544
545 Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning
546 the extracted features with the input. RoIAlign computes the value of each sampling point
547 by bilinear interpolation from the nearby grid points on the feature map. No quantization is
548 performed on any coordinates involved in the RoI, its bins, or the sampling points.
549 Bilinear interpolation is used to compute the exact values of the
550 input features at four regularly sampled locations in each RoI bin.
551 Then the feature map can be aggregated by avgpooling.
552
553
554 References
555 ----------
556
557 He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
558 )code" ADD_FILELINE)
559 .set_num_inputs(2)
560 .set_num_outputs(1)
561 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon67b181220102(const NodeAttrs& attrs) 562 [](const NodeAttrs& attrs) {
563 return std::vector<std::string>{"data", "rois"};
564 })
565 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
__anon67b181220202(const NodeAttrs& attrs) 566 [](const NodeAttrs& attrs) {
567 return std::vector<std::string>{"output"};
568 })
569 .set_attr_parser(ParamParser<ROIAlignParam>)
570 .set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
__anon67b181220302(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape)571 mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape){
572 using namespace mshadow;
573 const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
574 CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
575 // data: [batch_size, c, h, w]
576 mxnet::TShape dshape = in_shape->at(roialign::kData);
577 CHECK_EQ(dshape.ndim(), 4) << "data should be a 4D tensor";
578 // bbox: [num_rois, 5]
579 mxnet::TShape bshape = in_shape->at(roialign::kBox);
580 CHECK_EQ(bshape.ndim(), 2) << "bbox should be a 2D tensor of shape [batch, 5]";
581 CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]";
582 // out: [num_rois, c, pooled_h, pooled_w]
583 out_shape->clear();
584 if (param.position_sensitive) {
585 CHECK_EQ(dshape[1] % (param.pooled_size[0]*param.pooled_size[1]), 0) <<
586 "Input channels should be divided by pooled_size[0]*pooled_size[1]"
587 "when position_sensitive is true.";
588 out_shape->push_back(
589 Shape4(bshape[0], dshape[1]/param.pooled_size[0]/param.pooled_size[1],
590 param.pooled_size[0], param.pooled_size[1]));
591 } else {
592 out_shape->push_back(
593 Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1]));
594 }
595 return true;
596 })
597 .set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
__anon67b181220402(const nnvm::NodeAttrs& attrs, std::vector<int> *in_type, std::vector<int> *out_type) 598 std::vector<int> *in_type, std::vector<int> *out_type) {
599 CHECK_EQ(in_type->size(), 2);
600 int dtype = (*in_type)[0];
601 CHECK_EQ(dtype, (*in_type)[1]);
602 CHECK_NE(dtype, -1) << "Input must have specified type";
603
604 out_type->clear();
605 out_type->push_back(dtype);
606 return true;
607 })
608 .set_attr<FCompute>("FCompute<cpu>", ROIAlignForwardCompute<cpu>)
609 .set_attr<nnvm::FGradient>("FGradient",
__anon67b181220502(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) 610 [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
611 std::vector<nnvm::NodeEntry> heads;
612 heads.push_back(ograds[roialign::kOut]);
613 heads.push_back(n->inputs[roialign::kBox]);
614 return MakeGradNode("_backward_ROIAlign", n, heads, n->attrs.dict);
615 })
616 .add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps")
617 .add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array, "
618 "if batchid is less than 0, it will be ignored.")
619 .add_arguments(ROIAlignParam::__FIELDS__());
620
621
622 NNVM_REGISTER_OP(_backward_ROIAlign)
623 .set_num_inputs(2)
624 .set_num_outputs(2)
625 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
626 .set_attr_parser(ParamParser<ROIAlignParam>)
627 .set_attr<FCompute>("FCompute<cpu>", ROIAlignBackwardCompute<cpu>);
628
629 } // namespace op
630 } // namespace mxnet
631
632