1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file roi_align.cc
21  * \brief roi align operator
22  * \author Hang Zhang, Shesung
23  * Adapted from Caffe2
24 */
25 #include "./roi_align-inl.h"
26 
27 
28 namespace mxnet {
29 namespace op {
30 
31 template <typename T>
32 struct PreCalc {
33   int pos1;
34   int pos2;
35   int pos3;
36   int pos4;
37   T w1;
38   T w2;
39   T w3;
40   T w4;
41 };
42 
43 template <typename T>
pre_calc_for_bilinear_interpolate(const int height,const int width,const int pooled_height,const int pooled_width,const int iy_upper,const int ix_upper,T roi_start_h,T roi_start_w,T bin_size_h,T bin_size_w,int roi_bin_grid_h,int roi_bin_grid_w,std::vector<PreCalc<T>> * pre_calc)44 void pre_calc_for_bilinear_interpolate(
45     const int height,
46     const int width,
47     const int pooled_height,
48     const int pooled_width,
49     const int iy_upper,
50     const int ix_upper,
51     T roi_start_h,
52     T roi_start_w,
53     T bin_size_h,
54     T bin_size_w,
55     int roi_bin_grid_h,
56     int roi_bin_grid_w,
57     std::vector<PreCalc<T>>* pre_calc) {
58   int pre_calc_index = 0;
59   for (int ph = 0; ph < pooled_height; ph++) {
60     for (int pw = 0; pw < pooled_width; pw++) {
61       for (int iy = 0; iy < iy_upper; iy++) {
62         const T yy = roi_start_h + ph * bin_size_h +
63             static_cast<T>(iy + .5f) * bin_size_h /
64                 static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
65         for (int ix = 0; ix < ix_upper; ix++) {
66           const T xx = roi_start_w + pw * bin_size_w +
67               static_cast<T>(ix + .5f) * bin_size_w /
68                   static_cast<T>(roi_bin_grid_w);
69 
70           T x = xx;
71           T y = yy;
72           // deal with: inverse elements are out of feature map boundary
73           if (y < -1.0 || y > height || x < -1.0 || x > width) {
74             // empty
75             PreCalc<T> pc;
76             pc.pos1 = 0;
77             pc.pos2 = 0;
78             pc.pos3 = 0;
79             pc.pos4 = 0;
80             pc.w1 = 0;
81             pc.w2 = 0;
82             pc.w3 = 0;
83             pc.w4 = 0;
84             pre_calc->at(pre_calc_index) = pc;
85             pre_calc_index += 1;
86             continue;
87           }
88 
89           if (y <= 0) {
90             y = 0;
91           }
92           if (x <= 0) {
93             x = 0;
94           }
95 
96           int y_low = static_cast<int>(y);
97           int x_low = static_cast<int>(x);
98           int y_high;
99           int x_high;
100 
101           if (y_low >= height - 1) {
102             y_high = y_low = height - 1;
103             y = (T)y_low;
104           } else {
105             y_high = y_low + 1;
106           }
107 
108           if (x_low >= width - 1) {
109             x_high = x_low = width - 1;
110             x = (T)x_low;
111           } else {
112             x_high = x_low + 1;
113           }
114 
115           T ly = y - y_low;
116           T lx = x - x_low;
117           T hy = 1. - ly, hx = 1. - lx;
118           T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
119 
120           // save weights and indeces
121           PreCalc<T> pc;
122           pc.pos1 = y_low * width + x_low;
123           pc.pos2 = y_low * width + x_high;
124           pc.pos3 = y_high * width + x_low;
125           pc.pos4 = y_high * width + x_high;
126           pc.w1 = w1;
127           pc.w2 = w2;
128           pc.w3 = w3;
129           pc.w4 = w4;
130           pre_calc->at(pre_calc_index) = pc;
131 
132           pre_calc_index += 1;
133         }
134       }
135     }
136   }
137 }
138 
139 template <typename T>
ROIAlignForward(const int nthreads,const T * bottom_data,const T & spatial_scale,const bool position_sensitive,const bool continuous_coordinate,const int channels,const int height,const int width,const int pooled_height,const int pooled_width,const int sampling_ratio,const T * bottom_rois,int roi_cols,T * top_data)140 void ROIAlignForward(
141     const int nthreads,
142     const T* bottom_data,
143     const T& spatial_scale,
144     const bool position_sensitive,
145     const bool continuous_coordinate,
146     const int channels,
147     const int height,
148     const int width,
149     const int pooled_height,
150     const int pooled_width,
151     const int sampling_ratio,
152     const T* bottom_rois,
153     int roi_cols,
154     T* top_data) {
155   DCHECK(roi_cols == 4 || roi_cols == 5);
156 
157   int n_rois = nthreads / channels / pooled_width / pooled_height;
158   // (n, c, ph, pw) is an element in the pooled output
159   // can be parallelized using omp
160 #pragma omp parallel for \
161 num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
162   for (int n = 0; n < n_rois; n++) {
163     int index_n = n * channels * pooled_width * pooled_height;
164 
165     // roi could have 4 or 5 columns
166     const T* offset_bottom_rois = bottom_rois + n * roi_cols;
167     int roi_batch_ind = 0;
168     if (roi_cols == 5) {
169       roi_batch_ind = offset_bottom_rois[0];
170       if (roi_batch_ind < 0) {
171         top_data[n] = 0;
172         continue;
173       }
174       offset_bottom_rois++;
175     }
176 
177     // Do not using rounding; this implementation detail is critical
178     T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : static_cast<T>(0);
179     T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
180     T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
181     T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
182     T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;
183 
184     T roi_width = roi_end_w - roi_start_w;
185     T roi_height = roi_end_h - roi_start_h;
186     if (continuous_coordinate) {
187       CHECK_GT(roi_width, 0.);
188       CHECK_GT(roi_height, 0.);
189     } else {  // backward compatiblity
190       // Force malformed ROIs to be 1x1
191       roi_width = std::max(roi_width, (T)1.);
192       roi_height = std::max(roi_height, (T)1.);
193     }
194     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
195     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
196 
197     // We use roi_bin_grid to sample the grid and mimic integral
198     int roi_bin_grid_h = (sampling_ratio > 0)
199         ? sampling_ratio
200         : std::ceil(roi_height / pooled_height);  // e.g., = 2
201     int roi_bin_grid_w =
202         (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width);
203 
204     // We do average (integral) pooling inside a bin
205     const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
206 
207     // we want to precalculate indeces and weights shared by all chanels,
208     // this is the key point of optimiation
209     std::vector<PreCalc<T>> pre_calc(
210         roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
211     pre_calc_for_bilinear_interpolate(
212         height,
213         width,
214         pooled_height,
215         pooled_width,
216         roi_bin_grid_h,
217         roi_bin_grid_w,
218         roi_start_h,
219         roi_start_w,
220         bin_size_h,
221         bin_size_w,
222         roi_bin_grid_h,
223         roi_bin_grid_w,
224         &pre_calc);
225 
226     for (int c = 0; c < channels; c++) {
227       int index_n_c = index_n + c * pooled_width * pooled_height;
228       int pre_calc_index = 0;
229 
230       for (int ph = 0; ph < pooled_height; ph++) {
231         for (int pw = 0; pw < pooled_width; pw++) {
232           int index = index_n_c + ph * pooled_width + pw;
233 
234           int c_unpooled = c;
235           int channels_unpooled = channels;
236           if (position_sensitive) {
237             c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
238             channels_unpooled = channels * pooled_height * pooled_width;
239           }
240           const T* offset_bottom_data =
241               bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled)
242               * height * width;
243           T output_val = 0.;
244           for (int iy = 0; iy < roi_bin_grid_h; iy++) {
245             for (int ix = 0; ix < roi_bin_grid_w; ix++) {
246               PreCalc<T> pc = pre_calc[pre_calc_index];
247               output_val += pc.w1 * offset_bottom_data[pc.pos1] +
248                   pc.w2 * offset_bottom_data[pc.pos2] +
249                   pc.w3 * offset_bottom_data[pc.pos3] +
250                   pc.w4 * offset_bottom_data[pc.pos4];
251 
252               pre_calc_index += 1;
253             }
254           }
255           output_val /= count;
256 
257           top_data[index] = output_val;
258         }  // for pw
259       }  // for ph
260     }  // for c
261   }  // for n
262 }
263 
264 
265 template <typename T>
bilinear_interpolate_gradient(const int height,const int width,T y,T x,T * w1,T * w2,T * w3,T * w4,int * x_low,int * x_high,int * y_low,int * y_high,const int)266 void bilinear_interpolate_gradient(
267     const int height,
268     const int width,
269     T y,
270     T x,
271     T* w1,
272     T* w2,
273     T* w3,
274     T* w4,
275     int* x_low,
276     int* x_high,
277     int* y_low,
278     int* y_high,
279     const int /*index*/ /* index for debug only*/) {
280   // deal with cases that inverse elements are out of feature map boundary
281   if (y < -1.0 || y > height || x < -1.0 || x > width) {
282     // empty
283     *w1 = *w2 = *w3 = *w4 = 0.;
284     *x_low = *x_high = *y_low = *y_high = -1;
285     return;
286   }
287 
288   if (y <= 0) {
289     y = 0;
290   }
291   if (x <= 0) {
292     x = 0;
293   }
294 
295   *y_low = static_cast<int>(y);
296   *x_low = static_cast<int>(x);
297 
298   if (*y_low >= height - 1) {
299     *y_high = *y_low = height - 1;
300     y = (T)*y_low;
301   } else {
302     *y_high = *y_low + 1;
303   }
304 
305   if (*x_low >= width - 1) {
306     *x_high = *x_low = width - 1;
307     x = (T)*x_low;
308   } else {
309     *x_high = *x_low + 1;
310   }
311 
312   T ly = y - *y_low;
313   T lx = x - *x_low;
314   T hy = 1. - ly, hx = 1. - lx;
315 
316   *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
317 
318   return;
319 }
320 
321 template <class T>
add(const T & val,T * address)322 inline void add(const T& val, T* address) {
323   *address += val;
324 }
325 
326 template <typename T>
ROIAlignBackward(const int nthreads,const T * top_diff,const int,const T & spatial_scale,const bool position_sensitive,const bool continuous_coordinate,const int channels,const int height,const int width,const int pooled_height,const int pooled_width,const int sampling_ratio,T * bottom_diff,const T * bottom_rois,int rois_cols)327 void ROIAlignBackward(
328     const int nthreads,
329     const T* top_diff,
330     const int /*num_rois*/,
331     const T& spatial_scale,
332     const bool position_sensitive,
333     const bool continuous_coordinate,
334     const int channels,
335     const int height,
336     const int width,
337     const int pooled_height,
338     const int pooled_width,
339     const int sampling_ratio,
340     T* bottom_diff,
341     const T* bottom_rois,
342     int rois_cols) {
343   DCHECK(rois_cols == 4 || rois_cols == 5);
344 
345   for (int index = 0; index < nthreads; index++) {
346     // (n, c, ph, pw) is an element in the pooled output
347     int pw = index % pooled_width;
348     int ph = (index / pooled_width) % pooled_height;
349     int c = (index / pooled_width / pooled_height) % channels;
350     int n = index / pooled_width / pooled_height / channels;
351 
352     const T* offset_bottom_rois = bottom_rois + n * rois_cols;
353     int roi_batch_ind = 0;
354     if (rois_cols == 5) {
355       roi_batch_ind = offset_bottom_rois[0];
356       if (roi_batch_ind < 0) continue;
357       offset_bottom_rois++;
358     }
359 
360     // Do not using rounding; this implementation detail is critical
361     T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : static_cast<T>(0);
362     T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
363     T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
364     T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
365     T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;
366 
367     T roi_width = roi_end_w - roi_start_w;
368     T roi_height = roi_end_h - roi_start_h;
369     if (!continuous_coordinate) {  // backward compatiblity
370       // Force malformed ROIs to be 1x1
371       roi_width = std::max(roi_width, (T)1.);
372       roi_height = std::max(roi_height, (T)1.);
373     }
374     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
375     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
376 
377     int c_unpooled = c;
378     int channels_unpooled = channels;
379     if (position_sensitive) {
380       c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
381       channels_unpooled = channels * pooled_height * pooled_width;
382     }
383     T* offset_bottom_diff =
384         bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled)
385         * height * width;
386 
387     int top_offset = (n * channels + c) * pooled_height * pooled_width;
388     const T* offset_top_diff = top_diff + top_offset;
389     const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
390 
391     // We use roi_bin_grid to sample the grid and mimic integral
392     int roi_bin_grid_h = (sampling_ratio > 0)
393         ? sampling_ratio
394         : std::ceil(roi_height / pooled_height);  // e.g., = 2
395     int roi_bin_grid_w =
396         (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width);
397 
398     // We do average (integral) pooling inside a bin
399     const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
400 
401     for (int iy = 0; iy < roi_bin_grid_h; iy++) {
402       const T y = roi_start_h + ph * bin_size_h +
403           static_cast<T>(iy + .5f) * bin_size_h /
404               static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
405       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
406         const T x = roi_start_w + pw * bin_size_w +
407             static_cast<T>(ix + .5f) * bin_size_w /
408                 static_cast<T>(roi_bin_grid_w);
409 
410         T w1, w2, w3, w4;
411         int x_low, x_high, y_low, y_high;
412 
413         bilinear_interpolate_gradient(
414             height,
415             width,
416             y,
417             x,
418             &w1,
419             &w2,
420             &w3,
421             &w4,
422             &x_low,
423             &x_high,
424             &y_low,
425             &y_high,
426             index);
427 
428         T g1 = top_diff_this_bin * w1 / count;
429         T g2 = top_diff_this_bin * w2 / count;
430         T g3 = top_diff_this_bin * w3 / count;
431         T g4 = top_diff_this_bin * w4 / count;
432 
433         if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
434           // atomic add is not needed for now since it is single threaded
435           add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
436           add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
437           add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
438           add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
439         }  // if
440       }  // ix
441     }  // iy
442   }  // for
443 }  // ROIAlignBackward
444 
445 
446 template<typename xpu>
ROIAlignForwardCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data)447 void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
448                             const OpContext& ctx,
449                             const std::vector<TBlob>& in_data,
450                             const std::vector<OpReqType>& req,
451                             const std::vector<TBlob>& out_data) {
452   using namespace mshadow;
453   size_t expected_in = 2;
454   size_t expected_out = 1;
455   CHECK_EQ(in_data.size(), expected_in);
456   CHECK_EQ(out_data.size(), expected_out);
457   CHECK_EQ(out_data[roialign::kOut].shape_[0], in_data[roialign::kBox].shape_[0]);
458 
459   const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
460 
461   const int count = out_data[roialign::kOut].Size();
462   // const int num_rois = in_data[roialign::kBox].size(0);
463   const int channels = out_data[roialign::kOut].size(1);  // channels of pooled output
464   const int height = in_data[roialign::kData].size(2);
465   const int width = in_data[roialign::kData].size(3);
466   const int pooled_height = out_data[roialign::kOut].size(2);
467   const int pooled_width = out_data[roialign::kOut].size(3);
468   const int rois_cols = in_data[roialign::kBox].size(1);
469 
470   // assume all the data and gradient have the same type
471   MSHADOW_REAL_TYPE_SWITCH(in_data[0].type_flag_, DType, {
472     const DType *bottom_data = in_data[roialign::kData].dptr<DType>();
473     const DType *bottom_rois = in_data[roialign::kBox].dptr<DType>();
474     DType *top_data = out_data[roialign::kOut].dptr<DType>();
475 
476     ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, param.position_sensitive,
477                            param.aligned, channels, height, width, pooled_height, pooled_width,
478                            param.sample_ratio, bottom_rois, rois_cols, top_data);
479   })
480 }
481 
482 template<typename xpu>
ROIAlignBackwardCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)483 void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
484                              const OpContext& ctx,
485                              const std::vector<TBlob>& inputs,
486                              const std::vector<OpReqType>& req,
487                              const std::vector<TBlob>& outputs) {
488   using namespace mshadow;
489 
490   CHECK_EQ(inputs.size(), 2);
491   CHECK_EQ(outputs.size(), 2);
492   // the order here relates to the order in ROIAlignGrad
493   std::vector<TBlob> out_grad(1, inputs[0]);
494   std::vector<TBlob> in_data(1, inputs[1]);
495   // std::vector<TBlob> out_data(1, inputs[2]);
496 
497   CHECK_EQ(out_grad[0].shape_[0], in_data[0].shape_[0]);
498   CHECK_NE(req[0], kWriteInplace) <<
499     "ROIAlign: Backward doesn't support kWriteInplace.";
500   CHECK_NE(req[1], kWriteInplace) <<
501     "ROIAlign: Backward doesn't support kWriteInplace.";
502 
503   const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
504 
505   const int count = out_grad[0].Size();
506   const int num_rois = in_data[0].size(0);
507   const int channels = out_grad[0].size(1);  // channels of pooled output
508   const int height = outputs[0].size(2);
509   const int width = outputs[0].size(3);
510   const int pooled_height = out_grad[0].size(2);
511   const int pooled_width = out_grad[0].size(3);
512   const int rois_cols = in_data[0].size(1);
513 
514   Stream<cpu> *s = ctx.get_stream<cpu>();
515   // assume all the data and gradient have the same type
516   MSHADOW_REAL_TYPE_SWITCH(out_grad[0].type_flag_, DType, {
517     const DType *top_diff = out_grad[0].dptr<DType>();
518     const DType *bottom_rois = in_data[0].dptr<DType>();
519     DType *grad_in = outputs[0].dptr<DType>();
520 
521     if (kAddTo == req[roialign::kData] || kWriteTo == req[roialign::kData]) {
522       if (kWriteTo == req[roialign::kData]) {
523         Fill<false>(s, outputs[0], kWriteTo, static_cast<DType>(0));
524       }
525       ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
526                      param.position_sensitive, param.aligned, channels, height, width,
527                      pooled_height, pooled_width, param.sample_ratio, grad_in,
528                      bottom_rois, rois_cols);
529     }
530     if (kWriteTo == req[roialign::kBox]) {
531       Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
532     }
533   })
534 }
535 
536 DMLC_REGISTER_PARAMETER(ROIAlignParam);
537 
538 NNVM_REGISTER_OP(_contrib_ROIAlign)
539 .describe(R"code(
540 This operator takes a 4D feature map as an input array and region proposals as `rois`,
541 then align the feature map over sub-regions of input and produces a fixed-sized output array.
542 This operator is typically used in Faster R-CNN & Mask R-CNN networks. If roi batchid is less
543 than 0, it will be ignored, and the corresponding output will be set to 0.
544 
545 Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning
546 the extracted features with the input. RoIAlign computes the value of each sampling point
547 by bilinear interpolation from the nearby grid points on the feature map. No quantization is
548 performed on any coordinates involved in the RoI, its bins, or the sampling points.
549 Bilinear interpolation is used to compute the exact values of the
550 input features at four regularly sampled locations in each RoI bin.
551 Then the feature map can be aggregated by avgpooling.
552 
553 
554 References
555 ----------
556 
557 He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
558 )code" ADD_FILELINE)
559 .set_num_inputs(2)
560 .set_num_outputs(1)
561 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon67b181220102(const NodeAttrs& attrs) 562     [](const NodeAttrs& attrs) {
563   return std::vector<std::string>{"data", "rois"};
564 })
565 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
__anon67b181220202(const NodeAttrs& attrs) 566     [](const NodeAttrs& attrs) {
567   return std::vector<std::string>{"output"};
568 })
569 .set_attr_parser(ParamParser<ROIAlignParam>)
570 .set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
__anon67b181220302(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape)571       mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape){
572   using namespace mshadow;
573   const ROIAlignParam& param = nnvm::get<ROIAlignParam>(attrs.parsed);
574   CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
575   // data: [batch_size, c, h, w]
576   mxnet::TShape dshape = in_shape->at(roialign::kData);
577   CHECK_EQ(dshape.ndim(), 4) << "data should be a 4D tensor";
578   // bbox: [num_rois, 5]
579   mxnet::TShape bshape = in_shape->at(roialign::kBox);
580   CHECK_EQ(bshape.ndim(), 2) << "bbox should be a 2D tensor of shape [batch, 5]";
581   CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]";
582   // out: [num_rois, c, pooled_h, pooled_w]
583   out_shape->clear();
584   if (param.position_sensitive) {
585     CHECK_EQ(dshape[1] % (param.pooled_size[0]*param.pooled_size[1]), 0) <<
586       "Input channels should be divided by pooled_size[0]*pooled_size[1]"
587       "when position_sensitive is true.";
588     out_shape->push_back(
589          Shape4(bshape[0], dshape[1]/param.pooled_size[0]/param.pooled_size[1],
590                 param.pooled_size[0], param.pooled_size[1]));
591   } else {
592     out_shape->push_back(
593          Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1]));
594   }
595   return true;
596 })
597 .set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
__anon67b181220402(const nnvm::NodeAttrs& attrs, std::vector<int> *in_type, std::vector<int> *out_type) 598       std::vector<int> *in_type, std::vector<int> *out_type) {
599   CHECK_EQ(in_type->size(), 2);
600   int dtype = (*in_type)[0];
601   CHECK_EQ(dtype, (*in_type)[1]);
602   CHECK_NE(dtype, -1) << "Input must have specified type";
603 
604   out_type->clear();
605   out_type->push_back(dtype);
606   return true;
607 })
608 .set_attr<FCompute>("FCompute<cpu>", ROIAlignForwardCompute<cpu>)
609 .set_attr<nnvm::FGradient>("FGradient",
__anon67b181220502(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) 610   [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
611     std::vector<nnvm::NodeEntry> heads;
612     heads.push_back(ograds[roialign::kOut]);
613     heads.push_back(n->inputs[roialign::kBox]);
614     return MakeGradNode("_backward_ROIAlign", n, heads, n->attrs.dict);
615   })
616 .add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps")
617 .add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array, "
618               "if batchid is less than 0, it will be ignored.")
619 .add_arguments(ROIAlignParam::__FIELDS__());
620 
621 
622 NNVM_REGISTER_OP(_backward_ROIAlign)
623 .set_num_inputs(2)
624 .set_num_outputs(2)
625 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
626 .set_attr_parser(ParamParser<ROIAlignParam>)
627 .set_attr<FCompute>("FCompute<cpu>", ROIAlignBackwardCompute<cpu>);
628 
629 }  // namespace op
630 }  // namespace mxnet
631 
632