1import numba
2from numba.tests.support import TestCase, unittest
3from numba.core.registry import cpu_target
4from numba.core.compiler import CompilerBase, Flags
5from numba.core.compiler_machinery import PassManager
6from numba.core import types, ir, bytecode, compiler, ir_utils, registry
7from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode,
8                                       FixupArgs, IRProcessing,)
9
10from numba.core.typed_passes import (NopythonTypeInference,
11                                     type_inference_stage, DeadCodeElimination)
12from numba.experimental import jitclass
13
14# global constant for testing find_const
15GLOBAL_B = 11
16
17
18@jitclass([('val', numba.core.types.List(numba.intp))])
19class Dummy(object):
20    def __init__(self, val):
21        self.val = val
22
23
24class TestIrUtils(TestCase):
25    """
26    Tests ir handling utility functions like find_callname.
27    """
28
29    def test_obj_func_match(self):
30        """Test matching of an object method (other than Array see #3449)
31        """
32
33        def test_func():
34            d = Dummy([1])
35            d.val.append(2)
36
37        test_ir = compiler.run_frontend(test_func)
38        typingctx = cpu_target.typing_context
39        typemap, _, _ = type_inference_stage(
40            typingctx, test_ir, (), None)
41        matched_call = ir_utils.find_callname(
42            test_ir, test_ir.blocks[0].body[8].value, typemap)
43        self.assertTrue(isinstance(matched_call, tuple) and
44                        len(matched_call) == 2 and
45                        matched_call[0] == 'append')
46
47    def test_dead_code_elimination(self):
48
49        class Tester(CompilerBase):
50
51            @classmethod
52            def mk_pipeline(cls, args, return_type=None, flags=None, locals={},
53                            library=None, typing_context=None,
54                            target_context=None):
55                if not flags:
56                    flags = Flags()
57                flags.nrt = True
58                if typing_context is None:
59                    typing_context = registry.cpu_target.typing_context
60                if target_context is None:
61                    target_context = registry.cpu_target.target_context
62                return cls(typing_context, target_context, library, args,
63                           return_type, flags, locals)
64
65            def compile_to_ir(self, func, DCE=False):
66                """
67                Compile and return IR
68                """
69                func_id = bytecode.FunctionIdentity.from_function(func)
70                self.state.func_id = func_id
71                ExtractByteCode().run_pass(self.state)
72                state = self.state
73
74                name = "DCE_testing"
75                pm = PassManager(name)
76                pm.add_pass(TranslateByteCode, "analyzing bytecode")
77                pm.add_pass(FixupArgs, "fix up args")
78                pm.add_pass(IRProcessing, "processing IR")
79                pm.add_pass(NopythonTypeInference, "nopython frontend")
80                if DCE is True:
81                    pm.add_pass(DeadCodeElimination, "DCE after typing")
82                pm.finalize()
83                pm.run(state)
84                return state.func_ir
85
86        def check_initial_ir(the_ir):
87            # dead stuff:
88            # a const int value 0xdead
89            # an assign of above into to variable `dead`
90            # a const int above 0xdeaddead
91            # an assign of said int to variable `deaddead`
92            # this is 4 things to remove
93
94            self.assertEqual(len(the_ir.blocks), 1)
95            block = the_ir.blocks[0]
96            deads = []
97            for x in block.find_insts(ir.Assign):
98                if isinstance(getattr(x, 'target', None), ir.Var):
99                    if 'dead' in getattr(x.target, 'name', ''):
100                        deads.append(x)
101
102            expect_removed = []
103            self.assertEqual(len(deads), 2)
104            expect_removed.extend(deads)
105            for d in deads:
106                # check the ir.Const is the definition and the value is expected
107                const_val = the_ir.get_definition(d.value)
108                self.assertTrue(int('0x%s' % d.target.name, 16),
109                                const_val.value)
110                expect_removed.append(const_val)
111
112            self.assertEqual(len(expect_removed), 4)
113            return expect_removed
114
115        def check_dce_ir(the_ir):
116            self.assertEqual(len(the_ir.blocks), 1)
117            block = the_ir.blocks[0]
118            deads = []
119            consts = []
120            for x in block.find_insts(ir.Assign):
121                if isinstance(getattr(x, 'target', None), ir.Var):
122                    if 'dead' in getattr(x.target, 'name', ''):
123                        deads.append(x)
124                if isinstance(getattr(x, 'value', None), ir.Const):
125                    consts.append(x)
126            self.assertEqual(len(deads), 0)
127
128            # check the consts to make sure there's no reference to 0xdead or
129            # 0xdeaddead
130            for x in consts:
131                self.assertTrue(x.value.value not in [0xdead, 0xdeaddead])
132
133        def foo(x):
134            y = x + 1
135            dead = 0xdead  # noqa
136            z = y + 2
137            deaddead = 0xdeaddead  # noqa
138            ret = z * z
139            return ret
140
141        test_pipeline = Tester.mk_pipeline((types.intp,))
142        no_dce = test_pipeline.compile_to_ir(foo)
143        removed = check_initial_ir(no_dce)
144
145        test_pipeline = Tester.mk_pipeline((types.intp,))
146        w_dce = test_pipeline.compile_to_ir(foo, DCE=True)
147        check_dce_ir(w_dce)
148
149        # check that the count of initial - removed = dce
150        self.assertEqual(len(no_dce.blocks[0].body) - len(removed),
151                         len(w_dce.blocks[0].body))
152
153    def test_find_const_global(self):
154        """
155        Test find_const() for values in globals (ir.Global) and freevars
156        (ir.FreeVar) that are considered constants for compilation.
157        """
158        FREEVAR_C = 12
159
160        def foo(a):
161            b = GLOBAL_B
162            c = FREEVAR_C
163            return a + b + c
164
165        f_ir = compiler.run_frontend(foo)
166        block = f_ir.blocks[0]
167        const_b = None
168        const_c = None
169
170        for inst in block.body:
171            if isinstance(inst, ir.Assign) and inst.target.name == 'b':
172                const_b = ir_utils.guard(
173                    ir_utils.find_const, f_ir, inst.target)
174            if isinstance(inst, ir.Assign) and inst.target.name == 'c':
175                const_c = ir_utils.guard(
176                    ir_utils.find_const, f_ir, inst.target)
177
178        self.assertEqual(const_b, GLOBAL_B)
179        self.assertEqual(const_c, FREEVAR_C)
180
181    def test_flatten_labels(self):
182        """ tests flatten_labels """
183        def foo(a):
184            acc = 0
185            if a > 3:
186                acc += 1
187                if a > 19:
188                    return 53
189            elif a < 1000:
190                if a >= 12:
191                    acc += 1
192                for x in range(10):
193                    acc -= 1
194                    if acc < 2:
195                        break
196                else:
197                    acc += 7
198            else:
199                raise ValueError("some string")
200            return acc
201
202        def bar(a):
203            acc = 0
204            z = 12
205            if a > 3:
206                acc += 1
207                z += 12
208                if a > 19:
209                    z += 12
210                    return 53
211            elif a < 1000:
212                if a >= 12:
213                    z += 12
214                    acc += 1
215                for x in range(10):
216                    z += 12
217                    acc -= 1
218                    if acc < 2:
219                        break
220                else:
221                    z += 12
222                    acc += 7
223            else:
224                raise ValueError("some string")
225            return acc
226
227        def baz(a):
228            acc = 0
229            if a > 3:
230                acc += 1
231                if a > 19:
232                    return 53
233                else: # extra control flow in comparison to foo
234                    return 55
235            elif a < 1000:
236                if a >= 12:
237                    acc += 1
238                for x in range(10):
239                    acc -= 1
240                    if acc < 2:
241                        break
242                else:
243                    acc += 7
244            else:
245                raise ValueError("some string")
246            return acc
247
248        def get_flat_cfg(func):
249            func_ir = ir_utils.compile_to_numba_ir(func, dict())
250            flat_blocks = ir_utils.flatten_labels(func_ir.blocks)
251            self.assertEqual(max(flat_blocks.keys()) + 1, len(func_ir.blocks))
252            return ir_utils.compute_cfg_from_blocks(flat_blocks)
253
254        foo_cfg = get_flat_cfg(foo)
255        bar_cfg = get_flat_cfg(bar)
256        baz_cfg = get_flat_cfg(baz)
257
258        self.assertEqual(foo_cfg, bar_cfg)
259        self.assertNotEqual(foo_cfg, baz_cfg)
260
261
262if __name__ == "__main__":
263    unittest.main()
264