1"""Data-flow analyses.""" 2 3from abc import abstractmethod 4 5from typing import Dict, Tuple, List, Set, TypeVar, Iterator, Generic, Optional, Iterable, Union 6 7from mypyc.ir.ops import ( 8 Value, ControlOp, 9 BasicBlock, OpVisitor, Assign, AssignMulti, Integer, LoadErrorValue, RegisterOp, Goto, Branch, 10 Return, Call, Box, Unbox, Cast, Op, Unreachable, TupleGet, TupleSet, GetAttr, SetAttr, 11 LoadLiteral, LoadStatic, InitStatic, MethodCall, RaiseStandardError, CallC, LoadGlobal, 12 Truncate, IntOp, LoadMem, GetElementPtr, LoadAddress, ComparisonOp, SetMem, KeepAlive 13) 14from mypyc.ir.func_ir import all_values 15 16 17class CFG: 18 """Control-flow graph. 19 20 Node 0 is always assumed to be the entry point. There must be a 21 non-empty set of exits. 22 """ 23 24 def __init__(self, 25 succ: Dict[BasicBlock, List[BasicBlock]], 26 pred: Dict[BasicBlock, List[BasicBlock]], 27 exits: Set[BasicBlock]) -> None: 28 assert exits 29 self.succ = succ 30 self.pred = pred 31 self.exits = exits 32 33 def __str__(self) -> str: 34 lines = [] 35 lines.append('exits: %s' % sorted(self.exits, key=lambda e: e.label)) 36 lines.append('succ: %s' % self.succ) 37 lines.append('pred: %s' % self.pred) 38 return '\n'.join(lines) 39 40 41def get_cfg(blocks: List[BasicBlock]) -> CFG: 42 """Calculate basic block control-flow graph. 43 44 The result is a dictionary like this: 45 46 basic block index -> (successors blocks, predecesssor blocks) 47 """ 48 succ_map = {} 49 pred_map = {} # type: Dict[BasicBlock, List[BasicBlock]] 50 exits = set() 51 for block in blocks: 52 53 assert not any(isinstance(op, ControlOp) for op in block.ops[:-1]), ( 54 "Control-flow ops must be at the end of blocks") 55 56 last = block.ops[-1] 57 if isinstance(last, Branch): 58 succ = [last.true, last.false] 59 elif isinstance(last, Goto): 60 succ = [last.label] 61 else: 62 succ = [] 63 exits.add(block) 64 65 # Errors can occur anywhere inside a block, which means that 66 # we can't assume that the entire block has executed before 67 # jumping to the error handler. In our CFG construction, we 68 # model this as saying that a block can jump to its error 69 # handler or the error handlers of any of its normal 70 # successors (to represent an error before that next block 71 # completes). This works well for analyses like "must 72 # defined", where it implies that registers assigned in a 73 # block may be undefined in its error handler, but is in 74 # general not a precise representation of reality; any 75 # analyses that require more fidelity must wait until after 76 # exception insertion. 77 for error_point in [block] + succ: 78 if error_point.error_handler: 79 succ.append(error_point.error_handler) 80 81 succ_map[block] = succ 82 pred_map[block] = [] 83 for prev, nxt in succ_map.items(): 84 for label in nxt: 85 pred_map[label].append(prev) 86 return CFG(succ_map, pred_map, exits) 87 88 89def get_real_target(label: BasicBlock) -> BasicBlock: 90 if len(label.ops) == 1 and isinstance(label.ops[-1], Goto): 91 label = label.ops[-1].label 92 return label 93 94 95def cleanup_cfg(blocks: List[BasicBlock]) -> None: 96 """Cleanup the control flow graph. 97 98 This eliminates obviously dead basic blocks and eliminates blocks that contain 99 nothing but a single jump. 100 101 There is a lot more that could be done. 102 """ 103 changed = True 104 while changed: 105 # First collapse any jumps to basic block that only contain a goto 106 for block in blocks: 107 term = block.ops[-1] 108 if isinstance(term, Goto): 109 term.label = get_real_target(term.label) 110 elif isinstance(term, Branch): 111 term.true = get_real_target(term.true) 112 term.false = get_real_target(term.false) 113 114 # Then delete any blocks that have no predecessors 115 changed = False 116 cfg = get_cfg(blocks) 117 orig_blocks = blocks[:] 118 blocks.clear() 119 for i, block in enumerate(orig_blocks): 120 if i == 0 or cfg.pred[block]: 121 blocks.append(block) 122 else: 123 changed = True 124 125 126T = TypeVar('T') 127 128AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]] 129 130 131class AnalysisResult(Generic[T]): 132 def __init__(self, before: 'AnalysisDict[T]', after: 'AnalysisDict[T]') -> None: 133 self.before = before 134 self.after = after 135 136 def __str__(self) -> str: 137 return 'before: %s\nafter: %s\n' % (self.before, self.after) 138 139 140GenAndKill = Tuple[Set[Value], Set[Value]] 141 142 143class BaseAnalysisVisitor(OpVisitor[GenAndKill]): 144 def visit_goto(self, op: Goto) -> GenAndKill: 145 return set(), set() 146 147 @abstractmethod 148 def visit_register_op(self, op: RegisterOp) -> GenAndKill: 149 raise NotImplementedError 150 151 @abstractmethod 152 def visit_assign(self, op: Assign) -> GenAndKill: 153 raise NotImplementedError 154 155 @abstractmethod 156 def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: 157 raise NotImplementedError 158 159 @abstractmethod 160 def visit_set_mem(self, op: SetMem) -> GenAndKill: 161 raise NotImplementedError 162 163 def visit_call(self, op: Call) -> GenAndKill: 164 return self.visit_register_op(op) 165 166 def visit_method_call(self, op: MethodCall) -> GenAndKill: 167 return self.visit_register_op(op) 168 169 def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill: 170 return self.visit_register_op(op) 171 172 def visit_load_literal(self, op: LoadLiteral) -> GenAndKill: 173 return self.visit_register_op(op) 174 175 def visit_get_attr(self, op: GetAttr) -> GenAndKill: 176 return self.visit_register_op(op) 177 178 def visit_set_attr(self, op: SetAttr) -> GenAndKill: 179 return self.visit_register_op(op) 180 181 def visit_load_static(self, op: LoadStatic) -> GenAndKill: 182 return self.visit_register_op(op) 183 184 def visit_init_static(self, op: InitStatic) -> GenAndKill: 185 return self.visit_register_op(op) 186 187 def visit_tuple_get(self, op: TupleGet) -> GenAndKill: 188 return self.visit_register_op(op) 189 190 def visit_tuple_set(self, op: TupleSet) -> GenAndKill: 191 return self.visit_register_op(op) 192 193 def visit_box(self, op: Box) -> GenAndKill: 194 return self.visit_register_op(op) 195 196 def visit_unbox(self, op: Unbox) -> GenAndKill: 197 return self.visit_register_op(op) 198 199 def visit_cast(self, op: Cast) -> GenAndKill: 200 return self.visit_register_op(op) 201 202 def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill: 203 return self.visit_register_op(op) 204 205 def visit_call_c(self, op: CallC) -> GenAndKill: 206 return self.visit_register_op(op) 207 208 def visit_truncate(self, op: Truncate) -> GenAndKill: 209 return self.visit_register_op(op) 210 211 def visit_load_global(self, op: LoadGlobal) -> GenAndKill: 212 return self.visit_register_op(op) 213 214 def visit_int_op(self, op: IntOp) -> GenAndKill: 215 return self.visit_register_op(op) 216 217 def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill: 218 return self.visit_register_op(op) 219 220 def visit_load_mem(self, op: LoadMem) -> GenAndKill: 221 return self.visit_register_op(op) 222 223 def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill: 224 return self.visit_register_op(op) 225 226 def visit_load_address(self, op: LoadAddress) -> GenAndKill: 227 return self.visit_register_op(op) 228 229 def visit_keep_alive(self, op: KeepAlive) -> GenAndKill: 230 return self.visit_register_op(op) 231 232 233class DefinedVisitor(BaseAnalysisVisitor): 234 """Visitor for finding defined registers. 235 236 Note that this only deals with registers and not temporaries, on 237 the assumption that we never access temporaries when they might be 238 undefined. 239 """ 240 241 def visit_branch(self, op: Branch) -> GenAndKill: 242 return set(), set() 243 244 def visit_return(self, op: Return) -> GenAndKill: 245 return set(), set() 246 247 def visit_unreachable(self, op: Unreachable) -> GenAndKill: 248 return set(), set() 249 250 def visit_register_op(self, op: RegisterOp) -> GenAndKill: 251 return set(), set() 252 253 def visit_assign(self, op: Assign) -> GenAndKill: 254 # Loading an error value may undefine the register. 255 if isinstance(op.src, LoadErrorValue) and op.src.undefines: 256 return set(), {op.dest} 257 else: 258 return {op.dest}, set() 259 260 def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: 261 # Array registers are special and we don't track the definedness of them. 262 return set(), set() 263 264 def visit_set_mem(self, op: SetMem) -> GenAndKill: 265 return set(), set() 266 267 268def analyze_maybe_defined_regs(blocks: List[BasicBlock], 269 cfg: CFG, 270 initial_defined: Set[Value]) -> AnalysisResult[Value]: 271 """Calculate potentially defined registers at each CFG location. 272 273 A register is defined if it has a value along some path from the initial location. 274 """ 275 return run_analysis(blocks=blocks, 276 cfg=cfg, 277 gen_and_kill=DefinedVisitor(), 278 initial=initial_defined, 279 backward=False, 280 kind=MAYBE_ANALYSIS) 281 282 283def analyze_must_defined_regs( 284 blocks: List[BasicBlock], 285 cfg: CFG, 286 initial_defined: Set[Value], 287 regs: Iterable[Value]) -> AnalysisResult[Value]: 288 """Calculate always defined registers at each CFG location. 289 290 This analysis can work before exception insertion, since it is a 291 sound assumption that registers defined in a block might not be 292 initialized in its error handler. 293 294 A register is defined if it has a value along all paths from the 295 initial location. 296 """ 297 return run_analysis(blocks=blocks, 298 cfg=cfg, 299 gen_and_kill=DefinedVisitor(), 300 initial=initial_defined, 301 backward=False, 302 kind=MUST_ANALYSIS, 303 universe=set(regs)) 304 305 306class BorrowedArgumentsVisitor(BaseAnalysisVisitor): 307 def __init__(self, args: Set[Value]) -> None: 308 self.args = args 309 310 def visit_branch(self, op: Branch) -> GenAndKill: 311 return set(), set() 312 313 def visit_return(self, op: Return) -> GenAndKill: 314 return set(), set() 315 316 def visit_unreachable(self, op: Unreachable) -> GenAndKill: 317 return set(), set() 318 319 def visit_register_op(self, op: RegisterOp) -> GenAndKill: 320 return set(), set() 321 322 def visit_assign(self, op: Assign) -> GenAndKill: 323 if op.dest in self.args: 324 return set(), {op.dest} 325 return set(), set() 326 327 def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: 328 return set(), set() 329 330 def visit_set_mem(self, op: SetMem) -> GenAndKill: 331 return set(), set() 332 333 334def analyze_borrowed_arguments( 335 blocks: List[BasicBlock], 336 cfg: CFG, 337 borrowed: Set[Value]) -> AnalysisResult[Value]: 338 """Calculate arguments that can use references borrowed from the caller. 339 340 When assigning to an argument, it no longer is borrowed. 341 """ 342 return run_analysis(blocks=blocks, 343 cfg=cfg, 344 gen_and_kill=BorrowedArgumentsVisitor(borrowed), 345 initial=borrowed, 346 backward=False, 347 kind=MUST_ANALYSIS, 348 universe=borrowed) 349 350 351class UndefinedVisitor(BaseAnalysisVisitor): 352 def visit_branch(self, op: Branch) -> GenAndKill: 353 return set(), set() 354 355 def visit_return(self, op: Return) -> GenAndKill: 356 return set(), set() 357 358 def visit_unreachable(self, op: Unreachable) -> GenAndKill: 359 return set(), set() 360 361 def visit_register_op(self, op: RegisterOp) -> GenAndKill: 362 return set(), {op} if not op.is_void else set() 363 364 def visit_assign(self, op: Assign) -> GenAndKill: 365 return set(), {op.dest} 366 367 def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: 368 return set(), {op.dest} 369 370 def visit_set_mem(self, op: SetMem) -> GenAndKill: 371 return set(), set() 372 373 374def analyze_undefined_regs(blocks: List[BasicBlock], 375 cfg: CFG, 376 initial_defined: Set[Value]) -> AnalysisResult[Value]: 377 """Calculate potentially undefined registers at each CFG location. 378 379 A register is undefined if there is some path from initial block 380 where it has an undefined value. 381 382 Function arguments are assumed to be always defined. 383 """ 384 initial_undefined = set(all_values([], blocks)) - initial_defined 385 return run_analysis(blocks=blocks, 386 cfg=cfg, 387 gen_and_kill=UndefinedVisitor(), 388 initial=initial_undefined, 389 backward=False, 390 kind=MAYBE_ANALYSIS) 391 392 393def non_trivial_sources(op: Op) -> Set[Value]: 394 result = set() 395 for source in op.sources(): 396 if not isinstance(source, Integer): 397 result.add(source) 398 return result 399 400 401class LivenessVisitor(BaseAnalysisVisitor): 402 def visit_branch(self, op: Branch) -> GenAndKill: 403 return non_trivial_sources(op), set() 404 405 def visit_return(self, op: Return) -> GenAndKill: 406 if not isinstance(op.value, Integer): 407 return {op.value}, set() 408 else: 409 return set(), set() 410 411 def visit_unreachable(self, op: Unreachable) -> GenAndKill: 412 return set(), set() 413 414 def visit_register_op(self, op: RegisterOp) -> GenAndKill: 415 gen = non_trivial_sources(op) 416 if not op.is_void: 417 return gen, {op} 418 else: 419 return gen, set() 420 421 def visit_assign(self, op: Assign) -> GenAndKill: 422 return non_trivial_sources(op), {op.dest} 423 424 def visit_assign_multi(self, op: AssignMulti) -> GenAndKill: 425 return non_trivial_sources(op), {op.dest} 426 427 def visit_set_mem(self, op: SetMem) -> GenAndKill: 428 return non_trivial_sources(op), set() 429 430 431def analyze_live_regs(blocks: List[BasicBlock], 432 cfg: CFG) -> AnalysisResult[Value]: 433 """Calculate live registers at each CFG location. 434 435 A register is live at a location if it can be read along some CFG path starting 436 from the location. 437 """ 438 return run_analysis(blocks=blocks, 439 cfg=cfg, 440 gen_and_kill=LivenessVisitor(), 441 initial=set(), 442 backward=True, 443 kind=MAYBE_ANALYSIS) 444 445 446# Analysis kinds 447MUST_ANALYSIS = 0 448MAYBE_ANALYSIS = 1 449 450 451# TODO the return type of this function is too complicated. Abtract it into its 452# own class. 453 454def run_analysis(blocks: List[BasicBlock], 455 cfg: CFG, 456 gen_and_kill: OpVisitor[Tuple[Set[T], Set[T]]], 457 initial: Set[T], 458 kind: int, 459 backward: bool, 460 universe: Optional[Set[T]] = None) -> AnalysisResult[T]: 461 """Run a general set-based data flow analysis. 462 463 Args: 464 blocks: All basic blocks 465 cfg: Control-flow graph for the code 466 gen_and_kill: Implementation of gen and kill functions for each op 467 initial: Value of analysis for the entry points (for a forward analysis) or the 468 exit points (for a backward analysis) 469 kind: MUST_ANALYSIS or MAYBE_ANALYSIS 470 backward: If False, the analysis is a forward analysis; it's backward otherwise 471 universe: For a must analysis, the set of all possible values. This is the starting 472 value for the work list algorithm, which will narrow this down until reaching a 473 fixed point. For a maybe analysis the iteration always starts from an empty set 474 and this argument is ignored. 475 476 Return analysis results: (before, after) 477 """ 478 block_gen = {} 479 block_kill = {} 480 481 # Calculate kill and gen sets for entire basic blocks. 482 for block in blocks: 483 gen = set() # type: Set[T] 484 kill = set() # type: Set[T] 485 ops = block.ops 486 if backward: 487 ops = list(reversed(ops)) 488 for op in ops: 489 opgen, opkill = op.accept(gen_and_kill) 490 gen = ((gen - opkill) | opgen) 491 kill = ((kill - opgen) | opkill) 492 block_gen[block] = gen 493 block_kill[block] = kill 494 495 # Set up initial state for worklist algorithm. 496 worklist = list(blocks) 497 if not backward: 498 worklist = worklist[::-1] # Reverse for a small performance improvement 499 workset = set(worklist) 500 before = {} # type: Dict[BasicBlock, Set[T]] 501 after = {} # type: Dict[BasicBlock, Set[T]] 502 for block in blocks: 503 if kind == MAYBE_ANALYSIS: 504 before[block] = set() 505 after[block] = set() 506 else: 507 assert universe is not None, "Universe must be defined for a must analysis" 508 before[block] = set(universe) 509 after[block] = set(universe) 510 511 if backward: 512 pred_map = cfg.succ 513 succ_map = cfg.pred 514 else: 515 pred_map = cfg.pred 516 succ_map = cfg.succ 517 518 # Run work list algorithm to generate in and out sets for each basic block. 519 while worklist: 520 label = worklist.pop() 521 workset.remove(label) 522 if pred_map[label]: 523 new_before = None # type: Union[Set[T], None] 524 for pred in pred_map[label]: 525 if new_before is None: 526 new_before = set(after[pred]) 527 elif kind == MAYBE_ANALYSIS: 528 new_before |= after[pred] 529 else: 530 new_before &= after[pred] 531 assert new_before is not None 532 else: 533 new_before = set(initial) 534 before[label] = new_before 535 new_after = (new_before - block_kill[label]) | block_gen[label] 536 if new_after != after[label]: 537 for succ in succ_map[label]: 538 if succ not in workset: 539 worklist.append(succ) 540 workset.add(succ) 541 after[label] = new_after 542 543 # Run algorithm for each basic block to generate opcode-level sets. 544 op_before = {} # type: Dict[Tuple[BasicBlock, int], Set[T]] 545 op_after = {} # type: Dict[Tuple[BasicBlock, int], Set[T]] 546 for block in blocks: 547 label = block 548 cur = before[label] 549 ops_enum = enumerate(block.ops) # type: Iterator[Tuple[int, Op]] 550 if backward: 551 ops_enum = reversed(list(ops_enum)) 552 for idx, op in ops_enum: 553 op_before[label, idx] = cur 554 opgen, opkill = op.accept(gen_and_kill) 555 cur = (cur - opkill) | opgen 556 op_after[label, idx] = cur 557 if backward: 558 op_after, op_before = op_before, op_after 559 560 return AnalysisResult(op_before, op_after) 561