1import types as pytypes # avoid confusion with numba.types 2import copy 3import ctypes 4import numba.core.analysis 5from numba.core import utils, types, typing, errors, ir, rewrites, config, ir_utils 6from numba import prange 7from numba.parfors.parfor import internal_prange 8from numba.core.ir_utils import ( 9 mk_unique_var, 10 next_label, 11 add_offset_to_labels, 12 replace_vars, 13 remove_dels, 14 rename_labels, 15 find_topo_order, 16 merge_adjacent_blocks, 17 GuardException, 18 require, 19 guard, 20 get_definition, 21 find_callname, 22 find_build_sequence, 23 get_np_ufunc_typ, 24 get_ir_of_code, 25 simplify_CFG, 26 canonicalize_array_math, 27 dead_code_elimination, 28 ) 29 30from numba.core.analysis import ( 31 compute_cfg_from_blocks, 32 compute_use_defs, 33 compute_live_variables) 34from numba.core import postproc 35from numba.cpython.rangeobj import range_iter_len 36from numba.np.unsafe.ndarray import empty_inferred as unsafe_empty_inferred 37import numpy as np 38import operator 39import numba.misc.special 40 41""" 42Variable enable_inline_arraycall is only used for testing purpose. 43""" 44enable_inline_arraycall = True 45 46 47def callee_ir_validator(func_ir): 48 """Checks the IR of a callee is supported for inlining 49 """ 50 for blk in func_ir.blocks.values(): 51 for stmt in blk.find_insts(ir.Assign): 52 if isinstance(stmt.value, ir.Yield): 53 msg = "The use of yield in a closure is unsupported." 54 raise errors.UnsupportedError(msg, loc=stmt.loc) 55 56 57class InlineClosureCallPass(object): 58 """InlineClosureCallPass class looks for direct calls to locally defined 59 closures, and inlines the body of the closure function to the call site. 60 """ 61 62 def __init__(self, func_ir, parallel_options, swapped={}, typed=False): 63 self.func_ir = func_ir 64 self.parallel_options = parallel_options 65 self.swapped = swapped 66 self._processed_stencils = [] 67 self.typed = typed 68 69 def run(self): 70 """Run inline closure call pass. 71 """ 72 # Analysis relies on ir.Del presence, strip out later 73 pp = postproc.PostProcessor(self.func_ir) 74 pp.run(True) 75 76 modified = False 77 work_list = list(self.func_ir.blocks.items()) 78 debug_print = _make_debug_print("InlineClosureCallPass") 79 debug_print("START") 80 while work_list: 81 label, block = work_list.pop() 82 for i, instr in enumerate(block.body): 83 if isinstance(instr, ir.Assign): 84 lhs = instr.target 85 expr = instr.value 86 if isinstance(expr, ir.Expr) and expr.op == 'call': 87 call_name = guard(find_callname, self.func_ir, expr) 88 func_def = guard(get_definition, self.func_ir, expr.func) 89 90 if guard(self._inline_reduction, 91 work_list, block, i, expr, call_name): 92 modified = True 93 break # because block structure changed 94 95 if guard(self._inline_closure, 96 work_list, block, i, func_def): 97 modified = True 98 break # because block structure changed 99 100 if guard(self._inline_stencil, 101 instr, call_name, func_def): 102 modified = True 103 104 if enable_inline_arraycall: 105 # Identify loop structure 106 if modified: 107 # Need to do some cleanups if closure inlining kicked in 108 merge_adjacent_blocks(self.func_ir.blocks) 109 cfg = compute_cfg_from_blocks(self.func_ir.blocks) 110 debug_print("start inline arraycall") 111 _debug_dump(cfg) 112 loops = cfg.loops() 113 sized_loops = [(k, len(loops[k].body)) for k in loops.keys()] 114 visited = [] 115 # We go over all loops, bigger loops first (outer first) 116 for k, s in sorted(sized_loops, key=lambda tup: tup[1], reverse=True): 117 visited.append(k) 118 if guard(_inline_arraycall, self.func_ir, cfg, visited, loops[k], 119 self.swapped, self.parallel_options.comprehension, 120 self.typed): 121 modified = True 122 if modified: 123 _fix_nested_array(self.func_ir) 124 125 if modified: 126 # clean up now dead/unreachable blocks, e.g. unconditionally raising 127 # an exception in an inlined function would render some parts of the 128 # inliner unreachable 129 cfg = compute_cfg_from_blocks(self.func_ir.blocks) 130 for dead in cfg.dead_nodes(): 131 del self.func_ir.blocks[dead] 132 133 # run dead code elimination 134 dead_code_elimination(self.func_ir) 135 # do label renaming 136 self.func_ir.blocks = rename_labels(self.func_ir.blocks) 137 138 # inlining done, strip dels 139 remove_dels(self.func_ir.blocks) 140 141 debug_print("END") 142 143 def _inline_reduction(self, work_list, block, i, expr, call_name): 144 # only inline reduction in sequential execution, parallel handling 145 # is done in ParforPass. 146 require(not self.parallel_options.reduction) 147 require(call_name == ('reduce', 'builtins') or 148 call_name == ('reduce', '_functools')) 149 if len(expr.args) not in (2, 3): 150 raise TypeError("invalid reduce call, " 151 "two arguments are required (optional initial " 152 "value can also be specified)") 153 check_reduce_func(self.func_ir, expr.args[0]) 154 def reduce_func(f, A, v=None): 155 it = iter(A) 156 if v is not None: 157 s = v 158 else: 159 s = next(it) 160 for a in it: 161 s = f(s, a) 162 return s 163 inline_closure_call(self.func_ir, 164 self.func_ir.func_id.func.__globals__, 165 block, i, reduce_func, work_list=work_list, 166 callee_validator=callee_ir_validator) 167 return True 168 169 def _inline_stencil(self, instr, call_name, func_def): 170 from numba.stencils.stencil import StencilFunc 171 lhs = instr.target 172 expr = instr.value 173 # We keep the escaping variables of the stencil kernel 174 # alive by adding them to the actual kernel call as extra 175 # keyword arguments, which is ignored anyway. 176 if (isinstance(func_def, ir.Global) and 177 func_def.name == 'stencil' and 178 isinstance(func_def.value, StencilFunc)): 179 if expr.kws: 180 expr.kws += func_def.value.kws 181 else: 182 expr.kws = func_def.value.kws 183 return True 184 # Otherwise we proceed to check if it is a call to numba.stencil 185 require(call_name == ('stencil', 'numba.stencils.stencil') or 186 call_name == ('stencil', 'numba')) 187 require(expr not in self._processed_stencils) 188 self._processed_stencils.append(expr) 189 if not len(expr.args) == 1: 190 raise ValueError("As a minimum Stencil requires" 191 " a kernel as an argument") 192 stencil_def = guard(get_definition, self.func_ir, expr.args[0]) 193 require(isinstance(stencil_def, ir.Expr) and 194 stencil_def.op == "make_function") 195 kernel_ir = get_ir_of_code(self.func_ir.func_id.func.__globals__, 196 stencil_def.code) 197 options = dict(expr.kws) 198 if 'neighborhood' in options: 199 fixed = guard(self._fix_stencil_neighborhood, options) 200 if not fixed: 201 raise ValueError("stencil neighborhood option should be a tuple" 202 " with constant structure such as ((-w, w),)") 203 if 'index_offsets' in options: 204 fixed = guard(self._fix_stencil_index_offsets, options) 205 if not fixed: 206 raise ValueError("stencil index_offsets option should be a tuple" 207 " with constant structure such as (offset, )") 208 sf = StencilFunc(kernel_ir, 'constant', options) 209 sf.kws = expr.kws # hack to keep variables live 210 sf_global = ir.Global('stencil', sf, expr.loc) 211 self.func_ir._definitions[lhs.name] = [sf_global] 212 instr.value = sf_global 213 return True 214 215 def _fix_stencil_neighborhood(self, options): 216 """ 217 Extract the two-level tuple representing the stencil neighborhood 218 from the program IR to provide a tuple to StencilFunc. 219 """ 220 # build_tuple node with neighborhood for each dimension 221 dims_build_tuple = get_definition(self.func_ir, options['neighborhood']) 222 require(hasattr(dims_build_tuple, 'items')) 223 res = [] 224 for window_var in dims_build_tuple.items: 225 win_build_tuple = get_definition(self.func_ir, window_var) 226 require(hasattr(win_build_tuple, 'items')) 227 res.append(tuple(win_build_tuple.items)) 228 options['neighborhood'] = tuple(res) 229 return True 230 231 def _fix_stencil_index_offsets(self, options): 232 """ 233 Extract the tuple representing the stencil index offsets 234 from the program IR to provide to StencilFunc. 235 """ 236 offset_tuple = get_definition(self.func_ir, options['index_offsets']) 237 require(hasattr(offset_tuple, 'items')) 238 options['index_offsets'] = tuple(offset_tuple.items) 239 return True 240 241 def _inline_closure(self, work_list, block, i, func_def): 242 require(isinstance(func_def, ir.Expr) and 243 func_def.op == "make_function") 244 inline_closure_call(self.func_ir, 245 self.func_ir.func_id.func.__globals__, 246 block, i, func_def, work_list=work_list, 247 callee_validator=callee_ir_validator) 248 return True 249 250def check_reduce_func(func_ir, func_var): 251 """Checks the function at func_var in func_ir to make sure it's amenable 252 for inlining. Returns the function itself""" 253 reduce_func = guard(get_definition, func_ir, func_var) 254 if reduce_func is None: 255 raise ValueError("Reduce function cannot be found for njit \ 256 analysis") 257 if isinstance(reduce_func, (ir.FreeVar, ir.Global)): 258 if not isinstance(reduce_func.value, 259 numba.core.registry.CPUDispatcher): 260 raise ValueError("Invalid reduction function") 261 # pull out the python function for inlining 262 reduce_func = reduce_func.value.py_func 263 elif not (hasattr(reduce_func, 'code') 264 or hasattr(reduce_func, '__code__')): 265 raise ValueError("Invalid reduction function") 266 f_code = (reduce_func.code if hasattr(reduce_func, 'code') 267 else reduce_func.__code__) 268 if not f_code.co_argcount == 2: 269 raise TypeError("Reduction function should take 2 arguments") 270 return reduce_func 271 272 273class InlineWorker(object): 274 """ A worker class for inlining, this is a more advanced version of 275 `inline_closure_call` in that it permits inlining from function type, Numba 276 IR and code object. It also, runs the entire untyped compiler pipeline on 277 the inlinee to ensure that it is transformed as though it were compiled 278 directly. 279 """ 280 281 def __init__(self, 282 typingctx=None, 283 targetctx=None, 284 locals=None, 285 pipeline=None, 286 flags=None, 287 validator=callee_ir_validator, 288 typemap=None, 289 calltypes=None): 290 """ 291 Instantiate a new InlineWorker, all arguments are optional though some 292 must be supplied together for certain use cases. The methods will refuse 293 to run if the object isn't configured in the manner needed. Args are the 294 same as those in a numba.core.Compiler.state, except the validator which 295 is a function taking Numba IR and validating it for use when inlining 296 (this is optional and really to just provide better error messages about 297 things which the inliner cannot handle like yield in closure). 298 """ 299 def check(arg, name): 300 if arg is None: 301 raise TypeError("{} must not be None".format(name)) 302 303 from numba.core.compiler import DefaultPassBuilder 304 305 # check the stuff needed to run the more advanced compilation pipeline 306 # is valid if any of it is provided 307 compiler_args = (targetctx, locals, pipeline, flags) 308 compiler_group = [x is not None for x in compiler_args] 309 if any(compiler_group) and not all(compiler_group): 310 check(targetctx, 'targetctx') 311 check(locals, 'locals') 312 check(pipeline, 'pipeline') 313 check(flags, 'flags') 314 elif all(compiler_group): 315 check(typingctx, 'typingctx') 316 317 self._compiler_pipeline = DefaultPassBuilder.define_untyped_pipeline 318 319 self.typingctx = typingctx 320 self.targetctx = targetctx 321 self.locals = locals 322 self.pipeline = pipeline 323 self.flags = flags 324 self.validator = validator 325 self.debug_print = _make_debug_print("InlineWorker") 326 327 # check whether this inliner can also support typemap and calltypes 328 # update and if what's provided is valid 329 pair = (typemap, calltypes) 330 pair_is_none = [x is None for x in pair] 331 if any(pair_is_none) and not all(pair_is_none): 332 msg = ("typemap and calltypes must both be either None or have a " 333 "value, got: %s, %s") 334 raise TypeError(msg % pair) 335 self._permit_update_type_and_call_maps = not all(pair_is_none) 336 self.typemap = typemap 337 self.calltypes = calltypes 338 339 340 def inline_ir(self, caller_ir, block, i, callee_ir, callee_freevars, 341 arg_typs=None): 342 """ Inlines the callee_ir in the caller_ir at statement index i of block 343 `block`, callee_freevars are the free variables for the callee_ir. If 344 the callee_ir is derived from a function `func` then this is 345 `func.__code__.co_freevars`. If `arg_typs` is given and the InlineWorker 346 instance was initialized with a typemap and calltypes then they will be 347 appropriately updated based on the arg_typs. 348 """ 349 350 # Always copy the callee IR, it gets mutated 351 def copy_ir(the_ir): 352 kernel_copy = the_ir.copy() 353 kernel_copy.blocks = {} 354 for block_label, block in the_ir.blocks.items(): 355 new_block = copy.deepcopy(the_ir.blocks[block_label]) 356 new_block.body = [] 357 for stmt in the_ir.blocks[block_label].body: 358 scopy = copy.deepcopy(stmt) 359 new_block.body.append(scopy) 360 kernel_copy.blocks[block_label] = new_block 361 return kernel_copy 362 363 callee_ir = copy_ir(callee_ir) 364 365 # check that the contents of the callee IR is something that can be 366 # inlined if a validator is present 367 if self.validator is not None: 368 self.validator(callee_ir) 369 370 # save an unmutated copy of the callee_ir to return 371 callee_ir_original = callee_ir.copy() 372 scope = block.scope 373 instr = block.body[i] 374 call_expr = instr.value 375 callee_blocks = callee_ir.blocks 376 377 # 1. relabel callee_ir by adding an offset 378 max_label = max(ir_utils._max_label, max(caller_ir.blocks.keys())) 379 callee_blocks = add_offset_to_labels(callee_blocks, max_label + 1) 380 callee_blocks = simplify_CFG(callee_blocks) 381 callee_ir.blocks = callee_blocks 382 min_label = min(callee_blocks.keys()) 383 max_label = max(callee_blocks.keys()) 384 # reset globals in ir_utils before we use it 385 ir_utils._max_label = max_label 386 self.debug_print("After relabel") 387 _debug_dump(callee_ir) 388 389 # 2. rename all local variables in callee_ir with new locals created in 390 # caller_ir 391 callee_scopes = _get_all_scopes(callee_blocks) 392 self.debug_print("callee_scopes = ", callee_scopes) 393 # one function should only have one local scope 394 assert(len(callee_scopes) == 1) 395 callee_scope = callee_scopes[0] 396 var_dict = {} 397 for var in callee_scope.localvars._con.values(): 398 if not (var.name in callee_freevars): 399 new_var = scope.redefine(mk_unique_var(var.name), loc=var.loc) 400 var_dict[var.name] = new_var 401 self.debug_print("var_dict = ", var_dict) 402 replace_vars(callee_blocks, var_dict) 403 self.debug_print("After local var rename") 404 _debug_dump(callee_ir) 405 406 # 3. replace formal parameters with actual arguments 407 callee_func = callee_ir.func_id.func 408 args = _get_callee_args(call_expr, callee_func, block.body[i].loc, 409 caller_ir) 410 411 # 4. Update typemap 412 if self._permit_update_type_and_call_maps: 413 if arg_typs is None: 414 raise TypeError('arg_typs should have a value not None') 415 self.update_type_and_call_maps(callee_ir, arg_typs) 416 417 self.debug_print("After arguments rename: ") 418 _debug_dump(callee_ir) 419 420 _replace_args_with(callee_blocks, args) 421 # 5. split caller blocks into two 422 new_blocks = [] 423 new_block = ir.Block(scope, block.loc) 424 new_block.body = block.body[i + 1:] 425 new_label = next_label() 426 caller_ir.blocks[new_label] = new_block 427 new_blocks.append((new_label, new_block)) 428 block.body = block.body[:i] 429 block.body.append(ir.Jump(min_label, instr.loc)) 430 431 # 6. replace Return with assignment to LHS 432 topo_order = find_topo_order(callee_blocks) 433 _replace_returns(callee_blocks, instr.target, new_label) 434 435 # remove the old definition of instr.target too 436 if (instr.target.name in caller_ir._definitions 437 and call_expr in caller_ir._definitions[instr.target.name]): 438 # NOTE: target can have multiple definitions due to control flow 439 caller_ir._definitions[instr.target.name].remove(call_expr) 440 441 # 7. insert all new blocks, and add back definitions 442 for label in topo_order: 443 # block scope must point to parent's 444 block = callee_blocks[label] 445 block.scope = scope 446 _add_definitions(caller_ir, block) 447 caller_ir.blocks[label] = block 448 new_blocks.append((label, block)) 449 self.debug_print("After merge in") 450 _debug_dump(caller_ir) 451 452 return callee_ir_original, callee_blocks, var_dict, new_blocks 453 454 def inline_function(self, caller_ir, block, i, function, arg_typs=None): 455 """ Inlines the function in the caller_ir at statement index i of block 456 `block`. If `arg_typs` is given and the InlineWorker instance was 457 initialized with a typemap and calltypes then they will be appropriately 458 updated based on the arg_typs. 459 """ 460 callee_ir = self.run_untyped_passes(function) 461 freevars = function.__code__.co_freevars 462 return self.inline_ir(caller_ir, block, i, callee_ir, freevars, 463 arg_typs=arg_typs) 464 465 def run_untyped_passes(self, func): 466 """ 467 Run the compiler frontend's untyped passes over the given Python 468 function, and return the function's canonical Numba IR. 469 """ 470 from numba.core.compiler import StateDict, _CompileStatus 471 from numba.core.untyped_passes import ExtractByteCode, WithLifting 472 from numba.core import bytecode 473 from numba.parfors.parfor import ParforDiagnostics 474 state = StateDict() 475 state.func_ir = None 476 state.typingctx = self.typingctx 477 state.targetctx = self.targetctx 478 state.locals = self.locals 479 state.pipeline = self.pipeline 480 state.flags = self.flags 481 482 # Disable SSA transformation, the call site won't be in SSA form and 483 # self.inline_ir depends on this being the case. 484 state.flags.enable_ssa = False 485 486 state.func_id = bytecode.FunctionIdentity.from_function(func) 487 488 state.typemap = None 489 state.calltypes = None 490 state.type_annotation = None 491 state.status = _CompileStatus(False, False) 492 state.return_type = None 493 state.parfor_diagnostics = ParforDiagnostics() 494 state.metadata = {} 495 496 ExtractByteCode().run_pass(state) 497 # This is a lie, just need *some* args for the case where an obj mode 498 # with lift is needed 499 state.args = len(state.bc.func_id.pysig.parameters) * (types.pyobject,) 500 501 pm = self._compiler_pipeline(state) 502 503 pm.finalize() 504 pm.run(state) 505 return state.func_ir 506 507 def update_type_and_call_maps(self, callee_ir, arg_typs): 508 """ Updates the type and call maps based on calling callee_ir with arguments 509 from arg_typs""" 510 if not self._permit_update_type_and_call_maps: 511 msg = ("InlineWorker instance not configured correctly, typemap or " 512 "calltypes missing in initialization.") 513 raise ValueError(msg) 514 from numba.core import typed_passes 515 # call branch pruning to simplify IR and avoid inference errors 516 callee_ir._definitions = ir_utils.build_definitions(callee_ir.blocks) 517 numba.core.analysis.dead_branch_prune(callee_ir, arg_typs) 518 f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage( 519 self.typingctx, callee_ir, arg_typs, None) 520 canonicalize_array_math(callee_ir, f_typemap, 521 f_calltypes, self.typingctx) 522 # remove argument entries like arg.a from typemap 523 arg_names = [vname for vname in f_typemap if vname.startswith("arg.")] 524 for a in arg_names: 525 f_typemap.pop(a) 526 self.typemap.update(f_typemap) 527 self.calltypes.update(f_calltypes) 528 529 530def inline_closure_call(func_ir, glbls, block, i, callee, typingctx=None, 531 arg_typs=None, typemap=None, calltypes=None, 532 work_list=None, callee_validator=None, 533 replace_freevars=True): 534 """Inline the body of `callee` at its callsite (`i`-th instruction of `block`) 535 536 `func_ir` is the func_ir object of the caller function and `glbls` is its 537 global variable environment (func_ir.func_id.func.__globals__). 538 `block` is the IR block of the callsite and `i` is the index of the 539 callsite's node. `callee` is either the called function or a 540 make_function node. `typingctx`, `typemap` and `calltypes` are typing 541 data structures of the caller, available if we are in a typed pass. 542 `arg_typs` includes the types of the arguments at the callsite. 543 `callee_validator` is an optional callable which can be used to validate the 544 IR of the callee to ensure that it contains IR supported for inlining, it 545 takes one argument, the func_ir of the callee 546 547 Returns IR blocks of the callee and the variable renaming dictionary used 548 for them to facilitate further processing of new blocks. 549 """ 550 scope = block.scope 551 instr = block.body[i] 552 call_expr = instr.value 553 debug_print = _make_debug_print("inline_closure_call") 554 debug_print("Found closure call: ", instr, " with callee = ", callee) 555 # support both function object and make_function Expr 556 callee_code = callee.code if hasattr(callee, 'code') else callee.__code__ 557 callee_closure = callee.closure if hasattr(callee, 'closure') else callee.__closure__ 558 # first, get the IR of the callee 559 if isinstance(callee, pytypes.FunctionType): 560 from numba.core import compiler 561 callee_ir = compiler.run_frontend(callee, inline_closures=True) 562 else: 563 callee_ir = get_ir_of_code(glbls, callee_code) 564 565 # check that the contents of the callee IR is something that can be inlined 566 # if a validator is supplied 567 if callee_validator is not None: 568 callee_validator(callee_ir) 569 570 callee_blocks = callee_ir.blocks 571 572 # 1. relabel callee_ir by adding an offset 573 max_label = max(ir_utils._max_label, max(func_ir.blocks.keys())) 574 callee_blocks = add_offset_to_labels(callee_blocks, max_label + 1) 575 callee_blocks = simplify_CFG(callee_blocks) 576 callee_ir.blocks = callee_blocks 577 min_label = min(callee_blocks.keys()) 578 max_label = max(callee_blocks.keys()) 579 # reset globals in ir_utils before we use it 580 ir_utils._max_label = max_label 581 debug_print("After relabel") 582 _debug_dump(callee_ir) 583 584 # 2. rename all local variables in callee_ir with new locals created in func_ir 585 callee_scopes = _get_all_scopes(callee_blocks) 586 debug_print("callee_scopes = ", callee_scopes) 587 # one function should only have one local scope 588 assert(len(callee_scopes) == 1) 589 callee_scope = callee_scopes[0] 590 var_dict = {} 591 for var in callee_scope.localvars._con.values(): 592 if not (var.name in callee_code.co_freevars): 593 new_var = scope.redefine(mk_unique_var(var.name), loc=var.loc) 594 var_dict[var.name] = new_var 595 debug_print("var_dict = ", var_dict) 596 replace_vars(callee_blocks, var_dict) 597 debug_print("After local var rename") 598 _debug_dump(callee_ir) 599 600 # 3. replace formal parameters with actual arguments 601 args = _get_callee_args(call_expr, callee, block.body[i].loc, func_ir) 602 603 debug_print("After arguments rename: ") 604 _debug_dump(callee_ir) 605 606 # 4. replace freevar with actual closure var 607 if callee_closure and replace_freevars: 608 closure = func_ir.get_definition(callee_closure) 609 debug_print("callee's closure = ", closure) 610 if isinstance(closure, tuple): 611 cellget = ctypes.pythonapi.PyCell_Get 612 cellget.restype = ctypes.py_object 613 cellget.argtypes = (ctypes.py_object,) 614 items = tuple(cellget(x) for x in closure) 615 else: 616 assert(isinstance(closure, ir.Expr) 617 and closure.op == 'build_tuple') 618 items = closure.items 619 assert(len(callee_code.co_freevars) == len(items)) 620 _replace_freevars(callee_blocks, items) 621 debug_print("After closure rename") 622 _debug_dump(callee_ir) 623 624 if typingctx: 625 from numba.core import typed_passes 626 # call branch pruning to simplify IR and avoid inference errors 627 callee_ir._definitions = ir_utils.build_definitions(callee_ir.blocks) 628 numba.core.analysis.dead_branch_prune(callee_ir, arg_typs) 629 try: 630 f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage( 631 typingctx, callee_ir, arg_typs, None) 632 except Exception: 633 f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage( 634 typingctx, callee_ir, arg_typs, None) 635 pass 636 canonicalize_array_math(callee_ir, f_typemap, 637 f_calltypes, typingctx) 638 # remove argument entries like arg.a from typemap 639 arg_names = [vname for vname in f_typemap if vname.startswith("arg.")] 640 for a in arg_names: 641 f_typemap.pop(a) 642 typemap.update(f_typemap) 643 calltypes.update(f_calltypes) 644 645 _replace_args_with(callee_blocks, args) 646 # 5. split caller blocks into two 647 new_blocks = [] 648 new_block = ir.Block(scope, block.loc) 649 new_block.body = block.body[i + 1:] 650 new_label = next_label() 651 func_ir.blocks[new_label] = new_block 652 new_blocks.append((new_label, new_block)) 653 block.body = block.body[:i] 654 block.body.append(ir.Jump(min_label, instr.loc)) 655 656 # 6. replace Return with assignment to LHS 657 topo_order = find_topo_order(callee_blocks) 658 _replace_returns(callee_blocks, instr.target, new_label) 659 660 # remove the old definition of instr.target too 661 if (instr.target.name in func_ir._definitions 662 and call_expr in func_ir._definitions[instr.target.name]): 663 # NOTE: target can have multiple definitions due to control flow 664 func_ir._definitions[instr.target.name].remove(call_expr) 665 666 # 7. insert all new blocks, and add back definitions 667 for label in topo_order: 668 # block scope must point to parent's 669 block = callee_blocks[label] 670 block.scope = scope 671 _add_definitions(func_ir, block) 672 func_ir.blocks[label] = block 673 new_blocks.append((label, block)) 674 debug_print("After merge in") 675 _debug_dump(func_ir) 676 677 if work_list is not None: 678 for block in new_blocks: 679 work_list.append(block) 680 return callee_blocks, var_dict 681 682 683def _get_callee_args(call_expr, callee, loc, func_ir): 684 """Get arguments for calling 'callee', including the default arguments. 685 keyword arguments are currently only handled when 'callee' is a function. 686 """ 687 if call_expr.op == 'call': 688 args = list(call_expr.args) 689 elif call_expr.op == 'getattr': 690 args = [call_expr.value] 691 else: 692 raise TypeError("Unsupported ir.Expr.{}".format(call_expr.op)) 693 694 debug_print = _make_debug_print("inline_closure_call default handling") 695 696 # handle defaults and kw arguments using pysignature if callee is function 697 if isinstance(callee, pytypes.FunctionType): 698 pysig = numba.core.utils.pysignature(callee) 699 normal_handler = lambda index, param, default: default 700 default_handler = lambda index, param, default: ir.Const(default, loc) 701 # Throw error for stararg 702 # TODO: handle stararg 703 def stararg_handler(index, param, default): 704 raise NotImplementedError( 705 "Stararg not supported in inliner for arg {} {}".format( 706 index, param)) 707 if call_expr.op == 'call': 708 kws = dict(call_expr.kws) 709 else: 710 kws = {} 711 return numba.core.typing.fold_arguments( 712 pysig, args, kws, normal_handler, default_handler, 713 stararg_handler) 714 else: 715 # TODO: handle arguments for make_function case similar to function 716 # case above 717 callee_defaults = (callee.defaults if hasattr(callee, 'defaults') 718 else callee.__defaults__) 719 if callee_defaults: 720 debug_print("defaults = ", callee_defaults) 721 if isinstance(callee_defaults, tuple): # Python 3.5 722 defaults_list = [] 723 for x in callee_defaults: 724 if isinstance(x, ir.Var): 725 defaults_list.append(x) 726 else: 727 # this branch is predominantly for kwargs from 728 # inlinable functions 729 defaults_list.append(ir.Const(value=x, loc=loc)) 730 args = args + defaults_list 731 elif (isinstance(callee_defaults, ir.Var) 732 or isinstance(callee_defaults, str)): 733 default_tuple = func_ir.get_definition(callee_defaults) 734 assert(isinstance(default_tuple, ir.Expr)) 735 assert(default_tuple.op == "build_tuple") 736 const_vals = [func_ir.get_definition(x) for 737 x in default_tuple.items] 738 args = args + const_vals 739 else: 740 raise NotImplementedError( 741 "Unsupported defaults to make_function: {}".format( 742 defaults)) 743 return args 744 745 746def _make_debug_print(prefix): 747 def debug_print(*args): 748 if config.DEBUG_INLINE_CLOSURE: 749 print(prefix + ": " + "".join(str(x) for x in args)) 750 return debug_print 751 752 753def _debug_dump(func_ir): 754 if config.DEBUG_INLINE_CLOSURE: 755 func_ir.dump() 756 757 758def _get_all_scopes(blocks): 759 """Get all block-local scopes from an IR. 760 """ 761 all_scopes = [] 762 for label, block in blocks.items(): 763 if not (block.scope in all_scopes): 764 all_scopes.append(block.scope) 765 return all_scopes 766 767 768def _replace_args_with(blocks, args): 769 """ 770 Replace ir.Arg(...) with real arguments from call site 771 """ 772 for label, block in blocks.items(): 773 assigns = block.find_insts(ir.Assign) 774 for stmt in assigns: 775 if isinstance(stmt.value, ir.Arg): 776 idx = stmt.value.index 777 assert(idx < len(args)) 778 stmt.value = args[idx] 779 780 781def _replace_freevars(blocks, args): 782 """ 783 Replace ir.FreeVar(...) with real variables from parent function 784 """ 785 for label, block in blocks.items(): 786 assigns = block.find_insts(ir.Assign) 787 for stmt in assigns: 788 if isinstance(stmt.value, ir.FreeVar): 789 idx = stmt.value.index 790 assert(idx < len(args)) 791 if isinstance(args[idx], ir.Var): 792 stmt.value = args[idx] 793 else: 794 stmt.value = ir.Const(args[idx], stmt.loc) 795 796 797def _replace_returns(blocks, target, return_label): 798 """ 799 Return return statement by assigning directly to target, and a jump. 800 """ 801 for label, block in blocks.items(): 802 casts = [] 803 for i in range(len(block.body)): 804 stmt = block.body[i] 805 if isinstance(stmt, ir.Return): 806 assert(i + 1 == len(block.body)) 807 block.body[i] = ir.Assign(stmt.value, target, stmt.loc) 808 block.body.append(ir.Jump(return_label, stmt.loc)) 809 # remove cast of the returned value 810 for cast in casts: 811 if cast.target.name == stmt.value.name: 812 cast.value = cast.value.value 813 elif isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'cast': 814 casts.append(stmt) 815 816def _add_definitions(func_ir, block): 817 """ 818 Add variable definitions found in a block to parent func_ir. 819 """ 820 definitions = func_ir._definitions 821 assigns = block.find_insts(ir.Assign) 822 for stmt in assigns: 823 definitions[stmt.target.name].append(stmt.value) 824 825def _find_arraycall(func_ir, block): 826 """Look for statement like "x = numpy.array(y)" or "x[..] = y" 827 immediately after the closure call that creates list y (the i-th 828 statement in block). Return the statement index if found, or 829 raise GuardException. 830 """ 831 array_var = None 832 array_call_index = None 833 list_var_dead_after_array_call = False 834 list_var = None 835 836 i = 0 837 while i < len(block.body): 838 instr = block.body[i] 839 if isinstance(instr, ir.Del): 840 # Stop the process if list_var becomes dead 841 if list_var and array_var and instr.value == list_var.name: 842 list_var_dead_after_array_call = True 843 break 844 pass 845 elif isinstance(instr, ir.Assign): 846 # Found array_var = array(list_var) 847 lhs = instr.target 848 expr = instr.value 849 if (guard(find_callname, func_ir, expr) == ('array', 'numpy') and 850 isinstance(expr.args[0], ir.Var)): 851 list_var = expr.args[0] 852 array_var = lhs 853 array_stmt_index = i 854 array_kws = dict(expr.kws) 855 elif (isinstance(instr, ir.SetItem) and 856 isinstance(instr.value, ir.Var) and 857 not list_var): 858 list_var = instr.value 859 # Found array_var[..] = list_var, the case for nested array 860 array_var = instr.target 861 array_def = get_definition(func_ir, array_var) 862 require(guard(_find_unsafe_empty_inferred, func_ir, array_def)) 863 array_stmt_index = i 864 array_kws = {} 865 else: 866 # Bail out otherwise 867 break 868 i = i + 1 869 # require array_var is found, and list_var is dead after array_call. 870 require(array_var and list_var_dead_after_array_call) 871 _make_debug_print("find_array_call")(block.body[array_stmt_index]) 872 return list_var, array_stmt_index, array_kws 873 874 875def _find_iter_range(func_ir, range_iter_var, swapped): 876 """Find the iterator's actual range if it is either range(n), or range(m, n), 877 otherwise return raise GuardException. 878 """ 879 debug_print = _make_debug_print("find_iter_range") 880 range_iter_def = get_definition(func_ir, range_iter_var) 881 debug_print("range_iter_var = ", range_iter_var, " def = ", range_iter_def) 882 require(isinstance(range_iter_def, ir.Expr) and range_iter_def.op == 'getiter') 883 range_var = range_iter_def.value 884 range_def = get_definition(func_ir, range_var) 885 debug_print("range_var = ", range_var, " range_def = ", range_def) 886 require(isinstance(range_def, ir.Expr) and range_def.op == 'call') 887 func_var = range_def.func 888 func_def = get_definition(func_ir, func_var) 889 debug_print("func_var = ", func_var, " func_def = ", func_def) 890 require(isinstance(func_def, ir.Global) and 891 (func_def.value == range or func_def.value == numba.misc.special.prange)) 892 nargs = len(range_def.args) 893 swapping = [('"array comprehension"', 'closure of'), range_def.func.loc] 894 if nargs == 1: 895 swapped[range_def.func.name] = swapping 896 stop = get_definition(func_ir, range_def.args[0], lhs_only=True) 897 return (0, range_def.args[0], func_def) 898 elif nargs == 2: 899 swapped[range_def.func.name] = swapping 900 start = get_definition(func_ir, range_def.args[0], lhs_only=True) 901 stop = get_definition(func_ir, range_def.args[1], lhs_only=True) 902 return (start, stop, func_def) 903 else: 904 raise GuardException 905 906def _inline_arraycall(func_ir, cfg, visited, loop, swapped, enable_prange=False, 907 typed=False): 908 """Look for array(list) call in the exit block of a given loop, and turn list operations into 909 array operations in the loop if the following conditions are met: 910 1. The exit block contains an array call on the list; 911 2. The list variable is no longer live after array call; 912 3. The list is created in the loop entry block; 913 4. The loop is created from an range iterator whose length is known prior to the loop; 914 5. There is only one list_append operation on the list variable in the loop body; 915 6. The block that contains list_append dominates the loop head, which ensures list 916 length is the same as loop length; 917 If any condition check fails, no modification will be made to the incoming IR. 918 """ 919 debug_print = _make_debug_print("inline_arraycall") 920 # There should only be one loop exit 921 require(len(loop.exits) == 1) 922 exit_block = next(iter(loop.exits)) 923 list_var, array_call_index, array_kws = _find_arraycall(func_ir, func_ir.blocks[exit_block]) 924 925 # check if dtype is present in array call 926 dtype_def = None 927 dtype_mod_def = None 928 if 'dtype' in array_kws: 929 require(isinstance(array_kws['dtype'], ir.Var)) 930 # We require that dtype argument to be a constant of getattr Expr, and we'll 931 # remember its definition for later use. 932 dtype_def = get_definition(func_ir, array_kws['dtype']) 933 require(isinstance(dtype_def, ir.Expr) and dtype_def.op == 'getattr') 934 dtype_mod_def = get_definition(func_ir, dtype_def.value) 935 936 list_var_def = get_definition(func_ir, list_var) 937 debug_print("list_var = ", list_var, " def = ", list_var_def) 938 if isinstance(list_var_def, ir.Expr) and list_var_def.op == 'cast': 939 list_var_def = get_definition(func_ir, list_var_def.value) 940 # Check if the definition is a build_list 941 require(isinstance(list_var_def, ir.Expr) and list_var_def.op == 'build_list') 942 # The build_list must be empty 943 require(len(list_var_def.items) == 0) 944 945 # Look for list_append in "last" block in loop body, which should be a block that is 946 # a post-dominator of the loop header. 947 list_append_stmts = [] 948 for label in loop.body: 949 # We have to consider blocks of this loop, but not sub-loops. 950 # To achieve this, we require the set of "in_loops" of "label" to be visited loops. 951 in_visited_loops = [l.header in visited for l in cfg.in_loops(label)] 952 if not all(in_visited_loops): 953 continue 954 block = func_ir.blocks[label] 955 debug_print("check loop body block ", label) 956 for stmt in block.find_insts(ir.Assign): 957 lhs = stmt.target 958 expr = stmt.value 959 if isinstance(expr, ir.Expr) and expr.op == 'call': 960 func_def = get_definition(func_ir, expr.func) 961 if isinstance(func_def, ir.Expr) and func_def.op == 'getattr' \ 962 and func_def.attr == 'append': 963 list_def = get_definition(func_ir, func_def.value) 964 debug_print("list_def = ", list_def, list_def is list_var_def) 965 if list_def is list_var_def: 966 # found matching append call 967 list_append_stmts.append((label, block, stmt)) 968 969 # Require only one list_append, otherwise we won't know the indices 970 require(len(list_append_stmts) == 1) 971 append_block_label, append_block, append_stmt = list_append_stmts[0] 972 973 # Check if append_block (besides loop entry) dominates loop header. 974 # Since CFG doesn't give us this info without loop entry, we approximate 975 # by checking if the predecessor set of the header block is the same 976 # as loop_entries plus append_block, which is certainly more restrictive 977 # than necessary, and can be relaxed if needed. 978 preds = set(l for l, b in cfg.predecessors(loop.header)) 979 debug_print("preds = ", preds, (loop.entries | set([append_block_label]))) 980 require(preds == (loop.entries | set([append_block_label]))) 981 982 # Find iterator in loop header 983 iter_vars = [] 984 iter_first_vars = [] 985 loop_header = func_ir.blocks[loop.header] 986 for stmt in loop_header.find_insts(ir.Assign): 987 expr = stmt.value 988 if isinstance(expr, ir.Expr): 989 if expr.op == 'iternext': 990 iter_def = get_definition(func_ir, expr.value) 991 debug_print("iter_def = ", iter_def) 992 iter_vars.append(expr.value) 993 elif expr.op == 'pair_first': 994 iter_first_vars.append(stmt.target) 995 996 # Require only one iterator in loop header 997 require(len(iter_vars) == 1 and len(iter_first_vars) == 1) 998 iter_var = iter_vars[0] # variable that holds the iterator object 999 iter_first_var = iter_first_vars[0] # variable that holds the value out of iterator 1000 1001 # Final requirement: only one loop entry, and we're going to modify it by: 1002 # 1. replacing the list definition with an array definition; 1003 # 2. adding a counter for the array iteration. 1004 require(len(loop.entries) == 1) 1005 loop_entry = func_ir.blocks[next(iter(loop.entries))] 1006 terminator = loop_entry.terminator 1007 scope = loop_entry.scope 1008 loc = loop_entry.loc 1009 stmts = [] 1010 removed = [] 1011 def is_removed(val, removed): 1012 if isinstance(val, ir.Var): 1013 for x in removed: 1014 if x.name == val.name: 1015 return True 1016 return False 1017 # Skip list construction and skip terminator, add the rest to stmts 1018 for i in range(len(loop_entry.body) - 1): 1019 stmt = loop_entry.body[i] 1020 if isinstance(stmt, ir.Assign) and (stmt.value is list_def or is_removed(stmt.value, removed)): 1021 removed.append(stmt.target) 1022 else: 1023 stmts.append(stmt) 1024 debug_print("removed variables: ", removed) 1025 1026 # Define an index_var to index the array. 1027 # If the range happens to be single step ranges like range(n), or range(m, n), 1028 # then the index_var correlates to iterator index; otherwise we'll have to 1029 # define a new counter. 1030 range_def = guard(_find_iter_range, func_ir, iter_var, swapped) 1031 index_var = ir.Var(scope, mk_unique_var("index"), loc) 1032 if range_def and range_def[0] == 0: 1033 # iterator starts with 0, index_var can just be iter_first_var 1034 index_var = iter_first_var 1035 else: 1036 # index_var = -1 # starting the index with -1 since it will incremented in loop header 1037 stmts.append(_new_definition(func_ir, index_var, ir.Const(value=-1, loc=loc), loc)) 1038 1039 # Insert statement to get the size of the loop iterator 1040 size_var = ir.Var(scope, mk_unique_var("size"), loc) 1041 if range_def: 1042 start, stop, range_func_def = range_def 1043 if start == 0: 1044 size_val = stop 1045 else: 1046 size_val = ir.Expr.binop(fn=operator.sub, lhs=stop, rhs=start, loc=loc) 1047 # we can parallelize this loop if enable_prange = True, by changing 1048 # range function from range, to prange. 1049 if enable_prange and isinstance(range_func_def, ir.Global): 1050 range_func_def.name = 'internal_prange' 1051 range_func_def.value = internal_prange 1052 1053 else: 1054 # this doesn't work in objmode as it's effectively untyped 1055 if typed: 1056 len_func_var = ir.Var(scope, mk_unique_var("len_func"), loc) 1057 stmts.append(_new_definition(func_ir, len_func_var, 1058 ir.Global('range_iter_len', range_iter_len, loc=loc), 1059 loc)) 1060 size_val = ir.Expr.call(len_func_var, (iter_var,), (), loc=loc) 1061 else: 1062 raise GuardException 1063 1064 1065 stmts.append(_new_definition(func_ir, size_var, size_val, loc)) 1066 1067 size_tuple_var = ir.Var(scope, mk_unique_var("size_tuple"), loc) 1068 stmts.append(_new_definition(func_ir, size_tuple_var, 1069 ir.Expr.build_tuple(items=[size_var], loc=loc), loc)) 1070 1071 # Insert array allocation 1072 array_var = ir.Var(scope, mk_unique_var("array"), loc) 1073 empty_func = ir.Var(scope, mk_unique_var("empty_func"), loc) 1074 if dtype_def and dtype_mod_def: 1075 # when dtype is present, we'll call empty with dtype 1076 dtype_mod_var = ir.Var(scope, mk_unique_var("dtype_mod"), loc) 1077 dtype_var = ir.Var(scope, mk_unique_var("dtype"), loc) 1078 stmts.append(_new_definition(func_ir, dtype_mod_var, dtype_mod_def, loc)) 1079 stmts.append(_new_definition(func_ir, dtype_var, 1080 ir.Expr.getattr(dtype_mod_var, dtype_def.attr, loc), loc)) 1081 stmts.append(_new_definition(func_ir, empty_func, 1082 ir.Global('empty', np.empty, loc=loc), loc)) 1083 array_kws = [('dtype', dtype_var)] 1084 else: 1085 # this doesn't work in objmode as it's effectively untyped 1086 if typed: 1087 # otherwise we'll call unsafe_empty_inferred 1088 stmts.append(_new_definition(func_ir, empty_func, 1089 ir.Global('unsafe_empty_inferred', 1090 unsafe_empty_inferred, loc=loc), loc)) 1091 array_kws = [] 1092 else: 1093 raise GuardException 1094 1095 # array_var = empty_func(size_tuple_var) 1096 stmts.append(_new_definition(func_ir, array_var, 1097 ir.Expr.call(empty_func, (size_tuple_var,), list(array_kws), loc=loc), loc)) 1098 1099 # Add back removed just in case they are used by something else 1100 for var in removed: 1101 stmts.append(_new_definition(func_ir, var, array_var, loc)) 1102 1103 # Add back terminator 1104 stmts.append(terminator) 1105 # Modify loop_entry 1106 loop_entry.body = stmts 1107 1108 if range_def: 1109 if range_def[0] != 0: 1110 # when range doesn't start from 0, index_var becomes loop index 1111 # (iter_first_var) minus an offset (range_def[0]) 1112 terminator = loop_header.terminator 1113 assert(isinstance(terminator, ir.Branch)) 1114 # find the block in the loop body that header jumps to 1115 block_id = terminator.truebr 1116 blk = func_ir.blocks[block_id] 1117 loc = blk.loc 1118 blk.body.insert(0, _new_definition(func_ir, index_var, 1119 ir.Expr.binop(fn=operator.sub, lhs=iter_first_var, 1120 rhs=range_def[0], loc=loc), 1121 loc)) 1122 else: 1123 # Insert index_var increment to the end of loop header 1124 loc = loop_header.loc 1125 terminator = loop_header.terminator 1126 stmts = loop_header.body[0:-1] 1127 next_index_var = ir.Var(scope, mk_unique_var("next_index"), loc) 1128 one = ir.Var(scope, mk_unique_var("one"), loc) 1129 # one = 1 1130 stmts.append(_new_definition(func_ir, one, 1131 ir.Const(value=1,loc=loc), loc)) 1132 # next_index_var = index_var + 1 1133 stmts.append(_new_definition(func_ir, next_index_var, 1134 ir.Expr.binop(fn=operator.add, lhs=index_var, rhs=one, loc=loc), loc)) 1135 # index_var = next_index_var 1136 stmts.append(_new_definition(func_ir, index_var, next_index_var, loc)) 1137 stmts.append(terminator) 1138 loop_header.body = stmts 1139 1140 # In append_block, change list_append into array assign 1141 for i in range(len(append_block.body)): 1142 if append_block.body[i] is append_stmt: 1143 debug_print("Replace append with SetItem") 1144 append_block.body[i] = ir.SetItem(target=array_var, index=index_var, 1145 value=append_stmt.value.args[0], loc=append_stmt.loc) 1146 1147 # replace array call, by changing "a = array(b)" to "a = b" 1148 stmt = func_ir.blocks[exit_block].body[array_call_index] 1149 # stmt can be either array call or SetItem, we only replace array call 1150 if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr): 1151 stmt.value = array_var 1152 func_ir._definitions[stmt.target.name] = [stmt.value] 1153 1154 return True 1155 1156 1157def _find_unsafe_empty_inferred(func_ir, expr): 1158 unsafe_empty_inferred 1159 require(isinstance(expr, ir.Expr) and expr.op == 'call') 1160 callee = expr.func 1161 callee_def = get_definition(func_ir, callee) 1162 require(isinstance(callee_def, ir.Global)) 1163 _make_debug_print("_find_unsafe_empty_inferred")(callee_def.value) 1164 return callee_def.value == unsafe_empty_inferred 1165 1166 1167def _fix_nested_array(func_ir): 1168 """Look for assignment like: a[..] = b, where both a and b are numpy arrays, and 1169 try to eliminate array b by expanding a with an extra dimension. 1170 """ 1171 blocks = func_ir.blocks 1172 cfg = compute_cfg_from_blocks(blocks) 1173 usedefs = compute_use_defs(blocks) 1174 empty_deadmap = dict([(label, set()) for label in blocks.keys()]) 1175 livemap = compute_live_variables(cfg, blocks, usedefs.defmap, empty_deadmap) 1176 1177 def find_array_def(arr): 1178 """Find numpy array definition such as 1179 arr = numba.unsafe.ndarray.empty_inferred(...). 1180 If it is arr = b[...], find array definition of b recursively. 1181 """ 1182 arr_def = get_definition(func_ir, arr) 1183 _make_debug_print("find_array_def")(arr, arr_def) 1184 if isinstance(arr_def, ir.Expr): 1185 if guard(_find_unsafe_empty_inferred, func_ir, arr_def): 1186 return arr_def 1187 elif arr_def.op == 'getitem': 1188 return find_array_def(arr_def.value) 1189 raise GuardException 1190 1191 def fix_dependencies(expr, varlist): 1192 """Double check if all variables in varlist are defined before 1193 expr is used. Try to move constant definition when the check fails. 1194 Bails out by raising GuardException if it can't be moved. 1195 """ 1196 debug_print = _make_debug_print("fix_dependencies") 1197 for label, block in blocks.items(): 1198 scope = block.scope 1199 body = block.body 1200 defined = set() 1201 for i in range(len(body)): 1202 inst = body[i] 1203 if isinstance(inst, ir.Assign): 1204 defined.add(inst.target.name) 1205 if inst.value is expr: 1206 new_varlist = [] 1207 for var in varlist: 1208 # var must be defined before this inst, or live 1209 # and not later defined. 1210 if (var.name in defined or 1211 (var.name in livemap[label] and 1212 not (var.name in usedefs.defmap[label]))): 1213 debug_print(var.name, " already defined") 1214 new_varlist.append(var) 1215 else: 1216 debug_print(var.name, " not yet defined") 1217 var_def = get_definition(func_ir, var.name) 1218 if isinstance(var_def, ir.Const): 1219 loc = var.loc 1220 new_var = ir.Var(scope, mk_unique_var("new_var"), loc) 1221 new_const = ir.Const(var_def.value, loc) 1222 new_vardef = _new_definition(func_ir, 1223 new_var, new_const, loc) 1224 new_body = [] 1225 new_body.extend(body[:i]) 1226 new_body.append(new_vardef) 1227 new_body.extend(body[i:]) 1228 block.body = new_body 1229 new_varlist.append(new_var) 1230 else: 1231 raise GuardException 1232 return new_varlist 1233 # when expr is not found in block 1234 raise GuardException 1235 1236 def fix_array_assign(stmt): 1237 """For assignment like lhs[idx] = rhs, where both lhs and rhs are arrays, do the 1238 following: 1239 1. find the definition of rhs, which has to be a call to numba.unsafe.ndarray.empty_inferred 1240 2. find the source array creation for lhs, insert an extra dimension of size of b. 1241 3. replace the definition of rhs = numba.unsafe.ndarray.empty_inferred(...) with rhs = lhs[idx] 1242 """ 1243 require(isinstance(stmt, ir.SetItem)) 1244 require(isinstance(stmt.value, ir.Var)) 1245 debug_print = _make_debug_print("fix_array_assign") 1246 debug_print("found SetItem: ", stmt) 1247 lhs = stmt.target 1248 # Find the source array creation of lhs 1249 lhs_def = find_array_def(lhs) 1250 debug_print("found lhs_def: ", lhs_def) 1251 rhs_def = get_definition(func_ir, stmt.value) 1252 debug_print("found rhs_def: ", rhs_def) 1253 require(isinstance(rhs_def, ir.Expr)) 1254 if rhs_def.op == 'cast': 1255 rhs_def = get_definition(func_ir, rhs_def.value) 1256 require(isinstance(rhs_def, ir.Expr)) 1257 require(_find_unsafe_empty_inferred(func_ir, rhs_def)) 1258 # Find the array dimension of rhs 1259 dim_def = get_definition(func_ir, rhs_def.args[0]) 1260 require(isinstance(dim_def, ir.Expr) and dim_def.op == 'build_tuple') 1261 debug_print("dim_def = ", dim_def) 1262 extra_dims = [ get_definition(func_ir, x, lhs_only=True) for x in dim_def.items ] 1263 debug_print("extra_dims = ", extra_dims) 1264 # Expand size tuple when creating lhs_def with extra_dims 1265 size_tuple_def = get_definition(func_ir, lhs_def.args[0]) 1266 require(isinstance(size_tuple_def, ir.Expr) and size_tuple_def.op == 'build_tuple') 1267 debug_print("size_tuple_def = ", size_tuple_def) 1268 extra_dims = fix_dependencies(size_tuple_def, extra_dims) 1269 size_tuple_def.items += extra_dims 1270 # In-place modify rhs_def to be getitem 1271 rhs_def.op = 'getitem' 1272 rhs_def.value = get_definition(func_ir, lhs, lhs_only=True) 1273 rhs_def.index = stmt.index 1274 del rhs_def._kws['func'] 1275 del rhs_def._kws['args'] 1276 del rhs_def._kws['vararg'] 1277 del rhs_def._kws['kws'] 1278 # success 1279 return True 1280 1281 for label in find_topo_order(func_ir.blocks): 1282 block = func_ir.blocks[label] 1283 for stmt in block.body: 1284 if guard(fix_array_assign, stmt): 1285 block.body.remove(stmt) 1286 1287def _new_definition(func_ir, var, value, loc): 1288 func_ir._definitions[var.name] = [value] 1289 return ir.Assign(value=value, target=var, loc=loc) 1290 1291@rewrites.register_rewrite('after-inference') 1292class RewriteArrayOfConsts(rewrites.Rewrite): 1293 '''The RewriteArrayOfConsts class is responsible for finding 1294 1D array creations from a constant list, and rewriting it into 1295 direct initialization of array elements without creating the list. 1296 ''' 1297 def __init__(self, state, *args, **kws): 1298 self.typingctx = state.typingctx 1299 super(RewriteArrayOfConsts, self).__init__(*args, **kws) 1300 1301 def match(self, func_ir, block, typemap, calltypes): 1302 if len(calltypes) == 0: 1303 return False 1304 self.crnt_block = block 1305 self.new_body = guard(_inline_const_arraycall, block, func_ir, 1306 self.typingctx, typemap, calltypes) 1307 return self.new_body is not None 1308 1309 def apply(self): 1310 self.crnt_block.body = self.new_body 1311 return self.crnt_block 1312 1313 1314def _inline_const_arraycall(block, func_ir, context, typemap, calltypes): 1315 """Look for array(list) call where list is a constant list created by build_list, 1316 and turn them into direct array creation and initialization, if the following 1317 conditions are met: 1318 1. The build_list call immediate precedes the array call; 1319 2. The list variable is no longer live after array call; 1320 If any condition check fails, no modification will be made. 1321 """ 1322 debug_print = _make_debug_print("inline_const_arraycall") 1323 scope = block.scope 1324 1325 def inline_array(array_var, expr, stmts, list_vars, dels): 1326 """Check to see if the given "array_var" is created from a list 1327 of constants, and try to inline the list definition as array 1328 initialization. 1329 1330 Extra statements produced with be appended to "stmts". 1331 """ 1332 callname = guard(find_callname, func_ir, expr) 1333 require(callname and callname[1] == 'numpy' and callname[0] == 'array') 1334 require(expr.args[0].name in list_vars) 1335 ret_type = calltypes[expr].return_type 1336 require(isinstance(ret_type, types.ArrayCompatible) and 1337 ret_type.ndim == 1) 1338 loc = expr.loc 1339 list_var = expr.args[0] 1340 # Get the type of the array to be created. 1341 array_typ = typemap[array_var.name] 1342 debug_print("inline array_var = ", array_var, " list_var = ", list_var) 1343 # Get the element type of the array to be created. 1344 dtype = array_typ.dtype 1345 # Get the sequence of operations to provide values to the new array. 1346 seq, _ = find_build_sequence(func_ir, list_var) 1347 size = len(seq) 1348 # Create a tuple to pass to empty below to specify the new array size. 1349 size_var = ir.Var(scope, mk_unique_var("size"), loc) 1350 size_tuple_var = ir.Var(scope, mk_unique_var("size_tuple"), loc) 1351 size_typ = types.intp 1352 size_tuple_typ = types.UniTuple(size_typ, 1) 1353 typemap[size_var.name] = size_typ 1354 typemap[size_tuple_var.name] = size_tuple_typ 1355 stmts.append(_new_definition(func_ir, size_var, 1356 ir.Const(size, loc=loc), loc)) 1357 stmts.append(_new_definition(func_ir, size_tuple_var, 1358 ir.Expr.build_tuple(items=[size_var], loc=loc), loc)) 1359 1360 # The general approach is to create an empty array and then fill 1361 # the elements in one-by-one from their specificiation. 1362 1363 # Get the numpy type to pass to empty. 1364 nptype = types.DType(dtype) 1365 1366 # Create a variable to hold the numpy empty function. 1367 empty_func = ir.Var(scope, mk_unique_var("empty_func"), loc) 1368 fnty = get_np_ufunc_typ(np.empty) 1369 sig = context.resolve_function_type(fnty, (size_typ,), {'dtype':nptype}) 1370 1371 typemap[empty_func.name] = fnty 1372 1373 stmts.append(_new_definition(func_ir, empty_func, 1374 ir.Global('empty', np.empty, loc=loc), loc)) 1375 1376 # We pass two arguments to empty, first the size tuple and second 1377 # the dtype of the new array. Here, we created typ_var which is 1378 # the dtype argument of the new array. typ_var in turn is created 1379 # by getattr of the dtype string on the numpy module. 1380 1381 # Create var for numpy module. 1382 g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) 1383 typemap[g_np_var.name] = types.misc.Module(np) 1384 g_np = ir.Global('np', np, loc) 1385 stmts.append(_new_definition(func_ir, g_np_var, g_np, loc)) 1386 1387 # Create var for result of numpy.<dtype>. 1388 typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc) 1389 typemap[typ_var.name] = nptype 1390 dtype_str = str(dtype) 1391 if dtype_str == 'bool': 1392 dtype_str = 'bool_' 1393 # Get dtype attribute of numpy module. 1394 np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc) 1395 stmts.append(_new_definition(func_ir, typ_var, np_typ_getattr, loc)) 1396 1397 # Create the call to numpy.empty passing the size tuple and dtype var. 1398 empty_call = ir.Expr.call(empty_func, [size_var, typ_var], {}, loc=loc) 1399 calltypes[empty_call] = typing.signature(array_typ, size_typ, nptype) 1400 stmts.append(_new_definition(func_ir, array_var, empty_call, loc)) 1401 1402 # Fill in the new empty array one-by-one. 1403 for i in range(size): 1404 index_var = ir.Var(scope, mk_unique_var("index"), loc) 1405 index_typ = types.intp 1406 typemap[index_var.name] = index_typ 1407 stmts.append(_new_definition(func_ir, index_var, 1408 ir.Const(i, loc), loc)) 1409 setitem = ir.SetItem(array_var, index_var, seq[i], loc) 1410 calltypes[setitem] = typing.signature(types.none, array_typ, 1411 index_typ, dtype) 1412 stmts.append(setitem) 1413 1414 stmts.extend(dels) 1415 return True 1416 1417 class State(object): 1418 """ 1419 This class is used to hold the state in the following loop so as to make 1420 it easy to reset the state of the variables tracking the various 1421 statement kinds 1422 """ 1423 1424 def __init__(self): 1425 # list_vars keep track of the variable created from the latest 1426 # build_list instruction, as well as its synonyms. 1427 self.list_vars = [] 1428 # dead_vars keep track of those in list_vars that are considered dead. 1429 self.dead_vars = [] 1430 # list_items keep track of the elements used in build_list. 1431 self.list_items = [] 1432 self.stmts = [] 1433 # dels keep track of the deletion of list_items, which will need to be 1434 # moved after array initialization. 1435 self.dels = [] 1436 # tracks if a modification has taken place 1437 self.modified = False 1438 1439 def reset(self): 1440 """ 1441 Resets the internal state of the variables used for tracking 1442 """ 1443 self.list_vars = [] 1444 self.dead_vars = [] 1445 self.list_items = [] 1446 self.dels = [] 1447 1448 def list_var_used(self, inst): 1449 """ 1450 Returns True if the list being analysed is used between the 1451 build_list and the array call. 1452 """ 1453 return any([x.name in self.list_vars for x in inst.list_vars()]) 1454 1455 state = State() 1456 1457 for inst in block.body: 1458 if isinstance(inst, ir.Assign): 1459 if isinstance(inst.value, ir.Var): 1460 if inst.value.name in state.list_vars: 1461 state.list_vars.append(inst.target.name) 1462 state.stmts.append(inst) 1463 continue 1464 elif isinstance(inst.value, ir.Expr): 1465 expr = inst.value 1466 if expr.op == 'build_list': 1467 # new build_list encountered, reset state 1468 state.reset() 1469 state.list_items = [x.name for x in expr.items] 1470 state.list_vars = [inst.target.name] 1471 state.stmts.append(inst) 1472 continue 1473 elif expr.op == 'call' and expr in calltypes: 1474 arr_var = inst.target 1475 if guard(inline_array, inst.target, expr, 1476 state.stmts, state.list_vars, state.dels): 1477 state.modified = True 1478 continue 1479 elif isinstance(inst, ir.Del): 1480 removed_var = inst.value 1481 if removed_var in state.list_items: 1482 state.dels.append(inst) 1483 continue 1484 elif removed_var in state.list_vars: 1485 # one of the list_vars is considered dead. 1486 state.dead_vars.append(removed_var) 1487 state.list_vars.remove(removed_var) 1488 state.stmts.append(inst) 1489 if state.list_vars == []: 1490 # if all list_vars are considered dead, we need to filter 1491 # them out from existing stmts to completely remove 1492 # build_list. 1493 # Note that if a translation didn't take place, dead_vars 1494 # will also be empty when we reach this point. 1495 body = [] 1496 for inst in state.stmts: 1497 if ((isinstance(inst, ir.Assign) and 1498 inst.target.name in state.dead_vars) or 1499 (isinstance(inst, ir.Del) and 1500 inst.value in state.dead_vars)): 1501 continue 1502 body.append(inst) 1503 state.stmts = body 1504 state.dead_vars = [] 1505 state.modified = True 1506 continue 1507 state.stmts.append(inst) 1508 1509 # If the list is used in any capacity between build_list and array 1510 # call, then we must call off the translation for this list because 1511 # it could be mutated and list_items would no longer be applicable. 1512 if state.list_var_used(inst): 1513 state.reset() 1514 1515 return state.stmts if state.modified else None 1516