1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file spatial_transformer.cu
22 * \brief
23 * \author Wei Wu
24 */
25
26 #include "./spatial_transformer-inl.h"
27 #include <algorithm>
28 #if MXNET_USE_CUDNN == 1
29 #include "./cudnn_spatial_transformer-inl.h"
30 #endif // MXNET_USE_CUDNN
31
32 namespace mshadow {
33 template<typename DType>
between(DType value,int lowerBound,int upperBound)34 __device__ bool between(DType value, int lowerBound, int upperBound) {
35 return (value >= lowerBound && value <= upperBound);
36 }
37
38 template<typename DType>
39 __global__ void
40 /*
41 * In order to not generate the code that uses too many
42 * registers (resulting in too many resources requested
43 * error) we need to tell the compiler that we will be
44 * launching this kernel with cuda::kMaxThreadsPerBlock
45 * threads per block. Setting __launch_bounds__ ensures
46 * that such configuration can always be launched.
47 */
48 __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingForwardKernel(const int i_c,const int i_h,const int i_w,const DType * data,const DType * grid,const int o_n,const int o_c,const int o_h,const int o_w,DType * out)49 BilinearSamplingForwardKernel(const int i_c, const int i_h,
50 const int i_w, const DType* data,
51 const DType* grid, const int o_n,
52 const int o_c, const int o_h,
53 const int o_w, DType* out) {
54 for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
55 index < o_n * o_c * o_h * o_w;
56 index += blockDim.x * gridDim.x * gridDim.y) {
57 // (n, c, h, w) is the element in out
58 int w = index % o_w;
59 int h = (index / o_w) % o_h;
60 int c = (index / o_w / o_h) % o_c;
61 int n = index / o_w / o_h / o_c;
62 index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
63 index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
64 DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
65 DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
66 int top_left_y = static_cast<int>(floor(y_real));
67 int top_left_x = static_cast<int>(floor(x_real));
68 DType top_left_y_w = 1.0 - (y_real - top_left_y);
69 DType top_left_x_w = 1.0 - (x_real - top_left_x);
70 int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
71 DType top_left_v = 0;
72 DType top_right_v = 0;
73 DType bottom_left_v = 0;
74 DType bottom_right_v = 0;
75 if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
76 top_left_v = *(data + data_index);
77 if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
78 top_right_v = *(data + data_index + 1);
79 if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
80 bottom_left_v = *(data + data_index + i_w);
81 if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
82 bottom_right_v = *(data + data_index + i_w + 1);
83 *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
84 top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
85 bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
86 bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
87 }
88 }
89
90 /*
91 * In order to not generate the code that uses too many
92 * registers (resulting in too many resources requested
93 * error) we need to tell the compiler that we will be
94 * launching this kernel with cuda::kMaxThreadsPerBlock
95 * threads per block. Setting __launch_bounds__ ensures
96 * that such configuration can always be launched.
97 */
98 template<typename DType>
99 __global__ void
100 __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingBackwardKernel(const int i_c,const int i_h,const int i_w,const DType * grad,const DType * data,const int o_n,const int o_c,const int o_h,const int o_w,DType * g_input,DType * grid_src)101 BilinearSamplingBackwardKernel(const int i_c, const int i_h,
102 const int i_w, const DType* grad,
103 const DType* data, const int o_n,
104 const int o_c, const int o_h,
105 const int o_w, DType* g_input,
106 DType* grid_src) {
107 for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
108 index < o_n * o_h * o_w;
109 index += blockDim.x * gridDim.x * gridDim.y) {
110 // (n, c, h, w) is the element in grad
111 int w = index % o_w;
112 int h = (index / o_w) % o_h;
113 int n = index / o_w / o_h;
114 DType top_left_y_gw = 0.0;
115 DType top_left_x_gw = 0.0;
116 index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
117 DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
118 DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
119 int top_left_y = static_cast<int>(floor(y_real));
120 int top_left_x = static_cast<int>(floor(x_real));
121 DType top_left_y_w = 1.0 - (y_real - top_left_y);
122 DType top_left_x_w = 1.0 - (x_real - top_left_x);
123 for (index_t c = 0; c < o_c; ++c) {
124 index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
125 int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
126 // calc 4 vertex value in input data
127 DType top_left_v = 0;
128 DType top_right_v = 0;
129 DType bottom_left_v = 0;
130 DType bottom_right_v = 0;
131 // calc input grad
132 if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
133 atomicAdd((g_input + data_index), *(grad + grad_index) * top_left_y_w * top_left_x_w);
134 top_left_v = *(data + data_index);
135 }
136 if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
137 atomicAdd((g_input + data_index + 1),
138 *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
139 top_right_v = *(data + data_index + 1);
140 }
141 if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
142 atomicAdd((g_input + data_index + i_w),
143 *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
144 bottom_left_v = *(data + data_index + i_w);
145 }
146 if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
147 atomicAdd((g_input + data_index + i_w + 1),
148 *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
149 bottom_right_v = *(data + data_index + i_w + 1);
150 }
151 // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
152 top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
153 (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
154 * top_left_x_w);
155 top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
156 (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
157 * top_left_y_w);
158 }
159 // calc grid_src grad
160 *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
161 *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
162 }
163 }
164
165 template<typename DType>
BilinearSamplingForward(const Tensor<gpu,4,DType> & output,const Tensor<gpu,4,DType> & input,const Tensor<gpu,3,DType> grid_src)166 inline void BilinearSamplingForward(const Tensor<gpu, 4, DType> &output,
167 const Tensor<gpu, 4, DType> &input,
168 const Tensor<gpu, 3, DType> grid_src) {
169 DType *out = output.dptr_;
170 const DType *data = input.dptr_;
171 const DType *grid = grid_src.dptr_;
172 int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
173 int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
174 using namespace cuda;
175 const int max_block = (output.shape_.Size() + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
176 dim3 num_blocks(kMaxGridDim, (max_block + kMaxGridDim - 1) / kMaxGridDim);
177 dim3 threads_per_block(kMaxThreadsPerBlock);
178 CheckLaunchParam(num_blocks, threads_per_block, "spatial transformer forward");
179 cudaStream_t stream = Stream<gpu>::GetStream(output.stream_);
180 BilinearSamplingForwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
181 i_c, i_h, i_w, data, grid, o_n, o_c, o_h, o_w, out);
182 MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingForwardKernel);
183 }
184
185 template<typename DType>
BilinearSamplingBackward(const Tensor<gpu,4,DType> & input_grad,const Tensor<gpu,3,DType> & grid_src_data,const Tensor<gpu,4,DType> & output_grad,const Tensor<gpu,4,DType> & input_data)186 inline void BilinearSamplingBackward(const Tensor<gpu, 4, DType> &input_grad,
187 const Tensor<gpu, 3, DType> &grid_src_data,
188 const Tensor<gpu, 4, DType> &output_grad,
189 const Tensor<gpu, 4, DType> &input_data) {
190 DType *g_input = input_grad.dptr_;
191 DType *grid_src = grid_src_data.dptr_;
192 const DType *grad = output_grad.dptr_;
193 const DType *data = input_data.dptr_;
194 int o_n = output_grad.size(0), o_c = output_grad.size(1),
195 o_h = output_grad.size(2), o_w = output_grad.size(3);
196 int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
197 using namespace cuda;
198 const int max_block = (output_grad.shape_.Size() / o_c + kMaxThreadsPerBlock - 1)
199 / kMaxThreadsPerBlock;
200 dim3 num_blocks(kMaxGridDim, (max_block + kMaxGridDim - 1) / kMaxGridDim);
201 dim3 threads_per_block(kMaxThreadsPerBlock);
202 CheckLaunchParam(num_blocks, threads_per_block, "spatial transformer backward");
203 cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
204 BilinearSamplingBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
205 i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src);
206 MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingBackwardKernel);
207 }
208
209 } // namespace mshadow
210
211 namespace mxnet {
212 namespace op {
213 template<>
CreateOp(SpatialTransformerParam param,int dtype)214 Operator* CreateOp<gpu>(SpatialTransformerParam param, int dtype) {
215 Operator *op = nullptr;
216 #if MXNET_USE_CUDNN == 1
217 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
218 if (param.cudnn_off.has_value() && param.cudnn_off.value()) {
219 op = new SpatialTransformerOp<gpu, DType>(param);
220 } else {
221 op = new CuDNNSpatialTransformerOp<DType>(param);
222 }
223 })
224 #else
225 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
226 op = new SpatialTransformerOp<gpu, DType>(param);
227 })
228 #endif // MXNET_USE_CUDNN
229 return op;
230 }
231
232 } // namespace op
233 } // namespace mxnet
234