1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18import tvm
19from tvm import relay
20from tvm.relay import transform
21
22
23def run_opt_pass(expr, opt_pass):
24    assert isinstance(opt_pass, transform.Pass)
25
26    mod = relay.Module.from_expr(expr)
27    mod = opt_pass(mod)
28    entry = mod["main"]
29    return entry if isinstance(expr, relay.Function) else entry.body
30
31
32def test_concatenate_const():
33    def before():
34        data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
35        const = relay.const(data)
36        concat = relay.op.concatenate([const, const], axis=0)
37        func = relay.Function([], concat)
38        return func
39
40    def expected():
41        data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
42        const = relay.const(data)
43        func = relay.Function([], const)
44        return func
45
46    zz = run_opt_pass(before(), transform.FoldConstant())
47    zexpected = run_opt_pass(expected(), transform.InferType())
48    assert relay.analysis.graph_equal(zz, zexpected)
49
50
51def test_fold_const():
52    c_data = np.array([1, 2, 3]).astype("float32")
53    t = relay.TensorType([1, 2, 3], "float32")
54    def before():
55        c = relay.const(c_data)
56        x = relay.var("x", t)
57        y = relay.add(c, c)
58        y = relay.multiply(y, relay.const(2, "float32"))
59        y = relay.add(x, y)
60        z = relay.add(y, c)
61        return relay.Function([x], z)
62
63    def expected():
64        x = relay.var("x", t)
65        c_folded = (c_data + c_data) * 2
66        y = relay.add(x, relay.const(c_folded))
67        z = relay.add(y, relay.const(c_data))
68        return relay.Function([x], z)
69
70    def fail(x):
71        raise RuntimeError()
72
73    # the fold constant should work on any context.
74    with tvm.build_config(add_lower_pass=[(0, fail)]):
75        with tvm.target.create("cuda"):
76            zz = run_opt_pass(before(), transform.FoldConstant())
77    zexpected = run_opt_pass(expected(), transform.InferType())
78    assert relay.analysis.alpha_equal(zz, zexpected)
79
80
81def test_fold_let():
82    c_data = np.array(1).astype("float32")
83    t = relay.TensorType([1], "float32")
84    def before():
85        sb = relay.ScopeBuilder()
86        x = relay.var("x", t)
87        t1 = sb.let("t1", relay.const(c_data))
88        t2 = sb.let("t2", relay.add(t1, t1))
89        t3 = sb.let("t3", relay.add(t2, x))
90        sb.ret(t3)
91        return relay.Function([x], sb.get())
92
93    def expected():
94        sb = relay.ScopeBuilder()
95        x = relay.var("x", t)
96        c_folded = (c_data + c_data)
97        t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
98        sb.ret(t3)
99        return relay.Function([x], sb.get())
100
101    zz = run_opt_pass(before(), transform.FoldConstant())
102    zexpected = run_opt_pass(expected(), transform.InferType())
103    assert relay.analysis.graph_equal(zz, zexpected)
104
105
106def test_fold_tuple():
107    c_data = np.array(1).astype("float32")
108    t = relay.TensorType([1], "float32")
109    def before():
110        c = relay.const(c_data)
111        x = relay.var("x", t)
112        y = relay.Tuple([x, c])
113        z = relay.add(y[1], c)
114        z = relay.add(z, y[0])
115        return relay.Function([x], z)
116
117    def expected():
118        c = relay.const(c_data + c_data)
119        x = relay.var("x", t)
120        z = relay.add(c, x)
121        return relay.Function([x], z)
122
123    zz = run_opt_pass(before(), transform.FoldConstant())
124    zexpected = run_opt_pass(expected(), transform.InferType())
125    assert relay.analysis.graph_equal(zz, zexpected)
126
127
128def test_fold_concat():
129    c_data = np.array([[1, 2, 3]]).astype("float32")
130
131    def before():
132        a = relay.const(c_data)
133        b = relay.const(c_data)
134        y = relay.concatenate((a, b), axis=0)
135        return relay.Function([], y)
136
137    def expected():
138        y_data = np.concatenate((c_data, c_data), axis=0)
139        y = relay.const(y_data)
140        return relay.Function([], y)
141
142    zz = run_opt_pass(before(), transform.FoldConstant())
143    zexpected = run_opt_pass(expected(), transform.InferType())
144    assert relay.analysis.graph_equal(zz, zexpected)
145
146
147def test_fold_shape_of():
148    c_shape = (8, 9, 10)
149    def before(dtype):
150        x = relay.var("x", shape=c_shape, dtype="float32")
151        y = relay.var("y", shape=c_shape, dtype="float32")
152        z = relay.shape_of(x + y, dtype)
153        return relay.Function([x, y], z)
154
155    def expected(dtype):
156        x = relay.var("x", shape=c_shape, dtype="float32")
157        y = relay.var("y", shape=c_shape, dtype="float32")
158        z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype)
159        func = relay.Function([x, y], z)
160        return func
161
162    for dtype in ["int32", "float32"]:
163        zz = run_opt_pass(before(dtype), transform.FoldConstant())
164        zexpected = run_opt_pass(expected(dtype), transform.InferType())
165        assert relay.analysis.graph_equal(zz, zexpected)
166
167
168def test_fold_full():
169    c_shape = (8, 9, 10)
170    def before():
171        dtype = 'float32'
172        return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
173
174    def expected():
175        # expect no changes
176        return before()
177
178    zz = run_opt_pass(before(), transform.FoldConstant())
179    zexpected = run_opt_pass(expected(), transform.InferType())
180    assert relay.analysis.graph_equal(zz, zexpected)
181
182
183if __name__ == "__main__":
184    test_fold_const()
185    test_fold_let()
186    test_fold_tuple()
187    test_fold_concat()
188    test_fold_shape_of()
189    test_fold_full()
190