1import numpy as np
2import sys
3import traceback
4
5from numba.core.compiler import compile_isolated, Flags
6from numba import jit, njit
7from numba.core import types, errors
8from numba.tests.support import TestCase
9import unittest
10
11force_pyobj_flags = Flags()
12force_pyobj_flags.set("force_pyobject")
13
14no_pyobj_flags = Flags()
15
16
17class MyError(Exception):
18    pass
19
20
21class OtherError(Exception):
22    pass
23
24
25class UDEArgsToSuper(Exception):
26    def __init__(self, arg, value0):
27        super(UDEArgsToSuper, self).__init__(arg)
28        self.value0 = value0
29
30    def __eq__(self, other):
31        if not isinstance(other, self.__class__):
32            return False
33        same = True
34        same |= self.args == other.args
35        same |= self.value0 == other.value0
36        return same
37
38    def __hash__(self):
39        return hash((super(UDEArgsToSuper).__hash__(), self.value0))
40
41
42class UDENoArgSuper(Exception):
43    def __init__(self, arg, value0):
44        super(UDENoArgSuper, self).__init__()
45        self.deferarg = arg
46        self.value0 = value0
47
48    def __eq__(self, other):
49        if not isinstance(other, self.__class__):
50            return False
51        same = True
52        same |= self.args == other.args
53        same |= self.deferarg == other.deferarg
54        same |= self.value0 == other.value0
55        return same
56
57    def __hash__(self):
58        return hash((super(UDENoArgSuper).__hash__(), self.deferarg,
59                     self.value0))
60
61
62def raise_class(exc):
63    def raiser(i):
64        if i == 1:
65            raise exc
66        elif i == 2:
67            raise ValueError
68        elif i == 3:
69            # The exception type is looked up on a module (issue #1624)
70            raise np.linalg.LinAlgError
71        return i
72    return raiser
73
74
75def raise_instance(exc, arg):
76    def raiser(i):
77        if i == 1:
78            raise exc(arg, 1)
79        elif i == 2:
80            raise ValueError(arg, 2)
81        elif i == 3:
82            raise np.linalg.LinAlgError(arg, 3)
83        return i
84    return raiser
85
86
87def reraise():
88    raise
89
90
91def outer_function(inner):
92    def outer(i):
93        if i == 3:
94            raise OtherError("bar", 3)
95        return inner(i)
96    return outer
97
98
99def assert_usecase(i):
100    assert i == 1, "bar"
101
102
103def ude_bug_usecase():
104    raise UDEArgsToSuper()  # oops user forgot args to exception ctor
105
106
107class TestRaising(TestCase):
108
109    def test_unituple_index_error(self):
110        def pyfunc(a, i):
111            return a.shape[i]
112
113        cres = compile_isolated(pyfunc, (types.Array(types.int32, 1, 'A'),
114                                         types.int32))
115
116        cfunc = cres.entry_point
117        a = np.empty(2, dtype=np.int32)
118
119        self.assertEqual(cfunc(a, 0), pyfunc(a, 0))
120
121        with self.assertRaises(IndexError) as cm:
122            cfunc(a, 2)
123        self.assertEqual(str(cm.exception), "tuple index out of range")
124
125    def check_against_python(self, exec_mode, pyfunc, cfunc,
126                             expected_error_class, *args):
127
128        assert exec_mode in (force_pyobj_flags, no_pyobj_flags)
129
130        # invariant of mode, check the error class and args are the same
131        with self.assertRaises(expected_error_class) as pyerr:
132            pyfunc(*args)
133        with self.assertRaises(expected_error_class) as jiterr:
134            cfunc(*args)
135        self.assertEqual(pyerr.exception.args, jiterr.exception.args)
136
137        # special equality check for UDEs
138        if isinstance(pyerr.exception, (UDEArgsToSuper, UDENoArgSuper)):
139            self.assertTrue(pyerr.exception == jiterr.exception)
140
141        # in npm check bottom of traceback matches as frame injection with
142        # location info should ensure this
143        if exec_mode is no_pyobj_flags:
144
145            # we only care about the bottom two frames, the error and the
146            # location it was raised.
147            try:
148                pyfunc(*args)
149            except Exception:
150                py_frames = traceback.format_exception(*sys.exc_info())
151                expected_frames = py_frames[-2:]
152
153            try:
154                cfunc(*args)
155            except Exception:
156                c_frames = traceback.format_exception(*sys.exc_info())
157                got_frames = c_frames[-2:]
158
159            # check exception and the injected frame are the same
160            for expf, gotf in zip(expected_frames, got_frames):
161                self.assertEqual(expf, gotf)
162
163    def check_raise_class(self, flags):
164        pyfunc = raise_class(MyError)
165        cres = compile_isolated(pyfunc, (types.int32,), flags=flags)
166        cfunc = cres.entry_point
167        self.assertEqual(cfunc(0), 0)
168        self.check_against_python(flags, pyfunc, cfunc, MyError, 1)
169        self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
170        self.check_against_python(flags, pyfunc, cfunc,
171                                  np.linalg.linalg.LinAlgError, 3)
172
173    def test_raise_class_nopython(self):
174        self.check_raise_class(flags=no_pyobj_flags)
175
176    def test_raise_class_objmode(self):
177        self.check_raise_class(flags=force_pyobj_flags)
178
179    def check_raise_instance(self, flags):
180        for clazz in [MyError, UDEArgsToSuper,
181                      UDENoArgSuper]:
182            pyfunc = raise_instance(clazz, "some message")
183            cres = compile_isolated(pyfunc, (types.int32,), flags=flags)
184            cfunc = cres.entry_point
185
186            self.assertEqual(cfunc(0), 0)
187            self.check_against_python(flags, pyfunc, cfunc, clazz, 1)
188            self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
189            self.check_against_python(flags, pyfunc, cfunc,
190                                      np.linalg.linalg.LinAlgError, 3)
191
192    def test_raise_instance_objmode(self):
193        self.check_raise_instance(flags=force_pyobj_flags)
194
195    def test_raise_instance_nopython(self):
196        self.check_raise_instance(flags=no_pyobj_flags)
197
198    def check_raise_nested(self, flags, **jit_args):
199        """
200        Check exception propagation from nested functions.
201        """
202        for clazz in [MyError, UDEArgsToSuper,
203                      UDENoArgSuper]:
204            inner_pyfunc = raise_instance(clazz, "some message")
205            pyfunc = outer_function(inner_pyfunc)
206            inner_cfunc = jit(**jit_args)(inner_pyfunc)
207            cfunc = jit(**jit_args)(outer_function(inner_cfunc))
208
209            self.check_against_python(flags, pyfunc, cfunc, clazz, 1)
210            self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
211            self.check_against_python(flags, pyfunc, cfunc, OtherError, 3)
212
213    def test_raise_nested_objmode(self):
214        self.check_raise_nested(force_pyobj_flags, forceobj=True)
215
216    def test_raise_nested_nopython(self):
217        self.check_raise_nested(no_pyobj_flags, nopython=True)
218
219    def check_reraise(self, flags):
220        def raise_exc(exc):
221            raise exc
222        pyfunc = reraise
223        cres = compile_isolated(pyfunc, (), flags=flags)
224        cfunc = cres.entry_point
225        for op, err in [(lambda : raise_exc(ZeroDivisionError),
226                         ZeroDivisionError),
227                        (lambda : raise_exc(UDEArgsToSuper("msg", 1)),
228                         UDEArgsToSuper),
229                        (lambda : raise_exc(UDENoArgSuper("msg", 1)),
230                         UDENoArgSuper)]:
231            def gen_impl(fn):
232                def impl():
233                    try:
234                        op()
235                    except err:
236                        fn()
237                return impl
238            pybased = gen_impl(pyfunc)
239            cbased = gen_impl(cfunc)
240            self.check_against_python(flags, pybased, cbased, err,)
241
242    def test_reraise_objmode(self):
243        self.check_reraise(flags=force_pyobj_flags)
244
245    def test_reraise_nopython(self):
246        self.check_reraise(flags=no_pyobj_flags)
247
248    def check_raise_invalid_class(self, cls, flags):
249        pyfunc = raise_class(cls)
250        cres = compile_isolated(pyfunc, (types.int32,), flags=flags)
251        cfunc = cres.entry_point
252        with self.assertRaises(TypeError) as cm:
253            cfunc(1)
254        self.assertEqual(str(cm.exception),
255                         "exceptions must derive from BaseException")
256
257    def test_raise_invalid_class_objmode(self):
258        self.check_raise_invalid_class(int, flags=force_pyobj_flags)
259        self.check_raise_invalid_class(1, flags=force_pyobj_flags)
260
261    def test_raise_invalid_class_nopython(self):
262        msg = "Encountered unsupported constant type used for exception"
263        with self.assertRaises(errors.UnsupportedError) as raises:
264            self.check_raise_invalid_class(int, flags=no_pyobj_flags)
265        self.assertIn(msg, str(raises.exception))
266        with self.assertRaises(errors.UnsupportedError) as raises:
267            self.check_raise_invalid_class(1, flags=no_pyobj_flags)
268        self.assertIn(msg, str(raises.exception))
269
270    def test_raise_bare_string_nopython(self):
271        @njit
272        def foo():
273            raise "illegal"
274        msg = ("Directly raising a string constant as an exception is not "
275               "supported")
276        with self.assertRaises(errors.UnsupportedError) as raises:
277            foo()
278        self.assertIn(msg, str(raises.exception))
279
280    def check_assert_statement(self, flags):
281        pyfunc = assert_usecase
282        cres = compile_isolated(pyfunc, (types.int32,), flags=flags)
283        cfunc = cres.entry_point
284        cfunc(1)
285        self.check_against_python(flags, pyfunc, cfunc, AssertionError, 2)
286
287    def test_assert_statement_objmode(self):
288        self.check_assert_statement(flags=force_pyobj_flags)
289
290    def test_assert_statement_nopython(self):
291        self.check_assert_statement(flags=no_pyobj_flags)
292
293    def check_raise_from_exec_string(self, flags):
294        # issue #3428
295        simple_raise = "def f(a):\n  raise exc('msg', 10)"
296        assert_raise = "def f(a):\n  assert a != 1"
297        for f_text, exc in [(assert_raise, AssertionError),
298                            (simple_raise, UDEArgsToSuper),
299                            (simple_raise, UDENoArgSuper)]:
300            loc = {}
301            exec(f_text, {'exc': exc}, loc)
302            pyfunc = loc['f']
303            cres = compile_isolated(pyfunc, (types.int32,), flags=flags)
304            cfunc = cres.entry_point
305            self.check_against_python(flags, pyfunc, cfunc, exc, 1)
306
307    def test_assert_from_exec_string_objmode(self):
308        self.check_raise_from_exec_string(flags=force_pyobj_flags)
309
310    def test_assert_from_exec_string_nopython(self):
311        self.check_raise_from_exec_string(flags=no_pyobj_flags)
312
313    def check_user_code_error_traceback(self, flags):
314        # this test checks that if a user tries to compile code that contains
315        # a bug in exception initialisation (e.g. missing arg) then this also
316        # has a frame injected with the location information.
317        pyfunc = ude_bug_usecase
318        cres = compile_isolated(pyfunc, (), flags=flags)
319        cfunc = cres.entry_point
320        self.check_against_python(flags, pyfunc, cfunc, TypeError)
321
322    def test_user_code_error_traceback_objmode(self):
323        self.check_user_code_error_traceback(flags=force_pyobj_flags)
324
325    def test_user_code_error_traceback_nopython(self):
326        self.check_user_code_error_traceback(flags=no_pyobj_flags)
327
328
329if __name__ == '__main__':
330    unittest.main()
331