1"""
2Declarations of cuDNN types and constants used in Theano gpuarray DNN module.
3
4For every cuDNN API supported by Theano, this module defines a class that
5provides the set of cuDNN definitions to be used in Theano Ops.
6
7Use :func:`get_definitions` to get the right cuDNN definitions
8for a given cuDNN version.
9
10Currently supported cuDNN APIs:
11
12 - v5.1*
13 - v6.0*
14 - v7.0*
15
16"""
17
18from __future__ import absolute_import, print_function, division
19
20from theano.gof import CEnumType
21
22HALF, FLOAT, DOUBLE = ('float16', 'float32', 'float64')
23TRUE_HALF_CONFIG = (HALF, HALF)
24PSEUDO_HALF_CONFIG = (HALF, FLOAT)
25FLOAT_CONFIG = (FLOAT, FLOAT)
26DOUBLE_CONFIG = (DOUBLE, DOUBLE)
27
28
29def is_true_half_config(dtype, precision):
30    return dtype == precision == HALF
31
32
33def is_pseudo_half_config(dtype, precision):
34    return dtype == HALF and precision == FLOAT
35
36
37def is_float_config(dtype, precision):
38    return dtype == precision == FLOAT
39
40
41def is_double_config(dtype, precision):
42    return dtype == precision == DOUBLE
43
44
45# NB: Some cuDNN algorithms are listed in cuDNN enums but not implemented.
46# We still register them here because we try to exactly copy cuDNN enums
47# in Python side, but they will have no aliases associated, to help
48# exclude them from lists of supported algorithms.
49
50
51class CuDNNV51(object):
52    version = 5
53
54    cudnnConvolutionMode_t = CEnumType(('CUDNN_CONVOLUTION', 'conv'),
55                                       ('CUDNN_CROSS_CORRELATION', 'cross'),
56                                       ctype='cudnnConvolutionMode_t')
57
58    cudnnDataType_t = CEnumType(('CUDNN_DATA_FLOAT', 'float32'),
59                                ('CUDNN_DATA_DOUBLE', 'float64'),
60                                ('CUDNN_DATA_HALF', 'float16'),
61                                ctype='cudnnDataType_t')
62
63    cudnnConvolutionFwdAlgo_t = CEnumType(('CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM', 'none'),
64                                          ('CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM', 'small'),
65                                          ('CUDNN_CONVOLUTION_FWD_ALGO_GEMM', 'large'),
66                                          # not implemented:
67                                          ('CUDNN_CONVOLUTION_FWD_ALGO_DIRECT'),
68                                          ('CUDNN_CONVOLUTION_FWD_ALGO_FFT', 'fft'),
69                                          ('CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING', 'fft_tiling'),
70                                          ('CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD', 'winograd'),
71                                          # TODO: Not yet tested/documented:
72                                          ('CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
73                                          ctype='cudnnConvolutionFwdAlgo_t')
74
75    conv3d_fwd_algorithms = ('none', 'small', 'fft_tiling')
76
77    deterministic_fwd_algorithms = cudnnConvolutionFwdAlgo_t.get_aliases()
78
79    cudnnConvolutionBwdFilterAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0', 'none'),
80                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'),
81                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'),
82                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3', 'small'),
83                                                # TODO: not yet tested/documented:
84                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
85                                                ctype='cudnnConvolutionBwdFilterAlgo_t')
86
87    conv3d_bwd_filter_algorithms = ('none', 'small')
88
89    deterministic_bwd_filter_algorithms = ('deterministic', 'fft', 'winograd_non_fused')
90
91    cudnnConvolutionBwdDataAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', 'none'),
92                                              ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_1', 'deterministic'),
93                                              ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT', 'fft'),
94                                              ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING', 'fft_tiling'),
95                                              ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD', 'winograd'),
96                                              # TODO: not yet tested/documented:
97                                              ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
98                                              ctype='cudnnConvolutionBwdDataAlgo_t')
99
100    conv3d_bwd_data_algorithms = ('none', 'deterministic', 'fft_tiling')
101
102    deterministic_bwd_data_algorithms = ('deterministic', 'fft', 'fft_tiling', 'winograd', 'winograd_non_fused')
103
104    cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
105                                   ('CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING', 'average_inc_pad'),
106                                   ('CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING', 'average_exc_pad'),
107                                   ctype='cudnnPoolingMode_t')
108
109    cudnnSoftmaxAlgorithm_t = CEnumType(('CUDNN_SOFTMAX_FAST', 'fast'),
110                                        ('CUDNN_SOFTMAX_ACCURATE', 'accurate'),
111                                        ('CUDNN_SOFTMAX_LOG', 'log'),
112                                        ctype='cudnnSoftmaxAlgorithm_t')
113
114    cudnnSoftmaxMode_t = CEnumType(('CUDNN_SOFTMAX_MODE_INSTANCE', 'instance'),
115                                   ('CUDNN_SOFTMAX_MODE_CHANNEL', 'channel'),
116                                   ctype='cudnnSoftmaxMode_t')
117
118    cudnnBatchNormMode_t = CEnumType(('CUDNN_BATCHNORM_PER_ACTIVATION', 'per-activation'),
119                                     ('CUDNN_BATCHNORM_SPATIAL', 'spatial'),
120                                     ctype='cudnnBatchNormMode_t')
121    # It was introduced in cudnnv6, but we need to define it with an
122    # empty list of enum to don't crash with cudnn 5
123    cudnnReduceTensorOp_t = CEnumType()
124
125    def get_supported_dtype_configs(self, check_runtime=None):
126        """
127        Return the tuple of data type configurations supported by this version of cuDNN.
128        This is currently convenient for all supported cuDNN versions, as Theano does not
129        yet support new data types (like INT8, INT8x4, etc.).
130
131        ``check_runtime`` may be a function that tests if a data type configuration is supported.::
132
133            is_supported = check_runtime(dtype, precision)
134
135        .. warning::
136
137            From documentation for cudnnConvolutionForward (for both v5.1 and v6):
138
139            .. code-block::
140
141                TRUE_HALF_CONFIG is only supported on architectures with true fp16 support
142                (compute capability 5.3 and 6.0)
143
144            This seems to be a general remark about f16 support (not only for FWD).
145            It can be checked at runtime only.
146
147        """
148
149        if check_runtime is None or check_runtime(*TRUE_HALF_CONFIG):
150            return (TRUE_HALF_CONFIG, PSEUDO_HALF_CONFIG, FLOAT_CONFIG, DOUBLE_CONFIG)
151        return (PSEUDO_HALF_CONFIG, FLOAT_CONFIG, DOUBLE_CONFIG)
152
153    def fwd_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
154        algorithms = self.cudnnConvolutionFwdAlgo_t
155        algo = algorithms.fromalias(algo)
156        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
157            return not is_true_half_config(dtype, precision)
158        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
159            return ndim == 2 or not is_true_half_config(dtype, precision)
160        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
161            return ndim == 2 and not is_true_half_config(dtype, precision)
162        # CUDNN_CONVOLUTION_FWD_ALGO_DIRECT: not implemented.
163        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_FFT:
164            return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
165        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
166            if ndim == 2:
167                return is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision)
168            if ndim == 3:
169                return not is_true_half_config(dtype, precision)
170        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
171            return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
172        if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
173            # NB: "If wDesc 's filter (height, width) is (5,5), data type config TRUE_HALF_CONFIG is not supported".
174            # We could not check it before being in C code.
175            return ndim == 2 and not is_double_config(dtype, precision)
176        return False
177
178    def bwd_filter_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
179        # NB: Theano does not support float16 precision anymore for backward cuDNN convolutions.
180        if is_true_half_config(dtype, precision):
181            return False
182        algorithms = self.cudnnConvolutionBwdFilterAlgo_t
183        algo = algorithms.fromalias(algo)
184        if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
185            return not is_true_half_config(dtype, precision)
186        if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
187            return ndim == 2
188        if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
189            return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
190        if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
191            return not is_true_half_config(dtype, precision)
192        if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
193            # NB: "If wDesc 's filter (height, width) is (5,5), data type config TRUE_HALF_CONFIG is not supported".
194            # We could not check it before being in C code.
195            return ndim == 2 and not is_double_config(dtype, precision)
196        return False
197
198    def bwd_data_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
199        # NB: Theano does not support float16 precision anymore for backward cuDNN convolutions.
200        if is_true_half_config(dtype, precision):
201            return False
202        algorithms = self.cudnnConvolutionBwdDataAlgo_t
203        algo = algorithms.fromalias(algo)
204        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
205            return not is_true_half_config(dtype, precision)
206        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
207            # CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: all data type configs supported.
208            # NB: Let's avoid float16 precision, as some strange errors may be encountered
209            # with that precision ( see https://github.com/Theano/Theano/pull/5932/ )
210            return not is_true_half_config(dtype, precision)
211        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
212            return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
213        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
214            if ndim == 2:
215                return is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision)
216            if ndim == 3:
217                return not is_true_half_config(dtype, precision)
218        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
219            return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
220        if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
221            # NB: "If wDesc 's filter (height, width) is (5,5), data type config TRUE_HALF_CONFIG is not supported".
222            # We could not check it before being in C code.
223            return ndim == 2 and not is_double_config(dtype, precision)
224        return False
225
226
227class CuDNNV6(CuDNNV51):
228    version = 6
229
230    cudnnDataType_t = CEnumType(('CUDNN_DATA_FLOAT', 'float32'),
231                                ('CUDNN_DATA_DOUBLE', 'float64'),
232                                ('CUDNN_DATA_HALF', 'float16'),
233                                # new in v6
234                                ('CUDNN_DATA_INT8', 'int8'),
235                                ('CUDNN_DATA_INT32', 'int32'),
236                                # ('CUDNN_DATA_INT8X4', 'int8x4'),
237                                ctype='cudnnDataType_t')
238
239    cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
240                                   ('CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING', 'average_inc_pad'),
241                                   ('CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING', 'average_exc_pad'),
242                                   # new in v6:
243                                   ('CUDNN_POOLING_MAX_DETERMINISTIC', 'max_deterministic'),
244                                   ctype='cudnnPoolingMode_t')
245
246    cudnnConvolutionBwdFilterAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0', 'none'),
247                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'),
248                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'),
249                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3', 'small'),
250                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'),
251                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
252                                                # new in v6:
253                                                ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'),
254                                                ctype='cudnnConvolutionBwdFilterAlgo_t')
255
256    deterministic_bwd_filter_algorithms = CuDNNV51.deterministic_bwd_filter_algorithms + ('fft_tiling',)
257
258    cudnnReduceTensorOp_t = CEnumType(('CUDNN_REDUCE_TENSOR_ADD', 'add'),
259                                      ('CUDNN_REDUCE_TENSOR_MUL', 'mul'),
260                                      ('CUDNN_REDUCE_TENSOR_MIN', 'minimum'),
261                                      ('CUDNN_REDUCE_TENSOR_MAX', 'maximum'),
262                                      ('CUDNN_REDUCE_TENSOR_AMAX', 'absmax'),
263                                      ('CUDNN_REDUCE_TENSOR_AVG', 'avg'),
264                                      ('CUDNN_REDUCE_TENSOR_NORM1', 'norm1'),
265                                      ('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'),
266                                      ctype='cudnnReduceTensorOp_t')
267
268    def fwd_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
269        is_supported = super(CuDNNV6, self).fwd_algo_supports_dtype_config(algo, dtype, precision, ndim)
270        if not is_supported:
271            algorithms = self.cudnnConvolutionFwdAlgo_t
272            algo = algorithms.fromalias(algo)
273            if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
274                # NB: For cuDNN V6:
275                # "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
276                # (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
277                # ie, one of the filter dimension, width or height is 1)"
278                # Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
279                return ndim == 2 and (is_pseudo_half_config(dtype, precision) or
280                                      is_float_config(dtype, precision) or
281                                      is_double_config(dtype, precision))
282        return is_supported
283
284    def bwd_filter_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
285        is_supported = super(CuDNNV6, self).bwd_filter_algo_supports_dtype_config(algo, dtype, precision, ndim)
286        if not is_supported:
287            algorithms = self.cudnnConvolutionBwdFilterAlgo_t
288            algo = algorithms.fromalias(algo)
289            if algo == algorithms.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
290                return ndim == 2 and (is_pseudo_half_config(dtype, precision) or
291                                      is_float_config(dtype, precision) or
292                                      is_double_config(dtype, precision))
293        return is_supported
294
295    def bwd_data_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
296        is_supported = super(CuDNNV6, self).bwd_data_algo_supports_dtype_config(algo, dtype, precision, ndim)
297        if not is_supported:
298            algorithms = self.cudnnConvolutionBwdDataAlgo_t
299            algo = algorithms.fromalias(algo)
300            if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
301                # NB: For cuDNN V6:
302                # "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
303                # (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
304                # ie, one of the filter dimension, width or height is 1)"
305                # Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
306                return ndim == 2 and (is_pseudo_half_config(dtype, precision) or
307                                      is_float_config(dtype, precision) or
308                                      is_double_config(dtype, precision))
309        return is_supported
310
311
312class CuDNNV7(CuDNNV6):
313    version = 7
314    cudnnMathType_t = CEnumType(('CUDNN_DEFAULT_MATH', 'non_tensor_op'),
315                                ('CUDNN_TENSOR_OP_MATH', 'tensor_op'),
316                                ctype='cudnnMathType_t')
317    cudnnDeterminism_t = CEnumType(('CUDNN_NON_DETERMINISTIC', 'non_deterministic'),
318                                   ('CUDNN_DETERMINISTIC', 'deterministic'),
319                                   ctype='cudnnDeterminism_t')
320
321
322def get_definitions(cudnn_version=None):
323    """
324    Return cuDNN definitions to be used by Theano for the given cuDNN version.
325
326    ``cudnn_version`` must be None or an integer
327    (typically the version returned by :func:`theano.gpuarray.dnn.version`).
328    if None, return definitions for the  most recent supported cuDNN version.
329
330    """
331    if cudnn_version is not None:
332        if cudnn_version // 1000 == 5:
333            return CuDNNV51()
334        if cudnn_version // 1000 == 6:
335            return CuDNNV6()
336    # By default, we use definitions for the last supported cuDNN version.
337    return CuDNNV7()
338