1# Copyright 2013-2019 The Meson development team
2
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6
7#     http://www.apache.org/licenses/LICENSE-2.0
8
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import glob
16import re
17import os
18import typing as T
19from pathlib import Path
20
21from .. import mesonlib
22from .. import mlog
23from ..environment import detect_cpu_family
24from .base import DependencyException, SystemDependency
25
26
27if T.TYPE_CHECKING:
28    from ..environment import Environment
29    from ..compilers import Compiler
30
31TV_ResultTuple = T.Tuple[T.Optional[str], T.Optional[str], bool]
32
33class CudaDependency(SystemDependency):
34
35    supported_languages = ['cuda', 'cpp', 'c'] # see also _default_language
36
37    def __init__(self, environment: 'Environment', kwargs: T.Dict[str, T.Any]) -> None:
38        compilers = environment.coredata.compilers[self.get_for_machine_from_kwargs(kwargs)]
39        language = self._detect_language(compilers)
40        if language not in self.supported_languages:
41            raise DependencyException(f'Language \'{language}\' is not supported by the CUDA Toolkit. Supported languages are {self.supported_languages}.')
42
43        super().__init__('cuda', environment, kwargs, language=language)
44        self.lib_modules: T.Dict[str, T.List[str]] = {}
45        self.requested_modules = self.get_requested(kwargs)
46        if 'cudart' not in self.requested_modules:
47            self.requested_modules = ['cudart'] + self.requested_modules
48
49        (self.cuda_path, self.version, self.is_found) = self._detect_cuda_path_and_version()
50        if not self.is_found:
51            return
52
53        if not os.path.isabs(self.cuda_path):
54            raise DependencyException(f'CUDA Toolkit path must be absolute, got \'{self.cuda_path}\'.')
55
56        # nvcc already knows where to find the CUDA Toolkit, but if we're compiling
57        # a mixed C/C++/CUDA project, we still need to make the include dir searchable
58        if self.language != 'cuda' or len(compilers) > 1:
59            self.incdir = os.path.join(self.cuda_path, 'include')
60            self.compile_args += [f'-I{self.incdir}']
61
62        if self.language != 'cuda':
63            arch_libdir = self._detect_arch_libdir()
64            self.libdir = os.path.join(self.cuda_path, arch_libdir)
65            mlog.debug('CUDA library directory is', mlog.bold(self.libdir))
66        else:
67            self.libdir = None
68
69        self.is_found = self._find_requested_libraries()
70
71    @classmethod
72    def _detect_language(cls, compilers: T.Dict[str, 'Compiler']) -> str:
73        for lang in cls.supported_languages:
74            if lang in compilers:
75                return lang
76        return list(compilers.keys())[0]
77
78    def _detect_cuda_path_and_version(self) -> TV_ResultTuple:
79        self.env_var = self._default_path_env_var()
80        mlog.debug('Default path env var:', mlog.bold(self.env_var))
81
82        version_reqs = self.version_reqs
83        if self.language == 'cuda':
84            nvcc_version = self._strip_patch_version(self.get_compiler().version)
85            mlog.debug('nvcc version:', mlog.bold(nvcc_version))
86            if version_reqs:
87                # make sure nvcc version satisfies specified version requirements
88                (found_some, not_found, found) = mesonlib.version_compare_many(nvcc_version, version_reqs)
89                if not_found:
90                    msg = f'The current nvcc version {nvcc_version} does not satisfy the specified CUDA Toolkit version requirements {version_reqs}.'
91                    return self._report_dependency_error(msg, (None, None, False))
92
93            # use nvcc version to find a matching CUDA Toolkit
94            version_reqs = [f'={nvcc_version}']
95        else:
96            nvcc_version = None
97
98        paths = [(path, self._cuda_toolkit_version(path), default) for (path, default) in self._cuda_paths()]
99        if version_reqs:
100            return self._find_matching_toolkit(paths, version_reqs, nvcc_version)
101
102        defaults = [(path, version) for (path, version, default) in paths if default]
103        if defaults:
104            return (defaults[0][0], defaults[0][1], True)
105
106        platform_msg = 'set the CUDA_PATH environment variable' if self._is_windows() \
107            else 'set the CUDA_PATH environment variable/create the \'/usr/local/cuda\' symbolic link'
108        msg = f'Please specify the desired CUDA Toolkit version (e.g. dependency(\'cuda\', version : \'>=10.1\')) or {platform_msg} to point to the location of your desired version.'
109        return self._report_dependency_error(msg, (None, None, False))
110
111    def _find_matching_toolkit(self, paths: T.List[TV_ResultTuple], version_reqs: T.List[str], nvcc_version: T.Optional[str]) -> TV_ResultTuple:
112        # keep the default paths order intact, sort the rest in the descending order
113        # according to the toolkit version
114        part_func: T.Callable[[TV_ResultTuple], bool] = lambda t: not t[2]
115        defaults_it, rest_it = mesonlib.partition(part_func, paths)
116        defaults = list(defaults_it)
117        paths = defaults + sorted(rest_it, key=lambda t: mesonlib.Version(t[1]), reverse=True)
118        mlog.debug(f'Search paths: {paths}')
119
120        if nvcc_version and defaults:
121            default_src = f"the {self.env_var} environment variable" if self.env_var else "the \'/usr/local/cuda\' symbolic link"
122            nvcc_warning = 'The default CUDA Toolkit as designated by {} ({}) doesn\'t match the current nvcc version {} and will be ignored.'.format(default_src, os.path.realpath(defaults[0][0]), nvcc_version)
123        else:
124            nvcc_warning = None
125
126        for (path, version, default) in paths:
127            (found_some, not_found, found) = mesonlib.version_compare_many(version, version_reqs)
128            if not not_found:
129                if not default and nvcc_warning:
130                    mlog.warning(nvcc_warning)
131                return (path, version, True)
132
133        if nvcc_warning:
134            mlog.warning(nvcc_warning)
135        return (None, None, False)
136
137    def _default_path_env_var(self) -> T.Optional[str]:
138        env_vars = ['CUDA_PATH'] if self._is_windows() else ['CUDA_PATH', 'CUDA_HOME', 'CUDA_ROOT']
139        env_vars = [var for var in env_vars if var in os.environ]
140        user_defaults = {os.environ[var] for var in env_vars}
141        if len(user_defaults) > 1:
142            mlog.warning('Environment variables {} point to conflicting toolkit locations ({}). Toolkit selection might produce unexpected results.'.format(', '.join(env_vars), ', '.join(user_defaults)))
143        return env_vars[0] if env_vars else None
144
145    def _cuda_paths(self) -> T.List[T.Tuple[str, bool]]:
146        return ([(os.environ[self.env_var], True)] if self.env_var else []) \
147            + (self._cuda_paths_win() if self._is_windows() else self._cuda_paths_nix())
148
149    def _cuda_paths_win(self) -> T.List[T.Tuple[str, bool]]:
150        env_vars = os.environ.keys()
151        return [(os.environ[var], False) for var in env_vars if var.startswith('CUDA_PATH_')]
152
153    def _cuda_paths_nix(self) -> T.List[T.Tuple[str, bool]]:
154        # include /usr/local/cuda default only if no env_var was found
155        pattern = '/usr/local/cuda-*' if self.env_var else '/usr/local/cuda*'
156        return [(path, os.path.basename(path) == 'cuda') for path in glob.iglob(pattern)]
157
158    toolkit_version_regex = re.compile(r'^CUDA Version\s+(.*)$')
159    path_version_win_regex = re.compile(r'^v(.*)$')
160    path_version_nix_regex = re.compile(r'^cuda-(.*)$')
161    cudart_version_regex = re.compile(r'#define\s+CUDART_VERSION\s+([0-9]+)')
162
163    def _cuda_toolkit_version(self, path: str) -> str:
164        version = self._read_toolkit_version_txt(path)
165        if version:
166            return version
167        version = self._read_cuda_runtime_api_version(path)
168        if version:
169            return version
170
171        mlog.debug('Falling back to extracting version from path')
172        path_version_regex = self.path_version_win_regex if self._is_windows() else self.path_version_nix_regex
173        try:
174            m = path_version_regex.match(os.path.basename(path))
175            if m:
176                return m.group(1)
177            else:
178                mlog.warning(f'Could not detect CUDA Toolkit version for {path}')
179        except Exception as e:
180            mlog.warning(f'Could not detect CUDA Toolkit version for {path}: {e!s}')
181
182        return '0.0'
183
184    def _read_cuda_runtime_api_version(self, path_str: str) -> T.Optional[str]:
185        path = Path(path_str)
186        for i in path.rglob('cuda_runtime_api.h'):
187            raw = i.read_text(encoding='utf-8')
188            m = self.cudart_version_regex.search(raw)
189            if not m:
190                continue
191            try:
192                vers_int = int(m.group(1))
193            except ValueError:
194                continue
195            # use // for floor instead of / which produces a float
196            major = vers_int // 1000                  # type: int
197            minor = (vers_int - major * 1000) // 10   # type: int
198            return f'{major}.{minor}'
199        return None
200
201    def _read_toolkit_version_txt(self, path: str) -> T.Optional[str]:
202        # Read 'version.txt' at the root of the CUDA Toolkit directory to determine the tookit version
203        version_file_path = os.path.join(path, 'version.txt')
204        try:
205            with open(version_file_path, encoding='utf-8') as version_file:
206                version_str = version_file.readline() # e.g. 'CUDA Version 10.1.168'
207                m = self.toolkit_version_regex.match(version_str)
208                if m:
209                    return self._strip_patch_version(m.group(1))
210        except Exception as e:
211            mlog.debug(f'Could not read CUDA Toolkit\'s version file {version_file_path}: {e!s}')
212
213        return None
214
215    @classmethod
216    def _strip_patch_version(cls, version: str) -> str:
217        return '.'.join(version.split('.')[:2])
218
219    def _detect_arch_libdir(self) -> str:
220        arch = detect_cpu_family(self.env.coredata.compilers.host)
221        machine = self.env.machines[self.for_machine]
222        msg = '{} architecture is not supported in {} version of the CUDA Toolkit.'
223        if machine.is_windows():
224            libdirs = {'x86': 'Win32', 'x86_64': 'x64'}
225            if arch not in libdirs:
226                raise DependencyException(msg.format(arch, 'Windows'))
227            return os.path.join('lib', libdirs[arch])
228        elif machine.is_linux():
229            libdirs = {'x86_64': 'lib64', 'ppc64': 'lib', 'aarch64': 'lib64', 'loongarch64': 'lib64'}
230            if arch not in libdirs:
231                raise DependencyException(msg.format(arch, 'Linux'))
232            return libdirs[arch]
233        elif machine.is_darwin():
234            libdirs = {'x86_64': 'lib64'}
235            if arch not in libdirs:
236                raise DependencyException(msg.format(arch, 'macOS'))
237            return libdirs[arch]
238        else:
239            raise DependencyException('CUDA Toolkit: unsupported platform.')
240
241    def _find_requested_libraries(self) -> bool:
242        all_found = True
243
244        for module in self.requested_modules:
245            args = self.clib_compiler.find_library(module, self.env, [self.libdir] if self.libdir else [])
246            if args is None:
247                self._report_dependency_error(f'Couldn\'t find requested CUDA module \'{module}\'')
248                all_found = False
249            else:
250                mlog.debug(f'Link args for CUDA module \'{module}\' are {args}')
251                self.lib_modules[module] = args
252
253        return all_found
254
255    def _is_windows(self) -> bool:
256        return self.env.machines[self.for_machine].is_windows()
257
258    @T.overload
259    def _report_dependency_error(self, msg: str) -> None: ...
260
261    @T.overload
262    def _report_dependency_error(self, msg: str, ret_val: TV_ResultTuple) -> TV_ResultTuple: ...
263
264    def _report_dependency_error(self, msg: str, ret_val: T.Optional[TV_ResultTuple] = None) -> T.Optional[TV_ResultTuple]:
265        if self.required:
266            raise DependencyException(msg)
267
268        mlog.debug(msg)
269        return ret_val
270
271    def log_details(self) -> str:
272        module_str = ', '.join(self.requested_modules)
273        return 'modules: ' + module_str
274
275    def log_info(self) -> str:
276        return self.cuda_path if self.cuda_path else ''
277
278    def get_requested(self, kwargs: T.Dict[str, T.Any]) -> T.List[str]:
279        candidates = mesonlib.extract_as_list(kwargs, 'modules')
280        for c in candidates:
281            if not isinstance(c, str):
282                raise DependencyException('CUDA module argument is not a string.')
283        return candidates
284
285    def get_link_args(self, language: T.Optional[str] = None, raw: bool = False) -> T.List[str]:
286        args = []
287        if self.libdir:
288            args += self.clib_compiler.get_linker_search_args(self.libdir)
289        for lib in self.requested_modules:
290            args += self.lib_modules[lib]
291        return args
292