1""" 2Test for async helpers. 3 4Should only trigger on python 3.5+ or will have syntax errors. 5""" 6from itertools import chain, repeat 7import nose.tools as nt 8from textwrap import dedent, indent 9from unittest import TestCase 10from IPython.testing.decorators import skip_without 11import sys 12 13iprc = lambda x: ip.run_cell(dedent(x)).raise_error() 14iprc_nr = lambda x: ip.run_cell(dedent(x)) 15 16from IPython.core.async_helpers import _should_be_async 17 18class AsyncTest(TestCase): 19 def test_should_be_async(self): 20 nt.assert_false(_should_be_async("False")) 21 nt.assert_true(_should_be_async("await bar()")) 22 nt.assert_true(_should_be_async("x = await bar()")) 23 nt.assert_false( 24 _should_be_async( 25 dedent( 26 """ 27 async def awaitable(): 28 pass 29 """ 30 ) 31 ) 32 ) 33 34 def _get_top_level_cases(self): 35 # These are test cases that should be valid in a function 36 # but invalid outside of a function. 37 test_cases = [] 38 test_cases.append(('basic', "{val}")) 39 40 # Note, in all conditional cases, I use True instead of 41 # False so that the peephole optimizer won't optimize away 42 # the return, so CPython will see this as a syntax error: 43 # 44 # while True: 45 # break 46 # return 47 # 48 # But not this: 49 # 50 # while False: 51 # return 52 # 53 # See https://bugs.python.org/issue1875 54 55 test_cases.append(('if', dedent(""" 56 if True: 57 {val} 58 """))) 59 60 test_cases.append(('while', dedent(""" 61 while True: 62 {val} 63 break 64 """))) 65 66 test_cases.append(('try', dedent(""" 67 try: 68 {val} 69 except: 70 pass 71 """))) 72 73 test_cases.append(('except', dedent(""" 74 try: 75 pass 76 except: 77 {val} 78 """))) 79 80 test_cases.append(('finally', dedent(""" 81 try: 82 pass 83 except: 84 pass 85 finally: 86 {val} 87 """))) 88 89 test_cases.append(('for', dedent(""" 90 for _ in range(4): 91 {val} 92 """))) 93 94 95 test_cases.append(('nested', dedent(""" 96 if True: 97 while True: 98 {val} 99 break 100 """))) 101 102 test_cases.append(('deep-nested', dedent(""" 103 if True: 104 while True: 105 break 106 for x in range(3): 107 if True: 108 while True: 109 for x in range(3): 110 {val} 111 """))) 112 113 return test_cases 114 115 def _get_ry_syntax_errors(self): 116 # This is a mix of tests that should be a syntax error if 117 # return or yield whether or not they are in a function 118 119 test_cases = [] 120 121 test_cases.append(('class', dedent(""" 122 class V: 123 {val} 124 """))) 125 126 test_cases.append(('nested-class', dedent(""" 127 class V: 128 class C: 129 {val} 130 """))) 131 132 return test_cases 133 134 135 def test_top_level_return_error(self): 136 tl_err_test_cases = self._get_top_level_cases() 137 tl_err_test_cases.extend(self._get_ry_syntax_errors()) 138 139 vals = ('return', 'yield', 'yield from (_ for _ in range(3))', 140 dedent(''' 141 def f(): 142 pass 143 return 144 '''), 145 ) 146 147 for test_name, test_case in tl_err_test_cases: 148 # This example should work if 'pass' is used as the value 149 with self.subTest((test_name, 'pass')): 150 iprc(test_case.format(val='pass')) 151 152 # It should fail with all the values 153 for val in vals: 154 with self.subTest((test_name, val)): 155 msg = "Syntax error not raised for %s, %s" % (test_name, val) 156 with self.assertRaises(SyntaxError, msg=msg): 157 iprc(test_case.format(val=val)) 158 159 def test_in_func_no_error(self): 160 # Test that the implementation of top-level return/yield 161 # detection isn't *too* aggressive, and works inside a function 162 func_contexts = [] 163 164 func_contexts.append(('func', False, dedent(""" 165 def f():"""))) 166 167 func_contexts.append(('method', False, dedent(""" 168 class MyClass: 169 def __init__(self): 170 """))) 171 172 func_contexts.append(('async-func', True, dedent(""" 173 async def f():"""))) 174 175 func_contexts.append(('async-method', True, dedent(""" 176 class MyClass: 177 async def f(self):"""))) 178 179 func_contexts.append(('closure', False, dedent(""" 180 def f(): 181 def g(): 182 """))) 183 184 def nest_case(context, case): 185 # Detect indentation 186 lines = context.strip().splitlines() 187 prefix_len = 0 188 for c in lines[-1]: 189 if c != ' ': 190 break 191 prefix_len += 1 192 193 indented_case = indent(case, ' ' * (prefix_len + 4)) 194 return context + '\n' + indented_case 195 196 # Gather and run the tests 197 198 # yield is allowed in async functions, starting in Python 3.6, 199 # and yield from is not allowed in any version 200 vals = ('return', 'yield', 'yield from (_ for _ in range(3))') 201 async_safe = (True, 202 True, 203 False) 204 vals = tuple(zip(vals, async_safe)) 205 206 success_tests = zip(self._get_top_level_cases(), repeat(False)) 207 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True)) 208 209 tests = chain(success_tests, failure_tests) 210 211 for context_name, async_func, context in func_contexts: 212 for (test_name, test_case), should_fail in tests: 213 nested_case = nest_case(context, test_case) 214 215 for val, async_safe in vals: 216 val_should_fail = (should_fail or 217 (async_func and not async_safe)) 218 219 test_id = (context_name, test_name, val) 220 cell = nested_case.format(val=val) 221 222 with self.subTest(test_id): 223 if val_should_fail: 224 msg = ("SyntaxError not raised for %s" % 225 str(test_id)) 226 with self.assertRaises(SyntaxError, msg=msg): 227 iprc(cell) 228 229 print(cell) 230 else: 231 iprc(cell) 232 233 def test_nonlocal(self): 234 # fails if outer scope is not a function scope or if var not defined 235 with self.assertRaises(SyntaxError): 236 iprc("nonlocal x") 237 iprc(""" 238 x = 1 239 def f(): 240 nonlocal x 241 x = 10000 242 yield x 243 """) 244 iprc(""" 245 def f(): 246 def g(): 247 nonlocal x 248 x = 10000 249 yield x 250 """) 251 252 # works if outer scope is a function scope and var exists 253 iprc(""" 254 def f(): 255 x = 20 256 def g(): 257 nonlocal x 258 x = 10000 259 yield x 260 """) 261 262 263 def test_execute(self): 264 iprc(""" 265 import asyncio 266 await asyncio.sleep(0.001) 267 """ 268 ) 269 270 def test_autoawait(self): 271 iprc("%autoawait False") 272 iprc("%autoawait True") 273 iprc(""" 274 from asyncio import sleep 275 await sleep(0.1) 276 """ 277 ) 278 279 if sys.version_info < (3,9): 280 # new pgen parser in 3.9 does not raise MemoryError on too many nested 281 # parens anymore 282 def test_memory_error(self): 283 with self.assertRaises(MemoryError): 284 iprc("(" * 200 + ")" * 200) 285 286 @skip_without('curio') 287 def test_autoawait_curio(self): 288 iprc("%autoawait curio") 289 290 @skip_without('trio') 291 def test_autoawait_trio(self): 292 iprc("%autoawait trio") 293 294 @skip_without('trio') 295 def test_autoawait_trio_wrong_sleep(self): 296 iprc("%autoawait trio") 297 res = iprc_nr(""" 298 import asyncio 299 await asyncio.sleep(0) 300 """) 301 with nt.assert_raises(TypeError): 302 res.raise_error() 303 304 @skip_without('trio') 305 def test_autoawait_asyncio_wrong_sleep(self): 306 iprc("%autoawait asyncio") 307 res = iprc_nr(""" 308 import trio 309 await trio.sleep(0) 310 """) 311 with nt.assert_raises(RuntimeError): 312 res.raise_error() 313 314 315 def tearDown(self): 316 ip.loop_runner = "asyncio" 317