1import asyncio
2from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
3import functools
4from test import support
5import unittest
6
7from test.test_contextlib import TestBaseExitStack
8
9
10def _async_test(func):
11    """Decorator to turn an async function into a test case."""
12    @functools.wraps(func)
13    def wrapper(*args, **kwargs):
14        coro = func(*args, **kwargs)
15        loop = asyncio.new_event_loop()
16        asyncio.set_event_loop(loop)
17        try:
18            return loop.run_until_complete(coro)
19        finally:
20            loop.close()
21            asyncio.set_event_loop(None)
22    return wrapper
23
24
25class TestAbstractAsyncContextManager(unittest.TestCase):
26
27    @_async_test
28    async def test_enter(self):
29        class DefaultEnter(AbstractAsyncContextManager):
30            async def __aexit__(self, *args):
31                await super().__aexit__(*args)
32
33        manager = DefaultEnter()
34        self.assertIs(await manager.__aenter__(), manager)
35
36        async with manager as context:
37            self.assertIs(manager, context)
38
39    @_async_test
40    async def test_async_gen_propagates_generator_exit(self):
41        # A regression test for https://bugs.python.org/issue33786.
42
43        @asynccontextmanager
44        async def ctx():
45            yield
46
47        async def gen():
48            async with ctx():
49                yield 11
50
51        ret = []
52        exc = ValueError(22)
53        with self.assertRaises(ValueError):
54            async with ctx():
55                async for val in gen():
56                    ret.append(val)
57                    raise exc
58
59        self.assertEqual(ret, [11])
60
61    def test_exit_is_abstract(self):
62        class MissingAexit(AbstractAsyncContextManager):
63            pass
64
65        with self.assertRaises(TypeError):
66            MissingAexit()
67
68    def test_structural_subclassing(self):
69        class ManagerFromScratch:
70            async def __aenter__(self):
71                return self
72            async def __aexit__(self, exc_type, exc_value, traceback):
73                return None
74
75        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
76
77        class DefaultEnter(AbstractAsyncContextManager):
78            async def __aexit__(self, *args):
79                await super().__aexit__(*args)
80
81        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
82
83        class NoneAenter(ManagerFromScratch):
84            __aenter__ = None
85
86        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
87
88        class NoneAexit(ManagerFromScratch):
89            __aexit__ = None
90
91        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
92
93
94class AsyncContextManagerTestCase(unittest.TestCase):
95
96    @_async_test
97    async def test_contextmanager_plain(self):
98        state = []
99        @asynccontextmanager
100        async def woohoo():
101            state.append(1)
102            yield 42
103            state.append(999)
104        async with woohoo() as x:
105            self.assertEqual(state, [1])
106            self.assertEqual(x, 42)
107            state.append(x)
108        self.assertEqual(state, [1, 42, 999])
109
110    @_async_test
111    async def test_contextmanager_finally(self):
112        state = []
113        @asynccontextmanager
114        async def woohoo():
115            state.append(1)
116            try:
117                yield 42
118            finally:
119                state.append(999)
120        with self.assertRaises(ZeroDivisionError):
121            async with woohoo() as x:
122                self.assertEqual(state, [1])
123                self.assertEqual(x, 42)
124                state.append(x)
125                raise ZeroDivisionError()
126        self.assertEqual(state, [1, 42, 999])
127
128    @_async_test
129    async def test_contextmanager_no_reraise(self):
130        @asynccontextmanager
131        async def whee():
132            yield
133        ctx = whee()
134        await ctx.__aenter__()
135        # Calling __aexit__ should not result in an exception
136        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
137
138    @_async_test
139    async def test_contextmanager_trap_yield_after_throw(self):
140        @asynccontextmanager
141        async def whoo():
142            try:
143                yield
144            except:
145                yield
146        ctx = whoo()
147        await ctx.__aenter__()
148        with self.assertRaises(RuntimeError):
149            await ctx.__aexit__(TypeError, TypeError('foo'), None)
150
151    @_async_test
152    async def test_contextmanager_trap_no_yield(self):
153        @asynccontextmanager
154        async def whoo():
155            if False:
156                yield
157        ctx = whoo()
158        with self.assertRaises(RuntimeError):
159            await ctx.__aenter__()
160
161    @_async_test
162    async def test_contextmanager_trap_second_yield(self):
163        @asynccontextmanager
164        async def whoo():
165            yield
166            yield
167        ctx = whoo()
168        await ctx.__aenter__()
169        with self.assertRaises(RuntimeError):
170            await ctx.__aexit__(None, None, None)
171
172    @_async_test
173    async def test_contextmanager_non_normalised(self):
174        @asynccontextmanager
175        async def whoo():
176            try:
177                yield
178            except RuntimeError:
179                raise SyntaxError
180
181        ctx = whoo()
182        await ctx.__aenter__()
183        with self.assertRaises(SyntaxError):
184            await ctx.__aexit__(RuntimeError, None, None)
185
186    @_async_test
187    async def test_contextmanager_except(self):
188        state = []
189        @asynccontextmanager
190        async def woohoo():
191            state.append(1)
192            try:
193                yield 42
194            except ZeroDivisionError as e:
195                state.append(e.args[0])
196                self.assertEqual(state, [1, 42, 999])
197        async with woohoo() as x:
198            self.assertEqual(state, [1])
199            self.assertEqual(x, 42)
200            state.append(x)
201            raise ZeroDivisionError(999)
202        self.assertEqual(state, [1, 42, 999])
203
204    @_async_test
205    async def test_contextmanager_except_stopiter(self):
206        @asynccontextmanager
207        async def woohoo():
208            yield
209
210        for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
211            with self.subTest(type=type(stop_exc)):
212                try:
213                    async with woohoo():
214                        raise stop_exc
215                except Exception as ex:
216                    self.assertIs(ex, stop_exc)
217                else:
218                    self.fail(f'{stop_exc} was suppressed')
219
220    @_async_test
221    async def test_contextmanager_wrap_runtimeerror(self):
222        @asynccontextmanager
223        async def woohoo():
224            try:
225                yield
226            except Exception as exc:
227                raise RuntimeError(f'caught {exc}') from exc
228
229        with self.assertRaises(RuntimeError):
230            async with woohoo():
231                1 / 0
232
233        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
234        # we also unwrap it, because we can't tell whether the wrapping was
235        # done by the generator machinery or by the generator itself.
236        with self.assertRaises(StopAsyncIteration):
237            async with woohoo():
238                raise StopAsyncIteration
239
240    def _create_contextmanager_attribs(self):
241        def attribs(**kw):
242            def decorate(func):
243                for k,v in kw.items():
244                    setattr(func,k,v)
245                return func
246            return decorate
247        @asynccontextmanager
248        @attribs(foo='bar')
249        async def baz(spam):
250            """Whee!"""
251            yield
252        return baz
253
254    def test_contextmanager_attribs(self):
255        baz = self._create_contextmanager_attribs()
256        self.assertEqual(baz.__name__,'baz')
257        self.assertEqual(baz.foo, 'bar')
258
259    @support.requires_docstrings
260    def test_contextmanager_doc_attrib(self):
261        baz = self._create_contextmanager_attribs()
262        self.assertEqual(baz.__doc__, "Whee!")
263
264    @support.requires_docstrings
265    @_async_test
266    async def test_instance_docstring_given_cm_docstring(self):
267        baz = self._create_contextmanager_attribs()(None)
268        self.assertEqual(baz.__doc__, "Whee!")
269        async with baz:
270            pass  # suppress warning
271
272    @_async_test
273    async def test_keywords(self):
274        # Ensure no keyword arguments are inhibited
275        @asynccontextmanager
276        async def woohoo(self, func, args, kwds):
277            yield (self, func, args, kwds)
278        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
279            self.assertEqual(target, (11, 22, 33, 44))
280
281
282class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
283    class SyncAsyncExitStack(AsyncExitStack):
284        @staticmethod
285        def run_coroutine(coro):
286            loop = asyncio.get_event_loop()
287
288            f = asyncio.ensure_future(coro)
289            f.add_done_callback(lambda f: loop.stop())
290            loop.run_forever()
291
292            exc = f.exception()
293
294            if not exc:
295                return f.result()
296            else:
297                context = exc.__context__
298
299                try:
300                    raise exc
301                except:
302                    exc.__context__ = context
303                    raise exc
304
305        def close(self):
306            return self.run_coroutine(self.aclose())
307
308        def __enter__(self):
309            return self.run_coroutine(self.__aenter__())
310
311        def __exit__(self, *exc_details):
312            return self.run_coroutine(self.__aexit__(*exc_details))
313
314    exit_stack = SyncAsyncExitStack
315
316    def setUp(self):
317        self.loop = asyncio.new_event_loop()
318        asyncio.set_event_loop(self.loop)
319        self.addCleanup(self.loop.close)
320
321    @_async_test
322    async def test_async_callback(self):
323        expected = [
324            ((), {}),
325            ((1,), {}),
326            ((1,2), {}),
327            ((), dict(example=1)),
328            ((1,), dict(example=1)),
329            ((1,2), dict(example=1)),
330        ]
331        result = []
332        async def _exit(*args, **kwds):
333            """Test metadata propagation"""
334            result.append((args, kwds))
335
336        async with AsyncExitStack() as stack:
337            for args, kwds in reversed(expected):
338                if args and kwds:
339                    f = stack.push_async_callback(_exit, *args, **kwds)
340                elif args:
341                    f = stack.push_async_callback(_exit, *args)
342                elif kwds:
343                    f = stack.push_async_callback(_exit, **kwds)
344                else:
345                    f = stack.push_async_callback(_exit)
346                self.assertIs(f, _exit)
347            for wrapper in stack._exit_callbacks:
348                self.assertIs(wrapper[1].__wrapped__, _exit)
349                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
350                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
351
352        self.assertEqual(result, expected)
353
354        result = []
355        async with AsyncExitStack() as stack:
356            with self.assertRaises(TypeError):
357                stack.push_async_callback(arg=1)
358            with self.assertRaises(TypeError):
359                self.exit_stack.push_async_callback(arg=2)
360            stack.push_async_callback(callback=_exit, arg=3)
361        self.assertEqual(result, [((), {'arg': 3})])
362
363    @_async_test
364    async def test_async_push(self):
365        exc_raised = ZeroDivisionError
366        async def _expect_exc(exc_type, exc, exc_tb):
367            self.assertIs(exc_type, exc_raised)
368        async def _suppress_exc(*exc_details):
369            return True
370        async def _expect_ok(exc_type, exc, exc_tb):
371            self.assertIsNone(exc_type)
372            self.assertIsNone(exc)
373            self.assertIsNone(exc_tb)
374        class ExitCM(object):
375            def __init__(self, check_exc):
376                self.check_exc = check_exc
377            async def __aenter__(self):
378                self.fail("Should not be called!")
379            async def __aexit__(self, *exc_details):
380                await self.check_exc(*exc_details)
381
382        async with self.exit_stack() as stack:
383            stack.push_async_exit(_expect_ok)
384            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
385            cm = ExitCM(_expect_ok)
386            stack.push_async_exit(cm)
387            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
388            stack.push_async_exit(_suppress_exc)
389            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
390            cm = ExitCM(_expect_exc)
391            stack.push_async_exit(cm)
392            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
393            stack.push_async_exit(_expect_exc)
394            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
395            stack.push_async_exit(_expect_exc)
396            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
397            1/0
398
399    @_async_test
400    async def test_async_enter_context(self):
401        class TestCM(object):
402            async def __aenter__(self):
403                result.append(1)
404            async def __aexit__(self, *exc_details):
405                result.append(3)
406
407        result = []
408        cm = TestCM()
409
410        async with AsyncExitStack() as stack:
411            @stack.push_async_callback  # Registered first => cleaned up last
412            async def _exit():
413                result.append(4)
414            self.assertIsNotNone(_exit)
415            await stack.enter_async_context(cm)
416            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
417            result.append(2)
418
419        self.assertEqual(result, [1, 2, 3, 4])
420
421    @_async_test
422    async def test_async_exit_exception_chaining(self):
423        # Ensure exception chaining matches the reference behaviour
424        async def raise_exc(exc):
425            raise exc
426
427        saved_details = None
428        async def suppress_exc(*exc_details):
429            nonlocal saved_details
430            saved_details = exc_details
431            return True
432
433        try:
434            async with self.exit_stack() as stack:
435                stack.push_async_callback(raise_exc, IndexError)
436                stack.push_async_callback(raise_exc, KeyError)
437                stack.push_async_callback(raise_exc, AttributeError)
438                stack.push_async_exit(suppress_exc)
439                stack.push_async_callback(raise_exc, ValueError)
440                1 / 0
441        except IndexError as exc:
442            self.assertIsInstance(exc.__context__, KeyError)
443            self.assertIsInstance(exc.__context__.__context__, AttributeError)
444            # Inner exceptions were suppressed
445            self.assertIsNone(exc.__context__.__context__.__context__)
446        else:
447            self.fail("Expected IndexError, but no exception was raised")
448        # Check the inner exceptions
449        inner_exc = saved_details[1]
450        self.assertIsInstance(inner_exc, ValueError)
451        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
452
453
454if __name__ == '__main__':
455    unittest.main()
456