1 /*!
2 ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3 *
4 * COPYRIGHT
5 *
6 * All contributions by the University of California:
7 * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8 * All rights reserved.
9 *
10 * All other contributions:
11 * Copyright (c) 2014-2017, the respective contributors
12 * All rights reserved.
13 *
14 * Caffe uses a shared copyright model: each contributor holds copyright over
15 * their contributions to Caffe. The project versioning records all such
16 * contribution and copyright details. If a contributor wants to further mark
17 * their specific copyright on a particular contribution, they should indicate
18 * their copyright solely in the commit message of the change when it is
19 * committed.
20 *
21 * LICENSE
22 *
23 * Redistribution and use in source and binary forms, with or without
24 * modification, are permitted provided that the following conditions are met:
25 *
26 * 1. Redistributions of source code must retain the above copyright notice, this
27 * list of conditions and the following disclaimer.
28 * 2. Redistributions in binary form must reproduce the above copyright notice,
29 * this list of conditions and the following disclaimer in the documentation
30 * and/or other materials provided with the distribution.
31 *
32 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 *
43 * CONTRIBUTION AGREEMENT
44 *
45 * By contributing to the BVLC/caffe repository through pull-request, comment,
46 * or otherwise, the contributor releases their content to the
47 * license and copyright terms herein.
48 *
49 ***************** END Caffe Copyright Notice and Disclaimer ********************
50 *
51 * Copyright (c) 2018 Microsoft
52 * Licensed under The MIT License [see LICENSE for details]
53 * \file deformable_im2col.h
54 * \brief Function definitions of converting an image to
55 * column matrix based on kernel, padding, dilation, and offset.
56 * These functions are mainly used in deformable convolution operators.
57 * \ref: https://arxiv.org/abs/1811.11168
58 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
59 */
60
61 #ifndef MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
62 #define MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
63
64 #include <mxnet/base.h>
65 #include <mxnet/operator.h>
66 #include <cstring>
67 #include <vector>
68 #include <algorithm>
69 #include "../../mxnet_op.h"
70
71 namespace mxnet {
72 namespace op {
73
74 template <typename DType>
im2col_bilinear_cpu(const DType * data,const index_t height,const index_t width,DType h,DType w)75 inline DType im2col_bilinear_cpu(const DType* data,
76 const index_t height,
77 const index_t width,
78 DType h, DType w) {
79 index_t h_low = floor(h);
80 index_t w_low = floor(w);
81 index_t h_high;
82 index_t w_high;
83
84 if (h_low >= height - 1) {
85 h_high = height - 1;
86 h = static_cast<DType>(h_low);
87 } else {
88 h_high = h_low + 1;
89 }
90
91 if (w_low >= width - 1) {
92 w_high = width - 1;
93 w = static_cast<DType>(w_low);
94 } else {
95 w_high = w_low + 1;
96 }
97
98 DType lh = h - h_low;
99 DType lw = w - w_low;
100 DType hh = 1 - lh, hw = 1 - lw;
101
102 DType v1 = data[h_low * width + w_low];
103 DType v2 = data[h_low * width + w_high];
104 DType v3 = data[h_high * width + w_low];
105 DType v4 = data[h_high * width + w_high];
106 DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
107
108 return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
109 }
110
111
112 template <typename DType>
get_gradient_weight_cpu(DType argmax_h,DType argmax_w,const index_t h,const index_t w,const index_t height,const index_t width)113 inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w,
114 const index_t h, const index_t w,
115 const index_t height, const index_t width) {
116 if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
117 // empty
118 return 0;
119 }
120
121 argmax_h = std::max(argmax_h, static_cast<DType>(0.0f));
122 argmax_w = std::max(argmax_w, static_cast<DType>(0.0f));
123
124 index_t argmax_h_low = static_cast<index_t>(argmax_h);
125 index_t argmax_w_low = static_cast<index_t>(argmax_w);
126 index_t argmax_h_high;
127 index_t argmax_w_high;
128 if (argmax_h_low >= height - 1) {
129 argmax_h_high = argmax_h_low = height - 1;
130 argmax_h = static_cast<DType>(argmax_h_low);
131 } else {
132 argmax_h_high = argmax_h_low + 1;
133 }
134 if (argmax_w_low >= width - 1) {
135 argmax_w_high = argmax_w_low = width - 1;
136 argmax_w = static_cast<DType>(argmax_w_low);
137 } else {
138 argmax_w_high = argmax_w_low + 1;
139 }
140 DType weight = 0;
141 if (h == argmax_h_low) {
142 if (w == argmax_w_low) {
143 weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
144 } else if (w == argmax_w_high) {
145 weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
146 }
147 } else if (h == argmax_h_high) {
148 if (w == argmax_w_low) {
149 weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
150 } else if (w == argmax_w_high) {
151 weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
152 }
153 }
154 return weight;
155 }
156
157
158 template <typename DType>
get_coordinate_weight_cpu(DType argmax_h,DType argmax_w,const index_t height,const index_t width,const DType * im_data,const index_t data_width,const index_t bp_dir)159 inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w,
160 const index_t height, const index_t width,
161 const DType* im_data,
162 const index_t data_width, const index_t bp_dir) {
163 if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
164 // empty
165 return 0;
166 }
167
168 if (argmax_h < 0) argmax_h = 0;
169 if (argmax_w < 0) argmax_w = 0;
170
171 index_t argmax_h_low = static_cast<index_t>(argmax_h);
172 index_t argmax_w_low = static_cast<index_t>(argmax_w);
173 index_t argmax_h_high;
174 index_t argmax_w_high;
175 if (argmax_h_low >= height - 1) {
176 argmax_h_high = argmax_h_low = height - 1;
177 argmax_h = static_cast<DType>(argmax_h_low);
178 } else {
179 argmax_h_high = argmax_h_low + 1;
180 }
181 if (argmax_w_low >= width - 1) {
182 argmax_w_high = argmax_w_low = width - 1;
183 argmax_w = static_cast<DType>(argmax_w_low);
184 } else {
185 argmax_w_high = argmax_w_low + 1;
186 }
187
188 DType weight = 0;
189 DType im_ll = im_data[argmax_h_low * data_width + argmax_w_low];
190 DType im_lh = im_data[argmax_h_low * data_width + argmax_w_high];
191 DType im_hl = im_data[argmax_h_high * data_width + argmax_w_low];
192 DType im_hh = im_data[argmax_h_high * data_width + argmax_w_high];
193 if (bp_dir == 0) {
194 weight += -1 * (argmax_w_low + 1 - argmax_w) * im_ll;
195 weight += -1 * (argmax_w - argmax_w_low) * im_lh;
196 weight += (argmax_w_low + 1 - argmax_w) * im_hl;
197 weight += (argmax_w - argmax_w_low) * im_hh;
198 } else if (bp_dir == 1) {
199 weight += -1 * (argmax_h_low + 1 - argmax_h) * im_ll;
200 weight += (argmax_h_low + 1 - argmax_h) * im_lh;
201 weight += -1 * (argmax_h - argmax_h_low) * im_hl;
202 weight += (argmax_h - argmax_h_low) * im_hh;
203 }
204
205 return weight;
206 }
207
208
209 /*!
210 * \brief deformable_im2col 2D cpu version.
211 * DO NOT call this function directly.
212 * Use the wrapper function im2col() instead.
213 */
214 template <typename DType>
deformable_im2col_cpu(const DType * data_im,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * data_col)215 inline void deformable_im2col_cpu(const DType* data_im,
216 const DType* data_offset,
217 const index_t channels,
218 const index_t height, const index_t width,
219 const index_t kernel_h, const index_t kernel_w,
220 const index_t pad_h, const index_t pad_w,
221 const index_t stride_h, const index_t stride_w,
222 const index_t dilation_h, const index_t dilation_w,
223 const index_t deformable_group,
224 const index_t height_col, const index_t width_col,
225 DType* data_col) {
226 const index_t channel_size = height * width;
227 const index_t offset_size = 2 * kernel_h * kernel_w * height_col * width_col;
228 const index_t channel_per_group = channels / deformable_group;
229 for (index_t channel = 0; channel < channels; channel++, data_im += channel_size) {
230 if (channel % channel_per_group == 0 && channel != 0) {
231 data_offset += offset_size;
232 }
233 for (index_t i = 0; i < kernel_h; i++) {
234 for (index_t j = 0; j < kernel_w; j++) {
235 index_t input_row = -pad_h + i * dilation_h;
236 for (index_t h_col = 0; h_col < height_col; h_col++) {
237 index_t input_col = -pad_w + j * dilation_w;
238 for (index_t w_col = 0; w_col < width_col; w_col++) {
239 index_t offset_h_ptr = ((2 * (i * kernel_w + j)) *
240 height_col + h_col) * width_col + w_col;
241 index_t offset_w_ptr = offset_h_ptr + height_col * width_col;
242 DType im_row = input_row + data_offset[offset_h_ptr];
243 DType im_col = input_col + data_offset[offset_w_ptr];
244 if (im_row >= 0 && im_col >= 0 && im_row < height && im_col < width) {
245 *(data_col++) = im2col_bilinear_cpu(data_im, height, width, im_row, im_col);
246 } else {
247 *(data_col++) = 0;
248 }
249 input_col += stride_w;
250 }
251 input_row += stride_h;
252 }
253 }
254 }
255 }
256 }
257
258
259 /*!\brief
260 * cpu function of deformable_im2col algorithm
261 * \param s device stream
262 * \param data_im pointer of an image (C, H, W, ...) in the image batch
263 * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
264 * \param im_shape input image shape in dimensions (N, C, H, W,)
265 * \param col_shape column buffer shape (#channels, output_im_height, output_im_width, ...)
266 * \param kernel_shape kernel filter shape
267 * \param pad pad shape
268 * \param stride stride shape
269 * \param dilation dilation shape
270 * \param deformable_group #offset group that deformable convolution use
271 * \param data_col column buffer pointer
272 */
273 template <typename DType>
deformable_im2col(mshadow::Stream<cpu> * s,const DType * data_im,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * data_col)274 inline void deformable_im2col(mshadow::Stream<cpu>* s,
275 const DType* data_im, const DType* data_offset,
276 const mxnet::TShape& im_shape,
277 const mxnet::TShape& col_shape,
278 const mxnet::TShape& kernel_shape,
279 const mxnet::TShape& pad,
280 const mxnet::TShape& stride,
281 const mxnet::TShape& dilation,
282 const index_t deformable_group,
283 DType* data_col) {
284 if (2 == kernel_shape.ndim()) {
285 deformable_im2col_cpu(data_im, data_offset,
286 im_shape[1], im_shape[2], im_shape[3],
287 kernel_shape[0], kernel_shape[1],
288 pad[0], pad[1],
289 stride[0], stride[1],
290 dilation[0], dilation[1],
291 deformable_group,
292 col_shape[1], col_shape[2], data_col);
293 } else {
294 LOG(FATAL) << "not implemented";
295 }
296 }
297
298
299 /*!
300 * \brief deformable_col2im cpu version.
301 * DO NOT call this directly.
302 * Use wrapper function deformable_col2im() instead;
303 */
304 template <typename DType>
deformable_col2im_cpu(const DType * data_col,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * grad_im)305 inline void deformable_col2im_cpu(const DType* data_col,
306 const DType* data_offset, const index_t channels,
307 const index_t height, const index_t width,
308 const index_t kernel_h, const index_t kernel_w,
309 const index_t pad_h, const index_t pad_w,
310 const index_t stride_h, const index_t stride_w,
311 const index_t dilation_h, const index_t dilation_w,
312 const index_t deformable_group,
313 const index_t height_col, const index_t width_col,
314 DType* grad_im) {
315 index_t channel_per_group = channels / deformable_group;
316 index_t count = channels * kernel_h * kernel_w * height_col * width_col;
317 for (index_t index = 0; index < count; ++index) {
318 const index_t j = (index / width_col / height_col) % kernel_w;
319 const index_t i = (index / width_col / height_col / kernel_w) % kernel_h;
320 const index_t c = index / width_col / height_col / kernel_w / kernel_h;
321 // compute the start and end of the output
322
323 const index_t group_index = c / channel_per_group;
324 const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col;
325
326 index_t w_col = index % width_col;
327 index_t h_col = (index / width_col) % height_col;
328 index_t w_in = w_col * stride_w - pad_w;
329 index_t h_in = h_col * stride_h - pad_h;
330
331 const DType* data_offset_ptr = data_offset + group_index * group_offset_step;
332 const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) *
333 height_col + h_col) * width_col + w_col;
334 const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col;
335 const DType offset_h = data_offset_ptr[data_offset_h_ptr];
336 const DType offset_w = data_offset_ptr[data_offset_w_ptr];
337 const DType cur_inv_h_data = h_in + i * dilation_h + offset_h;
338 const DType cur_inv_w_data = w_in + j * dilation_w + offset_w;
339
340 const DType cur_top_grad = data_col[index];
341 const index_t cur_h = static_cast<index_t>(cur_inv_h_data);
342 const index_t cur_w = static_cast<index_t>(cur_inv_w_data);
343 for (int dy = -2; dy <= 2; dy++) {
344 for (int dx = -2; dx <= 2; dx++) {
345 if (cur_h + dy >= 0 && cur_h + dy < height &&
346 cur_w + dx >= 0 && cur_w + dx < width &&
347 std::abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
348 std::abs(cur_inv_w_data - (cur_w + dx)) < 1
349 ) {
350 index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx;
351 DType weight = get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data,
352 cur_h + dy, cur_w + dx, height, width);
353 grad_im[cur_bottom_grad_pos] += weight * cur_top_grad;
354 }
355 }
356 }
357 }
358 }
359
360
361 /*!\brief
362 * cpu function of deformable_col2im algorithm
363 * \param s device stream
364 * \param data_col start pointer of the column buffer to be filled
365 * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
366 * \param im_shape input image shape in dimensions (N, C, H, W,)
367 * \param col_shape column buffer shape
368 * \param kernel_shape kernel filter shape
369 * \param pad pad shape
370 * \param stride stride shape
371 * \param dilation dilation shape
372 * \param deformable_group #offset group that deformable convolution use
373 * \param grad_im pointer of a image (C, H, W,...) in the image batch
374 */
375 template <typename DType>
deformable_col2im(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * grad_im)376 inline void deformable_col2im(mshadow::Stream<cpu>* s,
377 const DType* data_col,
378 const DType* data_offset,
379 const mxnet::TShape& im_shape,
380 const mxnet::TShape& col_shape,
381 const mxnet::TShape& kernel_shape,
382 const mxnet::TShape& pad,
383 const mxnet::TShape& stride,
384 const mxnet::TShape& dilation,
385 const index_t deformable_group,
386 DType* grad_im) {
387 if (2 == kernel_shape.ndim()) {
388 deformable_col2im_cpu(data_col, data_offset,
389 im_shape[1], im_shape[2], im_shape[3],
390 kernel_shape[0], kernel_shape[1],
391 pad[0], pad[1], stride[0], stride[1],
392 dilation[0], dilation[1],
393 deformable_group,
394 col_shape[1], col_shape[2], grad_im);
395 } else {
396 LOG(FATAL) << "not implemented";
397 }
398 }
399
400
401 /*!
402 * \brief deformable_col2im_coord cpu version.
403 * DO NOT call this directly.
404 * Use wrapper function deformable_col2im_coord() instead;
405 */
406 template <typename DType>
deformable_col2im_coord_cpu(const DType * data_col,const DType * data_im,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * grad_offset)407 inline void deformable_col2im_coord_cpu(const DType* data_col,
408 const DType* data_im,
409 const DType* data_offset,
410 const index_t channels,
411 const index_t height, const index_t width,
412 const index_t kernel_h, const index_t kernel_w,
413 const index_t pad_h, const index_t pad_w,
414 const index_t stride_h, const index_t stride_w,
415 const index_t dilation_h, const index_t dilation_w,
416 const index_t deformable_group,
417 const index_t height_col, const index_t width_col,
418 DType* grad_offset) {
419 index_t channel_per_group = channels * kernel_h * kernel_w / deformable_group;
420 index_t count = height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
421 for (index_t index = 0; index < count; ++index) {
422 DType val = 0;
423 index_t w = index % width_col;
424 index_t h = (index / width_col) % height_col;
425 index_t c = index / width_col / height_col;
426 // compute the start and end of the output
427
428 const index_t group_index = c / (2 * kernel_h * kernel_w);
429 const index_t group_col_step = channel_per_group * width_col * height_col;
430 const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width;
431 const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col;
432 const index_t col_step = kernel_h * kernel_w;
433 const DType* data_col_ptr = data_col + group_index * group_col_step;
434 const DType* data_im_ptr = data_im + group_index * group_im_step;
435 const DType* data_offset_ptr = data_offset + group_index * group_offset_step;
436
437 index_t cnt = 0;
438 const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w;
439
440 for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) {
441 const index_t col_pos = ((col_c * height_col) + h) * width_col + w;
442 const index_t bp_dir = offset_c % 2;
443
444 index_t j = (col_pos / width_col / height_col) % kernel_w;
445 index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h;
446 index_t w_col = col_pos % width_col;
447 index_t h_col = (col_pos / width_col) % height_col;
448 index_t w_in = w_col * stride_w - pad_w;
449 index_t h_in = h_col * stride_h - pad_h;
450 const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) *
451 height_col + h_col) * width_col + w_col;
452 const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col;
453 const DType offset_h = data_offset_ptr[data_offset_h_ptr];
454 const DType offset_w = data_offset_ptr[data_offset_w_ptr];
455 DType inv_h = h_in + i * dilation_h + offset_h;
456 DType inv_w = w_in + j * dilation_w + offset_w;
457 if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) {
458 inv_h = inv_w = -1;
459 }
460 const DType weight = get_coordinate_weight_cpu(inv_h, inv_w, height, width,
461 data_im_ptr + cnt * height * width,
462 width, bp_dir);
463 val += weight * data_col_ptr[col_pos];
464 cnt += 1;
465 }
466
467 grad_offset[index] = val;
468 }
469 }
470
471
472 /*!\brief
473 * cpu function of deformable_col2im_coord algorithm
474 * \param s device stream
475 * \param data_col start pointer of the column buffer to be filled
476 * \param data_im pointer of an image (C, H, W, ...) in the image batch
477 * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
478 * \param im_shape input image shape in dimensions (N, C, H, W,)
479 * \param col_shape column buffer shape
480 * \param kernel_shape kernel filter shape
481 * \param pad pad shape
482 * \param stride stride shape
483 * \param dilation dilation shape
484 * \param deformable_group #offset group that deformable convolution use
485 * \param grad_offset pointer of the offset (C, H, W,...) in the offset batch
486 */
487 template <typename DType>
deformable_col2im_coord(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_im,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * grad_offset)488 inline void deformable_col2im_coord(mshadow::Stream<cpu>* s,
489 const DType* data_col,
490 const DType* data_im,
491 const DType* data_offset,
492 const mxnet::TShape& im_shape,
493 const mxnet::TShape& col_shape,
494 const mxnet::TShape& kernel_shape,
495 const mxnet::TShape& pad,
496 const mxnet::TShape& stride,
497 const mxnet::TShape& dilation,
498 const index_t deformable_group,
499 DType* grad_offset) {
500 if (2 == kernel_shape.ndim()) {
501 deformable_col2im_coord_cpu(data_col, data_im, data_offset,
502 im_shape[1], im_shape[2], im_shape[3],
503 kernel_shape[0], kernel_shape[1],
504 pad[0], pad[1], stride[0], stride[1],
505 dilation[0], dilation[1],
506 deformable_group,
507 col_shape[1], col_shape[2], grad_offset);
508 } else {
509 LOG(FATAL) << "not implemented";
510 }
511 }
512
513 } // namespace op
514 } // namespace mxnet
515 #ifdef __CUDACC__
516 #include "./deformable_im2col.cuh"
517 #endif
518 #endif // MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
519