1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import numpy as np 18import tvm 19from tvm import relay 20from tvm.relay import transform 21 22 23def run_opt_pass(expr, opt_pass): 24 assert isinstance(opt_pass, transform.Pass) 25 26 mod = relay.Module.from_expr(expr) 27 mod = opt_pass(mod) 28 entry = mod["main"] 29 return entry if isinstance(expr, relay.Function) else entry.body 30 31 32def test_concatenate_const(): 33 def before(): 34 data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) 35 const = relay.const(data) 36 concat = relay.op.concatenate([const, const], axis=0) 37 func = relay.Function([], concat) 38 return func 39 40 def expected(): 41 data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) 42 const = relay.const(data) 43 func = relay.Function([], const) 44 return func 45 46 zz = run_opt_pass(before(), transform.FoldConstant()) 47 zexpected = run_opt_pass(expected(), transform.InferType()) 48 assert relay.analysis.graph_equal(zz, zexpected) 49 50 51def test_fold_const(): 52 c_data = np.array([1, 2, 3]).astype("float32") 53 t = relay.TensorType([1, 2, 3], "float32") 54 def before(): 55 c = relay.const(c_data) 56 x = relay.var("x", t) 57 y = relay.add(c, c) 58 y = relay.multiply(y, relay.const(2, "float32")) 59 y = relay.add(x, y) 60 z = relay.add(y, c) 61 return relay.Function([x], z) 62 63 def expected(): 64 x = relay.var("x", t) 65 c_folded = (c_data + c_data) * 2 66 y = relay.add(x, relay.const(c_folded)) 67 z = relay.add(y, relay.const(c_data)) 68 return relay.Function([x], z) 69 70 def fail(x): 71 raise RuntimeError() 72 73 # the fold constant should work on any context. 74 with tvm.build_config(add_lower_pass=[(0, fail)]): 75 with tvm.target.create("cuda"): 76 zz = run_opt_pass(before(), transform.FoldConstant()) 77 zexpected = run_opt_pass(expected(), transform.InferType()) 78 assert relay.analysis.alpha_equal(zz, zexpected) 79 80 81def test_fold_let(): 82 c_data = np.array(1).astype("float32") 83 t = relay.TensorType([1], "float32") 84 def before(): 85 sb = relay.ScopeBuilder() 86 x = relay.var("x", t) 87 t1 = sb.let("t1", relay.const(c_data)) 88 t2 = sb.let("t2", relay.add(t1, t1)) 89 t3 = sb.let("t3", relay.add(t2, x)) 90 sb.ret(t3) 91 return relay.Function([x], sb.get()) 92 93 def expected(): 94 sb = relay.ScopeBuilder() 95 x = relay.var("x", t) 96 c_folded = (c_data + c_data) 97 t3 = sb.let("t3", relay.add(relay.const(c_folded), x)) 98 sb.ret(t3) 99 return relay.Function([x], sb.get()) 100 101 zz = run_opt_pass(before(), transform.FoldConstant()) 102 zexpected = run_opt_pass(expected(), transform.InferType()) 103 assert relay.analysis.graph_equal(zz, zexpected) 104 105 106def test_fold_tuple(): 107 c_data = np.array(1).astype("float32") 108 t = relay.TensorType([1], "float32") 109 def before(): 110 c = relay.const(c_data) 111 x = relay.var("x", t) 112 y = relay.Tuple([x, c]) 113 z = relay.add(y[1], c) 114 z = relay.add(z, y[0]) 115 return relay.Function([x], z) 116 117 def expected(): 118 c = relay.const(c_data + c_data) 119 x = relay.var("x", t) 120 z = relay.add(c, x) 121 return relay.Function([x], z) 122 123 zz = run_opt_pass(before(), transform.FoldConstant()) 124 zexpected = run_opt_pass(expected(), transform.InferType()) 125 assert relay.analysis.graph_equal(zz, zexpected) 126 127 128def test_fold_concat(): 129 c_data = np.array([[1, 2, 3]]).astype("float32") 130 131 def before(): 132 a = relay.const(c_data) 133 b = relay.const(c_data) 134 y = relay.concatenate((a, b), axis=0) 135 return relay.Function([], y) 136 137 def expected(): 138 y_data = np.concatenate((c_data, c_data), axis=0) 139 y = relay.const(y_data) 140 return relay.Function([], y) 141 142 zz = run_opt_pass(before(), transform.FoldConstant()) 143 zexpected = run_opt_pass(expected(), transform.InferType()) 144 assert relay.analysis.graph_equal(zz, zexpected) 145 146 147def test_fold_shape_of(): 148 c_shape = (8, 9, 10) 149 def before(dtype): 150 x = relay.var("x", shape=c_shape, dtype="float32") 151 y = relay.var("y", shape=c_shape, dtype="float32") 152 z = relay.shape_of(x + y, dtype) 153 return relay.Function([x, y], z) 154 155 def expected(dtype): 156 x = relay.var("x", shape=c_shape, dtype="float32") 157 y = relay.var("y", shape=c_shape, dtype="float32") 158 z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype) 159 func = relay.Function([x, y], z) 160 return func 161 162 for dtype in ["int32", "float32"]: 163 zz = run_opt_pass(before(dtype), transform.FoldConstant()) 164 zexpected = run_opt_pass(expected(dtype), transform.InferType()) 165 assert relay.analysis.graph_equal(zz, zexpected) 166 167 168def test_fold_full(): 169 c_shape = (8, 9, 10) 170 def before(): 171 dtype = 'float32' 172 return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype) 173 174 def expected(): 175 # expect no changes 176 return before() 177 178 zz = run_opt_pass(before(), transform.FoldConstant()) 179 zexpected = run_opt_pass(expected(), transform.InferType()) 180 assert relay.analysis.graph_equal(zz, zexpected) 181 182 183if __name__ == "__main__": 184 test_fold_const() 185 test_fold_let() 186 test_fold_tuple() 187 test_fold_concat() 188 test_fold_shape_of() 189 test_fold_full() 190