1""" 2Testing object mode specifics. 3 4""" 5 6import numpy as np 7 8import unittest 9from numba.core.compiler import compile_isolated, Flags 10from numba import jit 11from numba.core import utils 12from numba.tests.support import TestCase 13 14 15def complex_constant(n): 16 tmp = n + 4 17 return tmp + 3j 18 19def long_constant(n): 20 return n + 100000000000000000000000000000000000000000000000 21 22def delitem_usecase(x): 23 del x[:] 24 25 26forceobj = Flags() 27forceobj.set("force_pyobject") 28 29 30def loop_nest_3(x, y): 31 n = 0 32 for i in range(x): 33 for j in range(y): 34 for k in range(x+y): 35 n += i * j 36 37 return n 38 39 40def array_of_object(x): 41 return x 42 43 44class TestObjectMode(TestCase): 45 46 def test_complex_constant(self): 47 pyfunc = complex_constant 48 cres = compile_isolated(pyfunc, (), flags=forceobj) 49 cfunc = cres.entry_point 50 self.assertPreciseEqual(pyfunc(12), cfunc(12)) 51 52 def test_long_constant(self): 53 pyfunc = long_constant 54 cres = compile_isolated(pyfunc, (), flags=forceobj) 55 cfunc = cres.entry_point 56 self.assertPreciseEqual(pyfunc(12), cfunc(12)) 57 58 def test_loop_nest(self): 59 """ 60 Test bug that decref the iterator early. 61 If the bug occurs, a segfault should occur 62 """ 63 pyfunc = loop_nest_3 64 cres = compile_isolated(pyfunc, (), flags=forceobj) 65 cfunc = cres.entry_point 66 self.assertEqual(pyfunc(5, 5), cfunc(5, 5)) 67 68 def bm_pyfunc(): 69 pyfunc(5, 5) 70 71 def bm_cfunc(): 72 cfunc(5, 5) 73 74 print(utils.benchmark(bm_pyfunc)) 75 print(utils.benchmark(bm_cfunc)) 76 77 def test_array_of_object(self): 78 cfunc = jit(array_of_object) 79 objarr = np.array([object()] * 10) 80 self.assertIs(cfunc(objarr), objarr) 81 82 def test_sequence_contains(self): 83 """ 84 Test handling of the `in` comparison 85 """ 86 @jit(forceobj=True) 87 def foo(x, y): 88 return x in y 89 90 self.assertTrue(foo(1, [0, 1])) 91 self.assertTrue(foo(0, [0, 1])) 92 self.assertFalse(foo(2, [0, 1])) 93 94 with self.assertRaises(TypeError) as raises: 95 foo(None, None) 96 97 self.assertIn("is not iterable", str(raises.exception)) 98 99 def test_delitem(self): 100 pyfunc = delitem_usecase 101 cres = compile_isolated(pyfunc, (), flags=forceobj) 102 cfunc = cres.entry_point 103 104 l = [3, 4, 5] 105 cfunc(l) 106 self.assertPreciseEqual(l, []) 107 with self.assertRaises(TypeError): 108 cfunc(42) 109 110 111class TestObjectModeInvalidRewrite(TestCase): 112 """ 113 Tests to ensure that rewrite passes didn't affect objmode lowering. 114 """ 115 116 def _ensure_objmode(self, disp): 117 self.assertTrue(disp.signatures) 118 self.assertFalse(disp.nopython_signatures) 119 return disp 120 121 def test_static_raise_in_objmode_fallback(self): 122 """ 123 Test code based on user submitted issue at 124 https://github.com/numba/numba/issues/2159 125 """ 126 def test0(n): 127 return n 128 129 def test1(n): 130 if n == 0: 131 # static raise will fail in objmode if the IR is modified by 132 # rewrite pass 133 raise ValueError() 134 return test0(n) # trigger objmode fallback 135 136 compiled = jit(test1) 137 self.assertEqual(test1(10), compiled(10)) 138 self._ensure_objmode(compiled) 139 140 def test_static_setitem_in_objmode_fallback(self): 141 """ 142 Test code based on user submitted issue at 143 https://github.com/numba/numba/issues/2169 144 """ 145 146 def test0(n): 147 return n 148 149 def test(a1, a2): 150 a1 = np.asarray(a1) 151 # static setitem here will fail in objmode if the IR is modified by 152 # rewrite pass 153 a2[0] = 1 154 return test0(a1.sum() + a2.sum()) # trigger objmode fallback 155 156 compiled = jit(test) 157 args = np.array([3]), np.array([4]) 158 self.assertEqual(test(*args), compiled(*args)) 159 self._ensure_objmode(compiled) 160 161 def test_dynamic_func_objmode(self): 162 """ 163 Test issue https://github.com/numba/numba/issues/3355 164 """ 165 func_text = "def func():\n" 166 func_text += " np.array([1,2,3])\n" 167 loc_vars = {} 168 custom_globals = {'np': np} 169 exec(func_text, custom_globals, loc_vars) 170 func = loc_vars['func'] 171 jitted = jit(forceobj=True)(func) 172 jitted() 173 174 175if __name__ == '__main__': 176 unittest.main() 177