1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18from tvm import te 19from tvm import relay 20from tvm.relay import transform 21from tvm.relay.testing import run_opt_pass 22import tvm.testing 23 24 25def test_fuse_simple(): 26 """Simple testcase.""" 27 28 def before(): 29 x = relay.var("x", shape=(10, 20)) 30 y = relay.add(x, relay.const(1, "float32")) 31 z = relay.exp(y) 32 w = relay.squeeze(z) 33 return relay.Function([x], w) 34 35 def expected(): 36 x = relay.var("p", shape=(10, 20)) 37 y = relay.add(x, relay.const(1, "float32")) 38 z = relay.exp(y) 39 w = relay.squeeze(z) 40 f1 = relay.Function([x], w) 41 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 42 x = relay.var("x", shape=(10, 20)) 43 y = relay.Call(f1, [x]) 44 return relay.Function([x], y) 45 46 z = before() 47 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 48 zz = run_opt_pass(z, transform.FuseOps()) 49 after = run_opt_pass(expected(), transform.InferType()) 50 assert tvm.ir.structural_equal(zz, after) 51 52 53def test_conv2d_fuse(): 54 """Test fusion case of conv2d""" 55 56 def before(dshape): 57 x = relay.var("x", shape=dshape) 58 x = relay.add(x, relay.const(1, "float32")) 59 y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=16) 60 # this is the next dominator. 61 y1 = relay.add(relay.const(1, "float32"), y) 62 y = relay.add(y, y1) 63 # second path 64 z2 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(1, 1), padding=(0, 0), channels=16) 65 z3 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=16) 66 # add can only be fused to z1 67 z = relay.add(z2, z3) 68 return relay.Function(relay.analysis.free_vars(z), z) 69 70 def expected(dshape): 71 # segment 0 72 x = relay.var("p0", shape=dshape) 73 y = relay.add(x, relay.const(1, "float32")) 74 f0 = relay.Function([x], y) 75 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 76 77 # segment 1 78 x = relay.var("p0", shape=dshape) 79 w = relay.var("p1") 80 y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16) 81 y1 = relay.add(relay.const(1, "float32"), y) 82 y = relay.add(y, y1) 83 f1 = relay.Function([x, w], y) 84 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 85 86 # segment 2 87 x = relay.var("p0", shape=dshape) 88 w = relay.var("p1") 89 z2 = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16) 90 f2 = relay.Function([x, w], z2) 91 f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 92 93 # segment 3 94 x = relay.var("p0", shape=dshape) 95 w = relay.var("p1") 96 offset = relay.var("p2", shape=dshape) 97 z3 = relay.nn.conv2d(x, w, kernel_size=(1, 1), padding=(0, 0), channels=16) 98 z3 = relay.add(z3, offset) 99 f3 = relay.Function([x, w, offset], z3) 100 f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 101 102 # compose 103 x = relay.var("x", shape=dshape) 104 y = relay.Call(f0, [x]) 105 y = relay.Call(f1, [y, relay.var("w1")]) 106 z2 = relay.Call(f2, [y, relay.var("w3")]) 107 z3 = relay.Call(f3, [y, relay.var("w2"), z2]) 108 z = z3 109 return relay.Function(relay.analysis.free_vars(z), z) 110 111 dshape = (1, 16, 64, 64) 112 z = before(dshape) 113 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 114 after = run_opt_pass(expected(dshape), transform.InferType()) 115 assert tvm.ir.structural_equal(zz, after) 116 117 118def test_concatenate(): 119 """Test fusion case involving concat op and Tuple node""" 120 121 def before(dshape): 122 x = relay.var("x", shape=dshape) 123 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 124 upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") 125 concat = relay.concatenate((upsampled, x), axis=1) 126 out = relay.add(concat, relay.const(1, "float32")) 127 return relay.Function(relay.analysis.free_vars(out), out) 128 129 def expected(dshape): 130 x = relay.var("x", shape=dshape) 131 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 132 f0 = relay.Function([x], pooled) 133 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 134 135 p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2)) 136 p1 = relay.var("p1", shape=dshape) 137 upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") 138 concat = relay.concatenate((upsampled, p1), axis=1) 139 out = relay.add(concat, relay.const(1, "float32")) 140 f1 = relay.Function([p0, p1], out) 141 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 142 143 x = relay.var("x", shape=dshape) 144 y = relay.Call(f0, [x]) 145 z = relay.Call(f1, [y, x]) 146 return relay.Function([x], z) 147 148 dshape = (1, 16, 64, 64) 149 z = before(dshape) 150 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 151 assert not relay.analysis.free_vars(zz) 152 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 153 assert not relay.analysis.free_vars(zz) 154 after = run_opt_pass(expected(dshape), transform.InferType()) 155 assert tvm.ir.structural_equal(zz, after) 156 157 158def test_tuple_root(): 159 """Test fusion case where Tuple node is the root in its group""" 160 161 def before(dshape): 162 x = relay.var("x", shape=dshape) 163 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 164 upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") 165 out = relay.Tuple((upsampled, x)) 166 return relay.Function(relay.analysis.free_vars(out), out) 167 168 def expected(dshape): 169 x = relay.var("x", shape=dshape) 170 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 171 f0 = relay.Function([x], pooled) 172 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 173 174 p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2)) 175 upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") 176 f1 = relay.Function([p0], upsampled) 177 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 178 179 x = relay.var("x", shape=dshape) 180 y = relay.Call(f0, [x]) 181 z = relay.Call(f1, [y]) 182 tup = relay.Tuple((z, x)) 183 return relay.Function([x], tup) 184 185 dshape = (1, 16, 64, 64) 186 z = before(dshape) 187 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 188 assert not relay.analysis.free_vars(zz) 189 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 190 assert not relay.analysis.free_vars(zz) 191 after = run_opt_pass(expected(dshape), transform.InferType()) 192 assert tvm.ir.structural_equal(zz, after) 193 194 195def test_stop_fusion(): 196 def before(dshape): 197 x = relay.var("x", shape=dshape) 198 y = relay.add(x, relay.const(1, "float32")) 199 y = relay.annotation.stop_fusion(y) 200 z = relay.exp(y) 201 return relay.Function([x], z) 202 203 def expected(dshape): 204 x = relay.var("p0", shape=dshape) 205 y = relay.add(x, relay.const(1, "float32")) 206 f1 = relay.Function([x], y) 207 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 208 209 x = relay.var("p01", shape=dshape) 210 y = relay.exp(x) 211 f2 = relay.Function([x], y) 212 f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 213 214 x = relay.var("x", shape=dshape) 215 y = relay.Call(f1, [x]) 216 z = relay.Call(f2, [y]) 217 return relay.Function([x], z) 218 219 dshape = (10, 20) 220 z = before(dshape) 221 zz = run_opt_pass(z, transform.FuseOps()) 222 after = run_opt_pass(expected(dshape), transform.InferType()) 223 assert tvm.ir.structural_equal(zz, after) 224 225 226def test_fuse_myia_regression(): 227 def before(dshape, dtype): 228 x = relay.var("x", shape=dshape, dtype=dtype) 229 y = relay.var("y", shape=dshape, dtype=dtype) 230 sb = relay.ScopeBuilder() 231 with sb.if_scope(relay.op.greater(x, y)): 232 sb.ret(relay.Function([], x)) 233 with sb.else_scope(): 234 sb.ret(relay.Function([], y)) 235 return relay.Function([x, y], relay.Call(sb.get(), [])) 236 237 def expected(dshape, dtype): 238 x = relay.var("x", shape=dshape, dtype=dtype) 239 y = relay.var("y", shape=dshape, dtype=dtype) 240 sb = relay.ScopeBuilder() 241 p1 = relay.var("p1", shape=dshape, dtype=dtype) 242 p2 = relay.var("p2", shape=dshape, dtype=dtype) 243 fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2)) 244 fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 245 with sb.if_scope(fused_gt(x, y)): 246 sb.ret(relay.Function([], x)) 247 with sb.else_scope(): 248 sb.ret(relay.Function([], y)) 249 return relay.Function([x, y], relay.Call(sb.get(), [])) 250 251 dshape = () 252 dtype = "int64" 253 f = before(dshape, dtype) 254 zz = run_opt_pass(f, transform.FuseOps()) 255 after = run_opt_pass(expected(dshape, dtype), transform.InferType()) 256 assert tvm.ir.structural_equal(zz, after) 257 258 259def test_fuse_tuple_get_elemwise(): 260 def before(dim): 261 X = relay.var("X", shape=(1, dim)) 262 W = relay.var("W", shape=(3 * dim, dim)) 263 matmul = relay.nn.dense(X, W) 264 splitted = relay.split(matmul, indices_or_sections=3, axis=1) 265 out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) 266 return relay.Function([X, W], out) 267 268 def expected(dim): 269 p0 = relay.var("p0", shape=(1, dim)) 270 p1 = relay.var("p1", shape=(3 * dim, dim)) 271 matmul = relay.nn.dense(p0, p1) 272 f0 = relay.Function([p0, p1], matmul) 273 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 274 275 p01 = relay.var("p01", shape=(1, 3 * dim)) 276 splitted = relay.split(p01, indices_or_sections=3, axis=1) 277 out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) 278 f1 = relay.Function([p01], out) 279 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 280 281 X = relay.var("X", shape=(1, dim)) 282 W = relay.var("W", shape=(3 * dim, dim)) 283 y = relay.Call(f0, [X, W]) 284 z = relay.Call(f1, [y]) 285 return relay.Function([X, W], z) 286 287 dim = 10 288 z = before(dim) 289 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 290 assert not relay.analysis.free_vars(zz) 291 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 292 assert not relay.analysis.free_vars(zz) 293 after = run_opt_pass(expected(dim), transform.InferType()) 294 assert tvm.ir.structural_equal(zz, after) 295 296 297def test_tuple_get_root(): 298 def before(dim): 299 X = relay.var("X", shape=(1, 3 * dim)) 300 W = relay.var("W", shape=(dim, dim)) 301 splitted = relay.split(X, indices_or_sections=3, axis=1) 302 out = relay.nn.dense(splitted[0], W) 303 return relay.Function([X, W], out) 304 305 def expected(dim): 306 p0 = relay.var("p0", shape=(1, 3 * dim)) 307 splitted = relay.split(p0, indices_or_sections=3, axis=1) 308 out = splitted[0] 309 f0 = relay.Function([p0], out) 310 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 311 312 p01 = relay.var("p01", shape=(1, dim)) 313 p1 = relay.var("p1", shape=(dim, dim)) 314 out = relay.nn.dense(p01, p1) 315 f1 = relay.Function([p01, p1], out) 316 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 317 318 X = relay.var("X", shape=(1, 3 * dim)) 319 W = relay.var("W", shape=(dim, dim)) 320 y = relay.Call(f0, [X]) 321 z = relay.Call(f1, [y, W]) 322 return relay.Function([X, W], z) 323 324 dim = 10 325 z = before(dim) 326 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 327 assert not relay.analysis.free_vars(zz) 328 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 329 assert not relay.analysis.free_vars(zz) 330 after = run_opt_pass(expected(dim), transform.InferType()) 331 assert tvm.ir.structural_equal(zz, after) 332 333 334fuse0 = relay.transform.FuseOps(fuse_opt_level=0) 335fuse2 = relay.transform.FuseOps(fuse_opt_level=2) 336 337 338def test_tuple_intermediate(): 339 def before(x): 340 inj = relay.squeeze(x) 341 y1 = relay.add(inj, relay.const(1, "float32")) 342 tmp = relay.squeeze(inj) 343 tmp = relay.add(tmp, relay.const(1, "float32")) 344 y2 = relay.add(tmp, relay.const(1, "float32")) 345 y3 = relay.add(inj, relay.const(1, "float32")) 346 concat = relay.concatenate((y1, y2, y3), axis=1) 347 out_inj = relay.squeeze(concat) 348 out = relay.add(out_inj, relay.const(1, "float32")) 349 return relay.Function(relay.analysis.free_vars(out), out) 350 351 def expected(p0): 352 f0 = before(p0) 353 f1 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 354 x = relay.var("x", shape=dshape) 355 y = relay.Call(f1, [x]) 356 return relay.Function([x], y) 357 358 dshape = (1, 16, 64, 64) 359 x = relay.var("x", shape=dshape) 360 orig = before(x) 361 fuse0(tvm.IRModule.from_expr(orig)) 362 m = fuse2(tvm.IRModule.from_expr(orig)) 363 relay.build(m, "llvm") 364 after = run_opt_pass(expected(x), transform.InferType()) 365 assert tvm.ir.structural_equal(m["main"], after) 366 367 368def test_tuple_consecutive(): 369 def gen_intermediate_tuple(x): 370 y1 = relay.add(x, relay.const(1, "float32")) 371 y2 = relay.add(x, relay.const(1, "float32")) 372 y3 = relay.add(x, relay.const(1, "float32")) 373 concat = relay.concatenate((y1, y2, y3), axis=1) 374 out = relay.add(concat, relay.const(1, "float32")) 375 return out 376 377 def gen_consecutive_tuple(x): 378 y1 = gen_intermediate_tuple(x) 379 y2 = gen_intermediate_tuple(x) 380 y3 = gen_intermediate_tuple(x) 381 concat = relay.concatenate((y1, y2, y3), axis=1) 382 return concat 383 384 def before(x): 385 concat = gen_consecutive_tuple(x) 386 pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 387 out = relay.add(pooled, relay.const(1, "float32")) 388 out2 = relay.add(out, relay.const(1, "float32")) 389 out_tup = relay.Tuple((out, out2)) 390 return relay.Function(relay.analysis.free_vars(out_tup), out_tup) 391 392 def expected(dshape): 393 p0 = relay.var("p0", shape=dshape) 394 concat = gen_consecutive_tuple(p0) 395 f0 = relay.Function([p0], concat) 396 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 397 398 p01 = relay.var("p01", shape=(1, dshape[1] * 9, dshape[2], dshape[3])) 399 pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 400 out = relay.add(pooled, relay.const(1, "float32")) 401 f1 = relay.Function([p01], out) 402 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 403 404 p02 = relay.var("p02", shape=(1, dshape[1] * 9, dshape[2] // 2, dshape[3] // 2)) 405 out = relay.add(p02, relay.const(1, "float32")) 406 f2 = relay.Function([p02], out) 407 f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 408 409 x = relay.var("x", shape=dshape) 410 y = relay.Call(f0, [x]) 411 z = relay.Call(f1, [y]) 412 z2 = relay.Call(f2, [z]) 413 414 return relay.Function([x], relay.Tuple((z, z2))) 415 416 dshape = (1, 16, 64, 64) 417 x = relay.var("x", shape=dshape) 418 orig = before(x) 419 fuse0(tvm.IRModule.from_expr(orig)) 420 m = fuse2(tvm.IRModule.from_expr(orig)) 421 relay.build(m, "llvm") 422 after = run_opt_pass(expected(dshape), transform.InferType()) 423 assert tvm.ir.structural_equal(m["main"], after) 424 425 426def test_inception_like(): 427 def conv(data): 428 y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16) 429 return relay.nn.relu(data=y) 430 431 def inception_like(data): 432 c0 = conv(data) 433 c1 = conv(data) 434 return relay.concatenate((c0, c1), axis=1) 435 436 def before(dshape): 437 x = relay.var("x", shape=dshape) 438 in1 = inception_like(x) 439 in2 = inception_like(in1) 440 return relay.Function(relay.analysis.free_vars(in2), in2) 441 442 def expected(dshape): 443 p0 = relay.var("p0", shape=dshape) 444 c = conv(p0) 445 f0 = relay.Function(relay.analysis.free_vars(c), c) 446 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 447 448 p01 = relay.var("p01", shape=dshape) 449 c = conv(p01) 450 f1 = relay.Function(relay.analysis.free_vars(c), c) 451 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 452 453 p02 = relay.var("p02", shape=dshape) 454 p12 = relay.var("p12", shape=dshape) 455 concat1 = relay.concatenate((p02, p12), axis=1) 456 f_concat1 = relay.Function([p02, p12], concat1) 457 f_concat1 = f_concat1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 458 459 dshape2 = (dshape[0], dshape[1] * 2, dshape[2], dshape[3]) 460 461 p03 = relay.var("p03", shape=dshape2) 462 c = conv(p03) 463 f2 = relay.Function(relay.analysis.free_vars(c), c) 464 f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 465 466 p04 = relay.var("p04", shape=dshape2) 467 c = conv(p04) 468 f3 = relay.Function(relay.analysis.free_vars(c), c) 469 f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 470 471 p05 = relay.var("p05", shape=dshape) 472 p15 = relay.var("p15", shape=dshape) 473 concat2 = relay.concatenate((p05, p15), axis=1) 474 f_concat2 = relay.Function([p05, p15], concat2) 475 f_concat2 = f_concat2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 476 477 x = relay.var("x", shape=dshape) 478 c1 = relay.Call(f0, [x, relay.var("w1")]) 479 c2 = relay.Call(f1, [x, relay.var("w2")]) 480 concat = relay.Call(f_concat1, [c1, c2]) 481 c3 = relay.Call(f2, [concat, relay.var("w3")]) 482 c4 = relay.Call(f3, [concat, relay.var("w4")]) 483 out = relay.Call(f_concat2, [c3, c4]) 484 485 return relay.Function(relay.analysis.free_vars(out), out) 486 487 dshape = (1, 16, 64, 64) 488 orig = before(dshape) 489 fuse0(tvm.IRModule.from_expr(orig)) 490 m = fuse2(tvm.IRModule.from_expr(orig)) 491 relay.build(m, "llvm") 492 after = run_opt_pass(expected(dshape), transform.InferType()) 493 assert tvm.ir.structural_equal(m["main"], after) 494 495 496def test_fuse_parallel_injective(): 497 """Test fusing parallel injective ops to an elemwise op.""" 498 499 def before(): 500 x = relay.var("x", shape=(10, 20)) 501 y = relay.add(x, relay.const(1, "float32")) 502 z = relay.squeeze(y) 503 u = relay.transpose(y, axes=[0, 1]) 504 w = relay.left_shift(z, u) 505 return relay.Function([x], w) 506 507 def expected(): 508 x = relay.var("p", shape=(10, 20)) 509 y = relay.add(x, relay.const(1, "float32")) 510 z = relay.squeeze(y) 511 u = relay.transpose(y, axes=[0, 1]) 512 w = relay.left_shift(z, u) 513 f1 = relay.Function([x], w) 514 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 515 x = relay.var("x", shape=(10, 20)) 516 y = relay.Call(f1, [x]) 517 return relay.Function([x], y) 518 519 z = before() 520 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 521 assert not relay.analysis.free_vars(zz) 522 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 523 assert not relay.analysis.free_vars(zz) 524 after = run_opt_pass(expected(), transform.InferType()) 525 assert tvm.ir.structural_equal(zz, after) 526 527 528def test_immutable(): 529 """Verify the fusion pass won't change original module.""" 530 531 def before(): 532 x = relay.var("x", shape=(10, 20)) 533 y = relay.add(x, relay.const(1, "float32")) 534 z = relay.exp(y) 535 w = relay.squeeze(z) 536 mod = tvm.IRModule() 537 mod["main"] = relay.Function([x], w) 538 return mod 539 540 def expected(): 541 x = relay.var("p", shape=(10, 20)) 542 y = relay.add(x, relay.const(1, "float32")) 543 z = relay.exp(y) 544 w = relay.squeeze(z) 545 f1 = relay.Function([x], w) 546 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 547 x = relay.var("x", shape=(10, 20)) 548 y = relay.Call(f1, [x]) 549 mod = tvm.IRModule() 550 mod["main"] = relay.Function([x], y) 551 return mod 552 553 mod = before() 554 new_mod = transform.FuseOps(fuse_opt_level=2)(mod) 555 assert tvm.ir.structural_equal(mod, before()) 556 assert tvm.ir.structural_equal(new_mod, expected()) 557 558 559def test_split(): 560 """Test that the result is well formed.""" 561 x = relay.var("x", shape=(6, 9)) 562 y = relay.split(x, 3).astuple() 563 a = relay.TupleGetItem(y, 0) 564 b = relay.TupleGetItem(y, 1) 565 c = relay.TupleGetItem(y, 2) 566 mod = tvm.IRModule() 567 mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) 568 mod = transform.FuseOps()(mod) 569 570 571def test_fuse_max(): 572 """Test the constraint of number of nodes in op fusion.""" 573 574 def before(n): 575 x = relay.var("x", shape=(10, 20)) 576 y = x 577 for i in range(n): 578 y = relay.exp(y) 579 return relay.Function([x], y) 580 581 def expected(n, max_fused_ops): 582 x = relay.var("p", shape=(10, 20)) 583 y = x 584 for i in range(max_fused_ops): 585 y = relay.exp(y) 586 f1 = relay.Function([x], y) 587 f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 588 x = relay.var("x", shape=(10, 20)) 589 z = relay.Call(f1, [x]) 590 xx = relay.var("pp", shape=(10, 20)) 591 yy = xx 592 # it is assumed that there are two fused functions 593 for i in range(n - max_fused_ops): 594 yy = relay.exp(yy) 595 f2 = relay.Function([xx], yy) 596 f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 597 zz = relay.Call(f2, [z]) 598 return relay.Function([x], zz) 599 600 max_fused_ops = 256 601 n = 300 602 z = before(n) 603 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 604 zz = run_opt_pass(z, transform.FuseOps()) 605 after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) 606 assert tvm.ir.structural_equal(zz, after) 607 608 max_fused_ops = 10 609 n = 20 610 z = before(n) 611 after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) 612 613 with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): 614 zz = run_opt_pass(z, transform.FuseOps()) 615 616 assert tvm.ir.structural_equal(zz, after) 617 618 619def test_fuse_take(): 620 """Test fusion case involving concat and take""" 621 622 def before(): 623 shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) 624 x = relay.var("x", shape=shape) 625 concat = relay.concatenate([x, x], axis=-1) 626 out = relay.op.take(concat, indices=relay.const([0], dtype="int64")) 627 return relay.Function(relay.analysis.free_vars(out), out) 628 629 def expected(): 630 shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) 631 shape2 = (tvm.tir.const(1, "int64"),) 632 x = relay.var("x", shape=shape1) 633 p0 = relay.var("p0", shape=shape1) 634 p1 = relay.var("p1", shape=shape2, dtype="int64") 635 c = relay.const([0], dtype="int64") 636 concat = relay.concatenate([p0, p0], axis=-1) 637 out = relay.op.take(concat, indices=p1) 638 639 f0 = relay.Function([p0, p1], out) 640 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 641 642 y = relay.Call(f0, [x, c]) 643 return relay.Function([x], y) 644 645 orig = before() 646 m = fuse2(tvm.IRModule.from_expr(orig)) 647 relay.build(m, "llvm") 648 after = run_opt_pass(expected(), transform.InferType()) 649 assert tvm.ir.structural_equal(m["main"], after) 650 651 652def test_fuse_gather_nd(): 653 """Test fusion case involving concat and gather_nd""" 654 655 def before(): 656 shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) 657 x = relay.var("x", shape=shape) 658 concat = relay.concatenate([x, x], axis=-1) 659 out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64")) 660 return relay.Function(relay.analysis.free_vars(out), out) 661 662 def expected(): 663 shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) 664 shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64")) 665 x = relay.var("x", shape=shape1) 666 p0 = relay.var("p0", shape=shape1) 667 p1 = relay.var("p1", shape=shape2, dtype="int64") 668 c = relay.const([[0, 1], [1, 0]], dtype="int64") 669 concat = relay.concatenate([p0, p0], axis=-1) 670 out = relay.gather_nd(concat, indices=p1) 671 672 f0 = relay.Function([p0, p1], out) 673 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 674 675 y = relay.Call(f0, [x, c]) 676 return relay.Function([x], y) 677 678 orig = before() 679 m = fuse2(tvm.IRModule.from_expr(orig)) 680 relay.build(m, "llvm") 681 after = run_opt_pass(expected(), transform.InferType()) 682 assert tvm.ir.structural_equal(m["main"], after) 683 684 685@tvm.testing.uses_gpu 686def test_fuse_bcast_reduce_scalar(): 687 """Test fusion case with broadcast and reduction involving scalar""" 688 689 def before(): 690 x = relay.var("x", shape=(), dtype="int32") 691 less = relay.less(x, relay.const(10, dtype="int32")) 692 z = relay.min(less) 693 return relay.Function([x], z) 694 695 def expected(): 696 p0 = relay.var("p0", shape=(), dtype="int32") 697 less = relay.less(p0, relay.const(10, dtype="int32")) 698 z0 = relay.min(less) 699 f0 = relay.Function([p0], z0) 700 f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 701 702 x = relay.var("x", shape=(), dtype="int32") 703 f = relay.Call(f0, [x]) 704 return relay.Function([x], f) 705 706 orig = before() 707 m = fuse2(tvm.IRModule.from_expr(orig)) 708 for tgt, ctx in tvm.testing.enabled_targets(): 709 relay.build(m, tgt) 710 after = run_opt_pass(expected(), transform.InferType()) 711 assert tvm.ir.structural_equal(m["main"], after) 712 713 714def test_fuse_max_diamond(): 715 def create_diamond(x, branch_len): 716 x1 = x 717 x2 = x 718 for _ in range(branch_len): 719 x1 = relay.exp(x1) 720 x2 = relay.exp(x2) 721 return relay.add(x1, x2) 722 723 def before(branch_len, num_diamond): 724 x = relay.var("x", shape=(10, 20)) 725 out = x 726 for _ in range(num_diamond): 727 out = create_diamond(out, branch_len) 728 return relay.Function([x], out) 729 730 def after(branch_len, num_diamond): 731 def create_diamond_func(inp): 732 inp_var = relay.var("p", shape=(10, 20)) 733 d = create_diamond(inp_var, branch_len) 734 f = relay.Function([inp_var], d) 735 f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) 736 return relay.Call(f, [inp]) 737 738 inp = relay.var("x", shape=(10, 20)) 739 out = inp 740 for _ in range(num_diamond): 741 out = create_diamond_func(out) 742 return relay.Function([inp], out) 743 744 branch_len = 5 745 max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond 746 num_diamond = 3 747 748 with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): 749 fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps()) 750 751 expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType()) 752 assert tvm.ir.structural_equal(fused, expected) 753 754 755if __name__ == "__main__": 756 test_fuse_simple() 757 test_conv2d_fuse() 758 test_concatenate() 759 test_tuple_root() 760 test_stop_fusion() 761 test_fuse_myia_regression() 762 test_fuse_tuple_get_elemwise() 763 test_tuple_get_root() 764 test_tuple_intermediate() 765 test_tuple_consecutive() 766 test_inception_like() 767 test_fuse_parallel_injective() 768 test_immutable() 769 test_split() 770 test_fuse_max() 771 test_fuse_take() 772 test_fuse_gather_nd() 773 test_fuse_bcast_reduce_scalar() 774 test_fuse_max_diamond() 775