1"""
2Test for async helpers.
3
4Should only trigger on python 3.5+ or will have syntax errors.
5"""
6from itertools import chain, repeat
7import nose.tools as nt
8from textwrap import dedent, indent
9from unittest import TestCase
10from IPython.testing.decorators import skip_without
11import sys
12
13iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
14iprc_nr = lambda x: ip.run_cell(dedent(x))
15
16from IPython.core.async_helpers import _should_be_async
17
18class AsyncTest(TestCase):
19    def test_should_be_async(self):
20        nt.assert_false(_should_be_async("False"))
21        nt.assert_true(_should_be_async("await bar()"))
22        nt.assert_true(_should_be_async("x = await bar()"))
23        nt.assert_false(
24            _should_be_async(
25                dedent(
26                    """
27            async def awaitable():
28                pass
29        """
30                )
31            )
32        )
33
34    def _get_top_level_cases(self):
35        # These are test cases that should be valid in a function
36        # but invalid outside of a function.
37        test_cases = []
38        test_cases.append(('basic', "{val}"))
39
40        # Note, in all conditional cases, I use True instead of
41        # False so that the peephole optimizer won't optimize away
42        # the return, so CPython will see this as a syntax error:
43        #
44        # while True:
45        #    break
46        #    return
47        #
48        # But not this:
49        #
50        # while False:
51        #    return
52        #
53        # See https://bugs.python.org/issue1875
54
55        test_cases.append(('if', dedent("""
56        if True:
57            {val}
58        """)))
59
60        test_cases.append(('while', dedent("""
61        while True:
62            {val}
63            break
64        """)))
65
66        test_cases.append(('try', dedent("""
67        try:
68            {val}
69        except:
70            pass
71        """)))
72
73        test_cases.append(('except', dedent("""
74        try:
75            pass
76        except:
77            {val}
78        """)))
79
80        test_cases.append(('finally', dedent("""
81        try:
82            pass
83        except:
84            pass
85        finally:
86            {val}
87        """)))
88
89        test_cases.append(('for', dedent("""
90        for _ in range(4):
91            {val}
92        """)))
93
94
95        test_cases.append(('nested', dedent("""
96        if True:
97            while True:
98                {val}
99                break
100        """)))
101
102        test_cases.append(('deep-nested', dedent("""
103        if True:
104            while True:
105                break
106                for x in range(3):
107                    if True:
108                        while True:
109                            for x in range(3):
110                                {val}
111        """)))
112
113        return test_cases
114
115    def _get_ry_syntax_errors(self):
116        # This is a mix of tests that should be a syntax error if
117        # return or yield whether or not they are in a function
118
119        test_cases = []
120
121        test_cases.append(('class', dedent("""
122        class V:
123            {val}
124        """)))
125
126        test_cases.append(('nested-class', dedent("""
127        class V:
128            class C:
129                {val}
130        """)))
131
132        return test_cases
133
134
135    def test_top_level_return_error(self):
136        tl_err_test_cases = self._get_top_level_cases()
137        tl_err_test_cases.extend(self._get_ry_syntax_errors())
138
139        vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
140                dedent('''
141                    def f():
142                        pass
143                    return
144                    '''),
145                )
146
147        for test_name, test_case in tl_err_test_cases:
148            # This example should work if 'pass' is used as the value
149            with self.subTest((test_name, 'pass')):
150                iprc(test_case.format(val='pass'))
151
152            # It should fail with all the values
153            for val in vals:
154                with self.subTest((test_name, val)):
155                    msg = "Syntax error not raised for %s, %s" % (test_name, val)
156                    with self.assertRaises(SyntaxError, msg=msg):
157                        iprc(test_case.format(val=val))
158
159    def test_in_func_no_error(self):
160        # Test that the implementation of top-level return/yield
161        # detection isn't *too* aggressive, and works inside a function
162        func_contexts = []
163
164        func_contexts.append(('func', False, dedent("""
165        def f():""")))
166
167        func_contexts.append(('method', False, dedent("""
168        class MyClass:
169            def __init__(self):
170        """)))
171
172        func_contexts.append(('async-func', True,  dedent("""
173        async def f():""")))
174
175        func_contexts.append(('async-method', True,  dedent("""
176        class MyClass:
177            async def f(self):""")))
178
179        func_contexts.append(('closure', False, dedent("""
180        def f():
181            def g():
182        """)))
183
184        def nest_case(context, case):
185            # Detect indentation
186            lines = context.strip().splitlines()
187            prefix_len = 0
188            for c in lines[-1]:
189                if c != ' ':
190                    break
191                prefix_len += 1
192
193            indented_case = indent(case, ' ' * (prefix_len + 4))
194            return context + '\n' + indented_case
195
196        # Gather and run the tests
197
198        # yield is allowed in async functions, starting in Python 3.6,
199        # and yield from is not allowed in any version
200        vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
201        async_safe = (True,
202                        True,
203                        False)
204        vals = tuple(zip(vals, async_safe))
205
206        success_tests = zip(self._get_top_level_cases(), repeat(False))
207        failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
208
209        tests = chain(success_tests, failure_tests)
210
211        for context_name, async_func, context in func_contexts:
212            for (test_name, test_case), should_fail in tests:
213                nested_case = nest_case(context, test_case)
214
215                for val, async_safe in vals:
216                    val_should_fail = (should_fail or
217                                        (async_func and not async_safe))
218
219                    test_id = (context_name, test_name, val)
220                    cell = nested_case.format(val=val)
221
222                    with self.subTest(test_id):
223                        if val_should_fail:
224                            msg = ("SyntaxError not raised for %s" %
225                                    str(test_id))
226                            with self.assertRaises(SyntaxError, msg=msg):
227                                iprc(cell)
228
229                                print(cell)
230                        else:
231                            iprc(cell)
232
233    def test_nonlocal(self):
234        # fails if outer scope is not a function scope or if var not defined
235        with self.assertRaises(SyntaxError):
236            iprc("nonlocal x")
237            iprc("""
238            x = 1
239            def f():
240                nonlocal x
241                x = 10000
242                yield x
243            """)
244            iprc("""
245            def f():
246                def g():
247                    nonlocal x
248                    x = 10000
249                    yield x
250            """)
251
252        # works if outer scope is a function scope and var exists
253        iprc("""
254        def f():
255            x = 20
256            def g():
257                nonlocal x
258                x = 10000
259                yield x
260        """)
261
262
263    def test_execute(self):
264        iprc("""
265        import asyncio
266        await asyncio.sleep(0.001)
267        """
268        )
269
270    def test_autoawait(self):
271        iprc("%autoawait False")
272        iprc("%autoawait True")
273        iprc("""
274        from asyncio import sleep
275        await sleep(0.1)
276        """
277        )
278
279    if sys.version_info < (3,9):
280        # new pgen parser in 3.9 does not raise MemoryError on too many nested
281        # parens anymore
282        def test_memory_error(self):
283            with self.assertRaises(MemoryError):
284                iprc("(" * 200 + ")" * 200)
285
286    @skip_without('curio')
287    def test_autoawait_curio(self):
288        iprc("%autoawait curio")
289
290    @skip_without('trio')
291    def test_autoawait_trio(self):
292        iprc("%autoawait trio")
293
294    @skip_without('trio')
295    def test_autoawait_trio_wrong_sleep(self):
296        iprc("%autoawait trio")
297        res = iprc_nr("""
298        import asyncio
299        await asyncio.sleep(0)
300        """)
301        with nt.assert_raises(TypeError):
302            res.raise_error()
303
304    @skip_without('trio')
305    def test_autoawait_asyncio_wrong_sleep(self):
306        iprc("%autoawait asyncio")
307        res = iprc_nr("""
308        import trio
309        await trio.sleep(0)
310        """)
311        with nt.assert_raises(RuntimeError):
312            res.raise_error()
313
314
315    def tearDown(self):
316        ip.loop_runner = "asyncio"
317