1 // This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/);
2 // sources are clearly marked. Below we reproduce the original license of
3 // the Caffe software.
4 /*
5 Copyright (c) 2014, The Regents of the University of California (Regents)
6 All rights reserved.
7 
8 Redistribution and use in source and binary forms, with or without
9 modification, are permitted provided that the following conditions are met:
10 
11 1. Redistributions of source code must retain the above copyright notice, this
12    list of conditions and the following disclaimer.
13 2. Redistributions in binary form must reproduce the above copyright notice,
14    this list of conditions and the following disclaimer in the documentation
15    and/or other materials provided with the distribution.
16 
17 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28 
29 // (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp)
30 // Loops for fast unfold + copy
31 void im3d2col(const %(float_type)s* data_im, const int channels,
32     const int height, const int width, const int depth,
33     const int kernel_h, const int kernel_w, const int kernel_d,
34     const int dilation_h, const int dilation_w, const int dilation_d,
35     const int pad_h, const int pad_w, const int pad_d,
36     const int stride_h, const int stride_w, const int stride_d,
37     %(float_type)s* data_col) {
38   // Implicit dilated kernel size
39   int dil_kernel_h = (kernel_h - 1) * dilation_h + 1;
40   int dil_kernel_w = (kernel_w - 1) * dilation_w + 1;
41   int dil_kernel_d = (kernel_d - 1) * dilation_d + 1;
42   int height_col = (height + 2 * pad_h - dil_kernel_h) / stride_h + 1;
43   int width_col = (width + 2 * pad_w - dil_kernel_w) / stride_w + 1;
44   int depth_col = (depth + 2 * pad_d - dil_kernel_d) / stride_d + 1;
45   int channels_col = channels * kernel_h * kernel_w * kernel_d;
46   for (int c = 0; c < channels_col; ++c) {
47     int d_offset = c %% kernel_d;
48     int w_offset = (c / kernel_d) %% kernel_w;
49     int h_offset = (c / kernel_w / kernel_d) %% kernel_h;
50     int c_im = c / kernel_h / kernel_w / kernel_d;
51     for (int h = 0; h < height_col; ++h) {
52       int h_pad = h * stride_h - pad_h + h_offset * dilation_h;
53       for (int w = 0; w < width_col; ++w) {
54         int w_pad = w * stride_w - pad_w + w_offset * dilation_w;
55         for (int d = 0; d < depth_col; ++d) {
56           int d_pad = d * stride_d - pad_d + d_offset * dilation_d;
57           if (h_pad >= 0 && h_pad < height
58               && w_pad >= 0 && w_pad < width
59               && d_pad >= 0 && d_pad < depth)
60             data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d] =
61               data_im[(npy_intp)((c_im * height + h_pad) * width + w_pad) * depth + d_pad];
62           else
63             data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d] = 0.;
64         }
65       }
66     }
67   }
68 }
69 
70 // Unlike the Caffe and Theano GPU verions, the data_im array is set to zero
71 // before the col2im call rather than doing it here. So, the result is just
72 // accumulated into data_im.
73 void col2im3d(const %(float_type)s* data_col, const int channels,
74     const int height, const int width, const int depth,
75     const int patch_h, const int patch_w, const int patch_d,
76     const int dilation_h, const int dilation_w, const int dilation_d,
77     const int pad_h, const int pad_w, const int pad_d,
78     const int stride_h, const int stride_w, const int stride_d,
79     %(float_type)s* data_im) {
80   // Implicit dilated patch
81   int dil_patch_h = (patch_h - 1) * dilation_h + 1;
82   int dil_patch_w = (patch_w - 1) * dilation_w + 1;
83   int dil_patch_d = (patch_d - 1) * dilation_d + 1;
84   int height_col = (height + 2 * pad_h - dil_patch_h) / stride_h + 1;
85   int width_col = (width + 2 * pad_w - dil_patch_w) / stride_w + 1;
86   int depth_col = (depth + 2 * pad_d - dil_patch_d) / stride_d + 1;
87   int num_kernels = channels * height * width * depth;
88   int channels_col = channels * patch_h * patch_w * patch_d;
89   for (int c = 0; c < channels_col; ++c) {
90     int d_offset = c %% patch_d;
91     int w_offset = (c / patch_d) %% patch_w;
92     int h_offset = (c / patch_w / patch_d) %% patch_h;
93     int c_im = c / patch_h / patch_w / patch_d;
94     for (int h = 0; h < height_col; ++h) {
95       int h_pad = h * stride_h - pad_h + h_offset * dilation_h;
96       for (int w = 0; w < width_col; ++w) {
97         int w_pad = w * stride_w - pad_w + w_offset * dilation_w;
98         for (int d = 0; d < depth_col; ++d) {
99           int d_pad = d * stride_d - pad_d + d_offset * dilation_d;
100           if (h_pad >= 0 && h_pad < height
101               && w_pad >= 0 && w_pad < width
102               && d_pad >= 0 && d_pad < depth)
103             data_im[(npy_intp)((c_im * height + h_pad) * width + w_pad) * depth + d_pad] +=
104               data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d];
105         }
106       }
107     }
108   }
109 }
110 
111 
112 // Theano op code
113 // GPU version authors: Arjun Jain, Frederic Bastien, Jan Schlueter
114 // Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
115 //   and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
116 // CPU version author: Jesse Livezey
117 // CPU version adapted from GPU version
118 PyArrayObject* corr3dMM(PyArrayObject* bottom,
119                         PyArrayObject* weight,
120                         PyArrayObject* top,
121                         const int direction,
122                         const int dH = 1,
123                         const int dW = 1,
124                         const int dD = 1,
125                         const int dilH = 1,
126                         const int dilW = 1,
127                         const int dilD = 1,
128                         const int padH = 0,
129                         const int padW = 0,
130                         const int padD = 0,
131                         const int numgroups=1)
132 {
133     if (PyArray_NDIM(bottom) != 5)
134     {
135         PyErr_SetString(PyExc_ValueError, "Corr3dMM requires bottom of 5D");
136         return NULL;
137     }
138     if (PyArray_TYPE(bottom) != %(float_typenum)s)
139     {
140         PyErr_SetString(PyExc_ValueError, "Corr3dMM received bottom with wrong type.");
141         return NULL;
142     }
143 
144     if (PyArray_NDIM(weight) != 5)
145     {
146         PyErr_SetString(PyExc_ValueError, "Corr3dMM requires weight of 5D");
147         return NULL;
148     }
149     if (PyArray_TYPE(weight) != %(float_typenum)s)
150     {
151         PyErr_SetString(PyExc_ValueError, "Corr3dMM received weight with wrong type.");
152         return NULL;
153     }
154 
155     if (PyArray_NDIM(top) != 5)
156     {
157         PyErr_SetString(PyExc_ValueError, "Corr3dMM requires top of 5D");
158         return NULL;
159     }
160     if (PyArray_TYPE(top) != %(float_typenum)s)
161     {
162         PyErr_SetString(PyExc_ValueError, "Corr3dMM received top with wrong type.");
163         return NULL;
164     }
165     // Ensure data is contiguous
166     bottom = PyArray_GETCONTIGUOUS(bottom);
167     weight = PyArray_GETCONTIGUOUS(weight);
168     top = PyArray_GETCONTIGUOUS(top);
169 
170     // Extract some shape information for later and check shape consistency
171     // bottom: (batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth)
172     const int batchSize = PyArray_DIMS(bottom)[0];
173     const int nChannels = PyArray_DIMS(bottom)[1];
174     const int bottomHeight = PyArray_DIMS(bottom)[2];
175     const int bottomWidth = PyArray_DIMS(bottom)[3];
176     const int bottomDepth = PyArray_DIMS(bottom)[4];
177     // weights: (nFilters, nChannels, rows, columns, slices)
178     const int nFilters = PyArray_DIMS(weight)[0];
179     const int kH = PyArray_DIMS(weight)[2];
180     const int kW = PyArray_DIMS(weight)[3];
181     const int kD = PyArray_DIMS(weight)[4];
182     if (nChannels != PyArray_DIMS(weight)[1] * numgroups) {
183         PyErr_SetString(PyExc_ValueError,
184                 "Corr3dMM images and kernel must have the same stack size\n");
185         return NULL;
186     }
187     if ((nFilters %% numgroups) != 0) {
188         PyErr_SetString(PyExc_ValueError,
189                 "CorrMM the number of filters must be divisible by the number of groups\n");
190         return NULL;
191     }
192     // implicit dilated filter
193     const int dil_kH = (kH - 1) * dilH + 1;
194     const int dil_kW = (kW - 1) * dilW + 1;
195     const int dil_kD = (kD - 1) * dilD + 1;
196     // top: (batchSize, nFilters, topHeight, topWidth, topDepth)
197     const int topHeightNoDH = (bottomHeight + 2*padH - dil_kH);
198     const int topWidthNoDW  = (bottomWidth + 2*padW - dil_kW);
199     const int topDepthNoDD  = (bottomDepth + 2*padD - dil_kD);
200     // the above values might be negative so we need to use Python-like
201     // flooring integer division to be compatible with get_conv_output.
202     // note: this macro implements Python's // for negative x only
203 #define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) %% y) == 0 ? 0 : 1)) : (x / y))
204     const int topHeight = _CONV_FLOORDIV_X(topHeightNoDH, dH) + 1;
205     const int topWidth  = _CONV_FLOORDIV_X(topWidthNoDW, dW) + 1;
206     const int topDepth  = _CONV_FLOORDIV_X(topDepthNoDD, dD) + 1;
207 #undef _CONV_FLOORDIV
208     if (batchSize != PyArray_DIMS(top)[0] ||
209             nFilters != PyArray_DIMS(top)[1] ||
210             topHeight != PyArray_DIMS(top)[2] ||
211             topWidth != PyArray_DIMS(top)[3] ||
212             topDepth != PyArray_DIMS(top)[4]) {
213         PyErr_Format(PyExc_ValueError,
214                 "Corr3dMM shape inconsistency:\n"
215                 "  bottom shape: %%d %%d %%d %%d %%d\n"
216                 "  weight shape: %%d %%d %%d %%d %%d\n"
217                 "  top shape: %%ld %%ld %%ld %%ld %%ld (expected %%d %%d %%d %%d %%d)\n",
218                 batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
219                 nFilters, nChannels / numgroups, kH, kW, kD,
220                 PyArray_DIMS(top)[0], PyArray_DIMS(top)[1],
221                 PyArray_DIMS(top)[2], PyArray_DIMS(top)[3], PyArray_DIMS(top)[4],
222                 batchSize, nFilters, topHeight, topWidth, topDepth);
223         return NULL;
224     }
225 
226     // Create temporary columns
227     int max_threads = %(omp_get_max_threads)s;
228     if (batchSize < max_threads) {
229         max_threads = batchSize;
230     }
231     npy_intp col_dim[3];
232     col_dim[0] = (npy_intp)max_threads;
233     col_dim[1] = (npy_intp)(nChannels * kW * kH * kD);
234     col_dim[2] = (npy_intp)(topHeight * topWidth * topDepth);
235 
236     //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
237     PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3,
238             col_dim,
239             PyArray_TYPE(top),
240             0);
241     if (NULL == col) {
242         PyErr_Format(PyExc_RuntimeError,
243                 "Corr3dMM failed to allocate working memory of"
244                 " %%ld x %%ld x %%ld\n",
245                 col_dim[0], col_dim[1], col_dim[2]);
246         return NULL;
247     }
248 
249     // Define some useful variables
250     const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
251     const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f;
252     const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
253     const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f;
254     const int K_ = col_dim[1] / numgroups;
255     const int N_ = col_dim[2];
256     const int col_stride = (K_ * N_ * numgroups);
257     const int group_col_stride = (K_ * N_);
258     const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f;
259     const int M_ = nFilters / numgroups;
260     const %(c_float_type)s one = 1.0;
261     const %(c_float_type)s zero = 0.0;
262     char NTrans = 'N';
263     char Trans = 'T';
264     PyArrayObject *output;
265 
266     if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
267         switch(direction) {
268         case 0:
269             output = top;
270             break;
271         case 1:
272             output = weight;
273             break;
274         case 2:
275             output = bottom;
276             break;
277         default:
278             return NULL;
279         }
280         PyArray_FILLWBYTE(output, 0);
281     }
282     else if (direction == 0) {  // forward pass
283         output = top;
284         // valid correlation: im3d2col, then gemm
285         // Iterate over batch
286         int blas_threads_saved = %(blas_get_num_threads)s;
287         // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
288         %(blas_set_num_threads)s(1);
289         %(omp_flags)s
290         for (int n = 0; n < batchSize; ++n) {
291             int tid = %(omp_get_thread_num)s;
292             // First, im3d2col
293             im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
294                      nChannels, bottomHeight, bottomWidth, bottomDepth,
295                      kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
296                      (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
297 
298             for ( int g = 0; g < numgroups; ++g){
299                 // Second, gemm
300                 %(gemm)s(&NTrans, &NTrans,
301                          &N_, &M_, &K_,
302                          &one,
303                          (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride + g * group_col_stride, &N_,
304                          (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
305                          &zero,
306                          (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_);
307             }
308         }
309         // Restore to previous blas threads
310         %(blas_set_num_threads)s(blas_threads_saved);
311     }
312     else if (direction == 1) {  // backprop wrt. weights
313         output = weight;
314         npy_intp weight_dim[2];
315         weight_dim[0] = (npy_intp)max_threads;
316         weight_dim[1] = (npy_intp)(M_ * K_ * numgroups);
317         PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
318                                    weight_dim, PyArray_TYPE(weight), 0);
319 
320         if (NULL == local_weight)
321         {
322             PyErr_Format(PyExc_RuntimeError,
323                     "Corr3dMM failed to allocate weight memory of %%ld x %%ld\n",
324                     weight_dim[0], weight_dim[1]);
325             return NULL;
326         }
327 
328         // valid convolution: im2col, then gemm
329         // Iterate over batch
330         int blas_threads_saved = %(blas_get_num_threads)s;
331         // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
332         %(blas_set_num_threads)s(1);
333         // OMP for batch-level paralization
334         %(omp_flags)s
335         for (int n = 0; n < batchSize; ++n) {
336             int tid = %(omp_get_thread_num)s;
337             // First, im2col
338             im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
339                      nChannels, bottomHeight, bottomWidth, bottomDepth,
340                      kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
341                      (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
342 
343             for ( int g = 0; g < numgroups; ++g){
344                 // Second, gemm
345                 // Note that we accumulate into weight. We do so by setting beta = 0
346                 // for the first iteration and beta = 1 for subsequent ones. (This
347                 // is faster than setting weight to all zeros before the loop.)
348                 %(gemm)s(&Trans, &NTrans,
349                          &K_, &M_, &N_,
350                          &one,
351                          (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
352                          (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
353                          (n == 0) ? &zero : &one,
354                          (%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride +
355                          tid * weight_dim[1], &K_);
356             }
357         }
358         // Restore to previous blas threads
359         %(blas_set_num_threads)s(blas_threads_saved);
360 
361         //aggregate weights
362         memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s));
363         /*
364          * Put index "j" into outer loop to get the
365          * correct result when openmp is used.
366          */
367         %(omp_flags)s
368         for(int j = 0; j < weight_dim[1]; ++j){
369             for(int i = 0; i < max_threads; ++i){
370                 ((%(float_type)s*)PyArray_DATA(weight))[j] +=
371                     *((%(float_type)s*)PyArray_DATA(local_weight) +
372                     i * weight_dim[1] + j);
373             }
374         }
375         Py_DECREF(local_weight);
376     }
377     else if (direction == 2) {  // backprop wrt. inputs
378         output = bottom;
379         // bottom is set to zero here rather than inside of col2im
380         PyArray_FILLWBYTE(bottom, 0);
381         // full convolution: gemm, then col2im3d
382         // Iterate over batch
383 
384         int blas_threads_saved = %(blas_get_num_threads)s;
385         // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
386         %(blas_set_num_threads)s(1);
387         %(omp_flags)s
388         for (int n = 0; n < batchSize; ++n) {
389 
390             int tid = %(omp_get_thread_num)s;
391             for ( int g = 0; g < numgroups; ++g){
392                 // gemm into columns
393                 %(gemm)s(&NTrans, &Trans,
394                          &N_, &K_, &M_,
395                          &one,
396                          (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
397                          (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
398                          &zero,
399                          (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_);
400             }
401             // col2im back to the data
402             col2im3d((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels,
403                      bottomHeight, bottomWidth, bottomDepth,
404                      kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
405                      (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
406         }
407         // Restore to previous blas threads
408         %(blas_set_num_threads)s(blas_threads_saved);
409     }
410     // Free temporary columns
411     Py_DECREF(col);
412     // decref from contiguous check
413     Py_DECREF(bottom);
414     Py_DECREF(weight);
415     Py_DECREF(top);
416 
417     // Note that we don't change the refcount of the output matrix here. Output
418     // (re)allocation and refcounting is done in BaseCorr3dMM.c_code_helper();
419     // in here output is just aliased to one of bottom, weights, or top.
420     return output;
421 }
422 
423