1import logging
2import pytest
3
4from traceback import (
5    extract_tb,
6    print_exception,
7    format_exception,
8)
9from traceback import _cause_message  # type: ignore
10import sys
11import os
12import re
13from pathlib import Path
14import subprocess
15
16from .tutil import slow
17
18from .._multierror import MultiError, concat_tb
19from ..._core import open_nursery
20
21
22class NotHashableException(Exception):
23    code = None
24
25    def __init__(self, code):
26        super().__init__()
27        self.code = code
28
29    def __eq__(self, other):
30        if not isinstance(other, NotHashableException):
31            return False
32        return self.code == other.code
33
34
35async def raise_nothashable(code):
36    raise NotHashableException(code)
37
38
39def raiser1():
40    raiser1_2()
41
42
43def raiser1_2():
44    raiser1_3()
45
46
47def raiser1_3():
48    raise ValueError("raiser1_string")
49
50
51def raiser2():
52    raiser2_2()
53
54
55def raiser2_2():
56    raise KeyError("raiser2_string")
57
58
59def raiser3():
60    raise NameError
61
62
63def get_exc(raiser):
64    try:
65        raiser()
66    except Exception as exc:
67        return exc
68
69
70def get_tb(raiser):
71    return get_exc(raiser).__traceback__
72
73
74def einfo(exc):
75    return (type(exc), exc, exc.__traceback__)
76
77
78def test_concat_tb():
79
80    tb1 = get_tb(raiser1)
81    tb2 = get_tb(raiser2)
82
83    # These return a list of (filename, lineno, fn name, text) tuples
84    # https://docs.python.org/3/library/traceback.html#traceback.extract_tb
85    entries1 = extract_tb(tb1)
86    entries2 = extract_tb(tb2)
87
88    tb12 = concat_tb(tb1, tb2)
89    assert extract_tb(tb12) == entries1 + entries2
90
91    tb21 = concat_tb(tb2, tb1)
92    assert extract_tb(tb21) == entries2 + entries1
93
94    # Check degenerate cases
95    assert extract_tb(concat_tb(None, tb1)) == entries1
96    assert extract_tb(concat_tb(tb1, None)) == entries1
97    assert concat_tb(None, None) is None
98
99    # Make sure the original tracebacks didn't get mutated by mistake
100    assert extract_tb(get_tb(raiser1)) == entries1
101    assert extract_tb(get_tb(raiser2)) == entries2
102
103
104def test_MultiError():
105    exc1 = get_exc(raiser1)
106    exc2 = get_exc(raiser2)
107
108    assert MultiError([exc1]) is exc1
109    m = MultiError([exc1, exc2])
110    assert m.exceptions == [exc1, exc2]
111    assert "ValueError" in str(m)
112    assert "ValueError" in repr(m)
113
114    with pytest.raises(TypeError):
115        MultiError(object())
116    with pytest.raises(TypeError):
117        MultiError([KeyError(), ValueError])
118
119
120def test_MultiErrorOfSingleMultiError():
121    # For MultiError([MultiError]), ensure there is no bad recursion by the
122    # constructor where __init__ is called if __new__ returns a bare MultiError.
123    exceptions = [KeyError(), ValueError()]
124    a = MultiError(exceptions)
125    b = MultiError([a])
126    assert b == a
127    assert b.exceptions == exceptions
128
129
130async def test_MultiErrorNotHashable():
131    exc1 = NotHashableException(42)
132    exc2 = NotHashableException(4242)
133    exc3 = ValueError()
134    assert exc1 != exc2
135    assert exc1 != exc3
136
137    with pytest.raises(MultiError):
138        async with open_nursery() as nursery:
139            nursery.start_soon(raise_nothashable, 42)
140            nursery.start_soon(raise_nothashable, 4242)
141
142
143def test_MultiError_filter_NotHashable():
144    excs = MultiError([NotHashableException(42), ValueError()])
145
146    def handle_ValueError(exc):
147        if isinstance(exc, ValueError):
148            return None
149        else:
150            return exc
151
152    filtered_excs = MultiError.filter(handle_ValueError, excs)
153    assert isinstance(filtered_excs, NotHashableException)
154
155
156def test_traceback_recursion():
157    exc1 = RuntimeError()
158    exc2 = KeyError()
159    exc3 = NotHashableException(42)
160    # Note how this creates a loop, where exc1 refers to exc1
161    # This could trigger an infinite recursion; the 'seen' set is supposed to prevent
162    # this.
163    exc1.__cause__ = MultiError([exc1, exc2, exc3])
164    # python traceback.TracebackException < 3.6.4 does not support unhashable exceptions
165    # and raises a TypeError exception
166    if sys.version_info < (3, 6, 4):
167        with pytest.raises(TypeError):
168            format_exception(*einfo(exc1))
169    else:
170        format_exception(*einfo(exc1))
171
172
173def make_tree():
174    # Returns an object like:
175    #   MultiError([
176    #     MultiError([
177    #       ValueError,
178    #       KeyError,
179    #     ]),
180    #     NameError,
181    #   ])
182    # where all exceptions except the root have a non-trivial traceback.
183    exc1 = get_exc(raiser1)
184    exc2 = get_exc(raiser2)
185    exc3 = get_exc(raiser3)
186
187    # Give m12 a non-trivial traceback
188    try:
189        raise MultiError([exc1, exc2])
190    except BaseException as m12:
191        return MultiError([m12, exc3])
192
193
194def assert_tree_eq(m1, m2):
195    if m1 is None or m2 is None:
196        assert m1 is m2
197        return
198    assert type(m1) is type(m2)
199    assert extract_tb(m1.__traceback__) == extract_tb(m2.__traceback__)
200    assert_tree_eq(m1.__cause__, m2.__cause__)
201    assert_tree_eq(m1.__context__, m2.__context__)
202    if isinstance(m1, MultiError):
203        assert len(m1.exceptions) == len(m2.exceptions)
204        for e1, e2 in zip(m1.exceptions, m2.exceptions):
205            assert_tree_eq(e1, e2)
206
207
208def test_MultiError_filter():
209    def null_handler(exc):
210        return exc
211
212    m = make_tree()
213    assert_tree_eq(m, m)
214    assert MultiError.filter(null_handler, m) is m
215    assert_tree_eq(m, make_tree())
216
217    # Make sure we don't pick up any detritus if run in a context where
218    # implicit exception chaining would like to kick in
219    m = make_tree()
220    try:
221        raise ValueError
222    except ValueError:
223        assert MultiError.filter(null_handler, m) is m
224    assert_tree_eq(m, make_tree())
225
226    def simple_filter(exc):
227        if isinstance(exc, ValueError):
228            return None
229        if isinstance(exc, KeyError):
230            return RuntimeError()
231        return exc
232
233    new_m = MultiError.filter(simple_filter, make_tree())
234    assert isinstance(new_m, MultiError)
235    assert len(new_m.exceptions) == 2
236    # was: [[ValueError, KeyError], NameError]
237    # ValueError disappeared & KeyError became RuntimeError, so now:
238    assert isinstance(new_m.exceptions[0], RuntimeError)
239    assert isinstance(new_m.exceptions[1], NameError)
240
241    # implicit chaining:
242    assert isinstance(new_m.exceptions[0].__context__, KeyError)
243
244    # also, the traceback on the KeyError incorporates what used to be the
245    # traceback on its parent MultiError
246    orig = make_tree()
247    # make sure we have the right path
248    assert isinstance(orig.exceptions[0].exceptions[1], KeyError)
249    # get original traceback summary
250    orig_extracted = (
251        extract_tb(orig.__traceback__)
252        + extract_tb(orig.exceptions[0].__traceback__)
253        + extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
254    )
255
256    def p(exc):
257        print_exception(type(exc), exc, exc.__traceback__)
258
259    p(orig)
260    p(orig.exceptions[0])
261    p(orig.exceptions[0].exceptions[1])
262    p(new_m.exceptions[0].__context__)
263    # compare to the new path
264    assert new_m.__traceback__ is None
265    new_extracted = extract_tb(new_m.exceptions[0].__context__.__traceback__)
266    assert orig_extracted == new_extracted
267
268    # check preserving partial tree
269    def filter_NameError(exc):
270        if isinstance(exc, NameError):
271            return None
272        return exc
273
274    m = make_tree()
275    new_m = MultiError.filter(filter_NameError, m)
276    # with the NameError gone, the other branch gets promoted
277    assert new_m is m.exceptions[0]
278
279    # check fully handling everything
280    def filter_all(exc):
281        return None
282
283    assert MultiError.filter(filter_all, make_tree()) is None
284
285
286def test_MultiError_catch():
287    # No exception to catch
288
289    def noop(_):
290        pass  # pragma: no cover
291
292    with MultiError.catch(noop):
293        pass
294
295    # Simple pass-through of all exceptions
296    m = make_tree()
297    with pytest.raises(MultiError) as excinfo:
298        with MultiError.catch(lambda exc: exc):
299            raise m
300    assert excinfo.value is m
301    # Should be unchanged, except that we added a traceback frame by raising
302    # it here
303    assert m.__traceback__ is not None
304    assert m.__traceback__.tb_frame.f_code.co_name == "test_MultiError_catch"
305    assert m.__traceback__.tb_next is None
306    m.__traceback__ = None
307    assert_tree_eq(m, make_tree())
308
309    # Swallows everything
310    with MultiError.catch(lambda _: None):
311        raise make_tree()
312
313    def simple_filter(exc):
314        if isinstance(exc, ValueError):
315            return None
316        if isinstance(exc, KeyError):
317            return RuntimeError()
318        return exc
319
320    with pytest.raises(MultiError) as excinfo:
321        with MultiError.catch(simple_filter):
322            raise make_tree()
323    new_m = excinfo.value
324    assert isinstance(new_m, MultiError)
325    assert len(new_m.exceptions) == 2
326    # was: [[ValueError, KeyError], NameError]
327    # ValueError disappeared & KeyError became RuntimeError, so now:
328    assert isinstance(new_m.exceptions[0], RuntimeError)
329    assert isinstance(new_m.exceptions[1], NameError)
330    # Make sure that Python did not successfully attach the old MultiError to
331    # our new MultiError's __context__
332    assert not new_m.__suppress_context__
333    assert new_m.__context__ is None
334
335    # check preservation of __cause__ and __context__
336    v = ValueError()
337    v.__cause__ = KeyError()
338    with pytest.raises(ValueError) as excinfo:
339        with MultiError.catch(lambda exc: exc):
340            raise v
341    assert isinstance(excinfo.value.__cause__, KeyError)
342
343    v = ValueError()
344    context = KeyError()
345    v.__context__ = context
346    with pytest.raises(ValueError) as excinfo:
347        with MultiError.catch(lambda exc: exc):
348            raise v
349    assert excinfo.value.__context__ is context
350    assert not excinfo.value.__suppress_context__
351
352    for suppress_context in [True, False]:
353        v = ValueError()
354        context = KeyError()
355        v.__context__ = context
356        v.__suppress_context__ = suppress_context
357        distractor = RuntimeError()
358        with pytest.raises(ValueError) as excinfo:
359
360            def catch_RuntimeError(exc):
361                if isinstance(exc, RuntimeError):
362                    return None
363                else:
364                    return exc
365
366            with MultiError.catch(catch_RuntimeError):
367                raise MultiError([v, distractor])
368        assert excinfo.value.__context__ is context
369        assert excinfo.value.__suppress_context__ == suppress_context
370
371
372def assert_match_in_seq(pattern_list, string):
373    offset = 0
374    print("looking for pattern matches...")
375    for pattern in pattern_list:
376        print("checking pattern:", pattern)
377        reobj = re.compile(pattern)
378        match = reobj.search(string, offset)
379        assert match is not None
380        offset = match.end()
381
382
383def test_assert_match_in_seq():
384    assert_match_in_seq(["a", "b"], "xx a xx b xx")
385    assert_match_in_seq(["b", "a"], "xx b xx a xx")
386    with pytest.raises(AssertionError):
387        assert_match_in_seq(["a", "b"], "xx b xx a xx")
388
389
390def test_format_exception():
391    exc = get_exc(raiser1)
392    formatted = "".join(format_exception(*einfo(exc)))
393    assert "raiser1_string" in formatted
394    assert "in raiser1_3" in formatted
395    assert "raiser2_string" not in formatted
396    assert "in raiser2_2" not in formatted
397    assert "direct cause" not in formatted
398    assert "During handling" not in formatted
399
400    exc = get_exc(raiser1)
401    exc.__cause__ = get_exc(raiser2)
402    formatted = "".join(format_exception(*einfo(exc)))
403    assert "raiser1_string" in formatted
404    assert "in raiser1_3" in formatted
405    assert "raiser2_string" in formatted
406    assert "in raiser2_2" in formatted
407    assert "direct cause" in formatted
408    assert "During handling" not in formatted
409    # ensure cause included
410    assert _cause_message in formatted
411
412    exc = get_exc(raiser1)
413    exc.__context__ = get_exc(raiser2)
414    formatted = "".join(format_exception(*einfo(exc)))
415    assert "raiser1_string" in formatted
416    assert "in raiser1_3" in formatted
417    assert "raiser2_string" in formatted
418    assert "in raiser2_2" in formatted
419    assert "direct cause" not in formatted
420    assert "During handling" in formatted
421
422    exc.__suppress_context__ = True
423    formatted = "".join(format_exception(*einfo(exc)))
424    assert "raiser1_string" in formatted
425    assert "in raiser1_3" in formatted
426    assert "raiser2_string" not in formatted
427    assert "in raiser2_2" not in formatted
428    assert "direct cause" not in formatted
429    assert "During handling" not in formatted
430
431    # chain=False
432    exc = get_exc(raiser1)
433    exc.__context__ = get_exc(raiser2)
434    formatted = "".join(format_exception(*einfo(exc), chain=False))
435    assert "raiser1_string" in formatted
436    assert "in raiser1_3" in formatted
437    assert "raiser2_string" not in formatted
438    assert "in raiser2_2" not in formatted
439    assert "direct cause" not in formatted
440    assert "During handling" not in formatted
441
442    # limit
443    exc = get_exc(raiser1)
444    exc.__context__ = get_exc(raiser2)
445    # get_exc adds a frame that counts against the limit, so limit=2 means we
446    # get 1 deep into the raiser stack
447    formatted = "".join(format_exception(*einfo(exc), limit=2))
448    print(formatted)
449    assert "raiser1_string" in formatted
450    assert "in raiser1" in formatted
451    assert "in raiser1_2" not in formatted
452    assert "raiser2_string" in formatted
453    assert "in raiser2" in formatted
454    assert "in raiser2_2" not in formatted
455
456    exc = get_exc(raiser1)
457    exc.__context__ = get_exc(raiser2)
458    formatted = "".join(format_exception(*einfo(exc), limit=1))
459    print(formatted)
460    assert "raiser1_string" in formatted
461    assert "in raiser1" not in formatted
462    assert "raiser2_string" in formatted
463    assert "in raiser2" not in formatted
464
465    # handles loops
466    exc = get_exc(raiser1)
467    exc.__cause__ = exc
468    formatted = "".join(format_exception(*einfo(exc)))
469    assert "raiser1_string" in formatted
470    assert "in raiser1_3" in formatted
471    assert "raiser2_string" not in formatted
472    assert "in raiser2_2" not in formatted
473    # ensure duplicate exception is not included as cause
474    assert _cause_message not in formatted
475
476    # MultiError
477    formatted = "".join(format_exception(*einfo(make_tree())))
478    print(formatted)
479
480    assert_match_in_seq(
481        [
482            # Outer exception is MultiError
483            r"MultiError:",
484            # First embedded exception is the embedded MultiError
485            r"\nDetails of embedded exception 1",
486            # Which has a single stack frame from make_tree raising it
487            r"in make_tree",
488            # Then it has two embedded exceptions
489            r"  Details of embedded exception 1",
490            r"in raiser1_2",
491            # for some reason ValueError has no quotes
492            r"ValueError: raiser1_string",
493            r"  Details of embedded exception 2",
494            r"in raiser2_2",
495            # But KeyError does have quotes
496            r"KeyError: 'raiser2_string'",
497            # And finally the NameError, which is a sibling of the embedded
498            # MultiError
499            r"\nDetails of embedded exception 2:",
500            r"in raiser3",
501            r"NameError",
502        ],
503        formatted,
504    )
505
506    # Prints duplicate exceptions in sub-exceptions
507    exc1 = get_exc(raiser1)
508
509    def raise1_raiser1():
510        try:
511            raise exc1
512        except:
513            raise ValueError("foo")
514
515    def raise2_raiser1():
516        try:
517            raise exc1
518        except:
519            raise KeyError("bar")
520
521    exc2 = get_exc(raise1_raiser1)
522    exc3 = get_exc(raise2_raiser1)
523
524    try:
525        raise MultiError([exc2, exc3])
526    except MultiError as e:
527        exc = e
528
529    formatted = "".join(format_exception(*einfo(exc)))
530    print(formatted)
531
532    assert_match_in_seq(
533        [
534            r"Traceback",
535            # Outer exception is MultiError
536            r"MultiError:",
537            # First embedded exception is the embedded ValueError with cause of raiser1
538            r"\nDetails of embedded exception 1",
539            # Print details of exc1
540            r"  Traceback",
541            r"in get_exc",
542            r"in raiser1",
543            r"ValueError: raiser1_string",
544            # Print details of exc2
545            r"\n  During handling of the above exception, another exception occurred:",
546            r"  Traceback",
547            r"in get_exc",
548            r"in raise1_raiser1",
549            r"  ValueError: foo",
550            # Second embedded exception is the embedded KeyError with cause of raiser1
551            r"\nDetails of embedded exception 2",
552            # Print details of exc1 again
553            r"  Traceback",
554            r"in get_exc",
555            r"in raiser1",
556            r"ValueError: raiser1_string",
557            # Print details of exc3
558            r"\n  During handling of the above exception, another exception occurred:",
559            r"  Traceback",
560            r"in get_exc",
561            r"in raise2_raiser1",
562            r"  KeyError: 'bar'",
563        ],
564        formatted,
565    )
566
567
568def test_logging(caplog):
569    exc1 = get_exc(raiser1)
570    exc2 = get_exc(raiser2)
571
572    m = MultiError([exc1, exc2])
573
574    message = "test test test"
575    try:
576        raise m
577    except MultiError as exc:
578        logging.getLogger().exception(message)
579        # Join lines together
580        formatted = "".join(format_exception(type(exc), exc, exc.__traceback__))
581        assert message in caplog.text
582        assert formatted in caplog.text
583
584
585def run_script(name, use_ipython=False):
586    import trio
587
588    trio_path = Path(trio.__file__).parent.parent
589    script_path = Path(__file__).parent / "test_multierror_scripts" / name
590
591    env = dict(os.environ)
592    print("parent PYTHONPATH:", env.get("PYTHONPATH"))
593    if "PYTHONPATH" in env:  # pragma: no cover
594        pp = env["PYTHONPATH"].split(os.pathsep)
595    else:
596        pp = []
597    pp.insert(0, str(trio_path))
598    pp.insert(0, str(script_path.parent))
599    env["PYTHONPATH"] = os.pathsep.join(pp)
600    print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))
601
602    if use_ipython:
603        lines = [script_path.read_text(), "exit()"]
604
605        cmd = [
606            sys.executable,
607            "-u",
608            "-m",
609            "IPython",
610            # no startup files
611            "--quick",
612            "--TerminalIPythonApp.code_to_run=" + "\n".join(lines),
613        ]
614    else:
615        cmd = [sys.executable, "-u", str(script_path)]
616    print("running:", cmd)
617    completed = subprocess.run(
618        cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
619    )
620    print("process output:")
621    print(completed.stdout.decode("utf-8"))
622    return completed
623
624
625def check_simple_excepthook(completed):
626    assert_match_in_seq(
627        [
628            "in <module>",
629            "MultiError",
630            "Details of embedded exception 1",
631            "in exc1_fn",
632            "ValueError",
633            "Details of embedded exception 2",
634            "in exc2_fn",
635            "KeyError",
636        ],
637        completed.stdout.decode("utf-8"),
638    )
639
640
641def test_simple_excepthook():
642    completed = run_script("simple_excepthook.py")
643    check_simple_excepthook(completed)
644
645
646def test_custom_excepthook():
647    # Check that user-defined excepthooks aren't overridden
648    completed = run_script("custom_excepthook.py")
649    assert_match_in_seq(
650        [
651            # The warning
652            "RuntimeWarning",
653            "already have a custom",
654            # The message printed by the custom hook, proving we didn't
655            # override it
656            "custom running!",
657            # The MultiError
658            "MultiError:",
659        ],
660        completed.stdout.decode("utf-8"),
661    )
662
663
664# This warning is triggered by ipython 7.5.0 on python 3.8
665import warnings
666
667warnings.filterwarnings(
668    "ignore",
669    message='.*"@coroutine" decorator is deprecated',
670    category=DeprecationWarning,
671    module="IPython.*",
672)
673try:
674    import IPython
675except ImportError:  # pragma: no cover
676    have_ipython = False
677else:
678    have_ipython = True
679
680need_ipython = pytest.mark.skipif(not have_ipython, reason="need IPython")
681
682
683@slow
684@need_ipython
685def test_ipython_exc_handler():
686    completed = run_script("simple_excepthook.py", use_ipython=True)
687    check_simple_excepthook(completed)
688
689
690@slow
691@need_ipython
692def test_ipython_imported_but_unused():
693    completed = run_script("simple_excepthook_IPython.py")
694    check_simple_excepthook(completed)
695
696
697@slow
698def test_partial_imported_but_unused():
699    # Check that a functools.partial as sys.excepthook doesn't cause an exception when
700    # importing trio.  This was a problem due to the lack of a .__name__ attribute and
701    # happens when inside a pytest-qt test case for example.
702    completed = run_script("simple_excepthook_partial.py")
703    completed.check_returncode()
704
705
706@slow
707@need_ipython
708def test_ipython_custom_exc_handler():
709    # Check we get a nice warning (but only one!) if the user is using IPython
710    # and already has some other set_custom_exc handler installed.
711    completed = run_script("ipython_custom_exc.py", use_ipython=True)
712    assert_match_in_seq(
713        [
714            # The warning
715            "RuntimeWarning",
716            "IPython detected",
717            "skip installing Trio",
718            # The MultiError
719            "MultiError",
720            "ValueError",
721            "KeyError",
722        ],
723        completed.stdout.decode("utf-8"),
724    )
725    # Make sure our other warning doesn't show up
726    assert "custom sys.excepthook" not in completed.stdout.decode("utf-8")
727
728
729@slow
730@pytest.mark.skipif(
731    not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(),
732    reason="need Ubuntu with python3-apport installed",
733)
734def test_apport_excepthook_monkeypatch_interaction():
735    completed = run_script("apport_excepthook.py")
736    stdout = completed.stdout.decode("utf-8")
737
738    # No warning
739    assert "custom sys.excepthook" not in stdout
740
741    # Proper traceback
742    assert_match_in_seq(
743        ["Details of embedded", "KeyError", "Details of embedded", "ValueError"],
744        stdout,
745    )
746