1import numpy as np 2import sys 3import traceback 4 5from numba.core.compiler import compile_isolated, Flags 6from numba import jit, njit 7from numba.core import types, errors 8from numba.tests.support import TestCase 9import unittest 10 11force_pyobj_flags = Flags() 12force_pyobj_flags.set("force_pyobject") 13 14no_pyobj_flags = Flags() 15 16 17class MyError(Exception): 18 pass 19 20 21class OtherError(Exception): 22 pass 23 24 25class UDEArgsToSuper(Exception): 26 def __init__(self, arg, value0): 27 super(UDEArgsToSuper, self).__init__(arg) 28 self.value0 = value0 29 30 def __eq__(self, other): 31 if not isinstance(other, self.__class__): 32 return False 33 same = True 34 same |= self.args == other.args 35 same |= self.value0 == other.value0 36 return same 37 38 def __hash__(self): 39 return hash((super(UDEArgsToSuper).__hash__(), self.value0)) 40 41 42class UDENoArgSuper(Exception): 43 def __init__(self, arg, value0): 44 super(UDENoArgSuper, self).__init__() 45 self.deferarg = arg 46 self.value0 = value0 47 48 def __eq__(self, other): 49 if not isinstance(other, self.__class__): 50 return False 51 same = True 52 same |= self.args == other.args 53 same |= self.deferarg == other.deferarg 54 same |= self.value0 == other.value0 55 return same 56 57 def __hash__(self): 58 return hash((super(UDENoArgSuper).__hash__(), self.deferarg, 59 self.value0)) 60 61 62def raise_class(exc): 63 def raiser(i): 64 if i == 1: 65 raise exc 66 elif i == 2: 67 raise ValueError 68 elif i == 3: 69 # The exception type is looked up on a module (issue #1624) 70 raise np.linalg.LinAlgError 71 return i 72 return raiser 73 74 75def raise_instance(exc, arg): 76 def raiser(i): 77 if i == 1: 78 raise exc(arg, 1) 79 elif i == 2: 80 raise ValueError(arg, 2) 81 elif i == 3: 82 raise np.linalg.LinAlgError(arg, 3) 83 return i 84 return raiser 85 86 87def reraise(): 88 raise 89 90 91def outer_function(inner): 92 def outer(i): 93 if i == 3: 94 raise OtherError("bar", 3) 95 return inner(i) 96 return outer 97 98 99def assert_usecase(i): 100 assert i == 1, "bar" 101 102 103def ude_bug_usecase(): 104 raise UDEArgsToSuper() # oops user forgot args to exception ctor 105 106 107class TestRaising(TestCase): 108 109 def test_unituple_index_error(self): 110 def pyfunc(a, i): 111 return a.shape[i] 112 113 cres = compile_isolated(pyfunc, (types.Array(types.int32, 1, 'A'), 114 types.int32)) 115 116 cfunc = cres.entry_point 117 a = np.empty(2, dtype=np.int32) 118 119 self.assertEqual(cfunc(a, 0), pyfunc(a, 0)) 120 121 with self.assertRaises(IndexError) as cm: 122 cfunc(a, 2) 123 self.assertEqual(str(cm.exception), "tuple index out of range") 124 125 def check_against_python(self, exec_mode, pyfunc, cfunc, 126 expected_error_class, *args): 127 128 assert exec_mode in (force_pyobj_flags, no_pyobj_flags) 129 130 # invariant of mode, check the error class and args are the same 131 with self.assertRaises(expected_error_class) as pyerr: 132 pyfunc(*args) 133 with self.assertRaises(expected_error_class) as jiterr: 134 cfunc(*args) 135 self.assertEqual(pyerr.exception.args, jiterr.exception.args) 136 137 # special equality check for UDEs 138 if isinstance(pyerr.exception, (UDEArgsToSuper, UDENoArgSuper)): 139 self.assertTrue(pyerr.exception == jiterr.exception) 140 141 # in npm check bottom of traceback matches as frame injection with 142 # location info should ensure this 143 if exec_mode is no_pyobj_flags: 144 145 # we only care about the bottom two frames, the error and the 146 # location it was raised. 147 try: 148 pyfunc(*args) 149 except Exception: 150 py_frames = traceback.format_exception(*sys.exc_info()) 151 expected_frames = py_frames[-2:] 152 153 try: 154 cfunc(*args) 155 except Exception: 156 c_frames = traceback.format_exception(*sys.exc_info()) 157 got_frames = c_frames[-2:] 158 159 # check exception and the injected frame are the same 160 for expf, gotf in zip(expected_frames, got_frames): 161 self.assertEqual(expf, gotf) 162 163 def check_raise_class(self, flags): 164 pyfunc = raise_class(MyError) 165 cres = compile_isolated(pyfunc, (types.int32,), flags=flags) 166 cfunc = cres.entry_point 167 self.assertEqual(cfunc(0), 0) 168 self.check_against_python(flags, pyfunc, cfunc, MyError, 1) 169 self.check_against_python(flags, pyfunc, cfunc, ValueError, 2) 170 self.check_against_python(flags, pyfunc, cfunc, 171 np.linalg.linalg.LinAlgError, 3) 172 173 def test_raise_class_nopython(self): 174 self.check_raise_class(flags=no_pyobj_flags) 175 176 def test_raise_class_objmode(self): 177 self.check_raise_class(flags=force_pyobj_flags) 178 179 def check_raise_instance(self, flags): 180 for clazz in [MyError, UDEArgsToSuper, 181 UDENoArgSuper]: 182 pyfunc = raise_instance(clazz, "some message") 183 cres = compile_isolated(pyfunc, (types.int32,), flags=flags) 184 cfunc = cres.entry_point 185 186 self.assertEqual(cfunc(0), 0) 187 self.check_against_python(flags, pyfunc, cfunc, clazz, 1) 188 self.check_against_python(flags, pyfunc, cfunc, ValueError, 2) 189 self.check_against_python(flags, pyfunc, cfunc, 190 np.linalg.linalg.LinAlgError, 3) 191 192 def test_raise_instance_objmode(self): 193 self.check_raise_instance(flags=force_pyobj_flags) 194 195 def test_raise_instance_nopython(self): 196 self.check_raise_instance(flags=no_pyobj_flags) 197 198 def check_raise_nested(self, flags, **jit_args): 199 """ 200 Check exception propagation from nested functions. 201 """ 202 for clazz in [MyError, UDEArgsToSuper, 203 UDENoArgSuper]: 204 inner_pyfunc = raise_instance(clazz, "some message") 205 pyfunc = outer_function(inner_pyfunc) 206 inner_cfunc = jit(**jit_args)(inner_pyfunc) 207 cfunc = jit(**jit_args)(outer_function(inner_cfunc)) 208 209 self.check_against_python(flags, pyfunc, cfunc, clazz, 1) 210 self.check_against_python(flags, pyfunc, cfunc, ValueError, 2) 211 self.check_against_python(flags, pyfunc, cfunc, OtherError, 3) 212 213 def test_raise_nested_objmode(self): 214 self.check_raise_nested(force_pyobj_flags, forceobj=True) 215 216 def test_raise_nested_nopython(self): 217 self.check_raise_nested(no_pyobj_flags, nopython=True) 218 219 def check_reraise(self, flags): 220 def raise_exc(exc): 221 raise exc 222 pyfunc = reraise 223 cres = compile_isolated(pyfunc, (), flags=flags) 224 cfunc = cres.entry_point 225 for op, err in [(lambda : raise_exc(ZeroDivisionError), 226 ZeroDivisionError), 227 (lambda : raise_exc(UDEArgsToSuper("msg", 1)), 228 UDEArgsToSuper), 229 (lambda : raise_exc(UDENoArgSuper("msg", 1)), 230 UDENoArgSuper)]: 231 def gen_impl(fn): 232 def impl(): 233 try: 234 op() 235 except err: 236 fn() 237 return impl 238 pybased = gen_impl(pyfunc) 239 cbased = gen_impl(cfunc) 240 self.check_against_python(flags, pybased, cbased, err,) 241 242 def test_reraise_objmode(self): 243 self.check_reraise(flags=force_pyobj_flags) 244 245 def test_reraise_nopython(self): 246 self.check_reraise(flags=no_pyobj_flags) 247 248 def check_raise_invalid_class(self, cls, flags): 249 pyfunc = raise_class(cls) 250 cres = compile_isolated(pyfunc, (types.int32,), flags=flags) 251 cfunc = cres.entry_point 252 with self.assertRaises(TypeError) as cm: 253 cfunc(1) 254 self.assertEqual(str(cm.exception), 255 "exceptions must derive from BaseException") 256 257 def test_raise_invalid_class_objmode(self): 258 self.check_raise_invalid_class(int, flags=force_pyobj_flags) 259 self.check_raise_invalid_class(1, flags=force_pyobj_flags) 260 261 def test_raise_invalid_class_nopython(self): 262 msg = "Encountered unsupported constant type used for exception" 263 with self.assertRaises(errors.UnsupportedError) as raises: 264 self.check_raise_invalid_class(int, flags=no_pyobj_flags) 265 self.assertIn(msg, str(raises.exception)) 266 with self.assertRaises(errors.UnsupportedError) as raises: 267 self.check_raise_invalid_class(1, flags=no_pyobj_flags) 268 self.assertIn(msg, str(raises.exception)) 269 270 def test_raise_bare_string_nopython(self): 271 @njit 272 def foo(): 273 raise "illegal" 274 msg = ("Directly raising a string constant as an exception is not " 275 "supported") 276 with self.assertRaises(errors.UnsupportedError) as raises: 277 foo() 278 self.assertIn(msg, str(raises.exception)) 279 280 def check_assert_statement(self, flags): 281 pyfunc = assert_usecase 282 cres = compile_isolated(pyfunc, (types.int32,), flags=flags) 283 cfunc = cres.entry_point 284 cfunc(1) 285 self.check_against_python(flags, pyfunc, cfunc, AssertionError, 2) 286 287 def test_assert_statement_objmode(self): 288 self.check_assert_statement(flags=force_pyobj_flags) 289 290 def test_assert_statement_nopython(self): 291 self.check_assert_statement(flags=no_pyobj_flags) 292 293 def check_raise_from_exec_string(self, flags): 294 # issue #3428 295 simple_raise = "def f(a):\n raise exc('msg', 10)" 296 assert_raise = "def f(a):\n assert a != 1" 297 for f_text, exc in [(assert_raise, AssertionError), 298 (simple_raise, UDEArgsToSuper), 299 (simple_raise, UDENoArgSuper)]: 300 loc = {} 301 exec(f_text, {'exc': exc}, loc) 302 pyfunc = loc['f'] 303 cres = compile_isolated(pyfunc, (types.int32,), flags=flags) 304 cfunc = cres.entry_point 305 self.check_against_python(flags, pyfunc, cfunc, exc, 1) 306 307 def test_assert_from_exec_string_objmode(self): 308 self.check_raise_from_exec_string(flags=force_pyobj_flags) 309 310 def test_assert_from_exec_string_nopython(self): 311 self.check_raise_from_exec_string(flags=no_pyobj_flags) 312 313 def check_user_code_error_traceback(self, flags): 314 # this test checks that if a user tries to compile code that contains 315 # a bug in exception initialisation (e.g. missing arg) then this also 316 # has a frame injected with the location information. 317 pyfunc = ude_bug_usecase 318 cres = compile_isolated(pyfunc, (), flags=flags) 319 cfunc = cres.entry_point 320 self.check_against_python(flags, pyfunc, cfunc, TypeError) 321 322 def test_user_code_error_traceback_objmode(self): 323 self.check_user_code_error_traceback(flags=force_pyobj_flags) 324 325 def test_user_code_error_traceback_nopython(self): 326 self.check_user_code_error_traceback(flags=no_pyobj_flags) 327 328 329if __name__ == '__main__': 330 unittest.main() 331