1"""Tests support for new syntax introduced by PEP 492."""
2
3import sys
4import types
5import unittest
6
7from unittest import mock
8
9import asyncio
10from test.test_asyncio import utils as test_utils
11
12
13def tearDownModule():
14    asyncio.set_event_loop_policy(None)
15
16
17# Test that asyncio.iscoroutine() uses collections.abc.Coroutine
18class FakeCoro:
19    def send(self, value):
20        pass
21
22    def throw(self, typ, val=None, tb=None):
23        pass
24
25    def close(self):
26        pass
27
28    def __await__(self):
29        yield
30
31
32class BaseTest(test_utils.TestCase):
33
34    def setUp(self):
35        super().setUp()
36        self.loop = asyncio.BaseEventLoop()
37        self.loop._process_events = mock.Mock()
38        self.loop._selector = mock.Mock()
39        self.loop._selector.select.return_value = ()
40        self.set_event_loop(self.loop)
41
42
43class LockTests(BaseTest):
44
45    def test_context_manager_async_with(self):
46        with self.assertWarns(DeprecationWarning):
47            primitives = [
48                asyncio.Lock(loop=self.loop),
49                asyncio.Condition(loop=self.loop),
50                asyncio.Semaphore(loop=self.loop),
51                asyncio.BoundedSemaphore(loop=self.loop),
52            ]
53
54        async def test(lock):
55            await asyncio.sleep(0.01)
56            self.assertFalse(lock.locked())
57            async with lock as _lock:
58                self.assertIs(_lock, None)
59                self.assertTrue(lock.locked())
60                await asyncio.sleep(0.01)
61                self.assertTrue(lock.locked())
62            self.assertFalse(lock.locked())
63
64        for primitive in primitives:
65            self.loop.run_until_complete(test(primitive))
66            self.assertFalse(primitive.locked())
67
68    def test_context_manager_with_await(self):
69        with self.assertWarns(DeprecationWarning):
70            primitives = [
71                asyncio.Lock(loop=self.loop),
72                asyncio.Condition(loop=self.loop),
73                asyncio.Semaphore(loop=self.loop),
74                asyncio.BoundedSemaphore(loop=self.loop),
75            ]
76
77        async def test(lock):
78            await asyncio.sleep(0.01)
79            self.assertFalse(lock.locked())
80            with self.assertRaisesRegex(
81                TypeError,
82                "can't be used in 'await' expression"
83            ):
84                with await lock:
85                    pass
86
87        for primitive in primitives:
88            self.loop.run_until_complete(test(primitive))
89            self.assertFalse(primitive.locked())
90
91
92class StreamReaderTests(BaseTest):
93
94    def test_readline(self):
95        DATA = b'line1\nline2\nline3'
96
97        stream = asyncio.StreamReader(loop=self.loop)
98        stream.feed_data(DATA)
99        stream.feed_eof()
100
101        async def reader():
102            data = []
103            async for line in stream:
104                data.append(line)
105            return data
106
107        data = self.loop.run_until_complete(reader())
108        self.assertEqual(data, [b'line1\n', b'line2\n', b'line3'])
109
110
111class CoroutineTests(BaseTest):
112
113    def test_iscoroutine(self):
114        async def foo(): pass
115
116        f = foo()
117        try:
118            self.assertTrue(asyncio.iscoroutine(f))
119        finally:
120            f.close() # silence warning
121
122        self.assertTrue(asyncio.iscoroutine(FakeCoro()))
123
124    def test_iscoroutinefunction(self):
125        async def foo(): pass
126        self.assertTrue(asyncio.iscoroutinefunction(foo))
127
128    def test_function_returning_awaitable(self):
129        class Awaitable:
130            def __await__(self):
131                return ('spam',)
132
133        with self.assertWarns(DeprecationWarning):
134            @asyncio.coroutine
135            def func():
136                return Awaitable()
137
138        coro = func()
139        self.assertEqual(coro.send(None), 'spam')
140        coro.close()
141
142    def test_async_def_coroutines(self):
143        async def bar():
144            return 'spam'
145        async def foo():
146            return await bar()
147
148        # production mode
149        data = self.loop.run_until_complete(foo())
150        self.assertEqual(data, 'spam')
151
152        # debug mode
153        self.loop.set_debug(True)
154        data = self.loop.run_until_complete(foo())
155        self.assertEqual(data, 'spam')
156
157    def test_debug_mode_manages_coroutine_origin_tracking(self):
158        async def start():
159            self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0)
160
161        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
162        self.loop.set_debug(True)
163        self.loop.run_until_complete(start())
164        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
165
166    def test_types_coroutine(self):
167        def gen():
168            yield from ()
169            return 'spam'
170
171        @types.coroutine
172        def func():
173            return gen()
174
175        async def coro():
176            wrapper = func()
177            self.assertIsInstance(wrapper, types._GeneratorWrapper)
178            return await wrapper
179
180        data = self.loop.run_until_complete(coro())
181        self.assertEqual(data, 'spam')
182
183    def test_task_print_stack(self):
184        T = None
185
186        async def foo():
187            f = T.get_stack(limit=1)
188            try:
189                self.assertEqual(f[0].f_code.co_name, 'foo')
190            finally:
191                f = None
192
193        async def runner():
194            nonlocal T
195            T = asyncio.ensure_future(foo(), loop=self.loop)
196            await T
197
198        self.loop.run_until_complete(runner())
199
200    def test_double_await(self):
201        async def afunc():
202            await asyncio.sleep(0.1)
203
204        async def runner():
205            coro = afunc()
206            t = self.loop.create_task(coro)
207            try:
208                await asyncio.sleep(0)
209                await coro
210            finally:
211                t.cancel()
212
213        self.loop.set_debug(True)
214        with self.assertRaises(
215                RuntimeError,
216                msg='coroutine is being awaited already'):
217
218            self.loop.run_until_complete(runner())
219
220
221if __name__ == '__main__':
222    unittest.main()
223