1from collections import defaultdict 2 3from numba.core import config 4 5 6class Rewrite(object): 7 '''Defines the abstract base class for Numba rewrites. 8 ''' 9 10 def __init__(self, state=None): 11 '''Constructor for the Rewrite class. 12 ''' 13 pass 14 15 def match(self, func_ir, block, typemap, calltypes): 16 '''Overload this method to check an IR block for matching terms in the 17 rewrite. 18 ''' 19 return False 20 21 def apply(self): 22 '''Overload this method to return a rewritten IR basic block when a 23 match has been found. 24 ''' 25 raise NotImplementedError("Abstract Rewrite.apply() called!") 26 27 28class RewriteRegistry(object): 29 '''Defines a registry for Numba rewrites. 30 ''' 31 _kinds = frozenset(['before-inference', 'after-inference']) 32 33 def __init__(self): 34 '''Constructor for the rewrite registry. Initializes the rewrites 35 member to an empty list. 36 ''' 37 self.rewrites = defaultdict(list) 38 39 def register(self, kind): 40 """ 41 Decorator adding a subclass of Rewrite to the registry for 42 the given *kind*. 43 """ 44 if kind not in self._kinds: 45 raise KeyError("invalid kind %r" % (kind,)) 46 def do_register(rewrite_cls): 47 if not issubclass(rewrite_cls, Rewrite): 48 raise TypeError('{0} is not a subclass of Rewrite'.format( 49 rewrite_cls)) 50 self.rewrites[kind].append(rewrite_cls) 51 return rewrite_cls 52 return do_register 53 54 def apply(self, kind, state): 55 '''Given a pipeline and a dictionary of basic blocks, exhaustively 56 attempt to apply all registered rewrites to all basic blocks. 57 ''' 58 assert kind in self._kinds 59 blocks = state.func_ir.blocks 60 old_blocks = blocks.copy() 61 for rewrite_cls in self.rewrites[kind]: 62 # Exhaustively apply a rewrite until it stops matching. 63 rewrite = rewrite_cls(state) 64 work_list = list(blocks.items()) 65 while work_list: 66 key, block = work_list.pop() 67 matches = rewrite.match(state.func_ir, block, state.typemap, 68 state.calltypes) 69 if matches: 70 if config.DEBUG or config.DUMP_IR: 71 print("_" * 70) 72 print("REWRITING (%s):" % rewrite_cls.__name__) 73 block.dump() 74 print("_" * 60) 75 new_block = rewrite.apply() 76 blocks[key] = new_block 77 work_list.append((key, new_block)) 78 if config.DEBUG or config.DUMP_IR: 79 new_block.dump() 80 print("_" * 70) 81 # If any blocks were changed, perform a sanity check. 82 for key, block in blocks.items(): 83 if block != old_blocks[key]: 84 block.verify() 85 86 # Some passes, e.g. _inline_const_arraycall are known to occasionally 87 # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid 88 # things with ir.Del, but the placement is not optimal. The lines below 89 # fix-up the IR so that ref counts are valid and optimally placed, 90 # see #4093 for context. This has to be run here opposed to in 91 # apply() as the CFG needs computing so full IR is needed. 92 from numba.core import postproc 93 post_proc = postproc.PostProcessor(state.func_ir) 94 post_proc.run() 95 96 97rewrite_registry = RewriteRegistry() 98register_rewrite = rewrite_registry.register 99