1import ast
2import sys
3import unittest
4
5
6funcdef = """\
7def foo():
8    # type: () -> int
9    pass
10
11def bar():  # type: () -> None
12    pass
13"""
14
15asyncdef = """\
16async def foo():
17    # type: () -> int
18    return await bar()
19
20async def bar():  # type: () -> int
21    return await bar()
22"""
23
24asyncvar = """\
25async = 12
26await = 13
27"""
28
29asynccomp = """\
30async def foo(xs):
31    [x async for x in xs]
32"""
33
34matmul = """\
35a = b @ c
36"""
37
38fstring = """\
39a = 42
40f"{a}"
41"""
42
43underscorednumber = """\
44a = 42_42_42
45"""
46
47redundantdef = """\
48def foo():  # type: () -> int
49    # type: () -> str
50    return ''
51"""
52
53nonasciidef = """\
54def foo():
55    # type: () -> àçčéñt
56    pass
57"""
58
59forstmt = """\
60for a in []:  # type: int
61    pass
62"""
63
64withstmt = """\
65with context() as a:  # type: int
66    pass
67"""
68
69vardecl = """\
70a = 0  # type: int
71"""
72
73ignores = """\
74def foo():
75    pass  # type: ignore
76
77def bar():
78    x = 1  # type: ignore
79
80def baz():
81    pass  # type: ignore[excuse]
82    pass  # type: ignore=excuse
83    pass  # type: ignore [excuse]
84    x = 1  # type: ignore whatever
85"""
86
87# Test for long-form type-comments in arguments.  A test function
88# named 'fabvk' would have two positional args, a and b, plus a
89# var-arg *v, plus a kw-arg **k.  It is verified in test_longargs()
90# that it has exactly these arguments, no more, no fewer.
91longargs = """\
92def fa(
93    a = 1,  # type: A
94):
95    pass
96
97def fa(
98    a = 1  # type: A
99):
100    pass
101
102def fa(
103    a = 1,  # type: A
104    /
105):
106    pass
107
108def fab(
109    a,  # type: A
110    b,  # type: B
111):
112    pass
113
114def fab(
115    a,  # type: A
116    /,
117    b,  # type: B
118):
119    pass
120
121def fab(
122    a,  # type: A
123    b   # type: B
124):
125    pass
126
127def fv(
128    *v,  # type: V
129):
130    pass
131
132def fv(
133    *v  # type: V
134):
135    pass
136
137def fk(
138    **k,  # type: K
139):
140    pass
141
142def fk(
143    **k  # type: K
144):
145    pass
146
147def fvk(
148    *v,  # type: V
149    **k,  # type: K
150):
151    pass
152
153def fvk(
154    *v,  # type: V
155    **k  # type: K
156):
157    pass
158
159def fav(
160    a,  # type: A
161    *v,  # type: V
162):
163    pass
164
165def fav(
166    a,  # type: A
167    /,
168    *v,  # type: V
169):
170    pass
171
172def fav(
173    a,  # type: A
174    *v  # type: V
175):
176    pass
177
178def fak(
179    a,  # type: A
180    **k,  # type: K
181):
182    pass
183
184def fak(
185    a,  # type: A
186    /,
187    **k,  # type: K
188):
189    pass
190
191def fak(
192    a,  # type: A
193    **k  # type: K
194):
195    pass
196
197def favk(
198    a,  # type: A
199    *v,  # type: V
200    **k,  # type: K
201):
202    pass
203
204def favk(
205    a,  # type: A
206    /,
207    *v,  # type: V
208    **k,  # type: K
209):
210    pass
211
212def favk(
213    a,  # type: A
214    *v,  # type: V
215    **k  # type: K
216):
217    pass
218"""
219
220
221class TypeCommentTests(unittest.TestCase):
222
223    lowest = 4  # Lowest minor version supported
224    highest = sys.version_info[1]  # Highest minor version
225
226    def parse(self, source, feature_version=highest):
227        return ast.parse(source, type_comments=True,
228                         feature_version=feature_version)
229
230    def parse_all(self, source, minver=lowest, maxver=highest, expected_regex=""):
231        for version in range(self.lowest, self.highest + 1):
232            feature_version = (3, version)
233            if minver <= version <= maxver:
234                try:
235                    yield self.parse(source, feature_version)
236                except SyntaxError as err:
237                    raise SyntaxError(str(err) + f" feature_version={feature_version}")
238            else:
239                with self.assertRaisesRegex(SyntaxError, expected_regex,
240                                            msg=f"feature_version={feature_version}"):
241                    self.parse(source, feature_version)
242
243    def classic_parse(self, source):
244        return ast.parse(source)
245
246    def test_funcdef(self):
247        for tree in self.parse_all(funcdef):
248            self.assertEqual(tree.body[0].type_comment, "() -> int")
249            self.assertEqual(tree.body[1].type_comment, "() -> None")
250        tree = self.classic_parse(funcdef)
251        self.assertEqual(tree.body[0].type_comment, None)
252        self.assertEqual(tree.body[1].type_comment, None)
253
254    def test_asyncdef(self):
255        for tree in self.parse_all(asyncdef, minver=5):
256            self.assertEqual(tree.body[0].type_comment, "() -> int")
257            self.assertEqual(tree.body[1].type_comment, "() -> int")
258        tree = self.classic_parse(asyncdef)
259        self.assertEqual(tree.body[0].type_comment, None)
260        self.assertEqual(tree.body[1].type_comment, None)
261
262    def test_asyncvar(self):
263        for tree in self.parse_all(asyncvar, maxver=6):
264            pass
265
266    def test_asynccomp(self):
267        for tree in self.parse_all(asynccomp, minver=6):
268            pass
269
270    def test_matmul(self):
271        for tree in self.parse_all(matmul, minver=5):
272            pass
273
274    def test_fstring(self):
275        for tree in self.parse_all(fstring, minver=6):
276            pass
277
278    def test_underscorednumber(self):
279        for tree in self.parse_all(underscorednumber, minver=6):
280            pass
281
282    def test_redundantdef(self):
283        for tree in self.parse_all(redundantdef, maxver=0,
284                                expected_regex="^Cannot have two type comments on def"):
285            pass
286
287    def test_nonasciidef(self):
288        for tree in self.parse_all(nonasciidef):
289            self.assertEqual(tree.body[0].type_comment, "() -> àçčéñt")
290
291    def test_forstmt(self):
292        for tree in self.parse_all(forstmt):
293            self.assertEqual(tree.body[0].type_comment, "int")
294        tree = self.classic_parse(forstmt)
295        self.assertEqual(tree.body[0].type_comment, None)
296
297    def test_withstmt(self):
298        for tree in self.parse_all(withstmt):
299            self.assertEqual(tree.body[0].type_comment, "int")
300        tree = self.classic_parse(withstmt)
301        self.assertEqual(tree.body[0].type_comment, None)
302
303    def test_vardecl(self):
304        for tree in self.parse_all(vardecl):
305            self.assertEqual(tree.body[0].type_comment, "int")
306        tree = self.classic_parse(vardecl)
307        self.assertEqual(tree.body[0].type_comment, None)
308
309    def test_ignores(self):
310        for tree in self.parse_all(ignores):
311            self.assertEqual(
312                [(ti.lineno, ti.tag) for ti in tree.type_ignores],
313                [
314                    (2, ''),
315                    (5, ''),
316                    (8, '[excuse]'),
317                    (9, '=excuse'),
318                    (10, ' [excuse]'),
319                    (11, ' whatever'),
320                ])
321        tree = self.classic_parse(ignores)
322        self.assertEqual(tree.type_ignores, [])
323
324    def test_longargs(self):
325        for tree in self.parse_all(longargs):
326            for t in tree.body:
327                # The expected args are encoded in the function name
328                todo = set(t.name[1:])
329                self.assertEqual(len(t.args.args) + len(t.args.posonlyargs),
330                                 len(todo) - bool(t.args.vararg) - bool(t.args.kwarg))
331                self.assertTrue(t.name.startswith('f'), t.name)
332                for index, c in enumerate(t.name[1:]):
333                    todo.remove(c)
334                    if c == 'v':
335                        arg = t.args.vararg
336                    elif c == 'k':
337                        arg = t.args.kwarg
338                    else:
339                        assert 0 <= ord(c) - ord('a') < len(t.args.posonlyargs + t.args.args)
340                        if index < len(t.args.posonlyargs):
341                            arg = t.args.posonlyargs[ord(c) - ord('a')]
342                        else:
343                            arg = t.args.args[ord(c) - ord('a') - len(t.args.posonlyargs)]
344                    self.assertEqual(arg.arg, c)  # That's the argument name
345                    self.assertEqual(arg.type_comment, arg.arg.upper())
346                assert not todo
347        tree = self.classic_parse(longargs)
348        for t in tree.body:
349            for arg in t.args.args + [t.args.vararg, t.args.kwarg]:
350                if arg is not None:
351                    self.assertIsNone(arg.type_comment, "%s(%s:%r)" %
352                                      (t.name, arg.arg, arg.type_comment))
353
354    def test_inappropriate_type_comments(self):
355        """Tests for inappropriately-placed type comments.
356
357        These should be silently ignored with type comments off,
358        but raise SyntaxError with type comments on.
359
360        This is not meant to be exhaustive.
361        """
362
363        def check_both_ways(source):
364            ast.parse(source, type_comments=False)
365            for tree in self.parse_all(source, maxver=0):
366                pass
367
368        check_both_ways("pass  # type: int\n")
369        check_both_ways("foo()  # type: int\n")
370        check_both_ways("x += 1  # type: int\n")
371        check_both_ways("while True:  # type: int\n  continue\n")
372        check_both_ways("while True:\n  continue  # type: int\n")
373        check_both_ways("try:  # type: int\n  pass\nfinally:\n  pass\n")
374        check_both_ways("try:\n  pass\nfinally:  # type: int\n  pass\n")
375        check_both_ways("pass  # type: ignorewhatever\n")
376        check_both_ways("pass  # type: ignoreé\n")
377
378    def test_func_type_input(self):
379
380        def parse_func_type_input(source):
381            return ast.parse(source, "<unknown>", "func_type")
382
383        # Some checks below will crash if the returned structure is wrong
384        tree = parse_func_type_input("() -> int")
385        self.assertEqual(tree.argtypes, [])
386        self.assertEqual(tree.returns.id, "int")
387
388        tree = parse_func_type_input("(int) -> List[str]")
389        self.assertEqual(len(tree.argtypes), 1)
390        arg = tree.argtypes[0]
391        self.assertEqual(arg.id, "int")
392        self.assertEqual(tree.returns.value.id, "List")
393        self.assertEqual(tree.returns.slice.value.id, "str")
394
395        tree = parse_func_type_input("(int, *str, **Any) -> float")
396        self.assertEqual(tree.argtypes[0].id, "int")
397        self.assertEqual(tree.argtypes[1].id, "str")
398        self.assertEqual(tree.argtypes[2].id, "Any")
399        self.assertEqual(tree.returns.id, "float")
400
401        with self.assertRaises(SyntaxError):
402            tree = parse_func_type_input("(int, *str, *Any) -> float")
403
404        with self.assertRaises(SyntaxError):
405            tree = parse_func_type_input("(int, **str, Any) -> float")
406
407        with self.assertRaises(SyntaxError):
408            tree = parse_func_type_input("(**int, **str) -> float")
409
410
411if __name__ == '__main__':
412    unittest.main()
413