1import asyncio
2from contextlib import (
3    asynccontextmanager, AbstractAsyncContextManager,
4    AsyncExitStack, nullcontext, aclosing, contextmanager)
5import functools
6from test import support
7import unittest
8
9from test.test_contextlib import TestBaseExitStack
10
11
12def _async_test(func):
13    """Decorator to turn an async function into a test case."""
14    @functools.wraps(func)
15    def wrapper(*args, **kwargs):
16        coro = func(*args, **kwargs)
17        loop = asyncio.new_event_loop()
18        asyncio.set_event_loop(loop)
19        try:
20            return loop.run_until_complete(coro)
21        finally:
22            loop.close()
23            asyncio.set_event_loop_policy(None)
24    return wrapper
25
26
27class TestAbstractAsyncContextManager(unittest.TestCase):
28
29    @_async_test
30    async def test_enter(self):
31        class DefaultEnter(AbstractAsyncContextManager):
32            async def __aexit__(self, *args):
33                await super().__aexit__(*args)
34
35        manager = DefaultEnter()
36        self.assertIs(await manager.__aenter__(), manager)
37
38        async with manager as context:
39            self.assertIs(manager, context)
40
41    @_async_test
42    async def test_async_gen_propagates_generator_exit(self):
43        # A regression test for https://bugs.python.org/issue33786.
44
45        @asynccontextmanager
46        async def ctx():
47            yield
48
49        async def gen():
50            async with ctx():
51                yield 11
52
53        ret = []
54        exc = ValueError(22)
55        with self.assertRaises(ValueError):
56            async with ctx():
57                async for val in gen():
58                    ret.append(val)
59                    raise exc
60
61        self.assertEqual(ret, [11])
62
63    def test_exit_is_abstract(self):
64        class MissingAexit(AbstractAsyncContextManager):
65            pass
66
67        with self.assertRaises(TypeError):
68            MissingAexit()
69
70    def test_structural_subclassing(self):
71        class ManagerFromScratch:
72            async def __aenter__(self):
73                return self
74            async def __aexit__(self, exc_type, exc_value, traceback):
75                return None
76
77        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
78
79        class DefaultEnter(AbstractAsyncContextManager):
80            async def __aexit__(self, *args):
81                await super().__aexit__(*args)
82
83        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
84
85        class NoneAenter(ManagerFromScratch):
86            __aenter__ = None
87
88        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
89
90        class NoneAexit(ManagerFromScratch):
91            __aexit__ = None
92
93        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
94
95
96class AsyncContextManagerTestCase(unittest.TestCase):
97
98    @_async_test
99    async def test_contextmanager_plain(self):
100        state = []
101        @asynccontextmanager
102        async def woohoo():
103            state.append(1)
104            yield 42
105            state.append(999)
106        async with woohoo() as x:
107            self.assertEqual(state, [1])
108            self.assertEqual(x, 42)
109            state.append(x)
110        self.assertEqual(state, [1, 42, 999])
111
112    @_async_test
113    async def test_contextmanager_finally(self):
114        state = []
115        @asynccontextmanager
116        async def woohoo():
117            state.append(1)
118            try:
119                yield 42
120            finally:
121                state.append(999)
122        with self.assertRaises(ZeroDivisionError):
123            async with woohoo() as x:
124                self.assertEqual(state, [1])
125                self.assertEqual(x, 42)
126                state.append(x)
127                raise ZeroDivisionError()
128        self.assertEqual(state, [1, 42, 999])
129
130    @_async_test
131    async def test_contextmanager_no_reraise(self):
132        @asynccontextmanager
133        async def whee():
134            yield
135        ctx = whee()
136        await ctx.__aenter__()
137        # Calling __aexit__ should not result in an exception
138        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
139
140    @_async_test
141    async def test_contextmanager_trap_yield_after_throw(self):
142        @asynccontextmanager
143        async def whoo():
144            try:
145                yield
146            except:
147                yield
148        ctx = whoo()
149        await ctx.__aenter__()
150        with self.assertRaises(RuntimeError):
151            await ctx.__aexit__(TypeError, TypeError('foo'), None)
152
153    @_async_test
154    async def test_contextmanager_trap_no_yield(self):
155        @asynccontextmanager
156        async def whoo():
157            if False:
158                yield
159        ctx = whoo()
160        with self.assertRaises(RuntimeError):
161            await ctx.__aenter__()
162
163    @_async_test
164    async def test_contextmanager_trap_second_yield(self):
165        @asynccontextmanager
166        async def whoo():
167            yield
168            yield
169        ctx = whoo()
170        await ctx.__aenter__()
171        with self.assertRaises(RuntimeError):
172            await ctx.__aexit__(None, None, None)
173
174    @_async_test
175    async def test_contextmanager_non_normalised(self):
176        @asynccontextmanager
177        async def whoo():
178            try:
179                yield
180            except RuntimeError:
181                raise SyntaxError
182
183        ctx = whoo()
184        await ctx.__aenter__()
185        with self.assertRaises(SyntaxError):
186            await ctx.__aexit__(RuntimeError, None, None)
187
188    @_async_test
189    async def test_contextmanager_except(self):
190        state = []
191        @asynccontextmanager
192        async def woohoo():
193            state.append(1)
194            try:
195                yield 42
196            except ZeroDivisionError as e:
197                state.append(e.args[0])
198                self.assertEqual(state, [1, 42, 999])
199        async with woohoo() as x:
200            self.assertEqual(state, [1])
201            self.assertEqual(x, 42)
202            state.append(x)
203            raise ZeroDivisionError(999)
204        self.assertEqual(state, [1, 42, 999])
205
206    @_async_test
207    async def test_contextmanager_except_stopiter(self):
208        @asynccontextmanager
209        async def woohoo():
210            yield
211
212        class StopIterationSubclass(StopIteration):
213            pass
214
215        class StopAsyncIterationSubclass(StopAsyncIteration):
216            pass
217
218        for stop_exc in (
219            StopIteration('spam'),
220            StopAsyncIteration('ham'),
221            StopIterationSubclass('spam'),
222            StopAsyncIterationSubclass('spam')
223        ):
224            with self.subTest(type=type(stop_exc)):
225                try:
226                    async with woohoo():
227                        raise stop_exc
228                except Exception as ex:
229                    self.assertIs(ex, stop_exc)
230                else:
231                    self.fail(f'{stop_exc} was suppressed')
232
233    @_async_test
234    async def test_contextmanager_wrap_runtimeerror(self):
235        @asynccontextmanager
236        async def woohoo():
237            try:
238                yield
239            except Exception as exc:
240                raise RuntimeError(f'caught {exc}') from exc
241
242        with self.assertRaises(RuntimeError):
243            async with woohoo():
244                1 / 0
245
246        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
247        # we also unwrap it, because we can't tell whether the wrapping was
248        # done by the generator machinery or by the generator itself.
249        with self.assertRaises(StopAsyncIteration):
250            async with woohoo():
251                raise StopAsyncIteration
252
253    def _create_contextmanager_attribs(self):
254        def attribs(**kw):
255            def decorate(func):
256                for k,v in kw.items():
257                    setattr(func,k,v)
258                return func
259            return decorate
260        @asynccontextmanager
261        @attribs(foo='bar')
262        async def baz(spam):
263            """Whee!"""
264            yield
265        return baz
266
267    def test_contextmanager_attribs(self):
268        baz = self._create_contextmanager_attribs()
269        self.assertEqual(baz.__name__,'baz')
270        self.assertEqual(baz.foo, 'bar')
271
272    @support.requires_docstrings
273    def test_contextmanager_doc_attrib(self):
274        baz = self._create_contextmanager_attribs()
275        self.assertEqual(baz.__doc__, "Whee!")
276
277    @support.requires_docstrings
278    @_async_test
279    async def test_instance_docstring_given_cm_docstring(self):
280        baz = self._create_contextmanager_attribs()(None)
281        self.assertEqual(baz.__doc__, "Whee!")
282        async with baz:
283            pass  # suppress warning
284
285    @_async_test
286    async def test_keywords(self):
287        # Ensure no keyword arguments are inhibited
288        @asynccontextmanager
289        async def woohoo(self, func, args, kwds):
290            yield (self, func, args, kwds)
291        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
292            self.assertEqual(target, (11, 22, 33, 44))
293
294    @_async_test
295    async def test_recursive(self):
296        depth = 0
297        ncols = 0
298
299        @asynccontextmanager
300        async def woohoo():
301            nonlocal ncols
302            ncols += 1
303
304            nonlocal depth
305            before = depth
306            depth += 1
307            yield
308            depth -= 1
309            self.assertEqual(depth, before)
310
311        @woohoo()
312        async def recursive():
313            if depth < 10:
314                await recursive()
315
316        await recursive()
317
318        self.assertEqual(ncols, 10)
319        self.assertEqual(depth, 0)
320
321
322class AclosingTestCase(unittest.TestCase):
323
324    @support.requires_docstrings
325    def test_instance_docs(self):
326        cm_docstring = aclosing.__doc__
327        obj = aclosing(None)
328        self.assertEqual(obj.__doc__, cm_docstring)
329
330    @_async_test
331    async def test_aclosing(self):
332        state = []
333        class C:
334            async def aclose(self):
335                state.append(1)
336        x = C()
337        self.assertEqual(state, [])
338        async with aclosing(x) as y:
339            self.assertEqual(x, y)
340        self.assertEqual(state, [1])
341
342    @_async_test
343    async def test_aclosing_error(self):
344        state = []
345        class C:
346            async def aclose(self):
347                state.append(1)
348        x = C()
349        self.assertEqual(state, [])
350        with self.assertRaises(ZeroDivisionError):
351            async with aclosing(x) as y:
352                self.assertEqual(x, y)
353                1 / 0
354        self.assertEqual(state, [1])
355
356    @_async_test
357    async def test_aclosing_bpo41229(self):
358        state = []
359
360        @contextmanager
361        def sync_resource():
362            try:
363                yield
364            finally:
365                state.append(1)
366
367        async def agenfunc():
368            with sync_resource():
369                yield -1
370                yield -2
371
372        x = agenfunc()
373        self.assertEqual(state, [])
374        with self.assertRaises(ZeroDivisionError):
375            async with aclosing(x) as y:
376                self.assertEqual(x, y)
377                self.assertEqual(-1, await x.__anext__())
378                1 / 0
379        self.assertEqual(state, [1])
380
381
382class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
383    class SyncAsyncExitStack(AsyncExitStack):
384        @staticmethod
385        def run_coroutine(coro):
386            loop = asyncio.get_event_loop()
387
388            f = asyncio.ensure_future(coro)
389            f.add_done_callback(lambda f: loop.stop())
390            loop.run_forever()
391
392            exc = f.exception()
393
394            if not exc:
395                return f.result()
396            else:
397                context = exc.__context__
398
399                try:
400                    raise exc
401                except:
402                    exc.__context__ = context
403                    raise exc
404
405        def close(self):
406            return self.run_coroutine(self.aclose())
407
408        def __enter__(self):
409            return self.run_coroutine(self.__aenter__())
410
411        def __exit__(self, *exc_details):
412            return self.run_coroutine(self.__aexit__(*exc_details))
413
414    exit_stack = SyncAsyncExitStack
415
416    def setUp(self):
417        self.loop = asyncio.new_event_loop()
418        asyncio.set_event_loop(self.loop)
419        self.addCleanup(self.loop.close)
420        self.addCleanup(asyncio.set_event_loop_policy, None)
421
422    @_async_test
423    async def test_async_callback(self):
424        expected = [
425            ((), {}),
426            ((1,), {}),
427            ((1,2), {}),
428            ((), dict(example=1)),
429            ((1,), dict(example=1)),
430            ((1,2), dict(example=1)),
431        ]
432        result = []
433        async def _exit(*args, **kwds):
434            """Test metadata propagation"""
435            result.append((args, kwds))
436
437        async with AsyncExitStack() as stack:
438            for args, kwds in reversed(expected):
439                if args and kwds:
440                    f = stack.push_async_callback(_exit, *args, **kwds)
441                elif args:
442                    f = stack.push_async_callback(_exit, *args)
443                elif kwds:
444                    f = stack.push_async_callback(_exit, **kwds)
445                else:
446                    f = stack.push_async_callback(_exit)
447                self.assertIs(f, _exit)
448            for wrapper in stack._exit_callbacks:
449                self.assertIs(wrapper[1].__wrapped__, _exit)
450                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
451                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
452
453        self.assertEqual(result, expected)
454
455        result = []
456        async with AsyncExitStack() as stack:
457            with self.assertRaises(TypeError):
458                stack.push_async_callback(arg=1)
459            with self.assertRaises(TypeError):
460                self.exit_stack.push_async_callback(arg=2)
461            with self.assertRaises(TypeError):
462                stack.push_async_callback(callback=_exit, arg=3)
463        self.assertEqual(result, [])
464
465    @_async_test
466    async def test_async_push(self):
467        exc_raised = ZeroDivisionError
468        async def _expect_exc(exc_type, exc, exc_tb):
469            self.assertIs(exc_type, exc_raised)
470        async def _suppress_exc(*exc_details):
471            return True
472        async def _expect_ok(exc_type, exc, exc_tb):
473            self.assertIsNone(exc_type)
474            self.assertIsNone(exc)
475            self.assertIsNone(exc_tb)
476        class ExitCM(object):
477            def __init__(self, check_exc):
478                self.check_exc = check_exc
479            async def __aenter__(self):
480                self.fail("Should not be called!")
481            async def __aexit__(self, *exc_details):
482                await self.check_exc(*exc_details)
483
484        async with self.exit_stack() as stack:
485            stack.push_async_exit(_expect_ok)
486            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
487            cm = ExitCM(_expect_ok)
488            stack.push_async_exit(cm)
489            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
490            stack.push_async_exit(_suppress_exc)
491            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
492            cm = ExitCM(_expect_exc)
493            stack.push_async_exit(cm)
494            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
495            stack.push_async_exit(_expect_exc)
496            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
497            stack.push_async_exit(_expect_exc)
498            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
499            1/0
500
501    @_async_test
502    async def test_async_enter_context(self):
503        class TestCM(object):
504            async def __aenter__(self):
505                result.append(1)
506            async def __aexit__(self, *exc_details):
507                result.append(3)
508
509        result = []
510        cm = TestCM()
511
512        async with AsyncExitStack() as stack:
513            @stack.push_async_callback  # Registered first => cleaned up last
514            async def _exit():
515                result.append(4)
516            self.assertIsNotNone(_exit)
517            await stack.enter_async_context(cm)
518            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
519            result.append(2)
520
521        self.assertEqual(result, [1, 2, 3, 4])
522
523    @_async_test
524    async def test_async_exit_exception_chaining(self):
525        # Ensure exception chaining matches the reference behaviour
526        async def raise_exc(exc):
527            raise exc
528
529        saved_details = None
530        async def suppress_exc(*exc_details):
531            nonlocal saved_details
532            saved_details = exc_details
533            return True
534
535        try:
536            async with self.exit_stack() as stack:
537                stack.push_async_callback(raise_exc, IndexError)
538                stack.push_async_callback(raise_exc, KeyError)
539                stack.push_async_callback(raise_exc, AttributeError)
540                stack.push_async_exit(suppress_exc)
541                stack.push_async_callback(raise_exc, ValueError)
542                1 / 0
543        except IndexError as exc:
544            self.assertIsInstance(exc.__context__, KeyError)
545            self.assertIsInstance(exc.__context__.__context__, AttributeError)
546            # Inner exceptions were suppressed
547            self.assertIsNone(exc.__context__.__context__.__context__)
548        else:
549            self.fail("Expected IndexError, but no exception was raised")
550        # Check the inner exceptions
551        inner_exc = saved_details[1]
552        self.assertIsInstance(inner_exc, ValueError)
553        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
554
555    @_async_test
556    async def test_async_exit_exception_explicit_none_context(self):
557        # Ensure AsyncExitStack chaining matches actual nested `with` statements
558        # regarding explicit __context__ = None.
559
560        class MyException(Exception):
561            pass
562
563        @asynccontextmanager
564        async def my_cm():
565            try:
566                yield
567            except BaseException:
568                exc = MyException()
569                try:
570                    raise exc
571                finally:
572                    exc.__context__ = None
573
574        @asynccontextmanager
575        async def my_cm_with_exit_stack():
576            async with self.exit_stack() as stack:
577                await stack.enter_async_context(my_cm())
578                yield stack
579
580        for cm in (my_cm, my_cm_with_exit_stack):
581            with self.subTest():
582                try:
583                    async with cm():
584                        raise IndexError()
585                except MyException as exc:
586                    self.assertIsNone(exc.__context__)
587                else:
588                    self.fail("Expected IndexError, but no exception was raised")
589
590
591class TestAsyncNullcontext(unittest.TestCase):
592    @_async_test
593    async def test_async_nullcontext(self):
594        class C:
595            pass
596        c = C()
597        async with nullcontext(c) as c_in:
598            self.assertIs(c_in, c)
599
600
601if __name__ == '__main__':
602    unittest.main()
603