1"""Unit tests for contextlib.py, and other context managers."""
2
3import io
4import sys
5import tempfile
6import threading
7import unittest
8from contextlib import *  # Tests __all__
9from test import support
10from test.support import os_helper
11import weakref
12
13
14class TestAbstractContextManager(unittest.TestCase):
15
16    def test_enter(self):
17        class DefaultEnter(AbstractContextManager):
18            def __exit__(self, *args):
19                super().__exit__(*args)
20
21        manager = DefaultEnter()
22        self.assertIs(manager.__enter__(), manager)
23
24    def test_exit_is_abstract(self):
25        class MissingExit(AbstractContextManager):
26            pass
27
28        with self.assertRaises(TypeError):
29            MissingExit()
30
31    def test_structural_subclassing(self):
32        class ManagerFromScratch:
33            def __enter__(self):
34                return self
35            def __exit__(self, exc_type, exc_value, traceback):
36                return None
37
38        self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
39
40        class DefaultEnter(AbstractContextManager):
41            def __exit__(self, *args):
42                super().__exit__(*args)
43
44        self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
45
46        class NoEnter(ManagerFromScratch):
47            __enter__ = None
48
49        self.assertFalse(issubclass(NoEnter, AbstractContextManager))
50
51        class NoExit(ManagerFromScratch):
52            __exit__ = None
53
54        self.assertFalse(issubclass(NoExit, AbstractContextManager))
55
56
57class ContextManagerTestCase(unittest.TestCase):
58
59    def test_contextmanager_plain(self):
60        state = []
61        @contextmanager
62        def woohoo():
63            state.append(1)
64            yield 42
65            state.append(999)
66        with woohoo() as x:
67            self.assertEqual(state, [1])
68            self.assertEqual(x, 42)
69            state.append(x)
70        self.assertEqual(state, [1, 42, 999])
71
72    def test_contextmanager_finally(self):
73        state = []
74        @contextmanager
75        def woohoo():
76            state.append(1)
77            try:
78                yield 42
79            finally:
80                state.append(999)
81        with self.assertRaises(ZeroDivisionError):
82            with woohoo() as x:
83                self.assertEqual(state, [1])
84                self.assertEqual(x, 42)
85                state.append(x)
86                raise ZeroDivisionError()
87        self.assertEqual(state, [1, 42, 999])
88
89    def test_contextmanager_no_reraise(self):
90        @contextmanager
91        def whee():
92            yield
93        ctx = whee()
94        ctx.__enter__()
95        # Calling __exit__ should not result in an exception
96        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
97
98    def test_contextmanager_trap_yield_after_throw(self):
99        @contextmanager
100        def whoo():
101            try:
102                yield
103            except:
104                yield
105        ctx = whoo()
106        ctx.__enter__()
107        self.assertRaises(
108            RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
109        )
110
111    def test_contextmanager_except(self):
112        state = []
113        @contextmanager
114        def woohoo():
115            state.append(1)
116            try:
117                yield 42
118            except ZeroDivisionError as e:
119                state.append(e.args[0])
120                self.assertEqual(state, [1, 42, 999])
121        with woohoo() as x:
122            self.assertEqual(state, [1])
123            self.assertEqual(x, 42)
124            state.append(x)
125            raise ZeroDivisionError(999)
126        self.assertEqual(state, [1, 42, 999])
127
128    def test_contextmanager_except_stopiter(self):
129        @contextmanager
130        def woohoo():
131            yield
132
133        class StopIterationSubclass(StopIteration):
134            pass
135
136        for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
137            with self.subTest(type=type(stop_exc)):
138                try:
139                    with woohoo():
140                        raise stop_exc
141                except Exception as ex:
142                    self.assertIs(ex, stop_exc)
143                else:
144                    self.fail(f'{stop_exc} was suppressed')
145
146    def test_contextmanager_except_pep479(self):
147        code = """\
148from __future__ import generator_stop
149from contextlib import contextmanager
150@contextmanager
151def woohoo():
152    yield
153"""
154        locals = {}
155        exec(code, locals, locals)
156        woohoo = locals['woohoo']
157
158        stop_exc = StopIteration('spam')
159        try:
160            with woohoo():
161                raise stop_exc
162        except Exception as ex:
163            self.assertIs(ex, stop_exc)
164        else:
165            self.fail('StopIteration was suppressed')
166
167    def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
168        @contextmanager
169        def test_issue29692():
170            try:
171                yield
172            except Exception as exc:
173                raise RuntimeError('issue29692:Chained') from exc
174        try:
175            with test_issue29692():
176                raise ZeroDivisionError
177        except Exception as ex:
178            self.assertIs(type(ex), RuntimeError)
179            self.assertEqual(ex.args[0], 'issue29692:Chained')
180            self.assertIsInstance(ex.__cause__, ZeroDivisionError)
181
182        try:
183            with test_issue29692():
184                raise StopIteration('issue29692:Unchained')
185        except Exception as ex:
186            self.assertIs(type(ex), StopIteration)
187            self.assertEqual(ex.args[0], 'issue29692:Unchained')
188            self.assertIsNone(ex.__cause__)
189
190    def _create_contextmanager_attribs(self):
191        def attribs(**kw):
192            def decorate(func):
193                for k,v in kw.items():
194                    setattr(func,k,v)
195                return func
196            return decorate
197        @contextmanager
198        @attribs(foo='bar')
199        def baz(spam):
200            """Whee!"""
201        return baz
202
203    def test_contextmanager_attribs(self):
204        baz = self._create_contextmanager_attribs()
205        self.assertEqual(baz.__name__,'baz')
206        self.assertEqual(baz.foo, 'bar')
207
208    @support.requires_docstrings
209    def test_contextmanager_doc_attrib(self):
210        baz = self._create_contextmanager_attribs()
211        self.assertEqual(baz.__doc__, "Whee!")
212
213    @support.requires_docstrings
214    def test_instance_docstring_given_cm_docstring(self):
215        baz = self._create_contextmanager_attribs()(None)
216        self.assertEqual(baz.__doc__, "Whee!")
217
218    def test_keywords(self):
219        # Ensure no keyword arguments are inhibited
220        @contextmanager
221        def woohoo(self, func, args, kwds):
222            yield (self, func, args, kwds)
223        with woohoo(self=11, func=22, args=33, kwds=44) as target:
224            self.assertEqual(target, (11, 22, 33, 44))
225
226    def test_nokeepref(self):
227        class A:
228            pass
229
230        @contextmanager
231        def woohoo(a, b):
232            a = weakref.ref(a)
233            b = weakref.ref(b)
234            # Allow test to work with a non-refcounted GC
235            support.gc_collect()
236            self.assertIsNone(a())
237            self.assertIsNone(b())
238            yield
239
240        with woohoo(A(), b=A()):
241            pass
242
243    def test_param_errors(self):
244        @contextmanager
245        def woohoo(a, *, b):
246            yield
247
248        with self.assertRaises(TypeError):
249            woohoo()
250        with self.assertRaises(TypeError):
251            woohoo(3, 5)
252        with self.assertRaises(TypeError):
253            woohoo(b=3)
254
255    def test_recursive(self):
256        depth = 0
257        @contextmanager
258        def woohoo():
259            nonlocal depth
260            before = depth
261            depth += 1
262            yield
263            depth -= 1
264            self.assertEqual(depth, before)
265
266        @woohoo()
267        def recursive():
268            if depth < 10:
269                recursive()
270
271        recursive()
272        self.assertEqual(depth, 0)
273
274
275class ClosingTestCase(unittest.TestCase):
276
277    @support.requires_docstrings
278    def test_instance_docs(self):
279        # Issue 19330: ensure context manager instances have good docstrings
280        cm_docstring = closing.__doc__
281        obj = closing(None)
282        self.assertEqual(obj.__doc__, cm_docstring)
283
284    def test_closing(self):
285        state = []
286        class C:
287            def close(self):
288                state.append(1)
289        x = C()
290        self.assertEqual(state, [])
291        with closing(x) as y:
292            self.assertEqual(x, y)
293        self.assertEqual(state, [1])
294
295    def test_closing_error(self):
296        state = []
297        class C:
298            def close(self):
299                state.append(1)
300        x = C()
301        self.assertEqual(state, [])
302        with self.assertRaises(ZeroDivisionError):
303            with closing(x) as y:
304                self.assertEqual(x, y)
305                1 / 0
306        self.assertEqual(state, [1])
307
308
309class NullcontextTestCase(unittest.TestCase):
310    def test_nullcontext(self):
311        class C:
312            pass
313        c = C()
314        with nullcontext(c) as c_in:
315            self.assertIs(c_in, c)
316
317
318class FileContextTestCase(unittest.TestCase):
319
320    def testWithOpen(self):
321        tfn = tempfile.mktemp()
322        try:
323            f = None
324            with open(tfn, "w", encoding="utf-8") as f:
325                self.assertFalse(f.closed)
326                f.write("Booh\n")
327            self.assertTrue(f.closed)
328            f = None
329            with self.assertRaises(ZeroDivisionError):
330                with open(tfn, "r", encoding="utf-8") as f:
331                    self.assertFalse(f.closed)
332                    self.assertEqual(f.read(), "Booh\n")
333                    1 / 0
334            self.assertTrue(f.closed)
335        finally:
336            os_helper.unlink(tfn)
337
338class LockContextTestCase(unittest.TestCase):
339
340    def boilerPlate(self, lock, locked):
341        self.assertFalse(locked())
342        with lock:
343            self.assertTrue(locked())
344        self.assertFalse(locked())
345        with self.assertRaises(ZeroDivisionError):
346            with lock:
347                self.assertTrue(locked())
348                1 / 0
349        self.assertFalse(locked())
350
351    def testWithLock(self):
352        lock = threading.Lock()
353        self.boilerPlate(lock, lock.locked)
354
355    def testWithRLock(self):
356        lock = threading.RLock()
357        self.boilerPlate(lock, lock._is_owned)
358
359    def testWithCondition(self):
360        lock = threading.Condition()
361        def locked():
362            return lock._is_owned()
363        self.boilerPlate(lock, locked)
364
365    def testWithSemaphore(self):
366        lock = threading.Semaphore()
367        def locked():
368            if lock.acquire(False):
369                lock.release()
370                return False
371            else:
372                return True
373        self.boilerPlate(lock, locked)
374
375    def testWithBoundedSemaphore(self):
376        lock = threading.BoundedSemaphore()
377        def locked():
378            if lock.acquire(False):
379                lock.release()
380                return False
381            else:
382                return True
383        self.boilerPlate(lock, locked)
384
385
386class mycontext(ContextDecorator):
387    """Example decoration-compatible context manager for testing"""
388    started = False
389    exc = None
390    catch = False
391
392    def __enter__(self):
393        self.started = True
394        return self
395
396    def __exit__(self, *exc):
397        self.exc = exc
398        return self.catch
399
400
401class TestContextDecorator(unittest.TestCase):
402
403    @support.requires_docstrings
404    def test_instance_docs(self):
405        # Issue 19330: ensure context manager instances have good docstrings
406        cm_docstring = mycontext.__doc__
407        obj = mycontext()
408        self.assertEqual(obj.__doc__, cm_docstring)
409
410    def test_contextdecorator(self):
411        context = mycontext()
412        with context as result:
413            self.assertIs(result, context)
414            self.assertTrue(context.started)
415
416        self.assertEqual(context.exc, (None, None, None))
417
418
419    def test_contextdecorator_with_exception(self):
420        context = mycontext()
421
422        with self.assertRaisesRegex(NameError, 'foo'):
423            with context:
424                raise NameError('foo')
425        self.assertIsNotNone(context.exc)
426        self.assertIs(context.exc[0], NameError)
427
428        context = mycontext()
429        context.catch = True
430        with context:
431            raise NameError('foo')
432        self.assertIsNotNone(context.exc)
433        self.assertIs(context.exc[0], NameError)
434
435
436    def test_decorator(self):
437        context = mycontext()
438
439        @context
440        def test():
441            self.assertIsNone(context.exc)
442            self.assertTrue(context.started)
443        test()
444        self.assertEqual(context.exc, (None, None, None))
445
446
447    def test_decorator_with_exception(self):
448        context = mycontext()
449
450        @context
451        def test():
452            self.assertIsNone(context.exc)
453            self.assertTrue(context.started)
454            raise NameError('foo')
455
456        with self.assertRaisesRegex(NameError, 'foo'):
457            test()
458        self.assertIsNotNone(context.exc)
459        self.assertIs(context.exc[0], NameError)
460
461
462    def test_decorating_method(self):
463        context = mycontext()
464
465        class Test(object):
466
467            @context
468            def method(self, a, b, c=None):
469                self.a = a
470                self.b = b
471                self.c = c
472
473        # these tests are for argument passing when used as a decorator
474        test = Test()
475        test.method(1, 2)
476        self.assertEqual(test.a, 1)
477        self.assertEqual(test.b, 2)
478        self.assertEqual(test.c, None)
479
480        test = Test()
481        test.method('a', 'b', 'c')
482        self.assertEqual(test.a, 'a')
483        self.assertEqual(test.b, 'b')
484        self.assertEqual(test.c, 'c')
485
486        test = Test()
487        test.method(a=1, b=2)
488        self.assertEqual(test.a, 1)
489        self.assertEqual(test.b, 2)
490
491
492    def test_typo_enter(self):
493        class mycontext(ContextDecorator):
494            def __unter__(self):
495                pass
496            def __exit__(self, *exc):
497                pass
498
499        with self.assertRaises(AttributeError):
500            with mycontext():
501                pass
502
503
504    def test_typo_exit(self):
505        class mycontext(ContextDecorator):
506            def __enter__(self):
507                pass
508            def __uxit__(self, *exc):
509                pass
510
511        with self.assertRaises(AttributeError):
512            with mycontext():
513                pass
514
515
516    def test_contextdecorator_as_mixin(self):
517        class somecontext(object):
518            started = False
519            exc = None
520
521            def __enter__(self):
522                self.started = True
523                return self
524
525            def __exit__(self, *exc):
526                self.exc = exc
527
528        class mycontext(somecontext, ContextDecorator):
529            pass
530
531        context = mycontext()
532        @context
533        def test():
534            self.assertIsNone(context.exc)
535            self.assertTrue(context.started)
536        test()
537        self.assertEqual(context.exc, (None, None, None))
538
539
540    def test_contextmanager_as_decorator(self):
541        @contextmanager
542        def woohoo(y):
543            state.append(y)
544            yield
545            state.append(999)
546
547        state = []
548        @woohoo(1)
549        def test(x):
550            self.assertEqual(state, [1])
551            state.append(x)
552        test('something')
553        self.assertEqual(state, [1, 'something', 999])
554
555        # Issue #11647: Ensure the decorated function is 'reusable'
556        state = []
557        test('something else')
558        self.assertEqual(state, [1, 'something else', 999])
559
560
561class TestBaseExitStack:
562    exit_stack = None
563
564    @support.requires_docstrings
565    def test_instance_docs(self):
566        # Issue 19330: ensure context manager instances have good docstrings
567        cm_docstring = self.exit_stack.__doc__
568        obj = self.exit_stack()
569        self.assertEqual(obj.__doc__, cm_docstring)
570
571    def test_no_resources(self):
572        with self.exit_stack():
573            pass
574
575    def test_callback(self):
576        expected = [
577            ((), {}),
578            ((1,), {}),
579            ((1,2), {}),
580            ((), dict(example=1)),
581            ((1,), dict(example=1)),
582            ((1,2), dict(example=1)),
583            ((1,2), dict(self=3, callback=4)),
584        ]
585        result = []
586        def _exit(*args, **kwds):
587            """Test metadata propagation"""
588            result.append((args, kwds))
589        with self.exit_stack() as stack:
590            for args, kwds in reversed(expected):
591                if args and kwds:
592                    f = stack.callback(_exit, *args, **kwds)
593                elif args:
594                    f = stack.callback(_exit, *args)
595                elif kwds:
596                    f = stack.callback(_exit, **kwds)
597                else:
598                    f = stack.callback(_exit)
599                self.assertIs(f, _exit)
600            for wrapper in stack._exit_callbacks:
601                self.assertIs(wrapper[1].__wrapped__, _exit)
602                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
603                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
604        self.assertEqual(result, expected)
605
606        result = []
607        with self.exit_stack() as stack:
608            with self.assertRaises(TypeError):
609                stack.callback(arg=1)
610            with self.assertRaises(TypeError):
611                self.exit_stack.callback(arg=2)
612            with self.assertRaises(TypeError):
613                stack.callback(callback=_exit, arg=3)
614        self.assertEqual(result, [])
615
616    def test_push(self):
617        exc_raised = ZeroDivisionError
618        def _expect_exc(exc_type, exc, exc_tb):
619            self.assertIs(exc_type, exc_raised)
620        def _suppress_exc(*exc_details):
621            return True
622        def _expect_ok(exc_type, exc, exc_tb):
623            self.assertIsNone(exc_type)
624            self.assertIsNone(exc)
625            self.assertIsNone(exc_tb)
626        class ExitCM(object):
627            def __init__(self, check_exc):
628                self.check_exc = check_exc
629            def __enter__(self):
630                self.fail("Should not be called!")
631            def __exit__(self, *exc_details):
632                self.check_exc(*exc_details)
633        with self.exit_stack() as stack:
634            stack.push(_expect_ok)
635            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
636            cm = ExitCM(_expect_ok)
637            stack.push(cm)
638            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
639            stack.push(_suppress_exc)
640            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
641            cm = ExitCM(_expect_exc)
642            stack.push(cm)
643            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
644            stack.push(_expect_exc)
645            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
646            stack.push(_expect_exc)
647            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
648            1/0
649
650    def test_enter_context(self):
651        class TestCM(object):
652            def __enter__(self):
653                result.append(1)
654            def __exit__(self, *exc_details):
655                result.append(3)
656
657        result = []
658        cm = TestCM()
659        with self.exit_stack() as stack:
660            @stack.callback  # Registered first => cleaned up last
661            def _exit():
662                result.append(4)
663            self.assertIsNotNone(_exit)
664            stack.enter_context(cm)
665            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
666            result.append(2)
667        self.assertEqual(result, [1, 2, 3, 4])
668
669    def test_close(self):
670        result = []
671        with self.exit_stack() as stack:
672            @stack.callback
673            def _exit():
674                result.append(1)
675            self.assertIsNotNone(_exit)
676            stack.close()
677            result.append(2)
678        self.assertEqual(result, [1, 2])
679
680    def test_pop_all(self):
681        result = []
682        with self.exit_stack() as stack:
683            @stack.callback
684            def _exit():
685                result.append(3)
686            self.assertIsNotNone(_exit)
687            new_stack = stack.pop_all()
688            result.append(1)
689        result.append(2)
690        new_stack.close()
691        self.assertEqual(result, [1, 2, 3])
692
693    def test_exit_raise(self):
694        with self.assertRaises(ZeroDivisionError):
695            with self.exit_stack() as stack:
696                stack.push(lambda *exc: False)
697                1/0
698
699    def test_exit_suppress(self):
700        with self.exit_stack() as stack:
701            stack.push(lambda *exc: True)
702            1/0
703
704    def test_exit_exception_chaining_reference(self):
705        # Sanity check to make sure that ExitStack chaining matches
706        # actual nested with statements
707        class RaiseExc:
708            def __init__(self, exc):
709                self.exc = exc
710            def __enter__(self):
711                return self
712            def __exit__(self, *exc_details):
713                raise self.exc
714
715        class RaiseExcWithContext:
716            def __init__(self, outer, inner):
717                self.outer = outer
718                self.inner = inner
719            def __enter__(self):
720                return self
721            def __exit__(self, *exc_details):
722                try:
723                    raise self.inner
724                except:
725                    raise self.outer
726
727        class SuppressExc:
728            def __enter__(self):
729                return self
730            def __exit__(self, *exc_details):
731                type(self).saved_details = exc_details
732                return True
733
734        try:
735            with RaiseExc(IndexError):
736                with RaiseExcWithContext(KeyError, AttributeError):
737                    with SuppressExc():
738                        with RaiseExc(ValueError):
739                            1 / 0
740        except IndexError as exc:
741            self.assertIsInstance(exc.__context__, KeyError)
742            self.assertIsInstance(exc.__context__.__context__, AttributeError)
743            # Inner exceptions were suppressed
744            self.assertIsNone(exc.__context__.__context__.__context__)
745        else:
746            self.fail("Expected IndexError, but no exception was raised")
747        # Check the inner exceptions
748        inner_exc = SuppressExc.saved_details[1]
749        self.assertIsInstance(inner_exc, ValueError)
750        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
751
752    def test_exit_exception_chaining(self):
753        # Ensure exception chaining matches the reference behaviour
754        def raise_exc(exc):
755            raise exc
756
757        saved_details = None
758        def suppress_exc(*exc_details):
759            nonlocal saved_details
760            saved_details = exc_details
761            return True
762
763        try:
764            with self.exit_stack() as stack:
765                stack.callback(raise_exc, IndexError)
766                stack.callback(raise_exc, KeyError)
767                stack.callback(raise_exc, AttributeError)
768                stack.push(suppress_exc)
769                stack.callback(raise_exc, ValueError)
770                1 / 0
771        except IndexError as exc:
772            self.assertIsInstance(exc.__context__, KeyError)
773            self.assertIsInstance(exc.__context__.__context__, AttributeError)
774            # Inner exceptions were suppressed
775            self.assertIsNone(exc.__context__.__context__.__context__)
776        else:
777            self.fail("Expected IndexError, but no exception was raised")
778        # Check the inner exceptions
779        inner_exc = saved_details[1]
780        self.assertIsInstance(inner_exc, ValueError)
781        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
782
783    def test_exit_exception_explicit_none_context(self):
784        # Ensure ExitStack chaining matches actual nested `with` statements
785        # regarding explicit __context__ = None.
786
787        class MyException(Exception):
788            pass
789
790        @contextmanager
791        def my_cm():
792            try:
793                yield
794            except BaseException:
795                exc = MyException()
796                try:
797                    raise exc
798                finally:
799                    exc.__context__ = None
800
801        @contextmanager
802        def my_cm_with_exit_stack():
803            with self.exit_stack() as stack:
804                stack.enter_context(my_cm())
805                yield stack
806
807        for cm in (my_cm, my_cm_with_exit_stack):
808            with self.subTest():
809                try:
810                    with cm():
811                        raise IndexError()
812                except MyException as exc:
813                    self.assertIsNone(exc.__context__)
814                else:
815                    self.fail("Expected IndexError, but no exception was raised")
816
817    def test_exit_exception_non_suppressing(self):
818        # http://bugs.python.org/issue19092
819        def raise_exc(exc):
820            raise exc
821
822        def suppress_exc(*exc_details):
823            return True
824
825        try:
826            with self.exit_stack() as stack:
827                stack.callback(lambda: None)
828                stack.callback(raise_exc, IndexError)
829        except Exception as exc:
830            self.assertIsInstance(exc, IndexError)
831        else:
832            self.fail("Expected IndexError, but no exception was raised")
833
834        try:
835            with self.exit_stack() as stack:
836                stack.callback(raise_exc, KeyError)
837                stack.push(suppress_exc)
838                stack.callback(raise_exc, IndexError)
839        except Exception as exc:
840            self.assertIsInstance(exc, KeyError)
841        else:
842            self.fail("Expected KeyError, but no exception was raised")
843
844    def test_exit_exception_with_correct_context(self):
845        # http://bugs.python.org/issue20317
846        @contextmanager
847        def gets_the_context_right(exc):
848            try:
849                yield
850            finally:
851                raise exc
852
853        exc1 = Exception(1)
854        exc2 = Exception(2)
855        exc3 = Exception(3)
856        exc4 = Exception(4)
857
858        # The contextmanager already fixes the context, so prior to the
859        # fix, ExitStack would try to fix it *again* and get into an
860        # infinite self-referential loop
861        try:
862            with self.exit_stack() as stack:
863                stack.enter_context(gets_the_context_right(exc4))
864                stack.enter_context(gets_the_context_right(exc3))
865                stack.enter_context(gets_the_context_right(exc2))
866                raise exc1
867        except Exception as exc:
868            self.assertIs(exc, exc4)
869            self.assertIs(exc.__context__, exc3)
870            self.assertIs(exc.__context__.__context__, exc2)
871            self.assertIs(exc.__context__.__context__.__context__, exc1)
872            self.assertIsNone(
873                       exc.__context__.__context__.__context__.__context__)
874
875    def test_exit_exception_with_existing_context(self):
876        # Addresses a lack of test coverage discovered after checking in a
877        # fix for issue 20317 that still contained debugging code.
878        def raise_nested(inner_exc, outer_exc):
879            try:
880                raise inner_exc
881            finally:
882                raise outer_exc
883        exc1 = Exception(1)
884        exc2 = Exception(2)
885        exc3 = Exception(3)
886        exc4 = Exception(4)
887        exc5 = Exception(5)
888        try:
889            with self.exit_stack() as stack:
890                stack.callback(raise_nested, exc4, exc5)
891                stack.callback(raise_nested, exc2, exc3)
892                raise exc1
893        except Exception as exc:
894            self.assertIs(exc, exc5)
895            self.assertIs(exc.__context__, exc4)
896            self.assertIs(exc.__context__.__context__, exc3)
897            self.assertIs(exc.__context__.__context__.__context__, exc2)
898            self.assertIs(
899                 exc.__context__.__context__.__context__.__context__, exc1)
900            self.assertIsNone(
901                exc.__context__.__context__.__context__.__context__.__context__)
902
903    def test_body_exception_suppress(self):
904        def suppress_exc(*exc_details):
905            return True
906        try:
907            with self.exit_stack() as stack:
908                stack.push(suppress_exc)
909                1/0
910        except IndexError as exc:
911            self.fail("Expected no exception, got IndexError")
912
913    def test_exit_exception_chaining_suppress(self):
914        with self.exit_stack() as stack:
915            stack.push(lambda *exc: True)
916            stack.push(lambda *exc: 1/0)
917            stack.push(lambda *exc: {}[1])
918
919    def test_excessive_nesting(self):
920        # The original implementation would die with RecursionError here
921        with self.exit_stack() as stack:
922            for i in range(10000):
923                stack.callback(int)
924
925    def test_instance_bypass(self):
926        class Example(object): pass
927        cm = Example()
928        cm.__exit__ = object()
929        stack = self.exit_stack()
930        self.assertRaises(AttributeError, stack.enter_context, cm)
931        stack.push(cm)
932        self.assertIs(stack._exit_callbacks[-1][1], cm)
933
934    def test_dont_reraise_RuntimeError(self):
935        # https://bugs.python.org/issue27122
936        class UniqueException(Exception): pass
937        class UniqueRuntimeError(RuntimeError): pass
938
939        @contextmanager
940        def second():
941            try:
942                yield 1
943            except Exception as exc:
944                raise UniqueException("new exception") from exc
945
946        @contextmanager
947        def first():
948            try:
949                yield 1
950            except Exception as exc:
951                raise exc
952
953        # The UniqueRuntimeError should be caught by second()'s exception
954        # handler which chain raised a new UniqueException.
955        with self.assertRaises(UniqueException) as err_ctx:
956            with self.exit_stack() as es_ctx:
957                es_ctx.enter_context(second())
958                es_ctx.enter_context(first())
959                raise UniqueRuntimeError("please no infinite loop.")
960
961        exc = err_ctx.exception
962        self.assertIsInstance(exc, UniqueException)
963        self.assertIsInstance(exc.__context__, UniqueRuntimeError)
964        self.assertIsNone(exc.__context__.__context__)
965        self.assertIsNone(exc.__context__.__cause__)
966        self.assertIs(exc.__cause__, exc.__context__)
967
968
969class TestExitStack(TestBaseExitStack, unittest.TestCase):
970    exit_stack = ExitStack
971
972
973class TestRedirectStream:
974
975    redirect_stream = None
976    orig_stream = None
977
978    @support.requires_docstrings
979    def test_instance_docs(self):
980        # Issue 19330: ensure context manager instances have good docstrings
981        cm_docstring = self.redirect_stream.__doc__
982        obj = self.redirect_stream(None)
983        self.assertEqual(obj.__doc__, cm_docstring)
984
985    def test_no_redirect_in_init(self):
986        orig_stdout = getattr(sys, self.orig_stream)
987        self.redirect_stream(None)
988        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
989
990    def test_redirect_to_string_io(self):
991        f = io.StringIO()
992        msg = "Consider an API like help(), which prints directly to stdout"
993        orig_stdout = getattr(sys, self.orig_stream)
994        with self.redirect_stream(f):
995            print(msg, file=getattr(sys, self.orig_stream))
996        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
997        s = f.getvalue().strip()
998        self.assertEqual(s, msg)
999
1000    def test_enter_result_is_target(self):
1001        f = io.StringIO()
1002        with self.redirect_stream(f) as enter_result:
1003            self.assertIs(enter_result, f)
1004
1005    def test_cm_is_reusable(self):
1006        f = io.StringIO()
1007        write_to_f = self.redirect_stream(f)
1008        orig_stdout = getattr(sys, self.orig_stream)
1009        with write_to_f:
1010            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1011        with write_to_f:
1012            print("World!", file=getattr(sys, self.orig_stream))
1013        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1014        s = f.getvalue()
1015        self.assertEqual(s, "Hello World!\n")
1016
1017    def test_cm_is_reentrant(self):
1018        f = io.StringIO()
1019        write_to_f = self.redirect_stream(f)
1020        orig_stdout = getattr(sys, self.orig_stream)
1021        with write_to_f:
1022            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1023            with write_to_f:
1024                print("World!", file=getattr(sys, self.orig_stream))
1025        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1026        s = f.getvalue()
1027        self.assertEqual(s, "Hello World!\n")
1028
1029
1030class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
1031
1032    redirect_stream = redirect_stdout
1033    orig_stream = "stdout"
1034
1035
1036class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
1037
1038    redirect_stream = redirect_stderr
1039    orig_stream = "stderr"
1040
1041
1042class TestSuppress(unittest.TestCase):
1043
1044    @support.requires_docstrings
1045    def test_instance_docs(self):
1046        # Issue 19330: ensure context manager instances have good docstrings
1047        cm_docstring = suppress.__doc__
1048        obj = suppress()
1049        self.assertEqual(obj.__doc__, cm_docstring)
1050
1051    def test_no_result_from_enter(self):
1052        with suppress(ValueError) as enter_result:
1053            self.assertIsNone(enter_result)
1054
1055    def test_no_exception(self):
1056        with suppress(ValueError):
1057            self.assertEqual(pow(2, 5), 32)
1058
1059    def test_exact_exception(self):
1060        with suppress(TypeError):
1061            len(5)
1062
1063    def test_exception_hierarchy(self):
1064        with suppress(LookupError):
1065            'Hello'[50]
1066
1067    def test_other_exception(self):
1068        with self.assertRaises(ZeroDivisionError):
1069            with suppress(TypeError):
1070                1/0
1071
1072    def test_no_args(self):
1073        with self.assertRaises(ZeroDivisionError):
1074            with suppress():
1075                1/0
1076
1077    def test_multiple_exception_args(self):
1078        with suppress(ZeroDivisionError, TypeError):
1079            1/0
1080        with suppress(ZeroDivisionError, TypeError):
1081            len(5)
1082
1083    def test_cm_is_reentrant(self):
1084        ignore_exceptions = suppress(Exception)
1085        with ignore_exceptions:
1086            pass
1087        with ignore_exceptions:
1088            len(5)
1089        with ignore_exceptions:
1090            with ignore_exceptions: # Check nested usage
1091                len(5)
1092            outer_continued = True
1093            1/0
1094        self.assertTrue(outer_continued)
1095
1096if __name__ == "__main__":
1097    unittest.main()
1098