1from llvmlite.llvmpy.core import Module, Type, Builder
2from numba.cuda.cudadrv.nvvm import (NVVM, CompilationUnit, llvm_to_ptx,
3                                     set_cuda_kernel, fix_data_layout,
4                                     get_arch_option, get_supported_ccs)
5from ctypes import c_size_t, c_uint64, sizeof
6from numba.cuda.testing import unittest
7from numba.cuda.cudadrv.nvvm import LibDevice, NvvmError
8from numba.cuda.testing import skip_on_cudasim
9
10is64bit = sizeof(c_size_t) == sizeof(c_uint64)
11
12
13@skip_on_cudasim('NVVM Driver unsupported in the simulator')
14class TestNvvmDriver(unittest.TestCase):
15    def get_ptx(self):
16        nvvm = NVVM()
17
18        if is64bit:
19            return gpu64
20        else:
21            return gpu32
22
23    def test_nvvm_compile_simple(self):
24        nvvmir = self.get_ptx()
25        ptx = llvm_to_ptx(nvvmir).decode('utf8')
26        self.assertTrue('simple' in ptx)
27        self.assertTrue('ave' in ptx)
28
29    def test_nvvm_from_llvm(self):
30        m = Module("test_nvvm_from_llvm")
31        fty = Type.function(Type.void(), [Type.int()])
32        kernel = m.add_function(fty, name='mycudakernel')
33        bldr = Builder(kernel.append_basic_block('entry'))
34        bldr.ret_void()
35        set_cuda_kernel(kernel)
36
37        fix_data_layout(m)
38        ptx = llvm_to_ptx(str(m)).decode('utf8')
39        self.assertTrue('mycudakernel' in ptx)
40        if is64bit:
41            self.assertTrue('.address_size 64' in ptx)
42        else:
43            self.assertTrue('.address_size 32' in ptx)
44
45    def _test_nvvm_support(self, arch):
46        nvvmir = self.get_ptx()
47        compute_xx = 'compute_{0}{1}'.format(*arch)
48        ptx = llvm_to_ptx(nvvmir, arch=compute_xx, ftz=1, prec_sqrt=0,
49                          prec_div=0).decode('utf8')
50        self.assertIn(".target sm_{0}{1}".format(*arch), ptx)
51        self.assertIn('simple', ptx)
52        self.assertIn('ave', ptx)
53
54    def test_nvvm_support(self):
55        """Test supported CC by NVVM
56        """
57        for arch in get_supported_ccs():
58            self._test_nvvm_support(arch=arch)
59
60    @unittest.skipIf(True, "No new CC unknown to NVVM yet")
61    def test_nvvm_future_support(self):
62        """Test unsupported CC to help track the feature support
63        """
64        # List known CC but unsupported by NVVM
65        future_archs = [
66            # (5, 2),  # for example
67        ]
68        for arch in future_archs:
69            pat = r"-arch=compute_{0}{1}".format(*arch)
70            with self.assertRaises(NvvmError) as raises:
71                self._test_nvvm_support(arch=arch)
72            self.assertIn(pat, raises.msg)
73
74
75@skip_on_cudasim('NVVM Driver unsupported in the simulator')
76class TestArchOption(unittest.TestCase):
77    def test_get_arch_option(self):
78        # Test returning the nearest lowest arch.
79        self.assertEqual(get_arch_option(5, 0), 'compute_50')
80        self.assertEqual(get_arch_option(5, 1), 'compute_50')
81        self.assertEqual(get_arch_option(3, 7), 'compute_35')
82        # Test known arch.
83        supported_cc = get_supported_ccs()
84        for arch in supported_cc:
85            self.assertEqual(get_arch_option(*arch), 'compute_%d%d' % arch)
86        self.assertEqual(get_arch_option(1000, 0),
87                         'compute_%d%d' % supported_cc[-1])
88
89
90@skip_on_cudasim('NVVM Driver unsupported in the simulator')
91class TestLibDevice(unittest.TestCase):
92    def _libdevice_load(self, arch, expect):
93        libdevice = LibDevice(arch=arch)
94        self.assertEqual(libdevice.arch, expect)
95
96    def test_libdevice_arch_fix(self):
97        self._libdevice_load('compute_20', 'compute_20')
98        self._libdevice_load('compute_21', 'compute_20')
99        self._libdevice_load('compute_30', 'compute_30')
100        self._libdevice_load('compute_35', 'compute_35')
101        self._libdevice_load('compute_52', 'compute_50')
102
103
104gpu64 = '''
105target triple="nvptx64-"
106target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
107
108define i32 @ave(i32 %a, i32 %b) {
109entry:
110%add = add nsw i32 %a, %b
111%div = sdiv i32 %add, 2
112ret i32 %div
113}
114
115define void @simple(i32* %data) {
116entry:
117%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
118%1 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
119%mul = mul i32 %0, %1
120%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
121%add = add i32 %mul, %2
122%call = call i32 @ave(i32 %add, i32 %add)
123%idxprom = sext i32 %add to i64
124%arrayidx = getelementptr inbounds i32, i32* %data, i64 %idxprom
125store i32 %call, i32* %arrayidx, align 4
126ret void
127}
128
129declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() nounwind readnone
130
131declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() nounwind readnone
132
133declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone
134
135!nvvm.annotations = !{!1}
136!1 = metadata !{void (i32*)* @simple, metadata !"kernel", i32 1}
137'''
138
139gpu32 = '''
140target triple="nvptx-"
141target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
142
143define i32 @ave(i32 %a, i32 %b) {
144entry:
145%add = add nsw i32 %a, %b
146%div = sdiv i32 %add, 2
147ret i32 %div
148}
149
150define void @simple(i32* %data) {
151entry:
152%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
153%1 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
154%mul = mul i32 %0, %1
155%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
156%add = add i32 %mul, %2
157%call = call i32 @ave(i32 %add, i32 %add)
158%idxprom = sext i32 %add to i64
159%arrayidx = getelementptr inbounds i32, i32* %data, i64 %idxprom
160store i32 %call, i32* %arrayidx, align 4
161ret void
162}
163
164declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() nounwind readnone
165
166declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() nounwind readnone
167
168declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone
169
170!nvvm.annotations = !{!1}
171!1 = metadata !{void (i32*)* @simple, metadata !"kernel", i32 1}
172
173'''
174
175if __name__ == '__main__':
176    unittest.main()
177