1import sys
2
3import numpy as np
4
5from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
6from numba import vectorize, guvectorize
7from numba.np.ufunc import PyUFunc_One
8from numba.np.ufunc.dufunc import DUFunc as UFuncBuilder
9from numba.tests.support import tag, TestCase
10from numba.core import config
11import unittest
12
13
14def add(a, b):
15    """An addition"""
16    return a + b
17
18def equals(a, b):
19    return a == b
20
21def mul(a, b):
22    """A multiplication"""
23    return a * b
24
25def guadd(a, b, c):
26    """A generalized addition"""
27    x, y = c.shape
28    for i in range(x):
29        for j in range(y):
30            c[i, j] = a[i, j] + b[i, j]
31
32@vectorize(nopython=True)
33def inner(a, b):
34    return a + b
35
36@vectorize(["int64(int64, int64)"], nopython=True)
37def inner_explicit(a, b):
38    return a + b
39
40def outer(a, b):
41    return inner(a, b)
42
43def outer_explicit(a, b):
44    return inner_explicit(a, b)
45
46
47class Dummy: pass
48
49
50def guadd_obj(a, b, c):
51    Dummy()  # to force object mode
52    x, y = c.shape
53    for i in range(x):
54        for j in range(y):
55            c[i, j] = a[i, j] + b[i, j]
56
57def guadd_scalar_obj(a, b, c):
58    Dummy()  # to force object mode
59    x, y = c.shape
60    for i in range(x):
61        for j in range(y):
62            c[i, j] = a[i, j] + b
63
64
65class MyException(Exception):
66    pass
67
68
69def guerror(a, b, c):
70    raise MyException
71
72
73class TestUfuncBuilding(TestCase):
74
75    def test_basic_ufunc(self):
76        ufb = UFuncBuilder(add)
77        cres = ufb.add("int32(int32, int32)")
78        self.assertFalse(cres.objectmode)
79        cres = ufb.add("int64(int64, int64)")
80        self.assertFalse(cres.objectmode)
81        ufunc = ufb.build_ufunc()
82
83        def check(a):
84            b = ufunc(a, a)
85            self.assertPreciseEqual(a + a, b)
86            self.assertEqual(b.dtype, a.dtype)
87
88        a = np.arange(12, dtype='int32')
89        check(a)
90        # Non-contiguous dimension
91        a = a[::2]
92        check(a)
93        a = a.reshape((2, 3))
94        check(a)
95
96        # Metadata
97        self.assertEqual(ufunc.__name__, "add")
98        self.assertIn("An addition", ufunc.__doc__)
99
100    def test_ufunc_struct(self):
101        ufb = UFuncBuilder(add)
102        cres = ufb.add("complex64(complex64, complex64)")
103        self.assertFalse(cres.objectmode)
104        ufunc = ufb.build_ufunc()
105
106        def check(a):
107            b = ufunc(a, a)
108            self.assertPreciseEqual(a + a, b)
109            self.assertEqual(b.dtype, a.dtype)
110
111        a = np.arange(12, dtype='complex64') + 1j
112        check(a)
113        # Non-contiguous dimension
114        a = a[::2]
115        check(a)
116        a = a.reshape((2, 3))
117        check(a)
118
119    def test_ufunc_forceobj(self):
120        ufb = UFuncBuilder(add, targetoptions={'forceobj': True})
121        cres = ufb.add("int32(int32, int32)")
122        self.assertTrue(cres.objectmode)
123        ufunc = ufb.build_ufunc()
124
125        a = np.arange(10, dtype='int32')
126        b = ufunc(a, a)
127        self.assertPreciseEqual(a + a, b)
128
129    def test_nested_call(self):
130        """
131        Check nested call to an implicitly-typed ufunc.
132        """
133        builder = UFuncBuilder(outer,
134                               targetoptions={'nopython': True})
135        builder.add("(int64, int64)")
136        ufunc = builder.build_ufunc()
137        self.assertEqual(ufunc(-1, 3), 2)
138
139    def test_nested_call_explicit(self):
140        """
141        Check nested call to an explicitly-typed ufunc.
142        """
143        builder = UFuncBuilder(outer_explicit,
144                               targetoptions={'nopython': True})
145        builder.add("(int64, int64)")
146        ufunc = builder.build_ufunc()
147        self.assertEqual(ufunc(-1, 3), 2)
148
149
150class TestUfuncBuildingJitDisabled(TestUfuncBuilding):
151
152    def setUp(self):
153        self.old_disable_jit = config.DISABLE_JIT
154        config.DISABLE_JIT = False
155
156    def tearDown(self):
157        config.DISABLE_JIT = self.old_disable_jit
158
159
160class TestGUfuncBuilding(TestCase):
161
162    def test_basic_gufunc(self):
163        gufb = GUFuncBuilder(guadd, "(x, y),(x, y)->(x, y)")
164        cres = gufb.add("void(int32[:,:], int32[:,:], int32[:,:])")
165        self.assertFalse(cres.objectmode)
166        ufunc = gufb.build_ufunc()
167
168        a = np.arange(10, dtype="int32").reshape(2, 5)
169        b = ufunc(a, a)
170
171        self.assertPreciseEqual(a + a, b)
172        self.assertEqual(b.dtype, np.dtype('int32'))
173
174        # Metadata
175        self.assertEqual(ufunc.__name__, "guadd")
176        self.assertIn("A generalized addition", ufunc.__doc__)
177
178    def test_gufunc_struct(self):
179        gufb = GUFuncBuilder(guadd, "(x, y),(x, y)->(x, y)")
180        cres = gufb.add("void(complex64[:,:], complex64[:,:], complex64[:,:])")
181        self.assertFalse(cres.objectmode)
182        ufunc = gufb.build_ufunc()
183
184        a = np.arange(10, dtype="complex64").reshape(2, 5) + 1j
185        b = ufunc(a, a)
186
187        self.assertPreciseEqual(a + a, b)
188
189    def test_gufunc_struct_forceobj(self):
190        gufb = GUFuncBuilder(guadd, "(x, y),(x, y)->(x, y)",
191                             targetoptions=dict(forceobj=True))
192        cres = gufb.add("void(complex64[:,:], complex64[:,:], complex64[:,"
193                        ":])")
194        self.assertTrue(cres.objectmode)
195        ufunc = gufb.build_ufunc()
196
197        a = np.arange(10, dtype="complex64").reshape(2, 5) + 1j
198        b = ufunc(a, a)
199
200        self.assertPreciseEqual(a + a, b)
201
202
203class TestGUfuncBuildingJitDisabled(TestGUfuncBuilding):
204
205    def setUp(self):
206        self.old_disable_jit = config.DISABLE_JIT
207        config.DISABLE_JIT = False
208
209    def tearDown(self):
210        config.DISABLE_JIT = self.old_disable_jit
211
212
213class TestVectorizeDecor(TestCase):
214
215    _supported_identities = [0, 1, None, "reorderable"]
216
217    def test_vectorize(self):
218        ufunc = vectorize(['int32(int32, int32)'])(add)
219        a = np.arange(10, dtype='int32')
220        b = ufunc(a, a)
221        self.assertPreciseEqual(a + a, b)
222
223    def test_vectorize_objmode(self):
224        ufunc = vectorize(['int32(int32, int32)'], forceobj=True)(add)
225        a = np.arange(10, dtype='int32')
226        b = ufunc(a, a)
227        self.assertPreciseEqual(a + a, b)
228
229    def test_vectorize_bool_return(self):
230        ufunc = vectorize(['bool_(int32, int32)'])(equals)
231        a = np.arange(10, dtype='int32')
232        r = ufunc(a,a)
233        self.assertPreciseEqual(r, np.ones(r.shape, dtype=np.bool_))
234
235    def test_vectorize_identity(self):
236        sig = 'int32(int32, int32)'
237        for identity in self._supported_identities:
238            ufunc = vectorize([sig], identity=identity)(add)
239            expected = None if identity == 'reorderable' else identity
240            self.assertEqual(ufunc.identity, expected)
241        # Default value is None
242        ufunc = vectorize([sig])(add)
243        self.assertIs(ufunc.identity, None)
244        # Invalid values
245        with self.assertRaises(ValueError):
246            vectorize([sig], identity='none')(add)
247        with self.assertRaises(ValueError):
248            vectorize([sig], identity=2)(add)
249
250    def test_vectorize_no_args(self):
251        a = np.linspace(0,1,10)
252        b = np.linspace(1,2,10)
253        ufunc = vectorize(add)
254        self.assertPreciseEqual(ufunc(a,b), a + b)
255        ufunc2 = vectorize(add)
256        c = np.empty(10)
257        ufunc2(a, b, c)
258        self.assertPreciseEqual(c, a + b)
259
260    def test_vectorize_only_kws(self):
261        a = np.linspace(0,1,10)
262        b = np.linspace(1,2,10)
263        ufunc = vectorize(identity=PyUFunc_One, nopython=True)(mul)
264        self.assertPreciseEqual(ufunc(a,b), a * b)
265
266    def test_vectorize_output_kwarg(self):
267        """
268        Passing the output array as a keyword argument (issue #1867).
269        """
270        def check(ufunc):
271            a = np.arange(10, 16, dtype='int32')
272            out = np.zeros_like(a)
273            got = ufunc(a, a, out=out)
274            self.assertIs(got, out)
275            self.assertPreciseEqual(out, a + a)
276            with self.assertRaises(TypeError):
277                ufunc(a, a, zzz=out)
278
279        # With explicit sigs
280        ufunc = vectorize(['int32(int32, int32)'], nopython=True)(add)
281        check(ufunc)
282        # With implicit sig
283        ufunc = vectorize(nopython=True)(add)
284        check(ufunc)  # compiling
285        check(ufunc)  # after compiling
286
287    def test_guvectorize(self):
288        ufunc = guvectorize(['(int32[:,:], int32[:,:], int32[:,:])'],
289                            "(x,y),(x,y)->(x,y)")(guadd)
290        a = np.arange(10, dtype='int32').reshape(2, 5)
291        b = ufunc(a, a)
292        self.assertPreciseEqual(a + a, b)
293
294    def test_guvectorize_no_output(self):
295        ufunc = guvectorize(['(int32[:,:], int32[:,:], int32[:,:])'],
296                            "(x,y),(x,y),(x,y)")(guadd)
297        a = np.arange(10, dtype='int32').reshape(2, 5)
298        out = np.zeros_like(a)
299        ufunc(a, a, out)
300        self.assertPreciseEqual(a + a, out)
301
302    def test_guvectorize_objectmode(self):
303        ufunc = guvectorize(['(int32[:,:], int32[:,:], int32[:,:])'],
304                            "(x,y),(x,y)->(x,y)")(guadd_obj)
305        a = np.arange(10, dtype='int32').reshape(2, 5)
306        b = ufunc(a, a)
307        self.assertPreciseEqual(a + a, b)
308
309    def test_guvectorize_scalar_objectmode(self):
310        """
311        Test passing of scalars to object mode gufuncs.
312        """
313        ufunc = guvectorize(['(int32[:,:], int32, int32[:,:])'],
314                            "(x,y),()->(x,y)")(guadd_scalar_obj)
315        a = np.arange(10, dtype='int32').reshape(2, 5)
316        b = ufunc(a, 3)
317        self.assertPreciseEqual(a + 3, b)
318
319    def test_guvectorize_error_in_objectmode(self):
320        ufunc = guvectorize(['(int32[:,:], int32[:,:], int32[:,:])'],
321                            "(x,y),(x,y)->(x,y)", forceobj=True)(guerror)
322        a = np.arange(10, dtype='int32').reshape(2, 5)
323        with self.assertRaises(MyException):
324            ufunc(a, a)
325
326    def test_guvectorize_identity(self):
327        args = (['(int32[:,:], int32[:,:], int32[:,:])'], "(x,y),(x,y)->(x,y)")
328        for identity in self._supported_identities:
329            ufunc = guvectorize(*args, identity=identity)(guadd)
330            expected = None if identity == 'reorderable' else identity
331            self.assertEqual(ufunc.identity, expected)
332        # Default value is None
333        ufunc = guvectorize(*args)(guadd)
334        self.assertIs(ufunc.identity, None)
335        # Invalid values
336        with self.assertRaises(ValueError):
337            guvectorize(*args, identity='none')(add)
338        with self.assertRaises(ValueError):
339            guvectorize(*args, identity=2)(add)
340
341    def test_guvectorize_invalid_layout(self):
342        sigs = ['(int32[:,:], int32[:,:], int32[:,:])']
343        # Syntax error
344        with self.assertRaises(ValueError) as raises:
345            guvectorize(sigs, ")-:")(guadd)
346        self.assertIn("bad token in signature", str(raises.exception))
347        # Output shape can't be inferred from inputs
348        with self.assertRaises(NameError) as raises:
349            guvectorize(sigs, "(x,y),(x,y)->(x,z,v)")(guadd)
350        self.assertEqual(str(raises.exception),
351                         "undefined output symbols: v,z")
352        # Arrow but no outputs
353        with self.assertRaises(ValueError) as raises:
354            guvectorize(sigs, "(x,y),(x,y),(x,y)->")(guadd)
355        # (error message depends on Numpy version)
356
357
358class TestVectorizeDecorJitDisabled(TestVectorizeDecor):
359
360    def setUp(self):
361        self.old_disable_jit = config.DISABLE_JIT
362        config.DISABLE_JIT = False
363
364    def tearDown(self):
365        config.DISABLE_JIT = self.old_disable_jit
366
367
368if __name__ == '__main__':
369    unittest.main()
370