1"""Unit tests for contextlib.py, and other context managers."""
2
3import io
4import os
5import sys
6import tempfile
7import threading
8import unittest
9from contextlib import *  # Tests __all__
10from test import support
11from test.support import os_helper
12import weakref
13
14
15class TestAbstractContextManager(unittest.TestCase):
16
17    def test_enter(self):
18        class DefaultEnter(AbstractContextManager):
19            def __exit__(self, *args):
20                super().__exit__(*args)
21
22        manager = DefaultEnter()
23        self.assertIs(manager.__enter__(), manager)
24
25    def test_exit_is_abstract(self):
26        class MissingExit(AbstractContextManager):
27            pass
28
29        with self.assertRaises(TypeError):
30            MissingExit()
31
32    def test_structural_subclassing(self):
33        class ManagerFromScratch:
34            def __enter__(self):
35                return self
36            def __exit__(self, exc_type, exc_value, traceback):
37                return None
38
39        self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
40
41        class DefaultEnter(AbstractContextManager):
42            def __exit__(self, *args):
43                super().__exit__(*args)
44
45        self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
46
47        class NoEnter(ManagerFromScratch):
48            __enter__ = None
49
50        self.assertFalse(issubclass(NoEnter, AbstractContextManager))
51
52        class NoExit(ManagerFromScratch):
53            __exit__ = None
54
55        self.assertFalse(issubclass(NoExit, AbstractContextManager))
56
57
58class ContextManagerTestCase(unittest.TestCase):
59
60    def test_contextmanager_plain(self):
61        state = []
62        @contextmanager
63        def woohoo():
64            state.append(1)
65            yield 42
66            state.append(999)
67        with woohoo() as x:
68            self.assertEqual(state, [1])
69            self.assertEqual(x, 42)
70            state.append(x)
71        self.assertEqual(state, [1, 42, 999])
72
73    def test_contextmanager_finally(self):
74        state = []
75        @contextmanager
76        def woohoo():
77            state.append(1)
78            try:
79                yield 42
80            finally:
81                state.append(999)
82        with self.assertRaises(ZeroDivisionError):
83            with woohoo() as x:
84                self.assertEqual(state, [1])
85                self.assertEqual(x, 42)
86                state.append(x)
87                raise ZeroDivisionError()
88        self.assertEqual(state, [1, 42, 999])
89
90    def test_contextmanager_no_reraise(self):
91        @contextmanager
92        def whee():
93            yield
94        ctx = whee()
95        ctx.__enter__()
96        # Calling __exit__ should not result in an exception
97        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
98
99    def test_contextmanager_trap_yield_after_throw(self):
100        @contextmanager
101        def whoo():
102            try:
103                yield
104            except:
105                yield
106        ctx = whoo()
107        ctx.__enter__()
108        self.assertRaises(
109            RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
110        )
111
112    def test_contextmanager_except(self):
113        state = []
114        @contextmanager
115        def woohoo():
116            state.append(1)
117            try:
118                yield 42
119            except ZeroDivisionError as e:
120                state.append(e.args[0])
121                self.assertEqual(state, [1, 42, 999])
122        with woohoo() as x:
123            self.assertEqual(state, [1])
124            self.assertEqual(x, 42)
125            state.append(x)
126            raise ZeroDivisionError(999)
127        self.assertEqual(state, [1, 42, 999])
128
129    def test_contextmanager_except_stopiter(self):
130        @contextmanager
131        def woohoo():
132            yield
133
134        class StopIterationSubclass(StopIteration):
135            pass
136
137        for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
138            with self.subTest(type=type(stop_exc)):
139                try:
140                    with woohoo():
141                        raise stop_exc
142                except Exception as ex:
143                    self.assertIs(ex, stop_exc)
144                else:
145                    self.fail(f'{stop_exc} was suppressed')
146
147    def test_contextmanager_except_pep479(self):
148        code = """\
149from __future__ import generator_stop
150from contextlib import contextmanager
151@contextmanager
152def woohoo():
153    yield
154"""
155        locals = {}
156        exec(code, locals, locals)
157        woohoo = locals['woohoo']
158
159        stop_exc = StopIteration('spam')
160        try:
161            with woohoo():
162                raise stop_exc
163        except Exception as ex:
164            self.assertIs(ex, stop_exc)
165        else:
166            self.fail('StopIteration was suppressed')
167
168    def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
169        @contextmanager
170        def test_issue29692():
171            try:
172                yield
173            except Exception as exc:
174                raise RuntimeError('issue29692:Chained') from exc
175        try:
176            with test_issue29692():
177                raise ZeroDivisionError
178        except Exception as ex:
179            self.assertIs(type(ex), RuntimeError)
180            self.assertEqual(ex.args[0], 'issue29692:Chained')
181            self.assertIsInstance(ex.__cause__, ZeroDivisionError)
182
183        try:
184            with test_issue29692():
185                raise StopIteration('issue29692:Unchained')
186        except Exception as ex:
187            self.assertIs(type(ex), StopIteration)
188            self.assertEqual(ex.args[0], 'issue29692:Unchained')
189            self.assertIsNone(ex.__cause__)
190
191    def _create_contextmanager_attribs(self):
192        def attribs(**kw):
193            def decorate(func):
194                for k,v in kw.items():
195                    setattr(func,k,v)
196                return func
197            return decorate
198        @contextmanager
199        @attribs(foo='bar')
200        def baz(spam):
201            """Whee!"""
202        return baz
203
204    def test_contextmanager_attribs(self):
205        baz = self._create_contextmanager_attribs()
206        self.assertEqual(baz.__name__,'baz')
207        self.assertEqual(baz.foo, 'bar')
208
209    @support.requires_docstrings
210    def test_contextmanager_doc_attrib(self):
211        baz = self._create_contextmanager_attribs()
212        self.assertEqual(baz.__doc__, "Whee!")
213
214    @support.requires_docstrings
215    def test_instance_docstring_given_cm_docstring(self):
216        baz = self._create_contextmanager_attribs()(None)
217        self.assertEqual(baz.__doc__, "Whee!")
218
219    def test_keywords(self):
220        # Ensure no keyword arguments are inhibited
221        @contextmanager
222        def woohoo(self, func, args, kwds):
223            yield (self, func, args, kwds)
224        with woohoo(self=11, func=22, args=33, kwds=44) as target:
225            self.assertEqual(target, (11, 22, 33, 44))
226
227    def test_nokeepref(self):
228        class A:
229            pass
230
231        @contextmanager
232        def woohoo(a, b):
233            a = weakref.ref(a)
234            b = weakref.ref(b)
235            # Allow test to work with a non-refcounted GC
236            support.gc_collect()
237            self.assertIsNone(a())
238            self.assertIsNone(b())
239            yield
240
241        with woohoo(A(), b=A()):
242            pass
243
244    def test_param_errors(self):
245        @contextmanager
246        def woohoo(a, *, b):
247            yield
248
249        with self.assertRaises(TypeError):
250            woohoo()
251        with self.assertRaises(TypeError):
252            woohoo(3, 5)
253        with self.assertRaises(TypeError):
254            woohoo(b=3)
255
256    def test_recursive(self):
257        depth = 0
258        @contextmanager
259        def woohoo():
260            nonlocal depth
261            before = depth
262            depth += 1
263            yield
264            depth -= 1
265            self.assertEqual(depth, before)
266
267        @woohoo()
268        def recursive():
269            if depth < 10:
270                recursive()
271
272        recursive()
273        self.assertEqual(depth, 0)
274
275
276class ClosingTestCase(unittest.TestCase):
277
278    @support.requires_docstrings
279    def test_instance_docs(self):
280        # Issue 19330: ensure context manager instances have good docstrings
281        cm_docstring = closing.__doc__
282        obj = closing(None)
283        self.assertEqual(obj.__doc__, cm_docstring)
284
285    def test_closing(self):
286        state = []
287        class C:
288            def close(self):
289                state.append(1)
290        x = C()
291        self.assertEqual(state, [])
292        with closing(x) as y:
293            self.assertEqual(x, y)
294        self.assertEqual(state, [1])
295
296    def test_closing_error(self):
297        state = []
298        class C:
299            def close(self):
300                state.append(1)
301        x = C()
302        self.assertEqual(state, [])
303        with self.assertRaises(ZeroDivisionError):
304            with closing(x) as y:
305                self.assertEqual(x, y)
306                1 / 0
307        self.assertEqual(state, [1])
308
309
310class NullcontextTestCase(unittest.TestCase):
311    def test_nullcontext(self):
312        class C:
313            pass
314        c = C()
315        with nullcontext(c) as c_in:
316            self.assertIs(c_in, c)
317
318
319class FileContextTestCase(unittest.TestCase):
320
321    def testWithOpen(self):
322        tfn = tempfile.mktemp()
323        try:
324            f = None
325            with open(tfn, "w", encoding="utf-8") as f:
326                self.assertFalse(f.closed)
327                f.write("Booh\n")
328            self.assertTrue(f.closed)
329            f = None
330            with self.assertRaises(ZeroDivisionError):
331                with open(tfn, "r", encoding="utf-8") as f:
332                    self.assertFalse(f.closed)
333                    self.assertEqual(f.read(), "Booh\n")
334                    1 / 0
335            self.assertTrue(f.closed)
336        finally:
337            os_helper.unlink(tfn)
338
339class LockContextTestCase(unittest.TestCase):
340
341    def boilerPlate(self, lock, locked):
342        self.assertFalse(locked())
343        with lock:
344            self.assertTrue(locked())
345        self.assertFalse(locked())
346        with self.assertRaises(ZeroDivisionError):
347            with lock:
348                self.assertTrue(locked())
349                1 / 0
350        self.assertFalse(locked())
351
352    def testWithLock(self):
353        lock = threading.Lock()
354        self.boilerPlate(lock, lock.locked)
355
356    def testWithRLock(self):
357        lock = threading.RLock()
358        self.boilerPlate(lock, lock._is_owned)
359
360    def testWithCondition(self):
361        lock = threading.Condition()
362        def locked():
363            return lock._is_owned()
364        self.boilerPlate(lock, locked)
365
366    def testWithSemaphore(self):
367        lock = threading.Semaphore()
368        def locked():
369            if lock.acquire(False):
370                lock.release()
371                return False
372            else:
373                return True
374        self.boilerPlate(lock, locked)
375
376    def testWithBoundedSemaphore(self):
377        lock = threading.BoundedSemaphore()
378        def locked():
379            if lock.acquire(False):
380                lock.release()
381                return False
382            else:
383                return True
384        self.boilerPlate(lock, locked)
385
386
387class mycontext(ContextDecorator):
388    """Example decoration-compatible context manager for testing"""
389    started = False
390    exc = None
391    catch = False
392
393    def __enter__(self):
394        self.started = True
395        return self
396
397    def __exit__(self, *exc):
398        self.exc = exc
399        return self.catch
400
401
402class TestContextDecorator(unittest.TestCase):
403
404    @support.requires_docstrings
405    def test_instance_docs(self):
406        # Issue 19330: ensure context manager instances have good docstrings
407        cm_docstring = mycontext.__doc__
408        obj = mycontext()
409        self.assertEqual(obj.__doc__, cm_docstring)
410
411    def test_contextdecorator(self):
412        context = mycontext()
413        with context as result:
414            self.assertIs(result, context)
415            self.assertTrue(context.started)
416
417        self.assertEqual(context.exc, (None, None, None))
418
419
420    def test_contextdecorator_with_exception(self):
421        context = mycontext()
422
423        with self.assertRaisesRegex(NameError, 'foo'):
424            with context:
425                raise NameError('foo')
426        self.assertIsNotNone(context.exc)
427        self.assertIs(context.exc[0], NameError)
428
429        context = mycontext()
430        context.catch = True
431        with context:
432            raise NameError('foo')
433        self.assertIsNotNone(context.exc)
434        self.assertIs(context.exc[0], NameError)
435
436
437    def test_decorator(self):
438        context = mycontext()
439
440        @context
441        def test():
442            self.assertIsNone(context.exc)
443            self.assertTrue(context.started)
444        test()
445        self.assertEqual(context.exc, (None, None, None))
446
447
448    def test_decorator_with_exception(self):
449        context = mycontext()
450
451        @context
452        def test():
453            self.assertIsNone(context.exc)
454            self.assertTrue(context.started)
455            raise NameError('foo')
456
457        with self.assertRaisesRegex(NameError, 'foo'):
458            test()
459        self.assertIsNotNone(context.exc)
460        self.assertIs(context.exc[0], NameError)
461
462
463    def test_decorating_method(self):
464        context = mycontext()
465
466        class Test(object):
467
468            @context
469            def method(self, a, b, c=None):
470                self.a = a
471                self.b = b
472                self.c = c
473
474        # these tests are for argument passing when used as a decorator
475        test = Test()
476        test.method(1, 2)
477        self.assertEqual(test.a, 1)
478        self.assertEqual(test.b, 2)
479        self.assertEqual(test.c, None)
480
481        test = Test()
482        test.method('a', 'b', 'c')
483        self.assertEqual(test.a, 'a')
484        self.assertEqual(test.b, 'b')
485        self.assertEqual(test.c, 'c')
486
487        test = Test()
488        test.method(a=1, b=2)
489        self.assertEqual(test.a, 1)
490        self.assertEqual(test.b, 2)
491
492
493    def test_typo_enter(self):
494        class mycontext(ContextDecorator):
495            def __unter__(self):
496                pass
497            def __exit__(self, *exc):
498                pass
499
500        with self.assertRaisesRegex(TypeError, 'the context manager'):
501            with mycontext():
502                pass
503
504
505    def test_typo_exit(self):
506        class mycontext(ContextDecorator):
507            def __enter__(self):
508                pass
509            def __uxit__(self, *exc):
510                pass
511
512        with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
513            with mycontext():
514                pass
515
516
517    def test_contextdecorator_as_mixin(self):
518        class somecontext(object):
519            started = False
520            exc = None
521
522            def __enter__(self):
523                self.started = True
524                return self
525
526            def __exit__(self, *exc):
527                self.exc = exc
528
529        class mycontext(somecontext, ContextDecorator):
530            pass
531
532        context = mycontext()
533        @context
534        def test():
535            self.assertIsNone(context.exc)
536            self.assertTrue(context.started)
537        test()
538        self.assertEqual(context.exc, (None, None, None))
539
540
541    def test_contextmanager_as_decorator(self):
542        @contextmanager
543        def woohoo(y):
544            state.append(y)
545            yield
546            state.append(999)
547
548        state = []
549        @woohoo(1)
550        def test(x):
551            self.assertEqual(state, [1])
552            state.append(x)
553        test('something')
554        self.assertEqual(state, [1, 'something', 999])
555
556        # Issue #11647: Ensure the decorated function is 'reusable'
557        state = []
558        test('something else')
559        self.assertEqual(state, [1, 'something else', 999])
560
561
562class TestBaseExitStack:
563    exit_stack = None
564
565    @support.requires_docstrings
566    def test_instance_docs(self):
567        # Issue 19330: ensure context manager instances have good docstrings
568        cm_docstring = self.exit_stack.__doc__
569        obj = self.exit_stack()
570        self.assertEqual(obj.__doc__, cm_docstring)
571
572    def test_no_resources(self):
573        with self.exit_stack():
574            pass
575
576    def test_callback(self):
577        expected = [
578            ((), {}),
579            ((1,), {}),
580            ((1,2), {}),
581            ((), dict(example=1)),
582            ((1,), dict(example=1)),
583            ((1,2), dict(example=1)),
584            ((1,2), dict(self=3, callback=4)),
585        ]
586        result = []
587        def _exit(*args, **kwds):
588            """Test metadata propagation"""
589            result.append((args, kwds))
590        with self.exit_stack() as stack:
591            for args, kwds in reversed(expected):
592                if args and kwds:
593                    f = stack.callback(_exit, *args, **kwds)
594                elif args:
595                    f = stack.callback(_exit, *args)
596                elif kwds:
597                    f = stack.callback(_exit, **kwds)
598                else:
599                    f = stack.callback(_exit)
600                self.assertIs(f, _exit)
601            for wrapper in stack._exit_callbacks:
602                self.assertIs(wrapper[1].__wrapped__, _exit)
603                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
604                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
605        self.assertEqual(result, expected)
606
607        result = []
608        with self.exit_stack() as stack:
609            with self.assertRaises(TypeError):
610                stack.callback(arg=1)
611            with self.assertRaises(TypeError):
612                self.exit_stack.callback(arg=2)
613            with self.assertRaises(TypeError):
614                stack.callback(callback=_exit, arg=3)
615        self.assertEqual(result, [])
616
617    def test_push(self):
618        exc_raised = ZeroDivisionError
619        def _expect_exc(exc_type, exc, exc_tb):
620            self.assertIs(exc_type, exc_raised)
621        def _suppress_exc(*exc_details):
622            return True
623        def _expect_ok(exc_type, exc, exc_tb):
624            self.assertIsNone(exc_type)
625            self.assertIsNone(exc)
626            self.assertIsNone(exc_tb)
627        class ExitCM(object):
628            def __init__(self, check_exc):
629                self.check_exc = check_exc
630            def __enter__(self):
631                self.fail("Should not be called!")
632            def __exit__(self, *exc_details):
633                self.check_exc(*exc_details)
634        with self.exit_stack() as stack:
635            stack.push(_expect_ok)
636            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
637            cm = ExitCM(_expect_ok)
638            stack.push(cm)
639            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
640            stack.push(_suppress_exc)
641            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
642            cm = ExitCM(_expect_exc)
643            stack.push(cm)
644            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
645            stack.push(_expect_exc)
646            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
647            stack.push(_expect_exc)
648            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
649            1/0
650
651    def test_enter_context(self):
652        class TestCM(object):
653            def __enter__(self):
654                result.append(1)
655            def __exit__(self, *exc_details):
656                result.append(3)
657
658        result = []
659        cm = TestCM()
660        with self.exit_stack() as stack:
661            @stack.callback  # Registered first => cleaned up last
662            def _exit():
663                result.append(4)
664            self.assertIsNotNone(_exit)
665            stack.enter_context(cm)
666            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
667            result.append(2)
668        self.assertEqual(result, [1, 2, 3, 4])
669
670    def test_enter_context_errors(self):
671        class LacksEnterAndExit:
672            pass
673        class LacksEnter:
674            def __exit__(self, *exc_info):
675                pass
676        class LacksExit:
677            def __enter__(self):
678                pass
679
680        with self.exit_stack() as stack:
681            with self.assertRaisesRegex(TypeError, 'the context manager'):
682                stack.enter_context(LacksEnterAndExit())
683            with self.assertRaisesRegex(TypeError, 'the context manager'):
684                stack.enter_context(LacksEnter())
685            with self.assertRaisesRegex(TypeError, 'the context manager'):
686                stack.enter_context(LacksExit())
687            self.assertFalse(stack._exit_callbacks)
688
689    def test_close(self):
690        result = []
691        with self.exit_stack() as stack:
692            @stack.callback
693            def _exit():
694                result.append(1)
695            self.assertIsNotNone(_exit)
696            stack.close()
697            result.append(2)
698        self.assertEqual(result, [1, 2])
699
700    def test_pop_all(self):
701        result = []
702        with self.exit_stack() as stack:
703            @stack.callback
704            def _exit():
705                result.append(3)
706            self.assertIsNotNone(_exit)
707            new_stack = stack.pop_all()
708            result.append(1)
709        result.append(2)
710        new_stack.close()
711        self.assertEqual(result, [1, 2, 3])
712
713    def test_exit_raise(self):
714        with self.assertRaises(ZeroDivisionError):
715            with self.exit_stack() as stack:
716                stack.push(lambda *exc: False)
717                1/0
718
719    def test_exit_suppress(self):
720        with self.exit_stack() as stack:
721            stack.push(lambda *exc: True)
722            1/0
723
724    def test_exit_exception_chaining_reference(self):
725        # Sanity check to make sure that ExitStack chaining matches
726        # actual nested with statements
727        class RaiseExc:
728            def __init__(self, exc):
729                self.exc = exc
730            def __enter__(self):
731                return self
732            def __exit__(self, *exc_details):
733                raise self.exc
734
735        class RaiseExcWithContext:
736            def __init__(self, outer, inner):
737                self.outer = outer
738                self.inner = inner
739            def __enter__(self):
740                return self
741            def __exit__(self, *exc_details):
742                try:
743                    raise self.inner
744                except:
745                    raise self.outer
746
747        class SuppressExc:
748            def __enter__(self):
749                return self
750            def __exit__(self, *exc_details):
751                type(self).saved_details = exc_details
752                return True
753
754        try:
755            with RaiseExc(IndexError):
756                with RaiseExcWithContext(KeyError, AttributeError):
757                    with SuppressExc():
758                        with RaiseExc(ValueError):
759                            1 / 0
760        except IndexError as exc:
761            self.assertIsInstance(exc.__context__, KeyError)
762            self.assertIsInstance(exc.__context__.__context__, AttributeError)
763            # Inner exceptions were suppressed
764            self.assertIsNone(exc.__context__.__context__.__context__)
765        else:
766            self.fail("Expected IndexError, but no exception was raised")
767        # Check the inner exceptions
768        inner_exc = SuppressExc.saved_details[1]
769        self.assertIsInstance(inner_exc, ValueError)
770        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
771
772    def test_exit_exception_chaining(self):
773        # Ensure exception chaining matches the reference behaviour
774        def raise_exc(exc):
775            raise exc
776
777        saved_details = None
778        def suppress_exc(*exc_details):
779            nonlocal saved_details
780            saved_details = exc_details
781            return True
782
783        try:
784            with self.exit_stack() as stack:
785                stack.callback(raise_exc, IndexError)
786                stack.callback(raise_exc, KeyError)
787                stack.callback(raise_exc, AttributeError)
788                stack.push(suppress_exc)
789                stack.callback(raise_exc, ValueError)
790                1 / 0
791        except IndexError as exc:
792            self.assertIsInstance(exc.__context__, KeyError)
793            self.assertIsInstance(exc.__context__.__context__, AttributeError)
794            # Inner exceptions were suppressed
795            self.assertIsNone(exc.__context__.__context__.__context__)
796        else:
797            self.fail("Expected IndexError, but no exception was raised")
798        # Check the inner exceptions
799        inner_exc = saved_details[1]
800        self.assertIsInstance(inner_exc, ValueError)
801        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
802
803    def test_exit_exception_explicit_none_context(self):
804        # Ensure ExitStack chaining matches actual nested `with` statements
805        # regarding explicit __context__ = None.
806
807        class MyException(Exception):
808            pass
809
810        @contextmanager
811        def my_cm():
812            try:
813                yield
814            except BaseException:
815                exc = MyException()
816                try:
817                    raise exc
818                finally:
819                    exc.__context__ = None
820
821        @contextmanager
822        def my_cm_with_exit_stack():
823            with self.exit_stack() as stack:
824                stack.enter_context(my_cm())
825                yield stack
826
827        for cm in (my_cm, my_cm_with_exit_stack):
828            with self.subTest():
829                try:
830                    with cm():
831                        raise IndexError()
832                except MyException as exc:
833                    self.assertIsNone(exc.__context__)
834                else:
835                    self.fail("Expected IndexError, but no exception was raised")
836
837    def test_exit_exception_non_suppressing(self):
838        # http://bugs.python.org/issue19092
839        def raise_exc(exc):
840            raise exc
841
842        def suppress_exc(*exc_details):
843            return True
844
845        try:
846            with self.exit_stack() as stack:
847                stack.callback(lambda: None)
848                stack.callback(raise_exc, IndexError)
849        except Exception as exc:
850            self.assertIsInstance(exc, IndexError)
851        else:
852            self.fail("Expected IndexError, but no exception was raised")
853
854        try:
855            with self.exit_stack() as stack:
856                stack.callback(raise_exc, KeyError)
857                stack.push(suppress_exc)
858                stack.callback(raise_exc, IndexError)
859        except Exception as exc:
860            self.assertIsInstance(exc, KeyError)
861        else:
862            self.fail("Expected KeyError, but no exception was raised")
863
864    def test_exit_exception_with_correct_context(self):
865        # http://bugs.python.org/issue20317
866        @contextmanager
867        def gets_the_context_right(exc):
868            try:
869                yield
870            finally:
871                raise exc
872
873        exc1 = Exception(1)
874        exc2 = Exception(2)
875        exc3 = Exception(3)
876        exc4 = Exception(4)
877
878        # The contextmanager already fixes the context, so prior to the
879        # fix, ExitStack would try to fix it *again* and get into an
880        # infinite self-referential loop
881        try:
882            with self.exit_stack() as stack:
883                stack.enter_context(gets_the_context_right(exc4))
884                stack.enter_context(gets_the_context_right(exc3))
885                stack.enter_context(gets_the_context_right(exc2))
886                raise exc1
887        except Exception as exc:
888            self.assertIs(exc, exc4)
889            self.assertIs(exc.__context__, exc3)
890            self.assertIs(exc.__context__.__context__, exc2)
891            self.assertIs(exc.__context__.__context__.__context__, exc1)
892            self.assertIsNone(
893                       exc.__context__.__context__.__context__.__context__)
894
895    def test_exit_exception_with_existing_context(self):
896        # Addresses a lack of test coverage discovered after checking in a
897        # fix for issue 20317 that still contained debugging code.
898        def raise_nested(inner_exc, outer_exc):
899            try:
900                raise inner_exc
901            finally:
902                raise outer_exc
903        exc1 = Exception(1)
904        exc2 = Exception(2)
905        exc3 = Exception(3)
906        exc4 = Exception(4)
907        exc5 = Exception(5)
908        try:
909            with self.exit_stack() as stack:
910                stack.callback(raise_nested, exc4, exc5)
911                stack.callback(raise_nested, exc2, exc3)
912                raise exc1
913        except Exception as exc:
914            self.assertIs(exc, exc5)
915            self.assertIs(exc.__context__, exc4)
916            self.assertIs(exc.__context__.__context__, exc3)
917            self.assertIs(exc.__context__.__context__.__context__, exc2)
918            self.assertIs(
919                 exc.__context__.__context__.__context__.__context__, exc1)
920            self.assertIsNone(
921                exc.__context__.__context__.__context__.__context__.__context__)
922
923    def test_body_exception_suppress(self):
924        def suppress_exc(*exc_details):
925            return True
926        try:
927            with self.exit_stack() as stack:
928                stack.push(suppress_exc)
929                1/0
930        except IndexError as exc:
931            self.fail("Expected no exception, got IndexError")
932
933    def test_exit_exception_chaining_suppress(self):
934        with self.exit_stack() as stack:
935            stack.push(lambda *exc: True)
936            stack.push(lambda *exc: 1/0)
937            stack.push(lambda *exc: {}[1])
938
939    def test_excessive_nesting(self):
940        # The original implementation would die with RecursionError here
941        with self.exit_stack() as stack:
942            for i in range(10000):
943                stack.callback(int)
944
945    def test_instance_bypass(self):
946        class Example(object): pass
947        cm = Example()
948        cm.__enter__ = object()
949        cm.__exit__ = object()
950        stack = self.exit_stack()
951        with self.assertRaisesRegex(TypeError, 'the context manager'):
952            stack.enter_context(cm)
953        stack.push(cm)
954        self.assertIs(stack._exit_callbacks[-1][1], cm)
955
956    def test_dont_reraise_RuntimeError(self):
957        # https://bugs.python.org/issue27122
958        class UniqueException(Exception): pass
959        class UniqueRuntimeError(RuntimeError): pass
960
961        @contextmanager
962        def second():
963            try:
964                yield 1
965            except Exception as exc:
966                raise UniqueException("new exception") from exc
967
968        @contextmanager
969        def first():
970            try:
971                yield 1
972            except Exception as exc:
973                raise exc
974
975        # The UniqueRuntimeError should be caught by second()'s exception
976        # handler which chain raised a new UniqueException.
977        with self.assertRaises(UniqueException) as err_ctx:
978            with self.exit_stack() as es_ctx:
979                es_ctx.enter_context(second())
980                es_ctx.enter_context(first())
981                raise UniqueRuntimeError("please no infinite loop.")
982
983        exc = err_ctx.exception
984        self.assertIsInstance(exc, UniqueException)
985        self.assertIsInstance(exc.__context__, UniqueRuntimeError)
986        self.assertIsNone(exc.__context__.__context__)
987        self.assertIsNone(exc.__context__.__cause__)
988        self.assertIs(exc.__cause__, exc.__context__)
989
990
991class TestExitStack(TestBaseExitStack, unittest.TestCase):
992    exit_stack = ExitStack
993
994
995class TestRedirectStream:
996
997    redirect_stream = None
998    orig_stream = None
999
1000    @support.requires_docstrings
1001    def test_instance_docs(self):
1002        # Issue 19330: ensure context manager instances have good docstrings
1003        cm_docstring = self.redirect_stream.__doc__
1004        obj = self.redirect_stream(None)
1005        self.assertEqual(obj.__doc__, cm_docstring)
1006
1007    def test_no_redirect_in_init(self):
1008        orig_stdout = getattr(sys, self.orig_stream)
1009        self.redirect_stream(None)
1010        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1011
1012    def test_redirect_to_string_io(self):
1013        f = io.StringIO()
1014        msg = "Consider an API like help(), which prints directly to stdout"
1015        orig_stdout = getattr(sys, self.orig_stream)
1016        with self.redirect_stream(f):
1017            print(msg, file=getattr(sys, self.orig_stream))
1018        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1019        s = f.getvalue().strip()
1020        self.assertEqual(s, msg)
1021
1022    def test_enter_result_is_target(self):
1023        f = io.StringIO()
1024        with self.redirect_stream(f) as enter_result:
1025            self.assertIs(enter_result, f)
1026
1027    def test_cm_is_reusable(self):
1028        f = io.StringIO()
1029        write_to_f = self.redirect_stream(f)
1030        orig_stdout = getattr(sys, self.orig_stream)
1031        with write_to_f:
1032            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1033        with write_to_f:
1034            print("World!", file=getattr(sys, self.orig_stream))
1035        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1036        s = f.getvalue()
1037        self.assertEqual(s, "Hello World!\n")
1038
1039    def test_cm_is_reentrant(self):
1040        f = io.StringIO()
1041        write_to_f = self.redirect_stream(f)
1042        orig_stdout = getattr(sys, self.orig_stream)
1043        with write_to_f:
1044            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1045            with write_to_f:
1046                print("World!", file=getattr(sys, self.orig_stream))
1047        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1048        s = f.getvalue()
1049        self.assertEqual(s, "Hello World!\n")
1050
1051
1052class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
1053
1054    redirect_stream = redirect_stdout
1055    orig_stream = "stdout"
1056
1057
1058class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
1059
1060    redirect_stream = redirect_stderr
1061    orig_stream = "stderr"
1062
1063
1064class TestSuppress(unittest.TestCase):
1065
1066    @support.requires_docstrings
1067    def test_instance_docs(self):
1068        # Issue 19330: ensure context manager instances have good docstrings
1069        cm_docstring = suppress.__doc__
1070        obj = suppress()
1071        self.assertEqual(obj.__doc__, cm_docstring)
1072
1073    def test_no_result_from_enter(self):
1074        with suppress(ValueError) as enter_result:
1075            self.assertIsNone(enter_result)
1076
1077    def test_no_exception(self):
1078        with suppress(ValueError):
1079            self.assertEqual(pow(2, 5), 32)
1080
1081    def test_exact_exception(self):
1082        with suppress(TypeError):
1083            len(5)
1084
1085    def test_exception_hierarchy(self):
1086        with suppress(LookupError):
1087            'Hello'[50]
1088
1089    def test_other_exception(self):
1090        with self.assertRaises(ZeroDivisionError):
1091            with suppress(TypeError):
1092                1/0
1093
1094    def test_no_args(self):
1095        with self.assertRaises(ZeroDivisionError):
1096            with suppress():
1097                1/0
1098
1099    def test_multiple_exception_args(self):
1100        with suppress(ZeroDivisionError, TypeError):
1101            1/0
1102        with suppress(ZeroDivisionError, TypeError):
1103            len(5)
1104
1105    def test_cm_is_reentrant(self):
1106        ignore_exceptions = suppress(Exception)
1107        with ignore_exceptions:
1108            pass
1109        with ignore_exceptions:
1110            len(5)
1111        with ignore_exceptions:
1112            with ignore_exceptions: # Check nested usage
1113                len(5)
1114            outer_continued = True
1115            1/0
1116        self.assertTrue(outer_continued)
1117
1118
1119class TestChdir(unittest.TestCase):
1120    def test_simple(self):
1121        old_cwd = os.getcwd()
1122        target = os.path.join(os.path.dirname(__file__), 'data')
1123        self.assertNotEqual(old_cwd, target)
1124
1125        with chdir(target):
1126            self.assertEqual(os.getcwd(), target)
1127        self.assertEqual(os.getcwd(), old_cwd)
1128
1129    def test_reentrant(self):
1130        old_cwd = os.getcwd()
1131        target1 = os.path.join(os.path.dirname(__file__), 'data')
1132        target2 = os.path.join(os.path.dirname(__file__), 'ziptestdata')
1133        self.assertNotIn(old_cwd, (target1, target2))
1134        chdir1, chdir2 = chdir(target1), chdir(target2)
1135
1136        with chdir1:
1137            self.assertEqual(os.getcwd(), target1)
1138            with chdir2:
1139                self.assertEqual(os.getcwd(), target2)
1140                with chdir1:
1141                    self.assertEqual(os.getcwd(), target1)
1142                self.assertEqual(os.getcwd(), target2)
1143            self.assertEqual(os.getcwd(), target1)
1144        self.assertEqual(os.getcwd(), old_cwd)
1145
1146    def test_exception(self):
1147        old_cwd = os.getcwd()
1148        target = os.path.join(os.path.dirname(__file__), 'data')
1149        self.assertNotEqual(old_cwd, target)
1150
1151        try:
1152            with chdir(target):
1153                self.assertEqual(os.getcwd(), target)
1154                raise RuntimeError("boom")
1155        except RuntimeError as re:
1156            self.assertEqual(str(re), "boom")
1157        self.assertEqual(os.getcwd(), old_cwd)
1158
1159
1160if __name__ == "__main__":
1161    unittest.main()
1162