1""" 2Utils for IR analysis 3""" 4import operator 5from functools import reduce 6from collections import namedtuple, defaultdict 7 8from .controlflow import CFGraph 9from numba.core import types, errors, ir, consts 10from numba.misc import special 11 12# 13# Analysis related to variable lifetime 14# 15 16_use_defs_result = namedtuple('use_defs_result', 'usemap,defmap') 17 18# other packages that define new nodes add calls for finding defs 19# format: {type:function} 20ir_extension_usedefs = {} 21 22 23def compute_use_defs(blocks): 24 """ 25 Find variable use/def per block. 26 """ 27 28 var_use_map = {} # { block offset -> set of vars } 29 var_def_map = {} # { block offset -> set of vars } 30 for offset, ir_block in blocks.items(): 31 var_use_map[offset] = use_set = set() 32 var_def_map[offset] = def_set = set() 33 for stmt in ir_block.body: 34 if type(stmt) in ir_extension_usedefs: 35 func = ir_extension_usedefs[type(stmt)] 36 func(stmt, use_set, def_set) 37 continue 38 if isinstance(stmt, ir.Assign): 39 if isinstance(stmt.value, ir.Inst): 40 rhs_set = set(var.name for var in stmt.value.list_vars()) 41 elif isinstance(stmt.value, ir.Var): 42 rhs_set = set([stmt.value.name]) 43 elif isinstance(stmt.value, (ir.Arg, ir.Const, ir.Global, 44 ir.FreeVar)): 45 rhs_set = () 46 else: 47 raise AssertionError('unreachable', type(stmt.value)) 48 # If lhs not in rhs of the assignment 49 if stmt.target.name not in rhs_set: 50 def_set.add(stmt.target.name) 51 52 for var in stmt.list_vars(): 53 # do not include locally defined vars to use-map 54 if var.name not in def_set: 55 use_set.add(var.name) 56 57 return _use_defs_result(usemap=var_use_map, defmap=var_def_map) 58 59 60def compute_live_map(cfg, blocks, var_use_map, var_def_map): 61 """ 62 Find variables that must be alive at the ENTRY of each block. 63 We use a simple fix-point algorithm that iterates until the set of 64 live variables is unchanged for each block. 65 """ 66 def fix_point_progress(dct): 67 """Helper function to determine if a fix-point has been reached. 68 """ 69 return tuple(len(v) for v in dct.values()) 70 71 def fix_point(fn, dct): 72 """Helper function to run fix-point algorithm. 73 """ 74 old_point = None 75 new_point = fix_point_progress(dct) 76 while old_point != new_point: 77 fn(dct) 78 old_point = new_point 79 new_point = fix_point_progress(dct) 80 81 def def_reach(dct): 82 """Find all variable definition reachable at the entry of a block 83 """ 84 for offset in var_def_map: 85 used_or_defined = var_def_map[offset] | var_use_map[offset] 86 dct[offset] |= used_or_defined 87 # Propagate to outgoing nodes 88 for out_blk, _ in cfg.successors(offset): 89 dct[out_blk] |= dct[offset] 90 91 def liveness(dct): 92 """Find live variables. 93 94 Push var usage backward. 95 """ 96 for offset in dct: 97 # Live vars here 98 live_vars = dct[offset] 99 for inc_blk, _data in cfg.predecessors(offset): 100 # Reachable at the predecessor 101 reachable = live_vars & def_reach_map[inc_blk] 102 # But not defined in the predecessor 103 dct[inc_blk] |= reachable - var_def_map[inc_blk] 104 105 live_map = {} 106 for offset in blocks.keys(): 107 live_map[offset] = set(var_use_map[offset]) 108 109 def_reach_map = defaultdict(set) 110 fix_point(def_reach, def_reach_map) 111 fix_point(liveness, live_map) 112 return live_map 113 114 115_dead_maps_result = namedtuple('dead_maps_result', 'internal,escaping,combined') 116 117 118def compute_dead_maps(cfg, blocks, live_map, var_def_map): 119 """ 120 Compute the end-of-live information for variables. 121 `live_map` contains a mapping of block offset to all the living 122 variables at the ENTRY of the block. 123 """ 124 # The following three dictionaries will be 125 # { block offset -> set of variables to delete } 126 # all vars that should be deleted at the start of the successors 127 escaping_dead_map = defaultdict(set) 128 # all vars that should be deleted within this block 129 internal_dead_map = defaultdict(set) 130 # all vars that should be deleted after the function exit 131 exit_dead_map = defaultdict(set) 132 133 for offset, ir_block in blocks.items(): 134 # live vars WITHIN the block will include all the locally 135 # defined variables 136 cur_live_set = live_map[offset] | var_def_map[offset] 137 # vars alive in the outgoing blocks 138 outgoing_live_map = dict((out_blk, live_map[out_blk]) 139 for out_blk, _data in cfg.successors(offset)) 140 # vars to keep alive for the terminator 141 terminator_liveset = set(v.name 142 for v in ir_block.terminator.list_vars()) 143 # vars to keep alive in the successors 144 combined_liveset = reduce(operator.or_, outgoing_live_map.values(), 145 set()) 146 # include variables used in terminator 147 combined_liveset |= terminator_liveset 148 # vars that are dead within the block because they are not 149 # propagated to any outgoing blocks 150 internal_set = cur_live_set - combined_liveset 151 internal_dead_map[offset] = internal_set 152 # vars that escape this block 153 escaping_live_set = cur_live_set - internal_set 154 for out_blk, new_live_set in outgoing_live_map.items(): 155 # successor should delete the unused escaped vars 156 new_live_set = new_live_set | var_def_map[out_blk] 157 escaping_dead_map[out_blk] |= escaping_live_set - new_live_set 158 159 # if no outgoing blocks 160 if not outgoing_live_map: 161 # insert var used by terminator 162 exit_dead_map[offset] = terminator_liveset 163 164 # Verify that the dead maps cover all live variables 165 all_vars = reduce(operator.or_, live_map.values(), set()) 166 internal_dead_vars = reduce(operator.or_, internal_dead_map.values(), 167 set()) 168 escaping_dead_vars = reduce(operator.or_, escaping_dead_map.values(), 169 set()) 170 exit_dead_vars = reduce(operator.or_, exit_dead_map.values(), set()) 171 dead_vars = (internal_dead_vars | escaping_dead_vars | exit_dead_vars) 172 missing_vars = all_vars - dead_vars 173 if missing_vars: 174 # There are no exit points 175 if not cfg.exit_points(): 176 # We won't be able to verify this 177 pass 178 else: 179 msg = 'liveness info missing for vars: {0}'.format(missing_vars) 180 raise RuntimeError(msg) 181 182 combined = dict((k, internal_dead_map[k] | escaping_dead_map[k]) 183 for k in blocks) 184 185 return _dead_maps_result(internal=internal_dead_map, 186 escaping=escaping_dead_map, 187 combined=combined) 188 189 190def compute_live_variables(cfg, blocks, var_def_map, var_dead_map): 191 """ 192 Compute the live variables at the beginning of each block 193 and at each yield point. 194 The ``var_def_map`` and ``var_dead_map`` indicates the variable defined 195 and deleted at each block, respectively. 196 """ 197 # live var at the entry per block 198 block_entry_vars = defaultdict(set) 199 200 def fix_point_progress(): 201 return tuple(map(len, block_entry_vars.values())) 202 203 old_point = None 204 new_point = fix_point_progress() 205 206 # Propagate defined variables and still live the successors. 207 # (note the entry block automatically gets an empty set) 208 209 # Note: This is finding the actual available variables at the entry 210 # of each block. The algorithm in compute_live_map() is finding 211 # the variable that must be available at the entry of each block. 212 # This is top-down in the dataflow. The other one is bottom-up. 213 while old_point != new_point: 214 # We iterate until the result stabilizes. This is necessary 215 # because of loops in the graphself. 216 for offset in blocks: 217 # vars available + variable defined 218 avail = block_entry_vars[offset] | var_def_map[offset] 219 # subtract variables deleted 220 avail -= var_dead_map[offset] 221 # add ``avail`` to each successors 222 for succ, _data in cfg.successors(offset): 223 block_entry_vars[succ] |= avail 224 225 old_point = new_point 226 new_point = fix_point_progress() 227 228 return block_entry_vars 229 230 231# 232# Analysis related to controlflow 233# 234 235def compute_cfg_from_blocks(blocks): 236 cfg = CFGraph() 237 for k in blocks: 238 cfg.add_node(k) 239 240 for k, b in blocks.items(): 241 term = b.terminator 242 for target in term.get_targets(): 243 cfg.add_edge(k, target) 244 245 cfg.set_entry_point(min(blocks)) 246 cfg.process() 247 return cfg 248 249 250def find_top_level_loops(cfg): 251 """ 252 A generator that yields toplevel loops given a control-flow-graph 253 """ 254 blocks_in_loop = set() 255 # get loop bodies 256 for loop in cfg.loops().values(): 257 insiders = set(loop.body) | set(loop.entries) | set(loop.exits) 258 insiders.discard(loop.header) 259 blocks_in_loop |= insiders 260 # find loop that is not part of other loops 261 for loop in cfg.loops().values(): 262 if loop.header not in blocks_in_loop: 263 yield _fix_loop_exit(cfg, loop) 264 265 266def _fix_loop_exit(cfg, loop): 267 """ 268 Fixes loop.exits for Py3.8 bytecode CFG changes. 269 This is to handle `break` inside loops. 270 """ 271 # Computes the common postdoms of exit nodes 272 postdoms = cfg.post_dominators() 273 exits = reduce( 274 operator.and_, 275 [postdoms[b] for b in loop.exits], 276 loop.exits, 277 ) 278 if exits: 279 # Put the non-common-exits as body nodes 280 body = loop.body | loop.exits - exits 281 return loop._replace(exits=exits, body=body) 282 else: 283 return loop 284 285 286# Used to describe a nullified condition in dead branch pruning 287nullified = namedtuple('nullified', 'condition, taken_br, rewrite_stmt') 288 289 290# Functions to manipulate IR 291def dead_branch_prune(func_ir, called_args): 292 """ 293 Removes dead branches based on constant inference from function args. 294 This directly mutates the IR. 295 296 func_ir is the IR 297 called_args are the actual arguments with which the function is called 298 """ 299 from numba.core.ir_utils import (get_definition, guard, find_const, 300 GuardException) 301 302 DEBUG = 0 303 304 def find_branches(func_ir): 305 # find *all* branches 306 branches = [] 307 for blk in func_ir.blocks.values(): 308 branch_or_jump = blk.body[-1] 309 if isinstance(branch_or_jump, ir.Branch): 310 branch = branch_or_jump 311 pred = guard(get_definition, func_ir, branch.cond.name) 312 if pred is not None and pred.op == "call": 313 function = guard(get_definition, func_ir, pred.func) 314 if (function is not None and 315 isinstance(function, ir.Global) and 316 function.value is bool): 317 condition = guard(get_definition, func_ir, pred.args[0]) 318 if condition is not None: 319 branches.append((branch, condition, blk)) 320 return branches 321 322 def do_prune(take_truebr, blk): 323 keep = branch.truebr if take_truebr else branch.falsebr 324 # replace the branch with a direct jump 325 jmp = ir.Jump(keep, loc=branch.loc) 326 blk.body[-1] = jmp 327 return 1 if keep == branch.truebr else 0 328 329 def prune_by_type(branch, condition, blk, *conds): 330 # this prunes a given branch and fixes up the IR 331 # at least one needs to be a NoneType 332 lhs_cond, rhs_cond = conds 333 lhs_none = isinstance(lhs_cond, types.NoneType) 334 rhs_none = isinstance(rhs_cond, types.NoneType) 335 if lhs_none or rhs_none: 336 try: 337 take_truebr = condition.fn(lhs_cond, rhs_cond) 338 except Exception: 339 return False, None 340 if DEBUG > 0: 341 kill = branch.falsebr if take_truebr else branch.truebr 342 print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, 343 condition.fn) 344 taken = do_prune(take_truebr, blk) 345 return True, taken 346 return False, None 347 348 def prune_by_value(branch, condition, blk, *conds): 349 lhs_cond, rhs_cond = conds 350 try: 351 take_truebr = condition.fn(lhs_cond, rhs_cond) 352 except Exception: 353 return False, None 354 if DEBUG > 0: 355 kill = branch.falsebr if take_truebr else branch.truebr 356 print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn) 357 taken = do_prune(take_truebr, blk) 358 return True, taken 359 360 def prune_by_predicate(branch, pred, blk): 361 try: 362 # Just to prevent accidents, whilst already guarded, ensure this 363 # is an ir.Const 364 if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)): 365 raise TypeError('Expected constant Numba IR node') 366 take_truebr = bool(pred.value) 367 except TypeError: 368 return False, None 369 if DEBUG > 0: 370 kill = branch.falsebr if take_truebr else branch.truebr 371 print("Pruning %s" % kill, branch, pred) 372 taken = do_prune(take_truebr, blk) 373 return True, taken 374 375 class Unknown(object): 376 pass 377 378 def resolve_input_arg_const(input_arg_idx): 379 """ 380 Resolves an input arg to a constant (if possible) 381 """ 382 input_arg_ty = called_args[input_arg_idx] 383 384 # comparing to None? 385 if isinstance(input_arg_ty, types.NoneType): 386 return input_arg_ty 387 388 # is it a kwarg default 389 if isinstance(input_arg_ty, types.Omitted): 390 val = input_arg_ty.value 391 if isinstance(val, types.NoneType): 392 return val 393 elif val is None: 394 return types.NoneType('none') 395 396 # literal type, return the type itself so comparisons like `x == None` 397 # still work as e.g. x = types.int64 will never be None/NoneType so 398 # the branch can still be pruned 399 return getattr(input_arg_ty, 'literal_type', Unknown()) 400 401 if DEBUG > 1: 402 print("before".center(80, '-')) 403 print(func_ir.dump()) 404 405 # This looks for branches where: 406 # at least one arg of the condition is in input args and const 407 # at least one an arg of the condition is a const 408 # if the condition is met it will replace the branch with a jump 409 branch_info = find_branches(func_ir) 410 # stores conditions that have no impact post prune 411 nullified_conditions = [] 412 413 for branch, condition, blk in branch_info: 414 const_conds = [] 415 if isinstance(condition, ir.Expr) and condition.op == 'binop': 416 prune = prune_by_value 417 for arg in [condition.lhs, condition.rhs]: 418 resolved_const = Unknown() 419 arg_def = guard(get_definition, func_ir, arg) 420 if isinstance(arg_def, ir.Arg): 421 # it's an e.g. literal argument to the function 422 resolved_const = resolve_input_arg_const(arg_def.index) 423 prune = prune_by_type 424 else: 425 # it's some const argument to the function, cannot use guard 426 # here as the const itself may be None 427 try: 428 resolved_const = find_const(func_ir, arg) 429 if resolved_const is None: 430 resolved_const = types.NoneType('none') 431 except GuardException: 432 pass 433 434 if not isinstance(resolved_const, Unknown): 435 const_conds.append(resolved_const) 436 437 # lhs/rhs are consts 438 if len(const_conds) == 2: 439 # prune the branch, switch the branch for an unconditional jump 440 prune_stat, taken = prune(branch, condition, blk, *const_conds) 441 if(prune_stat): 442 # add the condition to the list of nullified conditions 443 nullified_conditions.append(nullified(condition, taken, 444 True)) 445 else: 446 # see if this is a branch on a constant value predicate 447 resolved_const = Unknown() 448 try: 449 pred_call = get_definition(func_ir, branch.cond) 450 resolved_const = find_const(func_ir, pred_call.args[0]) 451 if resolved_const is None: 452 resolved_const = types.NoneType('none') 453 except GuardException: 454 pass 455 456 if not isinstance(resolved_const, Unknown): 457 prune_stat, taken = prune_by_predicate(branch, condition, blk) 458 if(prune_stat): 459 # add the condition to the list of nullified conditions 460 nullified_conditions.append(nullified(condition, taken, 461 False)) 462 463 # 'ERE BE DRAGONS... 464 # It is the evaluation of the condition expression that often trips up type 465 # inference, so ideally it would be removed as it is effectively rendered 466 # dead by the unconditional jump if a branch was pruned. However, there may 467 # be references to the condition that exist in multiple places (e.g. dels) 468 # and we cannot run DCE here as typing has not taken place to give enough 469 # information to run DCE safely. Upshot of all this is the condition gets 470 # rewritten below into a benign const that typing will be happy with and DCE 471 # can remove it and its reference post typing when it is safe to do so 472 # (if desired). It is required that the const is assigned a value that 473 # indicates the branch taken as its mutated value would be read in the case 474 # of object mode fall back in place of the condition itself. For 475 # completeness the func_ir._definitions and ._consts are also updated to 476 # make the IR state self consistent. 477 478 deadcond = [x.condition for x in nullified_conditions] 479 for _, cond, blk in branch_info: 480 if cond in deadcond: 481 for x in blk.body: 482 if isinstance(x, ir.Assign) and x.value is cond: 483 # rewrite the condition as a true/false bit 484 nullified_info = nullified_conditions[deadcond.index(cond)] 485 # only do a rewrite of conditions, predicates need to retain 486 # their value as they may be used later. 487 if nullified_info.rewrite_stmt: 488 branch_bit = nullified_info.taken_br 489 x.value = ir.Const(branch_bit, loc=x.loc) 490 # update the specific definition to the new const 491 defns = func_ir._definitions[x.target.name] 492 repl_idx = defns.index(cond) 493 defns[repl_idx] = x.value 494 495 # Remove dead blocks, this is safe as it relies on the CFG only. 496 cfg = compute_cfg_from_blocks(func_ir.blocks) 497 for dead in cfg.dead_nodes(): 498 del func_ir.blocks[dead] 499 500 # if conditions were nullified then consts were rewritten, update 501 if nullified_conditions: 502 func_ir._consts = consts.ConstantInference(func_ir) 503 504 if DEBUG > 1: 505 print("after".center(80, '-')) 506 print(func_ir.dump()) 507 508 509def rewrite_semantic_constants(func_ir, called_args): 510 """ 511 This rewrites values known to be constant by their semantics as ir.Const 512 nodes, this is to give branch pruning the best chance possible of killing 513 branches. An example might be rewriting len(tuple) as the literal length. 514 515 func_ir is the IR 516 called_args are the actual arguments with which the function is called 517 """ 518 DEBUG = 0 519 520 if DEBUG > 1: 521 print(("rewrite_semantic_constants: " + 522 func_ir.func_id.func_name).center(80, '-')) 523 print("before".center(80, '*')) 524 func_ir.dump() 525 526 def rewrite_statement(func_ir, stmt, new_val): 527 """ 528 Rewrites the stmt as a ir.Const new_val and fixes up the entries in 529 func_ir._definitions 530 """ 531 stmt.value = ir.Const(new_val, stmt.loc) 532 defns = func_ir._definitions[stmt.target.name] 533 repl_idx = defns.index(val) 534 defns[repl_idx] = stmt.value 535 536 def rewrite_array_ndim(val, func_ir, called_args): 537 # rewrite Array.ndim as const(ndim) 538 if getattr(val, 'op', None) == 'getattr': 539 if val.attr == 'ndim': 540 arg_def = guard(get_definition, func_ir, val.value) 541 if isinstance(arg_def, ir.Arg): 542 argty = called_args[arg_def.index] 543 if isinstance(argty, types.Array): 544 rewrite_statement(func_ir, stmt, argty.ndim) 545 546 def rewrite_tuple_len(val, func_ir, called_args): 547 # rewrite len(tuple) as const(len(tuple)) 548 if getattr(val, 'op', None) == 'call': 549 func = guard(get_definition, func_ir, val.func) 550 if (func is not None and isinstance(func, ir.Global) and 551 getattr(func, 'value', None) is len): 552 553 (arg,) = val.args 554 arg_def = guard(get_definition, func_ir, arg) 555 if isinstance(arg_def, ir.Arg): 556 argty = called_args[arg_def.index] 557 if isinstance(argty, types.BaseTuple): 558 rewrite_statement(func_ir, stmt, argty.count) 559 560 from numba.core.ir_utils import get_definition, guard 561 for blk in func_ir.blocks.values(): 562 for stmt in blk.body: 563 if isinstance(stmt, ir.Assign): 564 val = stmt.value 565 if isinstance(val, ir.Expr): 566 rewrite_array_ndim(val, func_ir, called_args) 567 rewrite_tuple_len(val, func_ir, called_args) 568 569 if DEBUG > 1: 570 print("after".center(80, '*')) 571 func_ir.dump() 572 print('-' * 80) 573 574 575def find_literally_calls(func_ir, argtypes): 576 """An analysis to find `numba.literally` call inside the given IR. 577 When an unsatisfied literal typing request is found, a `ForceLiteralArg` 578 exception is raised. 579 580 Parameters 581 ---------- 582 583 func_ir : numba.ir.FunctionIR 584 585 argtypes : Sequence[numba.types.Type] 586 The argument types. 587 """ 588 from numba.core import ir_utils 589 590 marked_args = set() 591 first_loc = {} 592 # Scan for literally calls 593 for blk in func_ir.blocks.values(): 594 for assign in blk.find_exprs(op='call'): 595 var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func) 596 if isinstance(var, (ir.Global, ir.FreeVar)): 597 fnobj = var.value 598 else: 599 fnobj = ir_utils.guard(ir_utils.resolve_func_from_module, 600 func_ir, var) 601 if fnobj is special.literally: 602 # Found 603 [arg] = assign.args 604 defarg = func_ir.get_definition(arg) 605 if isinstance(defarg, ir.Arg): 606 argindex = defarg.index 607 marked_args.add(argindex) 608 first_loc.setdefault(argindex, assign.loc) 609 # Signal the dispatcher to force literal typing 610 for pos in marked_args: 611 query_arg = argtypes[pos] 612 do_raise = (isinstance(query_arg, types.InitialValue) and 613 query_arg.initial_value is None) 614 if do_raise: 615 loc = first_loc[pos] 616 raise errors.ForceLiteralArg(marked_args, loc=loc) 617 618 if not isinstance(query_arg, (types.Literal, types.InitialValue)): 619 loc = first_loc[pos] 620 raise errors.ForceLiteralArg(marked_args, loc=loc) 621