1import asyncio
2import unittest
3
4from unittest import mock
5from . import utils as test_utils
6
7
8class TestPolicy(asyncio.AbstractEventLoopPolicy):
9
10    def __init__(self, loop_factory):
11        self.loop_factory = loop_factory
12        self.loop = None
13
14    def get_event_loop(self):
15        # shouldn't ever be called by asyncio.run()
16        raise RuntimeError
17
18    def new_event_loop(self):
19        return self.loop_factory()
20
21    def set_event_loop(self, loop):
22        if loop is not None:
23            # we want to check if the loop is closed
24            # in BaseTest.tearDown
25            self.loop = loop
26
27
28class BaseTest(unittest.TestCase):
29
30    def new_loop(self):
31        loop = asyncio.BaseEventLoop()
32        loop._process_events = mock.Mock()
33        loop._selector = mock.Mock()
34        loop._selector.select.return_value = ()
35        loop.shutdown_ag_run = False
36
37        async def shutdown_asyncgens():
38            loop.shutdown_ag_run = True
39        loop.shutdown_asyncgens = shutdown_asyncgens
40
41        return loop
42
43    def setUp(self):
44        super().setUp()
45
46        policy = TestPolicy(self.new_loop)
47        asyncio.set_event_loop_policy(policy)
48
49    def tearDown(self):
50        policy = asyncio.get_event_loop_policy()
51        if policy.loop is not None:
52            self.assertTrue(policy.loop.is_closed())
53            self.assertTrue(policy.loop.shutdown_ag_run)
54
55        asyncio.set_event_loop_policy(None)
56        super().tearDown()
57
58
59class RunTests(BaseTest):
60
61    def test_asyncio_run_return(self):
62        async def main():
63            await asyncio.sleep(0)
64            return 42
65
66        self.assertEqual(asyncio.run(main()), 42)
67
68    def test_asyncio_run_raises(self):
69        async def main():
70            await asyncio.sleep(0)
71            raise ValueError('spam')
72
73        with self.assertRaisesRegex(ValueError, 'spam'):
74            asyncio.run(main())
75
76    def test_asyncio_run_only_coro(self):
77        for o in {1, lambda: None}:
78            with self.subTest(obj=o), \
79                    self.assertRaisesRegex(ValueError,
80                                           'a coroutine was expected'):
81                asyncio.run(o)
82
83    def test_asyncio_run_debug(self):
84        async def main(expected):
85            loop = asyncio.get_event_loop()
86            self.assertIs(loop.get_debug(), expected)
87
88        asyncio.run(main(False))
89        asyncio.run(main(True), debug=True)
90        with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True):
91            asyncio.run(main(True))
92            asyncio.run(main(False), debug=False)
93
94    def test_asyncio_run_from_running_loop(self):
95        async def main():
96            coro = main()
97            try:
98                asyncio.run(coro)
99            finally:
100                coro.close()  # Suppress ResourceWarning
101
102        with self.assertRaisesRegex(RuntimeError,
103                                    'cannot be called from a running'):
104            asyncio.run(main())
105
106    def test_asyncio_run_cancels_hanging_tasks(self):
107        lo_task = None
108
109        async def leftover():
110            await asyncio.sleep(0.1)
111
112        async def main():
113            nonlocal lo_task
114            lo_task = asyncio.create_task(leftover())
115            return 123
116
117        self.assertEqual(asyncio.run(main()), 123)
118        self.assertTrue(lo_task.done())
119
120    def test_asyncio_run_reports_hanging_tasks_errors(self):
121        lo_task = None
122        call_exc_handler_mock = mock.Mock()
123
124        async def leftover():
125            try:
126                await asyncio.sleep(0.1)
127            except asyncio.CancelledError:
128                1 / 0
129
130        async def main():
131            loop = asyncio.get_running_loop()
132            loop.call_exception_handler = call_exc_handler_mock
133
134            nonlocal lo_task
135            lo_task = asyncio.create_task(leftover())
136            return 123
137
138        self.assertEqual(asyncio.run(main()), 123)
139        self.assertTrue(lo_task.done())
140
141        call_exc_handler_mock.assert_called_with({
142            'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
143            'task': lo_task,
144            'exception': test_utils.MockInstanceOf(ZeroDivisionError)
145        })
146
147    def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
148        spinner = None
149        lazyboy = None
150
151        class FancyExit(Exception):
152            pass
153
154        async def fidget():
155            while True:
156                yield 1
157                await asyncio.sleep(1)
158
159        async def spin():
160            nonlocal spinner
161            spinner = fidget()
162            try:
163                async for the_meaning_of_life in spinner:  # NoQA
164                    pass
165            except asyncio.CancelledError:
166                1 / 0
167
168        async def main():
169            loop = asyncio.get_running_loop()
170            loop.call_exception_handler = mock.Mock()
171
172            nonlocal lazyboy
173            lazyboy = asyncio.create_task(spin())
174            raise FancyExit
175
176        with self.assertRaises(FancyExit):
177            asyncio.run(main())
178
179        self.assertTrue(lazyboy.done())
180
181        self.assertIsNone(spinner.ag_frame)
182        self.assertFalse(spinner.ag_running)
183