1from typing import Dict, List
2from typing_extensions import Final
3
4from mypy.nodes import (
5    Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr,
6    WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, StarExpr, ImportFrom,
7    MemberExpr, IndexExpr, Import, ClassDef
8)
9from mypy.traverser import TraverserVisitor
10
11# Scope kinds
12FILE = 0  # type: Final
13FUNCTION = 1  # type: Final
14CLASS = 2  # type: Final
15
16
17class VariableRenameVisitor(TraverserVisitor):
18    """Rename variables to allow redefinition of variables.
19
20    For example, consider this code:
21
22      x = 0
23      f(x)
24
25      x = "a"
26      g(x)
27
28    It will be transformed like this:
29
30      x' = 0
31      f(x')
32
33      x = "a"
34      g(x)
35
36    There will be two independent variables (x' and x) that will have separate
37    inferred types. The publicly exposed variant will get the non-suffixed name.
38    This is the last definition at module top level and the first definition
39    (argument) within a function.
40
41    Renaming only happens for assignments within the same block. Renaming is
42    performed before semantic analysis, immediately after parsing.
43
44    The implementation performs a rudimentary static analysis. The analysis is
45    overly conservative to keep things simple.
46    """
47
48    def __init__(self) -> None:
49        # Counter for labeling new blocks
50        self.block_id = 0
51        # Number of surrounding try statements that disallow variable redefinition
52        self.disallow_redef_depth = 0
53        # Number of surrounding loop statements
54        self.loop_depth = 0
55        # Map block id to loop depth.
56        self.block_loop_depth = {}  # type: Dict[int, int]
57        # Stack of block ids being processed.
58        self.blocks = []  # type: List[int]
59        # List of scopes; each scope maps short (unqualified) name to block id.
60        self.var_blocks = []  # type: List[Dict[str, int]]
61
62        # References to variables that we may need to rename. List of
63        # scopes; each scope is a mapping from name to list of collections
64        # of names that refer to the same logical variable.
65        self.refs = []  # type: List[Dict[str, List[List[NameExpr]]]]
66        # Number of reads of the most recent definition of a variable (per scope)
67        self.num_reads = []  # type: List[Dict[str, int]]
68        # Kinds of nested scopes (FILE, FUNCTION or CLASS)
69        self.scope_kinds = []  # type: List[int]
70
71    def visit_mypy_file(self, file_node: MypyFile) -> None:
72        """Rename variables within a file.
73
74        This is the main entry point to this class.
75        """
76        self.clear()
77        self.enter_scope(FILE)
78        self.enter_block()
79
80        for d in file_node.defs:
81            d.accept(self)
82
83        self.leave_block()
84        self.leave_scope()
85
86    def visit_func_def(self, fdef: FuncDef) -> None:
87        # Conservatively do not allow variable defined before a function to
88        # be redefined later, since function could refer to either definition.
89        self.reject_redefinition_of_vars_in_scope()
90
91        self.enter_scope(FUNCTION)
92        self.enter_block()
93
94        for arg in fdef.arguments:
95            name = arg.variable.name
96            # 'self' can't be redefined since it's special as it allows definition of
97            # attributes. 'cls' can't be used to define attributes so we can ignore it.
98            can_be_redefined = name != 'self'  # TODO: Proper check
99            self.record_assignment(arg.variable.name, can_be_redefined)
100            self.handle_arg(name)
101
102        for stmt in fdef.body.body:
103            stmt.accept(self)
104
105        self.leave_block()
106        self.leave_scope()
107
108    def visit_class_def(self, cdef: ClassDef) -> None:
109        self.reject_redefinition_of_vars_in_scope()
110        self.enter_scope(CLASS)
111        super().visit_class_def(cdef)
112        self.leave_scope()
113
114    def visit_block(self, block: Block) -> None:
115        self.enter_block()
116        super().visit_block(block)
117        self.leave_block()
118
119    def visit_while_stmt(self, stmt: WhileStmt) -> None:
120        self.enter_loop()
121        super().visit_while_stmt(stmt)
122        self.leave_loop()
123
124    def visit_for_stmt(self, stmt: ForStmt) -> None:
125        stmt.expr.accept(self)
126        self.analyze_lvalue(stmt.index, True)
127        # Also analyze as non-lvalue so that every for loop index variable is assumed to be read.
128        stmt.index.accept(self)
129        self.enter_loop()
130        stmt.body.accept(self)
131        self.leave_loop()
132        if stmt.else_body:
133            stmt.else_body.accept(self)
134
135    def visit_break_stmt(self, stmt: BreakStmt) -> None:
136        self.reject_redefinition_of_vars_in_loop()
137
138    def visit_continue_stmt(self, stmt: ContinueStmt) -> None:
139        self.reject_redefinition_of_vars_in_loop()
140
141    def visit_try_stmt(self, stmt: TryStmt) -> None:
142        # Variables defined by a try statement get special treatment in the
143        # type checker which allows them to be always redefined, so no need to
144        # do renaming here.
145        self.enter_try()
146        super().visit_try_stmt(stmt)
147        self.leave_try()
148
149    def visit_with_stmt(self, stmt: WithStmt) -> None:
150        for expr in stmt.expr:
151            expr.accept(self)
152        for target in stmt.target:
153            if target is not None:
154                self.analyze_lvalue(target)
155        # We allow redefinitions in the body of a with statement for
156        # convenience.  This is unsafe since with statements can affect control
157        # flow by catching exceptions, but this is rare except for
158        # assertRaises() and other similar functions, where the exception is
159        # raised by the last statement in the body, which usually isn't a
160        # problem.
161        stmt.body.accept(self)
162
163    def visit_import(self, imp: Import) -> None:
164        for id, as_id in imp.ids:
165            self.record_assignment(as_id or id, False)
166
167    def visit_import_from(self, imp: ImportFrom) -> None:
168        for id, as_id in imp.names:
169            self.record_assignment(as_id or id, False)
170
171    def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
172        s.rvalue.accept(self)
173        for lvalue in s.lvalues:
174            self.analyze_lvalue(lvalue)
175
176    def analyze_lvalue(self, lvalue: Lvalue, is_nested: bool = False) -> None:
177        """Process assignment; in particular, keep track of (re)defined names.
178
179        Args:
180            is_nested: True for non-outermost Lvalue in a multiple assignment such as
181                "x, y = ..."
182        """
183        if isinstance(lvalue, NameExpr):
184            name = lvalue.name
185            is_new = self.record_assignment(name, True)
186            if is_new:
187                self.handle_def(lvalue)
188            else:
189                self.handle_refine(lvalue)
190            if is_nested:
191                # This allows these to be redefined freely even if never read. Multiple
192                # assignment like "x, _ _ = y" defines dummy variables that are never read.
193                self.handle_ref(lvalue)
194        elif isinstance(lvalue, (ListExpr, TupleExpr)):
195            for item in lvalue.items:
196                self.analyze_lvalue(item, is_nested=True)
197        elif isinstance(lvalue, MemberExpr):
198            lvalue.expr.accept(self)
199        elif isinstance(lvalue, IndexExpr):
200            lvalue.base.accept(self)
201            lvalue.index.accept(self)
202        elif isinstance(lvalue, StarExpr):
203            # Propagate is_nested since in a typical use case like "x, *rest = ..." 'rest' may
204            # be freely reused.
205            self.analyze_lvalue(lvalue.expr, is_nested=is_nested)
206
207    def visit_name_expr(self, expr: NameExpr) -> None:
208        self.handle_ref(expr)
209
210    # Helpers for renaming references
211
212    def handle_arg(self, name: str) -> None:
213        """Store function argument."""
214        self.refs[-1][name] = [[]]
215        self.num_reads[-1][name] = 0
216
217    def handle_def(self, expr: NameExpr) -> None:
218        """Store new name definition."""
219        name = expr.name
220        names = self.refs[-1].setdefault(name, [])
221        names.append([expr])
222        self.num_reads[-1][name] = 0
223
224    def handle_refine(self, expr: NameExpr) -> None:
225        """Store assignment to an existing name (that replaces previous value, if any)."""
226        name = expr.name
227        if name in self.refs[-1]:
228            names = self.refs[-1][name]
229            if not names:
230                names.append([])
231            names[-1].append(expr)
232
233    def handle_ref(self, expr: NameExpr) -> None:
234        """Store reference to defined name."""
235        name = expr.name
236        if name in self.refs[-1]:
237            names = self.refs[-1][name]
238            if not names:
239                names.append([])
240            names[-1].append(expr)
241        num_reads = self.num_reads[-1]
242        num_reads[name] = num_reads.get(name, 0) + 1
243
244    def flush_refs(self) -> None:
245        """Rename all references within the current scope.
246
247        This will be called at the end of a scope.
248        """
249        is_func = self.scope_kinds[-1] == FUNCTION
250        for name, refs in self.refs[-1].items():
251            if len(refs) == 1:
252                # Only one definition -- no renaming needed.
253                continue
254            if is_func:
255                # In a function, don't rename the first definition, as it
256                # may be an argument that must preserve the name.
257                to_rename = refs[1:]
258            else:
259                # At module top level, don't rename the final definition,
260                # as it will be publicly visible outside the module.
261                to_rename = refs[:-1]
262            for i, item in enumerate(to_rename):
263                self.rename_refs(item, i)
264        self.refs.pop()
265
266    def rename_refs(self, names: List[NameExpr], index: int) -> None:
267        name = names[0].name
268        new_name = name + "'" * (index + 1)
269        for expr in names:
270            expr.name = new_name
271
272    # Helpers for determining which assignments define new variables
273
274    def clear(self) -> None:
275        self.blocks = []
276        self.var_blocks = []
277
278    def enter_block(self) -> None:
279        self.block_id += 1
280        self.blocks.append(self.block_id)
281        self.block_loop_depth[self.block_id] = self.loop_depth
282
283    def leave_block(self) -> None:
284        self.blocks.pop()
285
286    def enter_try(self) -> None:
287        self.disallow_redef_depth += 1
288
289    def leave_try(self) -> None:
290        self.disallow_redef_depth -= 1
291
292    def enter_loop(self) -> None:
293        self.loop_depth += 1
294
295    def leave_loop(self) -> None:
296        self.loop_depth -= 1
297
298    def current_block(self) -> int:
299        return self.blocks[-1]
300
301    def enter_scope(self, kind: int) -> None:
302        self.var_blocks.append({})
303        self.refs.append({})
304        self.num_reads.append({})
305        self.scope_kinds.append(kind)
306
307    def leave_scope(self) -> None:
308        self.flush_refs()
309        self.var_blocks.pop()
310        self.num_reads.pop()
311        self.scope_kinds.pop()
312
313    def is_nested(self) -> int:
314        return len(self.var_blocks) > 1
315
316    def reject_redefinition_of_vars_in_scope(self) -> None:
317        """Make it impossible to redefine defined variables in the current scope.
318
319        This is used if we encounter a function definition that
320        can make it ambiguous which definition is live. Example:
321
322          x = 0
323
324          def f() -> int:
325              return x
326
327          x = ''  # Error -- cannot redefine x across function definition
328        """
329        var_blocks = self.var_blocks[-1]
330        for key in var_blocks:
331            var_blocks[key] = -1
332
333    def reject_redefinition_of_vars_in_loop(self) -> None:
334        """Reject redefinition of variables in the innermost loop.
335
336        If there is an early exit from a loop, there may be ambiguity about which
337        value may escape the loop. Example where this matters:
338
339          while f():
340              x = 0
341              if g():
342                  break
343              x = ''  # Error -- not a redefinition
344          reveal_type(x)  # int
345
346        This method ensures that the second assignment to 'x' doesn't introduce a new
347        variable.
348        """
349        var_blocks = self.var_blocks[-1]
350        for key, block in var_blocks.items():
351            if self.block_loop_depth.get(block) == self.loop_depth:
352                var_blocks[key] = -1
353
354    def record_assignment(self, name: str, can_be_redefined: bool) -> bool:
355        """Record assignment to given name and return True if it defines a new variable.
356
357        Args:
358            can_be_redefined: If True, allows assignment in the same block to redefine
359                this name (if this is a new definition)
360        """
361        if self.num_reads[-1].get(name, -1) == 0:
362            # Only set, not read, so no reason to redefine
363            return False
364        if self.disallow_redef_depth > 0:
365            # Can't redefine within try/with a block.
366            can_be_redefined = False
367        block = self.current_block()
368        var_blocks = self.var_blocks[-1]
369        if name not in var_blocks:
370            # New definition in this scope.
371            if can_be_redefined:
372                # Store the block where this was defined to allow redefinition in
373                # the same block only.
374                var_blocks[name] = block
375            else:
376                # This doesn't support arbitrary redefinition.
377                var_blocks[name] = -1
378            return True
379        elif var_blocks[name] == block:
380            # Redefinition -- defines a new variable with the same name.
381            return True
382        else:
383            # Assigns to an existing variable.
384            return False
385