1""" Test suite for the code in fixer_util """
2
3# Testing imports
4from . import support
5
6# Local imports
7from lib2to3.pytree import Node, Leaf
8from lib2to3 import fixer_util
9from lib2to3.fixer_util import Attr, Name, Call, Comma
10from lib2to3.pgen2 import token
11
12def parse(code, strip_levels=0):
13    # The topmost node is file_input, which we don't care about.
14    # The next-topmost node is a *_stmt node, which we also don't care about
15    tree = support.parse_string(code)
16    for i in range(strip_levels):
17        tree = tree.children[0]
18    tree.parent = None
19    return tree
20
21class MacroTestCase(support.TestCase):
22    def assertStr(self, node, string):
23        if isinstance(node, (tuple, list)):
24            node = Node(fixer_util.syms.simple_stmt, node)
25        self.assertEqual(str(node), string)
26
27
28class Test_is_tuple(support.TestCase):
29    def is_tuple(self, string):
30        return fixer_util.is_tuple(parse(string, strip_levels=2))
31
32    def test_valid(self):
33        self.assertTrue(self.is_tuple("(a, b)"))
34        self.assertTrue(self.is_tuple("(a, (b, c))"))
35        self.assertTrue(self.is_tuple("((a, (b, c)),)"))
36        self.assertTrue(self.is_tuple("(a,)"))
37        self.assertTrue(self.is_tuple("()"))
38
39    def test_invalid(self):
40        self.assertFalse(self.is_tuple("(a)"))
41        self.assertFalse(self.is_tuple("('foo') % (b, c)"))
42
43
44class Test_is_list(support.TestCase):
45    def is_list(self, string):
46        return fixer_util.is_list(parse(string, strip_levels=2))
47
48    def test_valid(self):
49        self.assertTrue(self.is_list("[]"))
50        self.assertTrue(self.is_list("[a]"))
51        self.assertTrue(self.is_list("[a, b]"))
52        self.assertTrue(self.is_list("[a, [b, c]]"))
53        self.assertTrue(self.is_list("[[a, [b, c]],]"))
54
55    def test_invalid(self):
56        self.assertFalse(self.is_list("[]+[]"))
57
58
59class Test_Attr(MacroTestCase):
60    def test(self):
61        call = parse("foo()", strip_levels=2)
62
63        self.assertStr(Attr(Name("a"), Name("b")), "a.b")
64        self.assertStr(Attr(call, Name("b")), "foo().b")
65
66    def test_returns(self):
67        attr = Attr(Name("a"), Name("b"))
68        self.assertEqual(type(attr), list)
69
70
71class Test_Name(MacroTestCase):
72    def test(self):
73        self.assertStr(Name("a"), "a")
74        self.assertStr(Name("foo.foo().bar"), "foo.foo().bar")
75        self.assertStr(Name("a", prefix="b"), "ba")
76
77
78class Test_Call(MacroTestCase):
79    def _Call(self, name, args=None, prefix=None):
80        """Help the next test"""
81        children = []
82        if isinstance(args, list):
83            for arg in args:
84                children.append(arg)
85                children.append(Comma())
86            children.pop()
87        return Call(Name(name), children, prefix)
88
89    def test(self):
90        kids = [None,
91                [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
92                 Leaf(token.NUMBER, 3)],
93                [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
94                 Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
95                [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
96                ]
97        self.assertStr(self._Call("A"), "A()")
98        self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
99        self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
100        self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
101
102
103class Test_does_tree_import(support.TestCase):
104    def _find_bind_rec(self, name, node):
105        # Search a tree for a binding -- used to find the starting
106        # point for these tests.
107        c = fixer_util.find_binding(name, node)
108        if c: return c
109        for child in node.children:
110            c = self._find_bind_rec(name, child)
111            if c: return c
112
113    def does_tree_import(self, package, name, string):
114        node = parse(string)
115        # Find the binding of start -- that's what we'll go from
116        node = self._find_bind_rec('start', node)
117        return fixer_util.does_tree_import(package, name, node)
118
119    def try_with(self, string):
120        failing_tests = (("a", "a", "from a import b"),
121                         ("a.d", "a", "from a.d import b"),
122                         ("d.a", "a", "from d.a import b"),
123                         (None, "a", "import b"),
124                         (None, "a", "import b, c, d"))
125        for package, name, import_ in failing_tests:
126            n = self.does_tree_import(package, name, import_ + "\n" + string)
127            self.assertFalse(n)
128            n = self.does_tree_import(package, name, string + "\n" + import_)
129            self.assertFalse(n)
130
131        passing_tests = (("a", "a", "from a import a"),
132                         ("x", "a", "from x import a"),
133                         ("x", "a", "from x import b, c, a, d"),
134                         ("x.b", "a", "from x.b import a"),
135                         ("x.b", "a", "from x.b import b, c, a, d"),
136                         (None, "a", "import a"),
137                         (None, "a", "import b, c, a, d"))
138        for package, name, import_ in passing_tests:
139            n = self.does_tree_import(package, name, import_ + "\n" + string)
140            self.assertTrue(n)
141            n = self.does_tree_import(package, name, string + "\n" + import_)
142            self.assertTrue(n)
143
144    def test_in_function(self):
145        self.try_with("def foo():\n\tbar.baz()\n\tstart=3")
146
147class Test_find_binding(support.TestCase):
148    def find_binding(self, name, string, package=None):
149        return fixer_util.find_binding(name, parse(string), package)
150
151    def test_simple_assignment(self):
152        self.assertTrue(self.find_binding("a", "a = b"))
153        self.assertTrue(self.find_binding("a", "a = [b, c, d]"))
154        self.assertTrue(self.find_binding("a", "a = foo()"))
155        self.assertTrue(self.find_binding("a", "a = foo().foo.foo[6][foo]"))
156        self.assertFalse(self.find_binding("a", "foo = a"))
157        self.assertFalse(self.find_binding("a", "foo = (a, b, c)"))
158
159    def test_tuple_assignment(self):
160        self.assertTrue(self.find_binding("a", "(a,) = b"))
161        self.assertTrue(self.find_binding("a", "(a, b, c) = [b, c, d]"))
162        self.assertTrue(self.find_binding("a", "(c, (d, a), b) = foo()"))
163        self.assertTrue(self.find_binding("a", "(a, b) = foo().foo[6][foo]"))
164        self.assertFalse(self.find_binding("a", "(foo, b) = (b, a)"))
165        self.assertFalse(self.find_binding("a", "(foo, (b, c)) = (a, b, c)"))
166
167    def test_list_assignment(self):
168        self.assertTrue(self.find_binding("a", "[a] = b"))
169        self.assertTrue(self.find_binding("a", "[a, b, c] = [b, c, d]"))
170        self.assertTrue(self.find_binding("a", "[c, [d, a], b] = foo()"))
171        self.assertTrue(self.find_binding("a", "[a, b] = foo().foo[a][foo]"))
172        self.assertFalse(self.find_binding("a", "[foo, b] = (b, a)"))
173        self.assertFalse(self.find_binding("a", "[foo, [b, c]] = (a, b, c)"))
174
175    def test_invalid_assignments(self):
176        self.assertFalse(self.find_binding("a", "foo.a = 5"))
177        self.assertFalse(self.find_binding("a", "foo[a] = 5"))
178        self.assertFalse(self.find_binding("a", "foo(a) = 5"))
179        self.assertFalse(self.find_binding("a", "foo(a, b) = 5"))
180
181    def test_simple_import(self):
182        self.assertTrue(self.find_binding("a", "import a"))
183        self.assertTrue(self.find_binding("a", "import b, c, a, d"))
184        self.assertFalse(self.find_binding("a", "import b"))
185        self.assertFalse(self.find_binding("a", "import b, c, d"))
186
187    def test_from_import(self):
188        self.assertTrue(self.find_binding("a", "from x import a"))
189        self.assertTrue(self.find_binding("a", "from a import a"))
190        self.assertTrue(self.find_binding("a", "from x import b, c, a, d"))
191        self.assertTrue(self.find_binding("a", "from x.b import a"))
192        self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d"))
193        self.assertFalse(self.find_binding("a", "from a import b"))
194        self.assertFalse(self.find_binding("a", "from a.d import b"))
195        self.assertFalse(self.find_binding("a", "from d.a import b"))
196
197    def test_import_as(self):
198        self.assertTrue(self.find_binding("a", "import b as a"))
199        self.assertTrue(self.find_binding("a", "import b as a, c, a as f, d"))
200        self.assertFalse(self.find_binding("a", "import a as f"))
201        self.assertFalse(self.find_binding("a", "import b, c as f, d as e"))
202
203    def test_from_import_as(self):
204        self.assertTrue(self.find_binding("a", "from x import b as a"))
205        self.assertTrue(self.find_binding("a", "from x import g as a, d as b"))
206        self.assertTrue(self.find_binding("a", "from x.b import t as a"))
207        self.assertTrue(self.find_binding("a", "from x.b import g as a, d"))
208        self.assertFalse(self.find_binding("a", "from a import b as t"))
209        self.assertFalse(self.find_binding("a", "from a.d import b as t"))
210        self.assertFalse(self.find_binding("a", "from d.a import b as t"))
211
212    def test_simple_import_with_package(self):
213        self.assertTrue(self.find_binding("b", "import b"))
214        self.assertTrue(self.find_binding("b", "import b, c, d"))
215        self.assertFalse(self.find_binding("b", "import b", "b"))
216        self.assertFalse(self.find_binding("b", "import b, c, d", "c"))
217
218    def test_from_import_with_package(self):
219        self.assertTrue(self.find_binding("a", "from x import a", "x"))
220        self.assertTrue(self.find_binding("a", "from a import a", "a"))
221        self.assertTrue(self.find_binding("a", "from x import *", "x"))
222        self.assertTrue(self.find_binding("a", "from x import b, c, a, d", "x"))
223        self.assertTrue(self.find_binding("a", "from x.b import a", "x.b"))
224        self.assertTrue(self.find_binding("a", "from x.b import *", "x.b"))
225        self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d", "x.b"))
226        self.assertFalse(self.find_binding("a", "from a import b", "a"))
227        self.assertFalse(self.find_binding("a", "from a.d import b", "a.d"))
228        self.assertFalse(self.find_binding("a", "from d.a import b", "a.d"))
229        self.assertFalse(self.find_binding("a", "from x.y import *", "a.b"))
230
231    def test_import_as_with_package(self):
232        self.assertFalse(self.find_binding("a", "import b.c as a", "b.c"))
233        self.assertFalse(self.find_binding("a", "import a as f", "f"))
234        self.assertFalse(self.find_binding("a", "import a as f", "a"))
235
236    def test_from_import_as_with_package(self):
237        # Because it would take a lot of special-case code in the fixers
238        # to deal with from foo import bar as baz, we'll simply always
239        # fail if there is an "from ... import ... as ..."
240        self.assertFalse(self.find_binding("a", "from x import b as a", "x"))
241        self.assertFalse(self.find_binding("a", "from x import g as a, d as b", "x"))
242        self.assertFalse(self.find_binding("a", "from x.b import t as a", "x.b"))
243        self.assertFalse(self.find_binding("a", "from x.b import g as a, d", "x.b"))
244        self.assertFalse(self.find_binding("a", "from a import b as t", "a"))
245        self.assertFalse(self.find_binding("a", "from a import b as t", "b"))
246        self.assertFalse(self.find_binding("a", "from a import b as t", "t"))
247
248    def test_function_def(self):
249        self.assertTrue(self.find_binding("a", "def a(): pass"))
250        self.assertTrue(self.find_binding("a", "def a(b, c, d): pass"))
251        self.assertTrue(self.find_binding("a", "def a(): b = 7"))
252        self.assertFalse(self.find_binding("a", "def d(b, (c, a), e): pass"))
253        self.assertFalse(self.find_binding("a", "def d(a=7): pass"))
254        self.assertFalse(self.find_binding("a", "def d(a): pass"))
255        self.assertFalse(self.find_binding("a", "def d(): a = 7"))
256
257        s = """
258            def d():
259                def a():
260                    pass"""
261        self.assertFalse(self.find_binding("a", s))
262
263    def test_class_def(self):
264        self.assertTrue(self.find_binding("a", "class a: pass"))
265        self.assertTrue(self.find_binding("a", "class a(): pass"))
266        self.assertTrue(self.find_binding("a", "class a(b): pass"))
267        self.assertTrue(self.find_binding("a", "class a(b, c=8): pass"))
268        self.assertFalse(self.find_binding("a", "class d: pass"))
269        self.assertFalse(self.find_binding("a", "class d(a): pass"))
270        self.assertFalse(self.find_binding("a", "class d(b, a=7): pass"))
271        self.assertFalse(self.find_binding("a", "class d(b, *a): pass"))
272        self.assertFalse(self.find_binding("a", "class d(b, **a): pass"))
273        self.assertFalse(self.find_binding("a", "class d: a = 7"))
274
275        s = """
276            class d():
277                class a():
278                    pass"""
279        self.assertFalse(self.find_binding("a", s))
280
281    def test_for(self):
282        self.assertTrue(self.find_binding("a", "for a in r: pass"))
283        self.assertTrue(self.find_binding("a", "for a, b in r: pass"))
284        self.assertTrue(self.find_binding("a", "for (a, b) in r: pass"))
285        self.assertTrue(self.find_binding("a", "for c, (a,) in r: pass"))
286        self.assertTrue(self.find_binding("a", "for c, (a, b) in r: pass"))
287        self.assertTrue(self.find_binding("a", "for c in r: a = c"))
288        self.assertFalse(self.find_binding("a", "for c in a: pass"))
289
290    def test_for_nested(self):
291        s = """
292            for b in r:
293                for a in b:
294                    pass"""
295        self.assertTrue(self.find_binding("a", s))
296
297        s = """
298            for b in r:
299                for a, c in b:
300                    pass"""
301        self.assertTrue(self.find_binding("a", s))
302
303        s = """
304            for b in r:
305                for (a, c) in b:
306                    pass"""
307        self.assertTrue(self.find_binding("a", s))
308
309        s = """
310            for b in r:
311                for (a,) in b:
312                    pass"""
313        self.assertTrue(self.find_binding("a", s))
314
315        s = """
316            for b in r:
317                for c, (a, d) in b:
318                    pass"""
319        self.assertTrue(self.find_binding("a", s))
320
321        s = """
322            for b in r:
323                for c in b:
324                    a = 7"""
325        self.assertTrue(self.find_binding("a", s))
326
327        s = """
328            for b in r:
329                for c in b:
330                    d = a"""
331        self.assertFalse(self.find_binding("a", s))
332
333        s = """
334            for b in r:
335                for c in a:
336                    d = 7"""
337        self.assertFalse(self.find_binding("a", s))
338
339    def test_if(self):
340        self.assertTrue(self.find_binding("a", "if b in r: a = c"))
341        self.assertFalse(self.find_binding("a", "if a in r: d = e"))
342
343    def test_if_nested(self):
344        s = """
345            if b in r:
346                if c in d:
347                    a = c"""
348        self.assertTrue(self.find_binding("a", s))
349
350        s = """
351            if b in r:
352                if c in d:
353                    c = a"""
354        self.assertFalse(self.find_binding("a", s))
355
356    def test_while(self):
357        self.assertTrue(self.find_binding("a", "while b in r: a = c"))
358        self.assertFalse(self.find_binding("a", "while a in r: d = e"))
359
360    def test_while_nested(self):
361        s = """
362            while b in r:
363                while c in d:
364                    a = c"""
365        self.assertTrue(self.find_binding("a", s))
366
367        s = """
368            while b in r:
369                while c in d:
370                    c = a"""
371        self.assertFalse(self.find_binding("a", s))
372
373    def test_try_except(self):
374        s = """
375            try:
376                a = 6
377            except:
378                b = 8"""
379        self.assertTrue(self.find_binding("a", s))
380
381        s = """
382            try:
383                b = 8
384            except:
385                a = 6"""
386        self.assertTrue(self.find_binding("a", s))
387
388        s = """
389            try:
390                b = 8
391            except KeyError:
392                pass
393            except:
394                a = 6"""
395        self.assertTrue(self.find_binding("a", s))
396
397        s = """
398            try:
399                b = 8
400            except:
401                b = 6"""
402        self.assertFalse(self.find_binding("a", s))
403
404    def test_try_except_nested(self):
405        s = """
406            try:
407                try:
408                    a = 6
409                except:
410                    pass
411            except:
412                b = 8"""
413        self.assertTrue(self.find_binding("a", s))
414
415        s = """
416            try:
417                b = 8
418            except:
419                try:
420                    a = 6
421                except:
422                    pass"""
423        self.assertTrue(self.find_binding("a", s))
424
425        s = """
426            try:
427                b = 8
428            except:
429                try:
430                    pass
431                except:
432                    a = 6"""
433        self.assertTrue(self.find_binding("a", s))
434
435        s = """
436            try:
437                try:
438                    b = 8
439                except KeyError:
440                    pass
441                except:
442                    a = 6
443            except:
444                pass"""
445        self.assertTrue(self.find_binding("a", s))
446
447        s = """
448            try:
449                pass
450            except:
451                try:
452                    b = 8
453                except KeyError:
454                    pass
455                except:
456                    a = 6"""
457        self.assertTrue(self.find_binding("a", s))
458
459        s = """
460            try:
461                b = 8
462            except:
463                b = 6"""
464        self.assertFalse(self.find_binding("a", s))
465
466        s = """
467            try:
468                try:
469                    b = 8
470                except:
471                    c = d
472            except:
473                try:
474                    b = 6
475                except:
476                    t = 8
477                except:
478                    o = y"""
479        self.assertFalse(self.find_binding("a", s))
480
481    def test_try_except_finally(self):
482        s = """
483            try:
484                c = 6
485            except:
486                b = 8
487            finally:
488                a = 9"""
489        self.assertTrue(self.find_binding("a", s))
490
491        s = """
492            try:
493                b = 8
494            finally:
495                a = 6"""
496        self.assertTrue(self.find_binding("a", s))
497
498        s = """
499            try:
500                b = 8
501            finally:
502                b = 6"""
503        self.assertFalse(self.find_binding("a", s))
504
505        s = """
506            try:
507                b = 8
508            except:
509                b = 9
510            finally:
511                b = 6"""
512        self.assertFalse(self.find_binding("a", s))
513
514    def test_try_except_finally_nested(self):
515        s = """
516            try:
517                c = 6
518            except:
519                b = 8
520            finally:
521                try:
522                    a = 9
523                except:
524                    b = 9
525                finally:
526                    c = 9"""
527        self.assertTrue(self.find_binding("a", s))
528
529        s = """
530            try:
531                b = 8
532            finally:
533                try:
534                    pass
535                finally:
536                    a = 6"""
537        self.assertTrue(self.find_binding("a", s))
538
539        s = """
540            try:
541                b = 8
542            finally:
543                try:
544                    b = 6
545                finally:
546                    b = 7"""
547        self.assertFalse(self.find_binding("a", s))
548
549class Test_touch_import(support.TestCase):
550
551    def test_after_docstring(self):
552        node = parse('"""foo"""\nbar()')
553        fixer_util.touch_import(None, "foo", node)
554        self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
555
556    def test_after_imports(self):
557        node = parse('"""foo"""\nimport bar\nbar()')
558        fixer_util.touch_import(None, "foo", node)
559        self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
560
561    def test_beginning(self):
562        node = parse('bar()')
563        fixer_util.touch_import(None, "foo", node)
564        self.assertEqual(str(node), 'import foo\nbar()\n\n')
565
566    def test_from_import(self):
567        node = parse('bar()')
568        fixer_util.touch_import("html", "escape", node)
569        self.assertEqual(str(node), 'from html import escape\nbar()\n\n')
570
571    def test_name_import(self):
572        node = parse('bar()')
573        fixer_util.touch_import(None, "cgi", node)
574        self.assertEqual(str(node), 'import cgi\nbar()\n\n')
575
576class Test_find_indentation(support.TestCase):
577
578    def test_nothing(self):
579        fi = fixer_util.find_indentation
580        node = parse("node()")
581        self.assertEqual(fi(node), "")
582        node = parse("")
583        self.assertEqual(fi(node), "")
584
585    def test_simple(self):
586        fi = fixer_util.find_indentation
587        node = parse("def f():\n    x()")
588        self.assertEqual(fi(node), "")
589        self.assertEqual(fi(node.children[0].children[4].children[2]), "    ")
590        node = parse("def f():\n    x()\n    y()")
591        self.assertEqual(fi(node.children[0].children[4].children[4]), "    ")
592