1# mako/pyparser.py
2# Copyright 2006-2019 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""Handles parsing of Python code.
8
9Parsing to AST is done via _ast on Python > 2.5, otherwise the compiler
10module is used.
11"""
12
13import operator
14
15import _ast
16
17from mako import _ast_util
18from mako import compat
19from mako import exceptions
20from mako import util
21from mako.compat import arg_stringname
22
23if compat.py3k:
24    # words that cannot be assigned to (notably
25    # smaller than the total keys in __builtins__)
26    reserved = set(["True", "False", "None", "print"])
27
28    # the "id" attribute on a function node
29    arg_id = operator.attrgetter("arg")
30else:
31    # words that cannot be assigned to (notably
32    # smaller than the total keys in __builtins__)
33    reserved = set(["True", "False", "None"])
34
35    # the "id" attribute on a function node
36    arg_id = operator.attrgetter("id")
37
38util.restore__ast(_ast)
39
40
41def parse(code, mode="exec", **exception_kwargs):
42    """Parse an expression into AST"""
43
44    try:
45        return _ast_util.parse(code, "<unknown>", mode)
46    except Exception:
47        raise exceptions.SyntaxException(
48            "(%s) %s (%r)"
49            % (
50                compat.exception_as().__class__.__name__,
51                compat.exception_as(),
52                code[0:50],
53            ),
54            **exception_kwargs
55        )
56
57
58class FindIdentifiers(_ast_util.NodeVisitor):
59    def __init__(self, listener, **exception_kwargs):
60        self.in_function = False
61        self.in_assign_targets = False
62        self.local_ident_stack = set()
63        self.listener = listener
64        self.exception_kwargs = exception_kwargs
65
66    def _add_declared(self, name):
67        if not self.in_function:
68            self.listener.declared_identifiers.add(name)
69        else:
70            self.local_ident_stack.add(name)
71
72    def visit_ClassDef(self, node):
73        self._add_declared(node.name)
74
75    def visit_Assign(self, node):
76
77        # flip around the visiting of Assign so the expression gets
78        # evaluated first, in the case of a clause like "x=x+5" (x
79        # is undeclared)
80
81        self.visit(node.value)
82        in_a = self.in_assign_targets
83        self.in_assign_targets = True
84        for n in node.targets:
85            self.visit(n)
86        self.in_assign_targets = in_a
87
88    if compat.py3k:
89
90        # ExceptHandler is in Python 2, but this block only works in
91        # Python 3 (and is required there)
92
93        def visit_ExceptHandler(self, node):
94            if node.name is not None:
95                self._add_declared(node.name)
96            if node.type is not None:
97                self.visit(node.type)
98            for statement in node.body:
99                self.visit(statement)
100
101    def visit_Lambda(self, node, *args):
102        self._visit_function(node, True)
103
104    def visit_FunctionDef(self, node):
105        self._add_declared(node.name)
106        self._visit_function(node, False)
107
108    def _expand_tuples(self, args):
109        for arg in args:
110            if isinstance(arg, _ast.Tuple):
111                for n in arg.elts:
112                    yield n
113            else:
114                yield arg
115
116    def _visit_function(self, node, islambda):
117
118        # push function state onto stack.  dont log any more
119        # identifiers as "declared" until outside of the function,
120        # but keep logging identifiers as "undeclared". track
121        # argument names in each function header so they arent
122        # counted as "undeclared"
123
124        inf = self.in_function
125        self.in_function = True
126
127        local_ident_stack = self.local_ident_stack
128        self.local_ident_stack = local_ident_stack.union(
129            [arg_id(arg) for arg in self._expand_tuples(node.args.args)]
130        )
131        if islambda:
132            self.visit(node.body)
133        else:
134            for n in node.body:
135                self.visit(n)
136        self.in_function = inf
137        self.local_ident_stack = local_ident_stack
138
139    def visit_For(self, node):
140
141        # flip around visit
142
143        self.visit(node.iter)
144        self.visit(node.target)
145        for statement in node.body:
146            self.visit(statement)
147        for statement in node.orelse:
148            self.visit(statement)
149
150    def visit_Name(self, node):
151        if isinstance(node.ctx, _ast.Store):
152            # this is eqiuvalent to visit_AssName in
153            # compiler
154            self._add_declared(node.id)
155        elif (
156            node.id not in reserved
157            and node.id not in self.listener.declared_identifiers
158            and node.id not in self.local_ident_stack
159        ):
160            self.listener.undeclared_identifiers.add(node.id)
161
162    def visit_Import(self, node):
163        for name in node.names:
164            if name.asname is not None:
165                self._add_declared(name.asname)
166            else:
167                self._add_declared(name.name.split(".")[0])
168
169    def visit_ImportFrom(self, node):
170        for name in node.names:
171            if name.asname is not None:
172                self._add_declared(name.asname)
173            else:
174                if name.name == "*":
175                    raise exceptions.CompileException(
176                        "'import *' is not supported, since all identifier "
177                        "names must be explicitly declared.  Please use the "
178                        "form 'from <modulename> import <name1>, <name2>, "
179                        "...' instead.",
180                        **self.exception_kwargs
181                    )
182                self._add_declared(name.name)
183
184
185class FindTuple(_ast_util.NodeVisitor):
186    def __init__(self, listener, code_factory, **exception_kwargs):
187        self.listener = listener
188        self.exception_kwargs = exception_kwargs
189        self.code_factory = code_factory
190
191    def visit_Tuple(self, node):
192        for n in node.elts:
193            p = self.code_factory(n, **self.exception_kwargs)
194            self.listener.codeargs.append(p)
195            self.listener.args.append(ExpressionGenerator(n).value())
196            ldi = self.listener.declared_identifiers
197            self.listener.declared_identifiers = ldi.union(
198                p.declared_identifiers
199            )
200            lui = self.listener.undeclared_identifiers
201            self.listener.undeclared_identifiers = lui.union(
202                p.undeclared_identifiers
203            )
204
205
206class ParseFunc(_ast_util.NodeVisitor):
207    def __init__(self, listener, **exception_kwargs):
208        self.listener = listener
209        self.exception_kwargs = exception_kwargs
210
211    def visit_FunctionDef(self, node):
212        self.listener.funcname = node.name
213
214        argnames = [arg_id(arg) for arg in node.args.args]
215        if node.args.vararg:
216            argnames.append(arg_stringname(node.args.vararg))
217
218        if compat.py2k:
219            # kw-only args don't exist in Python 2
220            kwargnames = []
221        else:
222            kwargnames = [arg_id(arg) for arg in node.args.kwonlyargs]
223        if node.args.kwarg:
224            kwargnames.append(arg_stringname(node.args.kwarg))
225        self.listener.argnames = argnames
226        self.listener.defaults = node.args.defaults  # ast
227        self.listener.kwargnames = kwargnames
228        if compat.py2k:
229            self.listener.kwdefaults = []
230        else:
231            self.listener.kwdefaults = node.args.kw_defaults
232        self.listener.varargs = node.args.vararg
233        self.listener.kwargs = node.args.kwarg
234
235
236class ExpressionGenerator(object):
237    def __init__(self, astnode):
238        self.generator = _ast_util.SourceGenerator(" " * 4)
239        self.generator.visit(astnode)
240
241    def value(self):
242        return "".join(self.generator.result)
243