1 /*!
2 ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3 *
4 * COPYRIGHT
5 *
6 * All contributions by the University of California:
7 * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8 * All rights reserved.
9 *
10 * All other contributions:
11 * Copyright (c) 2014-2017, the respective contributors
12 * All rights reserved.
13 *
14 * Caffe uses a shared copyright model: each contributor holds copyright over
15 * their contributions to Caffe. The project versioning records all such
16 * contribution and copyright details. If a contributor wants to further mark
17 * their specific copyright on a particular contribution, they should indicate
18 * their copyright solely in the commit message of the change when it is
19 * committed.
20 *
21 * LICENSE
22 *
23 * Redistribution and use in source and binary forms, with or without
24 * modification, are permitted provided that the following conditions are met:
25 *
26 * 1. Redistributions of source code must retain the above copyright notice, this
27 * list of conditions and the following disclaimer.
28 * 2. Redistributions in binary form must reproduce the above copyright notice,
29 * this list of conditions and the following disclaimer in the documentation
30 * and/or other materials provided with the distribution.
31 *
32 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 *
43 * CONTRIBUTION AGREEMENT
44 *
45 * By contributing to the BVLC/caffe repository through pull-request, comment,
46 * or otherwise, the contributor releases their content to the
47 * license and copyright terms herein.
48 *
49 ***************** END Caffe Copyright Notice and Disclaimer ********************
50 *
51 * \file im2col.h
52 * \brief Function definitions of converting an image to
53 * column matrix based on kernel, padding, and dilation.
54 * These functions are mainly used in convolution operators.
55 * The implementation of the im2col and col2im algorithms
56 * are copied from Caffe with minor interface modifications
57 * adapting to MXNet data structures.
58 */
59
60 #ifndef MXNET_OPERATOR_NN_IM2COL_CUH_
61 #define MXNET_OPERATOR_NN_IM2COL_CUH_
62
63 #include <mxnet/base.h>
64 #include <mxnet/operator.h>
65 #include <algorithm>
66 #include <cstring>
67 #include <vector>
68 #include "../mxnet_op.h"
69
70 namespace mxnet {
71 namespace op {
72
73 /*!
74 * \brief im2col gpu kernel.
75 * DO NOT call this directly. Use wrapper function im2col() instead;
76 */
77 template <typename DType>
im2col_gpu_kernel(const int n,const DType * data_im,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,const int height_col,const int width_col,DType * data_col)78 __global__ void im2col_gpu_kernel(const int n, const DType* data_im,
79 const int height, const int width, const int kernel_h, const int kernel_w,
80 const int pad_h, const int pad_w,
81 const int stride_h, const int stride_w,
82 const int dilation_h, const int dilation_w,
83 const int height_col, const int width_col,
84 DType* data_col) {
85 CUDA_KERNEL_LOOP(index, n) {
86 const int h_index = index / width_col;
87 const int h_col = h_index % height_col;
88 const int w_col = index % width_col;
89 const int c_im = h_index / height_col;
90 const int c_col = c_im * kernel_h * kernel_w;
91 const int h_offset = h_col * stride_h - pad_h;
92 const int w_offset = w_col * stride_w - pad_w;
93 DType* data_col_ptr = data_col;
94 data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;
95 const DType* data_im_ptr = data_im;
96 data_im_ptr += (c_im * height + h_offset) * width + w_offset;
97 for (int i = 0; i < kernel_h; ++i) {
98 for (int j = 0; j < kernel_w; ++j) {
99 int h_im = h_offset + i * dilation_h;
100 int w_im = w_offset + j * dilation_w;
101 *data_col_ptr =
102 (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
103 data_im_ptr[i * dilation_h * width + j * dilation_w] : static_cast<DType>(0);
104 data_col_ptr += height_col * width_col;
105 }
106 }
107 }
108 }
109
110 /*!
111 * \brief DO NOT call this directly. Use wrapper function im2col() instead;
112 */
113 template <typename DType>
im2col_gpu(mshadow::Stream<gpu> * s,const DType * data_im,const int channels,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,DType * data_col)114 inline void im2col_gpu(mshadow::Stream<gpu>* s,
115 const DType* data_im, const int channels,
116 const int height, const int width,
117 const int kernel_h, const int kernel_w,
118 const int pad_h, const int pad_w,
119 const int stride_h, const int stride_w,
120 const int dilation_h, const int dilation_w,
121 DType* data_col) {
122 // We are going to launch channels * height_col * width_col kernels, each
123 // kernel responsible for copying a single-channel grid.
124 int height_col = (height + 2 * pad_h -
125 (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
126 int width_col = (width + 2 * pad_w -
127 (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
128 int num_kernels = channels * height_col * width_col;
129 using namespace mxnet_op;
130 // NOLINT_NEXT_LINE(whitespace/operators)
131 im2col_gpu_kernel<DType><<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
132 0, mshadow::Stream<gpu>::GetStream(s)>>>(
133 num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h,
134 pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col,
135 width_col, data_col);
136 MSHADOW_CUDA_POST_KERNEL_CHECK(im2col_gpu_kernel);
137 }
138
139 /*!
140 * \brief DO NOT call this directly. Use wrapper function col2im() instead;
141 */
142 template <typename DType>
col2im_gpu_kernel(const int n,const DType * data_col,const int channels,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,const int height_col,const int width_col,DType * data_im,OpReqType req)143 __global__ void col2im_gpu_kernel(const int n, const DType* data_col,
144 const int channels, const int height, const int width,
145 const int kernel_h, const int kernel_w,
146 const int pad_h, const int pad_w,
147 const int stride_h, const int stride_w,
148 const int dilation_h, const int dilation_w,
149 const int height_col, const int width_col,
150 DType* data_im, OpReqType req) {
151 CUDA_KERNEL_LOOP(index, n) {
152 DType val = 0;
153 const int w_im = index % width + pad_w;
154 const int h_im = (index / width) % height + pad_h;
155 const int c_im = index / (width * height);
156 int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
157 int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
158 // compute the start and end of the output
159 const int w_col_start =
160 (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
161 const int w_col_end = min(w_im / stride_w + 1, width_col);
162 const int h_col_start =
163 (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
164 const int h_col_end = min(h_im / stride_h + 1, height_col);
165 // TODO(caffe): use LCM of stride and dilation to avoid unnecessary loops
166 for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
167 for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
168 int h_k = (h_im - h_col * stride_h);
169 int w_k = (w_im - w_col * stride_w);
170 if (h_k % dilation_h == 0 && w_k % dilation_w == 0) {
171 h_k /= dilation_h;
172 w_k /= dilation_w;
173 int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) *
174 height_col + h_col) * width_col + w_col;
175 val += data_col[data_col_index];
176 }
177 }
178 }
179 KERNEL_ASSIGN(data_im[index], req, val);
180 }
181 }
182
183 /*!
184 * \brief DO NOT call this directly. Use wrapper function col2im() instead;
185 */
186 using mshadow::Shape;
187 template <typename DType, int num_axes>
im2col_nd_gpu_kernel(const int n,const DType * data_im,const Shape<num_axes+2> im_shape,const Shape<num_axes+1> col_shape,const Shape<num_axes> kernel_shape,const Shape<num_axes> pad,const Shape<num_axes> stride,const Shape<num_axes> dilation,DType * data_col)188 __global__ void im2col_nd_gpu_kernel(const int n, const DType* data_im,
189 const Shape<num_axes+2> im_shape, const Shape<num_axes+1> col_shape,
190 const Shape<num_axes> kernel_shape, const Shape<num_axes> pad, const Shape<num_axes> stride,
191 const Shape<num_axes> dilation, DType* data_col) {
192 int d_temp[num_axes]; // NOLINT(runtime/arrays)
193 int d_iter[num_axes]; // NOLINT(runtime/arrays)
194
195 __shared__ int shared_dilation[num_axes];
196 __shared__ int shared_kernel_shape[num_axes];
197 __shared__ int shared_pad[num_axes];
198 __shared__ int shared_stride[num_axes];
199 __shared__ int shared_col_shape[num_axes + 1];
200 __shared__ int shared_im_shape[num_axes + 1];
201
202 if (threadIdx.x < num_axes) {
203 shared_dilation[threadIdx.x] = dilation[threadIdx.x];
204 shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x];
205 shared_pad[threadIdx.x] = pad[threadIdx.x];
206 shared_stride[threadIdx.x] = stride[threadIdx.x];
207 }
208 if (threadIdx.x < num_axes + 1) {
209 shared_col_shape[threadIdx.x] = col_shape[threadIdx.x];
210 shared_im_shape[threadIdx.x] = im_shape[threadIdx.x+1]; // skip batch dim
211 }
212 __syncthreads();
213
214 int i;
215 CUDA_KERNEL_LOOP(index, n) {
216 // Initialize channel_in, computed in the loop below, with intermediate
217 // computations used to compute the spatial indices.
218 int channel_in = index;
219 int channel_out = 1;
220 for (i = num_axes - 1; i >= 0; --i) {
221 d_temp[i] = channel_in % shared_col_shape[i + 1];
222 channel_in /= shared_col_shape[i + 1];
223 channel_out *= shared_kernel_shape[i];
224 }
225 channel_out *= channel_in;
226 int data_col_inc = 1;
227 for (i = 0; i < num_axes; ++i) {
228 channel_out *= shared_col_shape[i + 1];
229 channel_out += d_temp[i];
230 d_temp[i] = d_temp[i] * shared_stride[i] - shared_pad[i];
231 channel_in *= shared_im_shape[i + 1];
232 channel_in += d_temp[i];
233 data_col_inc *= shared_col_shape[i + 1];
234 d_iter[i] = 0;
235 }
236 DType* data_col_ptr = data_col + channel_out;
237 const DType* data_im_ptr = data_im + channel_in;
238 bool incremented;
239 do {
240 bool in_range = true;
241 for (i = 0; i < num_axes; ++i) {
242 const int d_iter_im = d_iter[i] * shared_dilation[i] + d_temp[i];
243 in_range &= d_iter_im >= 0 && d_iter_im < shared_im_shape[i + 1];
244 if (!in_range) { break; }
245 }
246 if (in_range) {
247 int data_im_offset = d_iter[0] * shared_dilation[0];
248 for (i = 1; i < num_axes; ++i) {
249 data_im_offset *= shared_im_shape[i + 1];
250 data_im_offset += d_iter[i] * shared_dilation[i];
251 }
252 *data_col_ptr = data_im_ptr[data_im_offset];
253 } else {
254 *data_col_ptr = 0;
255 }
256 data_col_ptr += data_col_inc;
257 incremented = false;
258 for (i = num_axes - 1; i >= 0; --i) {
259 const int d_max = shared_kernel_shape[i];
260 if (d_iter[i] == d_max - 1) {
261 d_iter[i] = 0;
262 } else { // d_iter[i] < d_max - 1
263 ++d_iter[i];
264 incremented = true;
265 break;
266 }
267 } // for (int i = num_axes - 1; i >= 0; --i)
268 } while (incremented); // do
269 } // CUDA_KERNEL_LOOP(index, n)
270 }
271
272 /*!\brief im2col gpu version
273 * \param s device stream
274 * \param data_im pointer of an image (C, H, W, ...) in the image batch
275 * \param col_shape column buffer shape (#channels, output_im_height, output_im_width, ...)
276 * \param kernel_shape kernel filter shape
277 * \param pad pad shape
278 * \param stride stride shape
279 * \param dilation dilation shape
280 * \param data_col column buffer pointer
281 */
282 template <typename DType>
im2col(mshadow::Stream<gpu> * s,const DType * data_im,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,DType * data_col)283 inline void im2col(mshadow::Stream<gpu>* s,
284 const DType* data_im, const mxnet::TShape& im_shape,
285 const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape,
286 const mxnet::TShape& pad, const mxnet::TShape& stride,
287 const mxnet::TShape& dilation, DType* data_col) {
288 // num_axes should be smaller than block size
289 index_t num_spatial_axes = kernel_shape.ndim();
290 CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
291 index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim());
292 using namespace mxnet_op;
293 switch (num_spatial_axes) {
294 case 1:
295 im2col_nd_gpu_kernel<DType, 1> // NOLINT_NEXT_LINE(whitespace/operators)
296 <<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
297 0, mshadow::Stream<gpu>::GetStream(s)>>>(
298 num_kernels, data_im, im_shape.get<3>(), col_shape.get<2>(),
299 kernel_shape.get<1>(), pad.get<1>(), stride.get<1>(), dilation.get<1>(), data_col);
300 break;
301 case 2:
302 im2col_gpu_kernel<DType> // NOLINT_NEXT_LINE(whitespace/operators)
303 <<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
304 0, mshadow::Stream<gpu>::GetStream(s)>>>(
305 num_kernels, data_im, im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1],
306 pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1],
307 col_shape[1], col_shape[2], data_col);
308 break;
309 case 3:
310 im2col_nd_gpu_kernel<DType, 3> // NOLINT_NEXT_LINE(whitespace/operators)
311 <<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
312 0, mshadow::Stream<gpu>::GetStream(s)>>>(
313 num_kernels, data_im, im_shape.get<5>(), col_shape.get<4>(),
314 kernel_shape.get<3>(), pad.get<3>(), stride.get<3>(), dilation.get<3>(), data_col);
315 break;
316 default:
317 LOG(FATAL) << "im2col_nd_gpu does not support computation with "
318 << num_spatial_axes << " spatial axes";
319 }
320 MSHADOW_CUDA_POST_KERNEL_CHECK(im2col_nd_gpu_kernel);
321 }
322
323 /*!
324 * \brief DO NOT call this directly. Use wrapper function col2im() instead;
325 */
326 template <typename DType>
col2im_gpu(mshadow::Stream<gpu> * s,const DType * data_col,const int channels,const int height,const int width,const int kernel_h,const int kernel_w,const int pad_h,const int pad_w,const int stride_h,const int stride_w,const int dilation_h,const int dilation_w,DType * data_im,OpReqType req)327 inline void col2im_gpu(mshadow::Stream<gpu>* s, const DType* data_col, const int channels,
328 const int height, const int width, const int kernel_h, const int kernel_w,
329 const int pad_h, const int pad_w, const int stride_h,
330 const int stride_w, const int dilation_h, const int dilation_w,
331 DType* data_im, OpReqType req) {
332 int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) /
333 stride_h + 1;
334 int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) /
335 stride_w + 1;
336 int num_kernels = channels * height * width;
337 using namespace mxnet_op;
338 // To avoid involving atomic operations, we will launch one kernel per
339 // bottom dimension, and then in the kernel add up the top dimensions.
340 // NOLINT_NEXT_LINE(whitespace/operators)
341 col2im_gpu_kernel<DType><<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
342 0, mshadow::Stream<gpu>::GetStream(s)>>>(
343 num_kernels, data_col, height, width, channels, kernel_h, kernel_w,
344 pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
345 height_col, width_col, data_im, req);
346 MSHADOW_CUDA_POST_KERNEL_CHECK(col2im_gpu_kernel);
347 }
348
349 /*!
350 * \brief DO NOT call this directly. Use wrapper function col2im() instead;
351 */
352 template <typename DType, int num_axes>
col2im_nd_gpu_kernel(const int n,const DType * data_col,const Shape<num_axes+2> im_shape,const Shape<num_axes+1> col_shape,const Shape<num_axes> kernel_shape,const Shape<num_axes> pad,const Shape<num_axes> stride,const Shape<num_axes> dilation,DType * data_im,OpReqType req)353 __global__ void col2im_nd_gpu_kernel(const int n, const DType* data_col,
354 const Shape<num_axes+2> im_shape, const Shape<num_axes+1> col_shape,
355 const Shape<num_axes> kernel_shape, const Shape<num_axes> pad, const Shape<num_axes> stride,
356 const Shape<num_axes> dilation, DType* data_im, OpReqType req) {
357 int d_im[num_axes]; // NOLINT(runtime/arrays)
358 int d_col_iter[num_axes]; // NOLINT(runtime/arrays)
359 int d_col_start[num_axes]; // NOLINT(runtime/arrays)
360 int d_col_end[num_axes]; // NOLINT(runtime/arrays)
361
362 __shared__ int shared_dilation[num_axes];
363 __shared__ int shared_kernel_shape[num_axes];
364 __shared__ int shared_pad[num_axes];
365 __shared__ int shared_stride[num_axes];
366 __shared__ int shared_col_shape[num_axes + 1];
367 __shared__ int shared_im_shape[num_axes + 1];
368
369 if (threadIdx.x < num_axes) {
370 shared_dilation[threadIdx.x] = dilation[threadIdx.x];
371 shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x];
372 shared_pad[threadIdx.x] = pad[threadIdx.x];
373 shared_stride[threadIdx.x] = stride[threadIdx.x];
374 }
375 if (threadIdx.x < num_axes + 1) {
376 shared_col_shape[threadIdx.x] = col_shape[threadIdx.x];
377 shared_im_shape[threadIdx.x] = im_shape[threadIdx.x+1]; // skip batch dim
378 }
379 __syncthreads();
380
381 CUDA_KERNEL_LOOP(index, n) {
382 // Initialize channel_in, computed in the loop below, with intermediate
383 // computations used to compute the spatial indices.
384 int c_im = index;
385 // Calculate d_im (image dimensions).
386 for (int i = num_axes - 1; i >= 0; --i) {
387 d_im[i] = c_im % shared_im_shape[i + 1] + shared_pad[i];
388 c_im /= shared_im_shape[i + 1];
389 }
390 // Calculate col start/end indices.
391 bool done = false;
392 for (int i = 0; i < num_axes; ++i) {
393 const int kernel_extent =
394 shared_dilation[i] * (shared_kernel_shape[i] - 1) + 1;
395 d_col_start[i] = d_col_iter[i] =
396 (d_im[i] < kernel_extent) ? 0 :
397 (d_im[i] - kernel_extent) / shared_stride[i] + 1;
398 d_col_end[i] =
399 min(d_im[i] / shared_stride[i] + 1, shared_col_shape[i + 1]);
400 if (d_col_start[i] >= d_col_end[i]) {
401 // Skip computation if the dimension is 0 at any spatial axis --
402 // final val will be 0.
403 data_im[index] = 0;
404 done = true;
405 break; // for (int i = 0; i < num_axes; ++i)
406 }
407 }
408 if (done) {
409 continue; // CUDA_KERNEL_LOOP(index, n)
410 }
411 // Loop over the col to compute the output val.
412 DType val = 0;
413 bool incremented = true;
414 bool skip = false;
415 do {
416 // Compute the final offset.
417 int final_offset = 0;
418 int kernel_shape_prod = 1;
419 int kernel_index;
420 for (int i = num_axes - 1; i >= 0; --i) {
421 kernel_index = d_im[i] - d_col_iter[i] * shared_stride[i];
422 if (kernel_index % shared_dilation[i]) {
423 skip = true;
424 break;
425 } else {
426 kernel_index /= shared_dilation[i];
427 final_offset += kernel_index * kernel_shape_prod;
428 kernel_shape_prod *= shared_kernel_shape[i];
429 }
430 }
431 if (!skip) {
432 final_offset += kernel_shape_prod * c_im;
433 for (int i = 0; i < num_axes; ++i) {
434 final_offset *= shared_col_shape[i + 1];
435 final_offset += d_col_iter[i];
436 }
437 val += data_col[final_offset];
438 }
439 skip = false;
440 incremented = false;
441 for (int i = num_axes - 1; i >= 0; --i) {
442 const int d_max = d_col_end[i];
443 if (d_col_iter[i] == d_max - 1) {
444 d_col_iter[i] = d_col_start[i];
445 } else { // d_col_iter[i] < d_max - 1
446 ++d_col_iter[i];
447 incremented = true;
448 break; // for (int i = num_axes - 1; i >= 0; --i)
449 }
450 } // for (int i = num_axes - 1; i >= 0; --i)
451 } while (incremented);
452 KERNEL_ASSIGN(data_im[index], req, val);
453 } // CUDA_KERNEL_LOOP(index, n)
454 }
455
456 /*!\brief
457 * gpu function of col2im algorithm
458 * \param s device stream
459 * \param data_col start pointer of the column buffer to be filled
460 * \param im_shape input image shape in dimensions (N, C, H, W,)
461 * \param col_shape column buffer shape
462 * \param kernel_shape kernel filter shape
463 * \param pad pad shape
464 * \param stride stride shape
465 * \param dilation dilation shape
466 * \param data_im pointer of a image (C, H, W,...) in the image batch
467 */
468 template <typename DType>
col2im(mshadow::Stream<gpu> * s,const DType * data_col,const mxnet::TShape & im_shape,const mxnet::TShape & col_shape,const mxnet::TShape & kernel_shape,const mxnet::TShape & pad,const mxnet::TShape & stride,const mxnet::TShape & dilation,DType * data_im,OpReqType req)469 inline void col2im(mshadow::Stream<gpu>* s,
470 const DType* data_col, const mxnet::TShape& im_shape,
471 const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape,
472 const mxnet::TShape& pad, const mxnet::TShape& stride,
473 const mxnet::TShape& dilation, DType* data_im, OpReqType req) {
474 index_t num_spatial_axes = kernel_shape.ndim();
475 index_t im_size = im_shape.ProdShape(1, im_shape.ndim());
476 // num_axes should be smaller than block size
477 CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
478 using namespace mxnet_op;
479 switch (num_spatial_axes) {
480 case 1:
481 col2im_nd_gpu_kernel<DType, 1> // NOLINT_NEXT_LINE(whitespace/operators)
482 <<<cuda_get_num_blocks(im_size), mshadow::cuda::kBaseThreadNum,
483 0, mshadow::Stream<gpu>::GetStream(s)>>>(
484 im_size, data_col, im_shape.get<3>(), col_shape.get<2>(),
485 kernel_shape.get<1>(), pad.get<1>(), stride.get<1>(), dilation.get<1>(),
486 data_im, req);
487 MSHADOW_CUDA_POST_KERNEL_CHECK(col2im_nd_gpu_kernel);
488 break;
489 case 2:
490 // To avoid involving atomic operations, we will launch one kernel per
491 // bottom dimension, and then in the kernel add up the top dimensions.
492 // NOLINT_NEXT_LINE(whitespace/operators)
493 col2im_gpu_kernel<DType><<<cuda_get_num_blocks(im_size), mshadow::cuda::kBaseThreadNum,
494 0, mshadow::Stream<gpu>::GetStream(s)>>>(
495 im_size, data_col, im_shape[1], im_shape[2], im_shape[3],
496 kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1],
497 dilation[0], dilation[1], col_shape[1], col_shape[2], data_im, req);
498 MSHADOW_CUDA_POST_KERNEL_CHECK(col2im_gpu_kernel);
499 break;
500 case 3:
501 col2im_nd_gpu_kernel<DType, 3> // NOLINT_NEXT_LINE(whitespace/operators)
502 <<<cuda_get_num_blocks(im_size), mshadow::cuda::kBaseThreadNum,
503 0, mshadow::Stream<gpu>::GetStream(s)>>>(
504 im_size, data_col, im_shape.get<5>(), col_shape.get<4>(),
505 kernel_shape.get<3>(), pad.get<3>(), stride.get<3>(), dilation.get<3>(),
506 data_im, req);
507 MSHADOW_CUDA_POST_KERNEL_CHECK(col2im_nd_gpu_kernel);
508 break;
509 default:
510 LOG(FATAL) << "col2im_nd_gpu does not support computation with "
511 << num_spatial_axes << " spatial axes";
512 }
513 }
514
515 } // namespace op
516 } // namespace mxnet
517
518 #endif // MXNET_OPERATOR_NN_IM2COL_CUH_
519