1 #section support_code
2
dnn_rnn_desc(int hidden_size,int num_layers,cudnnDropoutDescriptor_t ddesc,int input_mode,int direction_mode,int rnn_mode,int dtype,cudnnRNNDescriptor_t * odesc,cudnnHandle_t _handle)3 int dnn_rnn_desc(int hidden_size, int num_layers,
4 cudnnDropoutDescriptor_t ddesc,
5 int input_mode, int direction_mode, int rnn_mode,
6 int dtype, cudnnRNNDescriptor_t *odesc,
7 cudnnHandle_t _handle) {
8 cudnnRNNDescriptor_t desc;
9 cudnnDataType_t data_type;
10 cudnnStatus_t err;
11
12 switch (dtype) {
13 case GA_FLOAT:
14 data_type = CUDNN_DATA_FLOAT;
15 break;
16 case GA_DOUBLE:
17 data_type = CUDNN_DATA_DOUBLE;
18 break;
19 case GA_HALF:
20 data_type = CUDNN_DATA_HALF;
21 break;
22 default:
23 PyErr_SetString(PyExc_ValueError, "Unsupported data type");
24 return -1;
25 }
26
27 err = cudnnCreateRNNDescriptor(&desc);
28 if (err != CUDNN_STATUS_SUCCESS) {
29 PyErr_SetString(PyExc_RuntimeError, "Can't create RNN descriptor");
30 return -1;
31 }
32 #if CUDNN_MAJOR < 7
33 err = cudnnSetRNNDescriptor(desc, hidden_size, num_layers, ddesc,
34 (cudnnRNNInputMode_t)input_mode,
35 (cudnnDirectionMode_t)direction_mode,
36 (cudnnRNNMode_t)rnn_mode, data_type);
37 #else
38 err = cudnnSetRNNDescriptor(_handle, desc, hidden_size, num_layers, ddesc,
39 (cudnnRNNInputMode_t)input_mode,
40 (cudnnDirectionMode_t)direction_mode,
41 (cudnnRNNMode_t)rnn_mode, CUDNN_RNN_ALGO_STANDARD, data_type);
42 #endif
43 if (err != CUDNN_STATUS_SUCCESS) {
44 cudnnDestroyRNNDescriptor(desc);
45 PyErr_SetString(PyExc_RuntimeError, "Can't set RNN descriptor");
46 return -1;
47 }
48
49 *odesc = desc;
50 return 0;
51 }
52