1 /*! 2 * Copyright (c) 2017 Microsoft 3 * Licensed under The MIT License [see LICENSE for details] 4 * \file deformable_psroi_pooling.cu 5 * \brief 6 * \author Yi Li, Guodong Zhang, Jifeng Dai 7 * 8 * Code from https://github.com/msracver/Deformable-ConvNets/blob/d51075968c5fd40b37a55d20c8e945c1f181d529/rfcn/operator_cxx/deformable_psroi_pooling.cu 9 */ 10 #include "./deformable_psroi_pooling-inl.h" 11 #include <mshadow/tensor.h> 12 #include <mshadow/cuda/reduce.cuh> 13 #include <algorithm> 14 #include <vector> 15 #include "../../common/cuda_utils.h" 16 #include "../mxnet_op.h" 17 18 #define DeformablePSROIPOOLING_CUDA_CHECK(condition) \ 19 /* Code block avoids redefinition of cudaError_t error */ \ 20 do { \ 21 cudaError_t error = condition; \ 22 CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ 23 } while (0) 24 25 namespace mshadow { 26 namespace cuda { 27 template <typename DType> bilinear_interp(const DType * data,const DType x,const DType y,const index_t width,const index_t height)28 __device__ DType bilinear_interp(const DType* data, 29 const DType x, const DType y, 30 const index_t width, const index_t height) { 31 index_t x1 = floor(x); 32 index_t x2 = ceil(x); 33 index_t y1 = floor(y); 34 index_t y2 = ceil(y); 35 DType dist_x = static_cast<DType>(x - x1); 36 DType dist_y = static_cast<DType>(y - y1); 37 DType value11 = data[y1 * width + x1]; 38 DType value12 = data[y2 * width + x1]; 39 DType value21 = data[y1 * width + x2]; 40 DType value22 = data[y2 * width + x2]; 41 DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + 42 dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; 43 return value; 44 } 45 46 template <typename DType> DeformablePSROIPoolForwardKernel(const index_t count,const DType * bottom_data,const DType spatial_scale,const index_t channels,const index_t height,const index_t width,const index_t pooled_height,const index_t pooled_width,const DType * bottom_rois,const DType * bottom_trans,const bool no_trans,const DType trans_std,const index_t sample_per_part,const index_t output_dim,const index_t group_size,const index_t part_size,const index_t num_classes,const index_t channels_each_class,DType * top_data,DType * top_count)47 __global__ void DeformablePSROIPoolForwardKernel(const index_t count, 48 const DType* bottom_data, 49 const DType spatial_scale, 50 const index_t channels, 51 const index_t height, const index_t width, 52 const index_t pooled_height, 53 const index_t pooled_width, 54 const DType* bottom_rois, 55 const DType* bottom_trans, 56 const bool no_trans, const DType trans_std, 57 const index_t sample_per_part, 58 const index_t output_dim, 59 const index_t group_size, 60 const index_t part_size, 61 const index_t num_classes, 62 const index_t channels_each_class, 63 DType* top_data, DType* top_count) { 64 CUDA_KERNEL_LOOP(index, count) { 65 // The output is in order (n, ctop, ph, pw) 66 index_t pw = index % pooled_width; 67 index_t ph = (index / pooled_width) % pooled_height; 68 index_t ctop = (index / pooled_width / pooled_height) % output_dim; 69 index_t n = index / pooled_width / pooled_height / output_dim; 70 71 // [start, end) interval for spatial sampling 72 const DType* offset_bottom_rois = bottom_rois + n * 5; 73 index_t roi_batch_ind = offset_bottom_rois[0]; 74 DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 75 DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 76 DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 77 DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 78 79 // Force too small ROIs to be 1x1 80 DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 81 DType roi_height = max(roi_end_h - roi_start_h, 0.1); 82 83 // Compute w and h at bottom 84 DType bin_size_h = roi_height / static_cast<DType>(pooled_height); 85 DType bin_size_w = roi_width / static_cast<DType>(pooled_width); 86 87 DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part); 88 DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part); 89 90 index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size); 91 index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size); 92 index_t class_id = ctop / channels_each_class; 93 DType trans_x = no_trans ? static_cast<DType>(0) : 94 bottom_trans[(((n * num_classes + class_id) * 2) 95 * part_size + part_h) 96 * part_size + part_w] * trans_std; 97 DType trans_y = no_trans ? static_cast<DType>(0) : 98 bottom_trans[(((n * num_classes + class_id) * 2 + 1) 99 * part_size + part_h) 100 * part_size + part_w] * trans_std; 101 102 DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w; 103 wstart += trans_x * roi_width; 104 DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h; 105 hstart += trans_y * roi_height; 106 107 DType sum = 0; 108 index_t count = 0; 109 index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width); 110 index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height); 111 gw = min(max(gw, static_cast<index_t>(0)), group_size - 1); 112 gh = min(max(gh, static_cast<index_t>(0)), group_size - 1); 113 114 const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 115 for (index_t ih = 0; ih < sample_per_part; ih++) { 116 for (index_t iw = 0; iw < sample_per_part; iw++) { 117 DType w = wstart + iw * sub_bin_size_w; 118 DType h = hstart + ih * sub_bin_size_h; 119 // bilinear interpolation 120 if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { 121 continue; 122 } 123 w = min(max(w, 0.), width - 1.); 124 h = min(max(h, 0.), height - 1.); 125 index_t c = (ctop * group_size + gh) * group_size + gw; 126 DType val = bilinear_interp(offset_bottom_data + c * height * width, 127 w, h, width, height); 128 sum += val; 129 count++; 130 } 131 } 132 top_data[index] = count == 0 ? static_cast<DType>(0) : sum / count; 133 top_count[index] = count; 134 } 135 } 136 137 template<typename DType> DeformablePSROIPoolForward(const Tensor<gpu,4,DType> & out,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)138 inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out, 139 const Tensor<gpu, 4, DType> &data, 140 const Tensor<gpu, 2, DType> &bbox, 141 const Tensor<gpu, 4, DType> &trans, 142 const Tensor<gpu, 4, DType> &top_count, 143 const bool no_trans, const float spatial_scale, 144 const index_t output_dim, const index_t group_size, 145 const index_t pooled_size, const index_t part_size, 146 const index_t sample_per_part, const float trans_std) { 147 const DType *bottom_data = data.dptr_; 148 const DType *bottom_rois = bbox.dptr_; 149 const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; 150 DType *top_data = out.dptr_; 151 DType *top_count_data = top_count.dptr_; 152 const index_t count = out.shape_.Size(); 153 const index_t channels = data.size(1); 154 const index_t height = data.size(2); 155 const index_t width = data.size(3); 156 const index_t pooled_height = pooled_size; 157 const index_t pooled_width = pooled_size; 158 const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; 159 const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; 160 161 cudaStream_t stream = Stream<gpu>::GetStream(out.stream_); 162 DeformablePSROIPoolForwardKernel<DType><<< 163 mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, 164 0, stream>>>(count, bottom_data, spatial_scale, channels, height, width, 165 pooled_height, pooled_width, bottom_rois, bottom_trans, 166 no_trans, trans_std, sample_per_part, output_dim, 167 group_size, part_size, num_classes, 168 channels_each_class, top_data, top_count_data); 169 DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); 170 } 171 172 173 template <typename DType> DeformablePSROIPoolBackwardAccKernel(const index_t count,const DType * top_diff,const DType * top_count,const index_t num_rois,const DType spatial_scale,const index_t channels,const index_t height,const index_t width,const index_t pooled_height,const index_t pooled_width,const index_t output_dim,DType * bottom_data_diff,DType * bottom_trans_diff,const DType * bottom_data,const DType * bottom_rois,const DType * bottom_trans,const bool no_trans,const DType trans_std,const index_t sample_per_part,const index_t group_size,const index_t part_size,const index_t num_classes,const index_t channels_each_class)174 __global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count, 175 const DType* top_diff, 176 const DType* top_count, 177 const index_t num_rois, 178 const DType spatial_scale, 179 const index_t channels, 180 const index_t height, 181 const index_t width, 182 const index_t pooled_height, 183 const index_t pooled_width, 184 const index_t output_dim, 185 DType* bottom_data_diff, 186 DType* bottom_trans_diff, 187 const DType* bottom_data, 188 const DType* bottom_rois, 189 const DType* bottom_trans, 190 const bool no_trans, 191 const DType trans_std, 192 const index_t sample_per_part, 193 const index_t group_size, 194 const index_t part_size, 195 const index_t num_classes, 196 const index_t channels_each_class) { 197 CUDA_KERNEL_LOOP(index, count) { 198 // The output is in order (n, ctop, ph, pw) 199 index_t pw = index % pooled_width; 200 index_t ph = (index / pooled_width) % pooled_height; 201 index_t ctop = (index / pooled_width / pooled_height) % output_dim; 202 index_t n = index / pooled_width / pooled_height / output_dim; 203 204 // [start, end) interval for spatial sampling 205 const DType* offset_bottom_rois = bottom_rois + n * 5; 206 index_t roi_batch_ind = offset_bottom_rois[0]; 207 DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 208 DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 209 DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 210 DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 211 212 // Force too small ROIs to be 1x1 213 DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 214 DType roi_height = max(roi_end_h - roi_start_h, 0.1); 215 216 // Compute w and h at bottom 217 DType bin_size_h = roi_height / static_cast<DType>(pooled_height); 218 DType bin_size_w = roi_width / static_cast<DType>(pooled_width); 219 220 DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part); 221 DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part); 222 223 index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size); 224 index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size); 225 index_t class_id = ctop / channels_each_class; 226 DType trans_x = no_trans ? static_cast<DType>(0) : 227 bottom_trans[(((n * num_classes + class_id) * 2) 228 * part_size + part_h) 229 * part_size + part_w] * trans_std; 230 DType trans_y = no_trans ? static_cast<DType>(0) : 231 bottom_trans[(((n * num_classes + class_id) * 2 + 1) 232 * part_size + part_h) 233 * part_size + part_w] * trans_std; 234 235 DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w; 236 wstart += trans_x * roi_width; 237 DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h; 238 hstart += trans_y * roi_height; 239 240 if (top_count[index] <= 0) { 241 continue; 242 } 243 DType diff_val = top_diff[index] / top_count[index]; 244 const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 245 DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 246 index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width); 247 index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height); 248 gw = min(max(gw, static_cast<index_t>(0)), group_size - 1); 249 gh = min(max(gh, static_cast<index_t>(0)), group_size - 1); 250 251 for (index_t ih = 0; ih < sample_per_part; ih++) { 252 for (index_t iw = 0; iw < sample_per_part; iw++) { 253 DType w = wstart + iw * sub_bin_size_w; 254 DType h = hstart + ih * sub_bin_size_h; 255 // bilinear interpolation 256 if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { 257 continue; 258 } 259 w = min(max(w, 0.), width - 1.); 260 h = min(max(h, 0.), height - 1.); 261 index_t c = (ctop * group_size + gh) * group_size + gw; 262 // backward on feature 263 index_t x0 = floor(w); 264 index_t x1 = ceil(w); 265 index_t y0 = floor(h); 266 index_t y1 = ceil(h); 267 DType dist_x = w - x0, dist_y = h - y0; 268 DType q00 = (1 - dist_x) * (1 - dist_y); 269 DType q01 = (1 - dist_x) * dist_y; 270 DType q10 = dist_x * (1 - dist_y); 271 DType q11 = dist_x * dist_y; 272 index_t bottom_index_base = c * height * width; 273 atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 274 atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 275 atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 276 atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); 277 278 if (no_trans) { 279 continue; 280 } 281 DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 282 DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 283 DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 284 DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 285 DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); 286 diff_x *= trans_std * diff_val * roi_width; 287 DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); 288 diff_y *= trans_std * diff_val * roi_height; 289 290 atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) 291 * part_size + part_h) 292 * part_size + part_w, diff_x); 293 atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) 294 * part_size + part_h) 295 * part_size + part_w, diff_y); 296 } 297 } 298 } 299 } 300 301 302 template<typename DType> DeformablePSROIPoolBackwardAcc(const Tensor<gpu,4,DType> & in_grad,const Tensor<gpu,4,DType> & trans_grad,const Tensor<gpu,4,DType> & out_grad,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)303 inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad, 304 const Tensor<gpu, 4, DType> &trans_grad, 305 const Tensor<gpu, 4, DType> &out_grad, 306 const Tensor<gpu, 4, DType> &data, 307 const Tensor<gpu, 2, DType> &bbox, 308 const Tensor<gpu, 4, DType> &trans, 309 const Tensor<gpu, 4, DType> &top_count, 310 const bool no_trans, const float spatial_scale, 311 const index_t output_dim, const index_t group_size, 312 const index_t pooled_size, const index_t part_size, 313 const index_t sample_per_part, const float trans_std) { 314 const DType *top_diff = out_grad.dptr_; 315 const DType *bottom_data = data.dptr_; 316 const DType *bottom_rois = bbox.dptr_; 317 const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; 318 DType *bottom_data_diff = in_grad.dptr_; 319 DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; 320 const DType *top_count_data = top_count.dptr_; 321 const index_t count = out_grad.shape_.Size(); 322 const index_t num_rois = bbox.size(0); 323 const index_t channels = in_grad.size(1); 324 const index_t height = in_grad.size(2); 325 const index_t width = in_grad.size(3); 326 const index_t pooled_height = pooled_size; 327 const index_t pooled_width = pooled_size; 328 const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; 329 const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; 330 331 cudaStream_t stream = Stream<gpu>::GetStream(in_grad.stream_); 332 DeformablePSROIPoolBackwardAccKernel<DType><<< 333 mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, 334 0, stream >>>(count, top_diff, top_count_data, num_rois, spatial_scale, 335 channels, height, width, pooled_height, pooled_width, 336 output_dim, bottom_data_diff, bottom_trans_diff, 337 bottom_data, bottom_rois, bottom_trans, 338 no_trans, trans_std, sample_per_part, group_size, 339 part_size, num_classes, channels_each_class); 340 DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); 341 } 342 343 } // namespace cuda 344 345 template<typename DType> DeformablePSROIPoolForward(const Tensor<gpu,4,DType> & out,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)346 inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out, 347 const Tensor<gpu, 4, DType> &data, 348 const Tensor<gpu, 2, DType> &bbox, 349 const Tensor<gpu, 4, DType> &trans, 350 const Tensor<gpu, 4, DType> &top_count, 351 const bool no_trans, const float spatial_scale, 352 const index_t output_dim, const index_t group_size, 353 const index_t pooled_size, const index_t part_size, 354 const index_t sample_per_part, const float trans_std) { 355 cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count, 356 no_trans, spatial_scale, output_dim, 357 group_size, pooled_size, part_size, 358 sample_per_part, trans_std); 359 } 360 361 template<typename DType> DeformablePSROIPoolBackwardAcc(const Tensor<gpu,4,DType> & in_grad,const Tensor<gpu,4,DType> & trans_grad,const Tensor<gpu,4,DType> & out_grad,const Tensor<gpu,4,DType> & data,const Tensor<gpu,2,DType> & bbox,const Tensor<gpu,4,DType> & trans,const Tensor<gpu,4,DType> & top_count,const bool no_trans,const float spatial_scale,const index_t output_dim,const index_t group_size,const index_t pooled_size,const index_t part_size,const index_t sample_per_part,const float trans_std)362 inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad, 363 const Tensor<gpu, 4, DType> &trans_grad, 364 const Tensor<gpu, 4, DType> &out_grad, 365 const Tensor<gpu, 4, DType> &data, 366 const Tensor<gpu, 2, DType> &bbox, 367 const Tensor<gpu, 4, DType> &trans, 368 const Tensor<gpu, 4, DType> &top_count, 369 const bool no_trans, const float spatial_scale, 370 const index_t output_dim, const index_t group_size, 371 const index_t pooled_size, const index_t part_size, 372 const index_t sample_per_part, const float trans_std) { 373 cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox, 374 trans, top_count, no_trans, spatial_scale, 375 output_dim, group_size, pooled_size, 376 part_size, sample_per_part, trans_std); 377 } 378 379 } // namespace mshadow 380 381 382 namespace mxnet { 383 namespace op { 384 385 template<> CreateOp(DeformablePSROIPoolingParam param,int dtype)386 Operator* CreateOp<gpu>(DeformablePSROIPoolingParam param, int dtype) { 387 Operator* op = nullptr; 388 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { 389 op = new DeformablePSROIPoolingOp<gpu, DType>(param); 390 }); 391 return op; 392 } 393 394 } // namespace op 395 } // namespace mxnet 396