1from ._compat import iteritems
2from .visitor import NodeVisitor
3
4VAR_LOAD_PARAMETER = "param"
5VAR_LOAD_RESOLVE = "resolve"
6VAR_LOAD_ALIAS = "alias"
7VAR_LOAD_UNDEFINED = "undefined"
8
9
10def find_symbols(nodes, parent_symbols=None):
11    sym = Symbols(parent=parent_symbols)
12    visitor = FrameSymbolVisitor(sym)
13    for node in nodes:
14        visitor.visit(node)
15    return sym
16
17
18def symbols_for_node(node, parent_symbols=None):
19    sym = Symbols(parent=parent_symbols)
20    sym.analyze_node(node)
21    return sym
22
23
24class Symbols(object):
25    def __init__(self, parent=None, level=None):
26        if level is None:
27            if parent is None:
28                level = 0
29            else:
30                level = parent.level + 1
31        self.level = level
32        self.parent = parent
33        self.refs = {}
34        self.loads = {}
35        self.stores = set()
36
37    def analyze_node(self, node, **kwargs):
38        visitor = RootVisitor(self)
39        visitor.visit(node, **kwargs)
40
41    def _define_ref(self, name, load=None):
42        ident = "l_%d_%s" % (self.level, name)
43        self.refs[name] = ident
44        if load is not None:
45            self.loads[ident] = load
46        return ident
47
48    def find_load(self, target):
49        if target in self.loads:
50            return self.loads[target]
51        if self.parent is not None:
52            return self.parent.find_load(target)
53
54    def find_ref(self, name):
55        if name in self.refs:
56            return self.refs[name]
57        if self.parent is not None:
58            return self.parent.find_ref(name)
59
60    def ref(self, name):
61        rv = self.find_ref(name)
62        if rv is None:
63            raise AssertionError(
64                "Tried to resolve a name to a reference that "
65                "was unknown to the frame (%r)" % name
66            )
67        return rv
68
69    def copy(self):
70        rv = object.__new__(self.__class__)
71        rv.__dict__.update(self.__dict__)
72        rv.refs = self.refs.copy()
73        rv.loads = self.loads.copy()
74        rv.stores = self.stores.copy()
75        return rv
76
77    def store(self, name):
78        self.stores.add(name)
79
80        # If we have not see the name referenced yet, we need to figure
81        # out what to set it to.
82        if name not in self.refs:
83            # If there is a parent scope we check if the name has a
84            # reference there.  If it does it means we might have to alias
85            # to a variable there.
86            if self.parent is not None:
87                outer_ref = self.parent.find_ref(name)
88                if outer_ref is not None:
89                    self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref))
90                    return
91
92            # Otherwise we can just set it to undefined.
93            self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
94
95    def declare_parameter(self, name):
96        self.stores.add(name)
97        return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
98
99    def load(self, name):
100        target = self.find_ref(name)
101        if target is None:
102            self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
103
104    def branch_update(self, branch_symbols):
105        stores = {}
106        for branch in branch_symbols:
107            for target in branch.stores:
108                if target in self.stores:
109                    continue
110                stores[target] = stores.get(target, 0) + 1
111
112        for sym in branch_symbols:
113            self.refs.update(sym.refs)
114            self.loads.update(sym.loads)
115            self.stores.update(sym.stores)
116
117        for name, branch_count in iteritems(stores):
118            if branch_count == len(branch_symbols):
119                continue
120            target = self.find_ref(name)
121            assert target is not None, "should not happen"
122
123            if self.parent is not None:
124                outer_target = self.parent.find_ref(name)
125                if outer_target is not None:
126                    self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
127                    continue
128            self.loads[target] = (VAR_LOAD_RESOLVE, name)
129
130    def dump_stores(self):
131        rv = {}
132        node = self
133        while node is not None:
134            for name in node.stores:
135                if name not in rv:
136                    rv[name] = self.find_ref(name)
137            node = node.parent
138        return rv
139
140    def dump_param_targets(self):
141        rv = set()
142        node = self
143        while node is not None:
144            for target, (instr, _) in iteritems(self.loads):
145                if instr == VAR_LOAD_PARAMETER:
146                    rv.add(target)
147            node = node.parent
148        return rv
149
150
151class RootVisitor(NodeVisitor):
152    def __init__(self, symbols):
153        self.sym_visitor = FrameSymbolVisitor(symbols)
154
155    def _simple_visit(self, node, **kwargs):
156        for child in node.iter_child_nodes():
157            self.sym_visitor.visit(child)
158
159    visit_Template = (
160        visit_Block
161    ) = (
162        visit_Macro
163    ) = (
164        visit_FilterBlock
165    ) = visit_Scope = visit_If = visit_ScopedEvalContextModifier = _simple_visit
166
167    def visit_AssignBlock(self, node, **kwargs):
168        for child in node.body:
169            self.sym_visitor.visit(child)
170
171    def visit_CallBlock(self, node, **kwargs):
172        for child in node.iter_child_nodes(exclude=("call",)):
173            self.sym_visitor.visit(child)
174
175    def visit_OverlayScope(self, node, **kwargs):
176        for child in node.body:
177            self.sym_visitor.visit(child)
178
179    def visit_For(self, node, for_branch="body", **kwargs):
180        if for_branch == "body":
181            self.sym_visitor.visit(node.target, store_as_param=True)
182            branch = node.body
183        elif for_branch == "else":
184            branch = node.else_
185        elif for_branch == "test":
186            self.sym_visitor.visit(node.target, store_as_param=True)
187            if node.test is not None:
188                self.sym_visitor.visit(node.test)
189            return
190        else:
191            raise RuntimeError("Unknown for branch")
192        for item in branch or ():
193            self.sym_visitor.visit(item)
194
195    def visit_With(self, node, **kwargs):
196        for target in node.targets:
197            self.sym_visitor.visit(target)
198        for child in node.body:
199            self.sym_visitor.visit(child)
200
201    def generic_visit(self, node, *args, **kwargs):
202        raise NotImplementedError(
203            "Cannot find symbols for %r" % node.__class__.__name__
204        )
205
206
207class FrameSymbolVisitor(NodeVisitor):
208    """A visitor for `Frame.inspect`."""
209
210    def __init__(self, symbols):
211        self.symbols = symbols
212
213    def visit_Name(self, node, store_as_param=False, **kwargs):
214        """All assignments to names go through this function."""
215        if store_as_param or node.ctx == "param":
216            self.symbols.declare_parameter(node.name)
217        elif node.ctx == "store":
218            self.symbols.store(node.name)
219        elif node.ctx == "load":
220            self.symbols.load(node.name)
221
222    def visit_NSRef(self, node, **kwargs):
223        self.symbols.load(node.name)
224
225    def visit_If(self, node, **kwargs):
226        self.visit(node.test, **kwargs)
227
228        original_symbols = self.symbols
229
230        def inner_visit(nodes):
231            self.symbols = rv = original_symbols.copy()
232            for subnode in nodes:
233                self.visit(subnode, **kwargs)
234            self.symbols = original_symbols
235            return rv
236
237        body_symbols = inner_visit(node.body)
238        elif_symbols = inner_visit(node.elif_)
239        else_symbols = inner_visit(node.else_ or ())
240
241        self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
242
243    def visit_Macro(self, node, **kwargs):
244        self.symbols.store(node.name)
245
246    def visit_Import(self, node, **kwargs):
247        self.generic_visit(node, **kwargs)
248        self.symbols.store(node.target)
249
250    def visit_FromImport(self, node, **kwargs):
251        self.generic_visit(node, **kwargs)
252        for name in node.names:
253            if isinstance(name, tuple):
254                self.symbols.store(name[1])
255            else:
256                self.symbols.store(name)
257
258    def visit_Assign(self, node, **kwargs):
259        """Visit assignments in the correct order."""
260        self.visit(node.node, **kwargs)
261        self.visit(node.target, **kwargs)
262
263    def visit_For(self, node, **kwargs):
264        """Visiting stops at for blocks.  However the block sequence
265        is visited as part of the outer scope.
266        """
267        self.visit(node.iter, **kwargs)
268
269    def visit_CallBlock(self, node, **kwargs):
270        self.visit(node.call, **kwargs)
271
272    def visit_FilterBlock(self, node, **kwargs):
273        self.visit(node.filter, **kwargs)
274
275    def visit_With(self, node, **kwargs):
276        for target in node.values:
277            self.visit(target)
278
279    def visit_AssignBlock(self, node, **kwargs):
280        """Stop visiting at block assigns."""
281        self.visit(node.target, **kwargs)
282
283    def visit_Scope(self, node, **kwargs):
284        """Stop visiting at scopes."""
285
286    def visit_Block(self, node, **kwargs):
287        """Stop visiting at blocks."""
288
289    def visit_OverlayScope(self, node, **kwargs):
290        """Do not visit into overlay scopes."""
291