1"""Data-flow analyses."""
2
3from abc import abstractmethod
4
5from typing import Dict, Tuple, List, Set, TypeVar, Iterator, Generic, Optional, Iterable, Union
6
7from mypyc.ir.ops import (
8    Value, ControlOp,
9    BasicBlock, OpVisitor, Assign, AssignMulti, Integer, LoadErrorValue, RegisterOp, Goto, Branch,
10    Return, Call, Box, Unbox, Cast, Op, Unreachable, TupleGet, TupleSet, GetAttr, SetAttr,
11    LoadLiteral, LoadStatic, InitStatic, MethodCall, RaiseStandardError, CallC, LoadGlobal,
12    Truncate, IntOp, LoadMem, GetElementPtr, LoadAddress, ComparisonOp, SetMem, KeepAlive
13)
14from mypyc.ir.func_ir import all_values
15
16
17class CFG:
18    """Control-flow graph.
19
20    Node 0 is always assumed to be the entry point. There must be a
21    non-empty set of exits.
22    """
23
24    def __init__(self,
25                 succ: Dict[BasicBlock, List[BasicBlock]],
26                 pred: Dict[BasicBlock, List[BasicBlock]],
27                 exits: Set[BasicBlock]) -> None:
28        assert exits
29        self.succ = succ
30        self.pred = pred
31        self.exits = exits
32
33    def __str__(self) -> str:
34        lines = []
35        lines.append('exits: %s' % sorted(self.exits, key=lambda e: e.label))
36        lines.append('succ: %s' % self.succ)
37        lines.append('pred: %s' % self.pred)
38        return '\n'.join(lines)
39
40
41def get_cfg(blocks: List[BasicBlock]) -> CFG:
42    """Calculate basic block control-flow graph.
43
44    The result is a dictionary like this:
45
46         basic block index -> (successors blocks, predecesssor blocks)
47    """
48    succ_map = {}
49    pred_map = {}  # type: Dict[BasicBlock, List[BasicBlock]]
50    exits = set()
51    for block in blocks:
52
53        assert not any(isinstance(op, ControlOp) for op in block.ops[:-1]), (
54            "Control-flow ops must be at the end of blocks")
55
56        last = block.ops[-1]
57        if isinstance(last, Branch):
58            succ = [last.true, last.false]
59        elif isinstance(last, Goto):
60            succ = [last.label]
61        else:
62            succ = []
63            exits.add(block)
64
65        # Errors can occur anywhere inside a block, which means that
66        # we can't assume that the entire block has executed before
67        # jumping to the error handler. In our CFG construction, we
68        # model this as saying that a block can jump to its error
69        # handler or the error handlers of any of its normal
70        # successors (to represent an error before that next block
71        # completes). This works well for analyses like "must
72        # defined", where it implies that registers assigned in a
73        # block may be undefined in its error handler, but is in
74        # general not a precise representation of reality; any
75        # analyses that require more fidelity must wait until after
76        # exception insertion.
77        for error_point in [block] + succ:
78            if error_point.error_handler:
79                succ.append(error_point.error_handler)
80
81        succ_map[block] = succ
82        pred_map[block] = []
83    for prev, nxt in succ_map.items():
84        for label in nxt:
85            pred_map[label].append(prev)
86    return CFG(succ_map, pred_map, exits)
87
88
89def get_real_target(label: BasicBlock) -> BasicBlock:
90    if len(label.ops) == 1 and isinstance(label.ops[-1], Goto):
91        label = label.ops[-1].label
92    return label
93
94
95def cleanup_cfg(blocks: List[BasicBlock]) -> None:
96    """Cleanup the control flow graph.
97
98    This eliminates obviously dead basic blocks and eliminates blocks that contain
99    nothing but a single jump.
100
101    There is a lot more that could be done.
102    """
103    changed = True
104    while changed:
105        # First collapse any jumps to basic block that only contain a goto
106        for block in blocks:
107            term = block.ops[-1]
108            if isinstance(term, Goto):
109                term.label = get_real_target(term.label)
110            elif isinstance(term, Branch):
111                term.true = get_real_target(term.true)
112                term.false = get_real_target(term.false)
113
114        # Then delete any blocks that have no predecessors
115        changed = False
116        cfg = get_cfg(blocks)
117        orig_blocks = blocks[:]
118        blocks.clear()
119        for i, block in enumerate(orig_blocks):
120            if i == 0 or cfg.pred[block]:
121                blocks.append(block)
122            else:
123                changed = True
124
125
126T = TypeVar('T')
127
128AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]]
129
130
131class AnalysisResult(Generic[T]):
132    def __init__(self, before: 'AnalysisDict[T]', after: 'AnalysisDict[T]') -> None:
133        self.before = before
134        self.after = after
135
136    def __str__(self) -> str:
137        return 'before: %s\nafter: %s\n' % (self.before, self.after)
138
139
140GenAndKill = Tuple[Set[Value], Set[Value]]
141
142
143class BaseAnalysisVisitor(OpVisitor[GenAndKill]):
144    def visit_goto(self, op: Goto) -> GenAndKill:
145        return set(), set()
146
147    @abstractmethod
148    def visit_register_op(self, op: RegisterOp) -> GenAndKill:
149        raise NotImplementedError
150
151    @abstractmethod
152    def visit_assign(self, op: Assign) -> GenAndKill:
153        raise NotImplementedError
154
155    @abstractmethod
156    def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
157        raise NotImplementedError
158
159    @abstractmethod
160    def visit_set_mem(self, op: SetMem) -> GenAndKill:
161        raise NotImplementedError
162
163    def visit_call(self, op: Call) -> GenAndKill:
164        return self.visit_register_op(op)
165
166    def visit_method_call(self, op: MethodCall) -> GenAndKill:
167        return self.visit_register_op(op)
168
169    def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill:
170        return self.visit_register_op(op)
171
172    def visit_load_literal(self, op: LoadLiteral) -> GenAndKill:
173        return self.visit_register_op(op)
174
175    def visit_get_attr(self, op: GetAttr) -> GenAndKill:
176        return self.visit_register_op(op)
177
178    def visit_set_attr(self, op: SetAttr) -> GenAndKill:
179        return self.visit_register_op(op)
180
181    def visit_load_static(self, op: LoadStatic) -> GenAndKill:
182        return self.visit_register_op(op)
183
184    def visit_init_static(self, op: InitStatic) -> GenAndKill:
185        return self.visit_register_op(op)
186
187    def visit_tuple_get(self, op: TupleGet) -> GenAndKill:
188        return self.visit_register_op(op)
189
190    def visit_tuple_set(self, op: TupleSet) -> GenAndKill:
191        return self.visit_register_op(op)
192
193    def visit_box(self, op: Box) -> GenAndKill:
194        return self.visit_register_op(op)
195
196    def visit_unbox(self, op: Unbox) -> GenAndKill:
197        return self.visit_register_op(op)
198
199    def visit_cast(self, op: Cast) -> GenAndKill:
200        return self.visit_register_op(op)
201
202    def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
203        return self.visit_register_op(op)
204
205    def visit_call_c(self, op: CallC) -> GenAndKill:
206        return self.visit_register_op(op)
207
208    def visit_truncate(self, op: Truncate) -> GenAndKill:
209        return self.visit_register_op(op)
210
211    def visit_load_global(self, op: LoadGlobal) -> GenAndKill:
212        return self.visit_register_op(op)
213
214    def visit_int_op(self, op: IntOp) -> GenAndKill:
215        return self.visit_register_op(op)
216
217    def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill:
218        return self.visit_register_op(op)
219
220    def visit_load_mem(self, op: LoadMem) -> GenAndKill:
221        return self.visit_register_op(op)
222
223    def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill:
224        return self.visit_register_op(op)
225
226    def visit_load_address(self, op: LoadAddress) -> GenAndKill:
227        return self.visit_register_op(op)
228
229    def visit_keep_alive(self, op: KeepAlive) -> GenAndKill:
230        return self.visit_register_op(op)
231
232
233class DefinedVisitor(BaseAnalysisVisitor):
234    """Visitor for finding defined registers.
235
236    Note that this only deals with registers and not temporaries, on
237    the assumption that we never access temporaries when they might be
238    undefined.
239    """
240
241    def visit_branch(self, op: Branch) -> GenAndKill:
242        return set(), set()
243
244    def visit_return(self, op: Return) -> GenAndKill:
245        return set(), set()
246
247    def visit_unreachable(self, op: Unreachable) -> GenAndKill:
248        return set(), set()
249
250    def visit_register_op(self, op: RegisterOp) -> GenAndKill:
251        return set(), set()
252
253    def visit_assign(self, op: Assign) -> GenAndKill:
254        # Loading an error value may undefine the register.
255        if isinstance(op.src, LoadErrorValue) and op.src.undefines:
256            return set(), {op.dest}
257        else:
258            return {op.dest}, set()
259
260    def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
261        # Array registers are special and we don't track the definedness of them.
262        return set(), set()
263
264    def visit_set_mem(self, op: SetMem) -> GenAndKill:
265        return set(), set()
266
267
268def analyze_maybe_defined_regs(blocks: List[BasicBlock],
269                               cfg: CFG,
270                               initial_defined: Set[Value]) -> AnalysisResult[Value]:
271    """Calculate potentially defined registers at each CFG location.
272
273    A register is defined if it has a value along some path from the initial location.
274    """
275    return run_analysis(blocks=blocks,
276                        cfg=cfg,
277                        gen_and_kill=DefinedVisitor(),
278                        initial=initial_defined,
279                        backward=False,
280                        kind=MAYBE_ANALYSIS)
281
282
283def analyze_must_defined_regs(
284        blocks: List[BasicBlock],
285        cfg: CFG,
286        initial_defined: Set[Value],
287        regs: Iterable[Value]) -> AnalysisResult[Value]:
288    """Calculate always defined registers at each CFG location.
289
290    This analysis can work before exception insertion, since it is a
291    sound assumption that registers defined in a block might not be
292    initialized in its error handler.
293
294    A register is defined if it has a value along all paths from the
295    initial location.
296    """
297    return run_analysis(blocks=blocks,
298                        cfg=cfg,
299                        gen_and_kill=DefinedVisitor(),
300                        initial=initial_defined,
301                        backward=False,
302                        kind=MUST_ANALYSIS,
303                        universe=set(regs))
304
305
306class BorrowedArgumentsVisitor(BaseAnalysisVisitor):
307    def __init__(self, args: Set[Value]) -> None:
308        self.args = args
309
310    def visit_branch(self, op: Branch) -> GenAndKill:
311        return set(), set()
312
313    def visit_return(self, op: Return) -> GenAndKill:
314        return set(), set()
315
316    def visit_unreachable(self, op: Unreachable) -> GenAndKill:
317        return set(), set()
318
319    def visit_register_op(self, op: RegisterOp) -> GenAndKill:
320        return set(), set()
321
322    def visit_assign(self, op: Assign) -> GenAndKill:
323        if op.dest in self.args:
324            return set(), {op.dest}
325        return set(), set()
326
327    def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
328        return set(), set()
329
330    def visit_set_mem(self, op: SetMem) -> GenAndKill:
331        return set(), set()
332
333
334def analyze_borrowed_arguments(
335        blocks: List[BasicBlock],
336        cfg: CFG,
337        borrowed: Set[Value]) -> AnalysisResult[Value]:
338    """Calculate arguments that can use references borrowed from the caller.
339
340    When assigning to an argument, it no longer is borrowed.
341    """
342    return run_analysis(blocks=blocks,
343                        cfg=cfg,
344                        gen_and_kill=BorrowedArgumentsVisitor(borrowed),
345                        initial=borrowed,
346                        backward=False,
347                        kind=MUST_ANALYSIS,
348                        universe=borrowed)
349
350
351class UndefinedVisitor(BaseAnalysisVisitor):
352    def visit_branch(self, op: Branch) -> GenAndKill:
353        return set(), set()
354
355    def visit_return(self, op: Return) -> GenAndKill:
356        return set(), set()
357
358    def visit_unreachable(self, op: Unreachable) -> GenAndKill:
359        return set(), set()
360
361    def visit_register_op(self, op: RegisterOp) -> GenAndKill:
362        return set(), {op} if not op.is_void else set()
363
364    def visit_assign(self, op: Assign) -> GenAndKill:
365        return set(), {op.dest}
366
367    def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
368        return set(), {op.dest}
369
370    def visit_set_mem(self, op: SetMem) -> GenAndKill:
371        return set(), set()
372
373
374def analyze_undefined_regs(blocks: List[BasicBlock],
375                           cfg: CFG,
376                           initial_defined: Set[Value]) -> AnalysisResult[Value]:
377    """Calculate potentially undefined registers at each CFG location.
378
379    A register is undefined if there is some path from initial block
380    where it has an undefined value.
381
382    Function arguments are assumed to be always defined.
383    """
384    initial_undefined = set(all_values([], blocks)) - initial_defined
385    return run_analysis(blocks=blocks,
386                        cfg=cfg,
387                        gen_and_kill=UndefinedVisitor(),
388                        initial=initial_undefined,
389                        backward=False,
390                        kind=MAYBE_ANALYSIS)
391
392
393def non_trivial_sources(op: Op) -> Set[Value]:
394    result = set()
395    for source in op.sources():
396        if not isinstance(source, Integer):
397            result.add(source)
398    return result
399
400
401class LivenessVisitor(BaseAnalysisVisitor):
402    def visit_branch(self, op: Branch) -> GenAndKill:
403        return non_trivial_sources(op), set()
404
405    def visit_return(self, op: Return) -> GenAndKill:
406        if not isinstance(op.value, Integer):
407            return {op.value}, set()
408        else:
409            return set(), set()
410
411    def visit_unreachable(self, op: Unreachable) -> GenAndKill:
412        return set(), set()
413
414    def visit_register_op(self, op: RegisterOp) -> GenAndKill:
415        gen = non_trivial_sources(op)
416        if not op.is_void:
417            return gen, {op}
418        else:
419            return gen, set()
420
421    def visit_assign(self, op: Assign) -> GenAndKill:
422        return non_trivial_sources(op), {op.dest}
423
424    def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
425        return non_trivial_sources(op), {op.dest}
426
427    def visit_set_mem(self, op: SetMem) -> GenAndKill:
428        return non_trivial_sources(op), set()
429
430
431def analyze_live_regs(blocks: List[BasicBlock],
432                      cfg: CFG) -> AnalysisResult[Value]:
433    """Calculate live registers at each CFG location.
434
435    A register is live at a location if it can be read along some CFG path starting
436    from the location.
437    """
438    return run_analysis(blocks=blocks,
439                        cfg=cfg,
440                        gen_and_kill=LivenessVisitor(),
441                        initial=set(),
442                        backward=True,
443                        kind=MAYBE_ANALYSIS)
444
445
446# Analysis kinds
447MUST_ANALYSIS = 0
448MAYBE_ANALYSIS = 1
449
450
451# TODO the return type of this function is too complicated. Abtract it into its
452# own class.
453
454def run_analysis(blocks: List[BasicBlock],
455                 cfg: CFG,
456                 gen_and_kill: OpVisitor[Tuple[Set[T], Set[T]]],
457                 initial: Set[T],
458                 kind: int,
459                 backward: bool,
460                 universe: Optional[Set[T]] = None) -> AnalysisResult[T]:
461    """Run a general set-based data flow analysis.
462
463    Args:
464        blocks: All basic blocks
465        cfg: Control-flow graph for the code
466        gen_and_kill: Implementation of gen and kill functions for each op
467        initial: Value of analysis for the entry points (for a forward analysis) or the
468            exit points (for a backward analysis)
469        kind: MUST_ANALYSIS or MAYBE_ANALYSIS
470        backward: If False, the analysis is a forward analysis; it's backward otherwise
471        universe: For a must analysis, the set of all possible values. This is the starting
472            value for the work list algorithm, which will narrow this down until reaching a
473            fixed point. For a maybe analysis the iteration always starts from an empty set
474            and this argument is ignored.
475
476    Return analysis results: (before, after)
477    """
478    block_gen = {}
479    block_kill = {}
480
481    # Calculate kill and gen sets for entire basic blocks.
482    for block in blocks:
483        gen = set()  # type: Set[T]
484        kill = set()  # type: Set[T]
485        ops = block.ops
486        if backward:
487            ops = list(reversed(ops))
488        for op in ops:
489            opgen, opkill = op.accept(gen_and_kill)
490            gen = ((gen - opkill) | opgen)
491            kill = ((kill - opgen) | opkill)
492        block_gen[block] = gen
493        block_kill[block] = kill
494
495    # Set up initial state for worklist algorithm.
496    worklist = list(blocks)
497    if not backward:
498        worklist = worklist[::-1]  # Reverse for a small performance improvement
499    workset = set(worklist)
500    before = {}  # type: Dict[BasicBlock, Set[T]]
501    after = {}  # type: Dict[BasicBlock, Set[T]]
502    for block in blocks:
503        if kind == MAYBE_ANALYSIS:
504            before[block] = set()
505            after[block] = set()
506        else:
507            assert universe is not None, "Universe must be defined for a must analysis"
508            before[block] = set(universe)
509            after[block] = set(universe)
510
511    if backward:
512        pred_map = cfg.succ
513        succ_map = cfg.pred
514    else:
515        pred_map = cfg.pred
516        succ_map = cfg.succ
517
518    # Run work list algorithm to generate in and out sets for each basic block.
519    while worklist:
520        label = worklist.pop()
521        workset.remove(label)
522        if pred_map[label]:
523            new_before = None  # type: Union[Set[T], None]
524            for pred in pred_map[label]:
525                if new_before is None:
526                    new_before = set(after[pred])
527                elif kind == MAYBE_ANALYSIS:
528                    new_before |= after[pred]
529                else:
530                    new_before &= after[pred]
531            assert new_before is not None
532        else:
533            new_before = set(initial)
534        before[label] = new_before
535        new_after = (new_before - block_kill[label]) | block_gen[label]
536        if new_after != after[label]:
537            for succ in succ_map[label]:
538                if succ not in workset:
539                    worklist.append(succ)
540                    workset.add(succ)
541        after[label] = new_after
542
543    # Run algorithm for each basic block to generate opcode-level sets.
544    op_before = {}  # type: Dict[Tuple[BasicBlock, int], Set[T]]
545    op_after = {}  # type: Dict[Tuple[BasicBlock, int], Set[T]]
546    for block in blocks:
547        label = block
548        cur = before[label]
549        ops_enum = enumerate(block.ops)  # type: Iterator[Tuple[int, Op]]
550        if backward:
551            ops_enum = reversed(list(ops_enum))
552        for idx, op in ops_enum:
553            op_before[label, idx] = cur
554            opgen, opkill = op.accept(gen_and_kill)
555            cur = (cur - opkill) | opgen
556            op_after[label, idx] = cur
557    if backward:
558        op_after, op_before = op_before, op_after
559
560    return AnalysisResult(op_before, op_after)
561