1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18import tvm
19from tvm import te
20from tvm import relay
21from tvm.relay import transform
22from tvm.relay.build_module import bind_params_by_name
23from tvm.relay.testing import run_infer_type, create_workload
24
25
26def run_opt_pass(expr, opt_pass):
27    assert isinstance(opt_pass, tvm.transform.Pass)
28
29    mod = tvm.IRModule.from_expr(expr)
30    mod = opt_pass(mod)
31    entry = mod["main"]
32    return entry if isinstance(expr, relay.Function) else entry.body
33
34
35def test_concatenate_const():
36    def before():
37        data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
38        const = relay.const(data)
39        concat = relay.op.concatenate([const, const], axis=0)
40        func = relay.Function([], concat)
41        return func
42
43    def expected():
44        data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
45        const = relay.const(data)
46        func = relay.Function([], const)
47        return func
48
49    zz = run_opt_pass(before(), transform.FoldConstant())
50    zexpected = run_opt_pass(expected(), transform.InferType())
51    assert tvm.ir.structural_equal(zz, zexpected)
52
53
54def test_fold_const():
55    c_data = np.array([1, 2, 3]).astype("float32")
56    t = relay.TensorType([1, 2, 3], "float32")
57
58    def before():
59        c = relay.const(c_data)
60        x = relay.var("x", t)
61        y = relay.add(c, c)
62        y = relay.multiply(y, relay.const(2, "float32"))
63        y = relay.add(x, y)
64        z = relay.add(y, c)
65        return relay.Function([x], z)
66
67    def expected():
68        x = relay.var("x", t)
69        c_folded = (c_data + c_data) * 2
70        y = relay.add(x, relay.const(c_folded))
71        z = relay.add(y, relay.const(c_data))
72        return relay.Function([x], z)
73
74    # the fold constant should work on any context.
75    with tvm.target.Target("cuda"):
76        zz = run_opt_pass(before(), transform.FoldConstant())
77    zexpected = run_opt_pass(expected(), transform.InferType())
78    assert tvm.ir.structural_equal(zz, zexpected)
79
80
81def test_fold_let():
82    c_data = np.array(1).astype("float32")
83    t = relay.TensorType([1], "float32")
84
85    def before():
86        sb = relay.ScopeBuilder()
87        x = relay.var("x", t)
88        t1 = sb.let("t1", relay.const(c_data))
89        t2 = sb.let("t2", relay.add(t1, t1))
90        t3 = sb.let("t3", relay.add(t2, x))
91        sb.ret(t3)
92        return relay.Function([x], sb.get())
93
94    def expected():
95        sb = relay.ScopeBuilder()
96        x = relay.var("x", t)
97        c_folded = c_data + c_data
98        t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
99        sb.ret(t3)
100        return relay.Function([x], sb.get())
101
102    zz = run_opt_pass(before(), transform.FoldConstant())
103    zexpected = run_opt_pass(expected(), transform.InferType())
104    assert tvm.ir.structural_equal(zz, zexpected)
105
106
107def test_fold_tuple():
108    c_data = np.array(1).astype("float32")
109    t = relay.TensorType([1], "float32")
110
111    def before():
112        c = relay.const(c_data)
113        x = relay.var("x", t)
114        y = relay.Tuple([x, c])
115        z = relay.add(y[1], c)
116        z = relay.add(z, y[0])
117        return relay.Function([x], z)
118
119    def expected():
120        c = relay.const(c_data + c_data)
121        x = relay.var("x", t)
122        z = relay.add(c, x)
123        return relay.Function([x], z)
124
125    zz = run_opt_pass(before(), transform.FoldConstant())
126    zexpected = run_opt_pass(expected(), transform.InferType())
127    assert tvm.ir.structural_equal(zz, zexpected)
128
129
130def test_fold_concat():
131    c_data = np.array([[1, 2, 3]]).astype("float32")
132
133    def before():
134        a = relay.const(c_data)
135        b = relay.const(c_data)
136        y = relay.concatenate((a, b), axis=0)
137        return relay.Function([], y)
138
139    def expected():
140        y_data = np.concatenate((c_data, c_data), axis=0)
141        y = relay.const(y_data)
142        return relay.Function([], y)
143
144    zz = run_opt_pass(before(), transform.FoldConstant())
145    zexpected = run_opt_pass(expected(), transform.InferType())
146    assert tvm.ir.structural_equal(zz, zexpected)
147
148
149def test_fold_shape_of():
150    c_shape = (8, 9, 10)
151
152    def before(dtype):
153        x = relay.var("x", shape=c_shape, dtype="float32")
154        y = relay.var("y", shape=c_shape, dtype="float32")
155        z = relay.shape_of(x + y, dtype)
156        return relay.Function([x, y], z)
157
158    def expected(dtype):
159        x = relay.var("x", shape=c_shape, dtype="float32")
160        y = relay.var("y", shape=c_shape, dtype="float32")
161        z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype)
162        func = relay.Function([x, y], z)
163        return func
164
165    for dtype in ["int32", "float32"]:
166        zz = run_opt_pass(before(dtype), transform.FoldConstant())
167        zexpected = run_opt_pass(expected(dtype), transform.InferType())
168        assert tvm.ir.structural_equal(zz, zexpected)
169
170
171def test_fold_ndarray_size():
172    c_shape = (8, 9, 10)
173
174    def before(dtype):
175        x = relay.var("x", shape=c_shape, dtype="float32")
176        y = relay.var("y", shape=c_shape, dtype="float32")
177        z = relay.ndarray_size(x + y, dtype)
178        return relay.Function([x, y], z)
179
180    def expected(dtype):
181        x = relay.var("x", shape=c_shape, dtype="float32")
182        y = relay.var("y", shape=c_shape, dtype="float32")
183        z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype)
184        func = relay.Function([x, y], z)
185        return func
186
187    for dtype in ["int32", "float32"]:
188        zz = run_opt_pass(before(dtype), transform.FoldConstant())
189        zexpected = run_opt_pass(expected(dtype), transform.InferType())
190        assert tvm.ir.structural_equal(zz, zexpected)
191
192
193def test_fold_full():
194    c_shape = (8, 9, 10)
195
196    def before():
197        dtype = "float32"
198        return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
199
200    def expected():
201        # expect no changes
202        return before()
203
204    zz = run_opt_pass(before(), transform.FoldConstant())
205    zexpected = run_opt_pass(expected(), transform.InferType())
206    assert tvm.ir.structural_equal(zz, zexpected)
207
208
209def test_fold_batch_norm():
210    def expected():
211        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
212        weight = relay.const(np.zeros((16, 3, 3, 3)))
213        bias = relay.const(np.zeros((16, 1, 1)))
214        conv = relay.nn.conv2d(
215            data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
216        )
217        add = relay.add(conv, bias)
218        return relay.Function(relay.analysis.free_vars(add), add)
219
220    remove_bn_pass = tvm.transform.Sequential(
221        [
222            relay.transform.InferType(),
223            relay.transform.SimplifyInference(),
224            relay.transform.FoldConstant(),
225            relay.transform.FoldScaleAxis(),
226        ]
227    )
228
229    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
230    weight = relay.var("weight")
231    bn_gamma = relay.var("bn_gamma")
232    bn_beta = relay.var("bn_beta")
233    bn_mmean = relay.var("bn_mean")
234    bn_mvar = relay.var("bn_var")
235
236    conv = relay.nn.conv2d(
237        data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
238    )
239    bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar)
240
241    def initializer(_, param):
242        param = np.zeros(param.shape)
243
244    mod, params = create_workload(bn_output[0], initializer)
245    mod["main"] = bind_params_by_name(mod["main"], params)
246
247    with tvm.transform.PassContext(opt_level=3):
248        mod = remove_bn_pass(mod)
249
250    expect = run_infer_type(expected())
251    assert tvm.ir.structural_equal(mod["main"], expect)
252
253
254if __name__ == "__main__":
255    test_fold_const()
256    test_fold_let()
257    test_fold_tuple()
258    test_fold_concat()
259    test_fold_shape_of()
260    test_fold_full()
261    test_fold_batch_norm()
262    test_fold_ndarray_size()
263