1import asyncio
2from contextlib import (
3    asynccontextmanager, AbstractAsyncContextManager,
4    AsyncExitStack, nullcontext, aclosing, contextmanager)
5import functools
6from test import support
7import unittest
8
9from test.test_contextlib import TestBaseExitStack
10
11
12def _async_test(func):
13    """Decorator to turn an async function into a test case."""
14    @functools.wraps(func)
15    def wrapper(*args, **kwargs):
16        coro = func(*args, **kwargs)
17        loop = asyncio.new_event_loop()
18        asyncio.set_event_loop(loop)
19        try:
20            return loop.run_until_complete(coro)
21        finally:
22            loop.close()
23            asyncio.set_event_loop_policy(None)
24    return wrapper
25
26
27class TestAbstractAsyncContextManager(unittest.TestCase):
28
29    @_async_test
30    async def test_enter(self):
31        class DefaultEnter(AbstractAsyncContextManager):
32            async def __aexit__(self, *args):
33                await super().__aexit__(*args)
34
35        manager = DefaultEnter()
36        self.assertIs(await manager.__aenter__(), manager)
37
38        async with manager as context:
39            self.assertIs(manager, context)
40
41    @_async_test
42    async def test_async_gen_propagates_generator_exit(self):
43        # A regression test for https://bugs.python.org/issue33786.
44
45        @asynccontextmanager
46        async def ctx():
47            yield
48
49        async def gen():
50            async with ctx():
51                yield 11
52
53        ret = []
54        exc = ValueError(22)
55        with self.assertRaises(ValueError):
56            async with ctx():
57                async for val in gen():
58                    ret.append(val)
59                    raise exc
60
61        self.assertEqual(ret, [11])
62
63    def test_exit_is_abstract(self):
64        class MissingAexit(AbstractAsyncContextManager):
65            pass
66
67        with self.assertRaises(TypeError):
68            MissingAexit()
69
70    def test_structural_subclassing(self):
71        class ManagerFromScratch:
72            async def __aenter__(self):
73                return self
74            async def __aexit__(self, exc_type, exc_value, traceback):
75                return None
76
77        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
78
79        class DefaultEnter(AbstractAsyncContextManager):
80            async def __aexit__(self, *args):
81                await super().__aexit__(*args)
82
83        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
84
85        class NoneAenter(ManagerFromScratch):
86            __aenter__ = None
87
88        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
89
90        class NoneAexit(ManagerFromScratch):
91            __aexit__ = None
92
93        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
94
95
96class AsyncContextManagerTestCase(unittest.TestCase):
97
98    @_async_test
99    async def test_contextmanager_plain(self):
100        state = []
101        @asynccontextmanager
102        async def woohoo():
103            state.append(1)
104            yield 42
105            state.append(999)
106        async with woohoo() as x:
107            self.assertEqual(state, [1])
108            self.assertEqual(x, 42)
109            state.append(x)
110        self.assertEqual(state, [1, 42, 999])
111
112    @_async_test
113    async def test_contextmanager_finally(self):
114        state = []
115        @asynccontextmanager
116        async def woohoo():
117            state.append(1)
118            try:
119                yield 42
120            finally:
121                state.append(999)
122        with self.assertRaises(ZeroDivisionError):
123            async with woohoo() as x:
124                self.assertEqual(state, [1])
125                self.assertEqual(x, 42)
126                state.append(x)
127                raise ZeroDivisionError()
128        self.assertEqual(state, [1, 42, 999])
129
130    @_async_test
131    async def test_contextmanager_no_reraise(self):
132        @asynccontextmanager
133        async def whee():
134            yield
135        ctx = whee()
136        await ctx.__aenter__()
137        # Calling __aexit__ should not result in an exception
138        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
139
140    @_async_test
141    async def test_contextmanager_trap_yield_after_throw(self):
142        @asynccontextmanager
143        async def whoo():
144            try:
145                yield
146            except:
147                yield
148        ctx = whoo()
149        await ctx.__aenter__()
150        with self.assertRaises(RuntimeError):
151            await ctx.__aexit__(TypeError, TypeError('foo'), None)
152
153    @_async_test
154    async def test_contextmanager_trap_no_yield(self):
155        @asynccontextmanager
156        async def whoo():
157            if False:
158                yield
159        ctx = whoo()
160        with self.assertRaises(RuntimeError):
161            await ctx.__aenter__()
162
163    @_async_test
164    async def test_contextmanager_trap_second_yield(self):
165        @asynccontextmanager
166        async def whoo():
167            yield
168            yield
169        ctx = whoo()
170        await ctx.__aenter__()
171        with self.assertRaises(RuntimeError):
172            await ctx.__aexit__(None, None, None)
173
174    @_async_test
175    async def test_contextmanager_non_normalised(self):
176        @asynccontextmanager
177        async def whoo():
178            try:
179                yield
180            except RuntimeError:
181                raise SyntaxError
182
183        ctx = whoo()
184        await ctx.__aenter__()
185        with self.assertRaises(SyntaxError):
186            await ctx.__aexit__(RuntimeError, None, None)
187
188    @_async_test
189    async def test_contextmanager_except(self):
190        state = []
191        @asynccontextmanager
192        async def woohoo():
193            state.append(1)
194            try:
195                yield 42
196            except ZeroDivisionError as e:
197                state.append(e.args[0])
198                self.assertEqual(state, [1, 42, 999])
199        async with woohoo() as x:
200            self.assertEqual(state, [1])
201            self.assertEqual(x, 42)
202            state.append(x)
203            raise ZeroDivisionError(999)
204        self.assertEqual(state, [1, 42, 999])
205
206    @_async_test
207    async def test_contextmanager_except_stopiter(self):
208        @asynccontextmanager
209        async def woohoo():
210            yield
211
212        class StopIterationSubclass(StopIteration):
213            pass
214
215        class StopAsyncIterationSubclass(StopAsyncIteration):
216            pass
217
218        for stop_exc in (
219            StopIteration('spam'),
220            StopAsyncIteration('ham'),
221            StopIterationSubclass('spam'),
222            StopAsyncIterationSubclass('spam')
223        ):
224            with self.subTest(type=type(stop_exc)):
225                try:
226                    async with woohoo():
227                        raise stop_exc
228                except Exception as ex:
229                    self.assertIs(ex, stop_exc)
230                else:
231                    self.fail(f'{stop_exc} was suppressed')
232
233    @_async_test
234    async def test_contextmanager_wrap_runtimeerror(self):
235        @asynccontextmanager
236        async def woohoo():
237            try:
238                yield
239            except Exception as exc:
240                raise RuntimeError(f'caught {exc}') from exc
241
242        with self.assertRaises(RuntimeError):
243            async with woohoo():
244                1 / 0
245
246        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
247        # we also unwrap it, because we can't tell whether the wrapping was
248        # done by the generator machinery or by the generator itself.
249        with self.assertRaises(StopAsyncIteration):
250            async with woohoo():
251                raise StopAsyncIteration
252
253    def _create_contextmanager_attribs(self):
254        def attribs(**kw):
255            def decorate(func):
256                for k,v in kw.items():
257                    setattr(func,k,v)
258                return func
259            return decorate
260        @asynccontextmanager
261        @attribs(foo='bar')
262        async def baz(spam):
263            """Whee!"""
264            yield
265        return baz
266
267    def test_contextmanager_attribs(self):
268        baz = self._create_contextmanager_attribs()
269        self.assertEqual(baz.__name__,'baz')
270        self.assertEqual(baz.foo, 'bar')
271
272    @support.requires_docstrings
273    def test_contextmanager_doc_attrib(self):
274        baz = self._create_contextmanager_attribs()
275        self.assertEqual(baz.__doc__, "Whee!")
276
277    @support.requires_docstrings
278    @_async_test
279    async def test_instance_docstring_given_cm_docstring(self):
280        baz = self._create_contextmanager_attribs()(None)
281        self.assertEqual(baz.__doc__, "Whee!")
282        async with baz:
283            pass  # suppress warning
284
285    @_async_test
286    async def test_keywords(self):
287        # Ensure no keyword arguments are inhibited
288        @asynccontextmanager
289        async def woohoo(self, func, args, kwds):
290            yield (self, func, args, kwds)
291        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
292            self.assertEqual(target, (11, 22, 33, 44))
293
294    @_async_test
295    async def test_recursive(self):
296        depth = 0
297        ncols = 0
298
299        @asynccontextmanager
300        async def woohoo():
301            nonlocal ncols
302            ncols += 1
303
304            nonlocal depth
305            before = depth
306            depth += 1
307            yield
308            depth -= 1
309            self.assertEqual(depth, before)
310
311        @woohoo()
312        async def recursive():
313            if depth < 10:
314                await recursive()
315
316        await recursive()
317
318        self.assertEqual(ncols, 10)
319        self.assertEqual(depth, 0)
320
321    @_async_test
322    async def test_decorator(self):
323        entered = False
324
325        @asynccontextmanager
326        async def context():
327            nonlocal entered
328            entered = True
329            yield
330            entered = False
331
332        @context()
333        async def test():
334            self.assertTrue(entered)
335
336        self.assertFalse(entered)
337        await test()
338        self.assertFalse(entered)
339
340    @_async_test
341    async def test_decorator_with_exception(self):
342        entered = False
343
344        @asynccontextmanager
345        async def context():
346            nonlocal entered
347            try:
348                entered = True
349                yield
350            finally:
351                entered = False
352
353        @context()
354        async def test():
355            self.assertTrue(entered)
356            raise NameError('foo')
357
358        self.assertFalse(entered)
359        with self.assertRaisesRegex(NameError, 'foo'):
360            await test()
361        self.assertFalse(entered)
362
363    @_async_test
364    async def test_decorating_method(self):
365
366        @asynccontextmanager
367        async def context():
368            yield
369
370
371        class Test(object):
372
373            @context()
374            async def method(self, a, b, c=None):
375                self.a = a
376                self.b = b
377                self.c = c
378
379        # these tests are for argument passing when used as a decorator
380        test = Test()
381        await test.method(1, 2)
382        self.assertEqual(test.a, 1)
383        self.assertEqual(test.b, 2)
384        self.assertEqual(test.c, None)
385
386        test = Test()
387        await test.method('a', 'b', 'c')
388        self.assertEqual(test.a, 'a')
389        self.assertEqual(test.b, 'b')
390        self.assertEqual(test.c, 'c')
391
392        test = Test()
393        await test.method(a=1, b=2)
394        self.assertEqual(test.a, 1)
395        self.assertEqual(test.b, 2)
396
397
398class AclosingTestCase(unittest.TestCase):
399
400    @support.requires_docstrings
401    def test_instance_docs(self):
402        cm_docstring = aclosing.__doc__
403        obj = aclosing(None)
404        self.assertEqual(obj.__doc__, cm_docstring)
405
406    @_async_test
407    async def test_aclosing(self):
408        state = []
409        class C:
410            async def aclose(self):
411                state.append(1)
412        x = C()
413        self.assertEqual(state, [])
414        async with aclosing(x) as y:
415            self.assertEqual(x, y)
416        self.assertEqual(state, [1])
417
418    @_async_test
419    async def test_aclosing_error(self):
420        state = []
421        class C:
422            async def aclose(self):
423                state.append(1)
424        x = C()
425        self.assertEqual(state, [])
426        with self.assertRaises(ZeroDivisionError):
427            async with aclosing(x) as y:
428                self.assertEqual(x, y)
429                1 / 0
430        self.assertEqual(state, [1])
431
432    @_async_test
433    async def test_aclosing_bpo41229(self):
434        state = []
435
436        @contextmanager
437        def sync_resource():
438            try:
439                yield
440            finally:
441                state.append(1)
442
443        async def agenfunc():
444            with sync_resource():
445                yield -1
446                yield -2
447
448        x = agenfunc()
449        self.assertEqual(state, [])
450        with self.assertRaises(ZeroDivisionError):
451            async with aclosing(x) as y:
452                self.assertEqual(x, y)
453                self.assertEqual(-1, await x.__anext__())
454                1 / 0
455        self.assertEqual(state, [1])
456
457
458class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
459    class SyncAsyncExitStack(AsyncExitStack):
460        @staticmethod
461        def run_coroutine(coro):
462            loop = asyncio.get_event_loop_policy().get_event_loop()
463            t = loop.create_task(coro)
464            t.add_done_callback(lambda f: loop.stop())
465            loop.run_forever()
466
467            exc = t.exception()
468            if not exc:
469                return t.result()
470            else:
471                context = exc.__context__
472
473                try:
474                    raise exc
475                except:
476                    exc.__context__ = context
477                    raise exc
478
479        def close(self):
480            return self.run_coroutine(self.aclose())
481
482        def __enter__(self):
483            return self.run_coroutine(self.__aenter__())
484
485        def __exit__(self, *exc_details):
486            return self.run_coroutine(self.__aexit__(*exc_details))
487
488    exit_stack = SyncAsyncExitStack
489
490    def setUp(self):
491        self.loop = asyncio.new_event_loop()
492        asyncio.set_event_loop(self.loop)
493        self.addCleanup(self.loop.close)
494        self.addCleanup(asyncio.set_event_loop_policy, None)
495
496    @_async_test
497    async def test_async_callback(self):
498        expected = [
499            ((), {}),
500            ((1,), {}),
501            ((1,2), {}),
502            ((), dict(example=1)),
503            ((1,), dict(example=1)),
504            ((1,2), dict(example=1)),
505        ]
506        result = []
507        async def _exit(*args, **kwds):
508            """Test metadata propagation"""
509            result.append((args, kwds))
510
511        async with AsyncExitStack() as stack:
512            for args, kwds in reversed(expected):
513                if args and kwds:
514                    f = stack.push_async_callback(_exit, *args, **kwds)
515                elif args:
516                    f = stack.push_async_callback(_exit, *args)
517                elif kwds:
518                    f = stack.push_async_callback(_exit, **kwds)
519                else:
520                    f = stack.push_async_callback(_exit)
521                self.assertIs(f, _exit)
522            for wrapper in stack._exit_callbacks:
523                self.assertIs(wrapper[1].__wrapped__, _exit)
524                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
525                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
526
527        self.assertEqual(result, expected)
528
529        result = []
530        async with AsyncExitStack() as stack:
531            with self.assertRaises(TypeError):
532                stack.push_async_callback(arg=1)
533            with self.assertRaises(TypeError):
534                self.exit_stack.push_async_callback(arg=2)
535            with self.assertRaises(TypeError):
536                stack.push_async_callback(callback=_exit, arg=3)
537        self.assertEqual(result, [])
538
539    @_async_test
540    async def test_async_push(self):
541        exc_raised = ZeroDivisionError
542        async def _expect_exc(exc_type, exc, exc_tb):
543            self.assertIs(exc_type, exc_raised)
544        async def _suppress_exc(*exc_details):
545            return True
546        async def _expect_ok(exc_type, exc, exc_tb):
547            self.assertIsNone(exc_type)
548            self.assertIsNone(exc)
549            self.assertIsNone(exc_tb)
550        class ExitCM(object):
551            def __init__(self, check_exc):
552                self.check_exc = check_exc
553            async def __aenter__(self):
554                self.fail("Should not be called!")
555            async def __aexit__(self, *exc_details):
556                await self.check_exc(*exc_details)
557
558        async with self.exit_stack() as stack:
559            stack.push_async_exit(_expect_ok)
560            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
561            cm = ExitCM(_expect_ok)
562            stack.push_async_exit(cm)
563            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
564            stack.push_async_exit(_suppress_exc)
565            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
566            cm = ExitCM(_expect_exc)
567            stack.push_async_exit(cm)
568            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
569            stack.push_async_exit(_expect_exc)
570            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
571            stack.push_async_exit(_expect_exc)
572            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
573            1/0
574
575    @_async_test
576    async def test_enter_async_context(self):
577        class TestCM(object):
578            async def __aenter__(self):
579                result.append(1)
580            async def __aexit__(self, *exc_details):
581                result.append(3)
582
583        result = []
584        cm = TestCM()
585
586        async with AsyncExitStack() as stack:
587            @stack.push_async_callback  # Registered first => cleaned up last
588            async def _exit():
589                result.append(4)
590            self.assertIsNotNone(_exit)
591            await stack.enter_async_context(cm)
592            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
593            result.append(2)
594
595        self.assertEqual(result, [1, 2, 3, 4])
596
597    @_async_test
598    async def test_enter_async_context_errors(self):
599        class LacksEnterAndExit:
600            pass
601        class LacksEnter:
602            async def __aexit__(self, *exc_info):
603                pass
604        class LacksExit:
605            async def __aenter__(self):
606                pass
607
608        async with self.exit_stack() as stack:
609            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
610                await stack.enter_async_context(LacksEnterAndExit())
611            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
612                await stack.enter_async_context(LacksEnter())
613            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
614                await stack.enter_async_context(LacksExit())
615            self.assertFalse(stack._exit_callbacks)
616
617    @_async_test
618    async def test_async_exit_exception_chaining(self):
619        # Ensure exception chaining matches the reference behaviour
620        async def raise_exc(exc):
621            raise exc
622
623        saved_details = None
624        async def suppress_exc(*exc_details):
625            nonlocal saved_details
626            saved_details = exc_details
627            return True
628
629        try:
630            async with self.exit_stack() as stack:
631                stack.push_async_callback(raise_exc, IndexError)
632                stack.push_async_callback(raise_exc, KeyError)
633                stack.push_async_callback(raise_exc, AttributeError)
634                stack.push_async_exit(suppress_exc)
635                stack.push_async_callback(raise_exc, ValueError)
636                1 / 0
637        except IndexError as exc:
638            self.assertIsInstance(exc.__context__, KeyError)
639            self.assertIsInstance(exc.__context__.__context__, AttributeError)
640            # Inner exceptions were suppressed
641            self.assertIsNone(exc.__context__.__context__.__context__)
642        else:
643            self.fail("Expected IndexError, but no exception was raised")
644        # Check the inner exceptions
645        inner_exc = saved_details[1]
646        self.assertIsInstance(inner_exc, ValueError)
647        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
648
649    @_async_test
650    async def test_async_exit_exception_explicit_none_context(self):
651        # Ensure AsyncExitStack chaining matches actual nested `with` statements
652        # regarding explicit __context__ = None.
653
654        class MyException(Exception):
655            pass
656
657        @asynccontextmanager
658        async def my_cm():
659            try:
660                yield
661            except BaseException:
662                exc = MyException()
663                try:
664                    raise exc
665                finally:
666                    exc.__context__ = None
667
668        @asynccontextmanager
669        async def my_cm_with_exit_stack():
670            async with self.exit_stack() as stack:
671                await stack.enter_async_context(my_cm())
672                yield stack
673
674        for cm in (my_cm, my_cm_with_exit_stack):
675            with self.subTest():
676                try:
677                    async with cm():
678                        raise IndexError()
679                except MyException as exc:
680                    self.assertIsNone(exc.__context__)
681                else:
682                    self.fail("Expected IndexError, but no exception was raised")
683
684    @_async_test
685    async def test_instance_bypass_async(self):
686        class Example(object): pass
687        cm = Example()
688        cm.__aenter__ = object()
689        cm.__aexit__ = object()
690        stack = self.exit_stack()
691        with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
692            await stack.enter_async_context(cm)
693        stack.push_async_exit(cm)
694        self.assertIs(stack._exit_callbacks[-1][1], cm)
695
696
697class TestAsyncNullcontext(unittest.TestCase):
698    @_async_test
699    async def test_async_nullcontext(self):
700        class C:
701            pass
702        c = C()
703        async with nullcontext(c) as c_in:
704            self.assertIs(c_in, c)
705
706
707if __name__ == '__main__':
708    unittest.main()
709