1import contextlib
2import os
3import shutil
4import sys
5
6from numba.tests.support import (
7    captured_stdout,
8    SerialMixin,
9    redirect_c_stdout,
10)
11from numba.cuda.cuda_paths import get_conda_ctk
12from numba.core import config
13from numba.tests.support import TestCase
14import unittest
15
16
17class CUDATestCase(SerialMixin, TestCase):
18    """
19    For tests that use a CUDA device. Test methods in a CUDATestCase must not
20    be run out of module order, because the ContextResettingTestCase may reset
21    the context and destroy resources used by a normal CUDATestCase if any of
22    its tests are run between tests from a CUDATestCase.
23    """
24
25
26class ContextResettingTestCase(CUDATestCase):
27    """
28    For tests where the context needs to be reset after each test. Typically
29    these inspect or modify parts of the context that would usually be expected
30    to be internal implementation details (such as the state of allocations and
31    deallocations, etc.).
32    """
33
34    def tearDown(self):
35        from numba.cuda.cudadrv.devices import reset
36        reset()
37
38
39def skip_on_cudasim(reason):
40    """Skip this test if running on the CUDA simulator"""
41    return unittest.skipIf(config.ENABLE_CUDASIM, reason)
42
43
44def skip_unless_cudasim(reason):
45    """Skip this test if running on CUDA hardware"""
46    return unittest.skipUnless(config.ENABLE_CUDASIM, reason)
47
48
49def skip_unless_conda_cudatoolkit(reason):
50    """Skip test if the CUDA toolkit was not installed by Conda"""
51    return unittest.skipUnless(get_conda_ctk() is not None, reason)
52
53
54def skip_if_external_memmgr(reason):
55    """Skip test if an EMM Plugin is in use"""
56    return unittest.skipIf(config.CUDA_MEMORY_MANAGER != 'default', reason)
57
58
59def skip_under_cuda_memcheck(reason):
60    return unittest.skipIf(os.environ.get('CUDA_MEMCHECK') is not None, reason)
61
62
63def skip_without_nvdisasm(reason):
64    nvdisasm_path = shutil.which('nvdisasm')
65    return unittest.skipIf(nvdisasm_path is None, reason)
66
67
68def skip_with_nvdisasm(reason):
69    nvdisasm_path = shutil.which('nvdisasm')
70    return unittest.skipIf(nvdisasm_path is not None, reason)
71
72
73class CUDATextCapture(object):
74
75    def __init__(self, stream):
76        self._stream = stream
77
78    def getvalue(self):
79        return self._stream.read()
80
81
82class PythonTextCapture(object):
83
84    def __init__(self, stream):
85        self._stream = stream
86
87    def getvalue(self):
88        return self._stream.getvalue()
89
90
91@contextlib.contextmanager
92def captured_cuda_stdout():
93    """
94    Return a minimal stream-like object capturing the text output of
95    either CUDA or the simulator.
96    """
97    # Prevent accidentally capturing previously output text
98    sys.stdout.flush()
99
100    if config.ENABLE_CUDASIM:
101        # The simulator calls print() on Python stdout
102        with captured_stdout() as stream:
103            yield PythonTextCapture(stream)
104    else:
105        # The CUDA runtime writes onto the system stdout
106        from numba import cuda
107        with redirect_c_stdout() as stream:
108            yield CUDATextCapture(stream)
109            cuda.synchronize()
110
111
112class ForeignArray(object):
113    """
114    Class for emulating an array coming from another library through the CUDA
115    Array interface. This just hides a DeviceNDArray so that it doesn't look
116    like a DeviceNDArray.
117    """
118
119    def __init__(self, arr):
120        self._arr = arr
121        self.__cuda_array_interface__ = arr.__cuda_array_interface__
122