1import sys
2import re
3import os
4from collections import defaultdict, namedtuple
5
6from numba.core.config import IS_WIN32, IS_OSX
7from numba.misc.findlib import find_lib, find_file
8from numba.cuda.envvars import get_numbapro_envvar
9
10
11_env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])
12
13
14def _find_valid_path(options):
15    """Find valid path from *options*, which is a list of 2-tuple of
16    (name, path).  Return first pair where *path* is not None.
17    If no valid path is found, return ('<unknown>', None)
18    """
19    for by, data in options:
20        if data is not None:
21            return by, data
22    else:
23        return '<unknown>', None
24
25
26def _get_libdevice_path_decision():
27    options = [
28        ('NUMBAPRO_LIBDEVICE', get_numbapro_envvar('NUMBAPRO_LIBDEVICE')),
29        ('NUMBAPRO_CUDALIB', get_numbapro_envvar('NUMBAPRO_CUDALIB')),
30        ('Conda environment', get_conda_ctk()),
31        ('CUDA_HOME', get_cuda_home('nvvm', 'libdevice')),
32        ('System', get_system_ctk('nvvm', 'libdevice')),
33        ('Debian package', get_debian_pkg_libdevice()),
34    ]
35    by, libdir = _find_valid_path(options)
36    return by, libdir
37
38
39def _nvvm_lib_dir():
40    if IS_WIN32:
41        return 'nvvm', 'bin'
42    elif IS_OSX:
43        return 'nvvm', 'lib'
44    else:
45        return 'nvvm', 'lib64'
46
47
48def _get_nvvm_path_decision():
49    options = [
50        ('NUMBAPRO_NVVM', get_numbapro_envvar('NUMBAPRO_NVVM')),
51        ('NUMBAPRO_CUDALIB', get_numbapro_envvar('NUMBAPRO_CUDALIB')),
52        ('Conda environment', get_conda_ctk()),
53        ('CUDA_HOME', get_cuda_home(*_nvvm_lib_dir())),
54        ('System', get_system_ctk(*_nvvm_lib_dir())),
55    ]
56    by, path = _find_valid_path(options)
57    return by, path
58
59
60def _get_libdevice_paths():
61    by, libdir = _get_libdevice_path_decision()
62    # Search for pattern
63    pat = r'libdevice(\.(?P<arch>compute_\d+))?(\.\d+)*\.bc$'
64    candidates = find_file(re.compile(pat), libdir)
65    # Grouping
66    out = defaultdict(list)
67    for path in candidates:
68        m = re.search(pat, path)
69        arch = m.group('arch')
70        out[arch].append(path)
71    # Keep only the max (most recent version) of the bitcode files.
72    out = {k: max(v) for k, v in out.items()}
73    return _env_path_tuple(by, out)
74
75
76def _cudalib_path():
77    if IS_WIN32:
78        return 'bin'
79    elif IS_OSX:
80        return 'lib'
81    else:
82        return 'lib64'
83
84
85def _get_cudalib_dir_path_decision():
86    options = [
87        ('NUMBAPRO_CUDALIB', get_numbapro_envvar('NUMBAPRO_CUDALIB')),
88        ('Conda environment', get_conda_ctk()),
89        ('CUDA_HOME', get_cuda_home(_cudalib_path())),
90        ('System', get_system_ctk(_cudalib_path())),
91    ]
92    by, libdir = _find_valid_path(options)
93    return by, libdir
94
95
96def _get_cudalib_dir():
97    by, libdir = _get_cudalib_dir_path_decision()
98    return _env_path_tuple(by, libdir)
99
100
101def get_system_ctk(*subdirs):
102    """Return path to system-wide cudatoolkit; or, None if it doesn't exist.
103    """
104    # Linux?
105    if sys.platform.startswith('linux'):
106        # Is cuda alias to /usr/local/cuda?
107        # We are intentionally not getting versioned cuda installation.
108        base = '/usr/local/cuda'
109        if os.path.exists(base):
110            return os.path.join(base, *subdirs)
111
112
113def get_conda_ctk():
114    """Return path to directory containing the shared libraries of cudatoolkit.
115    """
116    is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
117    if not is_conda_env:
118        return
119    # Asssume the existence of NVVM to imply cudatoolkit installed
120    paths = find_lib('nvvm')
121    if not paths:
122        return
123    # Use the directory name of the max path
124    return os.path.dirname(max(paths))
125
126
127def get_cuda_home(*subdirs):
128    """Get paths of CUDA_HOME.
129    If *subdirs* are the subdirectory name to be appended in the resulting
130    path.
131    """
132    cuda_home = os.environ.get('CUDA_HOME')
133    if cuda_home is None:
134        # Try Windows CUDA installation without Anaconda
135        cuda_home = os.environ.get('CUDA_PATH')
136    if cuda_home is not None:
137        return os.path.join(cuda_home, *subdirs)
138
139
140def _get_nvvm_path():
141    by, path = _get_nvvm_path_decision()
142    candidates = find_lib('nvvm', path)
143    path = max(candidates) if candidates else None
144    return _env_path_tuple(by, path)
145
146
147def get_cuda_paths():
148    """Returns a dictionary mapping component names to a 2-tuple
149    of (source_variable, info).
150
151    The returned dictionary will have the following keys and infos:
152    - "nvvm": file_path
153    - "libdevice": List[Tuple[arch, file_path]]
154    - "cudalib_dir": directory_path
155
156    Note: The result of the function is cached.
157    """
158    # Check cache
159    if hasattr(get_cuda_paths, '_cached_result'):
160        return get_cuda_paths._cached_result
161    else:
162        # Not in cache
163        d = {
164            'nvvm': _get_nvvm_path(),
165            'libdevice': _get_libdevice_paths(),
166            'cudalib_dir': _get_cudalib_dir(),
167        }
168        # Cache result
169        get_cuda_paths._cached_result = d
170        return d
171
172
173def get_debian_pkg_libdevice():
174    """
175    Return the Debian NVIDIA Maintainers-packaged libdevice location, if it
176    exists.
177    """
178    pkg_libdevice_location = '/usr/lib/nvidia-cuda-toolkit/libdevice'
179    if not os.path.exists(pkg_libdevice_location):
180        return None
181    return pkg_libdevice_location
182