1import asyncio
2from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
3import functools
4from test import support
5import unittest
6
7from test.test_contextlib import TestBaseExitStack
8
9
10def _async_test(func):
11    """Decorator to turn an async function into a test case."""
12    @functools.wraps(func)
13    def wrapper(*args, **kwargs):
14        coro = func(*args, **kwargs)
15        loop = asyncio.new_event_loop()
16        asyncio.set_event_loop(loop)
17        try:
18            return loop.run_until_complete(coro)
19        finally:
20            loop.close()
21            asyncio.set_event_loop_policy(None)
22    return wrapper
23
24
25class TestAbstractAsyncContextManager(unittest.TestCase):
26
27    @_async_test
28    async def test_enter(self):
29        class DefaultEnter(AbstractAsyncContextManager):
30            async def __aexit__(self, *args):
31                await super().__aexit__(*args)
32
33        manager = DefaultEnter()
34        self.assertIs(await manager.__aenter__(), manager)
35
36        async with manager as context:
37            self.assertIs(manager, context)
38
39    @_async_test
40    async def test_async_gen_propagates_generator_exit(self):
41        # A regression test for https://bugs.python.org/issue33786.
42
43        @asynccontextmanager
44        async def ctx():
45            yield
46
47        async def gen():
48            async with ctx():
49                yield 11
50
51        ret = []
52        exc = ValueError(22)
53        with self.assertRaises(ValueError):
54            async with ctx():
55                async for val in gen():
56                    ret.append(val)
57                    raise exc
58
59        self.assertEqual(ret, [11])
60
61    def test_exit_is_abstract(self):
62        class MissingAexit(AbstractAsyncContextManager):
63            pass
64
65        with self.assertRaises(TypeError):
66            MissingAexit()
67
68    def test_structural_subclassing(self):
69        class ManagerFromScratch:
70            async def __aenter__(self):
71                return self
72            async def __aexit__(self, exc_type, exc_value, traceback):
73                return None
74
75        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
76
77        class DefaultEnter(AbstractAsyncContextManager):
78            async def __aexit__(self, *args):
79                await super().__aexit__(*args)
80
81        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
82
83        class NoneAenter(ManagerFromScratch):
84            __aenter__ = None
85
86        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
87
88        class NoneAexit(ManagerFromScratch):
89            __aexit__ = None
90
91        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
92
93
94class AsyncContextManagerTestCase(unittest.TestCase):
95
96    @_async_test
97    async def test_contextmanager_plain(self):
98        state = []
99        @asynccontextmanager
100        async def woohoo():
101            state.append(1)
102            yield 42
103            state.append(999)
104        async with woohoo() as x:
105            self.assertEqual(state, [1])
106            self.assertEqual(x, 42)
107            state.append(x)
108        self.assertEqual(state, [1, 42, 999])
109
110    @_async_test
111    async def test_contextmanager_finally(self):
112        state = []
113        @asynccontextmanager
114        async def woohoo():
115            state.append(1)
116            try:
117                yield 42
118            finally:
119                state.append(999)
120        with self.assertRaises(ZeroDivisionError):
121            async with woohoo() as x:
122                self.assertEqual(state, [1])
123                self.assertEqual(x, 42)
124                state.append(x)
125                raise ZeroDivisionError()
126        self.assertEqual(state, [1, 42, 999])
127
128    @_async_test
129    async def test_contextmanager_no_reraise(self):
130        @asynccontextmanager
131        async def whee():
132            yield
133        ctx = whee()
134        await ctx.__aenter__()
135        # Calling __aexit__ should not result in an exception
136        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
137
138    @_async_test
139    async def test_contextmanager_trap_yield_after_throw(self):
140        @asynccontextmanager
141        async def whoo():
142            try:
143                yield
144            except:
145                yield
146        ctx = whoo()
147        await ctx.__aenter__()
148        with self.assertRaises(RuntimeError):
149            await ctx.__aexit__(TypeError, TypeError('foo'), None)
150
151    @_async_test
152    async def test_contextmanager_trap_no_yield(self):
153        @asynccontextmanager
154        async def whoo():
155            if False:
156                yield
157        ctx = whoo()
158        with self.assertRaises(RuntimeError):
159            await ctx.__aenter__()
160
161    @_async_test
162    async def test_contextmanager_trap_second_yield(self):
163        @asynccontextmanager
164        async def whoo():
165            yield
166            yield
167        ctx = whoo()
168        await ctx.__aenter__()
169        with self.assertRaises(RuntimeError):
170            await ctx.__aexit__(None, None, None)
171
172    @_async_test
173    async def test_contextmanager_non_normalised(self):
174        @asynccontextmanager
175        async def whoo():
176            try:
177                yield
178            except RuntimeError:
179                raise SyntaxError
180
181        ctx = whoo()
182        await ctx.__aenter__()
183        with self.assertRaises(SyntaxError):
184            await ctx.__aexit__(RuntimeError, None, None)
185
186    @_async_test
187    async def test_contextmanager_except(self):
188        state = []
189        @asynccontextmanager
190        async def woohoo():
191            state.append(1)
192            try:
193                yield 42
194            except ZeroDivisionError as e:
195                state.append(e.args[0])
196                self.assertEqual(state, [1, 42, 999])
197        async with woohoo() as x:
198            self.assertEqual(state, [1])
199            self.assertEqual(x, 42)
200            state.append(x)
201            raise ZeroDivisionError(999)
202        self.assertEqual(state, [1, 42, 999])
203
204    @_async_test
205    async def test_contextmanager_except_stopiter(self):
206        @asynccontextmanager
207        async def woohoo():
208            yield
209
210        for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
211            with self.subTest(type=type(stop_exc)):
212                try:
213                    async with woohoo():
214                        raise stop_exc
215                except Exception as ex:
216                    self.assertIs(ex, stop_exc)
217                else:
218                    self.fail(f'{stop_exc} was suppressed')
219
220    @_async_test
221    async def test_contextmanager_wrap_runtimeerror(self):
222        @asynccontextmanager
223        async def woohoo():
224            try:
225                yield
226            except Exception as exc:
227                raise RuntimeError(f'caught {exc}') from exc
228
229        with self.assertRaises(RuntimeError):
230            async with woohoo():
231                1 / 0
232
233        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
234        # we also unwrap it, because we can't tell whether the wrapping was
235        # done by the generator machinery or by the generator itself.
236        with self.assertRaises(StopAsyncIteration):
237            async with woohoo():
238                raise StopAsyncIteration
239
240    def _create_contextmanager_attribs(self):
241        def attribs(**kw):
242            def decorate(func):
243                for k,v in kw.items():
244                    setattr(func,k,v)
245                return func
246            return decorate
247        @asynccontextmanager
248        @attribs(foo='bar')
249        async def baz(spam):
250            """Whee!"""
251            yield
252        return baz
253
254    def test_contextmanager_attribs(self):
255        baz = self._create_contextmanager_attribs()
256        self.assertEqual(baz.__name__,'baz')
257        self.assertEqual(baz.foo, 'bar')
258
259    @support.requires_docstrings
260    def test_contextmanager_doc_attrib(self):
261        baz = self._create_contextmanager_attribs()
262        self.assertEqual(baz.__doc__, "Whee!")
263
264    @support.requires_docstrings
265    @_async_test
266    async def test_instance_docstring_given_cm_docstring(self):
267        baz = self._create_contextmanager_attribs()(None)
268        self.assertEqual(baz.__doc__, "Whee!")
269        async with baz:
270            pass  # suppress warning
271
272    @_async_test
273    async def test_keywords(self):
274        # Ensure no keyword arguments are inhibited
275        @asynccontextmanager
276        async def woohoo(self, func, args, kwds):
277            yield (self, func, args, kwds)
278        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
279            self.assertEqual(target, (11, 22, 33, 44))
280
281
282class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
283    class SyncAsyncExitStack(AsyncExitStack):
284        @staticmethod
285        def run_coroutine(coro):
286            loop = asyncio.get_event_loop()
287
288            f = asyncio.ensure_future(coro)
289            f.add_done_callback(lambda f: loop.stop())
290            loop.run_forever()
291
292            exc = f.exception()
293
294            if not exc:
295                return f.result()
296            else:
297                context = exc.__context__
298
299                try:
300                    raise exc
301                except:
302                    exc.__context__ = context
303                    raise exc
304
305        def close(self):
306            return self.run_coroutine(self.aclose())
307
308        def __enter__(self):
309            return self.run_coroutine(self.__aenter__())
310
311        def __exit__(self, *exc_details):
312            return self.run_coroutine(self.__aexit__(*exc_details))
313
314    exit_stack = SyncAsyncExitStack
315
316    def setUp(self):
317        self.loop = asyncio.new_event_loop()
318        asyncio.set_event_loop(self.loop)
319        self.addCleanup(self.loop.close)
320        self.addCleanup(asyncio.set_event_loop_policy, None)
321
322    @_async_test
323    async def test_async_callback(self):
324        expected = [
325            ((), {}),
326            ((1,), {}),
327            ((1,2), {}),
328            ((), dict(example=1)),
329            ((1,), dict(example=1)),
330            ((1,2), dict(example=1)),
331        ]
332        result = []
333        async def _exit(*args, **kwds):
334            """Test metadata propagation"""
335            result.append((args, kwds))
336
337        async with AsyncExitStack() as stack:
338            for args, kwds in reversed(expected):
339                if args and kwds:
340                    f = stack.push_async_callback(_exit, *args, **kwds)
341                elif args:
342                    f = stack.push_async_callback(_exit, *args)
343                elif kwds:
344                    f = stack.push_async_callback(_exit, **kwds)
345                else:
346                    f = stack.push_async_callback(_exit)
347                self.assertIs(f, _exit)
348            for wrapper in stack._exit_callbacks:
349                self.assertIs(wrapper[1].__wrapped__, _exit)
350                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
351                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
352
353        self.assertEqual(result, expected)
354
355        result = []
356        async with AsyncExitStack() as stack:
357            with self.assertRaises(TypeError):
358                stack.push_async_callback(arg=1)
359            with self.assertRaises(TypeError):
360                self.exit_stack.push_async_callback(arg=2)
361            with self.assertWarns(DeprecationWarning):
362                stack.push_async_callback(callback=_exit, arg=3)
363        self.assertEqual(result, [((), {'arg': 3})])
364
365    @_async_test
366    async def test_async_push(self):
367        exc_raised = ZeroDivisionError
368        async def _expect_exc(exc_type, exc, exc_tb):
369            self.assertIs(exc_type, exc_raised)
370        async def _suppress_exc(*exc_details):
371            return True
372        async def _expect_ok(exc_type, exc, exc_tb):
373            self.assertIsNone(exc_type)
374            self.assertIsNone(exc)
375            self.assertIsNone(exc_tb)
376        class ExitCM(object):
377            def __init__(self, check_exc):
378                self.check_exc = check_exc
379            async def __aenter__(self):
380                self.fail("Should not be called!")
381            async def __aexit__(self, *exc_details):
382                await self.check_exc(*exc_details)
383
384        async with self.exit_stack() as stack:
385            stack.push_async_exit(_expect_ok)
386            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
387            cm = ExitCM(_expect_ok)
388            stack.push_async_exit(cm)
389            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
390            stack.push_async_exit(_suppress_exc)
391            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
392            cm = ExitCM(_expect_exc)
393            stack.push_async_exit(cm)
394            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
395            stack.push_async_exit(_expect_exc)
396            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
397            stack.push_async_exit(_expect_exc)
398            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
399            1/0
400
401    @_async_test
402    async def test_async_enter_context(self):
403        class TestCM(object):
404            async def __aenter__(self):
405                result.append(1)
406            async def __aexit__(self, *exc_details):
407                result.append(3)
408
409        result = []
410        cm = TestCM()
411
412        async with AsyncExitStack() as stack:
413            @stack.push_async_callback  # Registered first => cleaned up last
414            async def _exit():
415                result.append(4)
416            self.assertIsNotNone(_exit)
417            await stack.enter_async_context(cm)
418            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
419            result.append(2)
420
421        self.assertEqual(result, [1, 2, 3, 4])
422
423    @_async_test
424    async def test_async_exit_exception_chaining(self):
425        # Ensure exception chaining matches the reference behaviour
426        async def raise_exc(exc):
427            raise exc
428
429        saved_details = None
430        async def suppress_exc(*exc_details):
431            nonlocal saved_details
432            saved_details = exc_details
433            return True
434
435        try:
436            async with self.exit_stack() as stack:
437                stack.push_async_callback(raise_exc, IndexError)
438                stack.push_async_callback(raise_exc, KeyError)
439                stack.push_async_callback(raise_exc, AttributeError)
440                stack.push_async_exit(suppress_exc)
441                stack.push_async_callback(raise_exc, ValueError)
442                1 / 0
443        except IndexError as exc:
444            self.assertIsInstance(exc.__context__, KeyError)
445            self.assertIsInstance(exc.__context__.__context__, AttributeError)
446            # Inner exceptions were suppressed
447            self.assertIsNone(exc.__context__.__context__.__context__)
448        else:
449            self.fail("Expected IndexError, but no exception was raised")
450        # Check the inner exceptions
451        inner_exc = saved_details[1]
452        self.assertIsInstance(inner_exc, ValueError)
453        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
454
455
456if __name__ == '__main__':
457    unittest.main()
458