1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18from tvm import relay 19from tvm.relay import transform 20from tvm.relay.testing import run_opt_pass 21 22 23def test_fuse_simple(): 24 """Simple testcase.""" 25 def before(): 26 x = relay.var("x", shape=(10, 20)) 27 y = relay.add(x, relay.const(1, "float32")) 28 z = relay.exp(y) 29 w = relay.squeeze(z) 30 return relay.Function([x], w) 31 32 def expected(): 33 x = relay.var("p", shape=(10, 20)) 34 y = relay.add(x, relay.const(1, "float32")) 35 z = relay.exp(y) 36 w = relay.squeeze(z) 37 f1 = relay.Function([x], w) 38 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 39 x = relay.var("x", shape=(10, 20)) 40 y = relay.Call(f1, [x]) 41 return relay.Function([x], y) 42 43 z = before() 44 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 45 zz = run_opt_pass(z, transform.FuseOps()) 46 after = run_opt_pass(expected(), transform.InferType()) 47 assert relay.analysis.alpha_equal(zz, after) 48 49 50def test_conv2d_fuse(): 51 """Test fusion case of conv2d""" 52 def before(dshape): 53 x = relay.var("x", shape=dshape) 54 x = relay.add(x, relay.const(1, "float32")) 55 y = relay.nn.conv2d(x, relay.var("w1"), 56 kernel_size=(3, 3), 57 padding=(1, 1), 58 channels=16) 59 # this is the next dominator. 60 y1 = relay.add(relay.const(1, "float32"), y) 61 y = relay.add(y, y1) 62 # second path 63 z2 = relay.nn.conv2d(y, relay.var("w2"), 64 kernel_size=(1, 1), 65 padding=(0,0), 66 channels=16) 67 z3 = relay.nn.conv2d(y, relay.var("w3"), 68 kernel_size=(3, 3), 69 padding=(1,1), 70 channels=16) 71 # add can only be fused to z1 72 z = relay.add(z2, z3) 73 return relay.Function(relay.analysis.free_vars(z), z) 74 75 def expected(dshape): 76 # segment 0 77 x = relay.var("p0", shape=dshape) 78 y = relay.add(x, relay.const(1, "float32")) 79 f0 = relay.Function([x], y) 80 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 81 82 # segment 1 83 x = relay.var("p0", shape=dshape) 84 w = relay.var("p1") 85 y = relay.nn.conv2d(x, w, 86 kernel_size=(3, 3), 87 padding=(1, 1), 88 channels=16) 89 y1 = relay.add(relay.const(1, "float32"), y) 90 y = relay.add(y, y1) 91 f1 = relay.Function([x, w], y) 92 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 93 94 # segment 2 95 x = relay.var("p0", shape=dshape) 96 w = relay.var("p1") 97 z2 = relay.nn.conv2d(x, w, 98 kernel_size=(3, 3), 99 padding=(1,1), 100 channels=16) 101 f2 = relay.Function([x, w], z2) 102 f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 103 104 # segment 3 105 x = relay.var("p0", shape=dshape) 106 w = relay.var("p1") 107 offset = relay.var("p2", shape=dshape) 108 z3 = relay.nn.conv2d(x, w, 109 kernel_size=(1, 1), 110 padding=(0, 0), 111 channels=16) 112 z3 = relay.add(z3, offset) 113 f3 = relay.Function([x, w, offset], z3) 114 f3 = f3.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 115 116 # compose 117 x = relay.var("x", shape=dshape) 118 y = relay.Call(f0, [x]) 119 y = relay.Call(f1, [y, relay.var("w1")]) 120 z2 = relay.Call(f2, [y, relay.var("w3")]) 121 z3 = relay.Call(f3, [y, relay.var("w2"), z2]) 122 z = z3 123 return relay.Function(relay.analysis.free_vars(z), z) 124 125 dshape = (1, 16, 64, 64) 126 z = before(dshape) 127 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 128 after = run_opt_pass(expected(dshape), transform.InferType()) 129 assert relay.analysis.alpha_equal(zz, after) 130 131 132def test_concatenate(): 133 """Test fusion case involving concat op and Tuple node""" 134 135 def before(dshape): 136 x = relay.var("x", shape=dshape) 137 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 138 upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") 139 concat = relay.concatenate((upsampled, x), axis=1) 140 out = relay.add(concat, relay.const(1, "float32")) 141 return relay.Function(relay.analysis.free_vars(out), out) 142 143 def expected(dshape): 144 x = relay.var("x", shape=dshape) 145 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 146 f0 = relay.Function([x], pooled) 147 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 148 149 p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) 150 p1 = relay.var("p1", shape=dshape) 151 upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") 152 concat = relay.concatenate((upsampled, p1), axis=1) 153 out = relay.add(concat, relay.const(1, "float32")) 154 f1 = relay.Function([p0, p1], out) 155 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 156 157 x = relay.var("x", shape=dshape) 158 y = relay.Call(f0, [x]) 159 z = relay.Call(f1, [y, x]) 160 return relay.Function([x], z) 161 162 dshape = (1, 16, 64, 64) 163 z = before(dshape) 164 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 165 assert not relay.analysis.free_vars(zz) 166 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 167 assert not relay.analysis.free_vars(zz) 168 after = run_opt_pass(expected(dshape), transform.InferType()) 169 assert relay.analysis.alpha_equal(zz, after) 170 171 172def test_tuple_root(): 173 """Test fusion case where Tuple node is the root in its group""" 174 175 def before(dshape): 176 x = relay.var("x", shape=dshape) 177 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 178 upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") 179 out = relay.Tuple((upsampled, x)) 180 return relay.Function(relay.analysis.free_vars(out), out) 181 182 def expected(dshape): 183 x = relay.var("x", shape=dshape) 184 pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 185 f0 = relay.Function([x], pooled) 186 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 187 188 p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) 189 upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") 190 f1 = relay.Function([p0], upsampled) 191 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 192 193 x = relay.var("x", shape=dshape) 194 y = relay.Call(f0, [x]) 195 z = relay.Call(f1, [y]) 196 tup = relay.Tuple((z, x)) 197 return relay.Function([x], tup) 198 199 dshape = (1, 16, 64, 64) 200 z = before(dshape) 201 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 202 assert not relay.analysis.free_vars(zz) 203 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 204 assert not relay.analysis.free_vars(zz) 205 after = run_opt_pass(expected(dshape), transform.InferType()) 206 assert relay.analysis.alpha_equal(zz, after) 207 208 209def test_stop_fusion(): 210 def before(dshape): 211 x = relay.var("x", shape=dshape) 212 y = relay.add(x, relay.const(1, "float32")) 213 y = relay.annotation.stop_fusion(y) 214 z = relay.exp(y) 215 return relay.Function([x], z) 216 217 def expected(dshape): 218 x = relay.var("p0", shape=dshape) 219 y = relay.add(x, relay.const(1, "float32")) 220 f1 = relay.Function([x], y) 221 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 222 223 x = relay.var("p01", shape=dshape) 224 y = relay.exp(x) 225 f2 = relay.Function([x], y) 226 f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 227 228 x = relay.var("x", shape=dshape) 229 y = relay.Call(f1, [x]) 230 z = relay.Call(f2, [y]) 231 return relay.Function([x], z) 232 233 dshape = (10, 20) 234 z = before(dshape) 235 zz = run_opt_pass(z, transform.FuseOps()) 236 after = run_opt_pass(expected(dshape), transform.InferType()) 237 assert relay.analysis.alpha_equal(zz, after) 238 239 240def test_fuse_myia_regression(): 241 def before(dshape, dtype): 242 x = relay.var('x', shape=dshape, dtype=dtype) 243 y = relay.var('y', shape=dshape, dtype=dtype) 244 sb = relay.ScopeBuilder() 245 with sb.if_scope(relay.op.greater(x, y)): 246 sb.ret(relay.Function([], x)) 247 with sb.else_scope(): 248 sb.ret(relay.Function([], y)) 249 return relay.Function([x, y], 250 relay.Call(sb.get(), [])) 251 252 def expected(dshape, dtype): 253 x = relay.var('x', shape=dshape, dtype=dtype) 254 y = relay.var('y', shape=dshape, dtype=dtype) 255 sb = relay.ScopeBuilder() 256 p1 = relay.var('p1', shape=dshape, dtype=dtype) 257 p2 = relay.var('p2', shape=dshape, dtype=dtype) 258 fused_gt = relay.Function([p1, p2], 259 relay.op.greater(p1, p2)) 260 fused_gt = fused_gt.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 261 with sb.if_scope(fused_gt(x, y)): 262 sb.ret(relay.Function([], x)) 263 with sb.else_scope(): 264 sb.ret(relay.Function([], y)) 265 return relay.Function([x, y], 266 relay.Call(sb.get(), [])) 267 268 dshape = () 269 dtype = 'int64' 270 f = before(dshape, dtype) 271 zz = run_opt_pass(f, transform.FuseOps()) 272 after = run_opt_pass(expected(dshape, dtype), transform.InferType()) 273 assert relay.analysis.alpha_equal(zz, after) 274 275 276def test_fuse_tuple_get_elemwise(): 277 def before(dim): 278 X = relay.var("X", shape=(1, dim)) 279 W = relay.var("W", shape=(3 * dim, dim)) 280 matmul = relay.nn.dense(X, W) 281 splitted = relay.split(matmul, indices_or_sections=3, axis=1) 282 out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) 283 return relay.Function([X, W], out) 284 285 def expected(dim): 286 p0 = relay.var("p0", shape=(1, dim)) 287 p1 = relay.var("p1", shape=(3 * dim, dim)) 288 matmul = relay.nn.dense(p0, p1) 289 f0 = relay.Function([p0, p1], matmul) 290 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 291 292 p01 = relay.var("p01", shape=(1, 3 * dim)) 293 splitted = relay.split(p01, indices_or_sections=3, axis=1) 294 out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) 295 f1 = relay.Function([p01], out) 296 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 297 298 X = relay.var("X", shape=(1, dim)) 299 W = relay.var("W", shape=(3 * dim, dim)) 300 y = relay.Call(f0, [X, W]) 301 z = relay.Call(f1, [y]) 302 return relay.Function([X, W], z) 303 304 dim = 10 305 z = before(dim) 306 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 307 assert not relay.analysis.free_vars(zz) 308 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 309 assert not relay.analysis.free_vars(zz) 310 after = run_opt_pass(expected(dim), transform.InferType()) 311 assert relay.analysis.alpha_equal(zz, after) 312 313 314def test_tuple_get_root(): 315 def before(dim): 316 X = relay.var("X", shape=(1, 3 * dim)) 317 W = relay.var("W", shape=(dim, dim)) 318 splitted = relay.split(X, indices_or_sections=3, axis=1) 319 out = relay.nn.dense(splitted[0], W) 320 return relay.Function([X, W], out) 321 322 def expected(dim): 323 p0 = relay.var("p0", shape=(1, 3 * dim)) 324 splitted = relay.split(p0, indices_or_sections=3, axis=1) 325 out = splitted[0] 326 f0 = relay.Function([p0], out) 327 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 328 329 p01 = relay.var("p01", shape=(1, dim)) 330 p1 = relay.var("p1", shape=(dim, dim)) 331 out = relay.nn.dense(p01, p1) 332 f1 = relay.Function([p01, p1], out) 333 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 334 335 X = relay.var("X", shape=(1, 3 * dim)) 336 W = relay.var("W", shape=(dim, dim)) 337 y = relay.Call(f0, [X]) 338 z = relay.Call(f1, [y, W]) 339 return relay.Function([X, W], z) 340 341 dim = 10 342 z = before(dim) 343 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 344 assert not relay.analysis.free_vars(zz) 345 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 346 assert not relay.analysis.free_vars(zz) 347 after = run_opt_pass(expected(dim), transform.InferType()) 348 assert relay.analysis.alpha_equal(zz, after) 349 350 351fuse0 = relay.transform.FuseOps(fuse_opt_level=0) 352fuse2 = relay.transform.FuseOps(fuse_opt_level=2) 353 354def test_tuple_intermediate(): 355 def before(x): 356 inj = relay.squeeze(x) 357 y1 = relay.add(inj, relay.const(1, "float32")) 358 tmp = relay.squeeze(inj) 359 tmp = relay.add(tmp, relay.const(1, "float32")) 360 y2 = relay.add(tmp, relay.const(1, "float32")) 361 y3 = relay.add(inj, relay.const(1, "float32")) 362 concat = relay.concatenate((y1, y2, y3), axis=1) 363 out_inj = relay.squeeze(concat) 364 out = relay.add(out_inj, relay.const(1, "float32")) 365 return relay.Function(relay.analysis.free_vars(out), out) 366 367 def expected(p0): 368 f0 = before(p0) 369 f1 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 370 x = relay.var("x", shape=dshape) 371 y = relay.Call(f1, [x]) 372 return relay.Function([x], y) 373 374 dshape = (1, 16, 64, 64) 375 x = relay.var("x", shape=dshape) 376 orig = before(x) 377 fuse0(relay.Module.from_expr(orig)) 378 m = fuse2(relay.Module.from_expr(orig)) 379 relay.build(m, 'llvm') 380 after = run_opt_pass(expected(x), transform.InferType()) 381 assert relay.analysis.alpha_equal(m["main"], after) 382 383 384def test_tuple_consecutive(): 385 def gen_intermediate_tuple(x): 386 y1 = relay.add(x, relay.const(1, "float32")) 387 y2 = relay.add(x, relay.const(1, "float32")) 388 y3 = relay.add(x, relay.const(1, "float32")) 389 concat = relay.concatenate((y1, y2, y3), axis=1) 390 out = relay.add(concat, relay.const(1, "float32")) 391 return out 392 393 def gen_consecutive_tuple(x): 394 y1 = gen_intermediate_tuple(x) 395 y2 = gen_intermediate_tuple(x) 396 y3 = gen_intermediate_tuple(x) 397 concat = relay.concatenate((y1, y2, y3), axis=1) 398 return concat 399 400 def before(x): 401 concat = gen_consecutive_tuple(x) 402 pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 403 out = relay.add(pooled, relay.const(1, "float32")) 404 out2 = relay.add(out, relay.const(1, "float32")) 405 out_tup = relay.Tuple((out, out2)) 406 return relay.Function(relay.analysis.free_vars(out_tup), out_tup) 407 408 def expected(dshape): 409 p0 = relay.var("p0", shape=dshape) 410 concat = gen_consecutive_tuple(p0) 411 f0 = relay.Function([p0], concat) 412 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 413 414 p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) 415 pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) 416 out = relay.add(pooled, relay.const(1, "float32")) 417 f1 = relay.Function([p01], out) 418 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 419 420 p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) 421 out = relay.add(p02, relay.const(1, "float32")) 422 f2 = relay.Function([p02], out) 423 f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 424 425 x = relay.var("x", shape=dshape) 426 y = relay.Call(f0, [x]) 427 z = relay.Call(f1, [y]) 428 z2 = relay.Call(f2, [z]) 429 430 return relay.Function([x], relay.Tuple((z, z2))) 431 432 dshape = (1, 16, 64, 64) 433 x = relay.var("x", shape=dshape) 434 orig = before(x) 435 fuse0(relay.Module.from_expr(orig)) 436 m = fuse2(relay.Module.from_expr(orig)) 437 relay.build(m, 'llvm') 438 after = run_opt_pass(expected(dshape), transform.InferType()) 439 assert relay.analysis.alpha_equal(m["main"], after) 440 441 442def test_inception_like(): 443 def conv(data): 444 y = relay.nn.conv2d(data, relay.var("w"), 445 kernel_size=(3, 3), 446 padding=(1, 1), 447 channels=16) 448 return relay.nn.relu(data=y) 449 450 def inception_like(data): 451 c0 = conv(data) 452 c1 = conv(data) 453 return relay.concatenate((c0, c1), axis=1) 454 455 def before(dshape): 456 x = relay.var("x", shape=dshape) 457 in1 = inception_like(x) 458 in2 = inception_like(in1) 459 return relay.Function(relay.analysis.free_vars(in2), in2) 460 461 def expected(dshape): 462 p0 = relay.var("p0", shape=dshape) 463 c = conv(p0) 464 f0 = relay.Function(relay.analysis.free_vars(c), c) 465 f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 466 467 p01 = relay.var("p01", shape=dshape) 468 c = conv(p01) 469 f1 = relay.Function(relay.analysis.free_vars(c), c) 470 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 471 472 p02 = relay.var("p02", shape=dshape) 473 p12 = relay.var("p12", shape=dshape) 474 concat1 = relay.concatenate((p02, p12), axis=1) 475 f_concat1 = relay.Function([p02, p12], concat1) 476 f_concat1 = f_concat1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 477 478 dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) 479 480 p03 = relay.var("p03", shape=dshape2) 481 c = conv(p03) 482 f2 = relay.Function(relay.analysis.free_vars(c), c) 483 f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 484 485 p04 = relay.var("p04", shape=dshape2) 486 c = conv(p04) 487 f3 = relay.Function(relay.analysis.free_vars(c), c) 488 f3 = f3.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 489 490 p05 = relay.var("p05", shape=dshape) 491 p15 = relay.var("p15", shape=dshape) 492 concat2 = relay.concatenate((p05, p15), axis=1) 493 f_concat2 = relay.Function([p05, p15], concat2) 494 f_concat2 = f_concat2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 495 496 x = relay.var("x", shape=dshape) 497 c1 = relay.Call(f0, [x, relay.var("w1")]) 498 c2 = relay.Call(f1, [x, relay.var("w2")]) 499 concat = relay.Call(f_concat1, [c1, c2]) 500 c3 = relay.Call(f2, [concat, relay.var("w3")]) 501 c4 = relay.Call(f3, [concat, relay.var("w4")]) 502 out = relay.Call(f_concat2, [c3, c4]) 503 504 return relay.Function(relay.analysis.free_vars(out), out) 505 506 dshape = (1, 16, 64, 64) 507 orig = before(dshape) 508 fuse0(relay.Module.from_expr(orig)) 509 m = fuse2(relay.Module.from_expr(orig)) 510 relay.build(m, 'llvm') 511 after = run_opt_pass(expected(dshape), transform.InferType()) 512 assert relay.analysis.alpha_equal(m["main"], after) 513 514 515def test_fuse_parallel_injective(): 516 """Test fusing parallel injective ops to an elemwise op.""" 517 def before(): 518 x = relay.var("x", shape=(10, 20)) 519 y = relay.add(x, relay.const(1, "float32")) 520 z = relay.squeeze(y) 521 u = relay.transpose(y, axes=[0, 1]) 522 w = relay.left_shift(z, u) 523 return relay.Function([x], w) 524 525 def expected(): 526 x = relay.var("p", shape=(10, 20)) 527 y = relay.add(x, relay.const(1, "float32")) 528 z = relay.squeeze(y) 529 u = relay.transpose(y, axes=[0, 1]) 530 w = relay.left_shift(z, u) 531 f1 = relay.Function([x], w) 532 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 533 x = relay.var("x", shape=(10, 20)) 534 y = relay.Call(f1, [x]) 535 return relay.Function([x], y) 536 537 z = before() 538 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) 539 assert not relay.analysis.free_vars(zz) 540 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 541 assert not relay.analysis.free_vars(zz) 542 after = run_opt_pass(expected(), transform.InferType()) 543 assert relay.analysis.alpha_equal(zz, after) 544 545 546def test_immutable(): 547 """Verify the fusion pass won't change original module.""" 548 def before(): 549 x = relay.var("x", shape=(10, 20)) 550 y = relay.add(x, relay.const(1, "float32")) 551 z = relay.exp(y) 552 w = relay.squeeze(z) 553 mod = relay.module.Module() 554 mod["main"] = relay.Function([x], w) 555 return mod 556 557 def expected(): 558 x = relay.var("p", shape=(10, 20)) 559 y = relay.add(x, relay.const(1, "float32")) 560 z = relay.exp(y) 561 w = relay.squeeze(z) 562 f1 = relay.Function([x], w) 563 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 564 x = relay.var("x", shape=(10, 20)) 565 y = relay.Call(f1, [x]) 566 mod = relay.module.Module() 567 mod["main"] = relay.Function([x], y) 568 return mod 569 570 mod = before() 571 new_mod = transform.FuseOps(fuse_opt_level=2)(mod) 572 assert relay.analysis.alpha_equal(mod, before()) 573 assert relay.analysis.alpha_equal(new_mod, expected()) 574 575 576def test_split(): 577 """Test that the result is well formed.""" 578 x = relay.var("x", shape=(6, 9)) 579 y = relay.split(x, 3).astuple() 580 a = relay.TupleGetItem(y, 0) 581 b = relay.TupleGetItem(y, 1) 582 c = relay.TupleGetItem(y, 2) 583 mod = relay.module.Module() 584 mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) 585 mod = transform.FuseOps()(mod) 586 587def test_fuse_max(): 588 """Test the constraint of number of nodes in op fusion.""" 589 max_fused_ops = 256 590 # n is the number of nodes to be fused, should be less than 2*max_fused_ops 591 n = 300 592 def before(): 593 x = relay.var("x", shape=(10, 20)) 594 y = x 595 for i in range(n): 596 y = relay.exp(y) 597 return relay.Function([x], y) 598 599 def expected(): 600 x = relay.var("p", shape=(10, 20)) 601 y = x 602 for i in range(max_fused_ops): 603 y = relay.exp(y) 604 f1 = relay.Function([x], y) 605 f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 606 x = relay.var("x", shape=(10, 20)) 607 z = relay.Call(f1, [x]) 608 xx = relay.var("pp", shape=(10, 20)) 609 yy = xx 610 for i in range(n-max_fused_ops): 611 yy = relay.exp(yy) 612 f2 = relay.Function([xx], yy) 613 f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) 614 zz = relay.Call(f2, [z]) 615 return relay.Function([x], zz) 616 617 z = before() 618 zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) 619 zz = run_opt_pass(z, transform.FuseOps()) 620 after = run_opt_pass(expected(), transform.InferType()) 621 assert relay.analysis.alpha_equal(zz, after) 622 623if __name__ == "__main__": 624 test_fuse_simple() 625 test_conv2d_fuse() 626 test_concatenate() 627 test_tuple_root() 628 test_stop_fusion() 629 test_fuse_myia_regression() 630 test_fuse_tuple_get_elemwise() 631 test_tuple_get_root() 632 test_tuple_intermediate() 633 test_tuple_consecutive() 634 test_inception_like() 635 test_fuse_parallel_injective() 636 test_immutable() 637 test_split() 638 test_fuse_max() 639