1 #section kernels
2 
3 #kernel max_pool2d_grad_grad_kernel : size, size, size, size, size, size, size, *, size, *, size, *, size, size, size, size, size, size, size, *, size :
4 #include "cluda.h"
5 
max_pool2d_grad_grad_kernel(const ga_size nthreads,const ga_size num,const ga_size channels,const ga_size pooled_height,const ga_size pooled_width,const ga_size height,const ga_size width,GLOBAL_MEM const DTYPE_INPUT_0 * x,const ga_size x_off,GLOBAL_MEM const DTYPE_INPUT_1 * z,const ga_size z_off,GLOBAL_MEM const DTYPE_INPUT_2 * gx,const ga_size gx_off,const ga_size kernel_h,const ga_size kernel_w,const ga_size stride_h,const ga_size stride_w,const ga_size pad_h,const ga_size pad_w,GLOBAL_MEM DTYPE_OUTPUT_0 * gz,const ga_size gz_off)6 KERNEL void max_pool2d_grad_grad_kernel(const ga_size nthreads,
7    const ga_size num, const ga_size channels, const ga_size pooled_height,
8    const ga_size pooled_width, const ga_size height, const ga_size width,
9    GLOBAL_MEM const DTYPE_INPUT_0 *x, const ga_size x_off, GLOBAL_MEM const DTYPE_INPUT_1 *z, const ga_size z_off, GLOBAL_MEM const DTYPE_INPUT_2 *gx, const ga_size gx_off,
10    const ga_size kernel_h, const ga_size kernel_w, const ga_size stride_h, const ga_size stride_w,
11    const ga_size pad_h, const ga_size pad_w,
12    GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
13 {
14   x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
15   z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
16   gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
17   gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
18   // grid stride looping
19   for (ga_size index = GID_0 * LDIM_0 + LID_0;
20        index < nthreads; index += LDIM_0 * GDIM_0) {
21     const ga_size pw = index % pooled_width;
22     const ga_size ph = (index / pooled_width) % pooled_height;
23     const ga_size c = (index / pooled_width / pooled_height) % channels;
24     const ga_size n = (index / pooled_width / pooled_height / channels);
25     ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
26     const ga_size hend = min(hstart + kernel_h, height);
27     ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
28     const ga_size wend = min(wstart + kernel_w, width);
29     hstart = max(hstart, 0);
30     wstart = max(wstart, 0);
31 
32     const ga_size offset = (n*channels + c) * height * width;
33 
34     GLOBAL_MEM const DTYPE_INPUT_0* x_slice = x + offset;
35     GLOBAL_MEM const DTYPE_INPUT_2* gx_slice = gx + offset;
36     DTYPE_OUTPUT_0 gradient = 0;
37 
38     for (ga_size h=hstart; h < hend; ++h) {
39       for (ga_size w=wstart; w < wend; ++w) {
40         // maximum in the region
41         if (z[index] == x_slice[h * width + w]) {
42           gradient += gx_slice[h * width + w];
43         }
44       }
45     }
46     gz[index] = gradient;
47   }
48 }
49 
50 #kernel max_pool3d_grad_grad_kernel : size, size, size, size, size, size, size, size, size, *, size, *, size, *, size, size, size, size, size, size, size, size, size, size, *, size :
51 #include "cluda.h"
52 
max_pool3d_grad_grad_kernel(const ga_size nthreads,const ga_size num,const ga_size channels,const ga_size pooled_depth,const ga_size pooled_height,const ga_size pooled_width,const ga_size depth,const ga_size height,const ga_size width,GLOBAL_MEM const DTYPE_INPUT_0 * x,const ga_size x_off,GLOBAL_MEM const DTYPE_INPUT_1 * z,const ga_size z_off,GLOBAL_MEM const DTYPE_INPUT_2 * gx,const ga_size gx_off,const ga_size kernel_d,const ga_size kernel_h,const ga_size kernel_w,const ga_size stride_d,const ga_size stride_h,const ga_size stride_w,const ga_size pad_d,const ga_size pad_h,const ga_size pad_w,GLOBAL_MEM DTYPE_OUTPUT_0 * gz,const ga_size gz_off)53 KERNEL void max_pool3d_grad_grad_kernel(const ga_size nthreads,
54    const ga_size num, const ga_size channels, const ga_size pooled_depth,
55    const ga_size pooled_height, const ga_size pooled_width,
56    const ga_size depth, const ga_size height, const ga_size width,
57    GLOBAL_MEM const DTYPE_INPUT_0 *x, const ga_size x_off, GLOBAL_MEM const DTYPE_INPUT_1 *z, const ga_size z_off, GLOBAL_MEM const DTYPE_INPUT_2 *gx, const ga_size gx_off,
58    const ga_size kernel_d, const ga_size kernel_h, const ga_size kernel_w,
59    const ga_size stride_d, const ga_size stride_h, const ga_size stride_w,
60    const ga_size pad_d, const ga_size pad_h, const ga_size pad_w,
61    GLOBAL_MEM DTYPE_OUTPUT_0 *gz, const ga_size gz_off)
62 {
63   x = (GLOBAL_MEM DTYPE_INPUT_0 *)(((GLOBAL_MEM char *)x) + x_off);
64   z = (GLOBAL_MEM DTYPE_INPUT_1 *)(((GLOBAL_MEM char *)z) + z_off);
65   gx = (GLOBAL_MEM DTYPE_INPUT_2 *)(((GLOBAL_MEM char *)gx) + gx_off);
66   gz = (GLOBAL_MEM DTYPE_OUTPUT_0 *)(((GLOBAL_MEM char *)gz) + gz_off);
67   // grid stride looping
68   for (ga_size index = GID_0 * LDIM_0 + LID_0;
69        index < nthreads; index += LDIM_0 * GDIM_0) {
70     const ga_size pw = index % pooled_width;
71     const ga_size ph = (index / pooled_width) % pooled_height;
72     const ga_size pd = (index / pooled_width / pooled_height) % pooled_depth;
73     const ga_size c = (index / pooled_width / pooled_height / pooled_depth) % channels;
74     const ga_size n = (index / pooled_width / pooled_height / pooled_depth / channels);
75     ga_int dstart = (ga_int)(pd*stride_d) - (ga_int)(pad_d);
76     const ga_size dend = min(dstart + kernel_d, depth);
77     ga_int hstart = (ga_int)(ph*stride_h) - (ga_int)(pad_h);
78     const ga_size hend = min(hstart + kernel_h, height);
79     ga_int wstart = (ga_int)(pw*stride_w) - (ga_int)(pad_w);
80     const ga_size wend = min(wstart + kernel_w, width);
81     dstart = max(dstart, 0);
82     hstart = max(hstart, 0);
83     wstart = max(wstart, 0);
84 
85     const ga_size offset = (n*channels + c) * depth * height * width;
86 
87     GLOBAL_MEM const DTYPE_INPUT_0* x_slice = x + offset;
88     GLOBAL_MEM const DTYPE_INPUT_2* gx_slice = gx + offset;
89     DTYPE_OUTPUT_0 gradient = 0;
90 
91     for (ga_size d=dstart; d < dend; ++d) {
92       for (ga_size h=hstart; h < hend; ++h) {
93         for (ga_size w=wstart; w < wend; ++w) {
94           // maximum in the region
95           if (z[index] == x_slice[(d * height + h) * width + w]) {
96             gradient += gx_slice[(d * height + h)* width + w];
97           }
98         }
99       }
100     }
101     gz[index] = gradient;
102   }
103 }
104 
105 #section support_code_struct
106 
APPLY_SPECIFIC(pool_grad_grad)107 int APPLY_SPECIFIC(pool_grad_grad)(PyGpuArrayObject *x,
108                                    PyGpuArrayObject *z,
109                                    PyGpuArrayObject *gx,
110                                    PyArrayObject *ws,
111                                    PyArrayObject *stride,
112                                    PyArrayObject *pad,
113                                    PyGpuArrayObject **gz,
114                                    PyGpuContextObject *ctx) {
115   if (!GpuArray_IS_C_CONTIGUOUS(&x->ga)
116       || !GpuArray_IS_C_CONTIGUOUS(&z->ga)
117       || !GpuArray_IS_C_CONTIGUOUS(&gx->ga))
118     {
119       PyErr_Format(PyExc_ValueError,
120                    "GpuPoolingGradGrad: requires data to be C-contiguous");
121       return 1;
122     }
123   size_t ndims = PyArray_DIM(ws, 0);
124   if (PyGpuArray_NDIM(x) != ndims + 2
125       || PyGpuArray_NDIM(z) != ndims + 2
126       || PyGpuArray_NDIM(gx) != ndims + 2)
127     {
128       PyErr_SetString(PyExc_ValueError, "GpuPoolingGradGrad: rank error");
129       return 1;
130     }
131   if (theano_prep_output(gz, PyGpuArray_NDIM(z), PyGpuArray_DIMS(z),
132                          z->ga.typecode, GA_C_ORDER, ctx) != 0)
133     {
134       PyErr_SetString(PyExc_RuntimeError,
135                       "GpuPoolingGradGrad: failed to allocate memory");
136       return 1;
137     }
138 
139   {
140     // scope for running kernel
141     size_t w[3];
142     size_t s[3];
143     size_t p[3];
144     for(int i = 0; i < ndims; i++) {
145       w[i] = *((npy_int64*)PyArray_GETPTR1(ws, i));
146       s[i] = *((npy_int64*)PyArray_GETPTR1(stride, i));
147       p[i] = *((npy_int64*)PyArray_GETPTR1(pad, i));
148     }
149 
150     int err;
151     const size_t* z_dims = PyGpuArray_DIMS(z);
152     const size_t* x_dims = PyGpuArray_DIMS(x);
153 
154     if (ndims == 2) {
155       size_t num_kernels = z_dims[0] * z_dims[1] * z_dims[2] * z_dims[3];
156       err = max_pool2d_grad_grad_kernel_scall(1, &num_kernels, 0, num_kernels,
157                                               z_dims[0], z_dims[1], z_dims[2], z_dims[3],
158                                               x_dims[2], x_dims[3],
159                                               x->ga.data, x->ga.offset,
160                                               z->ga.data, z->ga.offset,
161                                               gx->ga.data, gx->ga.offset,
162                                               w[0], w[1], s[0], s[1], p[0], p[1],
163                                               (*gz)->ga.data, (*gz)->ga.offset);
164       if (err != GA_NO_ERROR) {
165         PyErr_Format(PyExc_RuntimeError,
166                      "GpuPoolingGradGrad: max_pool2d_grad_grad_kernel %s.",
167                      GpuKernel_error(&k_max_pool2d_grad_grad_kernel, err));
168         return 1;
169       }
170     }
171     else if (ndims == 3) {
172       size_t num_kernels = z_dims[0] * z_dims[1] * z_dims[2] * z_dims[3] * z_dims[4];
173       err = max_pool3d_grad_grad_kernel_scall(1, &num_kernels, 0, num_kernels,
174                                               z_dims[0], z_dims[1], z_dims[2], z_dims[3], z_dims[4],
175                                               x_dims[2], x_dims[3], x_dims[4],
176                                               x->ga.data, x->ga.offset,
177                                               z->ga.data, z->ga.offset,
178                                               gx->ga.data, gx->ga.offset,
179                                               w[0], w[1], w[2], s[0], s[1], s[2], p[0], p[1], p[2],
180                                               (*gz)->ga.data, (*gz)->ga.offset);
181       if (err != GA_NO_ERROR) {
182         PyErr_Format(PyExc_RuntimeError,
183                      "GpuPoolingGradGrad: max_pool3d_grad_grad_kernel %s.",
184                      GpuKernel_error(&k_max_pool3d_grad_grad_kernel, err));
185         return 1;
186       }
187     }
188   }
189   return 0;
190 }
191