1 /*!
2  * Copyright (c) 2017 Microsoft
3  * Licensed under The MIT License [see LICENSE for details]
4  * \file deformable_psroi_pooling.cu
5  * \brief
6  * \author Yi Li, Guodong Zhang, Jifeng Dai
7  *
8  * Code from https://github.com/msracver/Deformable-ConvNets/blob/d51075968c5fd40b37a55d20c8e945c1f181d529/rfcn/operator_cxx/deformable_psroi_pooling.cu
9  */
10 #include "./deformable_psroi_pooling-inl.h"
11 #include <mshadow/tensor.h>
12 #include <mshadow/cuda/reduce.cuh>
13 #include <algorithm>
14 #include <vector>
15 #include "../../common/cuda_utils.h"
16 #include "../mxnet_op.h"
17 
18 #define DeformablePSROIPOOLING_CUDA_CHECK(condition) \
19   /* Code block avoids redefinition of cudaError_t error */ \
20   do { \
21     cudaError_t error = condition; \
22     CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
23   } while (0)
24 
25 namespace mshadow {
26 namespace cuda {
27   template <typename DType>
bilinear_interp(const DType * data,const DType x,const DType y,const index_t width,const index_t height)28   __device__ DType bilinear_interp(const DType* data,
29                                    const DType x, const DType y,
30                                    const index_t width, const index_t height) {
31     index_t x1 = floor(x);
32     index_t x2 = ceil(x);
33     index_t y1 = floor(y);
34     index_t y2 = ceil(y);
35     DType dist_x = static_cast<DType>(x - x1);
36     DType dist_y = static_cast<DType>(y - y1);
37     DType value11 = data[y1 * width + x1];
38     DType value12 = data[y2 * width + x1];
39     DType value21 = data[y1 * width + x2];
40     DType value22 = data[y2 * width + x2];
41     DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 +
42       dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
43     return value;
44   }
45 
46   template <typename DType>
DeformablePSROIPoolForwardKernel(const index_t count,const DType * bottom_data,const DType spatial_scale,const index_t channels,const index_t height,const index_t width,const index_t pooled_height,const index_t pooled_width,const DType * bottom_rois,const DType * bottom_trans,const bool no_trans,const DType trans_std,const index_t sample_per_part,const index_t output_dim,const index_t group_size,const index_t part_size,const index_t num_classes,const index_t channels_each_class,DType * top_data,DType * top_count)47   __global__ void DeformablePSROIPoolForwardKernel(const index_t count,
48                                                    const DType* bottom_data,
49                                                    const DType spatial_scale,
50                                                    const index_t channels,
51                                                    const index_t height, const index_t width,
52                                                    const index_t pooled_height,
53                                                    const index_t pooled_width,
54                                                    const DType* bottom_rois,
55                                                    const DType* bottom_trans,
56                                                    const bool no_trans, const DType trans_std,
57                                                    const index_t sample_per_part,
58                                                    const index_t output_dim,
59                                                    const index_t group_size,
60                                                    const index_t part_size,
61                                                    const index_t num_classes,
62                                                    const index_t channels_each_class,
63                                                    DType* top_data, DType* top_count) {
64     CUDA_KERNEL_LOOP(index, count) {
65       // The output is in order (n, ctop, ph, pw)
66       index_t pw = index % pooled_width;
67       index_t ph = (index / pooled_width) % pooled_height;
68       index_t ctop = (index / pooled_width / pooled_height) % output_dim;
69       index_t n = index / pooled_width / pooled_height / output_dim;
70 
71       // [start, end) interval for spatial sampling
72       const DType* offset_bottom_rois = bottom_rois + n * 5;
73       index_t roi_batch_ind = offset_bottom_rois[0];
74       DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
75       DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
76       DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
77       DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
78 
79       // Force too small ROIs to be 1x1
80       DType roi_width = max(roi_end_w - roi_start_w, 0.1);  // avoid 0
81       DType roi_height = max(roi_end_h - roi_start_h, 0.1);
82 
83       // Compute w and h at bottom
84       DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
85       DType bin_size_w = roi_width / static_cast<DType>(pooled_width);
86 
87       DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
88       DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
89 
90       index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
91       index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
92       index_t class_id = ctop / channels_each_class;
93       DType trans_x = no_trans ? static_cast<DType>(0) :
94         bottom_trans[(((n * num_classes + class_id) * 2)
95                         * part_size + part_h)
96                         * part_size + part_w] * trans_std;
97       DType trans_y = no_trans ? static_cast<DType>(0) :
98         bottom_trans[(((n * num_classes + class_id) * 2 + 1)
99                         * part_size + part_h)
100                         * part_size + part_w] * trans_std;
101 
102       DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
103       wstart += trans_x * roi_width;
104       DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
105       hstart += trans_y * roi_height;
106 
107       DType sum = 0;
108       index_t count = 0;
109       index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
110       index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
111       gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
112       gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
113 
114       const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
115       for (index_t ih = 0; ih < sample_per_part; ih++) {
116         for (index_t iw = 0; iw < sample_per_part; iw++) {
117           DType w = wstart + iw * sub_bin_size_w;
118           DType h = hstart + ih * sub_bin_size_h;
119           // bilinear interpolation
120           if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
121             continue;
122           }
123           w = min(max(w, 0.), width - 1.);
124           h = min(max(h, 0.), height - 1.);
125           index_t c = (ctop * group_size + gh) * group_size + gw;
126           DType val = bilinear_interp(offset_bottom_data + c * height * width,
127                                       w, h, width, height);
128           sum += val;
129           count++;
130         }
131       }
132       top_data[index] = count == 0 ? static_cast<DType>(0) : sum / count;
133       top_count[index] = count;
134     }
135   }
136 
137   template<typename DType>
DeformablePSROIPoolForward(const Tensor<gpu,4,DType> & out,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)138   inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out,
139                                          const Tensor<gpu, 4, DType> &data,
140                                          const Tensor<gpu, 2, DType> &bbox,
141                                          const Tensor<gpu, 4, DType> &trans,
142                                          const Tensor<gpu, 4, DType> &top_count,
143                                          const bool no_trans, const float spatial_scale,
144                                          const index_t output_dim, const index_t group_size,
145                                          const index_t pooled_size, const index_t part_size,
146                                          const index_t sample_per_part, const float trans_std) {
147     const DType *bottom_data = data.dptr_;
148     const DType *bottom_rois = bbox.dptr_;
149     const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
150     DType *top_data = out.dptr_;
151     DType *top_count_data = top_count.dptr_;
152     const index_t count = out.shape_.Size();
153     const index_t channels = data.size(1);
154     const index_t height = data.size(2);
155     const index_t width = data.size(3);
156     const index_t pooled_height = pooled_size;
157     const index_t pooled_width = pooled_size;
158     const index_t num_classes = no_trans ? 1 : trans.size(1) / 2;
159     const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
160 
161     cudaStream_t stream = Stream<gpu>::GetStream(out.stream_);
162     DeformablePSROIPoolForwardKernel<DType><<<
163       mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum,
164       0, stream>>>(count, bottom_data, spatial_scale, channels, height, width,
165                    pooled_height, pooled_width, bottom_rois, bottom_trans,
166                    no_trans, trans_std, sample_per_part, output_dim,
167                    group_size, part_size, num_classes,
168                    channels_each_class, top_data, top_count_data);
169     DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
170   }
171 
172 
173   template <typename DType>
DeformablePSROIPoolBackwardAccKernel(const index_t count,const DType * top_diff,const DType * top_count,const index_t num_rois,const DType spatial_scale,const index_t channels,const index_t height,const index_t width,const index_t pooled_height,const index_t pooled_width,const index_t output_dim,DType * bottom_data_diff,DType * bottom_trans_diff,const DType * bottom_data,const DType * bottom_rois,const DType * bottom_trans,const bool no_trans,const DType trans_std,const index_t sample_per_part,const index_t group_size,const index_t part_size,const index_t num_classes,const index_t channels_each_class)174   __global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count,
175                                                        const DType* top_diff,
176                                                        const DType* top_count,
177                                                        const index_t num_rois,
178                                                        const DType spatial_scale,
179                                                        const index_t channels,
180                                                        const index_t height,
181                                                        const index_t width,
182                                                        const index_t pooled_height,
183                                                        const index_t pooled_width,
184                                                        const index_t output_dim,
185                                                        DType* bottom_data_diff,
186                                                        DType* bottom_trans_diff,
187                                                        const DType* bottom_data,
188                                                        const DType* bottom_rois,
189                                                        const DType* bottom_trans,
190                                                        const bool no_trans,
191                                                        const DType trans_std,
192                                                        const index_t sample_per_part,
193                                                        const index_t group_size,
194                                                        const index_t part_size,
195                                                        const index_t num_classes,
196                                                        const index_t channels_each_class) {
197     CUDA_KERNEL_LOOP(index, count) {
198       // The output is in order (n, ctop, ph, pw)
199       index_t pw = index % pooled_width;
200       index_t ph = (index / pooled_width) % pooled_height;
201       index_t ctop = (index / pooled_width / pooled_height) % output_dim;
202       index_t n = index / pooled_width / pooled_height / output_dim;
203 
204       // [start, end) interval for spatial sampling
205       const DType* offset_bottom_rois = bottom_rois + n * 5;
206       index_t roi_batch_ind = offset_bottom_rois[0];
207       DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
208       DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
209       DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
210       DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
211 
212       // Force too small ROIs to be 1x1
213       DType roi_width = max(roi_end_w - roi_start_w, 0.1);  // avoid 0
214       DType roi_height = max(roi_end_h - roi_start_h, 0.1);
215 
216       // Compute w and h at bottom
217       DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
218       DType bin_size_w = roi_width / static_cast<DType>(pooled_width);
219 
220       DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
221       DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
222 
223       index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
224       index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
225       index_t class_id = ctop / channels_each_class;
226       DType trans_x = no_trans ? static_cast<DType>(0) :
227         bottom_trans[(((n * num_classes + class_id) * 2)
228                         * part_size + part_h)
229                         * part_size + part_w] * trans_std;
230       DType trans_y = no_trans ? static_cast<DType>(0) :
231         bottom_trans[(((n * num_classes + class_id) * 2 + 1)
232                         * part_size + part_h)
233                         * part_size + part_w] * trans_std;
234 
235       DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
236       wstart += trans_x * roi_width;
237       DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
238       hstart += trans_y * roi_height;
239 
240       if (top_count[index] <= 0) {
241         continue;
242       }
243       DType diff_val = top_diff[index] / top_count[index];
244       const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
245       DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
246       index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
247       index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
248       gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
249       gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
250 
251       for (index_t ih = 0; ih < sample_per_part; ih++) {
252         for (index_t iw = 0; iw < sample_per_part; iw++) {
253           DType w = wstart + iw * sub_bin_size_w;
254           DType h = hstart + ih * sub_bin_size_h;
255           // bilinear interpolation
256           if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
257             continue;
258           }
259           w = min(max(w, 0.), width - 1.);
260           h = min(max(h, 0.), height - 1.);
261           index_t c = (ctop * group_size + gh) * group_size + gw;
262           // backward on feature
263           index_t x0 = floor(w);
264           index_t x1 = ceil(w);
265           index_t y0 = floor(h);
266           index_t y1 = ceil(h);
267           DType dist_x = w - x0, dist_y = h - y0;
268           DType q00 = (1 - dist_x) * (1 - dist_y);
269           DType q01 = (1 - dist_x) * dist_y;
270           DType q10 = dist_x * (1 - dist_y);
271           DType q11 = dist_x * dist_y;
272           index_t bottom_index_base = c * height * width;
273           atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
274           atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
275           atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
276           atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
277 
278           if (no_trans) {
279             continue;
280           }
281           DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
282           DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
283           DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
284           DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
285           DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y);
286           diff_x *= trans_std * diff_val * roi_width;
287           DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x);
288           diff_y *= trans_std * diff_val * roi_height;
289 
290           atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2)
291                                            * part_size + part_h)
292                                            * part_size + part_w, diff_x);
293           atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1)
294                                            * part_size + part_h)
295                                            * part_size + part_w, diff_y);
296         }
297       }
298     }
299   }
300 
301 
302   template<typename DType>
DeformablePSROIPoolBackwardAcc(const Tensor<gpu,4,DType> & in_grad,const Tensor<gpu,4,DType> & trans_grad,const Tensor<gpu,4,DType> & out_grad,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)303   inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad,
304                                              const Tensor<gpu, 4, DType> &trans_grad,
305                                              const Tensor<gpu, 4, DType> &out_grad,
306                                              const Tensor<gpu, 4, DType> &data,
307                                              const Tensor<gpu, 2, DType> &bbox,
308                                              const Tensor<gpu, 4, DType> &trans,
309                                              const Tensor<gpu, 4, DType> &top_count,
310                                              const bool no_trans, const float spatial_scale,
311                                              const index_t output_dim, const index_t group_size,
312                                              const index_t pooled_size, const index_t part_size,
313                                              const index_t sample_per_part, const float trans_std) {
314     const DType *top_diff = out_grad.dptr_;
315     const DType *bottom_data = data.dptr_;
316     const DType *bottom_rois = bbox.dptr_;
317     const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
318     DType *bottom_data_diff = in_grad.dptr_;
319     DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_;
320     const DType *top_count_data = top_count.dptr_;
321     const index_t count = out_grad.shape_.Size();
322     const index_t num_rois = bbox.size(0);
323     const index_t channels = in_grad.size(1);
324     const index_t height = in_grad.size(2);
325     const index_t width = in_grad.size(3);
326     const index_t pooled_height = pooled_size;
327     const index_t pooled_width = pooled_size;
328     const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
329     const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
330 
331     cudaStream_t stream = Stream<gpu>::GetStream(in_grad.stream_);
332     DeformablePSROIPoolBackwardAccKernel<DType><<<
333       mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum,
334       0, stream >>>(count, top_diff, top_count_data, num_rois, spatial_scale,
335                     channels, height, width, pooled_height, pooled_width,
336                     output_dim, bottom_data_diff, bottom_trans_diff,
337                     bottom_data, bottom_rois, bottom_trans,
338                     no_trans, trans_std, sample_per_part, group_size,
339                     part_size, num_classes, channels_each_class);
340     DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
341   }
342 
343 }  // namespace cuda
344 
345   template<typename DType>
DeformablePSROIPoolForward(const Tensor<gpu,4,DType> & out,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)346   inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out,
347                                          const Tensor<gpu, 4, DType> &data,
348                                          const Tensor<gpu, 2, DType> &bbox,
349                                          const Tensor<gpu, 4, DType> &trans,
350                                          const Tensor<gpu, 4, DType> &top_count,
351                                          const bool no_trans, const float spatial_scale,
352                                          const index_t output_dim, const index_t group_size,
353                                          const index_t pooled_size, const index_t part_size,
354                                          const index_t sample_per_part, const float trans_std) {
355     cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count,
356                                      no_trans, spatial_scale, output_dim,
357                                      group_size, pooled_size, part_size,
358                                      sample_per_part, trans_std);
359   }
360 
361   template<typename DType>
DeformablePSROIPoolBackwardAcc(const Tensor<gpu,4,DType> & in_grad,const Tensor<gpu,4,DType> & trans_grad,const Tensor<gpu,4,DType> & out_grad,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)362   inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad,
363                                              const Tensor<gpu, 4, DType> &trans_grad,
364                                              const Tensor<gpu, 4, DType> &out_grad,
365                                              const Tensor<gpu, 4, DType> &data,
366                                              const Tensor<gpu, 2, DType> &bbox,
367                                              const Tensor<gpu, 4, DType> &trans,
368                                              const Tensor<gpu, 4, DType> &top_count,
369                                              const bool no_trans, const float spatial_scale,
370                                              const index_t output_dim, const index_t group_size,
371                                              const index_t pooled_size, const index_t part_size,
372                                              const index_t sample_per_part, const float trans_std) {
373     cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox,
374                                          trans, top_count, no_trans, spatial_scale,
375                                          output_dim, group_size, pooled_size,
376                                          part_size, sample_per_part, trans_std);
377   }
378 
379 }  // namespace mshadow
380 
381 
382 namespace mxnet {
383 namespace op {
384 
385   template<>
CreateOp(DeformablePSROIPoolingParam param,int dtype)386   Operator* CreateOp<gpu>(DeformablePSROIPoolingParam param, int dtype) {
387     Operator* op = nullptr;
388     MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
389       op = new DeformablePSROIPoolingOp<gpu, DType>(param);
390     });
391     return op;
392   }
393 
394 }  // namespace op
395 }  // namespace mxnet
396