1import numpy as np
2
3from numba.core.compiler import compile_isolated
4from numba import jit
5from numba.core import types
6
7from numba.tests.support import TestCase, tag
8import unittest
9
10
11def dobool(a):
12    return bool(a)
13
14
15def doint(a):
16    return int(a)
17
18
19def dofloat(a):
20    return float(a)
21
22
23def docomplex(a):
24    return complex(a)
25
26
27def docomplex2(a, b):
28    return complex(a, b)
29
30
31def complex_calc(a):
32    z = complex(a)
33    return z.real ** 2 + z.imag ** 2
34
35
36def complex_calc2(a, b):
37    z = complex(a, b)
38    return z.real ** 2 + z.imag ** 2
39
40
41def converter(tp):
42    def f(a):
43        return tp(a)
44    return f
45
46
47def real_np_types():
48    for tp_name in ('int8', 'int16', 'int32', 'int64',
49                    'uint8', 'uint16', 'uint32', 'uint64',
50                    'intc', 'uintc', 'intp', 'uintp',
51                    'float32', 'float64', 'bool_'):
52        yield tp_name
53
54def complex_np_types():
55    for tp_name in ('complex64', 'complex128'):
56        yield tp_name
57
58
59class TestScalarNumberCtor(TestCase):
60    """
61    Test <number class>(some scalar)
62    """
63
64    def check_int_constructor(self, pyfunc):
65        x_types = [
66            types.boolean, types.int32, types.int64, types.float32, types.float64
67        ]
68        x_values = [1, 0, 1000, 12.2, 23.4]
69
70        for ty, x in zip(x_types, x_values):
71            cres = compile_isolated(pyfunc, [ty])
72            cfunc = cres.entry_point
73            self.assertPreciseEqual(pyfunc(x), cfunc(x))
74
75    def test_bool(self):
76        self.check_int_constructor(dobool)
77
78    def test_int(self):
79        self.check_int_constructor(doint)
80
81    def test_float(self):
82        pyfunc = dofloat
83
84        x_types = [
85            types.int32, types.int64, types.float32, types.float64
86        ]
87        x_values = [1, 1000, 12.2, 23.4]
88
89        for ty, x in zip(x_types, x_values):
90            cres = compile_isolated(pyfunc, [ty])
91            cfunc = cres.entry_point
92            self.assertPreciseEqual(pyfunc(x), cfunc(x),
93                prec='single' if ty is types.float32 else 'exact')
94
95    def test_complex(self):
96        pyfunc = docomplex
97
98        x_types = [
99            types.int32, types.int64, types.float32, types.float64,
100            types.complex64, types.complex128,
101        ]
102        x_values = [1, 1000, 12.2, 23.4, 1.5-5j, 1-4.75j]
103
104        for ty, x in zip(x_types, x_values):
105            cres = compile_isolated(pyfunc, [ty])
106            cfunc = cres.entry_point
107            got = cfunc(x)
108            expected = pyfunc(x)
109            self.assertPreciseEqual(pyfunc(x), cfunc(x),
110                prec='single' if ty is types.float32 else 'exact')
111
112        # Check that complex(float32) really creates a complex64,
113        # by checking the accuracy of computations.
114        pyfunc = complex_calc
115        x = 1.0 + 2**-50
116        cres = compile_isolated(pyfunc, [types.float32])
117        cfunc = cres.entry_point
118        self.assertPreciseEqual(cfunc(x), 1.0)
119        # Control (complex128)
120        cres = compile_isolated(pyfunc, [types.float64])
121        cfunc = cres.entry_point
122        self.assertGreater(cfunc(x), 1.0)
123
124    def test_complex2(self):
125        pyfunc = docomplex2
126
127        x_types = [
128            types.int32, types.int64, types.float32, types.float64
129        ]
130        x_values = [1, 1000, 12.2, 23.4]
131        y_values = [x - 3 for x in x_values]
132
133        for ty, x, y in zip(x_types, x_values, y_values):
134            cres = compile_isolated(pyfunc, [ty, ty])
135            cfunc = cres.entry_point
136            self.assertPreciseEqual(pyfunc(x, y), cfunc(x, y),
137                prec='single' if ty is types.float32 else 'exact')
138
139        # Check that complex(float32, float32) really creates a complex64,
140        # by checking the accuracy of computations.
141        pyfunc = complex_calc2
142        x = 1.0 + 2**-50
143        cres = compile_isolated(pyfunc, [types.float32, types.float32])
144        cfunc = cres.entry_point
145        self.assertPreciseEqual(cfunc(x, x), 2.0)
146        # Control (complex128)
147        cres = compile_isolated(pyfunc, [types.float64, types.float32])
148        cfunc = cres.entry_point
149        self.assertGreater(cfunc(x, x), 2.0)
150
151    def check_type_converter(self, tp, np_type, values):
152        pyfunc = converter(tp)
153        cfunc = jit(nopython=True)(pyfunc)
154        if issubclass(np_type, np.integer):
155            # Converting from a Python int to a small Numpy int on 32-bit
156            # builds can raise "OverflowError: Python int too large to
157            # convert to C long".  Work around by going through a large
158            # Numpy int first.
159            np_converter = lambda x: np_type(np.int64(x))
160        else:
161            np_converter = np_type
162        dtype = np.dtype(np_type)
163        for val in values:
164            if dtype.kind == 'u' and isinstance(val, float) and val < 0.0:
165                # Converting negative float to unsigned int yields undefined
166                # behaviour (and concretely different on ARM vs. x86)
167                continue
168            expected = np_converter(val)
169            got = cfunc(val)
170            self.assertPreciseEqual(got, expected,
171                                    msg="for type %s with arg %s" % (np_type, val))
172
173    def check_number_types(self, tp_factory):
174        values = [0, 1, -1, 100003, 10000000000007, -100003, -10000000000007,
175                  1.5, -3.5]
176        for tp_name in real_np_types():
177            np_type = getattr(np, tp_name)
178            tp = tp_factory(tp_name)
179            self.check_type_converter(tp, np_type, values)
180        values.append(1.5+3j)
181        for tp_name in complex_np_types():
182            np_type = getattr(np, tp_name)
183            tp = tp_factory(tp_name)
184            self.check_type_converter(tp, np_type, values)
185
186    def test_numba_types(self):
187        """
188        Test explicit casting to Numba number types.
189        """
190        def tp_factory(tp_name):
191            return getattr(types, tp_name)
192        self.check_number_types(tp_factory)
193
194    def test_numpy_types(self):
195        """
196        Test explicit casting to Numpy number types.
197        """
198        def tp_factory(tp_name):
199            return getattr(np, tp_name)
200        self.check_number_types(tp_factory)
201
202
203class TestArrayNumberCtor(TestCase):
204    """
205    Test <number class>(some sequence)
206    """
207
208    def check_type_constructor(self, np_type, values):
209        pyfunc = converter(np_type)
210        cfunc = jit(nopython=True)(pyfunc)
211        for val in values:
212            expected = np_type(val)
213            got = cfunc(val)
214            self.assertPreciseEqual(got, expected)
215
216    def test_1d(self):
217        values = [
218            (1.0, 2.5),
219            (1, 2.5),
220            [1.0, 2.5],
221            (),
222            ]
223        for tp_name in real_np_types():
224            np_type = getattr(np, tp_name)
225            self.check_type_constructor(np_type, values)
226        values = [
227            (1j, 2.5),
228            [1.0, 2.5],
229            ]
230        for tp_name in complex_np_types():
231            np_type = getattr(np, tp_name)
232            self.check_type_constructor(np_type, values)
233
234    def test_2d(self):
235        values = [
236            ((1.0, 2.5), (3.5, 4)),
237            [(1.0, 2.5), (3.5, 4.0)],
238            ([1.0, 2.5], [3.5, 4.0]),
239            [(), ()],
240            ]
241        for tp_name in real_np_types():
242            np_type = getattr(np, tp_name)
243            self.check_type_constructor(np_type, values)
244        for tp_name in complex_np_types():
245            np_type = getattr(np, tp_name)
246            self.check_type_constructor(np_type, values)
247
248
249if __name__ == '__main__':
250    unittest.main()
251