1 /*!
2 ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3 *
4 * COPYRIGHT
5 *
6 * All contributions by the University of California:
7 * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8 * All rights reserved.
9 *
10 * All other contributions:
11 * Copyright (c) 2014-2017, the respective contributors
12 * All rights reserved.
13 *
14 * Caffe uses a shared copyright model: each contributor holds copyright over
15 * their contributions to Caffe. The project versioning records all such
16 * contribution and copyright details. If a contributor wants to further mark
17 * their specific copyright on a particular contribution, they should indicate
18 * their copyright solely in the commit message of the change when it is
19 * committed.
20 *
21 * LICENSE
22 *
23 * Redistribution and use in source and binary forms, with or without
24 * modification, are permitted provided that the following conditions are met:
25 *
26 * 1. Redistributions of source code must retain the above copyright notice, this
27 * list of conditions and the following disclaimer.
28 * 2. Redistributions in binary form must reproduce the above copyright notice,
29 * this list of conditions and the following disclaimer in the documentation
30 * and/or other materials provided with the distribution.
31 *
32 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 *
43 * CONTRIBUTION AGREEMENT
44 *
45 * By contributing to the BVLC/caffe repository through pull-request, comment,
46 * or otherwise, the contributor releases their content to the
47 * license and copyright terms herein.
48 *
49 ***************** END Caffe Copyright Notice and Disclaimer ********************
50 *
51 * Copyright (c) 2018 Microsoft
52 * Licensed under The MIT License [see LICENSE for details]
53 * \file modulated_deformable_im2col.cuh
54 * \brief Function definitions of converting an image to
55 * column matrix based on kernel, padding, dilation, and offset.
56 * These functions are mainly used in modulated deformable convolution operators.
57 * \ref: https://arxiv.org/abs/1811.11168
58 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
59 */
60
61 #ifndef MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_CUH_
62 #define MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_CUH_
63
64 #include <mxnet/base.h>
65 #include <mxnet/operator.h>
66 #include <algorithm>
67 #include <cstring>
68 #include <vector>
69 #include "../../mxnet_op.h"
70 #include "../../../common/cuda_utils.h"
71
72
73
74 namespace mxnet {
75 namespace op {
76
77 template <typename DType>
dmcn_im2col_bilinear(const DType * bottom_data,const int data_width,const int height,const int width,DType h,DType w)78 __device__ DType dmcn_im2col_bilinear(const DType* bottom_data, const int data_width,
79 const int height, const int width, DType h, DType w) {
80
81 int h_low = floor(h);
82 int w_low = floor(w);
83 int h_high = h_low + 1;
84 int w_high = w_low + 1;
85
86 DType lh = h - h_low;
87 DType lw = w - w_low;
88 DType hh = 1 - lh, hw = 1 - lw;
89
90 DType v1 = 0;
91 if (h_low >= 0 && w_low >= 0)
92 v1 = bottom_data[h_low * data_width + w_low];
93 DType v2 = 0;
94 if (h_low >=0 && w_high <= width - 1)
95 v2 = bottom_data[h_low * data_width + w_high];
96 DType v3 = 0;
97 if (h_high <= height - 1 && w_low >= 0)
98 v3 = bottom_data[h_high * data_width + w_low];
99 DType v4 = 0;
100 if (h_high <= height - 1 && w_high <= width - 1)
101 v4 = bottom_data[h_high * data_width + w_high];
102
103 DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
104
105 DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
106 return val;
107 }
108
109
110 template <typename DType>
dmcn_get_gradient_weight(DType argmax_h,DType argmax_w,const int h,const int w,const int height,const int width)111 __device__ DType dmcn_get_gradient_weight(DType argmax_h, DType argmax_w,
112 const int h, const int w, const int height, const int width) {
113
114 if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {
115 //empty
116 return 0;
117 }
118
119 int argmax_h_low = floor(argmax_h);
120 int argmax_w_low = floor(argmax_w);
121 int argmax_h_high = argmax_h_low + 1;
122 int argmax_w_high = argmax_w_low + 1;
123
124 DType weight = 0;
125 if (h == argmax_h_low && w == argmax_w_low)
126 weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
127 if (h == argmax_h_low && w == argmax_w_high)
128 weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
129 if (h == argmax_h_high && w == argmax_w_low)
130 weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
131 if (h == argmax_h_high && w == argmax_w_high)
132 weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
133 return weight;
134 }
135
136
137 template <typename DType>
dmcn_get_coordinate_weight(DType argmax_h,DType argmax_w,const int height,const int width,const DType * im_data,const int data_width,const int bp_dir)138 __device__ DType dmcn_get_coordinate_weight(DType argmax_h, DType argmax_w,
139 const int height, const int width, const DType* im_data,
140 const int data_width, const int bp_dir) {
141
142 if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
143 {
144 //empty
145 return 0;
146 }
147
148 int argmax_h_low = floor(argmax_h);
149 int argmax_w_low = floor(argmax_w);
150 int argmax_h_high = argmax_h_low + 1;
151 int argmax_w_high = argmax_w_low + 1;
152
153 DType weight = 0;
154
155 if (bp_dir == 0) {
156 if (argmax_h_low >= 0 && argmax_w_low >= 0)
157 weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
158 if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
159 weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
160 if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
161 weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
162 if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
163 weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
164 } else if (bp_dir == 1) {
165 if (argmax_h_low >= 0 && argmax_w_low >= 0)
166 weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
167 if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
168 weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
169 if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
170 weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
171 if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
172 weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
173 }
174
175 return weight;
176 }
177
178
179 /*!
180 * \brief deformable_im2col gpu kernel.
181 * DO NOT call this directly. Use wrapper function im2col() instead;
182 */
183 template <typename DType>
modulated_deformable_im2col_gpu_kernel(const int n,const DType * data_im,const DType * data_offset,const DType * data_mask,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,const int channel_per_deformable_group,const int batch_size,const int num_channels,const int deformable_group,const int height_col,const int width_col,DType * data_col)184 __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
185 const DType* data_im, const DType* data_offset, const DType* data_mask,
186 const int height, const int width, const int kernel_h, const int kernel_w,
187 const int pad_h, const int pad_w,
188 const int stride_h, const int stride_w,
189 const int dilation_h, const int dilation_w,
190 const int channel_per_deformable_group,
191 const int batch_size, const int num_channels, const int deformable_group,
192 const int height_col, const int width_col,
193 DType* data_col) {
194 CUDA_KERNEL_LOOP(index, n) {
195 // index index of output matrix
196 const int w_col = index % width_col;
197 const int h_col = (index / width_col) % height_col;
198 const int b_col = (index / width_col / height_col) % batch_size;
199 const int c_im = (index / width_col / height_col) / batch_size;
200 const int c_col = c_im * kernel_h * kernel_w;
201
202 // compute deformable group index
203 const int deformable_group_index = c_im / channel_per_deformable_group;
204
205 const int h_in = h_col * stride_h - pad_h;
206 const int w_in = w_col * stride_w - pad_w;
207
208 DType* data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
209 //const DType* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
210 const DType* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
211 const DType* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
212
213 const DType* data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
214
215 for (int i = 0; i < kernel_h; ++i) {
216 for (int j = 0; j < kernel_w; ++j) {
217 const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
218 const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
219 const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
220 const DType offset_h = data_offset_ptr[data_offset_h_ptr];
221 const DType offset_w = data_offset_ptr[data_offset_w_ptr];
222 const DType mask = data_mask_ptr[data_mask_hw_ptr];
223 DType val = static_cast<DType>(0);
224 const DType h_im = h_in + i * dilation_h + offset_h;
225 const DType w_im = w_in + j * dilation_w + offset_w;
226 //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
227 if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
228 //const DType map_h = i * dilation_h + offset_h;
229 //const DType map_w = j * dilation_w + offset_w;
230 //const int cur_height = height - h_in;
231 //const int cur_width = width - w_in;
232 //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
233 val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
234 }
235 *data_col_ptr = val * mask;
236 data_col_ptr += batch_size * height_col * width_col;
237 //data_col_ptr += height_col * width_col;
238 }
239 }
240 }
241 }
242
243
244
245 /*!\brief
246 * cpu function of deformable_im2col algorithm
247 * \param s device stream
248 * \param data_im pointer of an image (N, C, H, W, ...) in the image batch
249 * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch
250 * \param im_shape input image shape in dimensions (N, C, H, W,)
251 * \param col_shape column buffer shape (#channels, N, output_im_height, output_im_width, ...)
252 * \param kernel_shape kernel filter shape
253 * \param pad pad shape
254 * \param stride stride shape
255 * \param dilation dilation shape
256 * \param deformable_group #offset group that deformable convolution use
257 * \param data_col column buffer pointer
258 */
259 template <typename DType>
modulated_deformable_im2col(mshadow::Stream<gpu> * s,const DType * data_im,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * data_col)260 inline void modulated_deformable_im2col(mshadow::Stream<gpu>* s,
261 const DType* data_im, const DType* data_offset, const DType* data_mask,
262 const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
263 const TShape& pad, const TShape& stride, const TShape& dilation,
264 const uint32_t deformable_group, DType* data_col) {
265 // num_axes should be smaller than block size
266 index_t num_spatial_axes = kernel_shape.ndim();
267 CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
268 index_t channel_per_deformable_group = im_shape[1] / deformable_group;
269 index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim());
270 using namespace mxnet_op;
271 switch (num_spatial_axes) {
272 case 2:
273 modulated_deformable_im2col_gpu_kernel<DType> // NOLINT_NEXT_LINE(whitespace/operators)
274 <<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
275 0, mshadow::Stream<gpu>::GetStream(s)>>>(
276 num_kernels, data_im, data_offset, data_mask, im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1],
277 pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1], channel_per_deformable_group,
278 col_shape[1], im_shape[1], deformable_group, col_shape[2], col_shape[3], data_col);
279 MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_im2col_gpu_kernel);
280 break;
281 default:
282 LOG(FATAL) << "im2col_nd_gpu does not support computation with "
283 << num_spatial_axes << " spatial axes";
284 }
285 }
286
287
288 /*!
289 * \brief deformable_col2im gpu kernel.
290 * \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead;
291 */
292 template <typename DType>
modulated_deformable_col2im_gpu_kernel(const int n,const DType * data_col,const DType * data_offset,const DType * data_mask,const int channels,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,const int channel_per_deformable_group,const int batch_size,const int deformable_group,const int height_col,const int width_col,DType * grad_im,OpReqType req)293 __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
294 const DType* data_col, const DType* data_offset, const DType* data_mask,
295 const int channels, const int height, const int width,
296 const int kernel_h, const int kernel_w,
297 const int pad_h, const int pad_w,
298 const int stride_h, const int stride_w,
299 const int dilation_h, const int dilation_w,
300 const int channel_per_deformable_group,
301 const int batch_size, const int deformable_group,
302 const int height_col, const int width_col,
303 DType* grad_im, OpReqType req) {
304 CUDA_KERNEL_LOOP(index, n) {
305 const int j = (index / width_col / height_col / batch_size) % kernel_w;
306 const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
307 const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
308 // compute the start and end of the output
309
310 const int deformable_group_index = c / channel_per_deformable_group;
311
312 int w_out = index % width_col;
313 int h_out = (index / width_col) % height_col;
314 int b = (index / width_col / height_col) % batch_size;
315 int w_in = w_out * stride_w - pad_w;
316 int h_in = h_out * stride_h - pad_h;
317
318 const DType* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
319 const DType* data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
320 const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
321 const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
322 const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
323 const DType offset_h = data_offset_ptr[data_offset_h_ptr];
324 const DType offset_w = data_offset_ptr[data_offset_w_ptr];
325 const DType mask = data_mask_ptr[data_mask_hw_ptr];
326 const DType cur_inv_h_data = h_in + i * dilation_h + offset_h;
327 const DType cur_inv_w_data = w_in + j * dilation_w + offset_w;
328
329 const DType cur_top_grad = data_col[index] * mask;
330 const int cur_h = (int)cur_inv_h_data;
331 const int cur_w = (int)cur_inv_w_data;
332 for (int dy = -2; dy <= 2; dy++) {
333 for (int dx = -2; dx <= 2; dx++) {
334 if (cur_h + dy >= 0 && cur_h + dy < height &&
335 cur_w + dx >= 0 && cur_w + dx < width &&
336 abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
337 abs(cur_inv_w_data - (cur_w + dx)) < 1
338 ) {
339 int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
340 DType weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
341 atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
342 }
343 }
344 }
345 }
346 }
347
348
349 /*!\brief
350 * gpu function of deformable_col2im algorithm
351 * \param s device stream
352 * \param data_col start pointer of the column buffer to be filled
353 * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch
354 * \param im_shape input image shape in dimensions (N, C, H, W,)
355 * \param col_shape column buffer shape
356 * \param kernel_shape kernel filter shape
357 * \param pad pad shape
358 * \param stride stride shape
359 * \param dilation dilation shape
360 * \param deformable_group #offset group that deformable convolution use
361 * \param grad_im pointer of images (N, C, H, W,...) in the image batch
362 */
363 template <typename DType>
modulated_deformable_col2im(mshadow::Stream<gpu> * s,const DType * data_col,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * grad_im,OpReqType req)364 inline void modulated_deformable_col2im(mshadow::Stream<gpu>* s,
365 const DType* data_col, const DType* data_offset, const DType* data_mask,
366 const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
367 const TShape& pad, const TShape& stride,
368 const TShape& dilation, const uint32_t deformable_group,
369 DType* grad_im, OpReqType req) {
370 index_t num_spatial_axes = kernel_shape.ndim();
371 index_t im_size = im_shape.ProdShape(1, im_shape.ndim());
372 index_t channel_per_deformable_group = im_shape[1] / deformable_group;
373 index_t num_kernels = col_shape.ProdShape(0, col_shape.ndim());
374 // num_axes should be smaller than block size
375 CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
376 using namespace mxnet_op;
377 switch (num_spatial_axes) {
378 case 2:
379 // To avoid involving atomic operations, we will launch one kernel per
380 // bottom dimension, and then in the kernel add up the top dimensions.
381 // NOLINT_NEXT_LINE(whitespace/operators)
382 modulated_deformable_col2im_gpu_kernel<DType><<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
383 0, mshadow::Stream<gpu>::GetStream(s)>>>(
384 num_kernels, data_col, data_offset, data_mask, im_shape[1], im_shape[2], im_shape[3],
385 kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1],
386 dilation[0], dilation[1], channel_per_deformable_group,
387 col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im, req);
388 MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_col2im_gpu_kernel);
389 break;
390 default:
391 LOG(FATAL) << "col2im_nd_gpu does not support computation with "
392 << num_spatial_axes << " spatial axes";
393 }
394 }
395
396
397 /*!
398 * \brief deformable_col2im_coord gpu kernel.
399 * \brief DO NOT call this directly. Use wrapper function deformable_col2im_coord() instead;
400 */
401 template <typename DType>
modulated_deformable_col2im_coord_gpu_kernel(const int n,const DType * data_col,const DType * data_im,const DType * data_offset,const DType * data_mask,const int channels,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,const int channel_per_deformable_group,const int batch_size,const int offset_channels,const int deformable_group,const int height_col,const int width_col,DType * grad_offset,DType * grad_mask,OpReqType offset_req,OpReqType mask_req)402 __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
403 const DType* data_col, const DType* data_im,
404 const DType* data_offset, const DType* data_mask,
405 const int channels, const int height, const int width,
406 const int kernel_h, const int kernel_w,
407 const int pad_h, const int pad_w,
408 const int stride_h, const int stride_w,
409 const int dilation_h, const int dilation_w,
410 const int channel_per_deformable_group,
411 const int batch_size, const int offset_channels, const int deformable_group,
412 const int height_col, const int width_col,
413 DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) {
414 CUDA_KERNEL_LOOP(index, n) {
415 DType val = 0, mval = 0;
416 int w = index % width_col;
417 int h = (index / width_col) % height_col;
418 int c = (index / width_col / height_col) % offset_channels;
419 int b = (index / width_col / height_col) / offset_channels;
420 // compute the start and end of the output
421
422 const int deformable_group_index = c / (2 * kernel_h * kernel_w);
423 const int col_step = kernel_h * kernel_w;
424 int cnt = 0;
425 const DType* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
426 const DType* data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
427 const DType* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
428 const DType* data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
429
430 const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
431
432 for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) {
433 const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
434 const int bp_dir = offset_c % 2;
435
436 int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
437 int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
438 int w_out = col_pos % width_col;
439 int h_out = (col_pos / width_col) % height_col;
440 int w_in = w_out * stride_w - pad_w;
441 int h_in = h_out * stride_h - pad_h;
442 const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
443 const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
444 const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
445 const DType offset_h = data_offset_ptr[data_offset_h_ptr];
446 const DType offset_w = data_offset_ptr[data_offset_w_ptr];
447 const DType mask = data_mask_ptr[data_mask_hw_ptr];
448 DType inv_h = h_in + i * dilation_h + offset_h;
449 DType inv_w = w_in + j * dilation_w + offset_w;
450 if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
451 inv_h = inv_w = -2;
452 } else {
453 mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
454 }
455 const DType weight = dmcn_get_coordinate_weight(
456 inv_h, inv_w,
457 height, width, data_im_ptr + cnt * height * width, width, bp_dir);
458 val += weight * data_col_ptr[col_pos] * mask;
459 cnt += 1;
460 }
461
462 //grad_offset[index] = val;
463 KERNEL_ASSIGN(grad_offset[index], offset_req, val);
464 if (offset_c % 2 == 0)
465 KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
466 }
467 }
468
469 /*!\brief
470 * gpu function of deformable_col2im_coord algorithm
471 * \param s device stream
472 * \param data_col start pointer of the column buffer to be filled
473 * \param data_im pointer of an image (N, C, H, W, ...) in the image batch
474 * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch
475 * \param im_shape input image shape in dimensions (N, C, H, W,)
476 * \param col_shape column buffer shape
477 * \param kernel_shape kernel filter shape
478 * \param pad pad shape
479 * \param stride stride shape
480 * \param dilation dilation shape
481 * \param deformable_group #offset group that deformable convolution use
482 * \param grad_offset pointer of the offset (N, deformable_group*kernel_h*kernel_w*2, H, W,...) in the offset batch
483 */
484 template <typename DType>
modulated_deformable_col2im_coord(mshadow::Stream<gpu> * s,const DType * data_col,const DType * data_im,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * grad_offset,DType * grad_mask,OpReqType offset_req,OpReqType mask_req)485 inline void modulated_deformable_col2im_coord(mshadow::Stream<gpu>* s,
486 const DType* data_col, const DType* data_im, const DType* data_offset, const DType* data_mask,
487 const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
488 const TShape& pad, const TShape& stride,
489 const TShape& dilation, const uint32_t deformable_group,
490 DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) {
491 index_t num_spatial_axes = kernel_shape.ndim();
492 index_t num_kernels = col_shape[1] * col_shape[2] * col_shape[3] * 2 * kernel_shape[0] * kernel_shape[1] * deformable_group;
493 index_t channel_per_deformable_group = col_shape[0] / deformable_group;
494 // num_axes should be smaller than block size
495 CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
496 using namespace mxnet_op;
497 switch (num_spatial_axes) {
498 case 2:
499 // To avoid involving atomic operations, we will launch one kernel per
500 // bottom dimension, and then in the kernel add up the top dimensions.
501 // NOLINT_NEXT_LINE(whitespace/operators)
502
503 modulated_deformable_col2im_coord_gpu_kernel<DType> << <cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
504 0, mshadow::Stream<gpu>::GetStream(s) >> >(
505 num_kernels, data_col, data_im, data_offset, data_mask, im_shape[1], im_shape[2], im_shape[3],
506 kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1],
507 dilation[0], dilation[1], channel_per_deformable_group,
508 col_shape[1], 2 * kernel_shape[0] * kernel_shape[1] * deformable_group, deformable_group, col_shape[2], col_shape[3],
509 grad_offset, grad_mask, offset_req, mask_req);
510 MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_col2im_coord_gpu_kernel);
511 break;
512 default:
513 LOG(FATAL) << "col2im_nd_gpu does not support computation with "
514 << num_spatial_axes << " spatial axes";
515 }
516 }
517
518
519 } // namespace op
520 } // namespace mxnet
521
522 #endif // MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_MASKED_IM2COL_CUH_
523