1 #section init_code_struct
2 prev_algo.algo = PARAMS->conv_algo;
3 prev_algo.mathType = CUDNN_DEFAULT_MATH;
4 reuse_algo = 0;
5 hash_prefix = std::string("GI|GPU#");
6 #ifdef DEBUG_TIMING
7 total_computation_time = 0;
8 total_selection_time = 0;
9 n_computations = 0;
10 n_selections = 0;
11 if (PARAMS->choose_algo) {
12     if (PARAMS->choose_time) {
13         selection_name = "fastest";
14     } else {
15         selection_name = "best suited";
16     }
17 };
18 #endif
19 
20 #section support_code_struct
21 #line 22 "dnn_gi.c"
22 int     reuse_algo;
23 AlgoRec prev_algo;
24 std::string hash_prefix;
25 
26 #define THEANO_DONT_MEMSET_STRUCT
27 
28 #ifdef DEBUG
29 char algorithm_name[128];
30 #endif
31 #ifdef DEBUG_TIMING
32 double total_computation_time;
33 double total_selection_time;
34 size_t n_computations;
35 size_t n_selections;
36 const char* selection_name;
37 #endif
38 
39 /** Check given algorithm against inputs and convolution descriptor,
40     change algorithm inplace to a fallback algorithm if checkings fail.
41     Return 0 on success, non-0 on error. **/
dnn_conv_gi_fallback(cudnnConvolutionBwdDataAlgo_t * _algo,const PyGpuArrayObject * input,const PyGpuArrayObject * kerns,cudnnConvolutionDescriptor_t desc)42 int dnn_conv_gi_fallback(cudnnConvolutionBwdDataAlgo_t* _algo,
43                          const PyGpuArrayObject* input,
44                          const PyGpuArrayObject* kerns,
45                          cudnnConvolutionDescriptor_t desc) {
46   cudnnConvolutionBwdDataAlgo_t algo = *_algo;
47 
48   // The FFT implementation does not support strides, 1x1 filters or inputs
49   // with a spatial dimension larger than 1024. The tiled-FFT implementation
50   // does not support strides.
51   // If the chosen implementation is FFT or tiled-FFT, validate that it can
52   // be used on the current data and default to a safe implementation if it
53   // can't.
54   // The following code is 2d-specific but it is fine as FFT and tiled-FFT are
55   // defined only for 2d filters
56   if ((algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
57        algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT) && PyGpuArray_NDIM(kerns) == 4) {
58 
59     // Extract the properties of the convolution descriptor
60     int nd;
61     int pad[2];
62     int stride[2];
63     int upscale[2];
64     cudnnConvolutionMode_t mode;
65     cudnnDataType_t data_type;
66     cudnnStatus_t err = cudnnGetConvolutionNdDescriptor(desc, 2, &nd, pad, stride, upscale, &mode, &data_type);
67     if (err != CUDNN_STATUS_SUCCESS) {
68       PyErr_Format(PyExc_RuntimeError, "error getting convolution properties: %s",
69                    cudnnGetErrorString(err));
70       return 1;
71     }
72 
73     if (algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT) {
74       if (stride[0] != 1 || stride[1] != 1 ||
75           PyGpuArray_DIM(input, 2) > 1024 || PyGpuArray_DIM(input, 3) > 1024 ||
76           (PyGpuArray_DIM(kerns, 2) == 1 && PyGpuArray_DIM(kerns, 3) == 1))
77       {
78         algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
79         #ifdef DEBUG
80         fprintf(stderr, "(replacing gradinput algo fft with none)\n");
81         #endif
82       }
83     } else {
84       // algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
85       if (stride[0] != 1 || stride[1] != 1) {
86         algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
87         #ifdef DEBUG
88         fprintf(stderr, "(replacing gradinput algo fft_tiling with none)\n");
89         #endif
90       }
91     }
92   }
93   *_algo = algo;
94   return 0;
95 }
96 
97 int
APPLY_SPECIFIC(conv_gi)98 APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
99                         PyGpuArrayObject *im,
100                         cudnnConvolutionDescriptor_t desc,
101                         double alpha, double beta, PyGpuArrayObject **input,
102                         PARAMS_TYPE* params) {
103   PyGpuContextObject *c = kerns->context;
104   void *alpha_p;
105   void *beta_p;
106   float af = alpha, bf = beta;
107   cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
108   bool use_cached = 0;
109   #ifdef DEBUG
110   if (_cppver) fprintf(stderr, "%s\n", _cppver);
111   #endif
112   #ifdef DEBUG_TIMING
113   TheanoTimer timer;
114   #endif
115 
116   if (PyGpuArray_DIMS(im)[1] != PyGpuArray_DIMS(kerns)[1] * params->num_groups) {
117     PyErr_SetString(PyExc_ValueError, "images and kernel must have the same "
118                     "stack size");
119     return 1;
120   }
121   if ((PyGpuArray_DIMS(kerns)[0] % params->num_groups) != 0) {
122     PyErr_SetString(PyExc_ValueError,
123 		    "Number of filters must be divisible by number of groups");
124     return 1;
125   }
126 
127   switch (im->ga.typecode) {
128   case GA_DOUBLE:
129     alpha_p = (void *)α
130     beta_p = (void *)β
131     break;
132   case GA_FLOAT:
133   case GA_HALF:
134     alpha_p = (void *)⁡
135     beta_p = (void *)&bf;
136     break;
137   default:
138     PyErr_SetString(PyExc_TypeError, "Unsupported type in convolution");
139     return 1;
140   }
141 
142   if (params->inplace) {
143     Py_XDECREF(*input);
144     *input = im;
145     Py_INCREF(*input);
146   } else {
147     if (theano_prep_output(input, PyGpuArray_NDIM(im), PyGpuArray_DIMS(im),
148                            im->ga.typecode, GA_C_ORDER, c) != 0)
149       return 1;
150     if (beta != 0.0 && pygpu_move(*input, im))
151       return 1;
152   }
153 
154   if (PyGpuArray_DIMS(im)[0] == 0 || PyGpuArray_DIMS(kerns)[0] == 0 || PyGpuArray_DIMS(kerns)[1] == 0) {
155     int err2 = GpuArray_memset(&(*input)->ga, 0);
156     if (err2 != GA_NO_ERROR) {
157         PyErr_Format(PyExc_RuntimeError,
158                      "GpuDnnConv grad wrt. inputs could not fill the output with zeros: %d", err2);
159         return 1;
160     }
161     return 0;
162   }
163 
164   int groups = c_get_groups_for_conv(desc, params->num_groups);
165   if (groups == -1)
166     return 1;
167   if (c_set_tensor_for_conv(output, APPLY_SPECIFIC(output), groups) == -1)
168     return 1;
169   if (c_set_filter(kerns, APPLY_SPECIFIC(kerns), groups) == -1)
170     return 1;
171   if (c_set_tensor_for_conv(*input, APPLY_SPECIFIC(input), groups) == -1)
172     return 1;
173 
174   if (0 != dnn_check_convolution_output(desc, APPLY_SPECIFIC(input), APPLY_SPECIFIC(kerns),
175                                         PyGpuArray_NDIM(kerns), output, groups))
176     return 1;
177 
178   size_t input_offset = PyGpuArray_STRIDE(*input, 0) / groups;
179   size_t kern_offset = PyGpuArray_STRIDE(kerns, 0) * PyGpuArray_DIM(kerns, 0) / groups;
180   size_t output_offset = PyGpuArray_STRIDE(output, 0) / groups;
181 
182   cudnnConvolutionBwdDataAlgo_t algo = params->conv_algo;
183   size_t worksize = 0;
184   cudnnMathType_t mathtype = CUDNN_DEFAULT_MATH;
185 
186   std::string hashkey;
187 
188 
189   cuda_enter(c->ctx);
190 
191   size_t maxfree = c_get_largest_free_block_size(c);
192   if (PyErr_Occurred()) {
193     cuda_exit(c->ctx);
194     return 1;
195   }
196 
197   if (params->choose_algo) {
198 
199     if (!reuse_algo) {
200       char pci_id[16];
201       gpucontext_property(c->ctx, GA_CTX_PROP_UNIQUE_ID, pci_id);
202       // check out cache
203       hashkey = dnn_conv_shape(APPLY_SPECIFIC(input), *input, APPLY_SPECIFIC(kerns), kerns, desc, output, groups);
204       if (hashkey.empty()) {
205         cuda_exit(c->ctx);
206         return 1;
207       }
208       hashkey = hash_prefix + pci_id + (params->choose_time ? " -t " : " ") + hashkey;
209       const AlgoRec* cached = dnn_conv_check_cache(hashkey);
210       if (cached) {
211         prev_algo = *cached;
212         use_cached = 1;
213       }
214     }
215 
216     if (reuse_algo || use_cached) {
217       algo = (cudnnConvolutionBwdDataAlgo_t)prev_algo.algo;
218       worksize = prev_algo.wsSize;
219       mathtype = prev_algo.mathType;
220     } else {
221       if (params->choose_time) {
222         int count;
223         cudnnConvolutionBwdDataAlgoPerf_t choice;
224         gpudata *tmpmem;
225 
226         // set the 'tensor math ok' flag
227         if (im->ga.typecode == GA_HALF)
228           c_set_math_type_for_conv(desc, CUDNN_TENSOR_OP_MATH);
229 
230         tmpmem = gpudata_alloc(c->ctx, maxfree, NULL, 0, NULL);
231         if (tmpmem == NULL) {
232           PyErr_SetString(PyExc_MemoryError, "Could not allocate working GPU memory");
233           cuda_exit(c->ctx);
234           return -1;
235         }
236 
237         /* cudnnFindConvolutionBackwardDataAlgorithmEx() may write to output (input).
238            We don't want that if output is used in computation (ie. if beta != 0). */
239         PyGpuArrayObject* ip = *input;
240         if (beta != 0) {
241             ip = pygpu_empty(PyGpuArray_NDIM(*input), PyGpuArray_DIMS(*input), (*input)->ga.typecode, GA_C_ORDER, c, Py_None);
242         }
243 
244         #ifdef DEBUG_TIMING
245         timer.start();
246         #endif
247         err = cudnnFindConvolutionBackwardDataAlgorithmEx(
248           params->handle, APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(kerns),
249           APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(output), desc,
250           APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(ip),
251           1, &count, &choice, *(void **)tmpmem, maxfree);
252         #ifdef DEBUG_TIMING
253         timer.end();
254         #endif
255         gpudata_release(tmpmem);
256         if (beta != 0) {
257             Py_XDECREF(ip);
258         }
259 
260         if (err != CUDNN_STATUS_SUCCESS) {
261           PyErr_Format(PyExc_RuntimeError, "error selecting convolution algo: %s",
262                        cudnnGetErrorString(err));
263           cuda_exit(c->ctx);
264           return 1;
265         }
266 
267         #ifdef DEBUG
268         if (count == 0) {
269             PyErr_SetString(PyExc_RuntimeError, "No best-timed conv gradinput algorithm found");
270             cuda_exit(c->ctx);
271             return 1;
272         } else if (choice.status != CUDNN_STATUS_SUCCESS) {
273             PyErr_Format(PyExc_RuntimeError, "error getting best-timed gradinput algo: %s",
274                          cudnnGetErrorString(choice.status));
275             cuda_exit(c->ctx);
276             return 1;
277         } // Else, count is necessarly 1 for current implementation.
278         #endif
279 
280         algo = choice.algo;
281         worksize = choice.memory;
282 #if CUDNN_MAJOR >= 7
283         if (im->ga.typecode == GA_HALF)
284           mathtype = choice.mathType;
285 #endif
286       } else {
287         #ifdef DEBUG_TIMING
288         timer.start();
289         #endif
290         err = cudnnGetConvolutionBackwardDataAlgorithm(
291           params->handle, APPLY_SPECIFIC(kerns), APPLY_SPECIFIC(output),
292           desc, APPLY_SPECIFIC(input),
293           CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, maxfree, &algo);
294         #ifdef DEBUG_TIMING
295         timer.end();
296         #endif
297         if (err != CUDNN_STATUS_SUCCESS) {
298           PyErr_Format(PyExc_RuntimeError, "error selecting convolution algo: %s",
299                        cudnnGetErrorString(err));
300           cuda_exit(c->ctx);
301           return 1;
302         }
303       }
304       #ifdef DEBUG_TIMING
305       total_selection_time += timer.milliseconds;
306       ++n_selections;
307       #endif
308     }
309   }
310 
311   if (c_set_math_type_for_conv(desc, mathtype) == -1 ||
312       dnn_conv_gi_fallback(&algo, *input, kerns, desc) != 0) {
313     cuda_exit(c->ctx);
314     return 1;
315   }
316 
317   // if FindEx was used (choose_time), workspace size is set.
318   if (!(reuse_algo || use_cached || params->choose_time))
319   {
320     err = cudnnGetConvolutionBackwardDataWorkspaceSize(
321       params->handle, APPLY_SPECIFIC(kerns), APPLY_SPECIFIC(output), desc,
322       APPLY_SPECIFIC(input), algo, &worksize);
323     if (err == CUDNN_STATUS_NOT_SUPPORTED) {
324       // Fallback to none algo if not supported
325       #ifdef DEBUG
326       if (0 != theano_enum_to_string_cudnnConvolutionBwdDataAlgo_t(algo, algorithm_name)) {
327         cuda_exit(c->ctx);
328         return 1;
329       }
330       fprintf(stderr, "(error getting worksize for %s: failing back to CUDNN_CONVOLUTION_BWD_DATA_ALGO_0)\n",
331               algorithm_name);
332       #endif
333       algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
334       err = cudnnGetConvolutionBackwardDataWorkspaceSize(
335         params->handle, APPLY_SPECIFIC(kerns), APPLY_SPECIFIC(output), desc,
336         APPLY_SPECIFIC(input), algo, &worksize);
337     }
338 
339     if (err != CUDNN_STATUS_SUCCESS) {
340       PyErr_Format(PyExc_RuntimeError, "error getting worksize: %s",
341                    cudnnGetErrorString(err));
342       cuda_exit(c->ctx);
343       return 1;
344     }
345   }  // !(reuse_algo || use_cached || params->choose_time)
346 
347   if (params->choose_algo) {
348 
349 #ifdef DEBUG
350     if (0 != theano_enum_to_string_cudnnConvolutionBwdDataAlgo_t(algo, algorithm_name)) {
351         cuda_exit(c->ctx);
352         return 1;
353     }
354     fprintf(stderr, "(using %s%s %s%s%s, ws:%ld, hash:%s)\n",
355             algorithm_name,
356             mathtype == CUDNN_TENSOR_OP_MATH ? "(tensor_op)" : "",
357             params->choose_time ? "(timed)": "" ,
358             reuse_algo ? "(reused)" : "",
359             use_cached ? "(cache)": "",
360             worksize,
361             hashkey.c_str()
362     );
363 #endif
364 #ifdef DEBUG_TIMING
365     if (!(reuse_algo || use_cached)) {
366         // We have selected an algorithm at runtime.
367         // `timer` still contains timing about selection step.
368         fprintf(stderr, "\t(selected %s gradinput algo in %g milliseconds)\n", selection_name, timer.milliseconds);
369         if (n_selections > 1) {
370             fprintf(stderr, "\t(selected %lu gradinput algos in %g milliseconds (average: %g milliseconds per selection))\n",
371                     n_selections, total_selection_time, total_selection_time / n_selections);
372         }
373     }
374 #endif
375 
376     if (!reuse_algo) {
377       // save for next time/cache
378       prev_algo.algo = algo;
379       prev_algo.wsSize = worksize;
380       prev_algo.mathType = mathtype;
381 
382       // Add to the cache
383       if (!use_cached)
384         dnn_conv_update_cache(hashkey, prev_algo);
385 
386       if (params->choose_once)
387         reuse_algo = 1;
388     }
389 
390   } // params->choose_algo
391 
392   gpudata *workspace = 0;
393   if (worksize != 0) {
394     workspace = gpudata_alloc(c->ctx, worksize, NULL, 0, NULL);
395     if (workspace == NULL) {
396       PyErr_SetString(PyExc_RuntimeError, "Could not allocate working memory");
397       cuda_exit(c->ctx);
398       return 1;
399     }
400   }
401 
402   if (worksize != 0)
403     cuda_wait(workspace, GPUARRAY_CUDA_WAIT_WRITE);
404   cuda_wait(kerns->ga.data, GPUARRAY_CUDA_WAIT_READ);
405   cuda_wait(output->ga.data, GPUARRAY_CUDA_WAIT_READ);
406   cuda_wait((*input)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
407 
408   #ifdef DEBUG_TIMING
409   GpuArray_sync(&(*input)->ga);
410   timer.start();
411   #endif
412 
413   for ( int g = 0; g < groups; g++) {
414     err = cudnnConvolutionBackwardData(
415       params->handle,
416       alpha_p,
417       APPLY_SPECIFIC(kerns), ((char *)PyGpuArray_DEV_DATA(kerns)) + kern_offset * g,
418       APPLY_SPECIFIC(output), ((char *)PyGpuArray_DEV_DATA(output)) + output_offset * g,
419       desc, algo, worksize == 0 ? NULL : *(void **)workspace, worksize,
420       beta_p,
421       APPLY_SPECIFIC(input), ((char *)PyGpuArray_DEV_DATA(*input)) + input_offset * g);
422   }
423 
424   if (worksize != 0) {
425     cuda_record(workspace, GPUARRAY_CUDA_WAIT_WRITE);
426     gpudata_release(workspace);
427   }
428 
429   cuda_record(kerns->ga.data, GPUARRAY_CUDA_WAIT_READ);
430   cuda_record(output->ga.data, GPUARRAY_CUDA_WAIT_READ);
431   cuda_record((*input)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
432 
433   #ifdef DEBUG_TIMING
434   GpuArray_sync(&(*input)->ga);
435   timer.end();
436   total_computation_time += timer.milliseconds;
437   ++n_computations;
438   #endif
439 
440   cuda_exit(c->ctx);
441 
442   if (err != CUDNN_STATUS_SUCCESS) {
443     PyErr_Format(PyExc_RuntimeError, "error doing cuDNN conv gradinput operation: %s",
444                  cudnnGetErrorString(err));
445     return 1;
446   }
447   #ifdef DEBUG_TIMING
448   fprintf(stderr, "\t(ran gradinput algo in %g milliseconds)\n", timer.milliseconds);
449   if (n_computations > 1) {
450     fprintf(stderr, "\t(ran %lu gradinput computations in %g milliseconds (average: %g milliseconds per call))\n",
451             n_computations, total_computation_time, total_computation_time / n_computations);
452   }
453   #endif
454   return 0;
455 }
456