1 /*!
2  ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3  *
4  * COPYRIGHT
5  *
6  * All contributions by the University of California:
7  * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8  * All rights reserved.
9  *
10  * All other contributions:
11  * Copyright (c) 2014-2017, the respective contributors
12  * All rights reserved.
13  *
14  * Caffe uses a shared copyright model: each contributor holds copyright over
15  * their contributions to Caffe. The project versioning records all such
16  * contribution and copyright details. If a contributor wants to further mark
17  * their specific copyright on a particular contribution, they should indicate
18  * their copyright solely in the commit message of the change when it is
19  * committed.
20  *
21  * LICENSE
22  *
23  * Redistribution and use in source and binary forms, with or without
24  * modification, are permitted provided that the following conditions are met:
25  *
26  * 1. Redistributions of source code must retain the above copyright notice, this
27  * list of conditions and the following disclaimer.
28  * 2. Redistributions in binary form must reproduce the above copyright notice,
29  * this list of conditions and the following disclaimer in the documentation
30  * and/or other materials provided with the distribution.
31  *
32  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42  *
43  * CONTRIBUTION AGREEMENT
44  *
45  * By contributing to the BVLC/caffe repository through pull-request, comment,
46  * or otherwise, the contributor releases their content to the
47  * license and copyright terms herein.
48  *
49  ***************** END Caffe Copyright Notice and Disclaimer ********************
50  *
51  * Copyright (c) 2018 Microsoft
52  * Licensed under The MIT License [see LICENSE for details]
53  * \file deformable_im2col.h
54  * \brief Function definitions of converting an image to
55  * column matrix based on kernel, padding, dilation, and offset.
56  * These functions are mainly used in deformable convolution operators.
57  * \ref: https://arxiv.org/abs/1811.11168
58  * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
59  */
60 
61 #ifndef MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
62 #define MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
63 
64 #include <mxnet/base.h>
65 #include <mxnet/operator.h>
66 #include <cstring>
67 #include <vector>
68 #include <algorithm>
69 #include "../../mxnet_op.h"
70 
71 namespace mxnet {
72 namespace op {
73 
74 template <typename DType>
im2col_bilinear_cpu(const DType * data,const index_t height,const index_t width,DType h,DType w)75 inline DType im2col_bilinear_cpu(const DType* data,
76                                  const index_t height,
77                                  const index_t width,
78                                  DType h, DType w) {
79   index_t h_low = floor(h);
80   index_t w_low = floor(w);
81   index_t h_high;
82   index_t w_high;
83 
84   if (h_low >= height - 1) {
85     h_high = height - 1;
86     h = static_cast<DType>(h_low);
87   } else {
88     h_high = h_low + 1;
89   }
90 
91   if (w_low >= width - 1) {
92     w_high = width - 1;
93     w = static_cast<DType>(w_low);
94   } else {
95     w_high = w_low + 1;
96   }
97 
98   DType lh = h - h_low;
99   DType lw = w - w_low;
100   DType hh = 1 - lh, hw = 1 - lw;
101 
102   DType v1 = data[h_low * width + w_low];
103   DType v2 = data[h_low * width + w_high];
104   DType v3 = data[h_high * width + w_low];
105   DType v4 = data[h_high * width + w_high];
106   DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
107 
108   return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
109 }
110 
111 
112 template <typename DType>
get_gradient_weight_cpu(DType argmax_h,DType argmax_w,const index_t h,const index_t w,const index_t height,const index_t width)113 inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w,
114                                     const index_t h, const index_t w,
115                                     const index_t height, const index_t width) {
116   if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
117     // empty
118     return 0;
119   }
120 
121   argmax_h = std::max(argmax_h, static_cast<DType>(0.0f));
122   argmax_w = std::max(argmax_w, static_cast<DType>(0.0f));
123 
124   index_t argmax_h_low = static_cast<index_t>(argmax_h);
125   index_t argmax_w_low = static_cast<index_t>(argmax_w);
126   index_t argmax_h_high;
127   index_t argmax_w_high;
128   if (argmax_h_low >= height - 1) {
129     argmax_h_high = argmax_h_low = height - 1;
130     argmax_h = static_cast<DType>(argmax_h_low);
131   } else {
132     argmax_h_high = argmax_h_low + 1;
133   }
134   if (argmax_w_low >= width - 1) {
135     argmax_w_high = argmax_w_low = width - 1;
136     argmax_w = static_cast<DType>(argmax_w_low);
137   } else {
138     argmax_w_high = argmax_w_low + 1;
139   }
140   DType weight = 0;
141   if (h == argmax_h_low) {
142     if (w == argmax_w_low) {
143       weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
144     } else if (w == argmax_w_high) {
145       weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
146     }
147   } else if (h == argmax_h_high) {
148     if (w == argmax_w_low) {
149       weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
150     } else if (w == argmax_w_high) {
151       weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
152     }
153   }
154   return weight;
155 }
156 
157 
158 template <typename DType>
get_coordinate_weight_cpu(DType argmax_h,DType argmax_w,const index_t height,const index_t width,const DType * im_data,const index_t data_width,const index_t bp_dir)159 inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w,
160                                        const index_t height, const index_t width,
161                                        const DType* im_data,
162                                        const index_t data_width, const index_t bp_dir) {
163   if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) {
164     // empty
165     return 0;
166   }
167 
168   if (argmax_h < 0) argmax_h = 0;
169   if (argmax_w < 0) argmax_w = 0;
170 
171   index_t argmax_h_low = static_cast<index_t>(argmax_h);
172   index_t argmax_w_low = static_cast<index_t>(argmax_w);
173   index_t argmax_h_high;
174   index_t argmax_w_high;
175   if (argmax_h_low >= height - 1) {
176     argmax_h_high = argmax_h_low = height - 1;
177     argmax_h = static_cast<DType>(argmax_h_low);
178   } else {
179     argmax_h_high = argmax_h_low + 1;
180   }
181   if (argmax_w_low >= width - 1) {
182     argmax_w_high = argmax_w_low = width - 1;
183     argmax_w = static_cast<DType>(argmax_w_low);
184   } else {
185     argmax_w_high = argmax_w_low + 1;
186   }
187 
188   DType weight = 0;
189   DType im_ll = im_data[argmax_h_low * data_width + argmax_w_low];
190   DType im_lh = im_data[argmax_h_low * data_width + argmax_w_high];
191   DType im_hl = im_data[argmax_h_high * data_width + argmax_w_low];
192   DType im_hh = im_data[argmax_h_high * data_width + argmax_w_high];
193   if (bp_dir == 0) {
194     weight += -1 * (argmax_w_low + 1 - argmax_w) * im_ll;
195     weight += -1 * (argmax_w - argmax_w_low) * im_lh;
196     weight += (argmax_w_low + 1 - argmax_w) * im_hl;
197     weight += (argmax_w - argmax_w_low) * im_hh;
198   } else if (bp_dir == 1) {
199     weight += -1 * (argmax_h_low + 1 - argmax_h) * im_ll;
200     weight += (argmax_h_low + 1 - argmax_h) * im_lh;
201     weight += -1 * (argmax_h - argmax_h_low) * im_hl;
202     weight += (argmax_h - argmax_h_low) * im_hh;
203   }
204 
205   return weight;
206 }
207 
208 
209 /*!
210  * \brief deformable_im2col 2D cpu version.
211  * DO NOT call this function directly.
212  * Use the wrapper function im2col() instead.
213  */
214 template <typename DType>
deformable_im2col_cpu(const DType * data_im,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * data_col)215 inline void deformable_im2col_cpu(const DType* data_im,
216                                   const DType* data_offset,
217                                   const index_t channels,
218                                   const index_t height, const index_t width,
219                                   const index_t kernel_h, const index_t kernel_w,
220                                   const index_t pad_h, const index_t pad_w,
221                                   const index_t stride_h, const index_t stride_w,
222                                   const index_t dilation_h, const index_t dilation_w,
223                                   const index_t deformable_group,
224                                   const index_t height_col, const index_t width_col,
225                                   DType* data_col) {
226   const index_t channel_size = height * width;
227   const index_t offset_size = 2 * kernel_h * kernel_w * height_col * width_col;
228   const index_t channel_per_group = channels / deformable_group;
229   for (index_t channel = 0; channel < channels; channel++, data_im += channel_size) {
230     if (channel % channel_per_group == 0 && channel != 0) {
231       data_offset += offset_size;
232     }
233     for (index_t i = 0; i < kernel_h; i++) {
234       for (index_t j = 0; j < kernel_w; j++) {
235         index_t input_row = -pad_h + i * dilation_h;
236         for (index_t h_col = 0; h_col < height_col; h_col++) {
237           index_t input_col = -pad_w + j * dilation_w;
238           for (index_t w_col = 0; w_col < width_col; w_col++) {
239             index_t offset_h_ptr = ((2 * (i * kernel_w + j)) *
240               height_col + h_col) * width_col + w_col;
241             index_t offset_w_ptr = offset_h_ptr + height_col * width_col;
242             DType im_row = input_row + data_offset[offset_h_ptr];
243             DType im_col = input_col + data_offset[offset_w_ptr];
244             if (im_row >= 0 && im_col >= 0 && im_row < height && im_col < width) {
245               *(data_col++) = im2col_bilinear_cpu(data_im, height, width, im_row, im_col);
246             } else {
247               *(data_col++) = 0;
248             }
249             input_col += stride_w;
250           }
251           input_row += stride_h;
252         }
253       }
254     }
255   }
256 }
257 
258 
259 /*!\brief
260  * cpu function of deformable_im2col algorithm
261  * \param s device stream
262  * \param data_im pointer of an image (C, H, W, ...) in the image batch
263  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
264  * \param im_shape input image shape in dimensions (N, C, H, W,)
265  * \param col_shape column buffer shape (#channels, output_im_height, output_im_width, ...)
266  * \param kernel_shape kernel filter shape
267  * \param pad pad shape
268  * \param stride stride shape
269  * \param dilation dilation shape
270  * \param deformable_group #offset group that deformable convolution use
271  * \param data_col column buffer pointer
272  */
273 template <typename DType>
deformable_im2col(mshadow::Stream<cpu> * s,const DType * data_im,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * data_col)274 inline void deformable_im2col(mshadow::Stream<cpu>* s,
275                               const DType* data_im, const DType* data_offset,
276                               const mxnet::TShape& im_shape,
277                               const mxnet::TShape& col_shape,
278                               const mxnet::TShape& kernel_shape,
279                               const mxnet::TShape& pad,
280                               const mxnet::TShape& stride,
281                               const mxnet::TShape& dilation,
282                               const index_t deformable_group,
283                               DType* data_col) {
284   if (2 == kernel_shape.ndim()) {
285     deformable_im2col_cpu(data_im, data_offset,
286                           im_shape[1], im_shape[2], im_shape[3],
287                           kernel_shape[0], kernel_shape[1],
288                           pad[0], pad[1],
289                           stride[0], stride[1],
290                           dilation[0], dilation[1],
291                           deformable_group,
292                           col_shape[1], col_shape[2], data_col);
293   } else {
294     LOG(FATAL) << "not implemented";
295   }
296 }
297 
298 
299 /*!
300  * \brief deformable_col2im cpu version.
301  * DO NOT call this directly.
302  * Use wrapper function deformable_col2im() instead;
303  */
304 template <typename DType>
deformable_col2im_cpu(const DType * data_col,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * grad_im)305 inline void deformable_col2im_cpu(const DType* data_col,
306                                   const DType* data_offset, const index_t channels,
307                                   const index_t height, const index_t width,
308                                   const index_t kernel_h, const index_t kernel_w,
309                                   const index_t pad_h, const index_t pad_w,
310                                   const index_t stride_h, const index_t stride_w,
311                                   const index_t dilation_h, const index_t dilation_w,
312                                   const index_t deformable_group,
313                                   const index_t height_col, const index_t width_col,
314                                   DType* grad_im) {
315   index_t channel_per_group = channels / deformable_group;
316   index_t count = channels * kernel_h * kernel_w * height_col * width_col;
317   for (index_t index = 0; index < count; ++index) {
318     const index_t j = (index / width_col / height_col) % kernel_w;
319     const index_t i = (index / width_col / height_col / kernel_w) % kernel_h;
320     const index_t c = index / width_col / height_col / kernel_w / kernel_h;
321     // compute the start and end of the output
322 
323     const index_t group_index = c / channel_per_group;
324     const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col;
325 
326     index_t w_col = index % width_col;
327     index_t h_col = (index / width_col) % height_col;
328     index_t w_in = w_col * stride_w - pad_w;
329     index_t h_in = h_col * stride_h - pad_h;
330 
331     const DType* data_offset_ptr = data_offset + group_index * group_offset_step;
332     const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) *
333       height_col + h_col) * width_col + w_col;
334     const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col;
335     const DType offset_h = data_offset_ptr[data_offset_h_ptr];
336     const DType offset_w = data_offset_ptr[data_offset_w_ptr];
337     const DType cur_inv_h_data = h_in + i * dilation_h + offset_h;
338     const DType cur_inv_w_data = w_in + j * dilation_w + offset_w;
339 
340     const DType cur_top_grad = data_col[index];
341     const index_t cur_h = static_cast<index_t>(cur_inv_h_data);
342     const index_t cur_w = static_cast<index_t>(cur_inv_w_data);
343     for (int dy = -2; dy <= 2; dy++) {
344       for (int dx = -2; dx <= 2; dx++) {
345         if (cur_h + dy >= 0 && cur_h + dy < height &&
346           cur_w + dx >= 0 && cur_w + dx < width &&
347           std::abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
348           std::abs(cur_inv_w_data - (cur_w + dx)) < 1
349           ) {
350           index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx;
351           DType weight = get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data,
352                                                  cur_h + dy, cur_w + dx, height, width);
353           grad_im[cur_bottom_grad_pos] += weight * cur_top_grad;
354         }
355       }
356     }
357   }
358 }
359 
360 
361 /*!\brief
362  * cpu function of deformable_col2im algorithm
363  * \param s device stream
364  * \param data_col start pointer of the column buffer to be filled
365  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
366  * \param im_shape input image shape in dimensions (N, C, H, W,)
367  * \param col_shape column buffer shape
368  * \param kernel_shape kernel filter shape
369  * \param pad pad shape
370  * \param stride stride shape
371  * \param dilation dilation shape
372  * \param deformable_group #offset group that deformable convolution use
373  * \param grad_im pointer of a image (C, H, W,...) in the image batch
374  */
375 template <typename DType>
deformable_col2im(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * grad_im)376 inline void deformable_col2im(mshadow::Stream<cpu>* s,
377                               const DType* data_col,
378                               const DType* data_offset,
379                               const mxnet::TShape& im_shape,
380                               const mxnet::TShape& col_shape,
381                               const mxnet::TShape& kernel_shape,
382                               const mxnet::TShape& pad,
383                               const mxnet::TShape& stride,
384                               const mxnet::TShape& dilation,
385                               const index_t deformable_group,
386                               DType* grad_im) {
387   if (2 == kernel_shape.ndim()) {
388     deformable_col2im_cpu(data_col, data_offset,
389                           im_shape[1], im_shape[2], im_shape[3],
390                           kernel_shape[0], kernel_shape[1],
391                           pad[0], pad[1], stride[0], stride[1],
392                           dilation[0], dilation[1],
393                           deformable_group,
394                           col_shape[1], col_shape[2], grad_im);
395   } else {
396     LOG(FATAL) << "not implemented";
397   }
398 }
399 
400 
401 /*!
402  * \brief deformable_col2im_coord cpu version.
403  * DO NOT call this directly.
404  * Use wrapper function deformable_col2im_coord() instead;
405  */
406 template <typename DType>
deformable_col2im_coord_cpu(const DType * data_col,const DType * data_im,const DType * data_offset,const index_t channels,const index_t height,const index_t width,const index_t kernel_h,const index_t kernel_w,const index_t pad_h,const index_t pad_w,const index_t stride_h,const index_t stride_w,const index_t dilation_h,const index_t dilation_w,const index_t deformable_group,const index_t height_col,const index_t width_col,DType * grad_offset)407 inline void deformable_col2im_coord_cpu(const DType* data_col,
408                                         const DType* data_im,
409                                         const DType* data_offset,
410                                         const index_t channels,
411                                         const index_t height, const index_t width,
412                                         const index_t kernel_h, const index_t kernel_w,
413                                         const index_t pad_h, const index_t pad_w,
414                                         const index_t stride_h, const index_t stride_w,
415                                         const index_t dilation_h, const index_t dilation_w,
416                                         const index_t deformable_group,
417                                         const index_t height_col, const index_t width_col,
418                                         DType* grad_offset) {
419   index_t channel_per_group = channels * kernel_h * kernel_w / deformable_group;
420   index_t count = height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
421   for (index_t index = 0; index < count; ++index) {
422     DType val = 0;
423     index_t w = index % width_col;
424     index_t h = (index / width_col) % height_col;
425     index_t c = index / width_col / height_col;
426     // compute the start and end of the output
427 
428     const index_t group_index = c / (2 * kernel_h * kernel_w);
429     const index_t group_col_step = channel_per_group * width_col * height_col;
430     const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width;
431     const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col;
432     const index_t col_step = kernel_h * kernel_w;
433     const DType* data_col_ptr = data_col + group_index * group_col_step;
434     const DType* data_im_ptr = data_im + group_index * group_im_step;
435     const DType* data_offset_ptr = data_offset + group_index * group_offset_step;
436 
437     index_t cnt = 0;
438     const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w;
439 
440     for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) {
441       const index_t col_pos = ((col_c * height_col) + h) * width_col + w;
442       const index_t bp_dir = offset_c % 2;
443 
444       index_t j = (col_pos / width_col / height_col) % kernel_w;
445       index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h;
446       index_t w_col = col_pos % width_col;
447       index_t h_col = (col_pos / width_col) % height_col;
448       index_t w_in = w_col * stride_w - pad_w;
449       index_t h_in = h_col * stride_h - pad_h;
450       const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) *
451         height_col + h_col) * width_col + w_col;
452       const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col;
453       const DType offset_h = data_offset_ptr[data_offset_h_ptr];
454       const DType offset_w = data_offset_ptr[data_offset_w_ptr];
455       DType inv_h = h_in + i * dilation_h + offset_h;
456       DType inv_w = w_in + j * dilation_w + offset_w;
457       if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) {
458         inv_h = inv_w = -1;
459       }
460       const DType weight = get_coordinate_weight_cpu(inv_h, inv_w, height, width,
461                                                      data_im_ptr + cnt * height * width,
462                                                      width, bp_dir);
463       val += weight * data_col_ptr[col_pos];
464       cnt += 1;
465     }
466 
467     grad_offset[index] = val;
468   }
469 }
470 
471 
472 /*!\brief
473  * cpu function of deformable_col2im_coord algorithm
474  * \param s device stream
475  * \param data_col start pointer of the column buffer to be filled
476  * \param data_im pointer of an image (C, H, W, ...) in the image batch
477  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
478  * \param im_shape input image shape in dimensions (N, C, H, W,)
479  * \param col_shape column buffer shape
480  * \param kernel_shape kernel filter shape
481  * \param pad pad shape
482  * \param stride stride shape
483  * \param dilation dilation shape
484  * \param deformable_group #offset group that deformable convolution use
485  * \param grad_offset pointer of the offset (C, H, W,...) in the offset batch
486  */
487 template <typename DType>
deformable_col2im_coord(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_im,const DType * data_offset,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,const index_t deformable_group,DType * grad_offset)488 inline void deformable_col2im_coord(mshadow::Stream<cpu>* s,
489                                     const DType* data_col,
490                                     const DType* data_im,
491                                     const DType* data_offset,
492                                     const mxnet::TShape& im_shape,
493                                     const mxnet::TShape& col_shape,
494                                     const mxnet::TShape& kernel_shape,
495                                     const mxnet::TShape& pad,
496                                     const mxnet::TShape& stride,
497                                     const mxnet::TShape& dilation,
498                                     const index_t deformable_group,
499                                     DType* grad_offset) {
500   if (2 == kernel_shape.ndim()) {
501     deformable_col2im_coord_cpu(data_col, data_im, data_offset,
502                                 im_shape[1], im_shape[2], im_shape[3],
503                                 kernel_shape[0], kernel_shape[1],
504                                 pad[0], pad[1], stride[0], stride[1],
505                                 dilation[0], dilation[1],
506                                 deformable_group,
507                                 col_shape[1], col_shape[2], grad_offset);
508   } else {
509     LOG(FATAL) << "not implemented";
510   }
511 }
512 
513 }  // namespace op
514 }  // namespace mxnet
515 #ifdef __CUDACC__
516 #include "./deformable_im2col.cuh"
517 #endif
518 #endif  // MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_IM2COL_H_
519