1# mako/_ast_util.py
2# Copyright 2006-2019 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""
8    ast
9    ~~~
10
11    This is a stripped down version of Armin Ronacher's ast module.
12
13    :copyright: Copyright 2008 by Armin Ronacher.
14    :license: Python License.
15"""
16
17
18from _ast import Add
19from _ast import And
20from _ast import AST
21from _ast import BitAnd
22from _ast import BitOr
23from _ast import BitXor
24from _ast import Div
25from _ast import Eq
26from _ast import FloorDiv
27from _ast import Gt
28from _ast import GtE
29from _ast import If
30from _ast import In
31from _ast import Invert
32from _ast import Is
33from _ast import IsNot
34from _ast import LShift
35from _ast import Lt
36from _ast import LtE
37from _ast import Mod
38from _ast import Mult
39from _ast import Name
40from _ast import Not
41from _ast import NotEq
42from _ast import NotIn
43from _ast import Or
44from _ast import PyCF_ONLY_AST
45from _ast import RShift
46from _ast import Sub
47from _ast import UAdd
48from _ast import USub
49
50from mako.compat import arg_stringname
51
52BOOLOP_SYMBOLS = {And: "and", Or: "or"}
53
54BINOP_SYMBOLS = {
55    Add: "+",
56    Sub: "-",
57    Mult: "*",
58    Div: "/",
59    FloorDiv: "//",
60    Mod: "%",
61    LShift: "<<",
62    RShift: ">>",
63    BitOr: "|",
64    BitAnd: "&",
65    BitXor: "^",
66}
67
68CMPOP_SYMBOLS = {
69    Eq: "==",
70    Gt: ">",
71    GtE: ">=",
72    In: "in",
73    Is: "is",
74    IsNot: "is not",
75    Lt: "<",
76    LtE: "<=",
77    NotEq: "!=",
78    NotIn: "not in",
79}
80
81UNARYOP_SYMBOLS = {Invert: "~", Not: "not", UAdd: "+", USub: "-"}
82
83ALL_SYMBOLS = {}
84ALL_SYMBOLS.update(BOOLOP_SYMBOLS)
85ALL_SYMBOLS.update(BINOP_SYMBOLS)
86ALL_SYMBOLS.update(CMPOP_SYMBOLS)
87ALL_SYMBOLS.update(UNARYOP_SYMBOLS)
88
89
90def parse(expr, filename="<unknown>", mode="exec"):
91    """Parse an expression into an AST node."""
92    return compile(expr, filename, mode, PyCF_ONLY_AST)
93
94
95def iter_fields(node):
96    """Iterate over all fields of a node, only yielding existing fields."""
97    # CPython 2.5 compat
98    if not hasattr(node, "_fields") or not node._fields:
99        return
100    for field in node._fields:
101        try:
102            yield field, getattr(node, field)
103        except AttributeError:
104            pass
105
106
107class NodeVisitor(object):
108
109    """
110    Walks the abstract syntax tree and call visitor functions for every node
111    found.  The visitor functions may return values which will be forwarded
112    by the `visit` method.
113
114    Per default the visitor functions for the nodes are ``'visit_'`` +
115    class name of the node.  So a `TryFinally` node visit function would
116    be `visit_TryFinally`.  This behavior can be changed by overriding
117    the `get_visitor` function.  If no visitor function exists for a node
118    (return value `None`) the `generic_visit` visitor is used instead.
119
120    Don't use the `NodeVisitor` if you want to apply changes to nodes during
121    traversing.  For this a special visitor exists (`NodeTransformer`) that
122    allows modifications.
123    """
124
125    def get_visitor(self, node):
126        """
127        Return the visitor function for this node or `None` if no visitor
128        exists for this node.  In that case the generic visit function is
129        used instead.
130        """
131        method = "visit_" + node.__class__.__name__
132        return getattr(self, method, None)
133
134    def visit(self, node):
135        """Visit a node."""
136        f = self.get_visitor(node)
137        if f is not None:
138            return f(node)
139        return self.generic_visit(node)
140
141    def generic_visit(self, node):
142        """Called if no explicit visitor function exists for a node."""
143        for field, value in iter_fields(node):
144            if isinstance(value, list):
145                for item in value:
146                    if isinstance(item, AST):
147                        self.visit(item)
148            elif isinstance(value, AST):
149                self.visit(value)
150
151
152class NodeTransformer(NodeVisitor):
153
154    """
155    Walks the abstract syntax tree and allows modifications of nodes.
156
157    The `NodeTransformer` will walk the AST and use the return value of the
158    visitor functions to replace or remove the old node.  If the return
159    value of the visitor function is `None` the node will be removed
160    from the previous location otherwise it's replaced with the return
161    value.  The return value may be the original node in which case no
162    replacement takes place.
163
164    Here an example transformer that rewrites all `foo` to `data['foo']`::
165
166        class RewriteName(NodeTransformer):
167
168            def visit_Name(self, node):
169                return copy_location(Subscript(
170                    value=Name(id='data', ctx=Load()),
171                    slice=Index(value=Str(s=node.id)),
172                    ctx=node.ctx
173                ), node)
174
175    Keep in mind that if the node you're operating on has child nodes
176    you must either transform the child nodes yourself or call the generic
177    visit function for the node first.
178
179    Nodes that were part of a collection of statements (that applies to
180    all statement nodes) may also return a list of nodes rather than just
181    a single node.
182
183    Usually you use the transformer like this::
184
185        node = YourTransformer().visit(node)
186    """
187
188    def generic_visit(self, node):
189        for field, old_value in iter_fields(node):
190            old_value = getattr(node, field, None)
191            if isinstance(old_value, list):
192                new_values = []
193                for value in old_value:
194                    if isinstance(value, AST):
195                        value = self.visit(value)
196                        if value is None:
197                            continue
198                        elif not isinstance(value, AST):
199                            new_values.extend(value)
200                            continue
201                    new_values.append(value)
202                old_value[:] = new_values
203            elif isinstance(old_value, AST):
204                new_node = self.visit(old_value)
205                if new_node is None:
206                    delattr(node, field)
207                else:
208                    setattr(node, field, new_node)
209        return node
210
211
212class SourceGenerator(NodeVisitor):
213
214    """
215    This visitor is able to transform a well formed syntax tree into python
216    sourcecode.  For more details have a look at the docstring of the
217    `node_to_source` function.
218    """
219
220    def __init__(self, indent_with):
221        self.result = []
222        self.indent_with = indent_with
223        self.indentation = 0
224        self.new_lines = 0
225
226    def write(self, x):
227        if self.new_lines:
228            if self.result:
229                self.result.append("\n" * self.new_lines)
230            self.result.append(self.indent_with * self.indentation)
231            self.new_lines = 0
232        self.result.append(x)
233
234    def newline(self, n=1):
235        self.new_lines = max(self.new_lines, n)
236
237    def body(self, statements):
238        self.new_line = True
239        self.indentation += 1
240        for stmt in statements:
241            self.visit(stmt)
242        self.indentation -= 1
243
244    def body_or_else(self, node):
245        self.body(node.body)
246        if node.orelse:
247            self.newline()
248            self.write("else:")
249            self.body(node.orelse)
250
251    def signature(self, node):
252        want_comma = []
253
254        def write_comma():
255            if want_comma:
256                self.write(", ")
257            else:
258                want_comma.append(True)
259
260        padding = [None] * (len(node.args) - len(node.defaults))
261        for arg, default in zip(node.args, padding + node.defaults):
262            write_comma()
263            self.visit(arg)
264            if default is not None:
265                self.write("=")
266                self.visit(default)
267        if node.vararg is not None:
268            write_comma()
269            self.write("*" + arg_stringname(node.vararg))
270        if node.kwarg is not None:
271            write_comma()
272            self.write("**" + arg_stringname(node.kwarg))
273
274    def decorators(self, node):
275        for decorator in node.decorator_list:
276            self.newline()
277            self.write("@")
278            self.visit(decorator)
279
280    # Statements
281
282    def visit_Assign(self, node):
283        self.newline()
284        for idx, target in enumerate(node.targets):
285            if idx:
286                self.write(", ")
287            self.visit(target)
288        self.write(" = ")
289        self.visit(node.value)
290
291    def visit_AugAssign(self, node):
292        self.newline()
293        self.visit(node.target)
294        self.write(BINOP_SYMBOLS[type(node.op)] + "=")
295        self.visit(node.value)
296
297    def visit_ImportFrom(self, node):
298        self.newline()
299        self.write("from %s%s import " % ("." * node.level, node.module))
300        for idx, item in enumerate(node.names):
301            if idx:
302                self.write(", ")
303            self.write(item)
304
305    def visit_Import(self, node):
306        self.newline()
307        for item in node.names:
308            self.write("import ")
309            self.visit(item)
310
311    def visit_Expr(self, node):
312        self.newline()
313        self.generic_visit(node)
314
315    def visit_FunctionDef(self, node):
316        self.newline(n=2)
317        self.decorators(node)
318        self.newline()
319        self.write("def %s(" % node.name)
320        self.signature(node.args)
321        self.write("):")
322        self.body(node.body)
323
324    def visit_ClassDef(self, node):
325        have_args = []
326
327        def paren_or_comma():
328            if have_args:
329                self.write(", ")
330            else:
331                have_args.append(True)
332                self.write("(")
333
334        self.newline(n=3)
335        self.decorators(node)
336        self.newline()
337        self.write("class %s" % node.name)
338        for base in node.bases:
339            paren_or_comma()
340            self.visit(base)
341        # XXX: the if here is used to keep this module compatible
342        #      with python 2.6.
343        if hasattr(node, "keywords"):
344            for keyword in node.keywords:
345                paren_or_comma()
346                self.write(keyword.arg + "=")
347                self.visit(keyword.value)
348            if getattr(node, "starargs", None):
349                paren_or_comma()
350                self.write("*")
351                self.visit(node.starargs)
352            if getattr(node, "kwargs", None):
353                paren_or_comma()
354                self.write("**")
355                self.visit(node.kwargs)
356        self.write(have_args and "):" or ":")
357        self.body(node.body)
358
359    def visit_If(self, node):
360        self.newline()
361        self.write("if ")
362        self.visit(node.test)
363        self.write(":")
364        self.body(node.body)
365        while True:
366            else_ = node.orelse
367            if len(else_) == 1 and isinstance(else_[0], If):
368                node = else_[0]
369                self.newline()
370                self.write("elif ")
371                self.visit(node.test)
372                self.write(":")
373                self.body(node.body)
374            else:
375                self.newline()
376                self.write("else:")
377                self.body(else_)
378                break
379
380    def visit_For(self, node):
381        self.newline()
382        self.write("for ")
383        self.visit(node.target)
384        self.write(" in ")
385        self.visit(node.iter)
386        self.write(":")
387        self.body_or_else(node)
388
389    def visit_While(self, node):
390        self.newline()
391        self.write("while ")
392        self.visit(node.test)
393        self.write(":")
394        self.body_or_else(node)
395
396    def visit_With(self, node):
397        self.newline()
398        self.write("with ")
399        self.visit(node.context_expr)
400        if node.optional_vars is not None:
401            self.write(" as ")
402            self.visit(node.optional_vars)
403        self.write(":")
404        self.body(node.body)
405
406    def visit_Pass(self, node):
407        self.newline()
408        self.write("pass")
409
410    def visit_Print(self, node):
411        # XXX: python 2.6 only
412        self.newline()
413        self.write("print ")
414        want_comma = False
415        if node.dest is not None:
416            self.write(" >> ")
417            self.visit(node.dest)
418            want_comma = True
419        for value in node.values:
420            if want_comma:
421                self.write(", ")
422            self.visit(value)
423            want_comma = True
424        if not node.nl:
425            self.write(",")
426
427    def visit_Delete(self, node):
428        self.newline()
429        self.write("del ")
430        for idx, target in enumerate(node):
431            if idx:
432                self.write(", ")
433            self.visit(target)
434
435    def visit_TryExcept(self, node):
436        self.newline()
437        self.write("try:")
438        self.body(node.body)
439        for handler in node.handlers:
440            self.visit(handler)
441
442    def visit_TryFinally(self, node):
443        self.newline()
444        self.write("try:")
445        self.body(node.body)
446        self.newline()
447        self.write("finally:")
448        self.body(node.finalbody)
449
450    def visit_Global(self, node):
451        self.newline()
452        self.write("global " + ", ".join(node.names))
453
454    def visit_Nonlocal(self, node):
455        self.newline()
456        self.write("nonlocal " + ", ".join(node.names))
457
458    def visit_Return(self, node):
459        self.newline()
460        self.write("return ")
461        self.visit(node.value)
462
463    def visit_Break(self, node):
464        self.newline()
465        self.write("break")
466
467    def visit_Continue(self, node):
468        self.newline()
469        self.write("continue")
470
471    def visit_Raise(self, node):
472        # XXX: Python 2.6 / 3.0 compatibility
473        self.newline()
474        self.write("raise")
475        if hasattr(node, "exc") and node.exc is not None:
476            self.write(" ")
477            self.visit(node.exc)
478            if node.cause is not None:
479                self.write(" from ")
480                self.visit(node.cause)
481        elif hasattr(node, "type") and node.type is not None:
482            self.visit(node.type)
483            if node.inst is not None:
484                self.write(", ")
485                self.visit(node.inst)
486            if node.tback is not None:
487                self.write(", ")
488                self.visit(node.tback)
489
490    # Expressions
491
492    def visit_Attribute(self, node):
493        self.visit(node.value)
494        self.write("." + node.attr)
495
496    def visit_Call(self, node):
497        want_comma = []
498
499        def write_comma():
500            if want_comma:
501                self.write(", ")
502            else:
503                want_comma.append(True)
504
505        self.visit(node.func)
506        self.write("(")
507        for arg in node.args:
508            write_comma()
509            self.visit(arg)
510        for keyword in node.keywords:
511            write_comma()
512            self.write(keyword.arg + "=")
513            self.visit(keyword.value)
514        if getattr(node, "starargs", None):
515            write_comma()
516            self.write("*")
517            self.visit(node.starargs)
518        if getattr(node, "kwargs", None):
519            write_comma()
520            self.write("**")
521            self.visit(node.kwargs)
522        self.write(")")
523
524    def visit_Name(self, node):
525        self.write(node.id)
526
527    def visit_NameConstant(self, node):
528        self.write(str(node.value))
529
530    def visit_arg(self, node):
531        self.write(node.arg)
532
533    def visit_Str(self, node):
534        self.write(repr(node.s))
535
536    def visit_Bytes(self, node):
537        self.write(repr(node.s))
538
539    def visit_Num(self, node):
540        self.write(repr(node.n))
541
542    # newly needed in Python 3.8
543    def visit_Constant(self, node):
544        self.write(repr(node.value))
545
546    def visit_Tuple(self, node):
547        self.write("(")
548        idx = -1
549        for idx, item in enumerate(node.elts):
550            if idx:
551                self.write(", ")
552            self.visit(item)
553        self.write(idx and ")" or ",)")
554
555    def sequence_visit(left, right):
556        def visit(self, node):
557            self.write(left)
558            for idx, item in enumerate(node.elts):
559                if idx:
560                    self.write(", ")
561                self.visit(item)
562            self.write(right)
563
564        return visit
565
566    visit_List = sequence_visit("[", "]")
567    visit_Set = sequence_visit("{", "}")
568    del sequence_visit
569
570    def visit_Dict(self, node):
571        self.write("{")
572        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
573            if idx:
574                self.write(", ")
575            self.visit(key)
576            self.write(": ")
577            self.visit(value)
578        self.write("}")
579
580    def visit_BinOp(self, node):
581        self.write("(")
582        self.visit(node.left)
583        self.write(" %s " % BINOP_SYMBOLS[type(node.op)])
584        self.visit(node.right)
585        self.write(")")
586
587    def visit_BoolOp(self, node):
588        self.write("(")
589        for idx, value in enumerate(node.values):
590            if idx:
591                self.write(" %s " % BOOLOP_SYMBOLS[type(node.op)])
592            self.visit(value)
593        self.write(")")
594
595    def visit_Compare(self, node):
596        self.write("(")
597        self.visit(node.left)
598        for op, right in zip(node.ops, node.comparators):
599            self.write(" %s " % CMPOP_SYMBOLS[type(op)])
600            self.visit(right)
601        self.write(")")
602
603    def visit_UnaryOp(self, node):
604        self.write("(")
605        op = UNARYOP_SYMBOLS[type(node.op)]
606        self.write(op)
607        if op == "not":
608            self.write(" ")
609        self.visit(node.operand)
610        self.write(")")
611
612    def visit_Subscript(self, node):
613        self.visit(node.value)
614        self.write("[")
615        self.visit(node.slice)
616        self.write("]")
617
618    def visit_Slice(self, node):
619        if node.lower is not None:
620            self.visit(node.lower)
621        self.write(":")
622        if node.upper is not None:
623            self.visit(node.upper)
624        if node.step is not None:
625            self.write(":")
626            if not (isinstance(node.step, Name) and node.step.id == "None"):
627                self.visit(node.step)
628
629    def visit_ExtSlice(self, node):
630        for idx, item in node.dims:
631            if idx:
632                self.write(", ")
633            self.visit(item)
634
635    def visit_Yield(self, node):
636        self.write("yield ")
637        self.visit(node.value)
638
639    def visit_Lambda(self, node):
640        self.write("lambda ")
641        self.signature(node.args)
642        self.write(": ")
643        self.visit(node.body)
644
645    def visit_Ellipsis(self, node):
646        self.write("Ellipsis")
647
648    def generator_visit(left, right):
649        def visit(self, node):
650            self.write(left)
651            self.visit(node.elt)
652            for comprehension in node.generators:
653                self.visit(comprehension)
654            self.write(right)
655
656        return visit
657
658    visit_ListComp = generator_visit("[", "]")
659    visit_GeneratorExp = generator_visit("(", ")")
660    visit_SetComp = generator_visit("{", "}")
661    del generator_visit
662
663    def visit_DictComp(self, node):
664        self.write("{")
665        self.visit(node.key)
666        self.write(": ")
667        self.visit(node.value)
668        for comprehension in node.generators:
669            self.visit(comprehension)
670        self.write("}")
671
672    def visit_IfExp(self, node):
673        self.visit(node.body)
674        self.write(" if ")
675        self.visit(node.test)
676        self.write(" else ")
677        self.visit(node.orelse)
678
679    def visit_Starred(self, node):
680        self.write("*")
681        self.visit(node.value)
682
683    def visit_Repr(self, node):
684        # XXX: python 2.6 only
685        self.write("`")
686        self.visit(node.value)
687        self.write("`")
688
689    # Helper Nodes
690
691    def visit_alias(self, node):
692        self.write(node.name)
693        if node.asname is not None:
694            self.write(" as " + node.asname)
695
696    def visit_comprehension(self, node):
697        self.write(" for ")
698        self.visit(node.target)
699        self.write(" in ")
700        self.visit(node.iter)
701        if node.ifs:
702            for if_ in node.ifs:
703                self.write(" if ")
704                self.visit(if_)
705
706    def visit_excepthandler(self, node):
707        self.newline()
708        self.write("except")
709        if node.type is not None:
710            self.write(" ")
711            self.visit(node.type)
712            if node.name is not None:
713                self.write(" as ")
714                self.visit(node.name)
715        self.write(":")
716        self.body(node.body)
717