1import contextlib 2import os 3import shutil 4import sys 5 6from numba.tests.support import ( 7 captured_stdout, 8 SerialMixin, 9 redirect_c_stdout, 10) 11from numba.cuda.cuda_paths import get_conda_ctk 12from numba.core import config 13from numba.tests.support import TestCase 14import unittest 15 16 17class CUDATestCase(SerialMixin, TestCase): 18 """ 19 For tests that use a CUDA device. Test methods in a CUDATestCase must not 20 be run out of module order, because the ContextResettingTestCase may reset 21 the context and destroy resources used by a normal CUDATestCase if any of 22 its tests are run between tests from a CUDATestCase. 23 """ 24 25 26class ContextResettingTestCase(CUDATestCase): 27 """ 28 For tests where the context needs to be reset after each test. Typically 29 these inspect or modify parts of the context that would usually be expected 30 to be internal implementation details (such as the state of allocations and 31 deallocations, etc.). 32 """ 33 34 def tearDown(self): 35 from numba.cuda.cudadrv.devices import reset 36 reset() 37 38 39def skip_on_cudasim(reason): 40 """Skip this test if running on the CUDA simulator""" 41 return unittest.skipIf(config.ENABLE_CUDASIM, reason) 42 43 44def skip_unless_cudasim(reason): 45 """Skip this test if running on CUDA hardware""" 46 return unittest.skipUnless(config.ENABLE_CUDASIM, reason) 47 48 49def skip_unless_conda_cudatoolkit(reason): 50 """Skip test if the CUDA toolkit was not installed by Conda""" 51 return unittest.skipUnless(get_conda_ctk() is not None, reason) 52 53 54def skip_if_external_memmgr(reason): 55 """Skip test if an EMM Plugin is in use""" 56 return unittest.skipIf(config.CUDA_MEMORY_MANAGER != 'default', reason) 57 58 59def skip_under_cuda_memcheck(reason): 60 return unittest.skipIf(os.environ.get('CUDA_MEMCHECK') is not None, reason) 61 62 63def skip_without_nvdisasm(reason): 64 nvdisasm_path = shutil.which('nvdisasm') 65 return unittest.skipIf(nvdisasm_path is None, reason) 66 67 68def skip_with_nvdisasm(reason): 69 nvdisasm_path = shutil.which('nvdisasm') 70 return unittest.skipIf(nvdisasm_path is not None, reason) 71 72 73class CUDATextCapture(object): 74 75 def __init__(self, stream): 76 self._stream = stream 77 78 def getvalue(self): 79 return self._stream.read() 80 81 82class PythonTextCapture(object): 83 84 def __init__(self, stream): 85 self._stream = stream 86 87 def getvalue(self): 88 return self._stream.getvalue() 89 90 91@contextlib.contextmanager 92def captured_cuda_stdout(): 93 """ 94 Return a minimal stream-like object capturing the text output of 95 either CUDA or the simulator. 96 """ 97 # Prevent accidentally capturing previously output text 98 sys.stdout.flush() 99 100 if config.ENABLE_CUDASIM: 101 # The simulator calls print() on Python stdout 102 with captured_stdout() as stream: 103 yield PythonTextCapture(stream) 104 else: 105 # The CUDA runtime writes onto the system stdout 106 from numba import cuda 107 with redirect_c_stdout() as stream: 108 yield CUDATextCapture(stream) 109 cuda.synchronize() 110 111 112class ForeignArray(object): 113 """ 114 Class for emulating an array coming from another library through the CUDA 115 Array interface. This just hides a DeviceNDArray so that it doesn't look 116 like a DeviceNDArray. 117 """ 118 119 def __init__(self, arr): 120 self._arr = arr 121 self.__cuda_array_interface__ = arr.__cuda_array_interface__ 122