1 #section support_code
2 
dnn_rnn_desc(int hidden_size,int num_layers,cudnnDropoutDescriptor_t ddesc,int input_mode,int direction_mode,int rnn_mode,int dtype,cudnnRNNDescriptor_t * odesc,cudnnHandle_t _handle)3 int dnn_rnn_desc(int hidden_size, int num_layers,
4                  cudnnDropoutDescriptor_t ddesc,
5                  int input_mode, int direction_mode, int rnn_mode,
6                  int dtype, cudnnRNNDescriptor_t *odesc,
7                  cudnnHandle_t _handle) {
8   cudnnRNNDescriptor_t desc;
9   cudnnDataType_t data_type;
10   cudnnStatus_t err;
11 
12   switch (dtype) {
13   case GA_FLOAT:
14     data_type = CUDNN_DATA_FLOAT;
15     break;
16   case GA_DOUBLE:
17     data_type = CUDNN_DATA_DOUBLE;
18     break;
19   case GA_HALF:
20     data_type = CUDNN_DATA_HALF;
21     break;
22   default:
23     PyErr_SetString(PyExc_ValueError, "Unsupported data type");
24     return -1;
25   }
26 
27   err = cudnnCreateRNNDescriptor(&desc);
28   if (err != CUDNN_STATUS_SUCCESS) {
29     PyErr_SetString(PyExc_RuntimeError, "Can't create RNN descriptor");
30     return -1;
31   }
32   #if CUDNN_MAJOR < 7
33   err = cudnnSetRNNDescriptor(desc, hidden_size, num_layers, ddesc,
34                               (cudnnRNNInputMode_t)input_mode,
35                               (cudnnDirectionMode_t)direction_mode,
36                               (cudnnRNNMode_t)rnn_mode, data_type);
37   #else
38   err = cudnnSetRNNDescriptor(_handle, desc, hidden_size, num_layers, ddesc,
39                               (cudnnRNNInputMode_t)input_mode,
40                               (cudnnDirectionMode_t)direction_mode,
41                               (cudnnRNNMode_t)rnn_mode, CUDNN_RNN_ALGO_STANDARD, data_type);
42   #endif
43   if (err != CUDNN_STATUS_SUCCESS) {
44     cudnnDestroyRNNDescriptor(desc);
45     PyErr_SetString(PyExc_RuntimeError, "Can't set RNN descriptor");
46     return -1;
47   }
48 
49   *odesc = desc;
50   return 0;
51 }
52