1 // This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/); 2 // sources are clearly marked. Below we reproduce the original license of 3 // the Caffe software. 4 /* 5 Copyright (c) 2014, The Regents of the University of California (Regents) 6 All rights reserved. 7 8 Redistribution and use in source and binary forms, with or without 9 modification, are permitted provided that the following conditions are met: 10 11 1. Redistributions of source code must retain the above copyright notice, this 12 list of conditions and the following disclaimer. 13 2. Redistributions in binary form must reproduce the above copyright notice, 14 this list of conditions and the following disclaimer in the documentation 15 and/or other materials provided with the distribution. 16 17 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 21 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 */ 28 29 // (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp) 30 // Loops for fast unfold + copy 31 void im3d2col(const %(float_type)s* data_im, const int channels, 32 const int height, const int width, const int depth, 33 const int kernel_h, const int kernel_w, const int kernel_d, 34 const int dilation_h, const int dilation_w, const int dilation_d, 35 const int pad_h, const int pad_w, const int pad_d, 36 const int stride_h, const int stride_w, const int stride_d, 37 %(float_type)s* data_col) { 38 // Implicit dilated kernel size 39 int dil_kernel_h = (kernel_h - 1) * dilation_h + 1; 40 int dil_kernel_w = (kernel_w - 1) * dilation_w + 1; 41 int dil_kernel_d = (kernel_d - 1) * dilation_d + 1; 42 int height_col = (height + 2 * pad_h - dil_kernel_h) / stride_h + 1; 43 int width_col = (width + 2 * pad_w - dil_kernel_w) / stride_w + 1; 44 int depth_col = (depth + 2 * pad_d - dil_kernel_d) / stride_d + 1; 45 int channels_col = channels * kernel_h * kernel_w * kernel_d; 46 for (int c = 0; c < channels_col; ++c) { 47 int d_offset = c %% kernel_d; 48 int w_offset = (c / kernel_d) %% kernel_w; 49 int h_offset = (c / kernel_w / kernel_d) %% kernel_h; 50 int c_im = c / kernel_h / kernel_w / kernel_d; 51 for (int h = 0; h < height_col; ++h) { 52 int h_pad = h * stride_h - pad_h + h_offset * dilation_h; 53 for (int w = 0; w < width_col; ++w) { 54 int w_pad = w * stride_w - pad_w + w_offset * dilation_w; 55 for (int d = 0; d < depth_col; ++d) { 56 int d_pad = d * stride_d - pad_d + d_offset * dilation_d; 57 if (h_pad >= 0 && h_pad < height 58 && w_pad >= 0 && w_pad < width 59 && d_pad >= 0 && d_pad < depth) 60 data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d] = 61 data_im[(npy_intp)((c_im * height + h_pad) * width + w_pad) * depth + d_pad]; 62 else 63 data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d] = 0.; 64 } 65 } 66 } 67 } 68 } 69 70 // Unlike the Caffe and Theano GPU verions, the data_im array is set to zero 71 // before the col2im call rather than doing it here. So, the result is just 72 // accumulated into data_im. 73 void col2im3d(const %(float_type)s* data_col, const int channels, 74 const int height, const int width, const int depth, 75 const int patch_h, const int patch_w, const int patch_d, 76 const int dilation_h, const int dilation_w, const int dilation_d, 77 const int pad_h, const int pad_w, const int pad_d, 78 const int stride_h, const int stride_w, const int stride_d, 79 %(float_type)s* data_im) { 80 // Implicit dilated patch 81 int dil_patch_h = (patch_h - 1) * dilation_h + 1; 82 int dil_patch_w = (patch_w - 1) * dilation_w + 1; 83 int dil_patch_d = (patch_d - 1) * dilation_d + 1; 84 int height_col = (height + 2 * pad_h - dil_patch_h) / stride_h + 1; 85 int width_col = (width + 2 * pad_w - dil_patch_w) / stride_w + 1; 86 int depth_col = (depth + 2 * pad_d - dil_patch_d) / stride_d + 1; 87 int num_kernels = channels * height * width * depth; 88 int channels_col = channels * patch_h * patch_w * patch_d; 89 for (int c = 0; c < channels_col; ++c) { 90 int d_offset = c %% patch_d; 91 int w_offset = (c / patch_d) %% patch_w; 92 int h_offset = (c / patch_w / patch_d) %% patch_h; 93 int c_im = c / patch_h / patch_w / patch_d; 94 for (int h = 0; h < height_col; ++h) { 95 int h_pad = h * stride_h - pad_h + h_offset * dilation_h; 96 for (int w = 0; w < width_col; ++w) { 97 int w_pad = w * stride_w - pad_w + w_offset * dilation_w; 98 for (int d = 0; d < depth_col; ++d) { 99 int d_pad = d * stride_d - pad_d + d_offset * dilation_d; 100 if (h_pad >= 0 && h_pad < height 101 && w_pad >= 0 && w_pad < width 102 && d_pad >= 0 && d_pad < depth) 103 data_im[(npy_intp)((c_im * height + h_pad) * width + w_pad) * depth + d_pad] += 104 data_col[(npy_intp)((c * height_col + h) * width_col + w) * depth_col + d]; 105 } 106 } 107 } 108 } 109 } 110 111 112 // Theano op code 113 // GPU version authors: Arjun Jain, Frederic Bastien, Jan Schlueter 114 // Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu 115 // and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu 116 // CPU version author: Jesse Livezey 117 // CPU version adapted from GPU version 118 PyArrayObject* corr3dMM(PyArrayObject* bottom, 119 PyArrayObject* weight, 120 PyArrayObject* top, 121 const int direction, 122 const int dH = 1, 123 const int dW = 1, 124 const int dD = 1, 125 const int dilH = 1, 126 const int dilW = 1, 127 const int dilD = 1, 128 const int padH = 0, 129 const int padW = 0, 130 const int padD = 0, 131 const int numgroups=1) 132 { 133 if (PyArray_NDIM(bottom) != 5) 134 { 135 PyErr_SetString(PyExc_ValueError, "Corr3dMM requires bottom of 5D"); 136 return NULL; 137 } 138 if (PyArray_TYPE(bottom) != %(float_typenum)s) 139 { 140 PyErr_SetString(PyExc_ValueError, "Corr3dMM received bottom with wrong type."); 141 return NULL; 142 } 143 144 if (PyArray_NDIM(weight) != 5) 145 { 146 PyErr_SetString(PyExc_ValueError, "Corr3dMM requires weight of 5D"); 147 return NULL; 148 } 149 if (PyArray_TYPE(weight) != %(float_typenum)s) 150 { 151 PyErr_SetString(PyExc_ValueError, "Corr3dMM received weight with wrong type."); 152 return NULL; 153 } 154 155 if (PyArray_NDIM(top) != 5) 156 { 157 PyErr_SetString(PyExc_ValueError, "Corr3dMM requires top of 5D"); 158 return NULL; 159 } 160 if (PyArray_TYPE(top) != %(float_typenum)s) 161 { 162 PyErr_SetString(PyExc_ValueError, "Corr3dMM received top with wrong type."); 163 return NULL; 164 } 165 // Ensure data is contiguous 166 bottom = PyArray_GETCONTIGUOUS(bottom); 167 weight = PyArray_GETCONTIGUOUS(weight); 168 top = PyArray_GETCONTIGUOUS(top); 169 170 // Extract some shape information for later and check shape consistency 171 // bottom: (batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth) 172 const int batchSize = PyArray_DIMS(bottom)[0]; 173 const int nChannels = PyArray_DIMS(bottom)[1]; 174 const int bottomHeight = PyArray_DIMS(bottom)[2]; 175 const int bottomWidth = PyArray_DIMS(bottom)[3]; 176 const int bottomDepth = PyArray_DIMS(bottom)[4]; 177 // weights: (nFilters, nChannels, rows, columns, slices) 178 const int nFilters = PyArray_DIMS(weight)[0]; 179 const int kH = PyArray_DIMS(weight)[2]; 180 const int kW = PyArray_DIMS(weight)[3]; 181 const int kD = PyArray_DIMS(weight)[4]; 182 if (nChannels != PyArray_DIMS(weight)[1] * numgroups) { 183 PyErr_SetString(PyExc_ValueError, 184 "Corr3dMM images and kernel must have the same stack size\n"); 185 return NULL; 186 } 187 if ((nFilters %% numgroups) != 0) { 188 PyErr_SetString(PyExc_ValueError, 189 "CorrMM the number of filters must be divisible by the number of groups\n"); 190 return NULL; 191 } 192 // implicit dilated filter 193 const int dil_kH = (kH - 1) * dilH + 1; 194 const int dil_kW = (kW - 1) * dilW + 1; 195 const int dil_kD = (kD - 1) * dilD + 1; 196 // top: (batchSize, nFilters, topHeight, topWidth, topDepth) 197 const int topHeightNoDH = (bottomHeight + 2*padH - dil_kH); 198 const int topWidthNoDW = (bottomWidth + 2*padW - dil_kW); 199 const int topDepthNoDD = (bottomDepth + 2*padD - dil_kD); 200 // the above values might be negative so we need to use Python-like 201 // flooring integer division to be compatible with get_conv_output. 202 // note: this macro implements Python's // for negative x only 203 #define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) %% y) == 0 ? 0 : 1)) : (x / y)) 204 const int topHeight = _CONV_FLOORDIV_X(topHeightNoDH, dH) + 1; 205 const int topWidth = _CONV_FLOORDIV_X(topWidthNoDW, dW) + 1; 206 const int topDepth = _CONV_FLOORDIV_X(topDepthNoDD, dD) + 1; 207 #undef _CONV_FLOORDIV 208 if (batchSize != PyArray_DIMS(top)[0] || 209 nFilters != PyArray_DIMS(top)[1] || 210 topHeight != PyArray_DIMS(top)[2] || 211 topWidth != PyArray_DIMS(top)[3] || 212 topDepth != PyArray_DIMS(top)[4]) { 213 PyErr_Format(PyExc_ValueError, 214 "Corr3dMM shape inconsistency:\n" 215 " bottom shape: %%d %%d %%d %%d %%d\n" 216 " weight shape: %%d %%d %%d %%d %%d\n" 217 " top shape: %%ld %%ld %%ld %%ld %%ld (expected %%d %%d %%d %%d %%d)\n", 218 batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth, 219 nFilters, nChannels / numgroups, kH, kW, kD, 220 PyArray_DIMS(top)[0], PyArray_DIMS(top)[1], 221 PyArray_DIMS(top)[2], PyArray_DIMS(top)[3], PyArray_DIMS(top)[4], 222 batchSize, nFilters, topHeight, topWidth, topDepth); 223 return NULL; 224 } 225 226 // Create temporary columns 227 int max_threads = %(omp_get_max_threads)s; 228 if (batchSize < max_threads) { 229 max_threads = batchSize; 230 } 231 npy_intp col_dim[3]; 232 col_dim[0] = (npy_intp)max_threads; 233 col_dim[1] = (npy_intp)(nChannels * kW * kH * kD); 234 col_dim[2] = (npy_intp)(topHeight * topWidth * topDepth); 235 236 //Change to PyArray_ZEROS which is faster than PyArray_EMPTY. 237 PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3, 238 col_dim, 239 PyArray_TYPE(top), 240 0); 241 if (NULL == col) { 242 PyErr_Format(PyExc_RuntimeError, 243 "Corr3dMM failed to allocate working memory of" 244 " %%ld x %%ld x %%ld\n", 245 col_dim[0], col_dim[1], col_dim[2]); 246 return NULL; 247 } 248 249 // Define some useful variables 250 const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f; 251 const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f; 252 const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f; 253 const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f; 254 const int K_ = col_dim[1] / numgroups; 255 const int N_ = col_dim[2]; 256 const int col_stride = (K_ * N_ * numgroups); 257 const int group_col_stride = (K_ * N_); 258 const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f; 259 const int M_ = nFilters / numgroups; 260 const %(c_float_type)s one = 1.0; 261 const %(c_float_type)s zero = 0.0; 262 char NTrans = 'N'; 263 char Trans = 'T'; 264 PyArrayObject *output; 265 266 if (batchSize == 0 || nChannels == 0 || nFilters == 0) { 267 switch(direction) { 268 case 0: 269 output = top; 270 break; 271 case 1: 272 output = weight; 273 break; 274 case 2: 275 output = bottom; 276 break; 277 default: 278 return NULL; 279 } 280 PyArray_FILLWBYTE(output, 0); 281 } 282 else if (direction == 0) { // forward pass 283 output = top; 284 // valid correlation: im3d2col, then gemm 285 // Iterate over batch 286 int blas_threads_saved = %(blas_get_num_threads)s; 287 // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance. 288 %(blas_set_num_threads)s(1); 289 %(omp_flags)s 290 for (int n = 0; n < batchSize; ++n) { 291 int tid = %(omp_get_thread_num)s; 292 // First, im3d2col 293 im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, 294 nChannels, bottomHeight, bottomWidth, bottomDepth, 295 kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, 296 (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); 297 298 for ( int g = 0; g < numgroups; ++g){ 299 // Second, gemm 300 %(gemm)s(&NTrans, &NTrans, 301 &N_, &M_, &K_, 302 &one, 303 (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride + g * group_col_stride, &N_, 304 (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_, 305 &zero, 306 (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_); 307 } 308 } 309 // Restore to previous blas threads 310 %(blas_set_num_threads)s(blas_threads_saved); 311 } 312 else if (direction == 1) { // backprop wrt. weights 313 output = weight; 314 npy_intp weight_dim[2]; 315 weight_dim[0] = (npy_intp)max_threads; 316 weight_dim[1] = (npy_intp)(M_ * K_ * numgroups); 317 PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2, 318 weight_dim, PyArray_TYPE(weight), 0); 319 320 if (NULL == local_weight) 321 { 322 PyErr_Format(PyExc_RuntimeError, 323 "Corr3dMM failed to allocate weight memory of %%ld x %%ld\n", 324 weight_dim[0], weight_dim[1]); 325 return NULL; 326 } 327 328 // valid convolution: im2col, then gemm 329 // Iterate over batch 330 int blas_threads_saved = %(blas_get_num_threads)s; 331 // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance. 332 %(blas_set_num_threads)s(1); 333 // OMP for batch-level paralization 334 %(omp_flags)s 335 for (int n = 0; n < batchSize; ++n) { 336 int tid = %(omp_get_thread_num)s; 337 // First, im2col 338 im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, 339 nChannels, bottomHeight, bottomWidth, bottomDepth, 340 kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, 341 (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); 342 343 for ( int g = 0; g < numgroups; ++g){ 344 // Second, gemm 345 // Note that we accumulate into weight. We do so by setting beta = 0 346 // for the first iteration and beta = 1 for subsequent ones. (This 347 // is faster than setting weight to all zeros before the loop.) 348 %(gemm)s(&Trans, &NTrans, 349 &K_, &M_, &N_, 350 &one, 351 (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_, 352 (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_, 353 (n == 0) ? &zero : &one, 354 (%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride + 355 tid * weight_dim[1], &K_); 356 } 357 } 358 // Restore to previous blas threads 359 %(blas_set_num_threads)s(blas_threads_saved); 360 361 //aggregate weights 362 memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s)); 363 /* 364 * Put index "j" into outer loop to get the 365 * correct result when openmp is used. 366 */ 367 %(omp_flags)s 368 for(int j = 0; j < weight_dim[1]; ++j){ 369 for(int i = 0; i < max_threads; ++i){ 370 ((%(float_type)s*)PyArray_DATA(weight))[j] += 371 *((%(float_type)s*)PyArray_DATA(local_weight) + 372 i * weight_dim[1] + j); 373 } 374 } 375 Py_DECREF(local_weight); 376 } 377 else if (direction == 2) { // backprop wrt. inputs 378 output = bottom; 379 // bottom is set to zero here rather than inside of col2im 380 PyArray_FILLWBYTE(bottom, 0); 381 // full convolution: gemm, then col2im3d 382 // Iterate over batch 383 384 int blas_threads_saved = %(blas_get_num_threads)s; 385 // Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance. 386 %(blas_set_num_threads)s(1); 387 %(omp_flags)s 388 for (int n = 0; n < batchSize; ++n) { 389 390 int tid = %(omp_get_thread_num)s; 391 for ( int g = 0; g < numgroups; ++g){ 392 // gemm into columns 393 %(gemm)s(&NTrans, &Trans, 394 &N_, &K_, &M_, 395 &one, 396 (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_, 397 (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_, 398 &zero, 399 (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_); 400 } 401 // col2im back to the data 402 col2im3d((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, 403 bottomHeight, bottomWidth, bottomDepth, 404 kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, 405 (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride); 406 } 407 // Restore to previous blas threads 408 %(blas_set_num_threads)s(blas_threads_saved); 409 } 410 // Free temporary columns 411 Py_DECREF(col); 412 // decref from contiguous check 413 Py_DECREF(bottom); 414 Py_DECREF(weight); 415 Py_DECREF(top); 416 417 // Note that we don't change the refcount of the output matrix here. Output 418 // (re)allocation and refcounting is done in BaseCorr3dMM.c_code_helper(); 419 // in here output is just aliased to one of bottom, weights, or top. 420 return output; 421 } 422 423