1import ast 2import sys 3import unittest 4 5 6funcdef = """\ 7def foo(): 8 # type: () -> int 9 pass 10 11def bar(): # type: () -> None 12 pass 13""" 14 15asyncdef = """\ 16async def foo(): 17 # type: () -> int 18 return await bar() 19 20async def bar(): # type: () -> int 21 return await bar() 22""" 23 24asyncvar = """\ 25async = 12 26await = 13 27""" 28 29asynccomp = """\ 30async def foo(xs): 31 [x async for x in xs] 32""" 33 34matmul = """\ 35a = b @ c 36""" 37 38fstring = """\ 39a = 42 40f"{a}" 41""" 42 43underscorednumber = """\ 44a = 42_42_42 45""" 46 47redundantdef = """\ 48def foo(): # type: () -> int 49 # type: () -> str 50 return '' 51""" 52 53nonasciidef = """\ 54def foo(): 55 # type: () -> àçčéñt 56 pass 57""" 58 59forstmt = """\ 60for a in []: # type: int 61 pass 62""" 63 64withstmt = """\ 65with context() as a: # type: int 66 pass 67""" 68 69vardecl = """\ 70a = 0 # type: int 71""" 72 73ignores = """\ 74def foo(): 75 pass # type: ignore 76 77def bar(): 78 x = 1 # type: ignore 79 80def baz(): 81 pass # type: ignore[excuse] 82 pass # type: ignore=excuse 83 pass # type: ignore [excuse] 84 x = 1 # type: ignore whatever 85""" 86 87# Test for long-form type-comments in arguments. A test function 88# named 'fabvk' would have two positional args, a and b, plus a 89# var-arg *v, plus a kw-arg **k. It is verified in test_longargs() 90# that it has exactly these arguments, no more, no fewer. 91longargs = """\ 92def fa( 93 a = 1, # type: A 94): 95 pass 96 97def fa( 98 a = 1 # type: A 99): 100 pass 101 102def fa( 103 a = 1, # type: A 104 / 105): 106 pass 107 108def fab( 109 a, # type: A 110 b, # type: B 111): 112 pass 113 114def fab( 115 a, # type: A 116 /, 117 b, # type: B 118): 119 pass 120 121def fab( 122 a, # type: A 123 b # type: B 124): 125 pass 126 127def fv( 128 *v, # type: V 129): 130 pass 131 132def fv( 133 *v # type: V 134): 135 pass 136 137def fk( 138 **k, # type: K 139): 140 pass 141 142def fk( 143 **k # type: K 144): 145 pass 146 147def fvk( 148 *v, # type: V 149 **k, # type: K 150): 151 pass 152 153def fvk( 154 *v, # type: V 155 **k # type: K 156): 157 pass 158 159def fav( 160 a, # type: A 161 *v, # type: V 162): 163 pass 164 165def fav( 166 a, # type: A 167 /, 168 *v, # type: V 169): 170 pass 171 172def fav( 173 a, # type: A 174 *v # type: V 175): 176 pass 177 178def fak( 179 a, # type: A 180 **k, # type: K 181): 182 pass 183 184def fak( 185 a, # type: A 186 /, 187 **k, # type: K 188): 189 pass 190 191def fak( 192 a, # type: A 193 **k # type: K 194): 195 pass 196 197def favk( 198 a, # type: A 199 *v, # type: V 200 **k, # type: K 201): 202 pass 203 204def favk( 205 a, # type: A 206 /, 207 *v, # type: V 208 **k, # type: K 209): 210 pass 211 212def favk( 213 a, # type: A 214 *v, # type: V 215 **k # type: K 216): 217 pass 218""" 219 220 221class TypeCommentTests(unittest.TestCase): 222 223 lowest = 4 # Lowest minor version supported 224 highest = sys.version_info[1] # Highest minor version 225 226 def parse(self, source, feature_version=highest): 227 return ast.parse(source, type_comments=True, 228 feature_version=feature_version) 229 230 def parse_all(self, source, minver=lowest, maxver=highest, expected_regex=""): 231 for version in range(self.lowest, self.highest + 1): 232 feature_version = (3, version) 233 if minver <= version <= maxver: 234 try: 235 yield self.parse(source, feature_version) 236 except SyntaxError as err: 237 raise SyntaxError(str(err) + f" feature_version={feature_version}") 238 else: 239 with self.assertRaisesRegex(SyntaxError, expected_regex, 240 msg=f"feature_version={feature_version}"): 241 self.parse(source, feature_version) 242 243 def classic_parse(self, source): 244 return ast.parse(source) 245 246 def test_funcdef(self): 247 for tree in self.parse_all(funcdef): 248 self.assertEqual(tree.body[0].type_comment, "() -> int") 249 self.assertEqual(tree.body[1].type_comment, "() -> None") 250 tree = self.classic_parse(funcdef) 251 self.assertEqual(tree.body[0].type_comment, None) 252 self.assertEqual(tree.body[1].type_comment, None) 253 254 def test_asyncdef(self): 255 for tree in self.parse_all(asyncdef, minver=5): 256 self.assertEqual(tree.body[0].type_comment, "() -> int") 257 self.assertEqual(tree.body[1].type_comment, "() -> int") 258 tree = self.classic_parse(asyncdef) 259 self.assertEqual(tree.body[0].type_comment, None) 260 self.assertEqual(tree.body[1].type_comment, None) 261 262 def test_asyncvar(self): 263 for tree in self.parse_all(asyncvar, maxver=6): 264 pass 265 266 def test_asynccomp(self): 267 for tree in self.parse_all(asynccomp, minver=6): 268 pass 269 270 def test_matmul(self): 271 for tree in self.parse_all(matmul, minver=5): 272 pass 273 274 def test_fstring(self): 275 for tree in self.parse_all(fstring, minver=6): 276 pass 277 278 def test_underscorednumber(self): 279 for tree in self.parse_all(underscorednumber, minver=6): 280 pass 281 282 def test_redundantdef(self): 283 for tree in self.parse_all(redundantdef, maxver=0, 284 expected_regex="^Cannot have two type comments on def"): 285 pass 286 287 def test_nonasciidef(self): 288 for tree in self.parse_all(nonasciidef): 289 self.assertEqual(tree.body[0].type_comment, "() -> àçčéñt") 290 291 def test_forstmt(self): 292 for tree in self.parse_all(forstmt): 293 self.assertEqual(tree.body[0].type_comment, "int") 294 tree = self.classic_parse(forstmt) 295 self.assertEqual(tree.body[0].type_comment, None) 296 297 def test_withstmt(self): 298 for tree in self.parse_all(withstmt): 299 self.assertEqual(tree.body[0].type_comment, "int") 300 tree = self.classic_parse(withstmt) 301 self.assertEqual(tree.body[0].type_comment, None) 302 303 def test_vardecl(self): 304 for tree in self.parse_all(vardecl): 305 self.assertEqual(tree.body[0].type_comment, "int") 306 tree = self.classic_parse(vardecl) 307 self.assertEqual(tree.body[0].type_comment, None) 308 309 def test_ignores(self): 310 for tree in self.parse_all(ignores): 311 self.assertEqual( 312 [(ti.lineno, ti.tag) for ti in tree.type_ignores], 313 [ 314 (2, ''), 315 (5, ''), 316 (8, '[excuse]'), 317 (9, '=excuse'), 318 (10, ' [excuse]'), 319 (11, ' whatever'), 320 ]) 321 tree = self.classic_parse(ignores) 322 self.assertEqual(tree.type_ignores, []) 323 324 def test_longargs(self): 325 for tree in self.parse_all(longargs): 326 for t in tree.body: 327 # The expected args are encoded in the function name 328 todo = set(t.name[1:]) 329 self.assertEqual(len(t.args.args) + len(t.args.posonlyargs), 330 len(todo) - bool(t.args.vararg) - bool(t.args.kwarg)) 331 self.assertTrue(t.name.startswith('f'), t.name) 332 for index, c in enumerate(t.name[1:]): 333 todo.remove(c) 334 if c == 'v': 335 arg = t.args.vararg 336 elif c == 'k': 337 arg = t.args.kwarg 338 else: 339 assert 0 <= ord(c) - ord('a') < len(t.args.posonlyargs + t.args.args) 340 if index < len(t.args.posonlyargs): 341 arg = t.args.posonlyargs[ord(c) - ord('a')] 342 else: 343 arg = t.args.args[ord(c) - ord('a') - len(t.args.posonlyargs)] 344 self.assertEqual(arg.arg, c) # That's the argument name 345 self.assertEqual(arg.type_comment, arg.arg.upper()) 346 assert not todo 347 tree = self.classic_parse(longargs) 348 for t in tree.body: 349 for arg in t.args.args + [t.args.vararg, t.args.kwarg]: 350 if arg is not None: 351 self.assertIsNone(arg.type_comment, "%s(%s:%r)" % 352 (t.name, arg.arg, arg.type_comment)) 353 354 def test_inappropriate_type_comments(self): 355 """Tests for inappropriately-placed type comments. 356 357 These should be silently ignored with type comments off, 358 but raise SyntaxError with type comments on. 359 360 This is not meant to be exhaustive. 361 """ 362 363 def check_both_ways(source): 364 ast.parse(source, type_comments=False) 365 for tree in self.parse_all(source, maxver=0): 366 pass 367 368 check_both_ways("pass # type: int\n") 369 check_both_ways("foo() # type: int\n") 370 check_both_ways("x += 1 # type: int\n") 371 check_both_ways("while True: # type: int\n continue\n") 372 check_both_ways("while True:\n continue # type: int\n") 373 check_both_ways("try: # type: int\n pass\nfinally:\n pass\n") 374 check_both_ways("try:\n pass\nfinally: # type: int\n pass\n") 375 check_both_ways("pass # type: ignorewhatever\n") 376 check_both_ways("pass # type: ignoreé\n") 377 378 def test_func_type_input(self): 379 380 def parse_func_type_input(source): 381 return ast.parse(source, "<unknown>", "func_type") 382 383 # Some checks below will crash if the returned structure is wrong 384 tree = parse_func_type_input("() -> int") 385 self.assertEqual(tree.argtypes, []) 386 self.assertEqual(tree.returns.id, "int") 387 388 tree = parse_func_type_input("(int) -> List[str]") 389 self.assertEqual(len(tree.argtypes), 1) 390 arg = tree.argtypes[0] 391 self.assertEqual(arg.id, "int") 392 self.assertEqual(tree.returns.value.id, "List") 393 self.assertEqual(tree.returns.slice.value.id, "str") 394 395 tree = parse_func_type_input("(int, *str, **Any) -> float") 396 self.assertEqual(tree.argtypes[0].id, "int") 397 self.assertEqual(tree.argtypes[1].id, "str") 398 self.assertEqual(tree.argtypes[2].id, "Any") 399 self.assertEqual(tree.returns.id, "float") 400 401 with self.assertRaises(SyntaxError): 402 tree = parse_func_type_input("(int, *str, *Any) -> float") 403 404 with self.assertRaises(SyntaxError): 405 tree = parse_func_type_input("(int, **str, Any) -> float") 406 407 with self.assertRaises(SyntaxError): 408 tree = parse_func_type_input("(**int, **str) -> float") 409 410 411if __name__ == '__main__': 412 unittest.main() 413