1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file spatial_transformer.cu
22  * \brief
23  * \author Wei Wu
24 */
25 
26 #include "./spatial_transformer-inl.h"
27 #include <algorithm>
28 #if MXNET_USE_CUDNN == 1
29 #include "./cudnn_spatial_transformer-inl.h"
30 #endif  // MXNET_USE_CUDNN
31 
32 namespace mshadow {
33 template<typename DType>
between(DType value,int lowerBound,int upperBound)34 __device__ bool between(DType value, int lowerBound, int upperBound) {
35   return (value >= lowerBound && value <= upperBound);
36 }
37 
38 template<typename DType>
39 __global__ void
40 /*
41  * In order to not generate the code that uses too many
42  * registers (resulting in too many resources requested
43  * error) we need to tell the compiler that we will be
44  * launching this kernel with cuda::kMaxThreadsPerBlock
45  * threads per block. Setting __launch_bounds__ ensures
46  * that such configuration can always be launched.
47  */
48 __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingForwardKernel(const int i_c,const int i_h,const int i_w,const DType * data,const DType * grid,const int o_n,const int o_c,const int o_h,const int o_w,DType * out)49 BilinearSamplingForwardKernel(const int i_c, const int i_h,
50                               const int i_w, const DType* data,
51                               const DType* grid, const int o_n,
52                               const int o_c, const int o_h,
53                               const int o_w, DType* out) {
54   for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
55        index < o_n * o_c * o_h * o_w;
56        index += blockDim.x * gridDim.x * gridDim.y) {
57     // (n, c, h, w) is the element in out
58     int w = index % o_w;
59     int h = (index / o_w) % o_h;
60     int c = (index / o_w / o_h) % o_c;
61     int n = index / o_w / o_h / o_c;
62     index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
63     index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
64     DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
65     DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
66     int top_left_y = static_cast<int>(floor(y_real));
67     int top_left_x = static_cast<int>(floor(x_real));
68     DType top_left_y_w = 1.0 - (y_real - top_left_y);
69     DType top_left_x_w = 1.0 - (x_real - top_left_x);
70     int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
71     DType top_left_v = 0;
72     DType top_right_v = 0;
73     DType bottom_left_v = 0;
74     DType bottom_right_v = 0;
75     if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
76       top_left_v = *(data + data_index);
77     if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
78       top_right_v = *(data + data_index + 1);
79     if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
80       bottom_left_v = *(data + data_index + i_w);
81     if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
82       bottom_right_v = *(data + data_index + i_w + 1);
83     *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
84                         top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
85                         bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
86                         bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
87     }
88 }
89 
90 /*
91  * In order to not generate the code that uses too many
92  * registers (resulting in too many resources requested
93  * error) we need to tell the compiler that we will be
94  * launching this kernel with cuda::kMaxThreadsPerBlock
95  * threads per block. Setting __launch_bounds__ ensures
96  * that such configuration can always be launched.
97  */
98 template<typename DType>
99 __global__ void
100 __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingBackwardKernel(const int i_c,const int i_h,const int i_w,const DType * grad,const DType * data,const int o_n,const int o_c,const int o_h,const int o_w,DType * g_input,DType * grid_src)101 BilinearSamplingBackwardKernel(const int i_c, const int i_h,
102                                const int i_w, const DType* grad,
103                                const DType* data, const int o_n,
104                                const int o_c, const int o_h,
105                                const int o_w, DType* g_input,
106                                DType* grid_src) {
107   for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
108        index < o_n * o_h * o_w;
109        index += blockDim.x * gridDim.x * gridDim.y) {
110     // (n, c, h, w) is the element in grad
111     int w = index % o_w;
112     int h = (index / o_w) % o_h;
113     int n = index / o_w / o_h;
114     DType top_left_y_gw = 0.0;
115     DType top_left_x_gw = 0.0;
116     index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
117     DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
118     DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
119     int top_left_y = static_cast<int>(floor(y_real));
120     int top_left_x = static_cast<int>(floor(x_real));
121     DType top_left_y_w = 1.0 - (y_real - top_left_y);
122     DType top_left_x_w = 1.0 - (x_real - top_left_x);
123     for (index_t c = 0; c < o_c; ++c) {
124       index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
125       int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
126       // calc 4 vertex value in input data
127       DType top_left_v = 0;
128       DType top_right_v = 0;
129       DType bottom_left_v = 0;
130       DType bottom_right_v = 0;
131       // calc input grad
132       if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
133         atomicAdd((g_input + data_index), *(grad + grad_index) * top_left_y_w * top_left_x_w);
134         top_left_v = *(data + data_index);
135       }
136       if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
137         atomicAdd((g_input + data_index + 1),
138                   *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
139         top_right_v = *(data + data_index + 1);
140       }
141       if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
142         atomicAdd((g_input + data_index + i_w),
143                   *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
144         bottom_left_v = *(data + data_index + i_w);
145       }
146       if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
147         atomicAdd((g_input + data_index + i_w + 1),
148                   *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
149         bottom_right_v = *(data + data_index + i_w + 1);
150       }
151       // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
152       top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
153                        (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
154                        * top_left_x_w);
155       top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
156                        (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
157                        * top_left_y_w);
158     }
159     // calc grid_src grad
160     *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
161     *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
162   }
163 }
164 
165 template<typename DType>
BilinearSamplingForward(const Tensor<gpu,4,DType> & output,const Tensor<gpu,4,DType> & input,const Tensor<gpu,3,DType> grid_src)166 inline void BilinearSamplingForward(const Tensor<gpu, 4, DType> &output,
167                                     const Tensor<gpu, 4, DType> &input,
168                                     const Tensor<gpu, 3, DType> grid_src) {
169     DType *out = output.dptr_;
170     const DType *data = input.dptr_;
171     const DType *grid = grid_src.dptr_;
172     int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
173     int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
174     using namespace cuda;
175     const int max_block = (output.shape_.Size() + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
176     dim3 num_blocks(kMaxGridDim, (max_block + kMaxGridDim - 1) / kMaxGridDim);
177     dim3 threads_per_block(kMaxThreadsPerBlock);
178     CheckLaunchParam(num_blocks, threads_per_block, "spatial transformer forward");
179     cudaStream_t stream = Stream<gpu>::GetStream(output.stream_);
180     BilinearSamplingForwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
181       i_c, i_h, i_w, data, grid, o_n, o_c, o_h, o_w, out);
182     MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingForwardKernel);
183 }
184 
185 template<typename DType>
BilinearSamplingBackward(const Tensor<gpu,4,DType> & input_grad,const Tensor<gpu,3,DType> & grid_src_data,const Tensor<gpu,4,DType> & output_grad,const Tensor<gpu,4,DType> & input_data)186 inline void BilinearSamplingBackward(const Tensor<gpu, 4, DType> &input_grad,
187                                      const Tensor<gpu, 3, DType> &grid_src_data,
188                                      const Tensor<gpu, 4, DType> &output_grad,
189                                      const Tensor<gpu, 4, DType> &input_data) {
190   DType *g_input = input_grad.dptr_;
191   DType *grid_src = grid_src_data.dptr_;
192   const DType *grad = output_grad.dptr_;
193   const DType *data = input_data.dptr_;
194   int o_n = output_grad.size(0), o_c = output_grad.size(1),
195       o_h = output_grad.size(2), o_w = output_grad.size(3);
196   int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
197   using namespace cuda;
198   const int max_block = (output_grad.shape_.Size() / o_c + kMaxThreadsPerBlock - 1)
199                         / kMaxThreadsPerBlock;
200   dim3 num_blocks(kMaxGridDim, (max_block + kMaxGridDim - 1) / kMaxGridDim);
201   dim3 threads_per_block(kMaxThreadsPerBlock);
202   CheckLaunchParam(num_blocks, threads_per_block, "spatial transformer backward");
203   cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
204   BilinearSamplingBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
205     i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src);
206   MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingBackwardKernel);
207 }
208 
209 }  // namespace mshadow
210 
211 namespace mxnet {
212 namespace op {
213 template<>
CreateOp(SpatialTransformerParam param,int dtype)214 Operator* CreateOp<gpu>(SpatialTransformerParam param, int dtype) {
215   Operator *op = nullptr;
216 #if MXNET_USE_CUDNN == 1
217   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
218     if (param.cudnn_off.has_value() && param.cudnn_off.value()) {
219       op = new SpatialTransformerOp<gpu, DType>(param);
220     } else {
221       op = new CuDNNSpatialTransformerOp<DType>(param);
222     }
223   })
224 #else
225   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
226     op = new SpatialTransformerOp<gpu, DType>(param);
227   })
228 #endif  // MXNET_USE_CUDNN
229   return op;
230 }
231 
232 }  // namespace op
233 }  // namespace mxnet
234