1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Test legalize pass""" 18import numpy as np 19import tvm 20 21from tvm import relay 22from tvm.contrib import graph_runtime 23from tvm.relay import transform, analysis 24from tvm.relay.testing.temp_op_attr import TempOpAttr 25 26 27def run_opt_pass(expr, passes): 28 passes = passes if isinstance(passes, list) else [passes] 29 mod = relay.Module.from_expr(expr) 30 seq = transform.Sequential(passes) 31 with transform.PassContext(opt_level=3): 32 mod = seq(mod) 33 entry = mod["main"] 34 return entry if isinstance(expr, relay.Function) else entry.body 35 36def test_legalize(): 37 """Test directly replacing an operator with a new one""" 38 def before(): 39 x = relay.var("x", shape=(1, 64, 56, 56)) 40 weight = relay.var('weight', shape=(64, 64, 3, 3)) 41 y = relay.nn.conv2d(x, weight, 42 channels=64, 43 kernel_size=(3, 3), 44 padding=(1, 1)) 45 y = relay.nn.relu(y) 46 y = relay.Function([x, weight], y) 47 return y 48 49 def legalize_conv2d(attrs, inputs, types): 50 data, weight = inputs 51 weight = relay.multiply(weight, relay.const(2.0, "float32")) 52 return relay.nn.conv2d(data, weight, **attrs) 53 54 def expected(): 55 x = relay.var("x", shape=(1, 64, 56, 56)) 56 weight = relay.var('weight', shape=(64, 64, 3, 3)) 57 y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")), 58 channels=64, 59 kernel_size=(3, 3), 60 padding=(1, 1)) 61 y = relay.nn.relu(y) 62 y = relay.Function([x, weight], y) 63 return y 64 65 with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): 66 a = before() 67 a = run_opt_pass(a, transform.Legalize()) 68 b = run_opt_pass(expected(), transform.InferType()) 69 70 assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) 71 72def test_legalize_none(): 73 """Test doing nothing by returning 'None' """ 74 def before(): 75 x = relay.var("x", shape=(1, 64, 56, 56)) 76 y = relay.nn.global_max_pool2d(x) 77 y = relay.Function([x], y) 78 return y 79 80 called = [False] 81 82 def legalize_conv2d(attrs, inputs, types): 83 called[0] = True 84 return None 85 86 with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d): 87 a = before() 88 a = run_opt_pass(a, transform.Legalize()) 89 b = run_opt_pass(before(), transform.InferType()) 90 91 assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) 92 assert(called[0]) 93 94def test_legalize_multiple_ops(): 95 """Test directly replacing an operator with a new one""" 96 def before(): 97 x = relay.var("x", shape=(1, 64, 56, 56)) 98 weight = relay.var('weight', shape=(64, 64, 3, 3)) 99 y = relay.nn.conv2d(x, weight, 100 channels=64, 101 kernel_size=(3, 3), 102 padding=(1, 1)) 103 y = relay.nn.relu(y) 104 y = relay.Function([x, weight], y) 105 return y 106 107 def legalize_conv2d(attrs, inputs, types): 108 data, weight = inputs 109 weight = relay.multiply(weight, relay.const(2.0, "float32")) 110 return relay.nn.conv2d(data, weight, **attrs) 111 112 def legalize_relu(attrs, inputs, types): 113 data = inputs[0] 114 add = relay.add(tvm.relay.const(0, "float32"), data) 115 return relay.nn.relu(add) 116 117 118 def expected(): 119 x = relay.var("x", shape=(1, 64, 56, 56)) 120 weight = relay.var('weight', shape=(64, 64, 3, 3)) 121 y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")), 122 channels=64, 123 kernel_size=(3, 3), 124 padding=(1, 1)) 125 y = relay.add(tvm.relay.const(0, "float32"), y) 126 y = relay.nn.relu(y) 127 y = relay.Function([x, weight], y) 128 return y 129 130 with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): 131 with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu): 132 a = before() 133 a = run_opt_pass(a, transform.Legalize()) 134 b = run_opt_pass(expected(), transform.InferType()) 135 136 assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) 137 138 139def test_legalize_multi_input(): 140 """Test directly replacing an operator with a new one""" 141 def before(): 142 x = relay.var("x", shape=(1, 64, 56, 56)) 143 y = relay.var("y", shape=(1, 64, 56, 20)) 144 z = relay.var("z", shape=(1, 64, 56, 10)) 145 func = relay.concatenate([x, y, z], axis=3) 146 func = relay.Function([x, y, z], func) 147 return func 148 149 def legalize_concatenate(attrs, inputs, types): 150 # Check that the correct multi-input case is handled. 151 assert len(inputs) == 1 152 assert isinstance(inputs[0], tvm.relay.expr.Tuple) 153 assert len(types) == 2 154 assert isinstance(types[0], tvm.relay.ty.TupleType) 155 assert isinstance(types[1], tvm.relay.ty.TensorType) 156 return None 157 158 def expected(): 159 x = relay.var("x", shape=(1, 64, 56, 56)) 160 y = relay.var("y", shape=(1, 64, 56, 20)) 161 z = relay.var("z", shape=(1, 64, 56, 10)) 162 func = relay.concatenate([x, y, z], axis=3) 163 func = relay.Function([x, y, z], func) 164 return func 165 166 167 with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate): 168 a = before() 169 a = run_opt_pass(a, transform.Legalize()) 170 b = run_opt_pass(expected(), transform.InferType()) 171 172 assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) 173 174 175if __name__ == "__main__": 176 test_legalize() 177 test_legalize_none() 178 test_legalize_multiple_ops() 179 test_legalize_multi_input() 180