1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import numpy as np 18import tvm 19from tvm import te 20from tvm import relay 21from tvm.relay import transform 22from tvm.relay.build_module import bind_params_by_name 23from tvm.relay.testing import run_infer_type, create_workload 24 25 26def run_opt_pass(expr, opt_pass): 27 assert isinstance(opt_pass, tvm.transform.Pass) 28 29 mod = tvm.IRModule.from_expr(expr) 30 mod = opt_pass(mod) 31 entry = mod["main"] 32 return entry if isinstance(expr, relay.Function) else entry.body 33 34 35def test_concatenate_const(): 36 def before(): 37 data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) 38 const = relay.const(data) 39 concat = relay.op.concatenate([const, const], axis=0) 40 func = relay.Function([], concat) 41 return func 42 43 def expected(): 44 data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) 45 const = relay.const(data) 46 func = relay.Function([], const) 47 return func 48 49 zz = run_opt_pass(before(), transform.FoldConstant()) 50 zexpected = run_opt_pass(expected(), transform.InferType()) 51 assert tvm.ir.structural_equal(zz, zexpected) 52 53 54def test_fold_const(): 55 c_data = np.array([1, 2, 3]).astype("float32") 56 t = relay.TensorType([1, 2, 3], "float32") 57 58 def before(): 59 c = relay.const(c_data) 60 x = relay.var("x", t) 61 y = relay.add(c, c) 62 y = relay.multiply(y, relay.const(2, "float32")) 63 y = relay.add(x, y) 64 z = relay.add(y, c) 65 return relay.Function([x], z) 66 67 def expected(): 68 x = relay.var("x", t) 69 c_folded = (c_data + c_data) * 2 70 y = relay.add(x, relay.const(c_folded)) 71 z = relay.add(y, relay.const(c_data)) 72 return relay.Function([x], z) 73 74 # the fold constant should work on any context. 75 with tvm.target.Target("cuda"): 76 zz = run_opt_pass(before(), transform.FoldConstant()) 77 zexpected = run_opt_pass(expected(), transform.InferType()) 78 assert tvm.ir.structural_equal(zz, zexpected) 79 80 81def test_fold_let(): 82 c_data = np.array(1).astype("float32") 83 t = relay.TensorType([1], "float32") 84 85 def before(): 86 sb = relay.ScopeBuilder() 87 x = relay.var("x", t) 88 t1 = sb.let("t1", relay.const(c_data)) 89 t2 = sb.let("t2", relay.add(t1, t1)) 90 t3 = sb.let("t3", relay.add(t2, x)) 91 sb.ret(t3) 92 return relay.Function([x], sb.get()) 93 94 def expected(): 95 sb = relay.ScopeBuilder() 96 x = relay.var("x", t) 97 c_folded = c_data + c_data 98 t3 = sb.let("t3", relay.add(relay.const(c_folded), x)) 99 sb.ret(t3) 100 return relay.Function([x], sb.get()) 101 102 zz = run_opt_pass(before(), transform.FoldConstant()) 103 zexpected = run_opt_pass(expected(), transform.InferType()) 104 assert tvm.ir.structural_equal(zz, zexpected) 105 106 107def test_fold_tuple(): 108 c_data = np.array(1).astype("float32") 109 t = relay.TensorType([1], "float32") 110 111 def before(): 112 c = relay.const(c_data) 113 x = relay.var("x", t) 114 y = relay.Tuple([x, c]) 115 z = relay.add(y[1], c) 116 z = relay.add(z, y[0]) 117 return relay.Function([x], z) 118 119 def expected(): 120 c = relay.const(c_data + c_data) 121 x = relay.var("x", t) 122 z = relay.add(c, x) 123 return relay.Function([x], z) 124 125 zz = run_opt_pass(before(), transform.FoldConstant()) 126 zexpected = run_opt_pass(expected(), transform.InferType()) 127 assert tvm.ir.structural_equal(zz, zexpected) 128 129 130def test_fold_concat(): 131 c_data = np.array([[1, 2, 3]]).astype("float32") 132 133 def before(): 134 a = relay.const(c_data) 135 b = relay.const(c_data) 136 y = relay.concatenate((a, b), axis=0) 137 return relay.Function([], y) 138 139 def expected(): 140 y_data = np.concatenate((c_data, c_data), axis=0) 141 y = relay.const(y_data) 142 return relay.Function([], y) 143 144 zz = run_opt_pass(before(), transform.FoldConstant()) 145 zexpected = run_opt_pass(expected(), transform.InferType()) 146 assert tvm.ir.structural_equal(zz, zexpected) 147 148 149def test_fold_shape_of(): 150 c_shape = (8, 9, 10) 151 152 def before(dtype): 153 x = relay.var("x", shape=c_shape, dtype="float32") 154 y = relay.var("y", shape=c_shape, dtype="float32") 155 z = relay.shape_of(x + y, dtype) 156 return relay.Function([x, y], z) 157 158 def expected(dtype): 159 x = relay.var("x", shape=c_shape, dtype="float32") 160 y = relay.var("y", shape=c_shape, dtype="float32") 161 z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype) 162 func = relay.Function([x, y], z) 163 return func 164 165 for dtype in ["int32", "float32"]: 166 zz = run_opt_pass(before(dtype), transform.FoldConstant()) 167 zexpected = run_opt_pass(expected(dtype), transform.InferType()) 168 assert tvm.ir.structural_equal(zz, zexpected) 169 170 171def test_fold_ndarray_size(): 172 c_shape = (8, 9, 10) 173 174 def before(dtype): 175 x = relay.var("x", shape=c_shape, dtype="float32") 176 y = relay.var("y", shape=c_shape, dtype="float32") 177 z = relay.ndarray_size(x + y, dtype) 178 return relay.Function([x, y], z) 179 180 def expected(dtype): 181 x = relay.var("x", shape=c_shape, dtype="float32") 182 y = relay.var("y", shape=c_shape, dtype="float32") 183 z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype) 184 func = relay.Function([x, y], z) 185 return func 186 187 for dtype in ["int32", "float32"]: 188 zz = run_opt_pass(before(dtype), transform.FoldConstant()) 189 zexpected = run_opt_pass(expected(dtype), transform.InferType()) 190 assert tvm.ir.structural_equal(zz, zexpected) 191 192 193def test_fold_full(): 194 c_shape = (8, 9, 10) 195 196 def before(): 197 dtype = "float32" 198 return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype) 199 200 def expected(): 201 # expect no changes 202 return before() 203 204 zz = run_opt_pass(before(), transform.FoldConstant()) 205 zexpected = run_opt_pass(expected(), transform.InferType()) 206 assert tvm.ir.structural_equal(zz, zexpected) 207 208 209def test_fold_batch_norm(): 210 def expected(): 211 data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) 212 weight = relay.const(np.zeros((16, 3, 3, 3))) 213 bias = relay.const(np.zeros((16, 1, 1))) 214 conv = relay.nn.conv2d( 215 data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1) 216 ) 217 add = relay.add(conv, bias) 218 return relay.Function(relay.analysis.free_vars(add), add) 219 220 remove_bn_pass = tvm.transform.Sequential( 221 [ 222 relay.transform.InferType(), 223 relay.transform.SimplifyInference(), 224 relay.transform.FoldConstant(), 225 relay.transform.FoldScaleAxis(), 226 ] 227 ) 228 229 data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) 230 weight = relay.var("weight") 231 bn_gamma = relay.var("bn_gamma") 232 bn_beta = relay.var("bn_beta") 233 bn_mmean = relay.var("bn_mean") 234 bn_mvar = relay.var("bn_var") 235 236 conv = relay.nn.conv2d( 237 data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1) 238 ) 239 bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar) 240 241 def initializer(_, param): 242 param = np.zeros(param.shape) 243 244 mod, params = create_workload(bn_output[0], initializer) 245 mod["main"] = bind_params_by_name(mod["main"], params) 246 247 with tvm.transform.PassContext(opt_level=3): 248 mod = remove_bn_pass(mod) 249 250 expect = run_infer_type(expected()) 251 assert tvm.ir.structural_equal(mod["main"], expect) 252 253 254if __name__ == "__main__": 255 test_fold_const() 256 test_fold_let() 257 test_fold_tuple() 258 test_fold_concat() 259 test_fold_shape_of() 260 test_fold_full() 261 test_fold_batch_norm() 262 test_fold_ndarray_size() 263