1# 2# Copyright (c) 2017 Intel Corporation 3# SPDX-License-Identifier: BSD-2-Clause 4# 5 6import numpy 7 8import types as pytypes 9import collections 10import operator 11import warnings 12 13from llvmlite import ir as lir 14 15import numba 16from numba.core.extending import _Intrinsic 17from numba.core import types, utils, typing, ir, analysis, postproc, rewrites, config, cgutils 18from numba.core.typing.templates import (signature, infer_global, 19 AbstractTemplate) 20from numba.core.imputils import impl_ret_untracked 21from numba.core.analysis import (compute_live_map, compute_use_defs, 22 compute_cfg_from_blocks) 23from numba.core.errors import (TypingError, UnsupportedError, 24 NumbaPendingDeprecationWarning, NumbaWarning, 25 feedback_details, CompilerError) 26 27import copy 28 29_unique_var_count = 0 30 31 32def mk_unique_var(prefix): 33 global _unique_var_count 34 var = prefix + "." + str(_unique_var_count) 35 _unique_var_count = _unique_var_count + 1 36 return var 37 38 39_max_label = 0 40 41 42def get_unused_var_name(prefix, var_table): 43 """ Get a new var name with a given prefix and 44 make sure it is unused in the given variable table. 45 """ 46 cur = 0 47 while True: 48 var = prefix + str(cur) 49 if var not in var_table: 50 return var 51 cur += 1 52 53 54def next_label(): 55 global _max_label 56 _max_label += 1 57 return _max_label 58 59 60def mk_alloc(typemap, calltypes, lhs, size_var, dtype, scope, loc): 61 """generate an array allocation with np.empty() and return list of nodes. 62 size_var can be an int variable or tuple of int variables. 63 """ 64 out = [] 65 ndims = 1 66 size_typ = types.intp 67 if isinstance(size_var, tuple): 68 if len(size_var) == 1: 69 size_var = size_var[0] 70 size_var = convert_size_to_var(size_var, typemap, scope, loc, out) 71 else: 72 # tuple_var = build_tuple([size_var...]) 73 ndims = len(size_var) 74 tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc) 75 if typemap: 76 typemap[tuple_var.name] = types.containers.UniTuple( 77 types.intp, ndims) 78 # constant sizes need to be assigned to vars 79 new_sizes = [convert_size_to_var(s, typemap, scope, loc, out) 80 for s in size_var] 81 tuple_call = ir.Expr.build_tuple(new_sizes, loc) 82 tuple_assign = ir.Assign(tuple_call, tuple_var, loc) 83 out.append(tuple_assign) 84 size_var = tuple_var 85 size_typ = types.containers.UniTuple(types.intp, ndims) 86 # g_np_var = Global(numpy) 87 g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) 88 if typemap: 89 typemap[g_np_var.name] = types.misc.Module(numpy) 90 g_np = ir.Global('np', numpy, loc) 91 g_np_assign = ir.Assign(g_np, g_np_var, loc) 92 # attr call: empty_attr = getattr(g_np_var, empty) 93 empty_attr_call = ir.Expr.getattr(g_np_var, "empty", loc) 94 attr_var = ir.Var(scope, mk_unique_var("$empty_attr_attr"), loc) 95 if typemap: 96 typemap[attr_var.name] = get_np_ufunc_typ(numpy.empty) 97 attr_assign = ir.Assign(empty_attr_call, attr_var, loc) 98 # alloc call: lhs = empty_attr(size_var, typ_var) 99 typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc) 100 if typemap: 101 typemap[typ_var.name] = types.functions.NumberClass(dtype) 102 # assuming str(dtype) returns valid np dtype string 103 dtype_str = str(dtype) 104 if dtype_str=='bool': 105 # empty doesn't like 'bool' sometimes (e.g. kmeans example) 106 dtype_str = 'bool_' 107 np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc) 108 typ_var_assign = ir.Assign(np_typ_getattr, typ_var, loc) 109 alloc_call = ir.Expr.call(attr_var, [size_var, typ_var], (), loc) 110 if calltypes: 111 calltypes[alloc_call] = typemap[attr_var.name].get_call_type( 112 typing.Context(), [size_typ, types.functions.NumberClass(dtype)], {}) 113 # signature( 114 # types.npytypes.Array(dtype, ndims, 'C'), size_typ, 115 # types.functions.NumberClass(dtype)) 116 alloc_assign = ir.Assign(alloc_call, lhs, loc) 117 118 out.extend([g_np_assign, attr_assign, typ_var_assign, alloc_assign]) 119 return out 120 121 122def convert_size_to_var(size_var, typemap, scope, loc, nodes): 123 if isinstance(size_var, int): 124 new_size = ir.Var(scope, mk_unique_var("$alloc_size"), loc) 125 if typemap: 126 typemap[new_size.name] = types.intp 127 size_assign = ir.Assign(ir.Const(size_var, loc), new_size, loc) 128 nodes.append(size_assign) 129 return new_size 130 assert isinstance(size_var, ir.Var) 131 return size_var 132 133 134def get_np_ufunc_typ(func): 135 """get type of the incoming function from builtin registry""" 136 for (k, v) in typing.npydecl.registry.globals: 137 if k == func: 138 return v 139 raise RuntimeError("type for func ", func, " not found") 140 141 142def mk_range_block(typemap, start, stop, step, calltypes, scope, loc): 143 """make a block that initializes loop range and iteration variables. 144 target label in jump needs to be set. 145 """ 146 # g_range_var = Global(range) 147 g_range_var = ir.Var(scope, mk_unique_var("$range_g_var"), loc) 148 typemap[g_range_var.name] = get_global_func_typ(range) 149 g_range = ir.Global('range', range, loc) 150 g_range_assign = ir.Assign(g_range, g_range_var, loc) 151 arg_nodes, args = _mk_range_args(typemap, start, stop, step, scope, loc) 152 # range_call_var = call g_range_var(start, stop, step) 153 range_call = ir.Expr.call(g_range_var, args, (), loc) 154 calltypes[range_call] = typemap[g_range_var.name].get_call_type( 155 typing.Context(), [types.intp] * len(args), {}) 156 #signature(types.range_state64_type, types.intp) 157 range_call_var = ir.Var(scope, mk_unique_var("$range_c_var"), loc) 158 typemap[range_call_var.name] = types.iterators.RangeType(types.intp) 159 range_call_assign = ir.Assign(range_call, range_call_var, loc) 160 # iter_var = getiter(range_call_var) 161 iter_call = ir.Expr.getiter(range_call_var, loc) 162 calltypes[iter_call] = signature(types.range_iter64_type, 163 types.range_state64_type) 164 iter_var = ir.Var(scope, mk_unique_var("$iter_var"), loc) 165 typemap[iter_var.name] = types.iterators.RangeIteratorType(types.intp) 166 iter_call_assign = ir.Assign(iter_call, iter_var, loc) 167 # $phi = iter_var 168 phi_var = ir.Var(scope, mk_unique_var("$phi"), loc) 169 typemap[phi_var.name] = types.iterators.RangeIteratorType(types.intp) 170 phi_assign = ir.Assign(iter_var, phi_var, loc) 171 # jump to header 172 jump_header = ir.Jump(-1, loc) 173 range_block = ir.Block(scope, loc) 174 range_block.body = arg_nodes + [g_range_assign, range_call_assign, 175 iter_call_assign, phi_assign, jump_header] 176 return range_block 177 178 179def _mk_range_args(typemap, start, stop, step, scope, loc): 180 nodes = [] 181 if isinstance(stop, ir.Var): 182 g_stop_var = stop 183 else: 184 assert isinstance(stop, int) 185 g_stop_var = ir.Var(scope, mk_unique_var("$range_stop"), loc) 186 if typemap: 187 typemap[g_stop_var.name] = types.intp 188 stop_assign = ir.Assign(ir.Const(stop, loc), g_stop_var, loc) 189 nodes.append(stop_assign) 190 if start == 0 and step == 1: 191 return nodes, [g_stop_var] 192 193 if isinstance(start, ir.Var): 194 g_start_var = start 195 else: 196 assert isinstance(start, int) 197 g_start_var = ir.Var(scope, mk_unique_var("$range_start"), loc) 198 if typemap: 199 typemap[g_start_var.name] = types.intp 200 start_assign = ir.Assign(ir.Const(start, loc), g_start_var, loc) 201 nodes.append(start_assign) 202 if step == 1: 203 return nodes, [g_start_var, g_stop_var] 204 205 if isinstance(step, ir.Var): 206 g_step_var = step 207 else: 208 assert isinstance(step, int) 209 g_step_var = ir.Var(scope, mk_unique_var("$range_step"), loc) 210 if typemap: 211 typemap[g_step_var.name] = types.intp 212 step_assign = ir.Assign(ir.Const(step, loc), g_step_var, loc) 213 nodes.append(step_assign) 214 215 return nodes, [g_start_var, g_stop_var, g_step_var] 216 217 218def get_global_func_typ(func): 219 """get type variable for func() from builtin registry""" 220 for (k, v) in typing.templates.builtin_registry.globals: 221 if k == func: 222 return v 223 raise RuntimeError("func type not found {}".format(func)) 224 225 226def mk_loop_header(typemap, phi_var, calltypes, scope, loc): 227 """make a block that is a loop header updating iteration variables. 228 target labels in branch need to be set. 229 """ 230 # iternext_var = iternext(phi_var) 231 iternext_var = ir.Var(scope, mk_unique_var("$iternext_var"), loc) 232 typemap[iternext_var.name] = types.containers.Pair( 233 types.intp, types.boolean) 234 iternext_call = ir.Expr.iternext(phi_var, loc) 235 calltypes[iternext_call] = signature( 236 types.containers.Pair( 237 types.intp, 238 types.boolean), 239 types.range_iter64_type) 240 iternext_assign = ir.Assign(iternext_call, iternext_var, loc) 241 # pair_first_var = pair_first(iternext_var) 242 pair_first_var = ir.Var(scope, mk_unique_var("$pair_first_var"), loc) 243 typemap[pair_first_var.name] = types.intp 244 pair_first_call = ir.Expr.pair_first(iternext_var, loc) 245 pair_first_assign = ir.Assign(pair_first_call, pair_first_var, loc) 246 # pair_second_var = pair_second(iternext_var) 247 pair_second_var = ir.Var(scope, mk_unique_var("$pair_second_var"), loc) 248 typemap[pair_second_var.name] = types.boolean 249 pair_second_call = ir.Expr.pair_second(iternext_var, loc) 250 pair_second_assign = ir.Assign(pair_second_call, pair_second_var, loc) 251 # phi_b_var = pair_first_var 252 phi_b_var = ir.Var(scope, mk_unique_var("$phi"), loc) 253 typemap[phi_b_var.name] = types.intp 254 phi_b_assign = ir.Assign(pair_first_var, phi_b_var, loc) 255 # branch pair_second_var body_block out_block 256 branch = ir.Branch(pair_second_var, -1, -1, loc) 257 header_block = ir.Block(scope, loc) 258 header_block.body = [iternext_assign, pair_first_assign, 259 pair_second_assign, phi_b_assign, branch] 260 return header_block 261 262 263def legalize_names(varnames): 264 """returns a dictionary for conversion of variable names to legal 265 parameter names. 266 """ 267 var_map = {} 268 for var in varnames: 269 new_name = var.replace("_", "__").replace("$", "_").replace(".", "_") 270 assert new_name not in var_map 271 var_map[var] = new_name 272 return var_map 273 274 275def get_name_var_table(blocks): 276 """create a mapping from variable names to their ir.Var objects""" 277 def get_name_var_visit(var, namevar): 278 namevar[var.name] = var 279 return var 280 namevar = {} 281 visit_vars(blocks, get_name_var_visit, namevar) 282 return namevar 283 284 285def replace_var_names(blocks, namedict): 286 """replace variables (ir.Var to ir.Var) from dictionary (name -> name)""" 287 # remove identity values to avoid infinite loop 288 new_namedict = {} 289 for l, r in namedict.items(): 290 if l != r: 291 new_namedict[l] = r 292 293 def replace_name(var, namedict): 294 assert isinstance(var, ir.Var) 295 while var.name in namedict: 296 var = ir.Var(var.scope, namedict[var.name], var.loc) 297 return var 298 visit_vars(blocks, replace_name, new_namedict) 299 300 301def replace_var_callback(var, vardict): 302 assert isinstance(var, ir.Var) 303 while var.name in vardict.keys(): 304 assert(vardict[var.name].name != var.name) 305 new_var = vardict[var.name] 306 var = ir.Var(new_var.scope, new_var.name, new_var.loc) 307 return var 308 309 310def replace_vars(blocks, vardict): 311 """replace variables (ir.Var to ir.Var) from dictionary (name -> ir.Var)""" 312 # remove identity values to avoid infinite loop 313 new_vardict = {} 314 for l, r in vardict.items(): 315 if l != r.name: 316 new_vardict[l] = r 317 visit_vars(blocks, replace_var_callback, new_vardict) 318 319 320def replace_vars_stmt(stmt, vardict): 321 visit_vars_stmt(stmt, replace_var_callback, vardict) 322 323 324def replace_vars_inner(node, vardict): 325 return visit_vars_inner(node, replace_var_callback, vardict) 326 327 328# other packages that define new nodes add calls to visit variables in them 329# format: {type:function} 330visit_vars_extensions = {} 331 332 333def visit_vars(blocks, callback, cbdata): 334 """go over statements of block bodies and replace variable names with 335 dictionary. 336 """ 337 for block in blocks.values(): 338 for stmt in block.body: 339 visit_vars_stmt(stmt, callback, cbdata) 340 return 341 342 343def visit_vars_stmt(stmt, callback, cbdata): 344 # let external calls handle stmt if type matches 345 for t, f in visit_vars_extensions.items(): 346 if isinstance(stmt, t): 347 f(stmt, callback, cbdata) 348 return 349 if isinstance(stmt, ir.Assign): 350 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 351 stmt.value = visit_vars_inner(stmt.value, callback, cbdata) 352 elif isinstance(stmt, ir.Arg): 353 stmt.name = visit_vars_inner(stmt.name, callback, cbdata) 354 elif isinstance(stmt, ir.Return): 355 stmt.value = visit_vars_inner(stmt.value, callback, cbdata) 356 elif isinstance(stmt, ir.Raise): 357 stmt.exception = visit_vars_inner(stmt.exception, callback, cbdata) 358 elif isinstance(stmt, ir.Branch): 359 stmt.cond = visit_vars_inner(stmt.cond, callback, cbdata) 360 elif isinstance(stmt, ir.Jump): 361 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 362 elif isinstance(stmt, ir.Del): 363 # Because Del takes only a var name, we make up by 364 # constructing a temporary variable. 365 var = ir.Var(None, stmt.value, stmt.loc) 366 var = visit_vars_inner(var, callback, cbdata) 367 stmt.value = var.name 368 elif isinstance(stmt, ir.DelAttr): 369 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 370 stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata) 371 elif isinstance(stmt, ir.SetAttr): 372 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 373 stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata) 374 stmt.value = visit_vars_inner(stmt.value, callback, cbdata) 375 elif isinstance(stmt, ir.DelItem): 376 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 377 stmt.index = visit_vars_inner(stmt.index, callback, cbdata) 378 elif isinstance(stmt, ir.StaticSetItem): 379 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 380 stmt.index_var = visit_vars_inner(stmt.index_var, callback, cbdata) 381 stmt.value = visit_vars_inner(stmt.value, callback, cbdata) 382 elif isinstance(stmt, ir.SetItem): 383 stmt.target = visit_vars_inner(stmt.target, callback, cbdata) 384 stmt.index = visit_vars_inner(stmt.index, callback, cbdata) 385 stmt.value = visit_vars_inner(stmt.value, callback, cbdata) 386 elif isinstance(stmt, ir.Print): 387 stmt.args = [visit_vars_inner(x, callback, cbdata) for x in stmt.args] 388 else: 389 # TODO: raise NotImplementedError("no replacement for IR node: ", stmt) 390 pass 391 return 392 393 394def visit_vars_inner(node, callback, cbdata): 395 if isinstance(node, ir.Var): 396 return callback(node, cbdata) 397 elif isinstance(node, list): 398 return [visit_vars_inner(n, callback, cbdata) for n in node] 399 elif isinstance(node, tuple): 400 return tuple([visit_vars_inner(n, callback, cbdata) for n in node]) 401 elif isinstance(node, ir.Expr): 402 # if node.op in ['binop', 'inplace_binop']: 403 # lhs = node.lhs.name 404 # rhs = node.rhs.name 405 # node.lhs.name = callback, cbdata.get(lhs, lhs) 406 # node.rhs.name = callback, cbdata.get(rhs, rhs) 407 for arg in node._kws.keys(): 408 node._kws[arg] = visit_vars_inner(node._kws[arg], callback, cbdata) 409 elif isinstance(node, ir.Yield): 410 node.value = visit_vars_inner(node.value, callback, cbdata) 411 return node 412 413 414add_offset_to_labels_extensions = {} 415 416 417def add_offset_to_labels(blocks, offset): 418 """add an offset to all block labels and jump/branch targets 419 """ 420 new_blocks = {} 421 for l, b in blocks.items(): 422 # some parfor last blocks might be empty 423 term = None 424 if b.body: 425 term = b.body[-1] 426 for inst in b.body: 427 for T, f in add_offset_to_labels_extensions.items(): 428 if isinstance(inst, T): 429 f_max = f(inst, offset) 430 if isinstance(term, ir.Jump): 431 b.body[-1] = ir.Jump(term.target + offset, term.loc) 432 if isinstance(term, ir.Branch): 433 b.body[-1] = ir.Branch(term.cond, term.truebr + offset, 434 term.falsebr + offset, term.loc) 435 new_blocks[l + offset] = b 436 return new_blocks 437 438 439find_max_label_extensions = {} 440 441 442def find_max_label(blocks): 443 max_label = 0 444 for l, b in blocks.items(): 445 term = None 446 if b.body: 447 term = b.body[-1] 448 for inst in b.body: 449 for T, f in find_max_label_extensions.items(): 450 if isinstance(inst, T): 451 f_max = f(inst) 452 if f_max > max_label: 453 max_label = f_max 454 if l > max_label: 455 max_label = l 456 return max_label 457 458 459def flatten_labels(blocks): 460 """makes the labels in range(0, len(blocks)), useful to compare CFGs 461 """ 462 # first bulk move the labels out of the rewrite range 463 blocks = add_offset_to_labels(blocks, find_max_label(blocks) + 1) 464 # order them in topo order because it's easier to read 465 new_blocks = {} 466 topo_order = find_topo_order(blocks) 467 l_map = dict() 468 idx = 0 469 for x in topo_order: 470 l_map[x] = idx 471 idx += 1 472 473 for t_node in topo_order: 474 b = blocks[t_node] 475 # some parfor last blocks might be empty 476 term = None 477 if b.body: 478 term = b.body[-1] 479 if isinstance(term, ir.Jump): 480 b.body[-1] = ir.Jump(l_map[term.target], term.loc) 481 if isinstance(term, ir.Branch): 482 b.body[-1] = ir.Branch(term.cond, l_map[term.truebr], 483 l_map[term.falsebr], term.loc) 484 new_blocks[l_map[t_node]] = b 485 return new_blocks 486 487 488def remove_dels(blocks): 489 """remove ir.Del nodes""" 490 for block in blocks.values(): 491 new_body = [] 492 for stmt in block.body: 493 if not isinstance(stmt, ir.Del): 494 new_body.append(stmt) 495 block.body = new_body 496 return 497 498 499def remove_args(blocks): 500 """remove ir.Arg nodes""" 501 for block in blocks.values(): 502 new_body = [] 503 for stmt in block.body: 504 if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg): 505 continue 506 new_body.append(stmt) 507 block.body = new_body 508 return 509 510 511def dead_code_elimination(func_ir, typemap=None, alias_map=None, 512 arg_aliases=None): 513 """ Performs dead code elimination and leaves the IR in a valid state on 514 exit 515 """ 516 do_post_proc = False 517 while (remove_dead(func_ir.blocks, func_ir.arg_names, func_ir, typemap, 518 alias_map, arg_aliases)): 519 do_post_proc = True 520 521 if do_post_proc: 522 post_proc = postproc.PostProcessor(func_ir) 523 post_proc.run() 524 525 526def remove_dead(blocks, args, func_ir, typemap=None, alias_map=None, arg_aliases=None): 527 """dead code elimination using liveness and CFG info. 528 Returns True if something has been removed, or False if nothing is removed. 529 """ 530 cfg = compute_cfg_from_blocks(blocks) 531 usedefs = compute_use_defs(blocks) 532 live_map = compute_live_map(cfg, blocks, usedefs.usemap, usedefs.defmap) 533 call_table, _ = get_call_table(blocks) 534 if alias_map is None or arg_aliases is None: 535 alias_map, arg_aliases = find_potential_aliases(blocks, args, typemap, 536 func_ir) 537 if config.DEBUG_ARRAY_OPT >= 1: 538 print("remove_dead alias map:", alias_map) 539 print("live_map:", live_map) 540 print("usemap:", usedefs.usemap) 541 print("defmap:", usedefs.defmap) 542 # keep set for easier search 543 alias_set = set(alias_map.keys()) 544 545 removed = False 546 for label, block in blocks.items(): 547 # find live variables at each statement to delete dead assignment 548 lives = {v.name for v in block.terminator.list_vars()} 549 if config.DEBUG_ARRAY_OPT >= 2: 550 print("remove_dead processing block", label, lives) 551 # find live variables at the end of block 552 for out_blk, _data in cfg.successors(label): 553 if config.DEBUG_ARRAY_OPT >= 2: 554 print("succ live_map", out_blk, live_map[out_blk]) 555 lives |= live_map[out_blk] 556 removed |= remove_dead_block(block, lives, call_table, arg_aliases, 557 alias_map, alias_set, func_ir, typemap) 558 559 return removed 560 561 562# other packages that define new nodes add calls to remove dead code in them 563# format: {type:function} 564remove_dead_extensions = {} 565 566 567def remove_dead_block(block, lives, call_table, arg_aliases, alias_map, 568 alias_set, func_ir, typemap): 569 """remove dead code using liveness info. 570 Mutable arguments (e.g. arrays) that are not definitely assigned are live 571 after return of function. 572 """ 573 # TODO: find mutable args that are not definitely assigned instead of 574 # assuming all args are live after return 575 removed = False 576 577 # add statements in reverse order 578 new_body = [block.terminator] 579 # for each statement in reverse order, excluding terminator 580 for stmt in reversed(block.body[:-1]): 581 if config.DEBUG_ARRAY_OPT >= 2: 582 print("remove_dead_block", stmt) 583 # aliases of lives are also live 584 alias_lives = set() 585 init_alias_lives = lives & alias_set 586 for v in init_alias_lives: 587 alias_lives |= alias_map[v] 588 lives_n_aliases = lives | alias_lives | arg_aliases 589 590 # let external calls handle stmt if type matches 591 if type(stmt) in remove_dead_extensions: 592 f = remove_dead_extensions[type(stmt)] 593 stmt = f(stmt, lives, lives_n_aliases, arg_aliases, alias_map, func_ir, 594 typemap) 595 if stmt is None: 596 if config.DEBUG_ARRAY_OPT >= 2: 597 print("Statement was removed.") 598 removed = True 599 continue 600 601 # ignore assignments that their lhs is not live or lhs==rhs 602 if isinstance(stmt, ir.Assign): 603 lhs = stmt.target 604 rhs = stmt.value 605 if lhs.name not in lives and has_no_side_effect( 606 rhs, lives_n_aliases, call_table): 607 if config.DEBUG_ARRAY_OPT >= 2: 608 print("Statement was removed.") 609 removed = True 610 continue 611 if isinstance(rhs, ir.Var) and lhs.name == rhs.name: 612 if config.DEBUG_ARRAY_OPT >= 2: 613 print("Statement was removed.") 614 removed = True 615 continue 616 # TODO: remove other nodes like SetItem etc. 617 618 if isinstance(stmt, ir.Del): 619 if stmt.value not in lives: 620 if config.DEBUG_ARRAY_OPT >= 2: 621 print("Statement was removed.") 622 removed = True 623 continue 624 625 if isinstance(stmt, ir.SetItem): 626 name = stmt.target.name 627 if name not in lives_n_aliases: 628 if config.DEBUG_ARRAY_OPT >= 2: 629 print("Statement was removed.") 630 continue 631 632 if type(stmt) in analysis.ir_extension_usedefs: 633 def_func = analysis.ir_extension_usedefs[type(stmt)] 634 uses, defs = def_func(stmt) 635 lives -= defs 636 lives |= uses 637 else: 638 lives |= {v.name for v in stmt.list_vars()} 639 if isinstance(stmt, ir.Assign): 640 lives.remove(lhs.name) 641 642 new_body.append(stmt) 643 new_body.reverse() 644 block.body = new_body 645 return removed 646 647# list of functions 648remove_call_handlers = [] 649 650def remove_dead_random_call(rhs, lives, call_list): 651 if len(call_list) == 3 and call_list[1:] == ['random', numpy]: 652 return call_list[0] not in {'seed', 'shuffle'} 653 return False 654 655remove_call_handlers.append(remove_dead_random_call) 656 657def has_no_side_effect(rhs, lives, call_table): 658 """ Returns True if this expression has no side effects that 659 would prevent re-ordering. 660 """ 661 from numba.parfors import array_analysis, parfor 662 from numba.misc.special import prange 663 if isinstance(rhs, ir.Expr) and rhs.op == 'call': 664 func_name = rhs.func.name 665 if func_name not in call_table or call_table[func_name] == []: 666 return False 667 call_list = call_table[func_name] 668 if (call_list == ['empty', numpy] or 669 call_list == [slice] or 670 call_list == ['stencil', numba] or 671 call_list == ['log', numpy] or 672 call_list == ['dtype', numpy] or 673 call_list == [array_analysis.wrap_index] or 674 call_list == [prange] or 675 call_list == ['prange', numba] or 676 call_list == [parfor.internal_prange]): 677 return True 678 elif (isinstance(call_list[0], _Intrinsic) and 679 (call_list[0]._name == 'empty_inferred' or 680 call_list[0]._name == 'unsafe_empty_inferred')): 681 return True 682 from numba.core.registry import CPUDispatcher 683 from numba.np.linalg import dot_3_mv_check_args 684 if isinstance(call_list[0], CPUDispatcher): 685 py_func = call_list[0].py_func 686 if py_func == dot_3_mv_check_args: 687 return True 688 for f in remove_call_handlers: 689 if f(rhs, lives, call_list): 690 return True 691 return False 692 if isinstance(rhs, ir.Expr) and rhs.op == 'inplace_binop': 693 return rhs.lhs.name not in lives 694 if isinstance(rhs, ir.Yield): 695 return False 696 if isinstance(rhs, ir.Expr) and rhs.op == 'pair_first': 697 # don't remove pair_first since prange looks for it 698 return False 699 return True 700 701is_pure_extensions = [] 702 703def is_pure(rhs, lives, call_table): 704 """ Returns True if every time this expression is evaluated it 705 returns the same result. This is not the case for things 706 like calls to numpy.random. 707 """ 708 if isinstance(rhs, ir.Expr): 709 if rhs.op == 'call': 710 func_name = rhs.func.name 711 if func_name not in call_table or call_table[func_name] == []: 712 return False 713 call_list = call_table[func_name] 714 if (call_list == [slice] or 715 call_list == ['log', numpy] or 716 call_list == ['empty', numpy]): 717 return True 718 for f in is_pure_extensions: 719 if f(rhs, lives, call_list): 720 return True 721 return False 722 elif rhs.op == 'getiter' or rhs.op == 'iternext': 723 return False 724 if isinstance(rhs, ir.Yield): 725 return False 726 return True 727 728def is_const_call(module_name, func_name): 729 # Returns True if there is no state in the given module changed by the given function. 730 if module_name == 'numpy': 731 if func_name in ['empty']: 732 return True 733 return False 734 735alias_analysis_extensions = {} 736alias_func_extensions = {} 737 738def find_potential_aliases(blocks, args, typemap, func_ir, alias_map=None, 739 arg_aliases=None): 740 "find all array aliases and argument aliases to avoid remove as dead" 741 if alias_map is None: 742 alias_map = {} 743 if arg_aliases is None: 744 arg_aliases = set(a for a in args if not is_immutable_type(a, typemap)) 745 746 # update definitions since they are not guaranteed to be up-to-date 747 # FIXME keep definitions up-to-date to avoid the need for rebuilding 748 func_ir._definitions = build_definitions(func_ir.blocks) 749 np_alias_funcs = ['ravel', 'transpose', 'reshape'] 750 751 for bl in blocks.values(): 752 for instr in bl.body: 753 if type(instr) in alias_analysis_extensions: 754 f = alias_analysis_extensions[type(instr)] 755 f(instr, args, typemap, func_ir, alias_map, arg_aliases) 756 if isinstance(instr, ir.Assign): 757 expr = instr.value 758 lhs = instr.target.name 759 # only mutable types can alias 760 if is_immutable_type(lhs, typemap): 761 continue 762 if isinstance(expr, ir.Var) and lhs!=expr.name: 763 _add_alias(lhs, expr.name, alias_map, arg_aliases) 764 # subarrays like A = B[0] for 2D B 765 if (isinstance(expr, ir.Expr) and (expr.op == 'cast' or 766 expr.op in ['getitem', 'static_getitem'])): 767 _add_alias(lhs, expr.value.name, alias_map, arg_aliases) 768 # array attributes like A.T 769 if (isinstance(expr, ir.Expr) and expr.op == 'getattr' 770 and expr.attr in ['T', 'ctypes', 'flat']): 771 _add_alias(lhs, expr.value.name, alias_map, arg_aliases) 772 # a = b.c. a should alias b 773 if (isinstance(expr, ir.Expr) and expr.op == 'getattr' 774 and expr.value.name in arg_aliases): 775 _add_alias(lhs, expr.value.name, alias_map, arg_aliases) 776 # calls that can create aliases such as B = A.ravel() 777 if isinstance(expr, ir.Expr) and expr.op == 'call': 778 fdef = guard(find_callname, func_ir, expr, typemap) 779 # TODO: sometimes gufunc backend creates duplicate code 780 # causing find_callname to fail. Example: test_argmax 781 # ignored here since those cases don't create aliases 782 # but should be fixed in general 783 if fdef is None: 784 continue 785 fname, fmod = fdef 786 if fdef in alias_func_extensions: 787 alias_func = alias_func_extensions[fdef] 788 alias_func(lhs, expr.args, alias_map, arg_aliases) 789 if fmod == 'numpy' and fname in np_alias_funcs: 790 _add_alias(lhs, expr.args[0].name, alias_map, arg_aliases) 791 if isinstance(fmod, ir.Var) and fname in np_alias_funcs: 792 _add_alias(lhs, fmod.name, alias_map, arg_aliases) 793 794 # copy to avoid changing size during iteration 795 old_alias_map = copy.deepcopy(alias_map) 796 # combine all aliases transitively 797 for v in old_alias_map: 798 for w in old_alias_map[v]: 799 alias_map[v] |= alias_map[w] 800 for w in old_alias_map[v]: 801 alias_map[w] = alias_map[v] 802 803 return alias_map, arg_aliases 804 805def _add_alias(lhs, rhs, alias_map, arg_aliases): 806 if rhs in arg_aliases: 807 arg_aliases.add(lhs) 808 else: 809 if rhs not in alias_map: 810 alias_map[rhs] = set() 811 if lhs not in alias_map: 812 alias_map[lhs] = set() 813 alias_map[rhs].add(lhs) 814 alias_map[lhs].add(rhs) 815 return 816 817def is_immutable_type(var, typemap): 818 # Conservatively, assume mutable if type not available 819 if typemap is None or var not in typemap: 820 return False 821 typ = typemap[var] 822 # TODO: add more immutable types 823 if isinstance(typ, (types.Number, types.scalars._NPDatetimeBase, 824 types.containers.BaseTuple, 825 types.iterators.RangeType)): 826 return True 827 if typ==types.string: 828 return True 829 # consevatively, assume mutable 830 return False 831 832def copy_propagate(blocks, typemap): 833 """compute copy propagation information for each block using fixed-point 834 iteration on data flow equations: 835 in_b = intersect(predec(B)) 836 out_b = gen_b | (in_b - kill_b) 837 """ 838 cfg = compute_cfg_from_blocks(blocks) 839 entry = cfg.entry_point() 840 841 # format: dict of block labels to copies as tuples 842 # label -> (l,r) 843 c_data = init_copy_propagate_data(blocks, entry, typemap) 844 (gen_copies, all_copies, kill_copies, in_copies, out_copies) = c_data 845 846 old_point = None 847 new_point = copy.deepcopy(out_copies) 848 # comparison works since dictionary of built-in types 849 while old_point != new_point: 850 for label in blocks.keys(): 851 if label == entry: 852 continue 853 predecs = [i for i, _d in cfg.predecessors(label)] 854 # in_b = intersect(predec(B)) 855 in_copies[label] = out_copies[predecs[0]].copy() 856 for p in predecs: 857 in_copies[label] &= out_copies[p] 858 859 # out_b = gen_b | (in_b - kill_b) 860 out_copies[label] = (gen_copies[label] 861 | (in_copies[label] - kill_copies[label])) 862 old_point = new_point 863 new_point = copy.deepcopy(out_copies) 864 if config.DEBUG_ARRAY_OPT >= 1: 865 print("copy propagate out_copies:", out_copies) 866 return in_copies, out_copies 867 868 869def init_copy_propagate_data(blocks, entry, typemap): 870 """get initial condition of copy propagation data flow for each block. 871 """ 872 # gen is all definite copies, extra_kill is additional ones that may hit 873 # for example, parfors can have control flow so they may hit extra copies 874 gen_copies, extra_kill = get_block_copies(blocks, typemap) 875 # set of all program copies 876 all_copies = set() 877 for l, s in gen_copies.items(): 878 all_copies |= gen_copies[l] 879 kill_copies = {} 880 for label, gen_set in gen_copies.items(): 881 kill_copies[label] = set() 882 for lhs, rhs in all_copies: 883 if lhs in extra_kill[label] or rhs in extra_kill[label]: 884 kill_copies[label].add((lhs, rhs)) 885 # a copy is killed if it is not in this block and lhs or rhs are 886 # assigned in this block 887 assigned = {lhs for lhs, rhs in gen_set} 888 if ((lhs, rhs) not in gen_set 889 and (lhs in assigned or rhs in assigned)): 890 kill_copies[label].add((lhs, rhs)) 891 # set initial values 892 # all copies are in for all blocks except entry 893 in_copies = {l: all_copies.copy() for l in blocks.keys()} 894 in_copies[entry] = set() 895 out_copies = {} 896 for label in blocks.keys(): 897 # out_b = gen_b | (in_b - kill_b) 898 out_copies[label] = (gen_copies[label] 899 | (in_copies[label] - kill_copies[label])) 900 out_copies[entry] = gen_copies[entry] 901 return (gen_copies, all_copies, kill_copies, in_copies, out_copies) 902 903 904# other packages that define new nodes add calls to get copies in them 905# format: {type:function} 906copy_propagate_extensions = {} 907 908 909def get_block_copies(blocks, typemap): 910 """get copies generated and killed by each block 911 """ 912 block_copies = {} 913 extra_kill = {} 914 for label, block in blocks.items(): 915 assign_dict = {} 916 extra_kill[label] = set() 917 # assignments as dict to replace with latest value 918 for stmt in block.body: 919 for T, f in copy_propagate_extensions.items(): 920 if isinstance(stmt, T): 921 gen_set, kill_set = f(stmt, typemap) 922 for lhs, rhs in gen_set: 923 assign_dict[lhs] = rhs 924 # if a=b is in dict and b is killed, a is also killed 925 new_assign_dict = {} 926 for l, r in assign_dict.items(): 927 if l not in kill_set and r not in kill_set: 928 new_assign_dict[l] = r 929 if r in kill_set: 930 extra_kill[label].add(l) 931 assign_dict = new_assign_dict 932 extra_kill[label] |= kill_set 933 if isinstance(stmt, ir.Assign): 934 lhs = stmt.target.name 935 if isinstance(stmt.value, ir.Var): 936 rhs = stmt.value.name 937 # copy is valid only if same type (see 938 # TestCFunc.test_locals) 939 # Some transformations can produce assignments of the 940 # form A = A. We don't put these mapping in the 941 # copy propagation set because then you get cycles and 942 # infinite loops in the replacement phase. 943 if typemap[lhs] == typemap[rhs] and lhs != rhs: 944 assign_dict[lhs] = rhs 945 continue 946 if isinstance(stmt.value, 947 ir.Expr) and stmt.value.op == 'inplace_binop': 948 in1_var = stmt.value.lhs.name 949 in1_typ = typemap[in1_var] 950 # inplace_binop assigns first operand if mutable 951 if not (isinstance(in1_typ, types.Number) 952 or in1_typ == types.string): 953 extra_kill[label].add(in1_var) 954 # if a=b is in dict and b is killed, a is also killed 955 new_assign_dict = {} 956 for l, r in assign_dict.items(): 957 if l != in1_var and r != in1_var: 958 new_assign_dict[l] = r 959 if r == in1_var: 960 extra_kill[label].add(l) 961 assign_dict = new_assign_dict 962 extra_kill[label].add(lhs) 963 block_cps = set(assign_dict.items()) 964 block_copies[label] = block_cps 965 return block_copies, extra_kill 966 967 968# other packages that define new nodes add calls to apply copy propagate in them 969# format: {type:function} 970apply_copy_propagate_extensions = {} 971 972 973def apply_copy_propagate(blocks, in_copies, name_var_table, typemap, calltypes, 974 save_copies=None): 975 """apply copy propagation to IR: replace variables when copies available""" 976 # save_copies keeps an approximation of the copies that were applied, so 977 # that the variable names of removed user variables can be recovered to some 978 # extent. 979 if save_copies is None: 980 save_copies = [] 981 982 for label, block in blocks.items(): 983 var_dict = {l: name_var_table[r] for l, r in in_copies[label]} 984 # assignments as dict to replace with latest value 985 for stmt in block.body: 986 if type(stmt) in apply_copy_propagate_extensions: 987 f = apply_copy_propagate_extensions[type(stmt)] 988 f(stmt, var_dict, name_var_table, 989 typemap, calltypes, save_copies) 990 # only rhs of assignments should be replaced 991 # e.g. if x=y is available, x in x=z shouldn't be replaced 992 elif isinstance(stmt, ir.Assign): 993 stmt.value = replace_vars_inner(stmt.value, var_dict) 994 else: 995 replace_vars_stmt(stmt, var_dict) 996 fix_setitem_type(stmt, typemap, calltypes) 997 for T, f in copy_propagate_extensions.items(): 998 if isinstance(stmt, T): 999 gen_set, kill_set = f(stmt, typemap) 1000 for lhs, rhs in gen_set: 1001 if rhs in name_var_table: 1002 var_dict[lhs] = name_var_table[rhs] 1003 for l, r in var_dict.copy().items(): 1004 if l in kill_set or r.name in kill_set: 1005 var_dict.pop(l) 1006 if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var): 1007 lhs = stmt.target.name 1008 rhs = stmt.value.name 1009 # rhs could be replaced with lhs from previous copies 1010 if lhs != rhs: 1011 # copy is valid only if same type (see 1012 # TestCFunc.test_locals) 1013 if typemap[lhs] == typemap[rhs] and rhs in name_var_table: 1014 var_dict[lhs] = name_var_table[rhs] 1015 else: 1016 var_dict.pop(lhs, None) 1017 # a=b kills previous t=a 1018 lhs_kill = [] 1019 for k, v in var_dict.items(): 1020 if v.name == lhs: 1021 lhs_kill.append(k) 1022 for k in lhs_kill: 1023 var_dict.pop(k, None) 1024 if (isinstance(stmt, ir.Assign) 1025 and not isinstance(stmt.value, ir.Var)): 1026 lhs = stmt.target.name 1027 var_dict.pop(lhs, None) 1028 # previous t=a is killed if a is killed 1029 lhs_kill = [] 1030 for k, v in var_dict.items(): 1031 if v.name == lhs: 1032 lhs_kill.append(k) 1033 for k in lhs_kill: 1034 var_dict.pop(k, None) 1035 save_copies.extend(var_dict.items()) 1036 1037 return save_copies 1038 1039def fix_setitem_type(stmt, typemap, calltypes): 1040 """Copy propagation can replace setitem target variable, which can be array 1041 with 'A' layout. The replaced variable can be 'C' or 'F', so we update 1042 setitem call type reflect this (from matrix power test) 1043 """ 1044 if not isinstance(stmt, (ir.SetItem, ir.StaticSetItem)): 1045 return 1046 t_typ = typemap[stmt.target.name] 1047 s_typ = calltypes[stmt].args[0] 1048 # test_optional t_typ can be Optional with array 1049 if not isinstance( 1050 s_typ, 1051 types.npytypes.Array) or not isinstance( 1052 t_typ, 1053 types.npytypes.Array): 1054 return 1055 if s_typ.layout == 'A' and t_typ.layout != 'A': 1056 new_s_typ = s_typ.copy(layout=t_typ.layout) 1057 calltypes[stmt].args = ( 1058 new_s_typ, 1059 calltypes[stmt].args[1], 1060 calltypes[stmt].args[2]) 1061 return 1062 1063 1064def dprint_func_ir(func_ir, title, blocks=None): 1065 """Debug print function IR, with an optional blocks argument 1066 that may differ from the IR's original blocks. 1067 """ 1068 if config.DEBUG_ARRAY_OPT >= 1: 1069 ir_blocks = func_ir.blocks 1070 func_ir.blocks = ir_blocks if blocks == None else blocks 1071 name = func_ir.func_id.func_qualname 1072 print(("IR %s: %s" % (title, name)).center(80, "-")) 1073 func_ir.dump() 1074 print("-" * 40) 1075 func_ir.blocks = ir_blocks 1076 1077 1078def find_topo_order(blocks, cfg = None): 1079 """find topological order of blocks such that true branches are visited 1080 first (e.g. for_break test in test_dataflow). 1081 """ 1082 if cfg is None: 1083 cfg = compute_cfg_from_blocks(blocks) 1084 post_order = [] 1085 seen = set() 1086 1087 def _dfs_rec(node): 1088 if node not in seen: 1089 seen.add(node) 1090 succs = cfg._succs[node] 1091 last_inst = blocks[node].body[-1] 1092 if isinstance(last_inst, ir.Branch): 1093 succs = [last_inst.falsebr, last_inst.truebr] 1094 for dest in succs: 1095 if (node, dest) not in cfg._back_edges: 1096 _dfs_rec(dest) 1097 post_order.append(node) 1098 1099 _dfs_rec(cfg.entry_point()) 1100 post_order.reverse() 1101 return post_order 1102 1103 1104# other packages that define new nodes add calls to get call table 1105# format: {type:function} 1106call_table_extensions = {} 1107 1108 1109def get_call_table(blocks, call_table=None, reverse_call_table=None, topological_ordering=True): 1110 """returns a dictionary of call variables and their references. 1111 """ 1112 # call_table example: c = np.zeros becomes c:["zeroes", np] 1113 # reverse_call_table example: c = np.zeros becomes np_var:c 1114 if call_table is None: 1115 call_table = {} 1116 if reverse_call_table is None: 1117 reverse_call_table = {} 1118 1119 if topological_ordering: 1120 order = find_topo_order(blocks) 1121 else: 1122 order = list(blocks.keys()) 1123 1124 for label in reversed(order): 1125 for inst in reversed(blocks[label].body): 1126 if isinstance(inst, ir.Assign): 1127 lhs = inst.target.name 1128 rhs = inst.value 1129 if isinstance(rhs, ir.Expr) and rhs.op == 'call': 1130 call_table[rhs.func.name] = [] 1131 if isinstance(rhs, ir.Expr) and rhs.op == 'getattr': 1132 if lhs in call_table: 1133 call_table[lhs].append(rhs.attr) 1134 reverse_call_table[rhs.value.name] = lhs 1135 if lhs in reverse_call_table: 1136 call_var = reverse_call_table[lhs] 1137 call_table[call_var].append(rhs.attr) 1138 reverse_call_table[rhs.value.name] = call_var 1139 if isinstance(rhs, ir.Global): 1140 if lhs in call_table: 1141 call_table[lhs].append(rhs.value) 1142 if lhs in reverse_call_table: 1143 call_var = reverse_call_table[lhs] 1144 call_table[call_var].append(rhs.value) 1145 if isinstance(rhs, ir.FreeVar): 1146 if lhs in call_table: 1147 call_table[lhs].append(rhs.value) 1148 if lhs in reverse_call_table: 1149 call_var = reverse_call_table[lhs] 1150 call_table[call_var].append(rhs.value) 1151 if isinstance(rhs, ir.Var): 1152 if lhs in call_table: 1153 call_table[lhs].append(rhs.name) 1154 reverse_call_table[rhs.name] = lhs 1155 if lhs in reverse_call_table: 1156 call_var = reverse_call_table[lhs] 1157 call_table[call_var].append(rhs.name) 1158 for T, f in call_table_extensions.items(): 1159 if isinstance(inst, T): 1160 f(inst, call_table, reverse_call_table) 1161 return call_table, reverse_call_table 1162 1163 1164# other packages that define new nodes add calls to get tuple table 1165# format: {type:function} 1166tuple_table_extensions = {} 1167 1168 1169def get_tuple_table(blocks, tuple_table=None): 1170 """returns a dictionary of tuple variables and their values. 1171 """ 1172 if tuple_table is None: 1173 tuple_table = {} 1174 1175 for block in blocks.values(): 1176 for inst in block.body: 1177 if isinstance(inst, ir.Assign): 1178 lhs = inst.target.name 1179 rhs = inst.value 1180 if isinstance(rhs, ir.Expr) and rhs.op == 'build_tuple': 1181 tuple_table[lhs] = rhs.items 1182 if isinstance(rhs, ir.Const) and isinstance(rhs.value, tuple): 1183 tuple_table[lhs] = rhs.value 1184 for T, f in tuple_table_extensions.items(): 1185 if isinstance(inst, T): 1186 f(inst, tuple_table) 1187 return tuple_table 1188 1189 1190def get_stmt_writes(stmt): 1191 writes = set() 1192 if isinstance(stmt, (ir.Assign, ir.SetItem, ir.StaticSetItem)): 1193 writes.add(stmt.target.name) 1194 return writes 1195 1196 1197def rename_labels(blocks): 1198 """rename labels of function body blocks according to topological sort. 1199 The set of labels of these blocks will remain unchanged. 1200 """ 1201 topo_order = find_topo_order(blocks) 1202 1203 # make a block with return last if available (just for readability) 1204 return_label = -1 1205 for l, b in blocks.items(): 1206 if isinstance(b.body[-1], ir.Return): 1207 return_label = l 1208 # some cases like generators can have no return blocks 1209 if return_label != -1: 1210 topo_order.remove(return_label) 1211 topo_order.append(return_label) 1212 1213 label_map = {} 1214 all_labels = sorted(topo_order, reverse=True) 1215 for label in topo_order: 1216 label_map[label] = all_labels.pop() 1217 # update target labels in jumps/branches 1218 for b in blocks.values(): 1219 term = b.terminator 1220 if isinstance(term, ir.Jump): 1221 term.target = label_map[term.target] 1222 if isinstance(term, ir.Branch): 1223 term.truebr = label_map[term.truebr] 1224 term.falsebr = label_map[term.falsebr] 1225 # update blocks dictionary keys 1226 new_blocks = {} 1227 for k, b in blocks.items(): 1228 new_label = label_map[k] 1229 new_blocks[new_label] = b 1230 1231 return new_blocks 1232 1233 1234def simplify_CFG(blocks): 1235 """transform chains of blocks that have no loop into a single block""" 1236 # first, inline single-branch-block to its predecessors 1237 cfg = compute_cfg_from_blocks(blocks) 1238 def find_single_branch(label): 1239 block = blocks[label] 1240 return len(block.body) == 1 and isinstance(block.body[0], ir.Branch) 1241 single_branch_blocks = list(filter(find_single_branch, blocks.keys())) 1242 marked_for_del = set() 1243 for label in single_branch_blocks: 1244 inst = blocks[label].body[0] 1245 predecessors = cfg.predecessors(label) 1246 delete_block = True 1247 for (p, q) in predecessors: 1248 block = blocks[p] 1249 if isinstance(block.body[-1], ir.Jump): 1250 block.body[-1] = copy.copy(inst) 1251 else: 1252 delete_block = False 1253 if delete_block: 1254 marked_for_del.add(label) 1255 # Delete marked labels 1256 for label in marked_for_del: 1257 del blocks[label] 1258 merge_adjacent_blocks(blocks) 1259 return rename_labels(blocks) 1260 1261 1262arr_math = ['min', 'max', 'sum', 'prod', 'mean', 'var', 'std', 1263 'cumsum', 'cumprod', 'argmin', 'argmax', 'argsort', 1264 'nonzero', 'ravel'] 1265 1266 1267def canonicalize_array_math(func_ir, typemap, calltypes, typingctx): 1268 # save array arg to call 1269 # call_varname -> array 1270 blocks = func_ir.blocks 1271 saved_arr_arg = {} 1272 topo_order = find_topo_order(blocks) 1273 for label in topo_order: 1274 block = blocks[label] 1275 new_body = [] 1276 for stmt in block.body: 1277 if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr): 1278 lhs = stmt.target.name 1279 rhs = stmt.value 1280 # replace A.func with np.func, and save A in saved_arr_arg 1281 if (rhs.op == 'getattr' and rhs.attr in arr_math 1282 and isinstance( 1283 typemap[rhs.value.name], types.npytypes.Array)): 1284 rhs = stmt.value 1285 arr = rhs.value 1286 saved_arr_arg[lhs] = arr 1287 scope = arr.scope 1288 loc = arr.loc 1289 # g_np_var = Global(numpy) 1290 g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) 1291 typemap[g_np_var.name] = types.misc.Module(numpy) 1292 g_np = ir.Global('np', numpy, loc) 1293 g_np_assign = ir.Assign(g_np, g_np_var, loc) 1294 rhs.value = g_np_var 1295 new_body.append(g_np_assign) 1296 func_ir._definitions[g_np_var.name] = [g_np] 1297 # update func var type 1298 func = getattr(numpy, rhs.attr) 1299 func_typ = get_np_ufunc_typ(func) 1300 typemap.pop(lhs) 1301 typemap[lhs] = func_typ 1302 if rhs.op == 'call' and rhs.func.name in saved_arr_arg: 1303 # add array as first arg 1304 arr = saved_arr_arg[rhs.func.name] 1305 # update call type signature to include array arg 1306 old_sig = calltypes.pop(rhs) 1307 # argsort requires kws for typing so sig.args can't be used 1308 # reusing sig.args since some types become Const in sig 1309 argtyps = old_sig.args[:len(rhs.args)] 1310 kwtyps = {name: typemap[v.name] for name, v in rhs.kws} 1311 calltypes[rhs] = typemap[rhs.func.name].get_call_type( 1312 typingctx, [typemap[arr.name]] + list(argtyps), kwtyps) 1313 rhs.args = [arr] + rhs.args 1314 1315 new_body.append(stmt) 1316 block.body = new_body 1317 return 1318 1319 1320# format: {type:function} 1321array_accesses_extensions = {} 1322 1323 1324def get_array_accesses(blocks, accesses=None): 1325 """returns a set of arrays accessed and their indices. 1326 """ 1327 if accesses is None: 1328 accesses = set() 1329 1330 for block in blocks.values(): 1331 for inst in block.body: 1332 if isinstance(inst, ir.SetItem): 1333 accesses.add((inst.target.name, inst.index.name)) 1334 if isinstance(inst, ir.StaticSetItem): 1335 accesses.add((inst.target.name, inst.index_var.name)) 1336 if isinstance(inst, ir.Assign): 1337 lhs = inst.target.name 1338 rhs = inst.value 1339 if isinstance(rhs, ir.Expr) and rhs.op == 'getitem': 1340 accesses.add((rhs.value.name, rhs.index.name)) 1341 if isinstance(rhs, ir.Expr) and rhs.op == 'static_getitem': 1342 index = rhs.index 1343 # slice is unhashable, so just keep the variable 1344 if index is None or is_slice_index(index): 1345 index = rhs.index_var.name 1346 accesses.add((rhs.value.name, index)) 1347 for T, f in array_accesses_extensions.items(): 1348 if isinstance(inst, T): 1349 f(inst, accesses) 1350 return accesses 1351 1352def is_slice_index(index): 1353 """see if index is a slice index or has slice in it""" 1354 if isinstance(index, slice): 1355 return True 1356 if isinstance(index, tuple): 1357 for i in index: 1358 if isinstance(i, slice): 1359 return True 1360 return False 1361 1362def merge_adjacent_blocks(blocks): 1363 cfg = compute_cfg_from_blocks(blocks) 1364 # merge adjacent blocks 1365 removed = set() 1366 for label in list(blocks.keys()): 1367 if label in removed: 1368 continue 1369 block = blocks[label] 1370 succs = list(cfg.successors(label)) 1371 while True: 1372 if len(succs) != 1: 1373 break 1374 next_label = succs[0][0] 1375 if next_label in removed: 1376 break 1377 preds = list(cfg.predecessors(next_label)) 1378 succs = list(cfg.successors(next_label)) 1379 if len(preds) != 1 or preds[0][0] != label: 1380 break 1381 next_block = blocks[next_label] 1382 # XXX: commented out since scope objects are not consistent 1383 # throughout the compiler. for example, pieces of code are compiled 1384 # and inlined on the fly without proper scope merge. 1385 # if block.scope != next_block.scope: 1386 # break 1387 # merge 1388 block.body.pop() # remove Jump 1389 block.body += next_block.body 1390 del blocks[next_label] 1391 removed.add(next_label) 1392 label = next_label 1393 1394def restore_copy_var_names(blocks, save_copies, typemap): 1395 """ 1396 restores variable names of user variables after applying copy propagation 1397 """ 1398 rename_dict = {} 1399 for (a, b) in save_copies: 1400 # a is string name, b is variable 1401 # if a is user variable and b is generated temporary and b is not 1402 # already renamed 1403 if (not a.startswith('$') and b.name.startswith('$') 1404 and b.name not in rename_dict): 1405 new_name = mk_unique_var('${}'.format(a)); 1406 rename_dict[b.name] = new_name 1407 typ = typemap.pop(b.name) 1408 typemap[new_name] = typ 1409 1410 replace_var_names(blocks, rename_dict) 1411 1412def simplify(func_ir, typemap, calltypes): 1413 remove_dels(func_ir.blocks) 1414 # get copies in to blocks and out from blocks 1415 in_cps, out_cps = copy_propagate(func_ir.blocks, typemap) 1416 # table mapping variable names to ir.Var objects to help replacement 1417 name_var_table = get_name_var_table(func_ir.blocks) 1418 save_copies = apply_copy_propagate( 1419 func_ir.blocks, 1420 in_cps, 1421 name_var_table, 1422 typemap, 1423 calltypes) 1424 restore_copy_var_names(func_ir.blocks, save_copies, typemap) 1425 # remove dead code to enable fusion 1426 if config.DEBUG_ARRAY_OPT >= 1: 1427 dprint_func_ir(func_ir, "after copy prop") 1428 remove_dead(func_ir.blocks, func_ir.arg_names, func_ir, typemap) 1429 func_ir.blocks = simplify_CFG(func_ir.blocks) 1430 if config.DEBUG_ARRAY_OPT >= 1: 1431 dprint_func_ir(func_ir, "after simplify") 1432 1433class GuardException(Exception): 1434 pass 1435 1436def require(cond): 1437 """ 1438 Raise GuardException if the given condition is False. 1439 """ 1440 if not cond: 1441 raise GuardException 1442 1443def guard(func, *args, **kwargs): 1444 """ 1445 Run a function with given set of arguments, and guard against 1446 any GuardException raised by the function by returning None, 1447 or the expected return results if no such exception was raised. 1448 """ 1449 try: 1450 return func(*args, **kwargs) 1451 except GuardException: 1452 return None 1453 1454def get_definition(func_ir, name, **kwargs): 1455 """ 1456 Same as func_ir.get_definition(name), but raise GuardException if 1457 exception KeyError is caught. 1458 """ 1459 try: 1460 return func_ir.get_definition(name, **kwargs) 1461 except KeyError: 1462 raise GuardException 1463 1464def build_definitions(blocks, definitions=None): 1465 """Build the definitions table of the given blocks by scanning 1466 through all blocks and instructions, useful when the definitions 1467 table is out-of-sync. 1468 Will return a new definition table if one is not passed. 1469 """ 1470 if definitions is None: 1471 definitions = collections.defaultdict(list) 1472 1473 for block in blocks.values(): 1474 for inst in block.body: 1475 if isinstance(inst, ir.Assign): 1476 name = inst.target.name 1477 definition = definitions.get(name, []) 1478 if definition == []: 1479 definitions[name] = definition 1480 definition.append(inst.value) 1481 if type(inst) in build_defs_extensions: 1482 f = build_defs_extensions[type(inst)] 1483 f(inst, definitions) 1484 1485 return definitions 1486 1487build_defs_extensions = {} 1488 1489def find_callname(func_ir, expr, typemap=None, definition_finder=get_definition): 1490 """Try to find a call expression's function and module names and return 1491 them as strings for unbounded calls. If the call is a bounded call, return 1492 the self object instead of module name. Raise GuardException if failed. 1493 1494 Providing typemap can make the call matching more accurate in corner cases 1495 such as bounded call on an object which is inside another object. 1496 """ 1497 require(isinstance(expr, ir.Expr) and expr.op == 'call') 1498 callee = expr.func 1499 callee_def = definition_finder(func_ir, callee) 1500 attrs = [] 1501 obj = None 1502 while True: 1503 if isinstance(callee_def, (ir.Global, ir.FreeVar)): 1504 # require(callee_def.value == numpy) 1505 # these checks support modules like numpy, numpy.random as well as 1506 # calls like len() and intrinsitcs like assertEquiv 1507 keys = ['name', '_name', '__name__'] 1508 value = None 1509 for key in keys: 1510 if hasattr(callee_def.value, key): 1511 value = getattr(callee_def.value, key) 1512 break 1513 if not value or not isinstance(value, str): 1514 raise GuardException 1515 attrs.append(value) 1516 def_val = callee_def.value 1517 # get the underlying definition of Intrinsic object to be able to 1518 # find the module effectively. 1519 # Otherwise, it will return numba.extending 1520 if isinstance(def_val, _Intrinsic): 1521 def_val = def_val._defn 1522 if hasattr(def_val, '__module__'): 1523 mod_name = def_val.__module__ 1524 # it might be a numpy function imported directly 1525 if (hasattr(numpy, value) 1526 and def_val == getattr(numpy, value)): 1527 attrs += ['numpy'] 1528 # it might be a np.random function imported directly 1529 elif (hasattr(numpy.random, value) 1530 and def_val == getattr(numpy.random, value)): 1531 attrs += ['random', 'numpy'] 1532 elif mod_name is not None: 1533 attrs.append(mod_name) 1534 else: 1535 class_name = def_val.__class__.__name__ 1536 if class_name == 'builtin_function_or_method': 1537 class_name = 'builtin' 1538 if class_name != 'module': 1539 attrs.append(class_name) 1540 break 1541 elif isinstance(callee_def, ir.Expr) and callee_def.op == 'getattr': 1542 obj = callee_def.value 1543 attrs.append(callee_def.attr) 1544 if typemap and obj.name in typemap: 1545 typ = typemap[obj.name] 1546 if not isinstance(typ, types.Module): 1547 return attrs[0], obj 1548 callee_def = definition_finder(func_ir, obj) 1549 else: 1550 # obj.func calls where obj is not np array 1551 if obj is not None: 1552 return '.'.join(reversed(attrs)), obj 1553 raise GuardException 1554 return attrs[0], '.'.join(reversed(attrs[1:])) 1555 1556def find_build_sequence(func_ir, var): 1557 """Check if a variable is constructed via build_tuple or 1558 build_list or build_set, and return the sequence and the 1559 operator, or raise GuardException otherwise. 1560 Note: only build_tuple is immutable, so use with care. 1561 """ 1562 require(isinstance(var, ir.Var)) 1563 var_def = get_definition(func_ir, var) 1564 require(isinstance(var_def, ir.Expr)) 1565 build_ops = ['build_tuple', 'build_list', 'build_set'] 1566 require(var_def.op in build_ops) 1567 return var_def.items, var_def.op 1568 1569def find_const(func_ir, var): 1570 """Check if a variable is defined as constant, and return 1571 the constant value, or raise GuardException otherwise. 1572 """ 1573 require(isinstance(var, ir.Var)) 1574 var_def = get_definition(func_ir, var) 1575 require(isinstance(var_def, (ir.Const, ir.Global, ir.FreeVar))) 1576 return var_def.value 1577 1578def compile_to_numba_ir(mk_func, glbls, typingctx=None, arg_typs=None, 1579 typemap=None, calltypes=None): 1580 """ 1581 Compile a function or a make_function node to Numba IR. 1582 1583 Rename variables and 1584 labels to avoid conflict if inlined somewhere else. Perform type inference 1585 if typingctx and other typing inputs are available and update typemap and 1586 calltypes. 1587 """ 1588 from numba.core import typed_passes 1589 # mk_func can be actual function or make_function node, or a njit function 1590 if hasattr(mk_func, 'code'): 1591 code = mk_func.code 1592 elif hasattr(mk_func, '__code__'): 1593 code = mk_func.__code__ 1594 else: 1595 raise NotImplementedError("function type not recognized {}".format(mk_func)) 1596 f_ir = get_ir_of_code(glbls, code) 1597 remove_dels(f_ir.blocks) 1598 1599 # relabel by adding an offset 1600 global _max_label 1601 f_ir.blocks = add_offset_to_labels(f_ir.blocks, _max_label + 1) 1602 max_label = max(f_ir.blocks.keys()) 1603 _max_label = max_label 1604 1605 # rename all variables to avoid conflict 1606 var_table = get_name_var_table(f_ir.blocks) 1607 new_var_dict = {} 1608 for name, var in var_table.items(): 1609 new_var_dict[name] = mk_unique_var(name) 1610 replace_var_names(f_ir.blocks, new_var_dict) 1611 1612 # perform type inference if typingctx is available and update type 1613 # data structures typemap and calltypes 1614 if typingctx: 1615 f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage( 1616 typingctx, f_ir, arg_typs, None) 1617 # remove argument entries like arg.a from typemap 1618 arg_names = [vname for vname in f_typemap if vname.startswith("arg.")] 1619 for a in arg_names: 1620 f_typemap.pop(a) 1621 typemap.update(f_typemap) 1622 calltypes.update(f_calltypes) 1623 return f_ir 1624 1625def _create_function_from_code_obj(fcode, func_env, func_arg, func_clo, glbls): 1626 """ 1627 Creates a function from a code object. Args: 1628 * fcode - the code object 1629 * func_env - string for the freevar placeholders 1630 * func_arg - string for the function args (e.g. "a, b, c, d=None") 1631 * func_clo - string for the closure args 1632 * glbls - the function globals 1633 """ 1634 func_text = "def g():\n%s\n def f(%s):\n return (%s)\n return f" % ( 1635 func_env, func_arg, func_clo) 1636 loc = {} 1637 exec(func_text, glbls, loc) 1638 1639 f = loc['g']() 1640 # replace the code body 1641 f.__code__ = fcode 1642 f.__name__ = fcode.co_name 1643 return f 1644 1645def get_ir_of_code(glbls, fcode): 1646 """ 1647 Compile a code object to get its IR, ir.Del nodes are emitted 1648 """ 1649 nfree = len(fcode.co_freevars) 1650 func_env = "\n".join([" c_%d = None" % i for i in range(nfree)]) 1651 func_clo = ",".join(["c_%d" % i for i in range(nfree)]) 1652 func_arg = ",".join(["x_%d" % i for i in range(fcode.co_argcount)]) 1653 1654 f = _create_function_from_code_obj(fcode, func_env, func_arg, func_clo, 1655 glbls) 1656 1657 from numba.core import compiler 1658 ir = compiler.run_frontend(f) 1659 # we need to run the before inference rewrite pass to normalize the IR 1660 # XXX: check rewrite pass flag? 1661 # for example, Raise nodes need to become StaticRaise before type inference 1662 class DummyPipeline(object): 1663 def __init__(self, f_ir): 1664 self.state = compiler.StateDict() 1665 self.state.typingctx = None 1666 self.state.targetctx = None 1667 self.state.args = None 1668 self.state.func_ir = f_ir 1669 self.state.typemap = None 1670 self.state.return_type = None 1671 self.state.calltypes = None 1672 rewrites.rewrite_registry.apply('before-inference', DummyPipeline(ir).state) 1673 # call inline pass to handle cases like stencils and comprehensions 1674 swapped = {} # TODO: get this from diagnostics store 1675 inline_pass = numba.core.inline_closurecall.InlineClosureCallPass( 1676 ir, numba.core.cpu.ParallelOptions(False), swapped) 1677 inline_pass.run() 1678 post_proc = postproc.PostProcessor(ir) 1679 post_proc.run(True) 1680 return ir 1681 1682def replace_arg_nodes(block, args): 1683 """ 1684 Replace ir.Arg(...) with variables 1685 """ 1686 for stmt in block.body: 1687 if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg): 1688 idx = stmt.value.index 1689 assert(idx < len(args)) 1690 stmt.value = args[idx] 1691 return 1692 1693def replace_returns(blocks, target, return_label): 1694 """ 1695 Return return statement by assigning directly to target, and a jump. 1696 """ 1697 for block in blocks.values(): 1698 casts = [] 1699 for i, stmt in enumerate(block.body): 1700 if isinstance(stmt, ir.Return): 1701 assert(i + 1 == len(block.body)) 1702 block.body[i] = ir.Assign(stmt.value, target, stmt.loc) 1703 block.body.append(ir.Jump(return_label, stmt.loc)) 1704 # remove cast of the returned value 1705 for cast in casts: 1706 if cast.target.name == stmt.value.name: 1707 cast.value = cast.value.value 1708 elif isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'cast': 1709 casts.append(stmt) 1710 1711def gen_np_call(func_as_str, func, lhs, args, typingctx, typemap, calltypes): 1712 scope = args[0].scope 1713 loc = args[0].loc 1714 1715 # g_np_var = Global(numpy) 1716 g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) 1717 typemap[g_np_var.name] = types.misc.Module(numpy) 1718 g_np = ir.Global('np', numpy, loc) 1719 g_np_assign = ir.Assign(g_np, g_np_var, loc) 1720 # attr call: <something>_attr = getattr(g_np_var, func_as_str) 1721 np_attr_call = ir.Expr.getattr(g_np_var, func_as_str, loc) 1722 attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc) 1723 func_var_typ = get_np_ufunc_typ(func) 1724 typemap[attr_var.name] = func_var_typ 1725 attr_assign = ir.Assign(np_attr_call, attr_var, loc) 1726 # np call: lhs = np_attr(*args) 1727 np_call = ir.Expr.call(attr_var, args, (), loc) 1728 arg_types = [typemap[x.name] for x in args] 1729 func_typ = func_var_typ.get_call_type(typingctx, arg_types, {}) 1730 calltypes[np_call] = func_typ 1731 np_assign = ir.Assign(np_call, lhs, loc) 1732 return [g_np_assign, attr_assign, np_assign] 1733 1734def dump_blocks(blocks): 1735 for label, block in blocks.items(): 1736 print(label, ":") 1737 for stmt in block.body: 1738 print(" ", stmt) 1739 1740def is_get_setitem(stmt): 1741 """stmt is getitem assignment or setitem (and static cases)""" 1742 return is_getitem(stmt) or is_setitem(stmt) 1743 1744 1745def is_getitem(stmt): 1746 """true if stmt is a getitem or static_getitem assignment""" 1747 return (isinstance(stmt, ir.Assign) 1748 and isinstance(stmt.value, ir.Expr) 1749 and stmt.value.op in ['getitem', 'static_getitem']) 1750 1751def is_setitem(stmt): 1752 """true if stmt is a SetItem or StaticSetItem node""" 1753 return isinstance(stmt, (ir.SetItem, ir.StaticSetItem)) 1754 1755def index_var_of_get_setitem(stmt): 1756 """get index variable for getitem/setitem nodes (and static cases)""" 1757 if is_getitem(stmt): 1758 if stmt.value.op == 'getitem': 1759 return stmt.value.index 1760 else: 1761 return stmt.value.index_var 1762 1763 if is_setitem(stmt): 1764 if isinstance(stmt, ir.SetItem): 1765 return stmt.index 1766 else: 1767 return stmt.index_var 1768 1769 return None 1770 1771def set_index_var_of_get_setitem(stmt, new_index): 1772 if is_getitem(stmt): 1773 if stmt.value.op == 'getitem': 1774 stmt.value.index = new_index 1775 else: 1776 stmt.value.index_var = new_index 1777 elif is_setitem(stmt): 1778 if isinstance(stmt, ir.SetItem): 1779 stmt.index = new_index 1780 else: 1781 stmt.index_var = new_index 1782 else: 1783 raise ValueError("getitem or setitem node expected but received {}".format( 1784 stmt)) 1785 1786 1787def is_namedtuple_class(c): 1788 """check if c is a namedtuple class""" 1789 if not isinstance(c, type): 1790 return False 1791 # should have only tuple as superclass 1792 bases = c.__bases__ 1793 if len(bases) != 1 or bases[0] != tuple: 1794 return False 1795 # should have _make method 1796 if not hasattr(c, '_make'): 1797 return False 1798 # should have _fields that is all string 1799 fields = getattr(c, '_fields', None) 1800 if not isinstance(fields, tuple): 1801 return False 1802 return all(isinstance(f, str) for f in fields) 1803 1804 1805def fill_block_with_call(newblock, callee, label_next, inputs, outputs): 1806 """Fill *newblock* to call *callee* with arguments listed in *inputs*. 1807 The returned values are unwraped into variables in *outputs*. 1808 The block would then jump to *label_next*. 1809 """ 1810 scope = newblock.scope 1811 loc = newblock.loc 1812 1813 fn = ir.Const(value=callee, loc=loc) 1814 fnvar = scope.make_temp(loc=loc) 1815 newblock.append(ir.Assign(target=fnvar, value=fn, loc=loc)) 1816 # call 1817 args = [scope.get_exact(name) for name in inputs] 1818 callexpr = ir.Expr.call(func=fnvar, args=args, kws=(), loc=loc) 1819 callres = scope.make_temp(loc=loc) 1820 newblock.append(ir.Assign(target=callres, value=callexpr, loc=loc)) 1821 # unpack return value 1822 for i, out in enumerate(outputs): 1823 target = scope.get_exact(out) 1824 getitem = ir.Expr.static_getitem(value=callres, index=i, 1825 index_var=None, loc=loc) 1826 newblock.append(ir.Assign(target=target, value=getitem, loc=loc)) 1827 # jump to next block 1828 newblock.append(ir.Jump(target=label_next, loc=loc)) 1829 return newblock 1830 1831 1832def fill_callee_prologue(block, inputs, label_next): 1833 """ 1834 Fill a new block *block* that unwraps arguments using names in *inputs* and 1835 then jumps to *label_next*. 1836 1837 Expected to use with *fill_block_with_call()* 1838 """ 1839 scope = block.scope 1840 loc = block.loc 1841 # load args 1842 args = [ir.Arg(name=k, index=i, loc=loc) 1843 for i, k in enumerate(inputs)] 1844 for aname, aval in zip(inputs, args): 1845 tmp = ir.Var(scope=scope, name=aname, loc=loc) 1846 block.append(ir.Assign(target=tmp, value=aval, loc=loc)) 1847 # jump to loop entry 1848 block.append(ir.Jump(target=label_next, loc=loc)) 1849 return block 1850 1851 1852def fill_callee_epilogue(block, outputs): 1853 """ 1854 Fill a new block *block* to prepare the return values. 1855 This block is the last block of the function. 1856 1857 Expected to use with *fill_block_with_call()* 1858 """ 1859 scope = block.scope 1860 loc = block.loc 1861 # prepare tuples to return 1862 vals = [scope.get_exact(name=name) for name in outputs] 1863 tupexpr = ir.Expr.build_tuple(items=vals, loc=loc) 1864 tup = scope.make_temp(loc=loc) 1865 block.append(ir.Assign(target=tup, value=tupexpr, loc=loc)) 1866 # return 1867 block.append(ir.Return(value=tup, loc=loc)) 1868 return block 1869 1870 1871def find_global_value(func_ir, var): 1872 """Check if a variable is a global value, and return the value, 1873 or raise GuardException otherwise. 1874 """ 1875 dfn = get_definition(func_ir, var) 1876 if isinstance(dfn, ir.Global): 1877 return dfn.value 1878 1879 if isinstance(dfn, ir.Expr) and dfn.op == 'getattr': 1880 prev_val = find_global_value(func_ir, dfn.value) 1881 try: 1882 val = getattr(prev_val, dfn.attr) 1883 return val 1884 except AttributeError: 1885 raise GuardException 1886 1887 raise GuardException 1888 1889 1890def raise_on_unsupported_feature(func_ir, typemap): 1891 """ 1892 Helper function to walk IR and raise if it finds op codes 1893 that are unsupported. Could be extended to cover IR sequences 1894 as well as op codes. Intended use is to call it as a pipeline 1895 stage just prior to lowering to prevent LoweringErrors for known 1896 unsupported features. 1897 """ 1898 gdb_calls = [] # accumulate calls to gdb/gdb_init 1899 1900 # issue 2195: check for excessively large tuples 1901 for arg_name in func_ir.arg_names: 1902 if arg_name in typemap and \ 1903 isinstance(typemap[arg_name], types.containers.UniTuple) and \ 1904 typemap[arg_name].count > 1000: 1905 # Raise an exception when len(tuple) > 1000. The choice of this number (1000) 1906 # was entirely arbitrary 1907 msg = ("Tuple '{}' length must be smaller than 1000.\n" 1908 "Large tuples lead to the generation of a prohibitively large " 1909 "LLVM IR which causes excessive memory pressure " 1910 "and large compile times.\n" 1911 "As an alternative, the use of a 'list' is recommended in " 1912 "place of a 'tuple' as lists do not suffer from this problem.".format(arg_name)) 1913 raise UnsupportedError(msg, func_ir.loc) 1914 1915 for blk in func_ir.blocks.values(): 1916 for stmt in blk.find_insts(ir.Assign): 1917 # This raises on finding `make_function` 1918 if isinstance(stmt.value, ir.Expr): 1919 if stmt.value.op == 'make_function': 1920 val = stmt.value 1921 1922 # See if the construct name can be refined 1923 code = getattr(val, 'code', None) 1924 if code is not None: 1925 # check if this is a closure, the co_name will 1926 # be the captured function name which is not 1927 # useful so be explicit 1928 if getattr(val, 'closure', None) is not None: 1929 use = '<creating a function from a closure>' 1930 expr = '' 1931 else: 1932 use = code.co_name 1933 expr = '(%s) ' % use 1934 else: 1935 use = '<could not ascertain use case>' 1936 expr = '' 1937 1938 msg = ("Numba encountered the use of a language " 1939 "feature it does not support in this context: " 1940 "%s (op code: make_function not supported). If " 1941 "the feature is explicitly supported it is " 1942 "likely that the result of the expression %s" 1943 "is being used in an unsupported manner.") % \ 1944 (use, expr) 1945 raise UnsupportedError(msg, stmt.value.loc) 1946 1947 # this checks for gdb initialization calls, only one is permitted 1948 if isinstance(stmt.value, (ir.Global, ir.FreeVar)): 1949 val = stmt.value 1950 val = getattr(val, 'value', None) 1951 if val is None: 1952 continue 1953 1954 # check global function 1955 found = False 1956 if isinstance(val, pytypes.FunctionType): 1957 found = val in {numba.gdb, numba.gdb_init} 1958 if not found: # freevar bind to intrinsic 1959 found = getattr(val, '_name', "") == "gdb_internal" 1960 if found: 1961 gdb_calls.append(stmt.loc) # report last seen location 1962 1963 # this checks that np.<type> was called if view is called 1964 if isinstance(stmt.value, ir.Expr): 1965 if stmt.value.op == 'getattr' and stmt.value.attr == 'view': 1966 var = stmt.value.value.name 1967 if isinstance(typemap[var], types.Array): 1968 continue 1969 df = func_ir.get_definition(var) 1970 cn = guard(find_callname, func_ir, df) 1971 if cn and cn[1] == 'numpy': 1972 ty = getattr(numpy, cn[0]) 1973 if (numpy.issubdtype(ty, numpy.integer) or 1974 numpy.issubdtype(ty, numpy.floating)): 1975 continue 1976 1977 vardescr = '' if var.startswith('$') else "'{}' ".format(var) 1978 raise TypingError( 1979 "'view' can only be called on NumPy dtypes, " 1980 "try wrapping the variable {}with 'np.<dtype>()'". 1981 format(vardescr), loc=stmt.loc) 1982 1983 # checks for globals that are also reflected 1984 if isinstance(stmt.value, ir.Global): 1985 ty = typemap[stmt.target.name] 1986 msg = ("The use of a %s type, assigned to variable '%s' in " 1987 "globals, is not supported as globals are considered " 1988 "compile-time constants and there is no known way to " 1989 "compile a %s type as a constant.") 1990 if (getattr(ty, 'reflected', False) or 1991 isinstance(ty, (types.DictType, types.ListType))): 1992 raise TypingError(msg % (ty, stmt.value.name, ty), loc=stmt.loc) 1993 1994 # checks for generator expressions (yield in use when func_ir has 1995 # not been identified as a generator). 1996 if isinstance(stmt.value, ir.Yield) and not func_ir.is_generator: 1997 msg = "The use of generator expressions is unsupported." 1998 raise UnsupportedError(msg, loc=stmt.loc) 1999 2000 # There is more than one call to function gdb/gdb_init 2001 if len(gdb_calls) > 1: 2002 msg = ("Calling either numba.gdb() or numba.gdb_init() more than once " 2003 "in a function is unsupported (strange things happen!), use " 2004 "numba.gdb_breakpoint() to create additional breakpoints " 2005 "instead.\n\nRelevant documentation is available here:\n" 2006 "https://numba.pydata.org/numba-doc/latest/user/troubleshoot.html" 2007 "/troubleshoot.html#using-numba-s-direct-gdb-bindings-in-" 2008 "nopython-mode\n\nConflicting calls found at:\n %s") 2009 buf = '\n'.join([x.strformat() for x in gdb_calls]) 2010 raise UnsupportedError(msg % buf) 2011 2012 2013def warn_deprecated(func_ir, typemap): 2014 # first pass, just walk the type map 2015 for name, ty in typemap.items(): 2016 # the Type Metaclass has a reflected member 2017 if ty.reflected: 2018 # if its an arg, report function call 2019 if name.startswith('arg.'): 2020 loc = func_ir.loc 2021 arg = name.split('.')[1] 2022 fname = func_ir.func_id.func_qualname 2023 tyname = 'list' if isinstance(ty, types.List) else 'set' 2024 url = ("https://numba.pydata.org/numba-doc/latest/reference/" 2025 "deprecation.html#deprecation-of-reflection-for-list-and" 2026 "-set-types") 2027 msg = ("\nEncountered the use of a type that is scheduled for " 2028 "deprecation: type 'reflected %s' found for argument " 2029 "'%s' of function '%s'.\n\nFor more information visit " 2030 "%s" % (tyname, arg, fname, url)) 2031 warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc)) 2032 2033 2034def resolve_func_from_module(func_ir, node): 2035 """ 2036 This returns the python function that is being getattr'd from a module in 2037 some IR, it resolves import chains/submodules recursively. Should it not be 2038 possible to find the python function being called None will be returned. 2039 2040 func_ir - the FunctionIR object 2041 node - the IR node from which to start resolving (should be a `getattr`). 2042 """ 2043 getattr_chain = [] 2044 def resolve_mod(mod): 2045 if getattr(mod, 'op', False) == 'getattr': 2046 getattr_chain.insert(0, mod.attr) 2047 try: 2048 mod = func_ir.get_definition(mod.value) 2049 except KeyError: # multiple definitions 2050 return None 2051 return resolve_mod(mod) 2052 elif isinstance(mod, (ir.Global, ir.FreeVar)): 2053 if isinstance(mod.value, pytypes.ModuleType): 2054 return mod 2055 return None 2056 2057 mod = resolve_mod(node) 2058 if mod is not None: 2059 defn = mod.value 2060 for x in getattr_chain: 2061 defn = getattr(defn, x, False) 2062 if not defn: 2063 break 2064 else: 2065 return defn 2066 else: 2067 return None 2068 2069 2070def enforce_no_dels(func_ir): 2071 """ 2072 Enforce there being no ir.Del nodes in the IR. 2073 """ 2074 for blk in func_ir.blocks.values(): 2075 dels = [x for x in blk.find_insts(ir.Del)] 2076 if dels: 2077 msg = "Illegal IR, del found at: %s" % dels[0] 2078 raise CompilerError(msg, loc=dels[0].loc) 2079 2080def enforce_no_phis(func_ir): 2081 """ 2082 Enforce there being no ir.Expr.phi nodes in the IR. 2083 """ 2084 for blk in func_ir.blocks.values(): 2085 phis = [x for x in blk.find_exprs(op='phi')] 2086 if phis: 2087 msg = "Illegal IR, phi found at: %s" % phis[0] 2088 raise CompilerError(msg, loc=phis[0].loc) 2089 2090 2091def check_and_legalize_ir(func_ir): 2092 """ 2093 This checks that the IR presented is legal 2094 """ 2095 enforce_no_phis(func_ir) 2096 enforce_no_dels(func_ir) 2097 # postprocess and emit ir.Dels 2098 post_proc = postproc.PostProcessor(func_ir) 2099 post_proc.run(True) 2100 2101 2102def convert_code_obj_to_function(code_obj, caller_ir): 2103 """ 2104 Converts a code object from a `make_function.code` attr in the IR into a 2105 python function, caller_ir is the FunctionIR of the caller and is used for 2106 the resolution of freevars. 2107 """ 2108 fcode = code_obj.code 2109 nfree = len(fcode.co_freevars) 2110 2111 # try and resolve freevars if they are consts in the caller's IR 2112 # these can be baked into the new function 2113 freevars = [] 2114 for x in fcode.co_freevars: 2115 # not using guard here to differentiate between multiple definition and 2116 # non-const variable 2117 try: 2118 freevar_def = caller_ir.get_definition(x) 2119 except KeyError: 2120 msg = ("Cannot capture a constant value for variable '%s' as there " 2121 "are multiple definitions present." % x) 2122 raise TypingError(msg, loc=code_obj.loc) 2123 if isinstance(freevar_def, ir.Const): 2124 freevars.append(freevar_def.value) 2125 else: 2126 msg = ("Cannot capture the non-constant value associated with " 2127 "variable '%s' in a function that will escape." % x) 2128 raise TypingError(msg, loc=code_obj.loc) 2129 2130 func_env = "\n".join([" c_%d = %s" % (i, x) for i, x in enumerate(freevars)]) 2131 func_clo = ",".join(["c_%d" % i for i in range(nfree)]) 2132 co_varnames = list(fcode.co_varnames) 2133 2134 # This is horrible. The code object knows about the number of args present 2135 # it also knows the name of the args but these are bundled in with other 2136 # vars in `co_varnames`. The make_function IR node knows what the defaults 2137 # are, they are defined in the IR as consts. The following finds the total 2138 # number of args (args + kwargs with defaults), finds the default values 2139 # and infers the number of "kwargs with defaults" from this and then infers 2140 # the number of actual arguments from that. 2141 n_kwargs = 0 2142 n_allargs = fcode.co_argcount 2143 kwarg_defaults = caller_ir.get_definition(code_obj.defaults) 2144 if kwarg_defaults is not None: 2145 if isinstance(kwarg_defaults, tuple): 2146 d = [caller_ir.get_definition(x).value for x in kwarg_defaults] 2147 kwarg_defaults_tup = tuple(d) 2148 else: 2149 d = [caller_ir.get_definition(x).value 2150 for x in kwarg_defaults.items] 2151 kwarg_defaults_tup = tuple(d) 2152 n_kwargs = len(kwarg_defaults_tup) 2153 nargs = n_allargs - n_kwargs 2154 2155 func_arg = ",".join(["%s" % (co_varnames[i]) for i in range(nargs)]) 2156 if n_kwargs: 2157 kw_const = ["%s = %s" % (co_varnames[i + nargs], kwarg_defaults_tup[i]) 2158 for i in range(n_kwargs)] 2159 func_arg += ", " 2160 func_arg += ", ".join(kw_const) 2161 2162 # globals are the same as those in the caller 2163 glbls = caller_ir.func_id.func.__globals__ 2164 2165 # create the function and return it 2166 return _create_function_from_code_obj(fcode, func_env, func_arg, func_clo, 2167 glbls) 2168