1#
2# Copyright (c) 2017 Intel Corporation
3# SPDX-License-Identifier: BSD-2-Clause
4#
5
6from numba.core import types, typing, ir, config, compiler, cpu
7from numba.core.registry import cpu_target
8from numba.core.annotations import type_annotations
9from numba.core.ir_utils import (copy_propagate, apply_copy_propagate,
10                            get_name_var_table)
11from numba.core.typed_passes import type_inference_stage
12import unittest
13
14def test_will_propagate(b, z, w):
15    x = 3
16    if b > 0:
17        y = z + w
18    else:
19        y = 0
20    a = 2 * x
21    return a < b
22
23def test_wont_propagate(b, z, w):
24    x = 3
25    if b > 0:
26        y = z + w
27        x = 1
28    else:
29        y = 0
30    a = 2 * x
31    return a < b
32
33def null_func(a,b,c,d):
34    False
35
36def inListVar(list_var, var):
37    for i in list_var:
38        if i.name == var:
39            return True
40    return False
41
42def findAssign(func_ir, var):
43    for label, block in func_ir.blocks.items():
44        for i, inst in enumerate(block.body):
45            if isinstance(inst, ir.Assign) and inst.target.name!=var:
46                all_var = inst.list_vars()
47                if inListVar(all_var, var):
48                    return True
49
50    return False
51
52class TestCopyPropagate(unittest.TestCase):
53    def test1(self):
54        typingctx = typing.Context()
55        targetctx = cpu.CPUContext(typingctx)
56        test_ir = compiler.run_frontend(test_will_propagate)
57        #print("Num blocks = ", len(test_ir.blocks))
58        #print(test_ir.dump())
59        with cpu_target.nested_context(typingctx, targetctx):
60            typingctx.refresh()
61            targetctx.refresh()
62            args = (types.int64, types.int64, types.int64)
63            typemap, return_type, calltypes = type_inference_stage(typingctx, test_ir, args, None)
64            #print("typemap = ", typemap)
65            #print("return_type = ", return_type)
66            type_annotation = type_annotations.TypeAnnotation(
67                func_ir=test_ir,
68                typemap=typemap,
69                calltypes=calltypes,
70                lifted=(),
71                lifted_from=None,
72                args=args,
73                return_type=return_type,
74                html_output=config.HTML)
75            in_cps, out_cps = copy_propagate(test_ir.blocks, typemap)
76            apply_copy_propagate(test_ir.blocks, in_cps, get_name_var_table(test_ir.blocks), typemap, calltypes)
77
78            self.assertFalse(findAssign(test_ir, "x"))
79
80    def test2(self):
81        typingctx = typing.Context()
82        targetctx = cpu.CPUContext(typingctx)
83        test_ir = compiler.run_frontend(test_wont_propagate)
84        #print("Num blocks = ", len(test_ir.blocks))
85        #print(test_ir.dump())
86        with cpu_target.nested_context(typingctx, targetctx):
87            typingctx.refresh()
88            targetctx.refresh()
89            args = (types.int64, types.int64, types.int64)
90            typemap, return_type, calltypes = type_inference_stage(typingctx, test_ir, args, None)
91            type_annotation = type_annotations.TypeAnnotation(
92                func_ir=test_ir,
93                typemap=typemap,
94                calltypes=calltypes,
95                lifted=(),
96                lifted_from=None,
97                args=args,
98                return_type=return_type,
99                html_output=config.HTML)
100            in_cps, out_cps = copy_propagate(test_ir.blocks, typemap)
101            apply_copy_propagate(test_ir.blocks, in_cps, get_name_var_table(test_ir.blocks), typemap, calltypes)
102
103            self.assertTrue(findAssign(test_ir, "x"))
104
105if __name__ == "__main__":
106    unittest.main()
107