1 /*!
2  ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3  *
4  * COPYRIGHT
5  *
6  * All contributions by the University of California:
7  * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8  * All rights reserved.
9  *
10  * All other contributions:
11  * Copyright (c) 2014-2017, the respective contributors
12  * All rights reserved.
13  *
14  * Caffe uses a shared copyright model: each contributor holds copyright over
15  * their contributions to Caffe. The project versioning records all such
16  * contribution and copyright details. If a contributor wants to further mark
17  * their specific copyright on a particular contribution, they should indicate
18  * their copyright solely in the commit message of the change when it is
19  * committed.
20  *
21  * LICENSE
22  *
23  * Redistribution and use in source and binary forms, with or without
24  * modification, are permitted provided that the following conditions are met:
25  *
26  * 1. Redistributions of source code must retain the above copyright notice, this
27  * list of conditions and the following disclaimer.
28  * 2. Redistributions in binary form must reproduce the above copyright notice,
29  * this list of conditions and the following disclaimer in the documentation
30  * and/or other materials provided with the distribution.
31  *
32  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42  *
43  * CONTRIBUTION AGREEMENT
44  *
45  * By contributing to the BVLC/caffe repository through pull-request, comment,
46  * or otherwise, the contributor releases their content to the
47  * license and copyright terms herein.
48  *
49  ***************** END Caffe Copyright Notice and Disclaimer ********************
50  *
51  * Copyright (c) 2018 Microsoft
52  * Licensed under The MIT License [see LICENSE for details]
53  * \file modulated_deformable_im2col.h
54  * \brief Function definitions of converting an image to
55  * column matrix based on kernel, padding, dilation, and offset.
56  * These functions are mainly used in deformable convolution operators.
57  * \ref: https://arxiv.org/abs/1811.11168
58  * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
59  */
60 
61 #ifndef MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_
62 #define MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_
63 
64 #include <mxnet/base.h>
65 #include <mxnet/operator.h>
66 #include <cstring>
67 #include <vector>
68 #include <cmath>
69 #include "../../mxnet_op.h"
70 
71 namespace mxnet {
72 namespace op {
73 
74 template <typename DType>
dmcn_im2col_bilinear_cpu(const DType * bottom_data,const int data_width,const int height,const int width,DType h,DType w)75 inline DType dmcn_im2col_bilinear_cpu(const DType* bottom_data, const int data_width,
76   const int height, const int width, DType h, DType w) {
77   int h_low = floor(h);
78   int w_low = floor(w);
79   int h_high = h_low + 1;
80   int w_high = w_low + 1;
81 
82   DType lh = h - h_low;
83   DType lw = w - w_low;
84   DType hh = 1 - lh, hw = 1 - lw;
85 
86   DType v1 = 0;
87   if (h_low >= 0 && w_low >= 0)
88     v1 = bottom_data[h_low * data_width + w_low];
89   DType v2 = 0;
90   if (h_low >=0 && w_high <= width - 1)
91     v2 = bottom_data[h_low * data_width + w_high];
92   DType v3 = 0;
93   if (h_high <= height - 1 && w_low >= 0)
94     v3 = bottom_data[h_high * data_width + w_low];
95   DType v4 = 0;
96   if (h_high <= height - 1 && w_high <= width - 1)
97     v4 = bottom_data[h_high * data_width + w_high];
98 
99   DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
100 
101   DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
102   return val;
103 }
104 
105 /*!
106 * \brief deformable_col2im gpu kernel.
107 * \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead;
108 */
109 struct modulated_deformable_col2im_cpu_kernel {
110   template<typename DType>
Mapmodulated_deformable_col2im_cpu_kernel111   MSHADOW_XINLINE static void Map(const int index,
112   const DType* data_im, const DType* data_offset, const DType* data_mask,
113   const int height, const int width, const int kernel_h, const int kernel_w,
114   const int pad_h, const int pad_w,
115   const int stride_h, const int stride_w,
116   const int dilation_h, const int dilation_w,
117   const int channel_per_deformable_group,
118   const int batch_size, const int num_channels, const int deformable_group,
119   const int height_col, const int width_col,
120   DType* data_col) {
121     // index index of output matrix
122     const int w_col = index % width_col;
123     const int h_col = (index / width_col) % height_col;
124     const int b_col = (index / width_col / height_col) % batch_size;
125     const int c_im = (index / width_col / height_col) / batch_size;
126     const int c_col = c_im * kernel_h * kernel_w;
127 
128     // compute deformable group index
129     const int deformable_group_index = c_im / channel_per_deformable_group;
130 
131     const int h_in = h_col * stride_h - pad_h;
132     const int w_in = w_col * stride_w - pad_w;
133 
134     DType* data_col_ptr = data_col
135       + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
136     // const DType* data_im_ptr = data_im +
137     //  ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
138     const DType* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
139     const DType* data_offset_ptr = data_offset
140       + (b_col * deformable_group + deformable_group_index) * 2
141       * kernel_h * kernel_w * height_col * width_col;
142 
143     const DType* data_mask_ptr = data_mask
144       + (b_col *  deformable_group + deformable_group_index) * kernel_h
145       * kernel_w * height_col * width_col;
146 
147     for (int i = 0; i < kernel_h; ++i) {
148       for (int j = 0; j < kernel_w; ++j) {
149         const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col)
150           * width_col + w_col;
151         const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col)
152           * width_col + w_col;
153         const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
154         const DType offset_h = data_offset_ptr[data_offset_h_ptr];
155         const DType offset_w = data_offset_ptr[data_offset_w_ptr];
156         const DType mask = data_mask_ptr[data_mask_hw_ptr];
157         DType val = static_cast<DType>(0);
158         const DType h_im = h_in + i * dilation_h + offset_h;
159         const DType w_im = w_in + j * dilation_w + offset_w;
160         // if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
161         if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
162           // const DType map_h = i * dilation_h + offset_h;
163           // const DType map_w = j * dilation_w + offset_w;
164           // const int cur_height = height - h_in;
165           // const int cur_width = width - w_in;
166           // val = dmcn_im2col_bilinear_cpu(
167           // data_im_ptr, width, cur_height, cur_width, map_h, map_w);
168           val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);
169         }
170         *data_col_ptr = val * mask;
171         data_col_ptr += batch_size * height_col * width_col;
172         // data_col_ptr += height_col * width_col;
173       }
174     }
175   }
176 };
177 
178 /*!\brief
179  * cpu function of deformable_im2col algorithm
180  * \param s device stream
181  * \param data_im pointer of an image (C, H, W, ...) in the image batch
182  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
183  * \param im_shape input image shape in dimensions (N, C, H, W,)
184  * \param col_shape column buffer shape (#channels, output_im_height, output_im_width, ...)
185  * \param kernel_shape kernel filter shape
186  * \param pad pad shape
187  * \param stride stride shape
188  * \param dilation dilation shape
189  * \param deformable_group #offset group that deformable convolution use
190  * \param data_col column buffer pointer
191  */
192 template <typename DType>
modulated_deformable_im2col(mshadow::Stream<cpu> * s,const DType * data_im,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * data_col)193 inline void modulated_deformable_im2col(mshadow::Stream<cpu>* s,
194   const DType* data_im, const DType* data_offset, const DType* data_mask,
195   const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
196   const TShape& pad, const TShape& stride, const TShape& dilation,
197   const uint32_t deformable_group, DType* data_col) {
198   // num_axes should be smaller than block size
199   index_t num_spatial_axes = kernel_shape.ndim();
200   index_t channel_per_deformable_group = im_shape[1] / deformable_group;
201   index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim());
202   using namespace mxnet_op;
203   if (2 == num_spatial_axes) {
204     Kernel<modulated_deformable_col2im_cpu_kernel, cpu>::Launch(
205         s, num_kernels, data_im, data_offset, data_mask,
206         im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1],
207         pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1],
208         channel_per_deformable_group, col_shape[1], im_shape[1], deformable_group,
209         col_shape[2], col_shape[3], data_col);
210   } else {
211     LOG(FATAL) << "not implemented";
212   }
213 }
214 
215 
216 /*!\brief
217  * cpu function of deformable_col2im algorithm
218  * \param s device stream
219  * \param data_col start pointer of the column buffer to be filled
220  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
221  * \param im_shape input image shape in dimensions (N, C, H, W,)
222  * \param col_shape column buffer shape
223  * \param kernel_shape kernel filter shape
224  * \param pad pad shape
225  * \param stride stride shape
226  * \param dilation dilation shape
227  * \param deformable_group #offset group that deformable convolution use
228  * \param grad_im pointer of a image (C, H, W,...) in the image batch
229  */
230 template <typename DType>
modulated_deformable_col2im(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * grad_im,OpReqType req)231 inline void modulated_deformable_col2im(mshadow::Stream<cpu>* s,
232   const DType* data_col, const DType* data_offset, const DType* data_mask,
233   const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
234   const TShape& pad, const TShape& stride,
235   const TShape& dilation, const uint32_t deformable_group,
236   DType* grad_im, OpReqType req) {
237   LOG(FATAL) << "only implemented in GPU";
238 }
239 
240 
241 /*!\brief
242  * cpu function of deformable_col2im_coord algorithm
243  * \param s device stream
244  * \param data_col start pointer of the column buffer to be filled
245  * \param data_im pointer of an image (C, H, W, ...) in the image batch
246  * \param data_offset pointer of offset (C, H, W, ...) in the offset batch
247  * \param im_shape input image shape in dimensions (N, C, H, W,)
248  * \param col_shape column buffer shape
249  * \param kernel_shape kernel filter shape
250  * \param pad pad shape
251  * \param stride stride shape
252  * \param dilation dilation shape
253  * \param deformable_group #offset group that deformable convolution use
254  * \param grad_offset pointer of the offset (C, H, W,...) in the offset batch
255  */
256 
257 template <typename DType>
modulated_deformable_col2im_coord(mshadow::Stream<cpu> * s,const DType * data_col,const DType * data_im,const DType * data_offset,const DType * data_mask,const TShape & im_shape,const TShape & col_shape,const TShape & kernel_shape,const TShape & pad,const TShape & stride,const TShape & dilation,const uint32_t deformable_group,DType * grad_offset,DType * grad_mask,OpReqType offset_req,OpReqType mask_req)258 inline void modulated_deformable_col2im_coord(mshadow::Stream<cpu>* s,
259   const DType* data_col, const DType* data_im, const DType* data_offset, const DType* data_mask,
260   const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape,
261   const TShape& pad, const TShape& stride,
262   const TShape& dilation, const uint32_t deformable_group,
263   DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) {
264   LOG(FATAL) << "only implemented in GPU";
265 }
266 
267 }  // namespace op
268 }  // namespace mxnet
269 #ifdef __CUDACC__
270 #include "./modulated_deformable_im2col.cuh"
271 #endif
272 #endif  // MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_
273