1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Unit tests for merge composite.""" 18import pytest 19import tvm 20from tvm import relay, tir 21from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard 22from tvm.relay.testing import run_opt_pass 23 24 25""" 26The merge composite pass is designed to merge multiple relay operators, that 27match a given pattern, and combine them into a single relay function. 28 29For example suppose we have the graph: 30 31 conv2d 32 | (merge composite pass) 33 bias_add ====> conv2d_bias_relu 34 | (our target) 35 relu 36 37Our Relay IR before the pass: 38 fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], 39 %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] { 40 %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1]) 41 /* ty=Tensor[(1, 256, 28, 28), float32] */; 42 %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */; 43 nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */ 44 } 45 46Our Relay IR after the pass: 47 fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], 48 %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] { 49 %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32], 50 %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") -> 51 Tensor[(1, 256, 28, 28), float32] { 52 %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */; 53 %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */; 54 nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */ 55 }; 56 %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */ 57 } 58 59As you can see in the second relay example, the pattern we specified has been wrapped 60in a function. The function is then called, producing the same result as the first relay 61example. 62 63One convenient use for this pass is to offload multiple operators to a single external 64codegen function. 65""" 66 67 68def make_add_sub_mul_pattern(): 69 r"""Create a pattern to match the following graph. 70 71 add sub 72 \ / 73 \ / 74 mul 75 """ 76 x = wildcard() 77 y = wildcard() 78 return (x + y) * (x - y) 79 80 81def make_add_relu_pattern(): 82 r"""Create a pattern to match the following graph. 83 84 add 85 | 86 relu 87 """ 88 add_node = wildcard() + wildcard() 89 r = is_op("nn.relu")(add_node) 90 return r 91 92 93def make_conv_bias_relu_pattern(): 94 r"""Create a pattern to match the following graph. 95 96 conv2d 97 | 98 bias_add 99 | 100 relu 101 """ 102 x = wildcard() 103 y = wildcard() 104 z = wildcard() 105 conv_node = is_op("nn.conv2d")(x, y) 106 bias_node = is_op("nn.bias_add")(conv_node, z) 107 r = is_op("nn.relu")(bias_node) 108 return r 109 110 111def make_pattern_with_optional(): 112 r"""Create a pattern to match the following graph. Note that relu is optinal. 113 114 conv2d 115 | 116 bias_add 117 | 118 (relu) 119 """ 120 x = wildcard() 121 y = wildcard() 122 z = wildcard() 123 conv_node = is_op("nn.conv2d")(x, y) 124 bias_node = is_op("nn.bias_add")(conv_node, z) 125 r = bias_node.optional(lambda x: is_op("nn.relu")(x)) 126 return r 127 128 129def make_add_add_add_pattern(): 130 r"""Create a pattern to match the following graph. 131 Useful for testing re-using a call node. 132 133 x y 134 / \ / 135 | add 136 \ | \ 137 add | 138 | / 139 add 140 """ 141 x = wildcard() 142 y = wildcard() 143 add_node = is_op("add")(x, y) 144 add_node_1 = is_op("add")(x, add_node) 145 r = is_op("add")(add_node_1, add_node) 146 return r 147 148 149def make_bn_relu_pattern(): 150 r"""Create a pattern to match the following graph. 151 152 batch_norm 153 | 154 TupleGetItem(0) 155 | 156 relu 157 """ 158 x = wildcard() 159 gamma = wildcard() 160 beta = wildcard() 161 moving_mean = wildcard() 162 moving_var = wildcard() 163 bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var) 164 tuple_get_item_node = TupleGetItemPattern(bn_node, 0) 165 r = is_op("nn.relu")(tuple_get_item_node) 166 return r 167 168 169def check_result(pattern_table, graph, expected_graph, import_prelude=False): 170 """Utility function to check merge composite results.""" 171 result = run_opt_pass( 172 graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude 173 ) 174 assert not relay.analysis.free_vars(result), "Found free vars in the result graph: {0}".format( 175 str(result) 176 ) 177 expected = run_opt_pass(expected_graph, relay.transform.InferType()) 178 assert tvm.ir.structural_equal( 179 result, expected, map_free_vars=True 180 ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected)) 181 182 183def test_simple_merge(): 184 r"""Test composite function is correctly produced from simple graph. 185 186 We could expect the pattern `make_add_relu_pattern` to be merged 187 into a single op `add_relu`. 188 189 a b 190 \ / a b 191 add ====> \ / 192 | add_relu 193 relu 194 195 """ 196 pattern_table = [("add_relu", make_add_relu_pattern())] 197 198 def before(): 199 a = relay.var("a", shape=(10, 10)) 200 b = relay.var("b", shape=(10, 10)) 201 add_node = relay.add(a, b) 202 r = relay.nn.relu(add_node) 203 return relay.Function([a, b], r) 204 205 def expected(): 206 a = relay.var("a", shape=(10, 10)) 207 b = relay.var("b", shape=(10, 10)) 208 209 # add_relu function 210 in_1 = relay.var("in_1", shape=(10, 10)) 211 in_2 = relay.var("in_2", shape=(10, 10)) 212 add_node = relay.add(in_1, in_2) 213 relu_node = relay.nn.relu(add_node) 214 add_relu = relay.Function([in_1, in_2], relu_node) 215 add_relu = add_relu.with_attr("Composite", "add_relu") 216 add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") 217 218 # merged function 219 r = relay.Call(add_relu, [a, b]) 220 return relay.Function([a, b], r) 221 222 check_result(pattern_table, before(), expected()) 223 224 225def test_branch_merge(): 226 r"""Test composite function is correctly produced from branching graph. 227 228 We would expect the pattern `make_add_sub_mul_pattern` to be merged 229 into a single op `add_sub_mul`. 230 231 a b a b 232 \/ \/ 233 add sub a b 234 \ / \/ 235 \ / add_sub_mul 236 mul c | 237 / \ \ | 238 c / c | ====> add_sub_mul 239 \/ \/ | 240 add sub | 241 \ / relu 242 \ / 243 mul 244 | 245 | 246 relu 247 """ 248 249 pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())] 250 251 def before(): 252 a = relay.var("a", shape=(10, 10)) 253 b = relay.var("b", shape=(10, 10)) 254 c = relay.var("c", shape=(10, 10)) 255 add_node = relay.add(a, b) 256 sub_node = relay.subtract(a, b) 257 mul_node = relay.multiply(add_node, sub_node) 258 add_node_2 = relay.add(c, mul_node) 259 sub_node_2 = relay.subtract(c, mul_node) 260 mul_node_2 = relay.multiply(add_node_2, sub_node_2) 261 r = relay.nn.relu(mul_node_2) 262 return relay.Function([a, b, c], r) 263 264 def expected(): 265 a = relay.var("a", shape=(10, 10)) 266 b = relay.var("b", shape=(10, 10)) 267 c = relay.var("c", shape=(10, 10)) 268 269 # add_sub_mul function 270 in_1 = relay.var("in_1", shape=(10, 10)) 271 in_2 = relay.var("in_2", shape=(10, 10)) 272 add_node = relay.add(in_1, in_2) 273 sub_node = relay.subtract(in_1, in_2) 274 mul_node = relay.multiply(add_node, sub_node) 275 add_sub_mul = relay.Function([in_1, in_2], mul_node) 276 add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") 277 add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_") 278 279 # add_sub_mul1 function 280 in_3 = relay.var("in_3", shape=(10, 10)) 281 in_4 = relay.var("in_4", shape=(10, 10)) 282 add_node_1 = relay.add(in_3, in_4) 283 sub_node_1 = relay.subtract(in_3, in_4) 284 mul_node_1 = relay.multiply(add_node_1, sub_node_1) 285 add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) 286 add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul") 287 add_sub_mul_1 = add_sub_mul_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_") 288 289 # merged function 290 m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) 291 m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1]) 292 r = relay.nn.relu(m_add_sub_mul_2) 293 return relay.Function([a, b, c], r) 294 295 check_result(pattern_table, before(), expected()) 296 297 298def test_reuse_call_merge(): 299 r"""Test composite function is correctly produced from simple graph 300 which re-uses call nodes. 301 302 We could expect the pattern `make_add_add_add` to be merged 303 into a single op `add_add_add`. 304 305 x y 306 \ / \ 307 sub | x y 308 / | / \ / | 309 | add ====> sub | 310 \ | \ | / 311 add | add_add_add 312 | / 313 add 314 315 """ 316 pattern_table = [("add_add_add", make_add_add_add_pattern())] 317 318 def before(): 319 a = relay.var("a", shape=(10, 10)) 320 b = relay.var("b", shape=(10, 10)) 321 sub_node = relay.subtract(a, b) 322 323 # pattern 324 add_node = relay.add(sub_node, b) 325 add_node_1 = relay.add(sub_node, add_node) 326 r = relay.add(add_node_1, add_node) 327 328 return relay.Function([a, b], r) 329 330 def expected(): 331 a = relay.var("a", shape=(10, 10)) 332 b = relay.var("b", shape=(10, 10)) 333 334 # add_relu_add function 335 in_1 = relay.var("in_1", shape=(10, 10)) 336 in_2 = relay.var("in_2", shape=(10, 10)) 337 add_node = relay.add(in_1, in_2) 338 add_node_1 = relay.add(in_1, add_node) 339 add_node_2 = relay.add(add_node_1, add_node) 340 add_add_add = relay.Function([in_1, in_2], add_node_2) 341 add_add_add = add_add_add.with_attr("Composite", "add_add_add") 342 add_add_add = add_add_add.with_attr("PartitionedFromPattern", "add_add_add_") 343 344 # merged function 345 sub_node = relay.subtract(a, b) 346 call = relay.Call(add_add_add, [sub_node, b]) 347 return relay.Function([a, b], call) 348 349 check_result(pattern_table, before(), expected()) 350 351 352def test_multiple_patterns(): 353 r"""Test different patterns are merged correctly in the graph. 354 355 We would expect the pattern `make_conv_bias_relu_pattern` to be merged 356 into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern` 357 to be merged into a single op `add_relu`. 358 359 data kernel 360 \ / 361 \ / 362 conv2d data kernel bias 363 | \ | / 364 | bias conv2d_bias_relu 365 | / | 366 bias_add ====> | a 367 | | / 368 relu a add_relu 369 \ / | 370 add | b 371 | | / 372 relu b mul 373 | / 374 mul 375 """ 376 pattern_table = [ 377 ("conv2d_bias_relu", make_conv_bias_relu_pattern()), 378 ("add_relu", make_add_relu_pattern()), 379 ] 380 381 def before(): 382 data = relay.var("data", shape=(1, 512, 28, 28)) 383 kernel = relay.var("kernel", shape=(256, 512, 1, 1)) 384 bias = relay.var("bias", shape=(256,)) 385 a = relay.var("a", shape=(1, 256, 28, 28)) 386 b = relay.var("b", shape=(1, 256, 28, 28)) 387 388 conv_node = relay.nn.conv2d( 389 data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1) 390 ) 391 392 bias_node = relay.nn.bias_add(conv_node, bias) 393 relu_node = relay.nn.relu(bias_node) 394 add_node = relay.add(relu_node, a) 395 relu_node_2 = relay.nn.relu(add_node) 396 r = relay.multiply(relu_node_2, b) 397 return relay.Function([data, kernel, bias, a, b], r) 398 399 def expected(): 400 data = relay.var("data", shape=(1, 512, 28, 28)) 401 kernel = relay.var("kernel", shape=(256, 512, 1, 1)) 402 bias = relay.var("bias", shape=(256,)) 403 a = relay.var("a", shape=(1, 256, 28, 28)) 404 b = relay.var("b", shape=(1, 256, 28, 28)) 405 406 # conv_bias_relu function 407 in_1 = relay.var("in_1", shape=(1, 512, 28, 28)) 408 in_2 = relay.var("in_2", shape=(256, 512, 1, 1)) 409 in_3 = relay.var("in_3", shape=(256,)) 410 411 conv_node = relay.nn.conv2d(in_1, in_2, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)) 412 413 bias_node = relay.nn.bias_add(conv_node, in_3) 414 r = relay.nn.relu(bias_node) 415 conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) 416 conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", "conv2d_bias_relu") 417 conv_bias_add_relu = conv_bias_add_relu.with_attr( 418 "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_" 419 ) 420 421 # add_relu function 422 in_4 = relay.var("in_4", shape=(1, 256, 28, 28)) 423 in_5 = relay.var("in_5", shape=(1, 256, 28, 28)) 424 add_node = relay.add(in_4, in_5) 425 r = relay.nn.relu(add_node) 426 add_relu = relay.Function([in_4, in_5], r) 427 add_relu = add_relu.with_attr("Composite", "add_relu") 428 add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") 429 430 # merged function 431 conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) 432 add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a]) 433 r = relay.multiply(add_relu_1, b) 434 return relay.Function([data, kernel, bias, a, b], r) 435 436 check_result(pattern_table, before(), expected()) 437 438 439def test_optional_pattern(): 440 r"""Test the pattern with optional operators. We can define a pattern with some operators 441 optional. The merge composite pass will create composite functions for all matched patterns, 442 but with different "PartitionedFromPattern" attribute. We expect the backend codegen to 443 analyze that attribute and determine the corresponding action. 444 445 Pattern: Matched Case A: Matched Case B: 446 447 conv2d conv2d conv2d 448 | | | 449 bias_add bias_add bias_add 450 | | 451 (relu) relu 452 453 In the above example, the composite function for matched case A would have 454 PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_" while the one for matched case B 455 woud be "nn.conv2d_nn.bias_add_". 456 """ 457 pattern_table = [("layer", make_pattern_with_optional())] 458 459 def before(): 460 x = relay.var("x", shape=(1, 3, 7, 7)) 461 w1 = relay.var("w", shape=(3, 3, 1, 1)) 462 b1 = relay.var("b", shape=(3,)) 463 w2 = relay.var("w", shape=(3, 3, 1, 1)) 464 b2 = relay.var("b", shape=(3,)) 465 conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1)) 466 bias = relay.nn.bias_add(conv, b1) 467 relu = relay.nn.relu(bias) 468 conv = relay.nn.conv2d(relu, w2, kernel_size=(1, 1)) 469 bias = relay.nn.bias_add(conv, b2) 470 return relay.Function([x, w1, w2, b1, b2], bias) 471 472 def expected(): 473 # Matched composite function A 474 x = relay.var("x") 475 w = relay.var("w") 476 b = relay.var("b") 477 conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) 478 bias = relay.nn.bias_add(conv, b) 479 relu = relay.nn.relu(bias) 480 func1 = relay.Function([x, w, b], relu) 481 func1 = func1.with_attr("Composite", "layer") 482 func1 = func1.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") 483 484 # Matched composite function B 485 x = relay.var("x") 486 w = relay.var("w") 487 b = relay.var("b") 488 conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) 489 bias = relay.nn.bias_add(conv, b) 490 func2 = relay.Function([x, w, b], bias) 491 func2 = func2.with_attr("Composite", "layer") 492 func2 = func2.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_") 493 494 # Main function 495 x = relay.var("x", shape=(1, 3, 7, 7)) 496 w1 = relay.var("w", shape=(3, 3, 1, 1)) 497 b1 = relay.var("b", shape=(3,)) 498 w2 = relay.var("w", shape=(3, 3, 1, 1)) 499 b2 = relay.var("b", shape=(3,)) 500 out1 = func1(x, w1, b1) 501 out2 = func2(out1, w2, b2) 502 return relay.Function([x, w1, w2, b1, b2], out2) 503 504 check_result(pattern_table, before(), expected()) 505 506 507def test_merge_order(): 508 r"""Test that patterns are merged in the order they exist in the pattern table. 509 510 There can be cases where one pattern is a subgraph of another, in which case 511 it is not clear which match should take priority. The priority should come 512 from the order in which the patterns are declared in the pattern table. The 513 first patterns will be merged with highest priority and the last with lowest. 514 515 A: B: C: 516 add add abs 517 | | | 518 abs abs relu 519 | 520 relu 521 522 """ 523 524 def pattern_A(): 525 x = wildcard() 526 y = wildcard() 527 out = is_op("add")(x, y) 528 out = is_op("abs")(out) 529 out = is_op("nn.relu")(out) 530 return out 531 532 def pattern_B(): 533 x = wildcard() 534 y = wildcard() 535 out = is_op("add")(x, y) 536 out = is_op("abs")(out) 537 return out 538 539 def pattern_C(): 540 x = wildcard() 541 out = is_op("abs")(x) 542 out = is_op("nn.relu")(out) 543 return out 544 545 def before(): 546 input_1 = relay.var("input_1", shape=(10, 10)) 547 input_2 = relay.var("input_2", shape=(10, 10)) 548 out = relay.add(input_1, input_2) 549 out = relay.abs(out) 550 out = relay.nn.relu(out) 551 return relay.Function([input_1, input_2], out) 552 553 def after_A_priority(): 554 input_1 = relay.var("input_1", shape=(10, 10)) 555 input_2 = relay.var("input_2", shape=(10, 10)) 556 x = relay.var("x") 557 y = relay.var("y") 558 out = relay.add(x, y) 559 out = relay.abs(out) 560 out = relay.nn.relu(out) 561 merged_func = relay.Function([x, y], out) 562 merged_func = merged_func.with_attr("Composite", "A") 563 merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_nn.relu_") 564 ret = relay.Call(merged_func, [input_1, input_2]) 565 return relay.Function([input_1, input_2], ret) 566 567 def after_B_priority(): 568 input_1 = relay.var("input_1", shape=(10, 10)) 569 input_2 = relay.var("input_2", shape=(10, 10)) 570 x = relay.var("x") 571 y = relay.var("y") 572 out = relay.add(x, y) 573 out = relay.abs(out) 574 merged_func = relay.Function([x, y], out) 575 merged_func = merged_func.with_attr("Composite", "B") 576 merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_") 577 out = relay.Call(merged_func, [input_1, input_2]) 578 ret = relay.nn.relu(out) 579 return relay.Function([input_1, input_2], ret) 580 581 def after_C_priority(): 582 input_1 = relay.var("input_1", shape=(10, 10)) 583 input_2 = relay.var("input_2", shape=(10, 10)) 584 x = relay.var("x") 585 out = relay.abs(x) 586 out = relay.nn.relu(out) 587 merged_func = relay.Function([x], out) 588 merged_func = merged_func.with_attr("Composite", "C") 589 merged_func = merged_func.with_attr("PartitionedFromPattern", "abs_nn.relu_") 590 out = relay.add(input_1, input_2) 591 ret = relay.Call(merged_func, [out]) 592 return relay.Function([input_1, input_2], ret) 593 594 # check A highest priority 595 pattern_table = [ 596 ("A", pattern_A()), 597 ("B", pattern_B()), 598 ("C", pattern_C()), 599 ] 600 check_result(pattern_table, before(), after_A_priority()) 601 602 # check B highest priority 603 pattern_table = [ 604 ("B", pattern_B()), 605 ("C", pattern_C()), 606 ("A", pattern_A()), 607 ] 608 check_result(pattern_table, before(), after_B_priority()) 609 610 # check C highest priority 611 pattern_table = [ 612 ("C", pattern_C()), 613 ("A", pattern_A()), 614 ("B", pattern_B()), 615 ] 616 check_result(pattern_table, before(), after_C_priority()) 617 618 619def test_parallel_merge(): 620 r"""Tests that parallel patterns relying on the same inputs are correctly merged. 621 622 The test graph is difficult to draw out as ascii art. It is essentially two parallel 623 add-sub-mul units which both consume input_1 and input_2 with their results being multiplied 624 to give the output. We expect both parallel branches should get merged and both should still 625 consume the same input variables, input_1 and input_2.""" 626 627 def before(): 628 input_1 = relay.var("input_1", shape=(10, 10)) 629 input_2 = relay.var("input_2", shape=(10, 10)) 630 branch_1_add = relay.add(input_1, input_2) 631 branch_1_sub = relay.subtract(input_1, input_2) 632 branch_1 = relay.multiply(branch_1_add, branch_1_sub) 633 branch_2_add = relay.add(input_1, input_2) 634 branch_2_sub = relay.subtract(input_1, input_2) 635 branch_2 = relay.multiply(branch_2_add, branch_2_sub) 636 out = relay.multiply(branch_1, branch_2) 637 return relay.Function([input_1, input_2], out) 638 639 def expected(): 640 input_1 = relay.var("input_1", shape=(10, 10)) 641 input_2 = relay.var("input_2", shape=(10, 10)) 642 x = relay.var("x") 643 y = relay.var("y") 644 branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) 645 func_1 = relay.Function([x, y], branch_1) 646 func_1 = func_1.with_attr("Composite", "add_sub_mul") 647 func_1 = func_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_") 648 call_1 = relay.Call(func_1, [input_1, input_2]) 649 x1 = relay.var("x1") 650 y1 = relay.var("y1") 651 branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) 652 func_2 = relay.Function([x1, y1], branch_2) 653 func_2 = func_2.with_attr("Composite", "add_sub_mul") 654 func_2 = func_2.with_attr("PartitionedFromPattern", "add_subtract_multiply_") 655 call_2 = relay.Call(func_2, [input_1, input_2]) 656 out = relay.multiply(call_1, call_2) 657 return relay.Function([input_1, input_2], out) 658 659 pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())] 660 check_result(pattern_table, before(), expected()) 661 662 663def test_multiple_input_subgraphs(): 664 r"""Test the case when multiple input subgraphs feed into another subgraph. 665 666 (1) (2) (3) (4) 667 add add add add 668 | | | | 669 relu relu relu relu 670 \ / \ / 671 \ / \ / 672 add sub 673 \ / 674 \ / 675 \ / 676 mul 677 678 ----> When 1=3 and 2=4 (Case 'A') 679 680 add_relu add_relu 681 \ / 682 \ / 683 add_sub_mul 684 685 ----> When 1!=3 and 2!=4 (Case 'B') 686 687 add_relu add_relu add_relu add_relu 688 \ / \ / 689 \ / \ / 690 add sub 691 \ / 692 -------- ----- 693 \ / 694 mul 695 696 The difference in behaviour comes from the fact that add_sub_mul expects that the 697 inputs to add and sub are identical (the same two relay expressions). So when you 698 have 4 independent inputs, the pattern should not be merged. 699 """ 700 701 def before(): 702 before_funcs = {} 703 inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)] 704 add_relu_1 = relay.add(inputs[0], inputs[1]) 705 add_relu_1 = relay.nn.relu(add_relu_1) 706 add_relu_2 = relay.add(inputs[2], inputs[3]) 707 add_relu_2 = relay.nn.relu(add_relu_2) 708 add_relu_3 = relay.add(inputs[4], inputs[5]) 709 add_relu_3 = relay.nn.relu(add_relu_3) 710 add_relu_4 = relay.add(inputs[6], inputs[7]) 711 add_relu_4 = relay.nn.relu(add_relu_4) 712 add = relay.add(add_relu_1, add_relu_2) 713 sub = relay.subtract(add_relu_3, add_relu_4) 714 out = relay.multiply(add, sub) 715 before_funcs["B"] = relay.Function(inputs, out) 716 sub = relay.subtract(add_relu_1, add_relu_2) 717 out = relay.multiply(add, sub) 718 before_funcs["A"] = relay.Function(inputs[:4], out) 719 return before_funcs 720 721 def after_A(): 722 inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(4)] 723 x = relay.var("x") 724 y = relay.var("y") 725 add_relu_1 = relay.add(x, y) 726 add_relu_1 = relay.nn.relu(add_relu_1) 727 add_relu_1 = relay.Function([x, y], add_relu_1) 728 add_relu_1 = add_relu_1.with_attr("Composite", "add_relu") 729 add_relu_1 = add_relu_1.with_attr("PartitionedFromPattern", "add_nn.relu_") 730 add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) 731 x1 = relay.var("x1") 732 y1 = relay.var("y1") 733 add_relu_2 = relay.add(x1, y1) 734 add_relu_2 = relay.nn.relu(add_relu_2) 735 add_relu_2 = relay.Function([x1, y1], add_relu_2) 736 add_relu_2 = add_relu_2.with_attr("Composite", "add_relu") 737 add_relu_2 = add_relu_2.with_attr("PartitionedFromPattern", "add_nn.relu_") 738 add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) 739 x2 = relay.var("x2") 740 y2 = relay.var("y2") 741 add = relay.add(x2, y2) 742 sub = relay.subtract(x2, y2) 743 add_sub_mul = relay.multiply(add, sub) 744 add_sub_mul = relay.Function([x2, y2], add_sub_mul) 745 add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") 746 add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_") 747 add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) 748 return relay.Function(inputs, add_sub_mul_call) 749 750 def after_B(): 751 inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)] 752 add_relu_calls = [] 753 for i in range(4): 754 x = relay.var("x" + str(i)) 755 y = relay.var("x" + str(i)) 756 add_relu = relay.add(x, y) 757 add_relu = relay.nn.relu(add_relu) 758 add_relu = relay.Function([x, y], add_relu) 759 add_relu = add_relu.with_attr("Composite", "add_relu") 760 add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") 761 add_relu_call = relay.Call(add_relu, [inputs[i * 2], inputs[i * 2 + 1]]) 762 add_relu_calls.append(add_relu_call) 763 764 add = relay.add(add_relu_calls[0], add_relu_calls[1]) 765 sub = relay.subtract(add_relu_calls[2], add_relu_calls[3]) 766 out = relay.multiply(add, sub) 767 return relay.Function(inputs, out) 768 769 pattern_table = [ 770 ("add_sub_mul", make_add_sub_mul_pattern()), 771 ("add_relu", make_add_relu_pattern()), 772 ] 773 check_result(pattern_table, before()["A"], after_A()) 774 check_result(pattern_table, before()["B"], after_B()) 775 776 777def test_tuple_get_item_merge(): 778 """Test composite function can be merged from pattern containing TupleGetItem nodes.""" 779 pattern_table = [("bn_relu", make_bn_relu_pattern())] 780 781 def before(): 782 x = relay.var("x", shape=(1, 8)) 783 gamma = relay.var("gamma", shape=(8,)) 784 beta = relay.var("beta", shape=(8,)) 785 moving_mean = relay.var("moving_mean", shape=(8,)) 786 moving_var = relay.var("moving_var", shape=(8,)) 787 bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var) 788 tuple_get_item_node = bn_node[0] 789 r = relay.nn.relu(tuple_get_item_node) 790 return relay.Function([x, gamma, beta, moving_mean, moving_var], r) 791 792 def expected(): 793 x = relay.var("x", shape=(1, 8)) 794 beta = relay.var("beta", shape=(8,)) 795 gamma = relay.var("gamma", shape=(8,)) 796 moving_mean = relay.var("moving_mean", shape=(8,)) 797 moving_var = relay.var("moving_var", shape=(8,)) 798 799 # bn_relu function 800 in_1 = relay.var("x1", shape=(1, 8)) 801 in_2 = relay.var("gamma1", shape=(8,)) 802 in_3 = relay.var("beta1", shape=(8,)) 803 in_4 = relay.var("moving_mean1", shape=(8,)) 804 in_5 = relay.var("moving_var1", shape=(8,)) 805 bn_node = relay.nn.batch_norm(in_1, in_2, in_3, in_4, in_5) 806 tuple_get_item_node = bn_node[0] 807 relu_node = relay.nn.relu(tuple_get_item_node) 808 bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node) 809 bn_relu = bn_relu.with_attr("Composite", "bn_relu") 810 bn_relu = bn_relu.with_attr( 811 "PartitionedFromPattern", "nn.batch_norm_TupleGetItem0_nn.relu_" 812 ) 813 814 # merged function 815 r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var]) 816 return relay.Function([x, gamma, beta, moving_mean, moving_var], r) 817 818 check_result(pattern_table, before(), expected()) 819 820 821def test_pattern_with_check(): 822 def before(): 823 x = relay.var("x", shape=(1, 10, 10, 10)) 824 w = relay.var("w", shape=(10, 10, 3, 3)) 825 b = relay.var("b", shape=(8,)) 826 conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") 827 bias = relay.nn.bias_add(conv, b) 828 relu = relay.nn.relu(bias) 829 return relay.Function([x, w, b], relu) 830 831 def _check_true(extract): 832 conv = extract.args[0].args[0] 833 return conv.attrs.data_layout == "NHWC" 834 835 def _check_false(extract): 836 conv = extract.args[0].args[0] 837 return conv.attrs.data_layout == "NCHW" 838 839 def expected(): 840 x = relay.var("x") 841 w = relay.var("w") 842 b = relay.var("b") 843 conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") 844 bias = relay.nn.bias_add(conv, b) 845 relu = relay.nn.relu(bias) 846 func = relay.Function([x, w, b], relu) 847 func = func.with_attr("Composite", "conv_bias_relu") 848 func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") 849 850 x = relay.var("x", shape=(1, 10, 10, 10)) 851 w = relay.var("w", shape=(10, 10, 3, 3)) 852 b = relay.var("b", shape=(8,)) 853 return relay.Function([x, w, b], func(x, w, b)) 854 855 pattern_table_false = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)] 856 check_result(pattern_table_false, before(), before()) 857 858 pattern_table_true = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)] 859 check_result(pattern_table_true, before(), expected()) 860 861 862def test_diamond_not_merge(): 863 r""" 864 The pattern on the left shouldn't match the structure on the right 865 866 relu relu 867 | \ | \ 868 | clip | add 869 | / | | 870 mul | clip 871 | / 872 mul 873 """ 874 875 def get_pattern(): 876 conv = make_conv_bias_relu_pattern() 877 clip = is_op("clip")(conv, wildcard(), wildcard()) 878 return is_op("multiply")(conv, clip) 879 880 def get_net(): 881 data = relay.var("data", shape=(1, 512, 28, 28)) 882 kernel = relay.var("kernel", shape=(256, 512, 1, 1)) 883 conv = relay.nn.conv2d(data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)) 884 bias = relay.nn.bias_add(conv, relay.var("bias", shape=(256,))) 885 relu = relay.nn.relu(bias) 886 add = relay.op.add(relu, relay.const(1.0)) 887 clip2 = relay.op.clip(add, 0, 255) 888 mul = relay.op.multiply(relu, clip2) 889 return relay.Function(relay.analysis.free_vars(mul), mul) 890 891 pattern_table = [("pat", get_pattern())] 892 net = get_net() 893 check_result(pattern_table, net, net) 894 895 896def test_type_check(): 897 """Test that we can query tensor types in the 'check' function.""" 898 899 def before(): 900 x = relay.var("x", shape=(1, 10, 10, 10)) 901 w = relay.var("w", shape=(10, 10, 3, 3)) 902 b = relay.var("b", shape=(8,)) 903 add = relay.op.add(x, x) 904 relu = relay.nn.relu(add) 905 conv = relay.nn.conv2d( 906 relu, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC" 907 ) 908 bias = relay.nn.bias_add(conv, b) 909 relu2 = relay.nn.relu(bias) 910 return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType()) 911 912 def expected_false(): 913 x = relay.var("x", shape=(1, 10, 10, 10)) 914 w = relay.var("w", shape=(10, 10, 3, 3)) 915 b = relay.var("b", shape=(8,)) 916 917 x0 = relay.var("x") 918 y0 = relay.var("y") 919 920 add = relay.op.add(y0, y0) 921 relu = relay.nn.relu(add) 922 func = relay.Function([x0, y0], relu) 923 func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") 924 func = func.with_attr("Composite", "add_relu") 925 call = relay.Call(func, [x, x]) 926 927 conv = relay.nn.conv2d( 928 call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC" 929 ) 930 bias = relay.nn.bias_add(conv, b) 931 relu2 = relay.nn.relu(bias) 932 return relay.Function([x, w, b], relu2) 933 934 def expected_true(): 935 x = relay.var("x", shape=(1, 10, 10, 10)) 936 w = relay.var("w", shape=(10, 10, 3, 3)) 937 b = relay.var("b", shape=(8,)) 938 939 x0 = relay.var("x") 940 y0 = relay.var("y") 941 942 add = relay.op.add(y0, y0) 943 relu = relay.nn.relu(add) 944 func = relay.Function([x0, y0], relu) 945 func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") 946 func = func.with_attr("Composite", "add_relu") 947 call = relay.Call(func, [x, x]) 948 949 x2 = relay.var("x") 950 w1 = relay.var("w") 951 b1 = relay.var("b") 952 conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") 953 bias = relay.nn.bias_add(conv, b1) 954 relu2 = relay.nn.relu(bias) 955 func = relay.Function([x2, w1, b1], relu2) 956 func = func.with_attr("Composite", "conv_bias_relu") 957 func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") 958 call = relay.Call(func, [call, w, b]) 959 return relay.Function([x, w, b], call) 960 961 def _check_type_true(extract): 962 conv = extract.args[0].args[0] 963 typ = conv.checked_type 964 return bool(typ.shape[0] == 1) 965 966 def _check_type_false(extract): 967 conv = extract.args[0].args[0] 968 typ = conv.checked_type 969 return bool(typ.shape[0] != 1) 970 971 pattern_table_false = [ 972 ("add_relu", make_add_relu_pattern()), 973 ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false), 974 ] 975 check_result(pattern_table_false, before(), expected_false()) 976 977 pattern_table_true = [ 978 ("add_relu", make_add_relu_pattern()), 979 ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true), 980 ] 981 check_result(pattern_table_true, before(), expected_true()) 982 983 984if __name__ == "__main__": 985 pytest.main([__file__]) 986