1from unittest import TestCase
2import gast as ast
3import beniget
4import sys
5
6
7class StrictDefUseChains(beniget.DefUseChains):
8    def unbound_identifier(self, name, node):
9        raise RuntimeError(
10            "W: unbound identifier '{}' at {}:{}".format(
11                name, node.lineno, node.col_offset
12            )
13        )
14
15
16class TestGlobals(TestCase):
17    def checkGlobals(self, code, ref):
18        node = ast.parse(code)
19        c = StrictDefUseChains()
20        c.visit(node)
21        self.assertEqual(c.dump_definitions(node), ref)
22
23    def test_SingleFunctionDef(self):
24        code = "def foo(): pass"
25        self.checkGlobals(code, ["foo"])
26
27    def test_MultipleFunctionDef(self):
28        code = "def foo(): pass\ndef bar(): return"
29        self.checkGlobals(code, ["bar", "foo"])
30
31    def testFuntionRedefinion(self):
32        code = "def foo(): pass\ndef foo(): return"
33        self.checkGlobals(code, ["foo", "foo"])
34
35    def testFuntionNested(self):
36        code = "def foo():\n def bar(): return"
37        self.checkGlobals(code, ["foo"])
38
39    if sys.version_info.major >= 3:
40
41        def testAsyncFunctionDef(self):
42            code = "async def foo(): pass"
43            self.checkGlobals(code, ["foo"])
44
45    def testClassDef(self):
46        code = "class C:pass"
47        self.checkGlobals(code, ["C"])
48
49    def testDelClassDef(self):
50        code = "class C:pass\ndel C"
51        self.checkGlobals(code, ["C"])
52
53    def testDelClassDefReDef(self):
54        code = "class C:pass\ndel C\nclass C:pass"
55        self.checkGlobals(code, ["C", "C"])
56
57    def testNestedClassDef(self):
58        code = "class C:\n class D: pass"
59        self.checkGlobals(code, ["C"])
60
61    def testMultipleClassDef(self):
62        code = "class C: pass\nclass D: pass"
63        self.checkGlobals(code, ["C", "D"])
64
65    def testClassRedefinition(self):
66        code = "class C: pass\nclass C: pass"
67        self.checkGlobals(code, ["C", "C"])
68
69    def testClassMethodDef(self):
70        code = "class C:\n def some(self):pass"
71        self.checkGlobals(code, ["C"])
72
73    def testGlobalDef(self):
74        code = "x = 1"
75        self.checkGlobals(code, ["x"])
76
77    if sys.version_info.major >= 3:
78
79        def testGlobalAnnotatedDef(self):
80            code = "x : 1"
81            self.checkGlobals(code, ["x"])
82
83    def testMultipleGlobalDef(self):
84        code = "x = 1; x = 2"
85        self.checkGlobals(code, ["x", "x"])
86
87    def testGlobalDestructuring(self):
88        code = "x, y = 1, 2"
89        self.checkGlobals(code, ["x", "y"])
90
91    def testGlobalAugAssign(self):
92        code = "x = 1; x += 2"
93        self.checkGlobals(code, ["x"])
94
95    def testGlobalFor(self):
96        code = "for x in (1,2): pass"
97        self.checkGlobals(code, ["x"])
98
99    def testGlobalForDestructuring(self):
100        code = "for x, y in [(1,2)]: pass"
101        self.checkGlobals(code, ["x", "y"])
102
103    def testGlobalNestedFor(self):
104        code = "for x in (1,2):\n for y in (2, 1): pass"
105        self.checkGlobals(code, ["x", "y"])
106
107    def testGlobalInFor(self):
108        code = "for x in (1,2): y = x"
109        self.checkGlobals(code, ["x", "y"])
110
111    if sys.version_info >= (3, 7):
112
113        def testGlobalAsyncFor(self):
114            code = "async for x in (1,2): pass"
115            self.checkGlobals(code, ["x"])
116
117    def testGlobalInWhile(self):
118        code = "while True: x = 1"
119        self.checkGlobals(code, ["x"])
120
121    def testGlobalInIfTrueBranch(self):
122        code = "if 1: a = 1"
123        self.checkGlobals(code, ["a"])
124
125    def testGlobalInIfFalseBranch(self):
126        code = "if 1: pass\nelse: a = 1"
127        self.checkGlobals(code, ["a"])
128
129    def testGlobalInIfBothBranch(self):
130        code = "if 1: a = 1\nelse: a = 2"
131        self.checkGlobals(code, ["a", "a"])
132
133    def testGlobalInIfBothBranchDifferent(self):
134        code = "if 1: a = 1\nelse: b = 2"
135        self.checkGlobals(code, ["a", "b"])
136
137    def testGlobalWith(self):
138        code = "from some import foo\nwith foo() as x: pass"
139        self.checkGlobals(code, ["foo", "x"])
140
141    if sys.version_info >= (3, 7):
142
143        def testGlobalAsyncWith(self):
144            code = "from some import foo\nasync with foo() as x: pass"
145            self.checkGlobals(code, ["foo", "x"])
146
147    def testGlobalTry(self):
148        code = "try: x = 1\nexcept Exception: pass"
149        self.checkGlobals(code, ["x"])
150
151    def testGlobalTryExcept(self):
152        code = "from some import foo\ntry: foo()\nexcept Exception as e: pass"
153        self.checkGlobals(code, ["e", "foo"])
154
155    def testGlobalTryExceptFinally(self):
156        code = "try: w = 1\nexcept Exception as x: y = 1\nfinally: z = 1"
157        self.checkGlobals(code, ["w", "x", "y", "z"])
158
159    def testGlobalThroughKeyword(self):
160        code = "def foo(): global x"
161        self.checkGlobals(code, ["foo", "x"])
162
163    def testGlobalThroughKeywords(self):
164        code = "def foo(): global x, y"
165        self.checkGlobals(code, ["foo", "x", "y"])
166
167    def testGlobalThroughMultipleKeyword(self):
168        code = "def foo(): global x\ndef bar(): global x"
169        self.checkGlobals(code, ["bar", "foo", "x"])
170
171    def testGlobalBeforeKeyword(self):
172        code = "x = 1\ndef foo(): global x"
173        self.checkGlobals(code, ["foo", "x"])
174
175    def testGlobalsBeforeKeyword(self):
176        code = "x = 1\ndef foo(): global x, y"
177        self.checkGlobals(code, ["foo", "x", "y"])
178
179    if sys.version_info.major >= 3:
180
181        def testGlobalAfterKeyword(self):
182            code = "def foo(): global x\nx : 1"
183            self.checkGlobals(code, ["foo", "x"])
184
185        def testGlobalsAfterKeyword(self):
186            code = "def foo(): global x, y\ny : 1"
187            self.checkGlobals(code, ["foo", "x", "y"])
188
189    def testGlobalImport(self):
190        code = "import foo"
191        self.checkGlobals(code, ["foo"])
192
193    def testGlobalImports(self):
194        code = "import foo, bar"
195        self.checkGlobals(code, ["bar", "foo"])
196
197    def testGlobalImportSubModule(self):
198        code = "import foo.bar"
199        self.checkGlobals(code, ["foo"])
200
201    def testGlobalImportSubModuleAs(self):
202        code = "import foo.bar as foobar"
203        self.checkGlobals(code, ["foobar"])
204
205    def testGlobalImportAs(self):
206        code = "import foo as bar"
207        self.checkGlobals(code, ["bar"])
208
209    def testGlobalImportsAs(self):
210        code = "import foo as bar, foobar"
211        self.checkGlobals(code, ["bar", "foobar"])
212
213    def testGlobalImportFrom(self):
214        code = "from foo import bar"
215        self.checkGlobals(code, ["bar"])
216
217    def testGlobalImportFromAs(self):
218        code = "from foo import bar as BAR"
219        self.checkGlobals(code, ["BAR"])
220
221    def testGlobalImportFromStar(self):
222        code = "from foo import *"
223        self.checkGlobals(code, ["*"])
224
225    def testGlobalImportFromStarRedefine(self):
226        code = "from foo import *\nx+=1"
227        self.checkGlobals(code, ["*", "x"])
228
229    def testGlobalImportsFrom(self):
230        code = "from foo import bar, man"
231        self.checkGlobals(code, ["bar", "man"])
232
233    def testGlobalImportsFromAs(self):
234        code = "from foo import bar, man as maid"
235        self.checkGlobals(code, ["bar", "maid"])
236
237    def testGlobalListComp(self):
238        code = "from some import y; [1 for x in y]"
239        if sys.version_info.major == 2:
240            self.checkGlobals(code, ["x", "y"])
241        else:
242            self.checkGlobals(code, ["y"])
243
244    def testGlobalSetComp(self):
245        code = "from some import y; {1 for x in y}"
246        if sys.version_info.major == 2:
247            self.checkGlobals(code, ["x", "y"])
248        else:
249            self.checkGlobals(code, ["y"])
250
251    def testGlobalDictComp(self):
252        code = "from some import y; {1:1 for x in y}"
253        if sys.version_info.major == 2:
254            self.checkGlobals(code, ["x", "y"])
255        else:
256            self.checkGlobals(code, ["y"])
257
258    def testGlobalGeneratorExpr(self):
259        code = "from some import y; (1 for x in y)"
260        if sys.version_info.major == 2:
261            self.checkGlobals(code, ["x", "y"])
262        else:
263            self.checkGlobals(code, ["y"])
264
265    def testGlobalLambda(self):
266        code = "lambda x: x"
267        self.checkGlobals(code, [])
268
269
270class TestClasses(TestCase):
271    def checkClasses(self, code, ref):
272        node = ast.parse(code)
273        c = StrictDefUseChains()
274        c.visit(node)
275        classes = [n for n in node.body if isinstance(n, ast.ClassDef)]
276        assert len(classes) == 1, "only one top-level function per test case"
277        cls = classes[0]
278        self.assertEqual(c.dump_definitions(cls), ref)
279
280    def test_class_method_assign(self):
281        code = "class C:\n def foo(self):pass\n bar = foo"
282        self.checkClasses(code, ["bar", "foo"])
283
284
285class TestLocals(TestCase):
286    def checkLocals(self, code, ref):
287        node = ast.parse(code)
288        c = StrictDefUseChains()
289        c.visit(node)
290        functions = [n for n in node.body if isinstance(n, ast.FunctionDef)]
291        assert len(functions) == 1, "only one top-level function per test case"
292        f = functions[0]
293        self.assertEqual(c.dump_definitions(f), ref)
294
295    def testLocalFunctionDef(self):
296        code = "def foo(): pass"
297        self.checkLocals(code, [])
298
299    def testLocalFunctionDefOneArg(self):
300        code = "def foo(a): pass"
301        self.checkLocals(code, ["a"])
302
303    def testLocalFunctionDefOneArgDefault(self):
304        code = "def foo(a=1): pass"
305        self.checkLocals(code, ["a"])
306
307    def testLocalFunctionDefArgsDefault(self):
308        code = "def foo(a, b=1): pass"
309        self.checkLocals(code, ["a", "b"])
310
311    def testLocalFunctionDefStarArgs(self):
312        code = "def foo(a, *b): pass"
313        self.checkLocals(code, ["a", "b"])
314
315    def testLocalFunctionDefKwArgs(self):
316        code = "def foo(a, **b): pass"
317        self.checkLocals(code, ["a", "b"])
318
319    if sys.version_info.major >= 3:
320
321        def testLocalFunctionDefKwOnly(self):
322            code = "def foo(a, *, b=1): pass"
323            self.checkLocals(code, ["a", "b"])
324
325    if sys.version_info.major == 2:
326
327        def testLocalFunctionDefDestructureArg(self):
328            code = "def foo((a, b)): pass"
329            self.checkLocals(code, ["a", "b"])
330
331    def test_LocalAssign(self):
332        code = "def foo(): a = 1"
333        self.checkLocals(code, ["a"])
334
335    def test_LocalAssignRedef(self):
336        code = "def foo(a): a = 1"
337        self.checkLocals(code, ["a", "a"])
338
339    def test_LocalNestedFun(self):
340        code = "def foo(a):\n def bar(): return a\n return bar"
341        self.checkLocals(code, ["a", "bar"])
342
343    if sys.version_info.major >= 3:
344
345        def test_LocalNonLocalBefore(self):
346            code = "def foo(a):\n def bar():\n  nonlocal a; a = 1\n bar(); return a"
347            self.checkLocals(code, ["a", "bar"])
348
349        def test_LocalNonLocalAfter(self):
350            code = (
351                "def foo():\n def bar():\n  nonlocal a; a = 1\n a = 2; bar(); return a"
352            )
353            self.checkLocals(code, ["a", "bar"])
354
355    def test_LocalGlobal(self):
356        code = "def foo(): global a; a = 1"
357        self.checkLocals(code, [])
358
359    def test_ListCompInLoop(self):
360        code = "def foo(i):\n for j in i:\n  [k for k in j]"
361        if sys.version_info.major == 2:
362            self.checkLocals(code, ["i", "j", "k"])
363        else:
364            self.checkLocals(code, ["i", "j"])
365
366    def test_AugAssignInLoop(self):
367        code = """
368def foo(X, f):
369    for i in range(2):
370        if i == 0: A = f * X[:, i]
371        else: A += f * X[:, i]
372    return A"""
373        self.checkLocals(code, ["A", "X", "f", "i"])
374
375    def test_IfInWhile(self):
376        code = """
377def foo(a):
378    while(a):
379        if a == 1: print(b)
380        else: b = a"""
381        self.checkLocals(code, ["a", "b"])
382