1from functools import partial
2from itertools import permutations
3
4import numpy as np
5
6import unittest
7from numba.core.compiler import compile_isolated, Flags
8from numba import jit, njit, from_dtype, typeof
9from numba.core.errors import TypingError
10from numba.core import types, errors
11from numba.tests.support import (TestCase, MemoryLeakMixin, CompilationCache,
12                                 tag)
13
14enable_pyobj_flags = Flags()
15enable_pyobj_flags.set("enable_pyobject")
16
17no_pyobj_flags = Flags()
18no_pyobj_flags.set('nrt')
19
20
21def from_generic(pyfuncs_to_use):
22    """Decorator for generic check functions.
23        Iterates over 'pyfuncs_to_use', calling 'func' with the iterated
24        item as first argument. Example:
25
26        @from_generic(numpy_array_reshape, array_reshape)
27        def check_only_shape(pyfunc, arr, shape, expected_shape):
28            # Only check Numba result to avoid Numpy bugs
29            self.memory_leak_setup()
30            got = generic_run(pyfunc, arr, shape)
31            self.assertEqual(got.shape, expected_shape)
32            self.assertEqual(got.size, arr.size)
33            del got
34            self.memory_leak_teardown()
35    """
36    def decorator(func):
37        def result(*args, **kwargs):
38            return [func(pyfunc, *args, **kwargs) for pyfunc in pyfuncs_to_use]
39        return result
40    return decorator
41
42
43def array_reshape(arr, newshape):
44    return arr.reshape(newshape)
45
46
47def numpy_array_reshape(arr, newshape):
48    return np.reshape(arr, newshape)
49
50
51def flatten_array(a):
52    return a.flatten()
53
54
55def ravel_array(a):
56    return a.ravel()
57
58
59def ravel_array_size(a):
60    return a.ravel().size
61
62
63def numpy_ravel_array(a):
64    return np.ravel(a)
65
66
67def transpose_array(a):
68    return a.transpose()
69
70
71def numpy_transpose_array(a):
72    return np.transpose(a)
73
74def numpy_transpose_array_axes_kwarg(arr, axes):
75    return np.transpose(arr, axes=axes)
76
77
78def numpy_transpose_array_axes_kwarg_copy(arr, axes):
79    return np.transpose(arr, axes=axes).copy()
80
81
82def array_transpose_axes(arr, axes):
83    return arr.transpose(axes)
84
85
86def array_transpose_axes_copy(arr, axes):
87    return arr.transpose(axes).copy()
88
89
90def transpose_issue_4708(m, n):
91    r1 = np.reshape(np.arange(m * n * 3), (m, 3, n))
92    r2 = np.reshape(np.arange(n * 3), (n, 3))
93    r_dif = (r1 - r2.T).T
94    r_dif = np.transpose(r_dif, (2, 0, 1))
95    z = r_dif + 1
96    return z
97
98
99def squeeze_array(a):
100    return a.squeeze()
101
102
103def expand_dims(a, axis):
104    return np.expand_dims(a, axis)
105
106
107def atleast_1d(*args):
108    return np.atleast_1d(*args)
109
110
111def atleast_2d(*args):
112    return np.atleast_2d(*args)
113
114
115def atleast_3d(*args):
116    return np.atleast_3d(*args)
117
118
119def as_strided1(a):
120    # as_strided() with implicit shape
121    strides = (a.strides[0] // 2,) + a.strides[1:]
122    return np.lib.stride_tricks.as_strided(a, strides=strides)
123
124
125def as_strided2(a):
126    # Rolling window example as in https://github.com/numba/numba/issues/1884
127    window = 3
128    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
129    strides = a.strides + (a.strides[-1],)
130    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
131
132
133def add_axis2(a):
134    return a[np.newaxis, :]
135
136
137def bad_index(arr, arr2d):
138    x = arr.x,
139    y = arr.y
140    # note that `x` is a tuple, which causes a new axis to be created.
141    arr2d[x, y] = 1.0
142
143
144def bad_float_index(arr):
145    # 2D index required for this function because 1D index
146    # fails typing
147    return arr[1, 2.0]
148
149
150def numpy_fill_diagonal(arr, val, wrap=False):
151    return np.fill_diagonal(arr, val, wrap)
152
153
154def numpy_shape(arr):
155    return np.shape(arr)
156
157
158def numpy_flatnonzero(a):
159    return np.flatnonzero(a)
160
161
162def numpy_argwhere(a):
163    return np.argwhere(a)
164
165
166class TestArrayManipulation(MemoryLeakMixin, TestCase):
167    """
168    Check shape-changing operations on arrays.
169    """
170
171    def setUp(self):
172        super(TestArrayManipulation, self).setUp()
173        self.ccache = CompilationCache()
174
175    def test_array_reshape(self):
176        pyfuncs_to_use = [array_reshape, numpy_array_reshape]
177
178        def generic_run(pyfunc, arr, shape):
179            cres = compile_isolated(pyfunc, (typeof(arr), typeof(shape)))
180            return cres.entry_point(arr, shape)
181
182        @from_generic(pyfuncs_to_use)
183        def check(pyfunc, arr, shape):
184            expected = pyfunc(arr, shape)
185            self.memory_leak_setup()
186            got = generic_run(pyfunc, arr, shape)
187            self.assertPreciseEqual(got, expected)
188            del got
189            self.memory_leak_teardown()
190
191        @from_generic(pyfuncs_to_use)
192        def check_only_shape(pyfunc, arr, shape, expected_shape):
193            # Only check Numba result to avoid Numpy bugs
194            self.memory_leak_setup()
195            got = generic_run(pyfunc, arr, shape)
196            self.assertEqual(got.shape, expected_shape)
197            self.assertEqual(got.size, arr.size)
198            del got
199            self.memory_leak_teardown()
200
201        @from_generic(pyfuncs_to_use)
202        def check_err_shape(pyfunc, arr, shape):
203            with self.assertRaises(NotImplementedError) as raises:
204                generic_run(pyfunc, arr, shape)
205            self.assertEqual(str(raises.exception),
206                             "incompatible shape for array")
207
208        @from_generic(pyfuncs_to_use)
209        def check_err_size(pyfunc, arr, shape):
210            with self.assertRaises(ValueError) as raises:
211                generic_run(pyfunc, arr, shape)
212            self.assertEqual(str(raises.exception),
213                             "total size of new array must be unchanged")
214
215        @from_generic(pyfuncs_to_use)
216        def check_err_multiple_negative(pyfunc, arr, shape):
217            with self.assertRaises(ValueError) as raises:
218                generic_run(pyfunc, arr, shape)
219            self.assertEqual(str(raises.exception),
220                             "multiple negative shape values")
221
222
223        # C-contiguous
224        arr = np.arange(24)
225        check(arr, (24,))
226        check(arr, (4, 6))
227        check(arr, (8, 3))
228        check(arr, (8, 1, 3))
229        check(arr, (1, 8, 1, 1, 3, 1))
230        arr = np.arange(24).reshape((2, 3, 4))
231        check(arr, (24,))
232        check(arr, (4, 6))
233        check(arr, (8, 3))
234        check(arr, (8, 1, 3))
235        check(arr, (1, 8, 1, 1, 3, 1))
236        check_err_size(arr, ())
237        check_err_size(arr, (25,))
238        check_err_size(arr, (8, 4))
239        arr = np.arange(24).reshape((1, 8, 1, 1, 3, 1))
240        check(arr, (24,))
241        check(arr, (4, 6))
242        check(arr, (8, 3))
243        check(arr, (8, 1, 3))
244
245        # F-contiguous
246        arr = np.arange(24).reshape((2, 3, 4)).T
247        check(arr, (4, 3, 2))
248        check(arr, (1, 4, 1, 3, 1, 2, 1))
249        check_err_shape(arr, (2, 3, 4))
250        check_err_shape(arr, (6, 4))
251        check_err_shape(arr, (2, 12))
252
253        # Test negative shape value
254        arr = np.arange(25).reshape(5,5)
255        check(arr, -1)
256        check(arr, (-1,))
257        check(arr, (-1, 5))
258        check(arr, (5, -1, 5))
259        check(arr, (5, 5, -1))
260        check_err_size(arr, (-1, 4))
261        check_err_multiple_negative(arr, (-1, -2, 5, 5))
262        check_err_multiple_negative(arr, (5, 5, -1, -1))
263
264        # 0-sized arrays
265        def check_empty(arr):
266            check(arr, 0)
267            check(arr, (0,))
268            check(arr, (1, 0, 2))
269            check(arr, (0, 55, 1, 0, 2))
270            # -1 is buggy in Numpy with 0-sized arrays
271            check_only_shape(arr, -1, (0,))
272            check_only_shape(arr, (-1,), (0,))
273            check_only_shape(arr, (0, -1), (0, 0))
274            check_only_shape(arr, (4, -1), (4, 0))
275            check_only_shape(arr, (-1, 0, 4), (0, 0, 4))
276            check_err_size(arr, ())
277            check_err_size(arr, 1)
278            check_err_size(arr, (1, 2))
279
280        arr = np.array([])
281        check_empty(arr)
282        check_empty(arr.reshape((3, 2, 0)))
283
284        # Exceptions leak references
285        self.disable_leak_check()
286
287    def test_array_transpose_axes(self):
288        pyfuncs_to_use = [numpy_transpose_array_axes_kwarg,
289                          numpy_transpose_array_axes_kwarg_copy,
290                          array_transpose_axes,
291                          array_transpose_axes_copy]
292
293        def run(pyfunc, arr, axes):
294            cres = self.ccache.compile(pyfunc, (typeof(arr), typeof(axes)))
295            return cres.entry_point(arr, axes)
296
297        @from_generic(pyfuncs_to_use)
298        def check(pyfunc, arr, axes):
299            expected = pyfunc(arr, axes)
300            got = run(pyfunc, arr, axes)
301            self.assertPreciseEqual(got, expected)
302            self.assertEqual(got.flags.f_contiguous,
303                             expected.flags.f_contiguous)
304            self.assertEqual(got.flags.c_contiguous,
305                             expected.flags.c_contiguous)
306
307        @from_generic(pyfuncs_to_use)
308        def check_err_axis_repeated(pyfunc, arr, axes):
309            with self.assertRaises(ValueError) as raises:
310                run(pyfunc, arr, axes)
311            self.assertEqual(str(raises.exception),
312                             "repeated axis in transpose")
313
314        @from_generic(pyfuncs_to_use)
315        def check_err_axis_oob(pyfunc, arr, axes):
316            with self.assertRaises(ValueError) as raises:
317                run(pyfunc, arr, axes)
318            self.assertEqual(str(raises.exception),
319                             "axis is out of bounds for array of given dimension")
320
321        @from_generic(pyfuncs_to_use)
322        def check_err_invalid_args(pyfunc, arr, axes):
323            with self.assertRaises((TypeError, TypingError)):
324                run(pyfunc, arr, axes)
325
326        arrs = [np.arange(24),
327                np.arange(24).reshape(4, 6),
328                np.arange(24).reshape(2, 3, 4),
329                np.arange(24).reshape(1, 2, 3, 4),
330                np.arange(64).reshape(8, 4, 2)[::3,::2,:]]
331
332        for i in range(len(arrs)):
333            # First check `None`, the default, which is to reverse dims
334            check(arrs[i], None)
335            # Check supplied axis permutations
336            for axes in permutations(tuple(range(arrs[i].ndim))):
337                ndim = len(axes)
338                neg_axes = tuple([x - ndim for x in axes])
339                check(arrs[i], axes)
340                check(arrs[i], neg_axes)
341
342        @from_generic([transpose_issue_4708])
343        def check_issue_4708(pyfunc, m, n):
344            expected = pyfunc(m, n)
345            got = njit(pyfunc)(m, n)
346            # values in arrays are equals,
347            # but stronger assertions not hold (layout and strides equality)
348            np.testing.assert_equal(got, expected)
349
350        check_issue_4708(3, 2)
351        check_issue_4708(2, 3)
352        check_issue_4708(5, 4)
353
354        # Exceptions leak references
355        self.disable_leak_check()
356
357        check_err_invalid_args(arrs[1], "foo")
358        check_err_invalid_args(arrs[1], ("foo",))
359        check_err_invalid_args(arrs[1], 5.3)
360        check_err_invalid_args(arrs[2], (1.2, 5))
361
362        check_err_axis_repeated(arrs[1], (0, 0))
363        check_err_axis_repeated(arrs[2], (2, 0, 0))
364        check_err_axis_repeated(arrs[3], (3, 2, 1, 1))
365
366        check_err_axis_oob(arrs[0], (1,))
367        check_err_axis_oob(arrs[0], (-2,))
368        check_err_axis_oob(arrs[1], (0, 2))
369        check_err_axis_oob(arrs[1], (-3, 2))
370        check_err_axis_oob(arrs[1], (0, -3))
371        check_err_axis_oob(arrs[2], (3, 1, 2))
372        check_err_axis_oob(arrs[2], (-4, 1, 2))
373        check_err_axis_oob(arrs[3], (3, 1, 2, 5))
374        check_err_axis_oob(arrs[3], (3, 1, 2, -5))
375
376        with self.assertRaises(TypingError) as e:
377            jit(nopython=True)(numpy_transpose_array)((np.array([0, 1]),))
378        self.assertIn("np.transpose does not accept tuples",
379                        str(e.exception))
380
381    def test_expand_dims(self):
382        pyfunc = expand_dims
383
384        def run(arr, axis):
385            cres = self.ccache.compile(pyfunc, (typeof(arr), typeof(axis)))
386            return cres.entry_point(arr, axis)
387
388        def check(arr, axis):
389            expected = pyfunc(arr, axis)
390            self.memory_leak_setup()
391            got = run(arr, axis)
392            self.assertPreciseEqual(got, expected)
393            del got
394            self.memory_leak_teardown()
395
396        def check_all_axes(arr):
397            for axis in range(-arr.ndim - 1, arr.ndim + 1):
398                check(arr, axis)
399
400        # 1d
401        arr = np.arange(5)
402        check_all_axes(arr)
403        # 3d (C, F, A)
404        arr = np.arange(24).reshape((2, 3, 4))
405        check_all_axes(arr)
406        check_all_axes(arr.T)
407        check_all_axes(arr[::-1])
408        # 0d
409        arr = np.array(42)
410        check_all_axes(arr)
411
412    def check_atleast_nd(self, pyfunc, cfunc):
413        def check_result(got, expected):
414            # We would like to check the result has the same contiguity,
415            # but we can't rely on the "flags" attribute when there are
416            # 1-sized dimensions.
417            self.assertStridesEqual(got, expected)
418            self.assertPreciseEqual(got.flatten(), expected.flatten())
419
420        def check_single(arg):
421            check_result(cfunc(arg), pyfunc(arg))
422
423        def check_tuple(*args):
424            expected_tuple = pyfunc(*args)
425            got_tuple = cfunc(*args)
426            self.assertEqual(len(got_tuple), len(expected_tuple))
427            for got, expected in zip(got_tuple, expected_tuple):
428                check_result(got, expected)
429
430        # 0d
431        a1 = np.array(42)
432        a2 = np.array(5j)
433        check_single(a1)
434        check_tuple(a1, a2)
435        # 1d
436        b1 = np.arange(5)
437        b2 = np.arange(6) + 1j
438        b3 = b1[::-1]
439        check_single(b1)
440        check_tuple(b1, b2, b3)
441        # 2d
442        c1 = np.arange(6).reshape((2, 3))
443        c2 = c1.T
444        c3 = c1[::-1]
445        check_single(c1)
446        check_tuple(c1, c2, c3)
447        # 3d
448        d1 = np.arange(24).reshape((2, 3, 4))
449        d2 = d1.T
450        d3 = d1[::-1]
451        check_single(d1)
452        check_tuple(d1, d2, d3)
453        # 4d
454        e = np.arange(16).reshape((2, 2, 2, 2))
455        check_single(e)
456        # mixed dimensions
457        check_tuple(a1, b2, c3, d2)
458
459    def test_atleast_1d(self):
460        pyfunc = atleast_1d
461        cfunc = jit(nopython=True)(pyfunc)
462        self.check_atleast_nd(pyfunc, cfunc)
463
464    def test_atleast_2d(self):
465        pyfunc = atleast_2d
466        cfunc = jit(nopython=True)(pyfunc)
467        self.check_atleast_nd(pyfunc, cfunc)
468
469    def test_atleast_3d(self):
470        pyfunc = atleast_3d
471        cfunc = jit(nopython=True)(pyfunc)
472        self.check_atleast_nd(pyfunc, cfunc)
473
474    def check_as_strided(self, pyfunc):
475        def run(arr):
476            cres = self.ccache.compile(pyfunc, (typeof(arr),))
477            return cres.entry_point(arr)
478        def check(arr):
479            expected = pyfunc(arr)
480            got = run(arr)
481            self.assertPreciseEqual(got, expected)
482
483        arr = np.arange(24)
484        check(arr)
485        check(arr.reshape((6, 4)))
486        check(arr.reshape((4, 1, 6)))
487
488    def test_as_strided(self):
489        self.check_as_strided(as_strided1)
490        self.check_as_strided(as_strided2)
491
492    def test_flatten_array(self, flags=enable_pyobj_flags, layout='C'):
493        a = np.arange(9).reshape(3, 3)
494        if layout == 'F':
495            a = a.T
496
497        pyfunc = flatten_array
498        arraytype1 = typeof(a)
499        if layout == 'A':
500            # Force A layout
501            arraytype1 = arraytype1.copy(layout='A')
502
503        self.assertEqual(arraytype1.layout, layout)
504        cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
505        cfunc = cr.entry_point
506
507        expected = pyfunc(a)
508        got = cfunc(a)
509        np.testing.assert_equal(expected, got)
510
511    def test_flatten_array_npm(self):
512        self.test_flatten_array(flags=no_pyobj_flags)
513        self.test_flatten_array(flags=no_pyobj_flags, layout='F')
514        self.test_flatten_array(flags=no_pyobj_flags, layout='A')
515
516    def test_ravel_array(self, flags=enable_pyobj_flags):
517        def generic_check(pyfunc, a, assume_layout):
518            # compile
519            arraytype1 = typeof(a)
520            self.assertEqual(arraytype1.layout, assume_layout)
521            cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
522            cfunc = cr.entry_point
523
524            expected = pyfunc(a)
525            got = cfunc(a)
526            # Check result matches
527            np.testing.assert_equal(expected, got)
528            # Check copying behavior
529            py_copied = (a.ctypes.data != expected.ctypes.data)
530            nb_copied = (a.ctypes.data != got.ctypes.data)
531            self.assertEqual(py_copied, assume_layout != 'C')
532            self.assertEqual(py_copied, nb_copied)
533
534        check_method = partial(generic_check, ravel_array)
535        check_function = partial(generic_check, numpy_ravel_array)
536
537        def check(*args, **kwargs):
538            check_method(*args, **kwargs)
539            check_function(*args, **kwargs)
540
541        # Check 2D
542        check(np.arange(9).reshape(3, 3), assume_layout='C')
543        check(np.arange(9).reshape(3, 3, order='F'), assume_layout='F')
544        check(np.arange(18).reshape(3, 3, 2)[:, :, 0], assume_layout='A')
545
546        # Check 3D
547        check(np.arange(18).reshape(2, 3, 3), assume_layout='C')
548        check(np.arange(18).reshape(2, 3, 3, order='F'), assume_layout='F')
549        check(np.arange(36).reshape(2, 3, 3, 2)[:, :, :, 0], assume_layout='A')
550
551    def test_ravel_array_size(self, flags=enable_pyobj_flags):
552        a = np.arange(9).reshape(3, 3)
553
554        pyfunc = ravel_array_size
555        arraytype1 = typeof(a)
556        cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
557        cfunc = cr.entry_point
558
559        expected = pyfunc(a)
560        got = cfunc(a)
561        np.testing.assert_equal(expected, got)
562
563    def test_ravel_array_npm(self):
564        self.test_ravel_array(flags=no_pyobj_flags)
565
566    def test_ravel_array_size_npm(self):
567        self.test_ravel_array_size(flags=no_pyobj_flags)
568
569    def test_transpose_array(self, flags=enable_pyobj_flags):
570        @from_generic([transpose_array, numpy_transpose_array])
571        def check(pyfunc):
572            a = np.arange(9).reshape(3, 3)
573
574            arraytype1 = typeof(a)
575            cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
576            cfunc = cr.entry_point
577
578            expected = pyfunc(a)
579            got = cfunc(a)
580            np.testing.assert_equal(expected, got)
581
582        check()
583
584    def test_transpose_array_npm(self):
585        self.test_transpose_array(flags=no_pyobj_flags)
586
587    def test_squeeze_array(self, flags=enable_pyobj_flags):
588        a = np.arange(2 * 1 * 3 * 1 * 4).reshape(2, 1, 3, 1, 4)
589
590        pyfunc = squeeze_array
591        arraytype1 = typeof(a)
592        cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
593        cfunc = cr.entry_point
594
595        expected = pyfunc(a)
596        got = cfunc(a)
597        np.testing.assert_equal(expected, got)
598
599    def test_squeeze_array_npm(self):
600        with self.assertRaises(errors.TypingError) as raises:
601            self.test_squeeze_array(flags=no_pyobj_flags)
602
603        self.assertIn("squeeze", str(raises.exception))
604
605    def test_add_axis2(self, flags=enable_pyobj_flags):
606        a = np.arange(9).reshape(3, 3)
607
608        pyfunc = add_axis2
609        arraytype1 = typeof(a)
610        cr = compile_isolated(pyfunc, (arraytype1,), flags=flags)
611        cfunc = cr.entry_point
612
613        expected = pyfunc(a)
614        got = cfunc(a)
615        np.testing.assert_equal(expected, got)
616
617    def test_add_axis2_npm(self):
618        with self.assertTypingError() as raises:
619            self.test_add_axis2(flags=no_pyobj_flags)
620        self.assertIn("unsupported array index type none in",
621                      str(raises.exception))
622
623    def test_bad_index_npm(self):
624        with self.assertTypingError() as raises:
625            arraytype1 = from_dtype(np.dtype([('x', np.int32),
626                                              ('y', np.int32)]))
627            arraytype2 = types.Array(types.int32, 2, 'C')
628            compile_isolated(bad_index, (arraytype1, arraytype2),
629                             flags=no_pyobj_flags)
630        self.assertIn('unsupported array index type', str(raises.exception))
631
632    def test_bad_float_index_npm(self):
633        with self.assertTypingError() as raises:
634            compile_isolated(bad_float_index,
635                             (types.Array(types.float64, 2, 'C'),))
636        self.assertIn('unsupported array index type float64',
637                      str(raises.exception))
638
639    def test_fill_diagonal_basic(self):
640        pyfunc = numpy_fill_diagonal
641        cfunc = jit(nopython=True)(pyfunc)
642
643        def _shape_variations(n):
644            # square
645            yield (n, n)
646            # tall and thin
647            yield (2 * n, n)
648            # short and fat
649            yield (n, 2 * n)
650            # a bit taller than wide; odd numbers of rows and cols
651            yield ((2 * n + 1), (2 * n - 1))
652            # 4d, all dimensions same
653            yield (n, n, n, n)
654            # weird edge case
655            yield (1, 1, 1)
656
657        def _val_variations():
658            yield 1
659            yield 3.142
660            yield np.nan
661            yield -np.inf
662            yield True
663            yield np.arange(4)
664            yield (4,)
665            yield [8, 9]
666            yield np.arange(54).reshape(9, 3, 2, 1)  # contiguous C
667            yield np.asfortranarray(np.arange(9).reshape(3, 3))  # contiguous F
668            yield np.arange(9).reshape(3, 3)[::-1]  # non-contiguous
669
670        # contiguous arrays
671        def _multi_dimensional_array_variations(n):
672            for shape in _shape_variations(n):
673                yield np.zeros(shape, dtype=np.float64)
674                yield np.asfortranarray(np.ones(shape, dtype=np.float64))
675
676        # non-contiguous arrays
677        def _multi_dimensional_array_variations_strided(n):
678            for shape in _shape_variations(n):
679                tmp = np.zeros(tuple([x * 2 for x in shape]), dtype=np.float64)
680                slicer = tuple(slice(0, x * 2, 2) for x in shape)
681                yield tmp[slicer]
682
683        def _check_fill_diagonal(arr, val):
684            for wrap in None, True, False:
685                a = arr.copy()
686                b = arr.copy()
687
688                if wrap is None:
689                    params = {}
690                else:
691                    params = {'wrap': wrap}
692
693                pyfunc(a, val, **params)
694                cfunc(b, val, **params)
695                self.assertPreciseEqual(a, b)
696
697        for arr in _multi_dimensional_array_variations(3):
698            for val in _val_variations():
699                _check_fill_diagonal(arr, val)
700
701        for arr in _multi_dimensional_array_variations_strided(3):
702            for val in _val_variations():
703                _check_fill_diagonal(arr, val)
704
705        # non-numeric input arrays
706        arr = np.array([True] * 9).reshape(3, 3)
707        _check_fill_diagonal(arr, False)
708        _check_fill_diagonal(arr, [False, True, False])
709        _check_fill_diagonal(arr, np.array([True, False, True]))
710
711    def test_fill_diagonal_exception_cases(self):
712        pyfunc = numpy_fill_diagonal
713        cfunc = jit(nopython=True)(pyfunc)
714        val = 1
715
716        # Exceptions leak references
717        self.disable_leak_check()
718
719        # first argument unsupported number of dimensions
720        for a in np.array([]), np.ones(5):
721            with self.assertRaises(TypingError) as raises:
722                cfunc(a, val)
723            assert "The first argument must be at least 2-D" in str(raises.exception)
724
725        # multi-dimensional input where dimensions are not all equal
726        with self.assertRaises(ValueError) as raises:
727            a = np.zeros((3, 3, 4))
728            cfunc(a, val)
729            self.assertEqual("All dimensions of input must be of equal length", str(raises.exception))
730
731        # cases where val has incompatible type / value
732        def _assert_raises(arr, val):
733            with self.assertRaises(ValueError) as raises:
734                cfunc(arr, val)
735            self.assertEqual("Unable to safely conform val to a.dtype", str(raises.exception))
736
737        arr = np.zeros((3, 3), dtype=np.int32)
738        val = np.nan
739        _assert_raises(arr, val)
740
741        val = [3.3, np.inf]
742        _assert_raises(arr, val)
743
744        val = np.array([1, 2, 1e10], dtype=np.int64)
745        _assert_raises(arr, val)
746
747        arr = np.zeros((3, 3), dtype=np.float32)
748        val = [1.4, 2.6, -1e100]
749        _assert_raises(arr, val)
750
751        val = 1.1e100
752        _assert_raises(arr, val)
753
754        val = np.array([-1e100])
755        _assert_raises(arr, val)
756
757
758    def test_shape(self):
759        pyfunc = numpy_shape
760        cfunc = jit(nopython=True)(pyfunc)
761
762        def check(x):
763            expected = pyfunc(x)
764            got = cfunc(x)
765            self.assertPreciseEqual(got, expected)
766
767        # check arrays
768        for t in [(), (1,), (2, 3,), (4, 5, 6)]:
769            arr = np.empty(t)
770            check(arr)
771
772        # check some types that go via asarray
773        for t in [1, False, [1,], [[1, 2,],[3, 4]], (1,), (1, 2, 3)]:
774            check(arr)
775
776        with self.assertRaises(TypingError) as raises:
777            cfunc('a')
778
779        self.assertIn("The argument to np.shape must be array-like",
780                      str(raises.exception))
781
782    def test_flatnonzero_basic(self):
783        pyfunc = numpy_flatnonzero
784        cfunc = jit(nopython=True)(pyfunc)
785
786        def a_variations():
787            yield np.arange(-5, 5)
788            yield np.full(5, fill_value=0)
789            yield np.array([])
790            a = self.random.randn(100)
791            a[np.abs(a) > 0.2] = 0.0
792            yield a
793            yield a.reshape(5, 5, 4)
794            yield a.reshape(50, 2, order='F')
795            yield a.reshape(25, 4)[1::2]
796            yield a * 1j
797
798        for a in a_variations():
799            expected = pyfunc(a)
800            got = cfunc(a)
801            self.assertPreciseEqual(expected, got)
802
803    def test_argwhere_basic(self):
804        pyfunc = numpy_argwhere
805        cfunc = jit(nopython=True)(pyfunc)
806
807        def a_variations():
808            yield np.arange(-5, 5) > 2
809            yield np.full(5, fill_value=0)
810            yield np.full(5, fill_value=1)
811            yield np.array([])
812            yield np.array([-1.0, 0.0, 1.0])
813            a = self.random.randn(100)
814            yield a > 0.2
815            yield a.reshape(5, 5, 4) > 0.5
816            yield a.reshape(50, 2, order='F') > 0.5
817            yield a.reshape(25, 4)[1::2] > 0.5
818            yield a == a - 1
819            yield a > -a
820
821        for a in a_variations():
822            expected = pyfunc(a)
823            got = cfunc(a)
824            self.assertPreciseEqual(expected, got)
825
826    @staticmethod
827    def array_like_variations():
828        yield ((1.1, 2.2), (3.3, 4.4), (5.5, 6.6))
829        yield (0.0, 1.0, 0.0, -6.0)
830        yield ([0, 1], [2, 3])
831        yield ()
832        yield np.nan
833        yield 0
834        yield 1
835        yield False
836        yield True
837        yield (True, False, True)
838        yield 2 + 1j
839        # the following are not array-like, but NumPy does not raise
840        yield None
841        yield 'a_string'
842        yield ''
843
844
845    def test_flatnonzero_array_like(self):
846        pyfunc = numpy_flatnonzero
847        cfunc = jit(nopython=True)(pyfunc)
848
849        for a in self.array_like_variations():
850            expected = pyfunc(a)
851            got = cfunc(a)
852            self.assertPreciseEqual(expected, got)
853
854    def test_argwhere_array_like(self):
855        pyfunc = numpy_argwhere
856        cfunc = jit(nopython=True)(pyfunc)
857        for a in self.array_like_variations():
858            expected = pyfunc(a)
859            got = cfunc(a)
860            self.assertPreciseEqual(expected, got)
861
862
863if __name__ == '__main__':
864    unittest.main()
865