1"""
2Utils for IR analysis
3"""
4import operator
5from functools import reduce
6from collections import namedtuple, defaultdict
7
8from .controlflow import CFGraph
9from numba.core import types, errors, ir, consts
10from numba.misc import special
11
12#
13# Analysis related to variable lifetime
14#
15
16_use_defs_result = namedtuple('use_defs_result', 'usemap,defmap')
17
18# other packages that define new nodes add calls for finding defs
19# format: {type:function}
20ir_extension_usedefs = {}
21
22
23def compute_use_defs(blocks):
24    """
25    Find variable use/def per block.
26    """
27
28    var_use_map = {}   # { block offset -> set of vars }
29    var_def_map = {}   # { block offset -> set of vars }
30    for offset, ir_block in blocks.items():
31        var_use_map[offset] = use_set = set()
32        var_def_map[offset] = def_set = set()
33        for stmt in ir_block.body:
34            if type(stmt) in ir_extension_usedefs:
35                func = ir_extension_usedefs[type(stmt)]
36                func(stmt, use_set, def_set)
37                continue
38            if isinstance(stmt, ir.Assign):
39                if isinstance(stmt.value, ir.Inst):
40                    rhs_set = set(var.name for var in stmt.value.list_vars())
41                elif isinstance(stmt.value, ir.Var):
42                    rhs_set = set([stmt.value.name])
43                elif isinstance(stmt.value, (ir.Arg, ir.Const, ir.Global,
44                                             ir.FreeVar)):
45                    rhs_set = ()
46                else:
47                    raise AssertionError('unreachable', type(stmt.value))
48                # If lhs not in rhs of the assignment
49                if stmt.target.name not in rhs_set:
50                    def_set.add(stmt.target.name)
51
52            for var in stmt.list_vars():
53                # do not include locally defined vars to use-map
54                if var.name not in def_set:
55                    use_set.add(var.name)
56
57    return _use_defs_result(usemap=var_use_map, defmap=var_def_map)
58
59
60def compute_live_map(cfg, blocks, var_use_map, var_def_map):
61    """
62    Find variables that must be alive at the ENTRY of each block.
63    We use a simple fix-point algorithm that iterates until the set of
64    live variables is unchanged for each block.
65    """
66    def fix_point_progress(dct):
67        """Helper function to determine if a fix-point has been reached.
68        """
69        return tuple(len(v) for v in dct.values())
70
71    def fix_point(fn, dct):
72        """Helper function to run fix-point algorithm.
73        """
74        old_point = None
75        new_point = fix_point_progress(dct)
76        while old_point != new_point:
77            fn(dct)
78            old_point = new_point
79            new_point = fix_point_progress(dct)
80
81    def def_reach(dct):
82        """Find all variable definition reachable at the entry of a block
83        """
84        for offset in var_def_map:
85            used_or_defined = var_def_map[offset] | var_use_map[offset]
86            dct[offset] |= used_or_defined
87            # Propagate to outgoing nodes
88            for out_blk, _ in cfg.successors(offset):
89                dct[out_blk] |= dct[offset]
90
91    def liveness(dct):
92        """Find live variables.
93
94        Push var usage backward.
95        """
96        for offset in dct:
97            # Live vars here
98            live_vars = dct[offset]
99            for inc_blk, _data in cfg.predecessors(offset):
100                # Reachable at the predecessor
101                reachable = live_vars & def_reach_map[inc_blk]
102                # But not defined in the predecessor
103                dct[inc_blk] |= reachable - var_def_map[inc_blk]
104
105    live_map = {}
106    for offset in blocks.keys():
107        live_map[offset] = set(var_use_map[offset])
108
109    def_reach_map = defaultdict(set)
110    fix_point(def_reach, def_reach_map)
111    fix_point(liveness, live_map)
112    return live_map
113
114
115_dead_maps_result = namedtuple('dead_maps_result', 'internal,escaping,combined')
116
117
118def compute_dead_maps(cfg, blocks, live_map, var_def_map):
119    """
120    Compute the end-of-live information for variables.
121    `live_map` contains a mapping of block offset to all the living
122    variables at the ENTRY of the block.
123    """
124    # The following three dictionaries will be
125    # { block offset -> set of variables to delete }
126    # all vars that should be deleted at the start of the successors
127    escaping_dead_map = defaultdict(set)
128    # all vars that should be deleted within this block
129    internal_dead_map = defaultdict(set)
130    # all vars that should be deleted after the function exit
131    exit_dead_map = defaultdict(set)
132
133    for offset, ir_block in blocks.items():
134        # live vars WITHIN the block will include all the locally
135        # defined variables
136        cur_live_set = live_map[offset] | var_def_map[offset]
137        # vars alive in the outgoing blocks
138        outgoing_live_map = dict((out_blk, live_map[out_blk])
139                                 for out_blk, _data in cfg.successors(offset))
140        # vars to keep alive for the terminator
141        terminator_liveset = set(v.name
142                                 for v in ir_block.terminator.list_vars())
143        # vars to keep alive in the successors
144        combined_liveset = reduce(operator.or_, outgoing_live_map.values(),
145                                  set())
146        # include variables used in terminator
147        combined_liveset |= terminator_liveset
148        # vars that are dead within the block because they are not
149        # propagated to any outgoing blocks
150        internal_set = cur_live_set - combined_liveset
151        internal_dead_map[offset] = internal_set
152        # vars that escape this block
153        escaping_live_set = cur_live_set - internal_set
154        for out_blk, new_live_set in outgoing_live_map.items():
155            # successor should delete the unused escaped vars
156            new_live_set = new_live_set | var_def_map[out_blk]
157            escaping_dead_map[out_blk] |= escaping_live_set - new_live_set
158
159        # if no outgoing blocks
160        if not outgoing_live_map:
161            # insert var used by terminator
162            exit_dead_map[offset] = terminator_liveset
163
164    # Verify that the dead maps cover all live variables
165    all_vars = reduce(operator.or_, live_map.values(), set())
166    internal_dead_vars = reduce(operator.or_, internal_dead_map.values(),
167                                set())
168    escaping_dead_vars = reduce(operator.or_, escaping_dead_map.values(),
169                                set())
170    exit_dead_vars = reduce(operator.or_, exit_dead_map.values(), set())
171    dead_vars = (internal_dead_vars | escaping_dead_vars | exit_dead_vars)
172    missing_vars = all_vars - dead_vars
173    if missing_vars:
174        # There are no exit points
175        if not cfg.exit_points():
176            # We won't be able to verify this
177            pass
178        else:
179            msg = 'liveness info missing for vars: {0}'.format(missing_vars)
180            raise RuntimeError(msg)
181
182    combined = dict((k, internal_dead_map[k] | escaping_dead_map[k])
183                    for k in blocks)
184
185    return _dead_maps_result(internal=internal_dead_map,
186                             escaping=escaping_dead_map,
187                             combined=combined)
188
189
190def compute_live_variables(cfg, blocks, var_def_map, var_dead_map):
191    """
192    Compute the live variables at the beginning of each block
193    and at each yield point.
194    The ``var_def_map`` and ``var_dead_map`` indicates the variable defined
195    and deleted at each block, respectively.
196    """
197    # live var at the entry per block
198    block_entry_vars = defaultdict(set)
199
200    def fix_point_progress():
201        return tuple(map(len, block_entry_vars.values()))
202
203    old_point = None
204    new_point = fix_point_progress()
205
206    # Propagate defined variables and still live the successors.
207    # (note the entry block automatically gets an empty set)
208
209    # Note: This is finding the actual available variables at the entry
210    #       of each block. The algorithm in compute_live_map() is finding
211    #       the variable that must be available at the entry of each block.
212    #       This is top-down in the dataflow.  The other one is bottom-up.
213    while old_point != new_point:
214        # We iterate until the result stabilizes.  This is necessary
215        # because of loops in the graphself.
216        for offset in blocks:
217            # vars available + variable defined
218            avail = block_entry_vars[offset] | var_def_map[offset]
219            # subtract variables deleted
220            avail -= var_dead_map[offset]
221            # add ``avail`` to each successors
222            for succ, _data in cfg.successors(offset):
223                block_entry_vars[succ] |= avail
224
225        old_point = new_point
226        new_point = fix_point_progress()
227
228    return block_entry_vars
229
230
231#
232# Analysis related to controlflow
233#
234
235def compute_cfg_from_blocks(blocks):
236    cfg = CFGraph()
237    for k in blocks:
238        cfg.add_node(k)
239
240    for k, b in blocks.items():
241        term = b.terminator
242        for target in term.get_targets():
243            cfg.add_edge(k, target)
244
245    cfg.set_entry_point(min(blocks))
246    cfg.process()
247    return cfg
248
249
250def find_top_level_loops(cfg):
251    """
252    A generator that yields toplevel loops given a control-flow-graph
253    """
254    blocks_in_loop = set()
255    # get loop bodies
256    for loop in cfg.loops().values():
257        insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
258        insiders.discard(loop.header)
259        blocks_in_loop |= insiders
260    # find loop that is not part of other loops
261    for loop in cfg.loops().values():
262        if loop.header not in blocks_in_loop:
263            yield _fix_loop_exit(cfg, loop)
264
265
266def _fix_loop_exit(cfg, loop):
267    """
268    Fixes loop.exits for Py3.8 bytecode CFG changes.
269    This is to handle `break` inside loops.
270    """
271    # Computes the common postdoms of exit nodes
272    postdoms = cfg.post_dominators()
273    exits = reduce(
274        operator.and_,
275        [postdoms[b] for b in loop.exits],
276        loop.exits,
277    )
278    if exits:
279        # Put the non-common-exits as body nodes
280        body = loop.body | loop.exits - exits
281        return loop._replace(exits=exits, body=body)
282    else:
283        return loop
284
285
286# Used to describe a nullified condition in dead branch pruning
287nullified = namedtuple('nullified', 'condition, taken_br, rewrite_stmt')
288
289
290# Functions to manipulate IR
291def dead_branch_prune(func_ir, called_args):
292    """
293    Removes dead branches based on constant inference from function args.
294    This directly mutates the IR.
295
296    func_ir is the IR
297    called_args are the actual arguments with which the function is called
298    """
299    from numba.core.ir_utils import (get_definition, guard, find_const,
300                                     GuardException)
301
302    DEBUG = 0
303
304    def find_branches(func_ir):
305        # find *all* branches
306        branches = []
307        for blk in func_ir.blocks.values():
308            branch_or_jump = blk.body[-1]
309            if isinstance(branch_or_jump, ir.Branch):
310                branch = branch_or_jump
311                pred = guard(get_definition, func_ir, branch.cond.name)
312                if pred is not None and pred.op == "call":
313                    function = guard(get_definition, func_ir, pred.func)
314                    if (function is not None and
315                        isinstance(function, ir.Global) and
316                            function.value is bool):
317                        condition = guard(get_definition, func_ir, pred.args[0])
318                        if condition is not None:
319                            branches.append((branch, condition, blk))
320        return branches
321
322    def do_prune(take_truebr, blk):
323        keep = branch.truebr if take_truebr else branch.falsebr
324        # replace the branch with a direct jump
325        jmp = ir.Jump(keep, loc=branch.loc)
326        blk.body[-1] = jmp
327        return 1 if keep == branch.truebr else 0
328
329    def prune_by_type(branch, condition, blk, *conds):
330        # this prunes a given branch and fixes up the IR
331        # at least one needs to be a NoneType
332        lhs_cond, rhs_cond = conds
333        lhs_none = isinstance(lhs_cond, types.NoneType)
334        rhs_none = isinstance(rhs_cond, types.NoneType)
335        if lhs_none or rhs_none:
336            try:
337                take_truebr = condition.fn(lhs_cond, rhs_cond)
338            except Exception:
339                return False, None
340            if DEBUG > 0:
341                kill = branch.falsebr if take_truebr else branch.truebr
342                print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
343                      condition.fn)
344            taken = do_prune(take_truebr, blk)
345            return True, taken
346        return False, None
347
348    def prune_by_value(branch, condition, blk, *conds):
349        lhs_cond, rhs_cond = conds
350        try:
351            take_truebr = condition.fn(lhs_cond, rhs_cond)
352        except Exception:
353            return False, None
354        if DEBUG > 0:
355            kill = branch.falsebr if take_truebr else branch.truebr
356            print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
357        taken = do_prune(take_truebr, blk)
358        return True, taken
359
360    def prune_by_predicate(branch, pred, blk):
361        try:
362            # Just to prevent accidents, whilst already guarded, ensure this
363            # is an ir.Const
364            if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
365                raise TypeError('Expected constant Numba IR node')
366            take_truebr = bool(pred.value)
367        except TypeError:
368            return False, None
369        if DEBUG > 0:
370            kill = branch.falsebr if take_truebr else branch.truebr
371            print("Pruning %s" % kill, branch, pred)
372        taken = do_prune(take_truebr, blk)
373        return True, taken
374
375    class Unknown(object):
376        pass
377
378    def resolve_input_arg_const(input_arg_idx):
379        """
380        Resolves an input arg to a constant (if possible)
381        """
382        input_arg_ty = called_args[input_arg_idx]
383
384        # comparing to None?
385        if isinstance(input_arg_ty, types.NoneType):
386            return input_arg_ty
387
388        # is it a kwarg default
389        if isinstance(input_arg_ty, types.Omitted):
390            val = input_arg_ty.value
391            if isinstance(val, types.NoneType):
392                return val
393            elif val is None:
394                return types.NoneType('none')
395
396        # literal type, return the type itself so comparisons like `x == None`
397        # still work as e.g. x = types.int64 will never be None/NoneType so
398        # the branch can still be pruned
399        return getattr(input_arg_ty, 'literal_type', Unknown())
400
401    if DEBUG > 1:
402        print("before".center(80, '-'))
403        print(func_ir.dump())
404
405    # This looks for branches where:
406    # at least one arg of the condition is in input args and const
407    # at least one an arg of the condition is a const
408    # if the condition is met it will replace the branch with a jump
409    branch_info = find_branches(func_ir)
410    # stores conditions that have no impact post prune
411    nullified_conditions = []
412
413    for branch, condition, blk in branch_info:
414        const_conds = []
415        if isinstance(condition, ir.Expr) and condition.op == 'binop':
416            prune = prune_by_value
417            for arg in [condition.lhs, condition.rhs]:
418                resolved_const = Unknown()
419                arg_def = guard(get_definition, func_ir, arg)
420                if isinstance(arg_def, ir.Arg):
421                    # it's an e.g. literal argument to the function
422                    resolved_const = resolve_input_arg_const(arg_def.index)
423                    prune = prune_by_type
424                else:
425                    # it's some const argument to the function, cannot use guard
426                    # here as the const itself may be None
427                    try:
428                        resolved_const = find_const(func_ir, arg)
429                        if resolved_const is None:
430                            resolved_const = types.NoneType('none')
431                    except GuardException:
432                        pass
433
434                if not isinstance(resolved_const, Unknown):
435                    const_conds.append(resolved_const)
436
437            # lhs/rhs are consts
438            if len(const_conds) == 2:
439                # prune the branch, switch the branch for an unconditional jump
440                prune_stat, taken = prune(branch, condition, blk, *const_conds)
441                if(prune_stat):
442                    # add the condition to the list of nullified conditions
443                    nullified_conditions.append(nullified(condition, taken,
444                                                          True))
445        else:
446            # see if this is a branch on a constant value predicate
447            resolved_const = Unknown()
448            try:
449                pred_call = get_definition(func_ir, branch.cond)
450                resolved_const = find_const(func_ir, pred_call.args[0])
451                if resolved_const is None:
452                    resolved_const = types.NoneType('none')
453            except GuardException:
454                pass
455
456            if not isinstance(resolved_const, Unknown):
457                prune_stat, taken = prune_by_predicate(branch, condition, blk)
458                if(prune_stat):
459                    # add the condition to the list of nullified conditions
460                    nullified_conditions.append(nullified(condition, taken,
461                                                          False))
462
463    # 'ERE BE DRAGONS...
464    # It is the evaluation of the condition expression that often trips up type
465    # inference, so ideally it would be removed as it is effectively rendered
466    # dead by the unconditional jump if a branch was pruned. However, there may
467    # be references to the condition that exist in multiple places (e.g. dels)
468    # and we cannot run DCE here as typing has not taken place to give enough
469    # information to run DCE safely. Upshot of all this is the condition gets
470    # rewritten below into a benign const that typing will be happy with and DCE
471    # can remove it and its reference post typing when it is safe to do so
472    # (if desired). It is required that the const is assigned a value that
473    # indicates the branch taken as its mutated value would be read in the case
474    # of object mode fall back in place of the condition itself. For
475    # completeness the func_ir._definitions and ._consts are also updated to
476    # make the IR state self consistent.
477
478    deadcond = [x.condition for x in nullified_conditions]
479    for _, cond, blk in branch_info:
480        if cond in deadcond:
481            for x in blk.body:
482                if isinstance(x, ir.Assign) and x.value is cond:
483                    # rewrite the condition as a true/false bit
484                    nullified_info = nullified_conditions[deadcond.index(cond)]
485                    # only do a rewrite of conditions, predicates need to retain
486                    # their value as they may be used later.
487                    if nullified_info.rewrite_stmt:
488                        branch_bit = nullified_info.taken_br
489                        x.value = ir.Const(branch_bit, loc=x.loc)
490                        # update the specific definition to the new const
491                        defns = func_ir._definitions[x.target.name]
492                        repl_idx = defns.index(cond)
493                        defns[repl_idx] = x.value
494
495    # Remove dead blocks, this is safe as it relies on the CFG only.
496    cfg = compute_cfg_from_blocks(func_ir.blocks)
497    for dead in cfg.dead_nodes():
498        del func_ir.blocks[dead]
499
500    # if conditions were nullified then consts were rewritten, update
501    if nullified_conditions:
502        func_ir._consts = consts.ConstantInference(func_ir)
503
504    if DEBUG > 1:
505        print("after".center(80, '-'))
506        print(func_ir.dump())
507
508
509def rewrite_semantic_constants(func_ir, called_args):
510    """
511    This rewrites values known to be constant by their semantics as ir.Const
512    nodes, this is to give branch pruning the best chance possible of killing
513    branches. An example might be rewriting len(tuple) as the literal length.
514
515    func_ir is the IR
516    called_args are the actual arguments with which the function is called
517    """
518    DEBUG = 0
519
520    if DEBUG > 1:
521        print(("rewrite_semantic_constants: " +
522               func_ir.func_id.func_name).center(80, '-'))
523        print("before".center(80, '*'))
524        func_ir.dump()
525
526    def rewrite_statement(func_ir, stmt, new_val):
527        """
528        Rewrites the stmt as a ir.Const new_val and fixes up the entries in
529        func_ir._definitions
530        """
531        stmt.value = ir.Const(new_val, stmt.loc)
532        defns = func_ir._definitions[stmt.target.name]
533        repl_idx = defns.index(val)
534        defns[repl_idx] = stmt.value
535
536    def rewrite_array_ndim(val, func_ir, called_args):
537        # rewrite Array.ndim as const(ndim)
538        if getattr(val, 'op', None) == 'getattr':
539            if val.attr == 'ndim':
540                arg_def = guard(get_definition, func_ir, val.value)
541                if isinstance(arg_def, ir.Arg):
542                    argty = called_args[arg_def.index]
543                    if isinstance(argty, types.Array):
544                        rewrite_statement(func_ir, stmt, argty.ndim)
545
546    def rewrite_tuple_len(val, func_ir, called_args):
547        # rewrite len(tuple) as const(len(tuple))
548        if getattr(val, 'op', None) == 'call':
549            func = guard(get_definition, func_ir, val.func)
550            if (func is not None and isinstance(func, ir.Global) and
551                    getattr(func, 'value', None) is len):
552
553                (arg,) = val.args
554                arg_def = guard(get_definition, func_ir, arg)
555                if isinstance(arg_def, ir.Arg):
556                    argty = called_args[arg_def.index]
557                    if isinstance(argty, types.BaseTuple):
558                        rewrite_statement(func_ir, stmt, argty.count)
559
560    from numba.core.ir_utils import get_definition, guard
561    for blk in func_ir.blocks.values():
562        for stmt in blk.body:
563            if isinstance(stmt, ir.Assign):
564                val = stmt.value
565                if isinstance(val, ir.Expr):
566                    rewrite_array_ndim(val, func_ir, called_args)
567                    rewrite_tuple_len(val, func_ir, called_args)
568
569    if DEBUG > 1:
570        print("after".center(80, '*'))
571        func_ir.dump()
572        print('-' * 80)
573
574
575def find_literally_calls(func_ir, argtypes):
576    """An analysis to find `numba.literally` call inside the given IR.
577    When an unsatisfied literal typing request is found, a `ForceLiteralArg`
578    exception is raised.
579
580    Parameters
581    ----------
582
583    func_ir : numba.ir.FunctionIR
584
585    argtypes : Sequence[numba.types.Type]
586        The argument types.
587    """
588    from numba.core import ir_utils
589
590    marked_args = set()
591    first_loc = {}
592    # Scan for literally calls
593    for blk in func_ir.blocks.values():
594        for assign in blk.find_exprs(op='call'):
595            var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func)
596            if isinstance(var, (ir.Global, ir.FreeVar)):
597                fnobj = var.value
598            else:
599                fnobj = ir_utils.guard(ir_utils.resolve_func_from_module,
600                                       func_ir, var)
601            if fnobj is special.literally:
602                # Found
603                [arg] = assign.args
604                defarg = func_ir.get_definition(arg)
605                if isinstance(defarg, ir.Arg):
606                    argindex = defarg.index
607                    marked_args.add(argindex)
608                    first_loc.setdefault(argindex, assign.loc)
609    # Signal the dispatcher to force literal typing
610    for pos in marked_args:
611        query_arg = argtypes[pos]
612        do_raise = (isinstance(query_arg, types.InitialValue) and
613                    query_arg.initial_value is None)
614        if do_raise:
615            loc = first_loc[pos]
616            raise errors.ForceLiteralArg(marked_args, loc=loc)
617
618        if not isinstance(query_arg, (types.Literal, types.InitialValue)):
619            loc = first_loc[pos]
620            raise errors.ForceLiteralArg(marked_args, loc=loc)
621