1"""
2Test helper functions from numba.numpy_support.
3"""
4
5
6import sys
7from itertools import product
8
9import numpy as np
10
11import unittest
12from numba.core import types
13from numba.tests.support import TestCase
14from numba.tests.enum_usecases import Shake, RequestError
15from numba.np import numpy_support
16
17
18class TestFromDtype(TestCase):
19
20    def test_number_types(self):
21        """
22        Test from_dtype() and as_dtype() with the various scalar number types.
23        """
24        f = numpy_support.from_dtype
25
26        def check(typechar, numba_type):
27            # Only native ordering and alignment is supported
28            dtype = np.dtype(typechar)
29            self.assertIs(f(dtype), numba_type)
30            self.assertIs(f(np.dtype('=' + typechar)), numba_type)
31            self.assertEqual(dtype, numpy_support.as_dtype(numba_type))
32
33        check('?', types.bool_)
34        check('f', types.float32)
35        check('f4', types.float32)
36        check('d', types.float64)
37        check('f8', types.float64)
38
39        check('F', types.complex64)
40        check('c8', types.complex64)
41        check('D', types.complex128)
42        check('c16', types.complex128)
43
44        check('O', types.pyobject)
45
46        check('b', types.int8)
47        check('i1', types.int8)
48        check('B', types.uint8)
49        check('u1', types.uint8)
50
51        check('h', types.int16)
52        check('i2', types.int16)
53        check('H', types.uint16)
54        check('u2', types.uint16)
55
56        check('i', types.int32)
57        check('i4', types.int32)
58        check('I', types.uint32)
59        check('u4', types.uint32)
60
61        check('q', types.int64)
62        check('Q', types.uint64)
63        for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32',
64                     'int64', 'uint64', 'intp', 'uintp'):
65            self.assertIs(f(np.dtype(name)), getattr(types, name))
66
67        # Non-native alignments are unsupported (except for 1-byte types)
68        foreign_align = '>' if sys.byteorder == 'little' else '<'
69        for letter in 'hHiIlLqQfdFD':
70            self.assertRaises(NotImplementedError, f,
71                              np.dtype(foreign_align + letter))
72
73    def test_string_types(self):
74        """
75        Test from_dtype() and as_dtype() with the character string types.
76        """
77        def check(typestring, numba_type):
78            # Only native ordering and alignment is supported
79            dtype = np.dtype(typestring)
80            self.assertEqual(numpy_support.from_dtype(dtype), numba_type)
81            self.assertEqual(dtype, numpy_support.as_dtype(numba_type))
82
83        check('S10', types.CharSeq(10))
84        check('a11', types.CharSeq(11))
85        check('U12', types.UnicodeCharSeq(12))
86
87    def check_datetime_types(self, letter, nb_class):
88        def check(dtype, numba_type, code):
89            tp = numpy_support.from_dtype(dtype)
90            self.assertEqual(tp, numba_type)
91            self.assertEqual(tp.unit_code, code)
92            self.assertEqual(numpy_support.as_dtype(numba_type), dtype)
93            self.assertEqual(numpy_support.as_dtype(tp), dtype)
94
95        # Unit-less ("generic") type
96        check(np.dtype(letter), nb_class(''), 14)
97
98    def test_datetime_types(self):
99        """
100        Test from_dtype() and as_dtype() with the datetime types.
101        """
102        self.check_datetime_types('M', types.NPDatetime)
103
104    def test_timedelta_types(self):
105        """
106        Test from_dtype() and as_dtype() with the timedelta types.
107        """
108        self.check_datetime_types('m', types.NPTimedelta)
109
110    def test_struct_types(self):
111        def check(dtype, fields, size, aligned):
112            tp = numpy_support.from_dtype(dtype)
113            self.assertIsInstance(tp, types.Record)
114            # Only check for dtype equality, as the Numba type may be interned
115            self.assertEqual(tp.dtype, dtype)
116            self.assertEqual(tp.fields, fields)
117            self.assertEqual(tp.size, size)
118            self.assertEqual(tp.aligned, aligned)
119
120        dtype = np.dtype([('a', np.int16), ('b', np.int32)])
121        check(dtype,
122              fields={'a': (types.int16, 0, None, None),
123                      'b': (types.int32, 2, None, None)},
124              size=6, aligned=False)
125
126        dtype = np.dtype([('a', np.int16), ('b', np.int32)], align=True)
127        check(dtype,
128              fields={'a': (types.int16, 0, None, None),
129                      'b': (types.int32, 4, None, None)},
130              size=8, aligned=True)
131
132        dtype = np.dtype([('m', np.int32), ('n', 'S5')])
133        check(dtype,
134              fields={'m': (types.int32, 0, None, None),
135                      'n': (types.CharSeq(5), 4, None, None)},
136              size=9, aligned=False)
137
138    def test_enum_type(self):
139
140        def check(base_inst, enum_def, type_class):
141            np_dt = np.dtype(base_inst)
142            nb_ty = numpy_support.from_dtype(np_dt)
143            inst = type_class(enum_def, nb_ty)
144            recovered = numpy_support.as_dtype(inst)
145            self.assertEqual(np_dt, recovered)
146
147        dts = [np.float64, np.int32, np.complex128, np.bool]
148        enums = [Shake, RequestError]
149
150        for dt, enum in product(dts, enums):
151            check(dt, enum, types.EnumMember)
152
153        for dt, enum in product(dts, enums):
154            check(dt, enum, types.IntEnumMember)
155
156
157class ValueTypingTestBase(object):
158    """
159    Common tests for the typing of values.  Also used by test_special.
160    """
161
162    def check_number_values(self, func):
163        """
164        Test *func*() with scalar numeric values.
165        """
166        f = func
167        # Standard Python types get inferred by numpy
168        self.assertIn(f(1), (types.int32, types.int64))
169        self.assertIn(f(2**31 - 1), (types.int32, types.int64))
170        self.assertIn(f(-2**31), (types.int32, types.int64))
171        self.assertIs(f(1.0), types.float64)
172        self.assertIs(f(1.0j), types.complex128)
173        self.assertIs(f(True), types.bool_)
174        self.assertIs(f(False), types.bool_)
175        # Numpy scalar types get converted by from_dtype()
176        for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32',
177                     'int64', 'uint64', 'intc', 'uintc', 'intp', 'uintp',
178                     'float32', 'float64', 'complex64', 'complex128',
179                     'bool_'):
180            val = getattr(np, name)()
181            self.assertIs(f(val), getattr(types, name))
182
183    def _base_check_datetime_values(self, func, np_type, nb_type):
184        f = func
185        for unit in [
186            '', 'Y', 'M', 'D', 'h', 'm', 's',
187            'ms', 'us', 'ns', 'ps', 'fs', 'as',
188        ]:
189            if unit:
190                t = np_type(3, unit)
191            else:
192                # "generic" datetime / timedelta
193                t = np_type('Nat')
194            tp = f(t)
195            # This ensures the unit hasn't been lost
196            self.assertEqual(tp, nb_type(unit))
197
198    def check_datetime_values(self, func):
199        """
200        Test *func*() with np.datetime64 values.
201        """
202        self._base_check_datetime_values(func, np.datetime64, types.NPDatetime)
203
204    def check_timedelta_values(self, func):
205        """
206        Test *func*() with np.timedelta64 values.
207        """
208        self._base_check_datetime_values(func, np.timedelta64,
209                                         types.NPTimedelta)
210
211
212class TestArrayScalars(ValueTypingTestBase, TestCase):
213
214    def test_number_values(self):
215        """
216        Test map_arrayscalar_type() with scalar number values.
217        """
218        self.check_number_values(numpy_support.map_arrayscalar_type)
219
220    def test_datetime_values(self):
221        """
222        Test map_arrayscalar_type() with np.datetime64 values.
223        """
224        f = numpy_support.map_arrayscalar_type
225        self.check_datetime_values(f)
226        # datetime64s with a non-one factor shouldn't be supported
227        t = np.datetime64('2014', '10Y')
228        with self.assertRaises(NotImplementedError):
229            f(t)
230
231    def test_timedelta_values(self):
232        """
233        Test map_arrayscalar_type() with np.timedelta64 values.
234        """
235        f = numpy_support.map_arrayscalar_type
236        self.check_timedelta_values(f)
237        # timedelta64s with a non-one factor shouldn't be supported
238        t = np.timedelta64(10, '10Y')
239        with self.assertRaises(NotImplementedError):
240            f(t)
241
242
243class FakeUFunc(object):
244    __slots__ = ('nin', 'nout', 'types', 'ntypes')
245
246    def __init__(self, types):
247        self.types = types
248        in_, out = self.types[0].split('->')
249        self.nin = len(in_)
250        self.nout = len(out)
251        self.ntypes = len(types)
252        for tp in types:
253            in_, out = self.types[0].split('->')
254            assert len(in_) == self.nin
255            assert len(out) == self.nout
256
257
258# Typical types for np.add, np.multiply, np.isnan
259_add_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I',
260              'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d',
261              'gg->g', 'FF->F', 'DD->D', 'GG->G', 'Mm->M', 'mm->m', 'mM->M',
262              'OO->O']
263
264_mul_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I',
265              'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d',
266              'gg->g', 'FF->F', 'DD->D', 'GG->G', 'mq->m', 'qm->m', 'md->m',
267              'dm->m', 'OO->O']
268
269# Those ones only have floating-point loops
270_isnan_types = ['e->?', 'f->?', 'd->?', 'g->?', 'F->?', 'D->?', 'G->?']
271_sqrt_types = ['e->e', 'f->f', 'd->d', 'g->g', 'F->F', 'D->D', 'G->G', 'O->O']
272
273
274class TestUFuncs(TestCase):
275    """
276    Test ufunc helpers.
277    """
278
279    def test_ufunc_find_matching_loop(self):
280        f = numpy_support.ufunc_find_matching_loop
281        np_add = FakeUFunc(_add_types)
282        np_mul = FakeUFunc(_mul_types)
283        np_isnan = FakeUFunc(_isnan_types)
284        np_sqrt = FakeUFunc(_sqrt_types)
285
286        def check(ufunc, input_types, sigs, output_types=()):
287            """
288            Check that ufunc_find_matching_loop() finds one of the given
289            *sigs* for *ufunc*, *input_types* and optional *output_types*.
290            """
291            loop = f(ufunc, input_types + output_types)
292            self.assertTrue(loop)
293            if isinstance(sigs, str):
294                sigs = (sigs,)
295            self.assertIn(loop.ufunc_sig, sigs,
296                          "inputs=%s and outputs=%s should have selected "
297                          "one of %s, got %s"
298                          % (input_types, output_types, sigs, loop.ufunc_sig))
299            self.assertEqual(len(loop.numpy_inputs), len(loop.inputs))
300            self.assertEqual(len(loop.numpy_outputs), len(loop.outputs))
301            if not output_types:
302                # Add explicit outputs and check the result is the same
303                loop_explicit = f(ufunc, list(input_types) + loop.outputs)
304                self.assertEqual(loop_explicit, loop)
305            else:
306                self.assertEqual(loop.outputs, list(output_types))
307            # Round-tripping inputs and outputs
308            loop_rt = f(ufunc, loop.inputs + loop.outputs)
309            self.assertEqual(loop_rt, loop)
310            return loop
311
312        def check_exact(ufunc, input_types, sigs, output_types=()):
313            """
314            Like check(), but also ensure no casting of inputs occurred.
315            """
316            loop = check(ufunc, input_types, sigs, output_types)
317            self.assertEqual(loop.inputs, list(input_types))
318
319        def check_no_match(ufunc, input_types):
320            loop = f(ufunc, input_types)
321            self.assertIs(loop, None)
322
323        # Exact matching for number types
324        check_exact(np_add, (types.bool_, types.bool_), '??->?')
325        check_exact(np_add, (types.int8, types.int8), 'bb->b')
326        check_exact(np_add, (types.uint8, types.uint8), 'BB->B')
327        check_exact(np_add, (types.int64, types.int64), ('ll->l', 'qq->q'))
328        check_exact(np_add, (types.uint64, types.uint64), ('LL->L', 'QQ->Q'))
329        check_exact(np_add, (types.float32, types.float32), 'ff->f')
330        check_exact(np_add, (types.float64, types.float64), 'dd->d')
331        check_exact(np_add, (types.complex64, types.complex64), 'FF->F')
332        check_exact(np_add, (types.complex128, types.complex128), 'DD->D')
333
334        # Exact matching for datetime64 and timedelta64 types
335        check_exact(np_add, (types.NPTimedelta('s'), types.NPTimedelta('s')),
336                    'mm->m', output_types=(types.NPTimedelta('s'),))
337        check_exact(np_add, (types.NPTimedelta('ms'), types.NPDatetime('s')),
338                    'mM->M', output_types=(types.NPDatetime('ms'),))
339        check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('s')),
340                    'Mm->M', output_types=(types.NPDatetime('s'),))
341
342        check_exact(np_mul, (types.NPTimedelta('s'), types.int64),
343                    'mq->m', output_types=(types.NPTimedelta('s'),))
344        check_exact(np_mul, (types.float64, types.NPTimedelta('s')),
345                    'dm->m', output_types=(types.NPTimedelta('s'),))
346
347        # Mix and match number types, with casting
348        check(np_add, (types.bool_, types.int8), 'bb->b')
349        check(np_add, (types.uint8, types.bool_), 'BB->B')
350        check(np_add, (types.int16, types.uint16), 'ii->i')
351        check(np_add, (types.complex64, types.float64), 'DD->D')
352        check(np_add, (types.float64, types.complex64), 'DD->D')
353        # Integers, when used together with floating-point numbers,
354        # should cast to any real or complex (see #2006)
355        int_types = [types.int32, types.uint32, types.int64, types.uint64]
356        for intty in int_types:
357            check(np_add, (types.float32, intty), 'ff->f')
358            check(np_add, (types.float64, intty), 'dd->d')
359            check(np_add, (types.complex64, intty), 'FF->F')
360            check(np_add, (types.complex128, intty), 'DD->D')
361        # However, when used alone, they should cast only to
362        # floating-point types of sufficient precision
363        # (typical use case: np.sqrt(2) should give an accurate enough value)
364        for intty in int_types:
365            check(np_sqrt, (intty,), 'd->d')
366            check(np_isnan, (intty,), 'd->?')
367
368        # With some timedelta64 arguments as well
369        check(np_mul, (types.NPTimedelta('s'), types.int32),
370              'mq->m', output_types=(types.NPTimedelta('s'),))
371        check(np_mul, (types.NPTimedelta('s'), types.uint32),
372              'mq->m', output_types=(types.NPTimedelta('s'),))
373        check(np_mul, (types.NPTimedelta('s'), types.float32),
374              'md->m', output_types=(types.NPTimedelta('s'),))
375        check(np_mul, (types.float32, types.NPTimedelta('s')),
376              'dm->m', output_types=(types.NPTimedelta('s'),))
377
378        # No match
379        check_no_match(np_add, (types.NPDatetime('s'), types.NPDatetime('s')))
380        # No implicit casting from int64 to timedelta64 (Numpy would allow
381        # this).
382        check_no_match(np_add, (types.NPTimedelta('s'), types.int64))
383
384    def test_layout_checker(self):
385        def check_arr(arr):
386            dims = arr.shape
387            strides = arr.strides
388            itemsize = arr.dtype.itemsize
389            is_c = numpy_support.is_contiguous(dims, strides, itemsize)
390            is_f = numpy_support.is_fortran(dims, strides, itemsize)
391            expect_c = arr.flags['C_CONTIGUOUS']
392            expect_f = arr.flags['F_CONTIGUOUS']
393            self.assertEqual(is_c, expect_c)
394            self.assertEqual(is_f, expect_f)
395
396        arr = np.arange(24)
397        # 1D
398        check_arr(arr)
399        # 2D
400        check_arr(arr.reshape((3, 8)))
401        check_arr(arr.reshape((3, 8)).T)
402        check_arr(arr.reshape((3, 8))[::2])
403        # 3D
404        check_arr(arr.reshape((2, 3, 4)))
405        check_arr(arr.reshape((2, 3, 4)).T)
406        # middle axis is shape 1
407        check_arr(arr.reshape((2, 3, 4))[:, ::3])
408        check_arr(arr.reshape((2, 3, 4)).T[:, ::3])
409
410        # leading axis is shape 1
411        check_arr(arr.reshape((2, 3, 4))[::2])
412        check_arr(arr.reshape((2, 3, 4)).T[:, :, ::2])
413        # 2 leading axis are shape 1
414        check_arr(arr.reshape((2, 3, 4))[::2, ::3])
415        check_arr(arr.reshape((2, 3, 4)).T[:, ::3, ::2])
416        # single item slices for all axis
417        check_arr(arr.reshape((2, 3, 4))[::2, ::3, ::4])
418        check_arr(arr.reshape((2, 3, 4)).T[::4, ::3, ::2])
419        # 4D
420        check_arr(arr.reshape((2, 2, 3, 2))[::2, ::2, ::3])
421        check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::2])
422        # outer zero dims
423        check_arr(arr.reshape((2, 2, 3, 2))[::5, ::2, ::3])
424        check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::5])
425
426
427if __name__ == '__main__':
428    unittest.main()
429