1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test legalize pass"""
18import numpy as np
19import tvm
20from tvm import te
21
22from tvm import relay
23from tvm.contrib import graph_runtime
24from tvm.relay import transform, analysis
25from tvm.relay.testing.temp_op_attr import TempOpAttr
26
27
28def run_opt_pass(expr, passes):
29    passes = passes if isinstance(passes, list) else [passes]
30    mod = tvm.IRModule.from_expr(expr)
31    seq = tvm.transform.Sequential(passes)
32    with tvm.transform.PassContext(opt_level=3):
33        mod = seq(mod)
34    entry = mod["main"]
35    return entry if isinstance(expr, relay.Function) else entry.body
36
37
38def test_legalize():
39    """Test directly replacing an operator with a new one"""
40
41    def before():
42        x = relay.var("x", shape=(1, 64, 56, 56))
43        weight = relay.var("weight", shape=(64, 64, 3, 3))
44        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
45        y = relay.nn.relu(y)
46        y = relay.Function([x, weight], y)
47        return y
48
49    def legalize_conv2d(attrs, inputs, types):
50        data, weight = inputs
51        weight = relay.multiply(weight, relay.const(2.0, "float32"))
52        return relay.nn.conv2d(data, weight, **attrs)
53
54    def expected():
55        x = relay.var("x", shape=(1, 64, 56, 56))
56        weight = relay.var("weight", shape=(64, 64, 3, 3))
57        y = relay.nn.conv2d(
58            x,
59            relay.multiply(weight, relay.const(2.0, "float32")),
60            channels=64,
61            kernel_size=(3, 3),
62            padding=(1, 1),
63        )
64        y = relay.nn.relu(y)
65        y = relay.Function([x, weight], y)
66        return y
67
68    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
69        a = before()
70        a = run_opt_pass(a, transform.Legalize())
71        b = run_opt_pass(expected(), transform.InferType())
72
73    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
74
75
76def test_legalize_none():
77    """Test doing nothing by returning 'None' """
78
79    def before():
80        x = relay.var("x", shape=(1, 64, 56, 56))
81        y = relay.nn.global_max_pool2d(x)
82        y = relay.Function([x], y)
83        return y
84
85    called = [False]
86
87    def legalize_conv2d(attrs, inputs, types):
88        called[0] = True
89        return None
90
91    with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
92        a = before()
93        a = run_opt_pass(a, transform.Legalize())
94        b = run_opt_pass(before(), transform.InferType())
95
96    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
97    assert called[0]
98
99
100def test_legalize_multiple_ops():
101    """Test directly replacing an operator with a new one"""
102
103    def before():
104        x = relay.var("x", shape=(1, 64, 56, 56))
105        weight = relay.var("weight", shape=(64, 64, 3, 3))
106        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
107        y = relay.nn.relu(y)
108        y = relay.Function([x, weight], y)
109        return y
110
111    def legalize_conv2d(attrs, inputs, types):
112        data, weight = inputs
113        weight = relay.multiply(weight, relay.const(2.0, "float32"))
114        return relay.nn.conv2d(data, weight, **attrs)
115
116    def legalize_relu(attrs, inputs, types):
117        data = inputs[0]
118        add = relay.add(tvm.relay.const(0, "float32"), data)
119        return relay.nn.relu(add)
120
121    def expected():
122        x = relay.var("x", shape=(1, 64, 56, 56))
123        weight = relay.var("weight", shape=(64, 64, 3, 3))
124        y = relay.nn.conv2d(
125            x,
126            relay.multiply(weight, relay.const(2.0, "float32")),
127            channels=64,
128            kernel_size=(3, 3),
129            padding=(1, 1),
130        )
131        y = relay.add(tvm.relay.const(0, "float32"), y)
132        y = relay.nn.relu(y)
133        y = relay.Function([x, weight], y)
134        return y
135
136    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
137        with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
138            a = before()
139            a = run_opt_pass(a, transform.Legalize())
140            b = run_opt_pass(expected(), transform.InferType())
141
142    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
143
144
145def test_legalize_multi_input():
146    """Test directly replacing an operator with a new one"""
147
148    def before():
149        x = relay.var("x", shape=(1, 64, 56, 56))
150        y = relay.var("y", shape=(1, 64, 56, 20))
151        z = relay.var("z", shape=(1, 64, 56, 10))
152        func = relay.concatenate([x, y, z], axis=3)
153        func = relay.Function([x, y, z], func)
154        return func
155
156    def legalize_concatenate(attrs, inputs, types):
157        # Check that the correct multi-input case is handled.
158        assert len(inputs) == 1
159        assert isinstance(inputs[0], tvm.relay.expr.Tuple)
160        assert len(types) == 2
161        assert isinstance(types[0], tvm.relay.ty.TupleType)
162        assert isinstance(types[1], tvm.relay.ty.TensorType)
163        return None
164
165    def expected():
166        x = relay.var("x", shape=(1, 64, 56, 56))
167        y = relay.var("y", shape=(1, 64, 56, 20))
168        z = relay.var("z", shape=(1, 64, 56, 10))
169        func = relay.concatenate([x, y, z], axis=3)
170        func = relay.Function([x, y, z], func)
171        return func
172
173    with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
174        a = before()
175        a = run_opt_pass(a, transform.Legalize())
176        b = run_opt_pass(expected(), transform.InferType())
177
178    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
179
180
181if __name__ == "__main__":
182    test_legalize()
183    test_legalize_none()
184    test_legalize_multiple_ops()
185    test_legalize_multi_input()
186