1# Copyright 2017 The Meson development team
2
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6
7#     http://www.apache.org/licenses/LICENSE-2.0
8
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import typing as T
16import re
17
18from ..mesonlib import version_compare
19from ..compilers import CudaCompiler, Compiler
20
21from . import NewExtensionModule
22
23from ..interpreterbase import (
24    flatten, permittedKwargs, noKwargs,
25    InvalidArguments, FeatureNew
26)
27
28if T.TYPE_CHECKING:
29    from . import ModuleState
30
31class CudaModule(NewExtensionModule):
32
33    @FeatureNew('CUDA module', '0.50.0')
34    def __init__(self, *args, **kwargs):
35        super().__init__()
36        self.methods.update({
37            "min_driver_version": self.min_driver_version,
38            "nvcc_arch_flags":    self.nvcc_arch_flags,
39            "nvcc_arch_readable": self.nvcc_arch_readable,
40        })
41
42    @noKwargs
43    def min_driver_version(self, state: 'ModuleState',
44                                 args: T.Tuple[str],
45                                 kwargs: T.Dict[str, T.Any]) -> str:
46        argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
47                                    'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
48                                    'the CUDA Toolkit\'s components (including NVCC) are versioned ' +
49                                    'independently from each other (and the CUDA Toolkit as a whole).')
50
51        if len(args) != 1 or not isinstance(args[0], str):
52            raise argerror
53
54        cuda_version = args[0]
55        driver_version_table = [
56            {'cuda_version': '>=11.5.0',   'windows': '496.04', 'linux': '495.29.05'},
57            {'cuda_version': '>=11.4.1',   'windows': '471.41', 'linux': '470.57.02'},
58            {'cuda_version': '>=11.4.0',   'windows': '471.11', 'linux': '470.42.01'},
59            {'cuda_version': '>=11.3.0',   'windows': '465.89', 'linux': '465.19.01'},
60            {'cuda_version': '>=11.2.2',   'windows': '461.33', 'linux': '460.32.03'},
61            {'cuda_version': '>=11.2.1',   'windows': '461.09', 'linux': '460.32.03'},
62            {'cuda_version': '>=11.2.0',   'windows': '460.82', 'linux': '460.27.03'},
63            {'cuda_version': '>=11.1.1',   'windows': '456.81', 'linux': '455.32'},
64            {'cuda_version': '>=11.1.0',   'windows': '456.38', 'linux': '455.23'},
65            {'cuda_version': '>=11.0.3',   'windows': '451.82', 'linux': '450.51.06'},
66            {'cuda_version': '>=11.0.2',   'windows': '451.48', 'linux': '450.51.05'},
67            {'cuda_version': '>=11.0.1',   'windows': '451.22', 'linux': '450.36.06'},
68            {'cuda_version': '>=10.2.89',  'windows': '441.22', 'linux': '440.33'},
69            {'cuda_version': '>=10.1.105', 'windows': '418.96', 'linux': '418.39'},
70            {'cuda_version': '>=10.0.130', 'windows': '411.31', 'linux': '410.48'},
71            {'cuda_version': '>=9.2.148',  'windows': '398.26', 'linux': '396.37'},
72            {'cuda_version': '>=9.2.88',   'windows': '397.44', 'linux': '396.26'},
73            {'cuda_version': '>=9.1.85',   'windows': '391.29', 'linux': '390.46'},
74            {'cuda_version': '>=9.0.76',   'windows': '385.54', 'linux': '384.81'},
75            {'cuda_version': '>=8.0.61',   'windows': '376.51', 'linux': '375.26'},
76            {'cuda_version': '>=8.0.44',   'windows': '369.30', 'linux': '367.48'},
77            {'cuda_version': '>=7.5.16',   'windows': '353.66', 'linux': '352.31'},
78            {'cuda_version': '>=7.0.28',   'windows': '347.62', 'linux': '346.46'},
79        ]
80
81        driver_version = 'unknown'
82        for d in driver_version_table:
83            if version_compare(cuda_version, d['cuda_version']):
84                driver_version = d.get(state.host_machine.system, d['linux'])
85                break
86
87        return driver_version
88
89    @permittedKwargs(['detected'])
90    def nvcc_arch_flags(self, state: 'ModuleState',
91                              args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
92                              kwargs: T.Dict[str, T.Any]) -> T.List[str]:
93        nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
94        ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
95        return ret
96
97    @permittedKwargs(['detected'])
98    def nvcc_arch_readable(self, state: 'ModuleState',
99                                 args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
100                                 kwargs: T.Dict[str, T.Any]) -> T.List[str]:
101        nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
102        ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
103        return ret
104
105    @staticmethod
106    def _break_arch_string(s):
107        s = re.sub('[ \t\r\n,;]+', ';', s)
108        s = s.strip(';').split(';')
109        return s
110
111    @staticmethod
112    def _detected_cc_from_compiler(c):
113        if isinstance(c, CudaCompiler):
114            return c.detected_cc
115        return ''
116
117    @staticmethod
118    def _version_from_compiler(c):
119        if isinstance(c, CudaCompiler):
120            return c.version
121        if isinstance(c, str):
122            return c
123        return 'unknown'
124
125    def _validate_nvcc_arch_args(self, args, kwargs):
126        argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
127
128        if len(args) < 1:
129            raise argerror
130        else:
131            compiler = args[0]
132            cuda_version = self._version_from_compiler(compiler)
133            if cuda_version == 'unknown':
134                raise argerror
135
136        arch_list = [] if len(args) <= 1 else flatten(args[1:])
137        arch_list = [self._break_arch_string(a) for a in arch_list]
138        arch_list = flatten(arch_list)
139        if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
140            raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
141        arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
142
143        detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler))
144        detected = flatten([detected])
145        detected = [self._break_arch_string(a) for a in detected]
146        detected = flatten(detected)
147        if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
148            raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
149
150        return cuda_version, arch_list, detected
151
152    def _filter_cuda_arch_list(self, cuda_arch_list, lo=None, hi=None, saturate=None):
153        """
154        Filter CUDA arch list (no codenames) for >= low and < hi architecture
155        bounds, and deduplicate.
156        If saturate is provided, architectures >= hi are replaced with saturate.
157        """
158
159        filtered_cuda_arch_list = []
160        for arch in cuda_arch_list:
161            if arch:
162                if lo and version_compare(arch, '<' + lo):
163                    continue
164                if hi and version_compare(arch, '>=' + hi):
165                    if not saturate:
166                        continue
167                    arch = saturate
168                if arch not in filtered_cuda_arch_list:
169                    filtered_cuda_arch_list.append(arch)
170        return filtered_cuda_arch_list
171
172    def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):
173        """
174        Using the CUDA Toolkit version and the target architectures, compute
175        the NVCC architecture flags.
176        """
177
178        # Replicates much of the logic of
179        #     https://github.com/Kitware/CMake/blob/master/Modules/FindCUDA/select_compute_arch.cmake
180        # except that a bug with cuda_arch_list="All" is worked around by
181        # tracking both lower and upper limits on GPU architectures.
182
183        cuda_known_gpu_architectures   = ['Fermi', 'Kepler', 'Maxwell']  # noqa: E221
184        cuda_common_gpu_architectures  = ['3.0', '3.5', '5.0']           # noqa: E221
185        cuda_hi_limit_gpu_architecture = None                            # noqa: E221
186        cuda_lo_limit_gpu_architecture = '2.0'                           # noqa: E221
187        cuda_all_gpu_architectures     = ['3.0', '3.2', '3.5', '5.0']    # noqa: E221
188
189        if version_compare(cuda_version, '<7.0'):
190            cuda_hi_limit_gpu_architecture = '5.2'
191
192        if version_compare(cuda_version, '>=7.0'):
193            cuda_known_gpu_architectures  += ['Kepler+Tegra', 'Kepler+Tesla', 'Maxwell+Tegra']  # noqa: E221
194            cuda_common_gpu_architectures += ['5.2']                                            # noqa: E221
195
196            if version_compare(cuda_version, '<8.0'):
197                cuda_common_gpu_architectures += ['5.2+PTX']  # noqa: E221
198                cuda_hi_limit_gpu_architecture = '6.0'        # noqa: E221
199
200        if version_compare(cuda_version, '>=8.0'):
201            cuda_known_gpu_architectures  += ['Pascal', 'Pascal+Tegra']  # noqa: E221
202            cuda_common_gpu_architectures += ['6.0', '6.1']              # noqa: E221
203            cuda_all_gpu_architectures    += ['6.0', '6.1', '6.2']       # noqa: E221
204
205            if version_compare(cuda_version, '<9.0'):
206                cuda_common_gpu_architectures += ['6.1+PTX']  # noqa: E221
207                cuda_hi_limit_gpu_architecture = '7.0'        # noqa: E221
208
209        if version_compare(cuda_version, '>=9.0'):
210            cuda_known_gpu_architectures  += ['Volta', 'Xavier'] # noqa: E221
211            cuda_common_gpu_architectures += ['7.0']             # noqa: E221
212            cuda_all_gpu_architectures    += ['7.0', '7.2']      # noqa: E221
213            # https://docs.nvidia.com/cuda/archive/9.0/cuda-toolkit-release-notes/index.html#unsupported-features
214            cuda_lo_limit_gpu_architecture = '3.0'               # noqa: E221
215
216            if version_compare(cuda_version, '<10.0'):
217                cuda_common_gpu_architectures += ['7.2+PTX']  # noqa: E221
218                cuda_hi_limit_gpu_architecture = '8.0'        # noqa: E221
219
220        if version_compare(cuda_version, '>=10.0'):
221            cuda_known_gpu_architectures  += ['Turing'] # noqa: E221
222            cuda_common_gpu_architectures += ['7.5']    # noqa: E221
223            cuda_all_gpu_architectures    += ['7.5']    # noqa: E221
224
225            if version_compare(cuda_version, '<11.0'):
226                cuda_common_gpu_architectures += ['7.5+PTX']  # noqa: E221
227                cuda_hi_limit_gpu_architecture = '8.0'        # noqa: E221
228
229        if version_compare(cuda_version, '>=11.0'):
230            cuda_known_gpu_architectures  += ['Ampere'] # noqa: E221
231            cuda_common_gpu_architectures += ['8.0']    # noqa: E221
232            cuda_all_gpu_architectures    += ['8.0']    # noqa: E221
233            # https://docs.nvidia.com/cuda/archive/11.0/cuda-toolkit-release-notes/index.html#deprecated-features
234            cuda_lo_limit_gpu_architecture = '3.5'      # noqa: E221
235
236            if version_compare(cuda_version, '<11.1'):
237                cuda_common_gpu_architectures += ['8.0+PTX']  # noqa: E221
238                cuda_hi_limit_gpu_architecture = '8.6'        # noqa: E221
239
240        if version_compare(cuda_version, '>=11.1'):
241            cuda_common_gpu_architectures += ['8.6', '8.6+PTX']  # noqa: E221
242            cuda_all_gpu_architectures    += ['8.6']             # noqa: E221
243
244            if version_compare(cuda_version, '<12.0'):
245                cuda_hi_limit_gpu_architecture = '9.0'        # noqa: E221
246
247        if not cuda_arch_list:
248            cuda_arch_list = 'Auto'
249
250        if   cuda_arch_list == 'All':     # noqa: E271
251            cuda_arch_list = cuda_known_gpu_architectures
252        elif cuda_arch_list == 'Common':  # noqa: E271
253            cuda_arch_list = cuda_common_gpu_architectures
254        elif cuda_arch_list == 'Auto':    # noqa: E271
255            if detected:
256                if isinstance(detected, list):
257                    cuda_arch_list = detected
258                else:
259                    cuda_arch_list = self._break_arch_string(detected)
260                cuda_arch_list = self._filter_cuda_arch_list(cuda_arch_list,
261                                                             cuda_lo_limit_gpu_architecture,
262                                                             cuda_hi_limit_gpu_architecture,
263                                                             cuda_common_gpu_architectures[-1])
264            else:
265                cuda_arch_list = cuda_common_gpu_architectures
266        elif isinstance(cuda_arch_list, str):
267            cuda_arch_list = self._break_arch_string(cuda_arch_list)
268
269        cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x)
270
271        cuda_arch_bin = []
272        cuda_arch_ptx = []
273        for arch_name in cuda_arch_list:
274            arch_bin = []
275            arch_ptx = []
276            add_ptx = arch_name.endswith('+PTX')
277            if add_ptx:
278                arch_name = arch_name[:-len('+PTX')]
279
280            if re.fullmatch('[0-9]+\\.[0-9](\\([0-9]+\\.[0-9]\\))?', arch_name):
281                arch_bin, arch_ptx = [arch_name], [arch_name]
282            else:
283                arch_bin, arch_ptx = {
284                    'Fermi':         (['2.0', '2.1(2.0)'], []),
285                    'Kepler+Tegra':  (['3.2'],             []),
286                    'Kepler+Tesla':  (['3.7'],             []),
287                    'Kepler':        (['3.0', '3.5'],      ['3.5']),
288                    'Maxwell+Tegra': (['5.3'],             []),
289                    'Maxwell':       (['5.0', '5.2'],      ['5.2']),
290                    'Pascal':        (['6.0', '6.1'],      ['6.1']),
291                    'Pascal+Tegra':  (['6.2'],             []),
292                    'Volta':         (['7.0'],             ['7.0']),
293                    'Xavier':        (['7.2'],             []),
294                    'Turing':        (['7.5'],             ['7.5']),
295                    'Ampere':        (['8.0'],             ['8.0']),
296                }.get(arch_name, (None, None))
297
298            if arch_bin is None:
299                raise InvalidArguments(f'Unknown CUDA Architecture Name {arch_name}!')
300
301            cuda_arch_bin += arch_bin
302
303            if add_ptx:
304                if not arch_ptx:
305                    arch_ptx = arch_bin
306                cuda_arch_ptx += arch_ptx
307
308        cuda_arch_bin = sorted(list(set(cuda_arch_bin)))
309        cuda_arch_ptx = sorted(list(set(cuda_arch_ptx)))
310
311        nvcc_flags = []
312        nvcc_archs_readable = []
313
314        for arch in cuda_arch_bin:
315            arch, codev = re.fullmatch(
316                '([0-9]+\\.[0-9])(?:\\(([0-9]+\\.[0-9])\\))?', arch).groups()
317
318            if version_compare(arch, '<' + cuda_lo_limit_gpu_architecture):
319                continue
320            if version_compare(arch, '>=' + cuda_hi_limit_gpu_architecture):
321                continue
322
323            if codev:
324                arch = arch.replace('.', '')
325                codev = codev.replace('.', '')
326                nvcc_flags += ['-gencode', 'arch=compute_' + codev + ',code=sm_' + arch]
327                nvcc_archs_readable += ['sm_' + arch]
328            else:
329                arch = arch.replace('.', '')
330                nvcc_flags += ['-gencode', 'arch=compute_' + arch + ',code=sm_' + arch]
331                nvcc_archs_readable += ['sm_' + arch]
332
333        for arch in cuda_arch_ptx:
334            arch, codev = re.fullmatch(
335                '([0-9]+\\.[0-9])(?:\\(([0-9]+\\.[0-9])\\))?', arch).groups()
336
337            if codev:
338                arch = codev
339
340            if version_compare(arch, '<' + cuda_lo_limit_gpu_architecture):
341                continue
342            if version_compare(arch, '>=' + cuda_hi_limit_gpu_architecture):
343                continue
344
345            arch = arch.replace('.', '')
346            nvcc_flags += ['-gencode', 'arch=compute_' + arch + ',code=compute_' + arch]
347            nvcc_archs_readable += ['compute_' + arch]
348
349        return nvcc_flags, nvcc_archs_readable
350
351def initialize(*args, **kwargs):
352    return CudaModule(*args, **kwargs)
353