1import numba 2from numba.tests.support import TestCase, unittest 3from numba.core.registry import cpu_target 4from numba.core.compiler import CompilerBase, Flags 5from numba.core.compiler_machinery import PassManager 6from numba.core import types, ir, bytecode, compiler, ir_utils, registry 7from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, 8 FixupArgs, IRProcessing,) 9 10from numba.core.typed_passes import (NopythonTypeInference, 11 type_inference_stage, DeadCodeElimination) 12from numba.experimental import jitclass 13 14# global constant for testing find_const 15GLOBAL_B = 11 16 17 18@jitclass([('val', numba.core.types.List(numba.intp))]) 19class Dummy(object): 20 def __init__(self, val): 21 self.val = val 22 23 24class TestIrUtils(TestCase): 25 """ 26 Tests ir handling utility functions like find_callname. 27 """ 28 29 def test_obj_func_match(self): 30 """Test matching of an object method (other than Array see #3449) 31 """ 32 33 def test_func(): 34 d = Dummy([1]) 35 d.val.append(2) 36 37 test_ir = compiler.run_frontend(test_func) 38 typingctx = cpu_target.typing_context 39 typemap, _, _ = type_inference_stage( 40 typingctx, test_ir, (), None) 41 matched_call = ir_utils.find_callname( 42 test_ir, test_ir.blocks[0].body[8].value, typemap) 43 self.assertTrue(isinstance(matched_call, tuple) and 44 len(matched_call) == 2 and 45 matched_call[0] == 'append') 46 47 def test_dead_code_elimination(self): 48 49 class Tester(CompilerBase): 50 51 @classmethod 52 def mk_pipeline(cls, args, return_type=None, flags=None, locals={}, 53 library=None, typing_context=None, 54 target_context=None): 55 if not flags: 56 flags = Flags() 57 flags.nrt = True 58 if typing_context is None: 59 typing_context = registry.cpu_target.typing_context 60 if target_context is None: 61 target_context = registry.cpu_target.target_context 62 return cls(typing_context, target_context, library, args, 63 return_type, flags, locals) 64 65 def compile_to_ir(self, func, DCE=False): 66 """ 67 Compile and return IR 68 """ 69 func_id = bytecode.FunctionIdentity.from_function(func) 70 self.state.func_id = func_id 71 ExtractByteCode().run_pass(self.state) 72 state = self.state 73 74 name = "DCE_testing" 75 pm = PassManager(name) 76 pm.add_pass(TranslateByteCode, "analyzing bytecode") 77 pm.add_pass(FixupArgs, "fix up args") 78 pm.add_pass(IRProcessing, "processing IR") 79 pm.add_pass(NopythonTypeInference, "nopython frontend") 80 if DCE is True: 81 pm.add_pass(DeadCodeElimination, "DCE after typing") 82 pm.finalize() 83 pm.run(state) 84 return state.func_ir 85 86 def check_initial_ir(the_ir): 87 # dead stuff: 88 # a const int value 0xdead 89 # an assign of above into to variable `dead` 90 # a const int above 0xdeaddead 91 # an assign of said int to variable `deaddead` 92 # this is 4 things to remove 93 94 self.assertEqual(len(the_ir.blocks), 1) 95 block = the_ir.blocks[0] 96 deads = [] 97 for x in block.find_insts(ir.Assign): 98 if isinstance(getattr(x, 'target', None), ir.Var): 99 if 'dead' in getattr(x.target, 'name', ''): 100 deads.append(x) 101 102 expect_removed = [] 103 self.assertEqual(len(deads), 2) 104 expect_removed.extend(deads) 105 for d in deads: 106 # check the ir.Const is the definition and the value is expected 107 const_val = the_ir.get_definition(d.value) 108 self.assertTrue(int('0x%s' % d.target.name, 16), 109 const_val.value) 110 expect_removed.append(const_val) 111 112 self.assertEqual(len(expect_removed), 4) 113 return expect_removed 114 115 def check_dce_ir(the_ir): 116 self.assertEqual(len(the_ir.blocks), 1) 117 block = the_ir.blocks[0] 118 deads = [] 119 consts = [] 120 for x in block.find_insts(ir.Assign): 121 if isinstance(getattr(x, 'target', None), ir.Var): 122 if 'dead' in getattr(x.target, 'name', ''): 123 deads.append(x) 124 if isinstance(getattr(x, 'value', None), ir.Const): 125 consts.append(x) 126 self.assertEqual(len(deads), 0) 127 128 # check the consts to make sure there's no reference to 0xdead or 129 # 0xdeaddead 130 for x in consts: 131 self.assertTrue(x.value.value not in [0xdead, 0xdeaddead]) 132 133 def foo(x): 134 y = x + 1 135 dead = 0xdead # noqa 136 z = y + 2 137 deaddead = 0xdeaddead # noqa 138 ret = z * z 139 return ret 140 141 test_pipeline = Tester.mk_pipeline((types.intp,)) 142 no_dce = test_pipeline.compile_to_ir(foo) 143 removed = check_initial_ir(no_dce) 144 145 test_pipeline = Tester.mk_pipeline((types.intp,)) 146 w_dce = test_pipeline.compile_to_ir(foo, DCE=True) 147 check_dce_ir(w_dce) 148 149 # check that the count of initial - removed = dce 150 self.assertEqual(len(no_dce.blocks[0].body) - len(removed), 151 len(w_dce.blocks[0].body)) 152 153 def test_find_const_global(self): 154 """ 155 Test find_const() for values in globals (ir.Global) and freevars 156 (ir.FreeVar) that are considered constants for compilation. 157 """ 158 FREEVAR_C = 12 159 160 def foo(a): 161 b = GLOBAL_B 162 c = FREEVAR_C 163 return a + b + c 164 165 f_ir = compiler.run_frontend(foo) 166 block = f_ir.blocks[0] 167 const_b = None 168 const_c = None 169 170 for inst in block.body: 171 if isinstance(inst, ir.Assign) and inst.target.name == 'b': 172 const_b = ir_utils.guard( 173 ir_utils.find_const, f_ir, inst.target) 174 if isinstance(inst, ir.Assign) and inst.target.name == 'c': 175 const_c = ir_utils.guard( 176 ir_utils.find_const, f_ir, inst.target) 177 178 self.assertEqual(const_b, GLOBAL_B) 179 self.assertEqual(const_c, FREEVAR_C) 180 181 def test_flatten_labels(self): 182 """ tests flatten_labels """ 183 def foo(a): 184 acc = 0 185 if a > 3: 186 acc += 1 187 if a > 19: 188 return 53 189 elif a < 1000: 190 if a >= 12: 191 acc += 1 192 for x in range(10): 193 acc -= 1 194 if acc < 2: 195 break 196 else: 197 acc += 7 198 else: 199 raise ValueError("some string") 200 return acc 201 202 def bar(a): 203 acc = 0 204 z = 12 205 if a > 3: 206 acc += 1 207 z += 12 208 if a > 19: 209 z += 12 210 return 53 211 elif a < 1000: 212 if a >= 12: 213 z += 12 214 acc += 1 215 for x in range(10): 216 z += 12 217 acc -= 1 218 if acc < 2: 219 break 220 else: 221 z += 12 222 acc += 7 223 else: 224 raise ValueError("some string") 225 return acc 226 227 def baz(a): 228 acc = 0 229 if a > 3: 230 acc += 1 231 if a > 19: 232 return 53 233 else: # extra control flow in comparison to foo 234 return 55 235 elif a < 1000: 236 if a >= 12: 237 acc += 1 238 for x in range(10): 239 acc -= 1 240 if acc < 2: 241 break 242 else: 243 acc += 7 244 else: 245 raise ValueError("some string") 246 return acc 247 248 def get_flat_cfg(func): 249 func_ir = ir_utils.compile_to_numba_ir(func, dict()) 250 flat_blocks = ir_utils.flatten_labels(func_ir.blocks) 251 self.assertEqual(max(flat_blocks.keys()) + 1, len(func_ir.blocks)) 252 return ir_utils.compute_cfg_from_blocks(flat_blocks) 253 254 foo_cfg = get_flat_cfg(foo) 255 bar_cfg = get_flat_cfg(bar) 256 baz_cfg = get_flat_cfg(baz) 257 258 self.assertEqual(foo_cfg, bar_cfg) 259 self.assertNotEqual(foo_cfg, baz_cfg) 260 261 262if __name__ == "__main__": 263 unittest.main() 264