1"""Tests support for new syntax introduced by PEP 492."""
2
3import sys
4import types
5import unittest
6
7from unittest import mock
8
9import asyncio
10from test.test_asyncio import utils as test_utils
11
12
13def tearDownModule():
14    asyncio.set_event_loop_policy(None)
15
16
17# Test that asyncio.iscoroutine() uses collections.abc.Coroutine
18class FakeCoro:
19    def send(self, value):
20        pass
21
22    def throw(self, typ, val=None, tb=None):
23        pass
24
25    def close(self):
26        pass
27
28    def __await__(self):
29        yield
30
31
32class BaseTest(test_utils.TestCase):
33
34    def setUp(self):
35        super().setUp()
36        self.loop = asyncio.BaseEventLoop()
37        self.loop._process_events = mock.Mock()
38        self.loop._selector = mock.Mock()
39        self.loop._selector.select.return_value = ()
40        self.set_event_loop(self.loop)
41
42
43class LockTests(BaseTest):
44
45    def test_context_manager_async_with(self):
46        primitives = [
47            asyncio.Lock(),
48            asyncio.Condition(),
49            asyncio.Semaphore(),
50            asyncio.BoundedSemaphore(),
51        ]
52
53        async def test(lock):
54            await asyncio.sleep(0.01)
55            self.assertFalse(lock.locked())
56            async with lock as _lock:
57                self.assertIs(_lock, None)
58                self.assertTrue(lock.locked())
59                await asyncio.sleep(0.01)
60                self.assertTrue(lock.locked())
61            self.assertFalse(lock.locked())
62
63        for primitive in primitives:
64            self.loop.run_until_complete(test(primitive))
65            self.assertFalse(primitive.locked())
66
67    def test_context_manager_with_await(self):
68        primitives = [
69            asyncio.Lock(),
70            asyncio.Condition(),
71            asyncio.Semaphore(),
72            asyncio.BoundedSemaphore(),
73        ]
74
75        async def test(lock):
76            await asyncio.sleep(0.01)
77            self.assertFalse(lock.locked())
78            with self.assertRaisesRegex(
79                TypeError,
80                "can't be used in 'await' expression"
81            ):
82                with await lock:
83                    pass
84
85        for primitive in primitives:
86            self.loop.run_until_complete(test(primitive))
87            self.assertFalse(primitive.locked())
88
89
90class StreamReaderTests(BaseTest):
91
92    def test_readline(self):
93        DATA = b'line1\nline2\nline3'
94
95        stream = asyncio.StreamReader(loop=self.loop)
96        stream.feed_data(DATA)
97        stream.feed_eof()
98
99        async def reader():
100            data = []
101            async for line in stream:
102                data.append(line)
103            return data
104
105        data = self.loop.run_until_complete(reader())
106        self.assertEqual(data, [b'line1\n', b'line2\n', b'line3'])
107
108
109class CoroutineTests(BaseTest):
110
111    def test_iscoroutine(self):
112        async def foo(): pass
113
114        f = foo()
115        try:
116            self.assertTrue(asyncio.iscoroutine(f))
117        finally:
118            f.close() # silence warning
119
120        self.assertTrue(asyncio.iscoroutine(FakeCoro()))
121
122    def test_iscoroutinefunction(self):
123        async def foo(): pass
124        self.assertTrue(asyncio.iscoroutinefunction(foo))
125
126    def test_async_def_coroutines(self):
127        async def bar():
128            return 'spam'
129        async def foo():
130            return await bar()
131
132        # production mode
133        data = self.loop.run_until_complete(foo())
134        self.assertEqual(data, 'spam')
135
136        # debug mode
137        self.loop.set_debug(True)
138        data = self.loop.run_until_complete(foo())
139        self.assertEqual(data, 'spam')
140
141    def test_debug_mode_manages_coroutine_origin_tracking(self):
142        async def start():
143            self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0)
144
145        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
146        self.loop.set_debug(True)
147        self.loop.run_until_complete(start())
148        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
149
150    def test_types_coroutine(self):
151        def gen():
152            yield from ()
153            return 'spam'
154
155        @types.coroutine
156        def func():
157            return gen()
158
159        async def coro():
160            wrapper = func()
161            self.assertIsInstance(wrapper, types._GeneratorWrapper)
162            return await wrapper
163
164        data = self.loop.run_until_complete(coro())
165        self.assertEqual(data, 'spam')
166
167    def test_task_print_stack(self):
168        T = None
169
170        async def foo():
171            f = T.get_stack(limit=1)
172            try:
173                self.assertEqual(f[0].f_code.co_name, 'foo')
174            finally:
175                f = None
176
177        async def runner():
178            nonlocal T
179            T = asyncio.ensure_future(foo(), loop=self.loop)
180            await T
181
182        self.loop.run_until_complete(runner())
183
184    def test_double_await(self):
185        async def afunc():
186            await asyncio.sleep(0.1)
187
188        async def runner():
189            coro = afunc()
190            t = self.loop.create_task(coro)
191            try:
192                await asyncio.sleep(0)
193                await coro
194            finally:
195                t.cancel()
196
197        self.loop.set_debug(True)
198        with self.assertRaises(
199                RuntimeError,
200                msg='coroutine is being awaited already'):
201
202            self.loop.run_until_complete(runner())
203
204
205if __name__ == '__main__':
206    unittest.main()
207