1import numpy as np
2from numba import cuda, int32, float32
3from numba.cuda.testing import unittest, CUDATestCase
4from numba.core.config import ENABLE_CUDASIM
5
6
7def useless_sync(ary):
8    i = cuda.grid(1)
9    cuda.syncthreads()
10    ary[i] = i
11
12
13def simple_smem(ary):
14    N = 100
15    sm = cuda.shared.array(N, int32)
16    i = cuda.grid(1)
17    if i == 0:
18        for j in range(N):
19            sm[j] = j
20    cuda.syncthreads()
21    ary[i] = sm[i]
22
23
24def coop_smem2d(ary):
25    i, j = cuda.grid(2)
26    sm = cuda.shared.array((10, 20), float32)
27    sm[i, j] = (i + 1) / (j + 1)
28    cuda.syncthreads()
29    ary[i, j] = sm[i, j]
30
31
32def dyn_shared_memory(ary):
33    i = cuda.grid(1)
34    sm = cuda.shared.array(0, float32)
35    sm[i] = i * 2
36    cuda.syncthreads()
37    ary[i] = sm[i]
38
39
40def use_threadfence(ary):
41    ary[0] += 123
42    cuda.threadfence()
43    ary[0] += 321
44
45
46def use_threadfence_block(ary):
47    ary[0] += 123
48    cuda.threadfence_block()
49    ary[0] += 321
50
51
52def use_threadfence_system(ary):
53    ary[0] += 123
54    cuda.threadfence_system()
55    ary[0] += 321
56
57
58def use_syncthreads_count(ary_in, ary_out):
59    i = cuda.grid(1)
60    ary_out[i] = cuda.syncthreads_count(ary_in[i])
61
62
63def use_syncthreads_and(ary_in, ary_out):
64    i = cuda.grid(1)
65    ary_out[i] = cuda.syncthreads_and(ary_in[i])
66
67
68def use_syncthreads_or(ary_in, ary_out):
69    i = cuda.grid(1)
70    ary_out[i] = cuda.syncthreads_or(ary_in[i])
71
72
73
74class TestCudaSync(CUDATestCase):
75    def test_useless_sync(self):
76        compiled = cuda.jit("void(int32[::1])")(useless_sync)
77        nelem = 10
78        ary = np.empty(nelem, dtype=np.int32)
79        exp = np.arange(nelem, dtype=np.int32)
80        compiled[1, nelem](ary)
81        self.assertTrue(np.all(ary == exp))
82
83    def test_simple_smem(self):
84        compiled = cuda.jit("void(int32[::1])")(simple_smem)
85        nelem = 100
86        ary = np.empty(nelem, dtype=np.int32)
87        compiled[1, nelem](ary)
88        self.assertTrue(np.all(ary == np.arange(nelem, dtype=np.int32)))
89
90    def test_coop_smem2d(self):
91        compiled = cuda.jit("void(float32[:,::1])")(coop_smem2d)
92        shape = 10, 20
93        ary = np.empty(shape, dtype=np.float32)
94        compiled[1, shape](ary)
95        exp = np.empty_like(ary)
96        for i in range(ary.shape[0]):
97            for j in range(ary.shape[1]):
98                exp[i, j] = (i + 1) / (j + 1)
99        self.assertTrue(np.allclose(ary, exp))
100
101    def test_dyn_shared_memory(self):
102        compiled = cuda.jit("void(float32[::1])")(dyn_shared_memory)
103        shape = 50
104        ary = np.empty(shape, dtype=np.float32)
105        compiled[1, shape, 0, ary.size * 4](ary)
106        self.assertTrue(np.all(ary == 2 * np.arange(ary.size, dtype=np.int32)))
107
108    def test_threadfence_codegen(self):
109        # Does not test runtime behavior, just the code generation.
110        compiled = cuda.jit("void(int32[:])")(use_threadfence)
111        ary = np.zeros(10, dtype=np.int32)
112        compiled[1, 1](ary)
113        self.assertEqual(123 + 321, ary[0])
114        if not ENABLE_CUDASIM:
115            self.assertIn("membar.gl;", compiled.ptx)
116
117    def test_threadfence_block_codegen(self):
118        # Does not test runtime behavior, just the code generation.
119        compiled = cuda.jit("void(int32[:])")(use_threadfence_block)
120        ary = np.zeros(10, dtype=np.int32)
121        compiled[1, 1](ary)
122        self.assertEqual(123 + 321, ary[0])
123        if not ENABLE_CUDASIM:
124            self.assertIn("membar.cta;", compiled.ptx)
125
126    def test_threadfence_system_codegen(self):
127        # Does not test runtime behavior, just the code generation.
128        compiled = cuda.jit("void(int32[:])")(use_threadfence_system)
129        ary = np.zeros(10, dtype=np.int32)
130        compiled[1, 1](ary)
131        self.assertEqual(123 + 321, ary[0])
132        if not ENABLE_CUDASIM:
133            self.assertIn("membar.sys;", compiled.ptx)
134
135    def test_syncthreads_count(self):
136        compiled = cuda.jit("void(int32[:], int32[:])")(use_syncthreads_count)
137        ary_in = np.ones(72, dtype=np.int32)
138        ary_out = np.zeros(72, dtype=np.int32)
139        ary_in[31] = 0
140        ary_in[42] = 0
141        compiled[1, 72](ary_in, ary_out)
142        self.assertTrue(np.all(ary_out == 70))
143
144    def test_syncthreads_and(self):
145        compiled = cuda.jit("void(int32[:], int32[:])")(use_syncthreads_and)
146        nelem = 100
147        ary_in = np.ones(nelem, dtype=np.int32)
148        ary_out = np.zeros(nelem, dtype=np.int32)
149        compiled[1, nelem](ary_in, ary_out)
150        self.assertTrue(np.all(ary_out == 1))
151        ary_in[31] = 0
152        compiled[1, nelem](ary_in, ary_out)
153        self.assertTrue(np.all(ary_out == 0))
154
155    def test_syncthreads_or(self):
156        compiled = cuda.jit("void(int32[:], int32[:])")(use_syncthreads_or)
157        nelem = 100
158        ary_in = np.zeros(nelem, dtype=np.int32)
159        ary_out = np.zeros(nelem, dtype=np.int32)
160        compiled[1, nelem](ary_in, ary_out)
161        self.assertTrue(np.all(ary_out == 0))
162        ary_in[31] = 1
163        compiled[1, nelem](ary_in, ary_out)
164        self.assertTrue(np.all(ary_out == 1))
165
166
167if __name__ == '__main__':
168    unittest.main()
169