1"""
2Testing object mode specifics.
3
4"""
5
6import numpy as np
7
8import unittest
9from numba.core.compiler import compile_isolated, Flags
10from numba import jit
11from numba.core import utils
12from numba.tests.support import TestCase
13
14
15def complex_constant(n):
16    tmp = n + 4
17    return tmp + 3j
18
19def long_constant(n):
20    return n + 100000000000000000000000000000000000000000000000
21
22def delitem_usecase(x):
23    del x[:]
24
25
26forceobj = Flags()
27forceobj.set("force_pyobject")
28
29
30def loop_nest_3(x, y):
31    n = 0
32    for i in range(x):
33        for j in range(y):
34            for k in range(x+y):
35                n += i * j
36
37    return n
38
39
40def array_of_object(x):
41    return x
42
43
44class TestObjectMode(TestCase):
45
46    def test_complex_constant(self):
47        pyfunc = complex_constant
48        cres = compile_isolated(pyfunc, (), flags=forceobj)
49        cfunc = cres.entry_point
50        self.assertPreciseEqual(pyfunc(12), cfunc(12))
51
52    def test_long_constant(self):
53        pyfunc = long_constant
54        cres = compile_isolated(pyfunc, (), flags=forceobj)
55        cfunc = cres.entry_point
56        self.assertPreciseEqual(pyfunc(12), cfunc(12))
57
58    def test_loop_nest(self):
59        """
60        Test bug that decref the iterator early.
61        If the bug occurs, a segfault should occur
62        """
63        pyfunc = loop_nest_3
64        cres = compile_isolated(pyfunc, (), flags=forceobj)
65        cfunc = cres.entry_point
66        self.assertEqual(pyfunc(5, 5), cfunc(5, 5))
67
68        def bm_pyfunc():
69            pyfunc(5, 5)
70
71        def bm_cfunc():
72            cfunc(5, 5)
73
74        print(utils.benchmark(bm_pyfunc))
75        print(utils.benchmark(bm_cfunc))
76
77    def test_array_of_object(self):
78        cfunc = jit(array_of_object)
79        objarr = np.array([object()] * 10)
80        self.assertIs(cfunc(objarr), objarr)
81
82    def test_sequence_contains(self):
83        """
84        Test handling of the `in` comparison
85        """
86        @jit(forceobj=True)
87        def foo(x, y):
88            return x in y
89
90        self.assertTrue(foo(1, [0, 1]))
91        self.assertTrue(foo(0, [0, 1]))
92        self.assertFalse(foo(2, [0, 1]))
93
94        with self.assertRaises(TypeError) as raises:
95            foo(None, None)
96
97        self.assertIn("is not iterable", str(raises.exception))
98
99    def test_delitem(self):
100        pyfunc = delitem_usecase
101        cres = compile_isolated(pyfunc, (), flags=forceobj)
102        cfunc = cres.entry_point
103
104        l = [3, 4, 5]
105        cfunc(l)
106        self.assertPreciseEqual(l, [])
107        with self.assertRaises(TypeError):
108            cfunc(42)
109
110
111class TestObjectModeInvalidRewrite(TestCase):
112    """
113    Tests to ensure that rewrite passes didn't affect objmode lowering.
114    """
115
116    def _ensure_objmode(self, disp):
117        self.assertTrue(disp.signatures)
118        self.assertFalse(disp.nopython_signatures)
119        return disp
120
121    def test_static_raise_in_objmode_fallback(self):
122        """
123        Test code based on user submitted issue at
124        https://github.com/numba/numba/issues/2159
125        """
126        def test0(n):
127            return n
128
129        def test1(n):
130            if n == 0:
131                # static raise will fail in objmode if the IR is modified by
132                # rewrite pass
133                raise ValueError()
134            return test0(n)  # trigger objmode fallback
135
136        compiled = jit(test1)
137        self.assertEqual(test1(10), compiled(10))
138        self._ensure_objmode(compiled)
139
140    def test_static_setitem_in_objmode_fallback(self):
141        """
142        Test code based on user submitted issue at
143        https://github.com/numba/numba/issues/2169
144        """
145
146        def test0(n):
147            return n
148
149        def test(a1, a2):
150            a1 = np.asarray(a1)
151            # static setitem here will fail in objmode if the IR is modified by
152            # rewrite pass
153            a2[0] = 1
154            return test0(a1.sum() + a2.sum())   # trigger objmode fallback
155
156        compiled = jit(test)
157        args = np.array([3]), np.array([4])
158        self.assertEqual(test(*args), compiled(*args))
159        self._ensure_objmode(compiled)
160
161    def test_dynamic_func_objmode(self):
162        """
163        Test issue https://github.com/numba/numba/issues/3355
164        """
165        func_text = "def func():\n"
166        func_text += "    np.array([1,2,3])\n"
167        loc_vars = {}
168        custom_globals = {'np': np}
169        exec(func_text, custom_globals, loc_vars)
170        func = loc_vars['func']
171        jitted = jit(forceobj=True)(func)
172        jitted()
173
174
175if __name__ == '__main__':
176    unittest.main()
177