1from collections import defaultdict
2
3from numba.core import config
4
5
6class Rewrite(object):
7    '''Defines the abstract base class for Numba rewrites.
8    '''
9
10    def __init__(self, state=None):
11        '''Constructor for the Rewrite class.
12        '''
13        pass
14
15    def match(self, func_ir, block, typemap, calltypes):
16        '''Overload this method to check an IR block for matching terms in the
17        rewrite.
18        '''
19        return False
20
21    def apply(self):
22        '''Overload this method to return a rewritten IR basic block when a
23        match has been found.
24        '''
25        raise NotImplementedError("Abstract Rewrite.apply() called!")
26
27
28class RewriteRegistry(object):
29    '''Defines a registry for Numba rewrites.
30    '''
31    _kinds = frozenset(['before-inference', 'after-inference'])
32
33    def __init__(self):
34        '''Constructor for the rewrite registry.  Initializes the rewrites
35        member to an empty list.
36        '''
37        self.rewrites = defaultdict(list)
38
39    def register(self, kind):
40        """
41        Decorator adding a subclass of Rewrite to the registry for
42        the given *kind*.
43        """
44        if kind not in self._kinds:
45            raise KeyError("invalid kind %r" % (kind,))
46        def do_register(rewrite_cls):
47            if not issubclass(rewrite_cls, Rewrite):
48                raise TypeError('{0} is not a subclass of Rewrite'.format(
49                    rewrite_cls))
50            self.rewrites[kind].append(rewrite_cls)
51            return rewrite_cls
52        return do_register
53
54    def apply(self, kind, state):
55        '''Given a pipeline and a dictionary of basic blocks, exhaustively
56        attempt to apply all registered rewrites to all basic blocks.
57        '''
58        assert kind in self._kinds
59        blocks = state.func_ir.blocks
60        old_blocks = blocks.copy()
61        for rewrite_cls in self.rewrites[kind]:
62            # Exhaustively apply a rewrite until it stops matching.
63            rewrite = rewrite_cls(state)
64            work_list = list(blocks.items())
65            while work_list:
66                key, block = work_list.pop()
67                matches = rewrite.match(state.func_ir, block, state.typemap,
68                                        state.calltypes)
69                if matches:
70                    if config.DEBUG or config.DUMP_IR:
71                        print("_" * 70)
72                        print("REWRITING (%s):" % rewrite_cls.__name__)
73                        block.dump()
74                        print("_" * 60)
75                    new_block = rewrite.apply()
76                    blocks[key] = new_block
77                    work_list.append((key, new_block))
78                    if config.DEBUG or config.DUMP_IR:
79                        new_block.dump()
80                        print("_" * 70)
81        # If any blocks were changed, perform a sanity check.
82        for key, block in blocks.items():
83            if block != old_blocks[key]:
84                block.verify()
85
86        # Some passes, e.g. _inline_const_arraycall are known to occasionally
87        # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
88        # things with ir.Del, but the placement is not optimal. The lines below
89        # fix-up the IR so that ref counts are valid and optimally placed,
90        # see #4093 for context. This has to be run here opposed to in
91        # apply() as the CFG needs computing so full IR is needed.
92        from numba.core import postproc
93        post_proc = postproc.PostProcessor(state.func_ir)
94        post_proc.run()
95
96
97rewrite_registry = RewriteRegistry()
98register_rewrite = rewrite_registry.register
99