1""" 2Test helper functions from numba.numpy_support. 3""" 4 5 6import sys 7from itertools import product 8 9import numpy as np 10 11import unittest 12from numba.core import types 13from numba.tests.support import TestCase 14from numba.tests.enum_usecases import Shake, RequestError 15from numba.np import numpy_support 16 17 18class TestFromDtype(TestCase): 19 20 def test_number_types(self): 21 """ 22 Test from_dtype() and as_dtype() with the various scalar number types. 23 """ 24 f = numpy_support.from_dtype 25 26 def check(typechar, numba_type): 27 # Only native ordering and alignment is supported 28 dtype = np.dtype(typechar) 29 self.assertIs(f(dtype), numba_type) 30 self.assertIs(f(np.dtype('=' + typechar)), numba_type) 31 self.assertEqual(dtype, numpy_support.as_dtype(numba_type)) 32 33 check('?', types.bool_) 34 check('f', types.float32) 35 check('f4', types.float32) 36 check('d', types.float64) 37 check('f8', types.float64) 38 39 check('F', types.complex64) 40 check('c8', types.complex64) 41 check('D', types.complex128) 42 check('c16', types.complex128) 43 44 check('O', types.pyobject) 45 46 check('b', types.int8) 47 check('i1', types.int8) 48 check('B', types.uint8) 49 check('u1', types.uint8) 50 51 check('h', types.int16) 52 check('i2', types.int16) 53 check('H', types.uint16) 54 check('u2', types.uint16) 55 56 check('i', types.int32) 57 check('i4', types.int32) 58 check('I', types.uint32) 59 check('u4', types.uint32) 60 61 check('q', types.int64) 62 check('Q', types.uint64) 63 for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 64 'int64', 'uint64', 'intp', 'uintp'): 65 self.assertIs(f(np.dtype(name)), getattr(types, name)) 66 67 # Non-native alignments are unsupported (except for 1-byte types) 68 foreign_align = '>' if sys.byteorder == 'little' else '<' 69 for letter in 'hHiIlLqQfdFD': 70 self.assertRaises(NotImplementedError, f, 71 np.dtype(foreign_align + letter)) 72 73 def test_string_types(self): 74 """ 75 Test from_dtype() and as_dtype() with the character string types. 76 """ 77 def check(typestring, numba_type): 78 # Only native ordering and alignment is supported 79 dtype = np.dtype(typestring) 80 self.assertEqual(numpy_support.from_dtype(dtype), numba_type) 81 self.assertEqual(dtype, numpy_support.as_dtype(numba_type)) 82 83 check('S10', types.CharSeq(10)) 84 check('a11', types.CharSeq(11)) 85 check('U12', types.UnicodeCharSeq(12)) 86 87 def check_datetime_types(self, letter, nb_class): 88 def check(dtype, numba_type, code): 89 tp = numpy_support.from_dtype(dtype) 90 self.assertEqual(tp, numba_type) 91 self.assertEqual(tp.unit_code, code) 92 self.assertEqual(numpy_support.as_dtype(numba_type), dtype) 93 self.assertEqual(numpy_support.as_dtype(tp), dtype) 94 95 # Unit-less ("generic") type 96 check(np.dtype(letter), nb_class(''), 14) 97 98 def test_datetime_types(self): 99 """ 100 Test from_dtype() and as_dtype() with the datetime types. 101 """ 102 self.check_datetime_types('M', types.NPDatetime) 103 104 def test_timedelta_types(self): 105 """ 106 Test from_dtype() and as_dtype() with the timedelta types. 107 """ 108 self.check_datetime_types('m', types.NPTimedelta) 109 110 def test_struct_types(self): 111 def check(dtype, fields, size, aligned): 112 tp = numpy_support.from_dtype(dtype) 113 self.assertIsInstance(tp, types.Record) 114 # Only check for dtype equality, as the Numba type may be interned 115 self.assertEqual(tp.dtype, dtype) 116 self.assertEqual(tp.fields, fields) 117 self.assertEqual(tp.size, size) 118 self.assertEqual(tp.aligned, aligned) 119 120 dtype = np.dtype([('a', np.int16), ('b', np.int32)]) 121 check(dtype, 122 fields={'a': (types.int16, 0, None, None), 123 'b': (types.int32, 2, None, None)}, 124 size=6, aligned=False) 125 126 dtype = np.dtype([('a', np.int16), ('b', np.int32)], align=True) 127 check(dtype, 128 fields={'a': (types.int16, 0, None, None), 129 'b': (types.int32, 4, None, None)}, 130 size=8, aligned=True) 131 132 dtype = np.dtype([('m', np.int32), ('n', 'S5')]) 133 check(dtype, 134 fields={'m': (types.int32, 0, None, None), 135 'n': (types.CharSeq(5), 4, None, None)}, 136 size=9, aligned=False) 137 138 def test_enum_type(self): 139 140 def check(base_inst, enum_def, type_class): 141 np_dt = np.dtype(base_inst) 142 nb_ty = numpy_support.from_dtype(np_dt) 143 inst = type_class(enum_def, nb_ty) 144 recovered = numpy_support.as_dtype(inst) 145 self.assertEqual(np_dt, recovered) 146 147 dts = [np.float64, np.int32, np.complex128, np.bool] 148 enums = [Shake, RequestError] 149 150 for dt, enum in product(dts, enums): 151 check(dt, enum, types.EnumMember) 152 153 for dt, enum in product(dts, enums): 154 check(dt, enum, types.IntEnumMember) 155 156 157class ValueTypingTestBase(object): 158 """ 159 Common tests for the typing of values. Also used by test_special. 160 """ 161 162 def check_number_values(self, func): 163 """ 164 Test *func*() with scalar numeric values. 165 """ 166 f = func 167 # Standard Python types get inferred by numpy 168 self.assertIn(f(1), (types.int32, types.int64)) 169 self.assertIn(f(2**31 - 1), (types.int32, types.int64)) 170 self.assertIn(f(-2**31), (types.int32, types.int64)) 171 self.assertIs(f(1.0), types.float64) 172 self.assertIs(f(1.0j), types.complex128) 173 self.assertIs(f(True), types.bool_) 174 self.assertIs(f(False), types.bool_) 175 # Numpy scalar types get converted by from_dtype() 176 for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 177 'int64', 'uint64', 'intc', 'uintc', 'intp', 'uintp', 178 'float32', 'float64', 'complex64', 'complex128', 179 'bool_'): 180 val = getattr(np, name)() 181 self.assertIs(f(val), getattr(types, name)) 182 183 def _base_check_datetime_values(self, func, np_type, nb_type): 184 f = func 185 for unit in [ 186 '', 'Y', 'M', 'D', 'h', 'm', 's', 187 'ms', 'us', 'ns', 'ps', 'fs', 'as', 188 ]: 189 if unit: 190 t = np_type(3, unit) 191 else: 192 # "generic" datetime / timedelta 193 t = np_type('Nat') 194 tp = f(t) 195 # This ensures the unit hasn't been lost 196 self.assertEqual(tp, nb_type(unit)) 197 198 def check_datetime_values(self, func): 199 """ 200 Test *func*() with np.datetime64 values. 201 """ 202 self._base_check_datetime_values(func, np.datetime64, types.NPDatetime) 203 204 def check_timedelta_values(self, func): 205 """ 206 Test *func*() with np.timedelta64 values. 207 """ 208 self._base_check_datetime_values(func, np.timedelta64, 209 types.NPTimedelta) 210 211 212class TestArrayScalars(ValueTypingTestBase, TestCase): 213 214 def test_number_values(self): 215 """ 216 Test map_arrayscalar_type() with scalar number values. 217 """ 218 self.check_number_values(numpy_support.map_arrayscalar_type) 219 220 def test_datetime_values(self): 221 """ 222 Test map_arrayscalar_type() with np.datetime64 values. 223 """ 224 f = numpy_support.map_arrayscalar_type 225 self.check_datetime_values(f) 226 # datetime64s with a non-one factor shouldn't be supported 227 t = np.datetime64('2014', '10Y') 228 with self.assertRaises(NotImplementedError): 229 f(t) 230 231 def test_timedelta_values(self): 232 """ 233 Test map_arrayscalar_type() with np.timedelta64 values. 234 """ 235 f = numpy_support.map_arrayscalar_type 236 self.check_timedelta_values(f) 237 # timedelta64s with a non-one factor shouldn't be supported 238 t = np.timedelta64(10, '10Y') 239 with self.assertRaises(NotImplementedError): 240 f(t) 241 242 243class FakeUFunc(object): 244 __slots__ = ('nin', 'nout', 'types', 'ntypes') 245 246 def __init__(self, types): 247 self.types = types 248 in_, out = self.types[0].split('->') 249 self.nin = len(in_) 250 self.nout = len(out) 251 self.ntypes = len(types) 252 for tp in types: 253 in_, out = self.types[0].split('->') 254 assert len(in_) == self.nin 255 assert len(out) == self.nout 256 257 258# Typical types for np.add, np.multiply, np.isnan 259_add_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 260 'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d', 261 'gg->g', 'FF->F', 'DD->D', 'GG->G', 'Mm->M', 'mm->m', 'mM->M', 262 'OO->O'] 263 264_mul_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 265 'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d', 266 'gg->g', 'FF->F', 'DD->D', 'GG->G', 'mq->m', 'qm->m', 'md->m', 267 'dm->m', 'OO->O'] 268 269# Those ones only have floating-point loops 270_isnan_types = ['e->?', 'f->?', 'd->?', 'g->?', 'F->?', 'D->?', 'G->?'] 271_sqrt_types = ['e->e', 'f->f', 'd->d', 'g->g', 'F->F', 'D->D', 'G->G', 'O->O'] 272 273 274class TestUFuncs(TestCase): 275 """ 276 Test ufunc helpers. 277 """ 278 279 def test_ufunc_find_matching_loop(self): 280 f = numpy_support.ufunc_find_matching_loop 281 np_add = FakeUFunc(_add_types) 282 np_mul = FakeUFunc(_mul_types) 283 np_isnan = FakeUFunc(_isnan_types) 284 np_sqrt = FakeUFunc(_sqrt_types) 285 286 def check(ufunc, input_types, sigs, output_types=()): 287 """ 288 Check that ufunc_find_matching_loop() finds one of the given 289 *sigs* for *ufunc*, *input_types* and optional *output_types*. 290 """ 291 loop = f(ufunc, input_types + output_types) 292 self.assertTrue(loop) 293 if isinstance(sigs, str): 294 sigs = (sigs,) 295 self.assertIn(loop.ufunc_sig, sigs, 296 "inputs=%s and outputs=%s should have selected " 297 "one of %s, got %s" 298 % (input_types, output_types, sigs, loop.ufunc_sig)) 299 self.assertEqual(len(loop.numpy_inputs), len(loop.inputs)) 300 self.assertEqual(len(loop.numpy_outputs), len(loop.outputs)) 301 if not output_types: 302 # Add explicit outputs and check the result is the same 303 loop_explicit = f(ufunc, list(input_types) + loop.outputs) 304 self.assertEqual(loop_explicit, loop) 305 else: 306 self.assertEqual(loop.outputs, list(output_types)) 307 # Round-tripping inputs and outputs 308 loop_rt = f(ufunc, loop.inputs + loop.outputs) 309 self.assertEqual(loop_rt, loop) 310 return loop 311 312 def check_exact(ufunc, input_types, sigs, output_types=()): 313 """ 314 Like check(), but also ensure no casting of inputs occurred. 315 """ 316 loop = check(ufunc, input_types, sigs, output_types) 317 self.assertEqual(loop.inputs, list(input_types)) 318 319 def check_no_match(ufunc, input_types): 320 loop = f(ufunc, input_types) 321 self.assertIs(loop, None) 322 323 # Exact matching for number types 324 check_exact(np_add, (types.bool_, types.bool_), '??->?') 325 check_exact(np_add, (types.int8, types.int8), 'bb->b') 326 check_exact(np_add, (types.uint8, types.uint8), 'BB->B') 327 check_exact(np_add, (types.int64, types.int64), ('ll->l', 'qq->q')) 328 check_exact(np_add, (types.uint64, types.uint64), ('LL->L', 'QQ->Q')) 329 check_exact(np_add, (types.float32, types.float32), 'ff->f') 330 check_exact(np_add, (types.float64, types.float64), 'dd->d') 331 check_exact(np_add, (types.complex64, types.complex64), 'FF->F') 332 check_exact(np_add, (types.complex128, types.complex128), 'DD->D') 333 334 # Exact matching for datetime64 and timedelta64 types 335 check_exact(np_add, (types.NPTimedelta('s'), types.NPTimedelta('s')), 336 'mm->m', output_types=(types.NPTimedelta('s'),)) 337 check_exact(np_add, (types.NPTimedelta('ms'), types.NPDatetime('s')), 338 'mM->M', output_types=(types.NPDatetime('ms'),)) 339 check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('s')), 340 'Mm->M', output_types=(types.NPDatetime('s'),)) 341 342 check_exact(np_mul, (types.NPTimedelta('s'), types.int64), 343 'mq->m', output_types=(types.NPTimedelta('s'),)) 344 check_exact(np_mul, (types.float64, types.NPTimedelta('s')), 345 'dm->m', output_types=(types.NPTimedelta('s'),)) 346 347 # Mix and match number types, with casting 348 check(np_add, (types.bool_, types.int8), 'bb->b') 349 check(np_add, (types.uint8, types.bool_), 'BB->B') 350 check(np_add, (types.int16, types.uint16), 'ii->i') 351 check(np_add, (types.complex64, types.float64), 'DD->D') 352 check(np_add, (types.float64, types.complex64), 'DD->D') 353 # Integers, when used together with floating-point numbers, 354 # should cast to any real or complex (see #2006) 355 int_types = [types.int32, types.uint32, types.int64, types.uint64] 356 for intty in int_types: 357 check(np_add, (types.float32, intty), 'ff->f') 358 check(np_add, (types.float64, intty), 'dd->d') 359 check(np_add, (types.complex64, intty), 'FF->F') 360 check(np_add, (types.complex128, intty), 'DD->D') 361 # However, when used alone, they should cast only to 362 # floating-point types of sufficient precision 363 # (typical use case: np.sqrt(2) should give an accurate enough value) 364 for intty in int_types: 365 check(np_sqrt, (intty,), 'd->d') 366 check(np_isnan, (intty,), 'd->?') 367 368 # With some timedelta64 arguments as well 369 check(np_mul, (types.NPTimedelta('s'), types.int32), 370 'mq->m', output_types=(types.NPTimedelta('s'),)) 371 check(np_mul, (types.NPTimedelta('s'), types.uint32), 372 'mq->m', output_types=(types.NPTimedelta('s'),)) 373 check(np_mul, (types.NPTimedelta('s'), types.float32), 374 'md->m', output_types=(types.NPTimedelta('s'),)) 375 check(np_mul, (types.float32, types.NPTimedelta('s')), 376 'dm->m', output_types=(types.NPTimedelta('s'),)) 377 378 # No match 379 check_no_match(np_add, (types.NPDatetime('s'), types.NPDatetime('s'))) 380 # No implicit casting from int64 to timedelta64 (Numpy would allow 381 # this). 382 check_no_match(np_add, (types.NPTimedelta('s'), types.int64)) 383 384 def test_layout_checker(self): 385 def check_arr(arr): 386 dims = arr.shape 387 strides = arr.strides 388 itemsize = arr.dtype.itemsize 389 is_c = numpy_support.is_contiguous(dims, strides, itemsize) 390 is_f = numpy_support.is_fortran(dims, strides, itemsize) 391 expect_c = arr.flags['C_CONTIGUOUS'] 392 expect_f = arr.flags['F_CONTIGUOUS'] 393 self.assertEqual(is_c, expect_c) 394 self.assertEqual(is_f, expect_f) 395 396 arr = np.arange(24) 397 # 1D 398 check_arr(arr) 399 # 2D 400 check_arr(arr.reshape((3, 8))) 401 check_arr(arr.reshape((3, 8)).T) 402 check_arr(arr.reshape((3, 8))[::2]) 403 # 3D 404 check_arr(arr.reshape((2, 3, 4))) 405 check_arr(arr.reshape((2, 3, 4)).T) 406 # middle axis is shape 1 407 check_arr(arr.reshape((2, 3, 4))[:, ::3]) 408 check_arr(arr.reshape((2, 3, 4)).T[:, ::3]) 409 410 # leading axis is shape 1 411 check_arr(arr.reshape((2, 3, 4))[::2]) 412 check_arr(arr.reshape((2, 3, 4)).T[:, :, ::2]) 413 # 2 leading axis are shape 1 414 check_arr(arr.reshape((2, 3, 4))[::2, ::3]) 415 check_arr(arr.reshape((2, 3, 4)).T[:, ::3, ::2]) 416 # single item slices for all axis 417 check_arr(arr.reshape((2, 3, 4))[::2, ::3, ::4]) 418 check_arr(arr.reshape((2, 3, 4)).T[::4, ::3, ::2]) 419 # 4D 420 check_arr(arr.reshape((2, 2, 3, 2))[::2, ::2, ::3]) 421 check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::2]) 422 # outer zero dims 423 check_arr(arr.reshape((2, 2, 3, 2))[::5, ::2, ::3]) 424 check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::5]) 425 426 427if __name__ == '__main__': 428 unittest.main() 429