1import types as pytypes  # avoid confusion with numba.types
2import copy
3import ctypes
4import numba.core.analysis
5from numba.core import utils, types, typing, errors, ir, rewrites, config, ir_utils
6from numba import prange
7from numba.parfors.parfor import internal_prange
8from numba.core.ir_utils import (
9    mk_unique_var,
10    next_label,
11    add_offset_to_labels,
12    replace_vars,
13    remove_dels,
14    rename_labels,
15    find_topo_order,
16    merge_adjacent_blocks,
17    GuardException,
18    require,
19    guard,
20    get_definition,
21    find_callname,
22    find_build_sequence,
23    get_np_ufunc_typ,
24    get_ir_of_code,
25    simplify_CFG,
26    canonicalize_array_math,
27    dead_code_elimination,
28    )
29
30from numba.core.analysis import (
31    compute_cfg_from_blocks,
32    compute_use_defs,
33    compute_live_variables)
34from numba.core import postproc
35from numba.cpython.rangeobj import range_iter_len
36from numba.np.unsafe.ndarray import empty_inferred as unsafe_empty_inferred
37import numpy as np
38import operator
39import numba.misc.special
40
41"""
42Variable enable_inline_arraycall is only used for testing purpose.
43"""
44enable_inline_arraycall = True
45
46
47def callee_ir_validator(func_ir):
48    """Checks the IR of a callee is supported for inlining
49    """
50    for blk in func_ir.blocks.values():
51        for stmt in blk.find_insts(ir.Assign):
52            if isinstance(stmt.value, ir.Yield):
53                msg = "The use of yield in a closure is unsupported."
54                raise errors.UnsupportedError(msg, loc=stmt.loc)
55
56
57class InlineClosureCallPass(object):
58    """InlineClosureCallPass class looks for direct calls to locally defined
59    closures, and inlines the body of the closure function to the call site.
60    """
61
62    def __init__(self, func_ir, parallel_options, swapped={}, typed=False):
63        self.func_ir = func_ir
64        self.parallel_options = parallel_options
65        self.swapped = swapped
66        self._processed_stencils = []
67        self.typed = typed
68
69    def run(self):
70        """Run inline closure call pass.
71        """
72        # Analysis relies on ir.Del presence, strip out later
73        pp = postproc.PostProcessor(self.func_ir)
74        pp.run(True)
75
76        modified = False
77        work_list = list(self.func_ir.blocks.items())
78        debug_print = _make_debug_print("InlineClosureCallPass")
79        debug_print("START")
80        while work_list:
81            label, block = work_list.pop()
82            for i, instr in enumerate(block.body):
83                if isinstance(instr, ir.Assign):
84                    lhs = instr.target
85                    expr = instr.value
86                    if isinstance(expr, ir.Expr) and expr.op == 'call':
87                        call_name = guard(find_callname, self.func_ir, expr)
88                        func_def = guard(get_definition, self.func_ir, expr.func)
89
90                        if guard(self._inline_reduction,
91                                 work_list, block, i, expr, call_name):
92                            modified = True
93                            break # because block structure changed
94
95                        if guard(self._inline_closure,
96                                work_list, block, i, func_def):
97                            modified = True
98                            break # because block structure changed
99
100                        if guard(self._inline_stencil,
101                                instr, call_name, func_def):
102                            modified = True
103
104        if enable_inline_arraycall:
105            # Identify loop structure
106            if modified:
107                # Need to do some cleanups if closure inlining kicked in
108                merge_adjacent_blocks(self.func_ir.blocks)
109            cfg = compute_cfg_from_blocks(self.func_ir.blocks)
110            debug_print("start inline arraycall")
111            _debug_dump(cfg)
112            loops = cfg.loops()
113            sized_loops = [(k, len(loops[k].body)) for k in loops.keys()]
114            visited = []
115            # We go over all loops, bigger loops first (outer first)
116            for k, s in sorted(sized_loops, key=lambda tup: tup[1], reverse=True):
117                visited.append(k)
118                if guard(_inline_arraycall, self.func_ir, cfg, visited, loops[k],
119                         self.swapped, self.parallel_options.comprehension,
120                         self.typed):
121                    modified = True
122            if modified:
123                _fix_nested_array(self.func_ir)
124
125        if modified:
126            # clean up now dead/unreachable blocks, e.g. unconditionally raising
127            # an exception in an inlined function would render some parts of the
128            # inliner unreachable
129            cfg = compute_cfg_from_blocks(self.func_ir.blocks)
130            for dead in cfg.dead_nodes():
131                del self.func_ir.blocks[dead]
132
133            # run dead code elimination
134            dead_code_elimination(self.func_ir)
135            # do label renaming
136            self.func_ir.blocks = rename_labels(self.func_ir.blocks)
137
138        # inlining done, strip dels
139        remove_dels(self.func_ir.blocks)
140
141        debug_print("END")
142
143    def _inline_reduction(self, work_list, block, i, expr, call_name):
144        # only inline reduction in sequential execution, parallel handling
145        # is done in ParforPass.
146        require(not self.parallel_options.reduction)
147        require(call_name == ('reduce', 'builtins') or
148                call_name == ('reduce', '_functools'))
149        if len(expr.args) not in (2, 3):
150            raise TypeError("invalid reduce call, "
151                            "two arguments are required (optional initial "
152                            "value can also be specified)")
153        check_reduce_func(self.func_ir, expr.args[0])
154        def reduce_func(f, A, v=None):
155            it = iter(A)
156            if v is not None:
157                s = v
158            else:
159                s = next(it)
160            for a in it:
161               s = f(s, a)
162            return s
163        inline_closure_call(self.func_ir,
164                        self.func_ir.func_id.func.__globals__,
165                        block, i, reduce_func, work_list=work_list,
166                        callee_validator=callee_ir_validator)
167        return True
168
169    def _inline_stencil(self, instr, call_name, func_def):
170        from numba.stencils.stencil import StencilFunc
171        lhs = instr.target
172        expr = instr.value
173        # We keep the escaping variables of the stencil kernel
174        # alive by adding them to the actual kernel call as extra
175        # keyword arguments, which is ignored anyway.
176        if (isinstance(func_def, ir.Global) and
177            func_def.name == 'stencil' and
178            isinstance(func_def.value, StencilFunc)):
179            if expr.kws:
180                expr.kws += func_def.value.kws
181            else:
182                expr.kws = func_def.value.kws
183            return True
184        # Otherwise we proceed to check if it is a call to numba.stencil
185        require(call_name == ('stencil', 'numba.stencils.stencil') or
186                call_name == ('stencil', 'numba'))
187        require(expr not in self._processed_stencils)
188        self._processed_stencils.append(expr)
189        if not len(expr.args) == 1:
190            raise ValueError("As a minimum Stencil requires"
191                " a kernel as an argument")
192        stencil_def = guard(get_definition, self.func_ir, expr.args[0])
193        require(isinstance(stencil_def, ir.Expr) and
194                stencil_def.op == "make_function")
195        kernel_ir = get_ir_of_code(self.func_ir.func_id.func.__globals__,
196                stencil_def.code)
197        options = dict(expr.kws)
198        if 'neighborhood' in options:
199            fixed = guard(self._fix_stencil_neighborhood, options)
200            if not fixed:
201               raise ValueError("stencil neighborhood option should be a tuple"
202                        " with constant structure such as ((-w, w),)")
203        if 'index_offsets' in options:
204            fixed = guard(self._fix_stencil_index_offsets, options)
205            if not fixed:
206               raise ValueError("stencil index_offsets option should be a tuple"
207                        " with constant structure such as (offset, )")
208        sf = StencilFunc(kernel_ir, 'constant', options)
209        sf.kws = expr.kws # hack to keep variables live
210        sf_global = ir.Global('stencil', sf, expr.loc)
211        self.func_ir._definitions[lhs.name] = [sf_global]
212        instr.value = sf_global
213        return True
214
215    def _fix_stencil_neighborhood(self, options):
216        """
217        Extract the two-level tuple representing the stencil neighborhood
218        from the program IR to provide a tuple to StencilFunc.
219        """
220        # build_tuple node with neighborhood for each dimension
221        dims_build_tuple = get_definition(self.func_ir, options['neighborhood'])
222        require(hasattr(dims_build_tuple, 'items'))
223        res = []
224        for window_var in dims_build_tuple.items:
225            win_build_tuple = get_definition(self.func_ir, window_var)
226            require(hasattr(win_build_tuple, 'items'))
227            res.append(tuple(win_build_tuple.items))
228        options['neighborhood'] = tuple(res)
229        return True
230
231    def _fix_stencil_index_offsets(self, options):
232        """
233        Extract the tuple representing the stencil index offsets
234        from the program IR to provide to StencilFunc.
235        """
236        offset_tuple = get_definition(self.func_ir, options['index_offsets'])
237        require(hasattr(offset_tuple, 'items'))
238        options['index_offsets'] = tuple(offset_tuple.items)
239        return True
240
241    def _inline_closure(self, work_list, block, i, func_def):
242        require(isinstance(func_def, ir.Expr) and
243                func_def.op == "make_function")
244        inline_closure_call(self.func_ir,
245                            self.func_ir.func_id.func.__globals__,
246                            block, i, func_def, work_list=work_list,
247                            callee_validator=callee_ir_validator)
248        return True
249
250def check_reduce_func(func_ir, func_var):
251    """Checks the function at func_var in func_ir to make sure it's amenable
252    for inlining. Returns the function itself"""
253    reduce_func = guard(get_definition, func_ir, func_var)
254    if reduce_func is None:
255        raise ValueError("Reduce function cannot be found for njit \
256                            analysis")
257    if isinstance(reduce_func, (ir.FreeVar, ir.Global)):
258        if not isinstance(reduce_func.value,
259                          numba.core.registry.CPUDispatcher):
260            raise ValueError("Invalid reduction function")
261        # pull out the python function for inlining
262        reduce_func = reduce_func.value.py_func
263    elif not (hasattr(reduce_func, 'code')
264            or hasattr(reduce_func, '__code__')):
265        raise ValueError("Invalid reduction function")
266    f_code = (reduce_func.code if hasattr(reduce_func, 'code')
267                                    else reduce_func.__code__)
268    if not f_code.co_argcount == 2:
269        raise TypeError("Reduction function should take 2 arguments")
270    return reduce_func
271
272
273class InlineWorker(object):
274    """ A worker class for inlining, this is a more advanced version of
275    `inline_closure_call` in that it permits inlining from function type, Numba
276    IR and code object. It also, runs the entire untyped compiler pipeline on
277    the inlinee to ensure that it is transformed as though it were compiled
278    directly.
279    """
280
281    def __init__(self,
282                 typingctx=None,
283                 targetctx=None,
284                 locals=None,
285                 pipeline=None,
286                 flags=None,
287                 validator=callee_ir_validator,
288                 typemap=None,
289                 calltypes=None):
290        """
291        Instantiate a new InlineWorker, all arguments are optional though some
292        must be supplied together for certain use cases. The methods will refuse
293        to run if the object isn't configured in the manner needed. Args are the
294        same as those in a numba.core.Compiler.state, except the validator which
295        is a function taking Numba IR and validating it for use when inlining
296        (this is optional and really to just provide better error messages about
297        things which the inliner cannot handle like yield in closure).
298        """
299        def check(arg, name):
300            if arg is None:
301                raise TypeError("{} must not be None".format(name))
302
303        from numba.core.compiler import DefaultPassBuilder
304
305        # check the stuff needed to run the more advanced compilation pipeline
306        # is valid if any of it is provided
307        compiler_args = (targetctx, locals, pipeline, flags)
308        compiler_group = [x is not None for x in compiler_args]
309        if any(compiler_group) and not all(compiler_group):
310            check(targetctx, 'targetctx')
311            check(locals, 'locals')
312            check(pipeline, 'pipeline')
313            check(flags, 'flags')
314        elif all(compiler_group):
315            check(typingctx, 'typingctx')
316
317        self._compiler_pipeline = DefaultPassBuilder.define_untyped_pipeline
318
319        self.typingctx = typingctx
320        self.targetctx = targetctx
321        self.locals = locals
322        self.pipeline = pipeline
323        self.flags = flags
324        self.validator = validator
325        self.debug_print = _make_debug_print("InlineWorker")
326
327        # check whether this inliner can also support typemap and calltypes
328        # update and if what's provided is valid
329        pair = (typemap, calltypes)
330        pair_is_none = [x is None for x in pair]
331        if any(pair_is_none) and not all(pair_is_none):
332            msg = ("typemap and calltypes must both be either None or have a "
333                   "value, got: %s, %s")
334            raise TypeError(msg % pair)
335        self._permit_update_type_and_call_maps = not all(pair_is_none)
336        self.typemap = typemap
337        self.calltypes = calltypes
338
339
340    def inline_ir(self, caller_ir, block, i, callee_ir, callee_freevars,
341                  arg_typs=None):
342        """ Inlines the callee_ir in the caller_ir at statement index i of block
343        `block`, callee_freevars are the free variables for the callee_ir. If
344        the callee_ir is derived from a function `func` then this is
345        `func.__code__.co_freevars`. If `arg_typs` is given and the InlineWorker
346        instance was initialized with a typemap and calltypes then they will be
347        appropriately updated based on the arg_typs.
348        """
349
350        # Always copy the callee IR, it gets mutated
351        def copy_ir(the_ir):
352            kernel_copy = the_ir.copy()
353            kernel_copy.blocks = {}
354            for block_label, block in the_ir.blocks.items():
355                new_block = copy.deepcopy(the_ir.blocks[block_label])
356                new_block.body = []
357                for stmt in the_ir.blocks[block_label].body:
358                    scopy = copy.deepcopy(stmt)
359                    new_block.body.append(scopy)
360                kernel_copy.blocks[block_label] = new_block
361            return kernel_copy
362
363        callee_ir = copy_ir(callee_ir)
364
365        # check that the contents of the callee IR is something that can be
366        # inlined if a validator is present
367        if self.validator is not None:
368            self.validator(callee_ir)
369
370        # save an unmutated copy of the callee_ir to return
371        callee_ir_original = callee_ir.copy()
372        scope = block.scope
373        instr = block.body[i]
374        call_expr = instr.value
375        callee_blocks = callee_ir.blocks
376
377        # 1. relabel callee_ir by adding an offset
378        max_label = max(ir_utils._max_label, max(caller_ir.blocks.keys()))
379        callee_blocks = add_offset_to_labels(callee_blocks, max_label + 1)
380        callee_blocks = simplify_CFG(callee_blocks)
381        callee_ir.blocks = callee_blocks
382        min_label = min(callee_blocks.keys())
383        max_label = max(callee_blocks.keys())
384        #    reset globals in ir_utils before we use it
385        ir_utils._max_label = max_label
386        self.debug_print("After relabel")
387        _debug_dump(callee_ir)
388
389        # 2. rename all local variables in callee_ir with new locals created in
390        # caller_ir
391        callee_scopes = _get_all_scopes(callee_blocks)
392        self.debug_print("callee_scopes = ", callee_scopes)
393        #    one function should only have one local scope
394        assert(len(callee_scopes) == 1)
395        callee_scope = callee_scopes[0]
396        var_dict = {}
397        for var in callee_scope.localvars._con.values():
398            if not (var.name in callee_freevars):
399                new_var = scope.redefine(mk_unique_var(var.name), loc=var.loc)
400                var_dict[var.name] = new_var
401        self.debug_print("var_dict = ", var_dict)
402        replace_vars(callee_blocks, var_dict)
403        self.debug_print("After local var rename")
404        _debug_dump(callee_ir)
405
406        # 3. replace formal parameters with actual arguments
407        callee_func = callee_ir.func_id.func
408        args = _get_callee_args(call_expr, callee_func, block.body[i].loc,
409                                caller_ir)
410
411        # 4. Update typemap
412        if self._permit_update_type_and_call_maps:
413            if arg_typs is None:
414                raise TypeError('arg_typs should have a value not None')
415            self.update_type_and_call_maps(callee_ir, arg_typs)
416
417        self.debug_print("After arguments rename: ")
418        _debug_dump(callee_ir)
419
420        _replace_args_with(callee_blocks, args)
421        # 5. split caller blocks into two
422        new_blocks = []
423        new_block = ir.Block(scope, block.loc)
424        new_block.body = block.body[i + 1:]
425        new_label = next_label()
426        caller_ir.blocks[new_label] = new_block
427        new_blocks.append((new_label, new_block))
428        block.body = block.body[:i]
429        block.body.append(ir.Jump(min_label, instr.loc))
430
431        # 6. replace Return with assignment to LHS
432        topo_order = find_topo_order(callee_blocks)
433        _replace_returns(callee_blocks, instr.target, new_label)
434
435        # remove the old definition of instr.target too
436        if (instr.target.name in caller_ir._definitions
437                and call_expr in caller_ir._definitions[instr.target.name]):
438            # NOTE: target can have multiple definitions due to control flow
439            caller_ir._definitions[instr.target.name].remove(call_expr)
440
441        # 7. insert all new blocks, and add back definitions
442        for label in topo_order:
443            # block scope must point to parent's
444            block = callee_blocks[label]
445            block.scope = scope
446            _add_definitions(caller_ir, block)
447            caller_ir.blocks[label] = block
448            new_blocks.append((label, block))
449        self.debug_print("After merge in")
450        _debug_dump(caller_ir)
451
452        return callee_ir_original, callee_blocks, var_dict, new_blocks
453
454    def inline_function(self, caller_ir, block, i, function, arg_typs=None):
455        """ Inlines the function in the caller_ir at statement index i of block
456        `block`. If `arg_typs` is given and the InlineWorker instance was
457        initialized with a typemap and calltypes then they will be appropriately
458        updated based on the arg_typs.
459        """
460        callee_ir = self.run_untyped_passes(function)
461        freevars = function.__code__.co_freevars
462        return self.inline_ir(caller_ir, block, i, callee_ir, freevars,
463                              arg_typs=arg_typs)
464
465    def run_untyped_passes(self, func):
466        """
467        Run the compiler frontend's untyped passes over the given Python
468        function, and return the function's canonical Numba IR.
469        """
470        from numba.core.compiler import StateDict, _CompileStatus
471        from numba.core.untyped_passes import ExtractByteCode, WithLifting
472        from numba.core import bytecode
473        from numba.parfors.parfor import ParforDiagnostics
474        state = StateDict()
475        state.func_ir = None
476        state.typingctx = self.typingctx
477        state.targetctx = self.targetctx
478        state.locals = self.locals
479        state.pipeline = self.pipeline
480        state.flags = self.flags
481
482        # Disable SSA transformation, the call site won't be in SSA form and
483        # self.inline_ir depends on this being the case.
484        state.flags.enable_ssa = False
485
486        state.func_id = bytecode.FunctionIdentity.from_function(func)
487
488        state.typemap = None
489        state.calltypes = None
490        state.type_annotation = None
491        state.status = _CompileStatus(False, False)
492        state.return_type = None
493        state.parfor_diagnostics = ParforDiagnostics()
494        state.metadata = {}
495
496        ExtractByteCode().run_pass(state)
497        # This is a lie, just need *some* args for the case where an obj mode
498        # with lift is needed
499        state.args = len(state.bc.func_id.pysig.parameters) * (types.pyobject,)
500
501        pm = self._compiler_pipeline(state)
502
503        pm.finalize()
504        pm.run(state)
505        return state.func_ir
506
507    def update_type_and_call_maps(self, callee_ir, arg_typs):
508        """ Updates the type and call maps based on calling callee_ir with arguments
509        from arg_typs"""
510        if not self._permit_update_type_and_call_maps:
511            msg = ("InlineWorker instance not configured correctly, typemap or "
512                   "calltypes missing in initialization.")
513            raise ValueError(msg)
514        from numba.core import typed_passes
515        # call branch pruning to simplify IR and avoid inference errors
516        callee_ir._definitions = ir_utils.build_definitions(callee_ir.blocks)
517        numba.core.analysis.dead_branch_prune(callee_ir, arg_typs)
518        f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage(
519                self.typingctx, callee_ir, arg_typs, None)
520        canonicalize_array_math(callee_ir, f_typemap,
521                                f_calltypes, self.typingctx)
522        # remove argument entries like arg.a from typemap
523        arg_names = [vname for vname in f_typemap if vname.startswith("arg.")]
524        for a in arg_names:
525            f_typemap.pop(a)
526        self.typemap.update(f_typemap)
527        self.calltypes.update(f_calltypes)
528
529
530def inline_closure_call(func_ir, glbls, block, i, callee, typingctx=None,
531                        arg_typs=None, typemap=None, calltypes=None,
532                        work_list=None, callee_validator=None,
533                        replace_freevars=True):
534    """Inline the body of `callee` at its callsite (`i`-th instruction of `block`)
535
536    `func_ir` is the func_ir object of the caller function and `glbls` is its
537    global variable environment (func_ir.func_id.func.__globals__).
538    `block` is the IR block of the callsite and `i` is the index of the
539    callsite's node. `callee` is either the called function or a
540    make_function node. `typingctx`, `typemap` and `calltypes` are typing
541    data structures of the caller, available if we are in a typed pass.
542    `arg_typs` includes the types of the arguments at the callsite.
543    `callee_validator` is an optional callable which can be used to validate the
544    IR of the callee to ensure that it contains IR supported for inlining, it
545    takes one argument, the func_ir of the callee
546
547    Returns IR blocks of the callee and the variable renaming dictionary used
548    for them to facilitate further processing of new blocks.
549    """
550    scope = block.scope
551    instr = block.body[i]
552    call_expr = instr.value
553    debug_print = _make_debug_print("inline_closure_call")
554    debug_print("Found closure call: ", instr, " with callee = ", callee)
555    # support both function object and make_function Expr
556    callee_code = callee.code if hasattr(callee, 'code') else callee.__code__
557    callee_closure = callee.closure if hasattr(callee, 'closure') else callee.__closure__
558    # first, get the IR of the callee
559    if isinstance(callee, pytypes.FunctionType):
560        from numba.core import compiler
561        callee_ir = compiler.run_frontend(callee, inline_closures=True)
562    else:
563        callee_ir = get_ir_of_code(glbls, callee_code)
564
565    # check that the contents of the callee IR is something that can be inlined
566    # if a validator is supplied
567    if callee_validator is not None:
568        callee_validator(callee_ir)
569
570    callee_blocks = callee_ir.blocks
571
572    # 1. relabel callee_ir by adding an offset
573    max_label = max(ir_utils._max_label, max(func_ir.blocks.keys()))
574    callee_blocks = add_offset_to_labels(callee_blocks, max_label + 1)
575    callee_blocks = simplify_CFG(callee_blocks)
576    callee_ir.blocks = callee_blocks
577    min_label = min(callee_blocks.keys())
578    max_label = max(callee_blocks.keys())
579    #    reset globals in ir_utils before we use it
580    ir_utils._max_label = max_label
581    debug_print("After relabel")
582    _debug_dump(callee_ir)
583
584    # 2. rename all local variables in callee_ir with new locals created in func_ir
585    callee_scopes = _get_all_scopes(callee_blocks)
586    debug_print("callee_scopes = ", callee_scopes)
587    #    one function should only have one local scope
588    assert(len(callee_scopes) == 1)
589    callee_scope = callee_scopes[0]
590    var_dict = {}
591    for var in callee_scope.localvars._con.values():
592        if not (var.name in callee_code.co_freevars):
593            new_var = scope.redefine(mk_unique_var(var.name), loc=var.loc)
594            var_dict[var.name] = new_var
595    debug_print("var_dict = ", var_dict)
596    replace_vars(callee_blocks, var_dict)
597    debug_print("After local var rename")
598    _debug_dump(callee_ir)
599
600    # 3. replace formal parameters with actual arguments
601    args = _get_callee_args(call_expr, callee, block.body[i].loc, func_ir)
602
603    debug_print("After arguments rename: ")
604    _debug_dump(callee_ir)
605
606    # 4. replace freevar with actual closure var
607    if callee_closure and replace_freevars:
608        closure = func_ir.get_definition(callee_closure)
609        debug_print("callee's closure = ", closure)
610        if isinstance(closure, tuple):
611            cellget = ctypes.pythonapi.PyCell_Get
612            cellget.restype = ctypes.py_object
613            cellget.argtypes = (ctypes.py_object,)
614            items = tuple(cellget(x) for x in closure)
615        else:
616            assert(isinstance(closure, ir.Expr)
617                   and closure.op == 'build_tuple')
618            items = closure.items
619        assert(len(callee_code.co_freevars) == len(items))
620        _replace_freevars(callee_blocks, items)
621        debug_print("After closure rename")
622        _debug_dump(callee_ir)
623
624    if typingctx:
625        from numba.core import typed_passes
626        # call branch pruning to simplify IR and avoid inference errors
627        callee_ir._definitions = ir_utils.build_definitions(callee_ir.blocks)
628        numba.core.analysis.dead_branch_prune(callee_ir, arg_typs)
629        try:
630            f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage(
631                    typingctx, callee_ir, arg_typs, None)
632        except Exception:
633            f_typemap, f_return_type, f_calltypes = typed_passes.type_inference_stage(
634                    typingctx, callee_ir, arg_typs, None)
635            pass
636        canonicalize_array_math(callee_ir, f_typemap,
637                                f_calltypes, typingctx)
638        # remove argument entries like arg.a from typemap
639        arg_names = [vname for vname in f_typemap if vname.startswith("arg.")]
640        for a in arg_names:
641            f_typemap.pop(a)
642        typemap.update(f_typemap)
643        calltypes.update(f_calltypes)
644
645    _replace_args_with(callee_blocks, args)
646    # 5. split caller blocks into two
647    new_blocks = []
648    new_block = ir.Block(scope, block.loc)
649    new_block.body = block.body[i + 1:]
650    new_label = next_label()
651    func_ir.blocks[new_label] = new_block
652    new_blocks.append((new_label, new_block))
653    block.body = block.body[:i]
654    block.body.append(ir.Jump(min_label, instr.loc))
655
656    # 6. replace Return with assignment to LHS
657    topo_order = find_topo_order(callee_blocks)
658    _replace_returns(callee_blocks, instr.target, new_label)
659
660    # remove the old definition of instr.target too
661    if (instr.target.name in func_ir._definitions
662            and call_expr in func_ir._definitions[instr.target.name]):
663        # NOTE: target can have multiple definitions due to control flow
664        func_ir._definitions[instr.target.name].remove(call_expr)
665
666    # 7. insert all new blocks, and add back definitions
667    for label in topo_order:
668        # block scope must point to parent's
669        block = callee_blocks[label]
670        block.scope = scope
671        _add_definitions(func_ir, block)
672        func_ir.blocks[label] = block
673        new_blocks.append((label, block))
674    debug_print("After merge in")
675    _debug_dump(func_ir)
676
677    if work_list is not None:
678        for block in new_blocks:
679            work_list.append(block)
680    return callee_blocks, var_dict
681
682
683def _get_callee_args(call_expr, callee, loc, func_ir):
684    """Get arguments for calling 'callee', including the default arguments.
685    keyword arguments are currently only handled when 'callee' is a function.
686    """
687    if call_expr.op == 'call':
688        args = list(call_expr.args)
689    elif call_expr.op == 'getattr':
690        args = [call_expr.value]
691    else:
692        raise TypeError("Unsupported ir.Expr.{}".format(call_expr.op))
693
694    debug_print = _make_debug_print("inline_closure_call default handling")
695
696    # handle defaults and kw arguments using pysignature if callee is function
697    if isinstance(callee, pytypes.FunctionType):
698        pysig = numba.core.utils.pysignature(callee)
699        normal_handler = lambda index, param, default: default
700        default_handler = lambda index, param, default: ir.Const(default, loc)
701        # Throw error for stararg
702        # TODO: handle stararg
703        def stararg_handler(index, param, default):
704            raise NotImplementedError(
705                "Stararg not supported in inliner for arg {} {}".format(
706                    index, param))
707        if call_expr.op == 'call':
708            kws = dict(call_expr.kws)
709        else:
710            kws = {}
711        return numba.core.typing.fold_arguments(
712            pysig, args, kws, normal_handler, default_handler,
713            stararg_handler)
714    else:
715        # TODO: handle arguments for make_function case similar to function
716        # case above
717        callee_defaults = (callee.defaults if hasattr(callee, 'defaults')
718                           else callee.__defaults__)
719        if callee_defaults:
720            debug_print("defaults = ", callee_defaults)
721            if isinstance(callee_defaults, tuple):  # Python 3.5
722                defaults_list = []
723                for x in callee_defaults:
724                    if isinstance(x, ir.Var):
725                        defaults_list.append(x)
726                    else:
727                        # this branch is predominantly for kwargs from
728                        # inlinable functions
729                        defaults_list.append(ir.Const(value=x, loc=loc))
730                args = args + defaults_list
731            elif (isinstance(callee_defaults, ir.Var)
732                    or isinstance(callee_defaults, str)):
733                default_tuple = func_ir.get_definition(callee_defaults)
734                assert(isinstance(default_tuple, ir.Expr))
735                assert(default_tuple.op == "build_tuple")
736                const_vals = [func_ir.get_definition(x) for
737                              x in default_tuple.items]
738                args = args + const_vals
739            else:
740                raise NotImplementedError(
741                    "Unsupported defaults to make_function: {}".format(
742                        defaults))
743        return args
744
745
746def _make_debug_print(prefix):
747    def debug_print(*args):
748        if config.DEBUG_INLINE_CLOSURE:
749            print(prefix + ": " + "".join(str(x) for x in args))
750    return debug_print
751
752
753def _debug_dump(func_ir):
754    if config.DEBUG_INLINE_CLOSURE:
755        func_ir.dump()
756
757
758def _get_all_scopes(blocks):
759    """Get all block-local scopes from an IR.
760    """
761    all_scopes = []
762    for label, block in blocks.items():
763        if not (block.scope in all_scopes):
764            all_scopes.append(block.scope)
765    return all_scopes
766
767
768def _replace_args_with(blocks, args):
769    """
770    Replace ir.Arg(...) with real arguments from call site
771    """
772    for label, block in blocks.items():
773        assigns = block.find_insts(ir.Assign)
774        for stmt in assigns:
775            if isinstance(stmt.value, ir.Arg):
776                idx = stmt.value.index
777                assert(idx < len(args))
778                stmt.value = args[idx]
779
780
781def _replace_freevars(blocks, args):
782    """
783    Replace ir.FreeVar(...) with real variables from parent function
784    """
785    for label, block in blocks.items():
786        assigns = block.find_insts(ir.Assign)
787        for stmt in assigns:
788            if isinstance(stmt.value, ir.FreeVar):
789                idx = stmt.value.index
790                assert(idx < len(args))
791                if isinstance(args[idx], ir.Var):
792                    stmt.value = args[idx]
793                else:
794                    stmt.value = ir.Const(args[idx], stmt.loc)
795
796
797def _replace_returns(blocks, target, return_label):
798    """
799    Return return statement by assigning directly to target, and a jump.
800    """
801    for label, block in blocks.items():
802        casts = []
803        for i in range(len(block.body)):
804            stmt = block.body[i]
805            if isinstance(stmt, ir.Return):
806                assert(i + 1 == len(block.body))
807                block.body[i] = ir.Assign(stmt.value, target, stmt.loc)
808                block.body.append(ir.Jump(return_label, stmt.loc))
809                # remove cast of the returned value
810                for cast in casts:
811                    if cast.target.name == stmt.value.name:
812                        cast.value = cast.value.value
813            elif isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'cast':
814                casts.append(stmt)
815
816def _add_definitions(func_ir, block):
817    """
818    Add variable definitions found in a block to parent func_ir.
819    """
820    definitions = func_ir._definitions
821    assigns = block.find_insts(ir.Assign)
822    for stmt in assigns:
823        definitions[stmt.target.name].append(stmt.value)
824
825def _find_arraycall(func_ir, block):
826    """Look for statement like "x = numpy.array(y)" or "x[..] = y"
827    immediately after the closure call that creates list y (the i-th
828    statement in block).  Return the statement index if found, or
829    raise GuardException.
830    """
831    array_var = None
832    array_call_index = None
833    list_var_dead_after_array_call = False
834    list_var = None
835
836    i = 0
837    while i < len(block.body):
838        instr = block.body[i]
839        if isinstance(instr, ir.Del):
840            # Stop the process if list_var becomes dead
841            if list_var and array_var and instr.value == list_var.name:
842                list_var_dead_after_array_call = True
843                break
844            pass
845        elif isinstance(instr, ir.Assign):
846            # Found array_var = array(list_var)
847            lhs  = instr.target
848            expr = instr.value
849            if (guard(find_callname, func_ir, expr) == ('array', 'numpy') and
850                isinstance(expr.args[0], ir.Var)):
851                list_var = expr.args[0]
852                array_var = lhs
853                array_stmt_index = i
854                array_kws = dict(expr.kws)
855        elif (isinstance(instr, ir.SetItem) and
856              isinstance(instr.value, ir.Var) and
857              not list_var):
858            list_var = instr.value
859            # Found array_var[..] = list_var, the case for nested array
860            array_var = instr.target
861            array_def = get_definition(func_ir, array_var)
862            require(guard(_find_unsafe_empty_inferred, func_ir, array_def))
863            array_stmt_index = i
864            array_kws = {}
865        else:
866            # Bail out otherwise
867            break
868        i = i + 1
869    # require array_var is found, and list_var is dead after array_call.
870    require(array_var and list_var_dead_after_array_call)
871    _make_debug_print("find_array_call")(block.body[array_stmt_index])
872    return list_var, array_stmt_index, array_kws
873
874
875def _find_iter_range(func_ir, range_iter_var, swapped):
876    """Find the iterator's actual range if it is either range(n), or range(m, n),
877    otherwise return raise GuardException.
878    """
879    debug_print = _make_debug_print("find_iter_range")
880    range_iter_def = get_definition(func_ir, range_iter_var)
881    debug_print("range_iter_var = ", range_iter_var, " def = ", range_iter_def)
882    require(isinstance(range_iter_def, ir.Expr) and range_iter_def.op == 'getiter')
883    range_var = range_iter_def.value
884    range_def = get_definition(func_ir, range_var)
885    debug_print("range_var = ", range_var, " range_def = ", range_def)
886    require(isinstance(range_def, ir.Expr) and range_def.op == 'call')
887    func_var = range_def.func
888    func_def = get_definition(func_ir, func_var)
889    debug_print("func_var = ", func_var, " func_def = ", func_def)
890    require(isinstance(func_def, ir.Global) and
891            (func_def.value == range or func_def.value == numba.misc.special.prange))
892    nargs = len(range_def.args)
893    swapping = [('"array comprehension"', 'closure of'), range_def.func.loc]
894    if nargs == 1:
895        swapped[range_def.func.name] = swapping
896        stop = get_definition(func_ir, range_def.args[0], lhs_only=True)
897        return (0, range_def.args[0], func_def)
898    elif nargs == 2:
899        swapped[range_def.func.name] = swapping
900        start = get_definition(func_ir, range_def.args[0], lhs_only=True)
901        stop = get_definition(func_ir, range_def.args[1], lhs_only=True)
902        return (start, stop, func_def)
903    else:
904        raise GuardException
905
906def _inline_arraycall(func_ir, cfg, visited, loop, swapped, enable_prange=False,
907                      typed=False):
908    """Look for array(list) call in the exit block of a given loop, and turn list operations into
909    array operations in the loop if the following conditions are met:
910      1. The exit block contains an array call on the list;
911      2. The list variable is no longer live after array call;
912      3. The list is created in the loop entry block;
913      4. The loop is created from an range iterator whose length is known prior to the loop;
914      5. There is only one list_append operation on the list variable in the loop body;
915      6. The block that contains list_append dominates the loop head, which ensures list
916         length is the same as loop length;
917    If any condition check fails, no modification will be made to the incoming IR.
918    """
919    debug_print = _make_debug_print("inline_arraycall")
920    # There should only be one loop exit
921    require(len(loop.exits) == 1)
922    exit_block = next(iter(loop.exits))
923    list_var, array_call_index, array_kws = _find_arraycall(func_ir, func_ir.blocks[exit_block])
924
925    # check if dtype is present in array call
926    dtype_def = None
927    dtype_mod_def = None
928    if 'dtype' in array_kws:
929        require(isinstance(array_kws['dtype'], ir.Var))
930        # We require that dtype argument to be a constant of getattr Expr, and we'll
931        # remember its definition for later use.
932        dtype_def = get_definition(func_ir, array_kws['dtype'])
933        require(isinstance(dtype_def, ir.Expr) and dtype_def.op == 'getattr')
934        dtype_mod_def = get_definition(func_ir, dtype_def.value)
935
936    list_var_def = get_definition(func_ir, list_var)
937    debug_print("list_var = ", list_var, " def = ", list_var_def)
938    if isinstance(list_var_def, ir.Expr) and list_var_def.op == 'cast':
939        list_var_def = get_definition(func_ir, list_var_def.value)
940    # Check if the definition is a build_list
941    require(isinstance(list_var_def, ir.Expr) and list_var_def.op ==  'build_list')
942    # The build_list must be empty
943    require(len(list_var_def.items) == 0)
944
945    # Look for list_append in "last" block in loop body, which should be a block that is
946    # a post-dominator of the loop header.
947    list_append_stmts = []
948    for label in loop.body:
949        # We have to consider blocks of this loop, but not sub-loops.
950        # To achieve this, we require the set of "in_loops" of "label" to be visited loops.
951        in_visited_loops = [l.header in visited for l in cfg.in_loops(label)]
952        if not all(in_visited_loops):
953            continue
954        block = func_ir.blocks[label]
955        debug_print("check loop body block ", label)
956        for stmt in block.find_insts(ir.Assign):
957            lhs = stmt.target
958            expr = stmt.value
959            if isinstance(expr, ir.Expr) and expr.op == 'call':
960                func_def = get_definition(func_ir, expr.func)
961                if isinstance(func_def, ir.Expr) and func_def.op == 'getattr' \
962                  and func_def.attr == 'append':
963                    list_def = get_definition(func_ir, func_def.value)
964                    debug_print("list_def = ", list_def, list_def is list_var_def)
965                    if list_def is list_var_def:
966                        # found matching append call
967                        list_append_stmts.append((label, block, stmt))
968
969    # Require only one list_append, otherwise we won't know the indices
970    require(len(list_append_stmts) == 1)
971    append_block_label, append_block, append_stmt = list_append_stmts[0]
972
973    # Check if append_block (besides loop entry) dominates loop header.
974    # Since CFG doesn't give us this info without loop entry, we approximate
975    # by checking if the predecessor set of the header block is the same
976    # as loop_entries plus append_block, which is certainly more restrictive
977    # than necessary, and can be relaxed if needed.
978    preds = set(l for l, b in cfg.predecessors(loop.header))
979    debug_print("preds = ", preds, (loop.entries | set([append_block_label])))
980    require(preds == (loop.entries | set([append_block_label])))
981
982    # Find iterator in loop header
983    iter_vars = []
984    iter_first_vars = []
985    loop_header = func_ir.blocks[loop.header]
986    for stmt in loop_header.find_insts(ir.Assign):
987        expr = stmt.value
988        if isinstance(expr, ir.Expr):
989            if expr.op == 'iternext':
990                iter_def = get_definition(func_ir, expr.value)
991                debug_print("iter_def = ", iter_def)
992                iter_vars.append(expr.value)
993            elif expr.op == 'pair_first':
994                iter_first_vars.append(stmt.target)
995
996    # Require only one iterator in loop header
997    require(len(iter_vars) == 1 and len(iter_first_vars) == 1)
998    iter_var = iter_vars[0] # variable that holds the iterator object
999    iter_first_var = iter_first_vars[0] # variable that holds the value out of iterator
1000
1001    # Final requirement: only one loop entry, and we're going to modify it by:
1002    # 1. replacing the list definition with an array definition;
1003    # 2. adding a counter for the array iteration.
1004    require(len(loop.entries) == 1)
1005    loop_entry = func_ir.blocks[next(iter(loop.entries))]
1006    terminator = loop_entry.terminator
1007    scope = loop_entry.scope
1008    loc = loop_entry.loc
1009    stmts = []
1010    removed = []
1011    def is_removed(val, removed):
1012        if isinstance(val, ir.Var):
1013            for x in removed:
1014                if x.name == val.name:
1015                    return True
1016        return False
1017    # Skip list construction and skip terminator, add the rest to stmts
1018    for i in range(len(loop_entry.body) - 1):
1019        stmt = loop_entry.body[i]
1020        if isinstance(stmt, ir.Assign) and (stmt.value is list_def or is_removed(stmt.value, removed)):
1021            removed.append(stmt.target)
1022        else:
1023            stmts.append(stmt)
1024    debug_print("removed variables: ", removed)
1025
1026    # Define an index_var to index the array.
1027    # If the range happens to be single step ranges like range(n), or range(m, n),
1028    # then the index_var correlates to iterator index; otherwise we'll have to
1029    # define a new counter.
1030    range_def = guard(_find_iter_range, func_ir, iter_var, swapped)
1031    index_var = ir.Var(scope, mk_unique_var("index"), loc)
1032    if range_def and range_def[0] == 0:
1033        # iterator starts with 0, index_var can just be iter_first_var
1034        index_var = iter_first_var
1035    else:
1036        # index_var = -1 # starting the index with -1 since it will incremented in loop header
1037        stmts.append(_new_definition(func_ir, index_var, ir.Const(value=-1, loc=loc), loc))
1038
1039    # Insert statement to get the size of the loop iterator
1040    size_var = ir.Var(scope, mk_unique_var("size"), loc)
1041    if range_def:
1042        start, stop, range_func_def = range_def
1043        if start == 0:
1044            size_val = stop
1045        else:
1046            size_val = ir.Expr.binop(fn=operator.sub, lhs=stop, rhs=start, loc=loc)
1047        # we can parallelize this loop if enable_prange = True, by changing
1048        # range function from range, to prange.
1049        if enable_prange and isinstance(range_func_def, ir.Global):
1050            range_func_def.name = 'internal_prange'
1051            range_func_def.value = internal_prange
1052
1053    else:
1054        # this doesn't work in objmode as it's effectively untyped
1055        if typed:
1056            len_func_var = ir.Var(scope, mk_unique_var("len_func"), loc)
1057            stmts.append(_new_definition(func_ir, len_func_var,
1058                        ir.Global('range_iter_len', range_iter_len, loc=loc),
1059                        loc))
1060            size_val = ir.Expr.call(len_func_var, (iter_var,), (), loc=loc)
1061        else:
1062            raise GuardException
1063
1064
1065    stmts.append(_new_definition(func_ir, size_var, size_val, loc))
1066
1067    size_tuple_var = ir.Var(scope, mk_unique_var("size_tuple"), loc)
1068    stmts.append(_new_definition(func_ir, size_tuple_var,
1069                 ir.Expr.build_tuple(items=[size_var], loc=loc), loc))
1070
1071    # Insert array allocation
1072    array_var = ir.Var(scope, mk_unique_var("array"), loc)
1073    empty_func = ir.Var(scope, mk_unique_var("empty_func"), loc)
1074    if dtype_def and dtype_mod_def:
1075        # when dtype is present, we'll call empty with dtype
1076        dtype_mod_var = ir.Var(scope, mk_unique_var("dtype_mod"), loc)
1077        dtype_var = ir.Var(scope, mk_unique_var("dtype"), loc)
1078        stmts.append(_new_definition(func_ir, dtype_mod_var, dtype_mod_def, loc))
1079        stmts.append(_new_definition(func_ir, dtype_var,
1080                         ir.Expr.getattr(dtype_mod_var, dtype_def.attr, loc), loc))
1081        stmts.append(_new_definition(func_ir, empty_func,
1082                         ir.Global('empty', np.empty, loc=loc), loc))
1083        array_kws = [('dtype', dtype_var)]
1084    else:
1085        # this doesn't work in objmode as it's effectively untyped
1086        if typed:
1087            # otherwise we'll call unsafe_empty_inferred
1088            stmts.append(_new_definition(func_ir, empty_func,
1089                            ir.Global('unsafe_empty_inferred',
1090                                unsafe_empty_inferred, loc=loc), loc))
1091            array_kws = []
1092        else:
1093            raise GuardException
1094
1095    # array_var = empty_func(size_tuple_var)
1096    stmts.append(_new_definition(func_ir, array_var,
1097                 ir.Expr.call(empty_func, (size_tuple_var,), list(array_kws), loc=loc), loc))
1098
1099    # Add back removed just in case they are used by something else
1100    for var in removed:
1101        stmts.append(_new_definition(func_ir, var, array_var, loc))
1102
1103    # Add back terminator
1104    stmts.append(terminator)
1105    # Modify loop_entry
1106    loop_entry.body = stmts
1107
1108    if range_def:
1109        if range_def[0] != 0:
1110            # when range doesn't start from 0, index_var becomes loop index
1111            # (iter_first_var) minus an offset (range_def[0])
1112            terminator = loop_header.terminator
1113            assert(isinstance(terminator, ir.Branch))
1114            # find the block in the loop body that header jumps to
1115            block_id = terminator.truebr
1116            blk = func_ir.blocks[block_id]
1117            loc = blk.loc
1118            blk.body.insert(0, _new_definition(func_ir, index_var,
1119                ir.Expr.binop(fn=operator.sub, lhs=iter_first_var,
1120                                      rhs=range_def[0], loc=loc),
1121                loc))
1122    else:
1123        # Insert index_var increment to the end of loop header
1124        loc = loop_header.loc
1125        terminator = loop_header.terminator
1126        stmts = loop_header.body[0:-1]
1127        next_index_var = ir.Var(scope, mk_unique_var("next_index"), loc)
1128        one = ir.Var(scope, mk_unique_var("one"), loc)
1129        # one = 1
1130        stmts.append(_new_definition(func_ir, one,
1131                     ir.Const(value=1,loc=loc), loc))
1132        # next_index_var = index_var + 1
1133        stmts.append(_new_definition(func_ir, next_index_var,
1134                     ir.Expr.binop(fn=operator.add, lhs=index_var, rhs=one, loc=loc), loc))
1135        # index_var = next_index_var
1136        stmts.append(_new_definition(func_ir, index_var, next_index_var, loc))
1137        stmts.append(terminator)
1138        loop_header.body = stmts
1139
1140    # In append_block, change list_append into array assign
1141    for i in range(len(append_block.body)):
1142        if append_block.body[i] is append_stmt:
1143            debug_print("Replace append with SetItem")
1144            append_block.body[i] = ir.SetItem(target=array_var, index=index_var,
1145                                              value=append_stmt.value.args[0], loc=append_stmt.loc)
1146
1147    # replace array call, by changing "a = array(b)" to "a = b"
1148    stmt = func_ir.blocks[exit_block].body[array_call_index]
1149    # stmt can be either array call or SetItem, we only replace array call
1150    if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr):
1151        stmt.value = array_var
1152        func_ir._definitions[stmt.target.name] = [stmt.value]
1153
1154    return True
1155
1156
1157def _find_unsafe_empty_inferred(func_ir, expr):
1158    unsafe_empty_inferred
1159    require(isinstance(expr, ir.Expr) and expr.op == 'call')
1160    callee = expr.func
1161    callee_def = get_definition(func_ir, callee)
1162    require(isinstance(callee_def, ir.Global))
1163    _make_debug_print("_find_unsafe_empty_inferred")(callee_def.value)
1164    return callee_def.value == unsafe_empty_inferred
1165
1166
1167def _fix_nested_array(func_ir):
1168    """Look for assignment like: a[..] = b, where both a and b are numpy arrays, and
1169    try to eliminate array b by expanding a with an extra dimension.
1170    """
1171    blocks = func_ir.blocks
1172    cfg = compute_cfg_from_blocks(blocks)
1173    usedefs = compute_use_defs(blocks)
1174    empty_deadmap = dict([(label, set()) for label in blocks.keys()])
1175    livemap = compute_live_variables(cfg, blocks, usedefs.defmap, empty_deadmap)
1176
1177    def find_array_def(arr):
1178        """Find numpy array definition such as
1179            arr = numba.unsafe.ndarray.empty_inferred(...).
1180        If it is arr = b[...], find array definition of b recursively.
1181        """
1182        arr_def = get_definition(func_ir, arr)
1183        _make_debug_print("find_array_def")(arr, arr_def)
1184        if isinstance(arr_def, ir.Expr):
1185            if guard(_find_unsafe_empty_inferred, func_ir, arr_def):
1186                return arr_def
1187            elif arr_def.op == 'getitem':
1188                return find_array_def(arr_def.value)
1189        raise GuardException
1190
1191    def fix_dependencies(expr, varlist):
1192        """Double check if all variables in varlist are defined before
1193        expr is used. Try to move constant definition when the check fails.
1194        Bails out by raising GuardException if it can't be moved.
1195        """
1196        debug_print = _make_debug_print("fix_dependencies")
1197        for label, block in blocks.items():
1198            scope = block.scope
1199            body = block.body
1200            defined = set()
1201            for i in range(len(body)):
1202                inst = body[i]
1203                if isinstance(inst, ir.Assign):
1204                    defined.add(inst.target.name)
1205                    if inst.value is expr:
1206                        new_varlist = []
1207                        for var in varlist:
1208                            # var must be defined before this inst, or live
1209                            # and not later defined.
1210                            if (var.name in defined or
1211                                (var.name in livemap[label] and
1212                                 not (var.name in usedefs.defmap[label]))):
1213                                debug_print(var.name, " already defined")
1214                                new_varlist.append(var)
1215                            else:
1216                                debug_print(var.name, " not yet defined")
1217                                var_def = get_definition(func_ir, var.name)
1218                                if isinstance(var_def, ir.Const):
1219                                    loc = var.loc
1220                                    new_var = ir.Var(scope, mk_unique_var("new_var"), loc)
1221                                    new_const = ir.Const(var_def.value, loc)
1222                                    new_vardef = _new_definition(func_ir,
1223                                                    new_var, new_const, loc)
1224                                    new_body = []
1225                                    new_body.extend(body[:i])
1226                                    new_body.append(new_vardef)
1227                                    new_body.extend(body[i:])
1228                                    block.body = new_body
1229                                    new_varlist.append(new_var)
1230                                else:
1231                                    raise GuardException
1232                        return new_varlist
1233        # when expr is not found in block
1234        raise GuardException
1235
1236    def fix_array_assign(stmt):
1237        """For assignment like lhs[idx] = rhs, where both lhs and rhs are arrays, do the
1238        following:
1239        1. find the definition of rhs, which has to be a call to numba.unsafe.ndarray.empty_inferred
1240        2. find the source array creation for lhs, insert an extra dimension of size of b.
1241        3. replace the definition of rhs = numba.unsafe.ndarray.empty_inferred(...) with rhs = lhs[idx]
1242        """
1243        require(isinstance(stmt, ir.SetItem))
1244        require(isinstance(stmt.value, ir.Var))
1245        debug_print = _make_debug_print("fix_array_assign")
1246        debug_print("found SetItem: ", stmt)
1247        lhs = stmt.target
1248        # Find the source array creation of lhs
1249        lhs_def = find_array_def(lhs)
1250        debug_print("found lhs_def: ", lhs_def)
1251        rhs_def = get_definition(func_ir, stmt.value)
1252        debug_print("found rhs_def: ", rhs_def)
1253        require(isinstance(rhs_def, ir.Expr))
1254        if rhs_def.op == 'cast':
1255            rhs_def = get_definition(func_ir, rhs_def.value)
1256            require(isinstance(rhs_def, ir.Expr))
1257        require(_find_unsafe_empty_inferred(func_ir, rhs_def))
1258        # Find the array dimension of rhs
1259        dim_def = get_definition(func_ir, rhs_def.args[0])
1260        require(isinstance(dim_def, ir.Expr) and dim_def.op == 'build_tuple')
1261        debug_print("dim_def = ", dim_def)
1262        extra_dims = [ get_definition(func_ir, x, lhs_only=True) for x in dim_def.items ]
1263        debug_print("extra_dims = ", extra_dims)
1264        # Expand size tuple when creating lhs_def with extra_dims
1265        size_tuple_def = get_definition(func_ir, lhs_def.args[0])
1266        require(isinstance(size_tuple_def, ir.Expr) and size_tuple_def.op == 'build_tuple')
1267        debug_print("size_tuple_def = ", size_tuple_def)
1268        extra_dims = fix_dependencies(size_tuple_def, extra_dims)
1269        size_tuple_def.items += extra_dims
1270        # In-place modify rhs_def to be getitem
1271        rhs_def.op = 'getitem'
1272        rhs_def.value = get_definition(func_ir, lhs, lhs_only=True)
1273        rhs_def.index = stmt.index
1274        del rhs_def._kws['func']
1275        del rhs_def._kws['args']
1276        del rhs_def._kws['vararg']
1277        del rhs_def._kws['kws']
1278        # success
1279        return True
1280
1281    for label in find_topo_order(func_ir.blocks):
1282        block = func_ir.blocks[label]
1283        for stmt in block.body:
1284            if guard(fix_array_assign, stmt):
1285                block.body.remove(stmt)
1286
1287def _new_definition(func_ir, var, value, loc):
1288    func_ir._definitions[var.name] = [value]
1289    return ir.Assign(value=value, target=var, loc=loc)
1290
1291@rewrites.register_rewrite('after-inference')
1292class RewriteArrayOfConsts(rewrites.Rewrite):
1293    '''The RewriteArrayOfConsts class is responsible for finding
1294    1D array creations from a constant list, and rewriting it into
1295    direct initialization of array elements without creating the list.
1296    '''
1297    def __init__(self, state, *args, **kws):
1298        self.typingctx = state.typingctx
1299        super(RewriteArrayOfConsts, self).__init__(*args, **kws)
1300
1301    def match(self, func_ir, block, typemap, calltypes):
1302        if len(calltypes) == 0:
1303            return False
1304        self.crnt_block = block
1305        self.new_body = guard(_inline_const_arraycall, block, func_ir,
1306                              self.typingctx, typemap, calltypes)
1307        return self.new_body is not None
1308
1309    def apply(self):
1310        self.crnt_block.body = self.new_body
1311        return self.crnt_block
1312
1313
1314def _inline_const_arraycall(block, func_ir, context, typemap, calltypes):
1315    """Look for array(list) call where list is a constant list created by build_list,
1316    and turn them into direct array creation and initialization, if the following
1317    conditions are met:
1318      1. The build_list call immediate precedes the array call;
1319      2. The list variable is no longer live after array call;
1320    If any condition check fails, no modification will be made.
1321    """
1322    debug_print = _make_debug_print("inline_const_arraycall")
1323    scope = block.scope
1324
1325    def inline_array(array_var, expr, stmts, list_vars, dels):
1326        """Check to see if the given "array_var" is created from a list
1327        of constants, and try to inline the list definition as array
1328        initialization.
1329
1330        Extra statements produced with be appended to "stmts".
1331        """
1332        callname = guard(find_callname, func_ir, expr)
1333        require(callname and callname[1] == 'numpy' and callname[0] == 'array')
1334        require(expr.args[0].name in list_vars)
1335        ret_type = calltypes[expr].return_type
1336        require(isinstance(ret_type, types.ArrayCompatible) and
1337                           ret_type.ndim == 1)
1338        loc = expr.loc
1339        list_var = expr.args[0]
1340        # Get the type of the array to be created.
1341        array_typ = typemap[array_var.name]
1342        debug_print("inline array_var = ", array_var, " list_var = ", list_var)
1343        # Get the element type of the array to be created.
1344        dtype = array_typ.dtype
1345        # Get the sequence of operations to provide values to the new array.
1346        seq, _ = find_build_sequence(func_ir, list_var)
1347        size = len(seq)
1348        # Create a tuple to pass to empty below to specify the new array size.
1349        size_var = ir.Var(scope, mk_unique_var("size"), loc)
1350        size_tuple_var = ir.Var(scope, mk_unique_var("size_tuple"), loc)
1351        size_typ = types.intp
1352        size_tuple_typ = types.UniTuple(size_typ, 1)
1353        typemap[size_var.name] = size_typ
1354        typemap[size_tuple_var.name] = size_tuple_typ
1355        stmts.append(_new_definition(func_ir, size_var,
1356                 ir.Const(size, loc=loc), loc))
1357        stmts.append(_new_definition(func_ir, size_tuple_var,
1358                 ir.Expr.build_tuple(items=[size_var], loc=loc), loc))
1359
1360        # The general approach is to create an empty array and then fill
1361        # the elements in one-by-one from their specificiation.
1362
1363        # Get the numpy type to pass to empty.
1364        nptype = types.DType(dtype)
1365
1366        # Create a variable to hold the numpy empty function.
1367        empty_func = ir.Var(scope, mk_unique_var("empty_func"), loc)
1368        fnty = get_np_ufunc_typ(np.empty)
1369        sig = context.resolve_function_type(fnty, (size_typ,), {'dtype':nptype})
1370
1371        typemap[empty_func.name] = fnty
1372
1373        stmts.append(_new_definition(func_ir, empty_func,
1374                         ir.Global('empty', np.empty, loc=loc), loc))
1375
1376        # We pass two arguments to empty, first the size tuple and second
1377        # the dtype of the new array.  Here, we created typ_var which is
1378        # the dtype argument of the new array.  typ_var in turn is created
1379        # by getattr of the dtype string on the numpy module.
1380
1381        # Create var for numpy module.
1382        g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
1383        typemap[g_np_var.name] = types.misc.Module(np)
1384        g_np = ir.Global('np', np, loc)
1385        stmts.append(_new_definition(func_ir, g_np_var, g_np, loc))
1386
1387        # Create var for result of numpy.<dtype>.
1388        typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc)
1389        typemap[typ_var.name] = nptype
1390        dtype_str = str(dtype)
1391        if dtype_str == 'bool':
1392            dtype_str = 'bool_'
1393        # Get dtype attribute of numpy module.
1394        np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc)
1395        stmts.append(_new_definition(func_ir, typ_var, np_typ_getattr, loc))
1396
1397        # Create the call to numpy.empty passing the size tuple and dtype var.
1398        empty_call = ir.Expr.call(empty_func, [size_var, typ_var], {}, loc=loc)
1399        calltypes[empty_call] = typing.signature(array_typ, size_typ, nptype)
1400        stmts.append(_new_definition(func_ir, array_var, empty_call, loc))
1401
1402        # Fill in the new empty array one-by-one.
1403        for i in range(size):
1404            index_var = ir.Var(scope, mk_unique_var("index"), loc)
1405            index_typ = types.intp
1406            typemap[index_var.name] = index_typ
1407            stmts.append(_new_definition(func_ir, index_var,
1408                    ir.Const(i, loc), loc))
1409            setitem = ir.SetItem(array_var, index_var, seq[i], loc)
1410            calltypes[setitem] = typing.signature(types.none, array_typ,
1411                                                  index_typ, dtype)
1412            stmts.append(setitem)
1413
1414        stmts.extend(dels)
1415        return True
1416
1417    class State(object):
1418        """
1419        This class is used to hold the state in the following loop so as to make
1420        it easy to reset the state of the variables tracking the various
1421        statement kinds
1422        """
1423
1424        def __init__(self):
1425            # list_vars keep track of the variable created from the latest
1426            # build_list instruction, as well as its synonyms.
1427            self.list_vars = []
1428            # dead_vars keep track of those in list_vars that are considered dead.
1429            self.dead_vars = []
1430            # list_items keep track of the elements used in build_list.
1431            self.list_items = []
1432            self.stmts = []
1433            # dels keep track of the deletion of list_items, which will need to be
1434            # moved after array initialization.
1435            self.dels = []
1436            # tracks if a modification has taken place
1437            self.modified = False
1438
1439        def reset(self):
1440            """
1441            Resets the internal state of the variables used for tracking
1442            """
1443            self.list_vars = []
1444            self.dead_vars = []
1445            self.list_items = []
1446            self.dels = []
1447
1448        def list_var_used(self, inst):
1449            """
1450            Returns True if the list being analysed is used between the
1451            build_list and the array call.
1452            """
1453            return any([x.name in self.list_vars for x in inst.list_vars()])
1454
1455    state = State()
1456
1457    for inst in block.body:
1458        if isinstance(inst, ir.Assign):
1459            if isinstance(inst.value, ir.Var):
1460                if inst.value.name in state.list_vars:
1461                    state.list_vars.append(inst.target.name)
1462                    state.stmts.append(inst)
1463                    continue
1464            elif isinstance(inst.value, ir.Expr):
1465                expr = inst.value
1466                if expr.op == 'build_list':
1467                    # new build_list encountered, reset state
1468                    state.reset()
1469                    state.list_items = [x.name for x in expr.items]
1470                    state.list_vars = [inst.target.name]
1471                    state.stmts.append(inst)
1472                    continue
1473                elif expr.op == 'call' and expr in calltypes:
1474                    arr_var = inst.target
1475                    if guard(inline_array, inst.target, expr,
1476                             state.stmts, state.list_vars, state.dels):
1477                        state.modified = True
1478                        continue
1479        elif isinstance(inst, ir.Del):
1480            removed_var = inst.value
1481            if removed_var in state.list_items:
1482                state.dels.append(inst)
1483                continue
1484            elif removed_var in state.list_vars:
1485                # one of the list_vars is considered dead.
1486                state.dead_vars.append(removed_var)
1487                state.list_vars.remove(removed_var)
1488                state.stmts.append(inst)
1489                if state.list_vars == []:
1490                    # if all list_vars are considered dead, we need to filter
1491                    # them out from existing stmts to completely remove
1492                    # build_list.
1493                    # Note that if a translation didn't take place, dead_vars
1494                    # will also be empty when we reach this point.
1495                    body = []
1496                    for inst in state.stmts:
1497                        if ((isinstance(inst, ir.Assign) and
1498                             inst.target.name in state.dead_vars) or
1499                             (isinstance(inst, ir.Del) and
1500                             inst.value in state.dead_vars)):
1501                            continue
1502                        body.append(inst)
1503                    state.stmts = body
1504                    state.dead_vars = []
1505                    state.modified = True
1506                    continue
1507        state.stmts.append(inst)
1508
1509        # If the list is used in any capacity between build_list and array
1510        # call, then we must call off the translation for this list because
1511        # it could be mutated and list_items would no longer be applicable.
1512        if state.list_var_used(inst):
1513            state.reset()
1514
1515    return state.stmts if state.modified else None
1516