1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Test legalize pass""" 18import numpy as np 19import tvm 20from tvm import te 21 22from tvm import relay 23from tvm.contrib import graph_runtime 24from tvm.relay import transform, analysis 25from tvm.relay.testing.temp_op_attr import TempOpAttr 26 27 28def run_opt_pass(expr, passes): 29 passes = passes if isinstance(passes, list) else [passes] 30 mod = tvm.IRModule.from_expr(expr) 31 seq = tvm.transform.Sequential(passes) 32 with tvm.transform.PassContext(opt_level=3): 33 mod = seq(mod) 34 entry = mod["main"] 35 return entry if isinstance(expr, relay.Function) else entry.body 36 37 38def test_legalize(): 39 """Test directly replacing an operator with a new one""" 40 41 def before(): 42 x = relay.var("x", shape=(1, 64, 56, 56)) 43 weight = relay.var("weight", shape=(64, 64, 3, 3)) 44 y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) 45 y = relay.nn.relu(y) 46 y = relay.Function([x, weight], y) 47 return y 48 49 def legalize_conv2d(attrs, inputs, types): 50 data, weight = inputs 51 weight = relay.multiply(weight, relay.const(2.0, "float32")) 52 return relay.nn.conv2d(data, weight, **attrs) 53 54 def expected(): 55 x = relay.var("x", shape=(1, 64, 56, 56)) 56 weight = relay.var("weight", shape=(64, 64, 3, 3)) 57 y = relay.nn.conv2d( 58 x, 59 relay.multiply(weight, relay.const(2.0, "float32")), 60 channels=64, 61 kernel_size=(3, 3), 62 padding=(1, 1), 63 ) 64 y = relay.nn.relu(y) 65 y = relay.Function([x, weight], y) 66 return y 67 68 with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): 69 a = before() 70 a = run_opt_pass(a, transform.Legalize()) 71 b = run_opt_pass(expected(), transform.InferType()) 72 73 assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) 74 75 76def test_legalize_none(): 77 """Test doing nothing by returning 'None' """ 78 79 def before(): 80 x = relay.var("x", shape=(1, 64, 56, 56)) 81 y = relay.nn.global_max_pool2d(x) 82 y = relay.Function([x], y) 83 return y 84 85 called = [False] 86 87 def legalize_conv2d(attrs, inputs, types): 88 called[0] = True 89 return None 90 91 with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d): 92 a = before() 93 a = run_opt_pass(a, transform.Legalize()) 94 b = run_opt_pass(before(), transform.InferType()) 95 96 assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) 97 assert called[0] 98 99 100def test_legalize_multiple_ops(): 101 """Test directly replacing an operator with a new one""" 102 103 def before(): 104 x = relay.var("x", shape=(1, 64, 56, 56)) 105 weight = relay.var("weight", shape=(64, 64, 3, 3)) 106 y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) 107 y = relay.nn.relu(y) 108 y = relay.Function([x, weight], y) 109 return y 110 111 def legalize_conv2d(attrs, inputs, types): 112 data, weight = inputs 113 weight = relay.multiply(weight, relay.const(2.0, "float32")) 114 return relay.nn.conv2d(data, weight, **attrs) 115 116 def legalize_relu(attrs, inputs, types): 117 data = inputs[0] 118 add = relay.add(tvm.relay.const(0, "float32"), data) 119 return relay.nn.relu(add) 120 121 def expected(): 122 x = relay.var("x", shape=(1, 64, 56, 56)) 123 weight = relay.var("weight", shape=(64, 64, 3, 3)) 124 y = relay.nn.conv2d( 125 x, 126 relay.multiply(weight, relay.const(2.0, "float32")), 127 channels=64, 128 kernel_size=(3, 3), 129 padding=(1, 1), 130 ) 131 y = relay.add(tvm.relay.const(0, "float32"), y) 132 y = relay.nn.relu(y) 133 y = relay.Function([x, weight], y) 134 return y 135 136 with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): 137 with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu): 138 a = before() 139 a = run_opt_pass(a, transform.Legalize()) 140 b = run_opt_pass(expected(), transform.InferType()) 141 142 assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) 143 144 145def test_legalize_multi_input(): 146 """Test directly replacing an operator with a new one""" 147 148 def before(): 149 x = relay.var("x", shape=(1, 64, 56, 56)) 150 y = relay.var("y", shape=(1, 64, 56, 20)) 151 z = relay.var("z", shape=(1, 64, 56, 10)) 152 func = relay.concatenate([x, y, z], axis=3) 153 func = relay.Function([x, y, z], func) 154 return func 155 156 def legalize_concatenate(attrs, inputs, types): 157 # Check that the correct multi-input case is handled. 158 assert len(inputs) == 1 159 assert isinstance(inputs[0], tvm.relay.expr.Tuple) 160 assert len(types) == 2 161 assert isinstance(types[0], tvm.relay.ty.TupleType) 162 assert isinstance(types[1], tvm.relay.ty.TensorType) 163 return None 164 165 def expected(): 166 x = relay.var("x", shape=(1, 64, 56, 56)) 167 y = relay.var("y", shape=(1, 64, 56, 20)) 168 z = relay.var("z", shape=(1, 64, 56, 10)) 169 func = relay.concatenate([x, y, z], axis=3) 170 func = relay.Function([x, y, z], func) 171 return func 172 173 with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate): 174 a = before() 175 a = run_opt_pass(a, transform.Legalize()) 176 b = run_opt_pass(expected(), transform.InferType()) 177 178 assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) 179 180 181if __name__ == "__main__": 182 test_legalize() 183 test_legalize_none() 184 test_legalize_multiple_ops() 185 test_legalize_multi_input() 186