1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test legalize pass"""
18import numpy as np
19import tvm
20
21from tvm import relay
22from tvm.contrib import graph_runtime
23from tvm.relay import transform, analysis
24from tvm.relay.testing.temp_op_attr import TempOpAttr
25
26
27def run_opt_pass(expr, passes):
28    passes = passes if isinstance(passes, list) else [passes]
29    mod = relay.Module.from_expr(expr)
30    seq = transform.Sequential(passes)
31    with transform.PassContext(opt_level=3):
32        mod = seq(mod)
33    entry = mod["main"]
34    return entry if isinstance(expr, relay.Function) else entry.body
35
36def test_legalize():
37    """Test directly replacing an operator with a new one"""
38    def before():
39        x = relay.var("x", shape=(1, 64, 56, 56))
40        weight = relay.var('weight', shape=(64, 64, 3, 3))
41        y = relay.nn.conv2d(x, weight,
42                            channels=64,
43                            kernel_size=(3, 3),
44                            padding=(1, 1))
45        y = relay.nn.relu(y)
46        y = relay.Function([x, weight], y)
47        return y
48
49    def legalize_conv2d(attrs, inputs, types):
50        data, weight = inputs
51        weight = relay.multiply(weight, relay.const(2.0, "float32"))
52        return relay.nn.conv2d(data, weight, **attrs)
53
54    def expected():
55        x = relay.var("x", shape=(1, 64, 56, 56))
56        weight = relay.var('weight', shape=(64, 64, 3, 3))
57        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
58                            channels=64,
59                            kernel_size=(3, 3),
60                            padding=(1, 1))
61        y = relay.nn.relu(y)
62        y = relay.Function([x, weight], y)
63        return y
64
65    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
66        a = before()
67        a = run_opt_pass(a, transform.Legalize())
68        b = run_opt_pass(expected(), transform.InferType())
69
70    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
71
72def test_legalize_none():
73    """Test doing nothing by returning 'None' """
74    def before():
75        x = relay.var("x", shape=(1, 64, 56, 56))
76        y = relay.nn.global_max_pool2d(x)
77        y = relay.Function([x], y)
78        return y
79
80    called = [False]
81
82    def legalize_conv2d(attrs, inputs, types):
83        called[0] = True
84        return None
85
86    with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
87        a = before()
88        a = run_opt_pass(a, transform.Legalize())
89        b = run_opt_pass(before(), transform.InferType())
90
91    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
92    assert(called[0])
93
94def test_legalize_multiple_ops():
95    """Test directly replacing an operator with a new one"""
96    def before():
97        x = relay.var("x", shape=(1, 64, 56, 56))
98        weight = relay.var('weight', shape=(64, 64, 3, 3))
99        y = relay.nn.conv2d(x, weight,
100                            channels=64,
101                            kernel_size=(3, 3),
102                            padding=(1, 1))
103        y = relay.nn.relu(y)
104        y = relay.Function([x, weight], y)
105        return y
106
107    def legalize_conv2d(attrs, inputs, types):
108        data, weight = inputs
109        weight = relay.multiply(weight, relay.const(2.0, "float32"))
110        return relay.nn.conv2d(data, weight, **attrs)
111
112    def legalize_relu(attrs, inputs, types):
113        data = inputs[0]
114        add = relay.add(tvm.relay.const(0, "float32"), data)
115        return relay.nn.relu(add)
116
117
118    def expected():
119        x = relay.var("x", shape=(1, 64, 56, 56))
120        weight = relay.var('weight', shape=(64, 64, 3, 3))
121        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
122                            channels=64,
123                            kernel_size=(3, 3),
124                            padding=(1, 1))
125        y = relay.add(tvm.relay.const(0, "float32"), y)
126        y = relay.nn.relu(y)
127        y = relay.Function([x, weight], y)
128        return y
129
130    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
131        with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
132            a = before()
133            a = run_opt_pass(a, transform.Legalize())
134            b = run_opt_pass(expected(), transform.InferType())
135
136    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
137
138
139def test_legalize_multi_input():
140    """Test directly replacing an operator with a new one"""
141    def before():
142        x = relay.var("x", shape=(1, 64, 56, 56))
143        y = relay.var("y", shape=(1, 64, 56, 20))
144        z = relay.var("z", shape=(1, 64, 56, 10))
145        func = relay.concatenate([x, y, z], axis=3)
146        func = relay.Function([x, y, z], func)
147        return func
148
149    def legalize_concatenate(attrs, inputs, types):
150        # Check that the correct multi-input case is handled.
151        assert len(inputs) == 1
152        assert isinstance(inputs[0], tvm.relay.expr.Tuple)
153        assert len(types) == 2
154        assert isinstance(types[0], tvm.relay.ty.TupleType)
155        assert isinstance(types[1], tvm.relay.ty.TensorType)
156        return None
157
158    def expected():
159        x = relay.var("x", shape=(1, 64, 56, 56))
160        y = relay.var("y", shape=(1, 64, 56, 20))
161        z = relay.var("z", shape=(1, 64, 56, 10))
162        func = relay.concatenate([x, y, z], axis=3)
163        func = relay.Function([x, y, z], func)
164        return func
165
166
167    with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
168        a = before()
169        a = run_opt_pass(a, transform.Legalize())
170        b = run_opt_pass(expected(), transform.InferType())
171
172    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
173
174
175if __name__ == "__main__":
176    test_legalize()
177    test_legalize_none()
178    test_legalize_multiple_ops()
179    test_legalize_multi_input()
180