1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Unit tests for merge composite."""
18import pytest
19import tvm
20from tvm import relay, tir
21from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard
22from tvm.relay.testing import run_opt_pass
23
24
25"""
26The merge composite pass is designed to merge multiple relay operators, that
27match a given pattern, and combine them into a single relay function.
28
29For example suppose we have the graph:
30
31    conv2d
32      |       (merge composite pass)
33   bias_add            ====>           conv2d_bias_relu
34      |            (our target)
35     relu
36
37Our Relay IR before the pass:
38    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
39            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
40        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
41            /* ty=Tensor[(1, 256, 28, 28), float32] */;
42        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
43        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
44    }
45
46Our Relay IR after the pass:
47    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
48            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
49      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
50            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
51            Tensor[(1, 256, 28, 28), float32] {
52        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
53        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
54        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
55      };
56      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
57    }
58
59As you can see in the second relay example, the pattern we specified has been wrapped
60in a function. The function is then called, producing the same result as the first relay
61example.
62
63One convenient use for this pass is to offload multiple operators to a single external
64codegen function.
65"""
66
67
68def make_add_sub_mul_pattern():
69    r"""Create a pattern to match the following graph.
70
71    add  sub
72     \   /
73      \ /
74      mul
75    """
76    x = wildcard()
77    y = wildcard()
78    return (x + y) * (x - y)
79
80
81def make_add_relu_pattern():
82    r"""Create a pattern to match the following graph.
83
84     add
85      |
86    relu
87    """
88    add_node = wildcard() + wildcard()
89    r = is_op("nn.relu")(add_node)
90    return r
91
92
93def make_conv_bias_relu_pattern():
94    r"""Create a pattern to match the following graph.
95
96     conv2d
97       |
98    bias_add
99       |
100     relu
101    """
102    x = wildcard()
103    y = wildcard()
104    z = wildcard()
105    conv_node = is_op("nn.conv2d")(x, y)
106    bias_node = is_op("nn.bias_add")(conv_node, z)
107    r = is_op("nn.relu")(bias_node)
108    return r
109
110
111def make_pattern_with_optional():
112    r"""Create a pattern to match the following graph. Note that relu is optinal.
113
114     conv2d
115       |
116    bias_add
117       |
118     (relu)
119    """
120    x = wildcard()
121    y = wildcard()
122    z = wildcard()
123    conv_node = is_op("nn.conv2d")(x, y)
124    bias_node = is_op("nn.bias_add")(conv_node, z)
125    r = bias_node.optional(lambda x: is_op("nn.relu")(x))
126    return r
127
128
129def make_add_add_add_pattern():
130    r"""Create a pattern to match the following graph.
131       Useful for testing re-using a call node.
132
133        x    y
134      /  \  /
135      |  add
136       \  |  \
137         add |
138          | /
139         add
140    """
141    x = wildcard()
142    y = wildcard()
143    add_node = is_op("add")(x, y)
144    add_node_1 = is_op("add")(x, add_node)
145    r = is_op("add")(add_node_1, add_node)
146    return r
147
148
149def make_bn_relu_pattern():
150    r"""Create a pattern to match the following graph.
151
152     batch_norm
153         |
154    TupleGetItem(0)
155         |
156       relu
157    """
158    x = wildcard()
159    gamma = wildcard()
160    beta = wildcard()
161    moving_mean = wildcard()
162    moving_var = wildcard()
163    bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var)
164    tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
165    r = is_op("nn.relu")(tuple_get_item_node)
166    return r
167
168
169def check_result(pattern_table, graph, expected_graph, import_prelude=False):
170    """Utility function to check merge composite results."""
171    result = run_opt_pass(
172        graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
173    )
174    assert not relay.analysis.free_vars(result), "Found free vars in the result graph: {0}".format(
175        str(result)
176    )
177    expected = run_opt_pass(expected_graph, relay.transform.InferType())
178    assert tvm.ir.structural_equal(
179        result, expected, map_free_vars=True
180    ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected))
181
182
183def test_simple_merge():
184    r"""Test composite function is correctly produced from simple graph.
185
186    We could expect the pattern `make_add_relu_pattern` to be merged
187    into a single op `add_relu`.
188
189        a  b
190        \ /               a  b
191        add    ====>      \ /
192         |             add_relu
193       relu
194
195    """
196    pattern_table = [("add_relu", make_add_relu_pattern())]
197
198    def before():
199        a = relay.var("a", shape=(10, 10))
200        b = relay.var("b", shape=(10, 10))
201        add_node = relay.add(a, b)
202        r = relay.nn.relu(add_node)
203        return relay.Function([a, b], r)
204
205    def expected():
206        a = relay.var("a", shape=(10, 10))
207        b = relay.var("b", shape=(10, 10))
208
209        # add_relu function
210        in_1 = relay.var("in_1", shape=(10, 10))
211        in_2 = relay.var("in_2", shape=(10, 10))
212        add_node = relay.add(in_1, in_2)
213        relu_node = relay.nn.relu(add_node)
214        add_relu = relay.Function([in_1, in_2], relu_node)
215        add_relu = add_relu.with_attr("Composite", "add_relu")
216        add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_")
217
218        # merged function
219        r = relay.Call(add_relu, [a, b])
220        return relay.Function([a, b], r)
221
222    check_result(pattern_table, before(), expected())
223
224
225def test_branch_merge():
226    r"""Test composite function is correctly produced from branching graph.
227
228    We would expect the pattern `make_add_sub_mul_pattern` to be merged
229    into a single op `add_sub_mul`.
230
231       a  b  a  b
232        \/    \/
233        add  sub                       a  b
234         \   /                          \/
235          \ /                      add_sub_mul
236          mul                     c     |
237          /  \                     \    |
238       c /  c |       ====>        add_sub_mul
239       \/   \/                          |
240       add  sub                         |
241        \   /                         relu
242         \ /
243         mul
244          |
245          |
246        relu
247    """
248
249    pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
250
251    def before():
252        a = relay.var("a", shape=(10, 10))
253        b = relay.var("b", shape=(10, 10))
254        c = relay.var("c", shape=(10, 10))
255        add_node = relay.add(a, b)
256        sub_node = relay.subtract(a, b)
257        mul_node = relay.multiply(add_node, sub_node)
258        add_node_2 = relay.add(c, mul_node)
259        sub_node_2 = relay.subtract(c, mul_node)
260        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
261        r = relay.nn.relu(mul_node_2)
262        return relay.Function([a, b, c], r)
263
264    def expected():
265        a = relay.var("a", shape=(10, 10))
266        b = relay.var("b", shape=(10, 10))
267        c = relay.var("c", shape=(10, 10))
268
269        # add_sub_mul function
270        in_1 = relay.var("in_1", shape=(10, 10))
271        in_2 = relay.var("in_2", shape=(10, 10))
272        add_node = relay.add(in_1, in_2)
273        sub_node = relay.subtract(in_1, in_2)
274        mul_node = relay.multiply(add_node, sub_node)
275        add_sub_mul = relay.Function([in_1, in_2], mul_node)
276        add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
277        add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
278
279        # add_sub_mul1 function
280        in_3 = relay.var("in_3", shape=(10, 10))
281        in_4 = relay.var("in_4", shape=(10, 10))
282        add_node_1 = relay.add(in_3, in_4)
283        sub_node_1 = relay.subtract(in_3, in_4)
284        mul_node_1 = relay.multiply(add_node_1, sub_node_1)
285        add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
286        add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul")
287        add_sub_mul_1 = add_sub_mul_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
288
289        # merged function
290        m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
291        m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1])
292        r = relay.nn.relu(m_add_sub_mul_2)
293        return relay.Function([a, b, c], r)
294
295    check_result(pattern_table, before(), expected())
296
297
298def test_reuse_call_merge():
299    r"""Test composite function is correctly produced from simple graph
300       which re-uses call nodes.
301
302    We could expect the pattern `make_add_add_add` to be merged
303    into a single op `add_add_add`.
304
305        x     y
306         \   / \
307          sub  |           x     y
308        /  |  /             \   / |
309        | add      ====>     sub  |
310         \ |  \               |  /
311          add |           add_add_add
312           | /
313          add
314
315    """
316    pattern_table = [("add_add_add", make_add_add_add_pattern())]
317
318    def before():
319        a = relay.var("a", shape=(10, 10))
320        b = relay.var("b", shape=(10, 10))
321        sub_node = relay.subtract(a, b)
322
323        # pattern
324        add_node = relay.add(sub_node, b)
325        add_node_1 = relay.add(sub_node, add_node)
326        r = relay.add(add_node_1, add_node)
327
328        return relay.Function([a, b], r)
329
330    def expected():
331        a = relay.var("a", shape=(10, 10))
332        b = relay.var("b", shape=(10, 10))
333
334        # add_relu_add function
335        in_1 = relay.var("in_1", shape=(10, 10))
336        in_2 = relay.var("in_2", shape=(10, 10))
337        add_node = relay.add(in_1, in_2)
338        add_node_1 = relay.add(in_1, add_node)
339        add_node_2 = relay.add(add_node_1, add_node)
340        add_add_add = relay.Function([in_1, in_2], add_node_2)
341        add_add_add = add_add_add.with_attr("Composite", "add_add_add")
342        add_add_add = add_add_add.with_attr("PartitionedFromPattern", "add_add_add_")
343
344        # merged function
345        sub_node = relay.subtract(a, b)
346        call = relay.Call(add_add_add, [sub_node, b])
347        return relay.Function([a, b], call)
348
349    check_result(pattern_table, before(), expected())
350
351
352def test_multiple_patterns():
353    r"""Test different patterns are merged correctly in the graph.
354
355    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
356    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
357    to be merged into a single op `add_relu`.
358
359        data   kernel
360          \      /
361           \    /
362           conv2d                   data   kernel   bias
363             |                         \      |      /
364             |   bias                 conv2d_bias_relu
365             |   /                            |
366          bias_add        ====>               |    a
367             |                                |   /
368           relu  a                        add_relu
369             \  /                             |
370             add                              |  b
371              |                               | /
372            relu  b                          mul
373              |  /
374             mul
375    """
376    pattern_table = [
377        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
378        ("add_relu", make_add_relu_pattern()),
379    ]
380
381    def before():
382        data = relay.var("data", shape=(1, 512, 28, 28))
383        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
384        bias = relay.var("bias", shape=(256,))
385        a = relay.var("a", shape=(1, 256, 28, 28))
386        b = relay.var("b", shape=(1, 256, 28, 28))
387
388        conv_node = relay.nn.conv2d(
389            data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)
390        )
391
392        bias_node = relay.nn.bias_add(conv_node, bias)
393        relu_node = relay.nn.relu(bias_node)
394        add_node = relay.add(relu_node, a)
395        relu_node_2 = relay.nn.relu(add_node)
396        r = relay.multiply(relu_node_2, b)
397        return relay.Function([data, kernel, bias, a, b], r)
398
399    def expected():
400        data = relay.var("data", shape=(1, 512, 28, 28))
401        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
402        bias = relay.var("bias", shape=(256,))
403        a = relay.var("a", shape=(1, 256, 28, 28))
404        b = relay.var("b", shape=(1, 256, 28, 28))
405
406        # conv_bias_relu function
407        in_1 = relay.var("in_1", shape=(1, 512, 28, 28))
408        in_2 = relay.var("in_2", shape=(256, 512, 1, 1))
409        in_3 = relay.var("in_3", shape=(256,))
410
411        conv_node = relay.nn.conv2d(in_1, in_2, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1))
412
413        bias_node = relay.nn.bias_add(conv_node, in_3)
414        r = relay.nn.relu(bias_node)
415        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
416        conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", "conv2d_bias_relu")
417        conv_bias_add_relu = conv_bias_add_relu.with_attr(
418            "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
419        )
420
421        # add_relu function
422        in_4 = relay.var("in_4", shape=(1, 256, 28, 28))
423        in_5 = relay.var("in_5", shape=(1, 256, 28, 28))
424        add_node = relay.add(in_4, in_5)
425        r = relay.nn.relu(add_node)
426        add_relu = relay.Function([in_4, in_5], r)
427        add_relu = add_relu.with_attr("Composite", "add_relu")
428        add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_")
429
430        # merged function
431        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
432        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
433        r = relay.multiply(add_relu_1, b)
434        return relay.Function([data, kernel, bias, a, b], r)
435
436    check_result(pattern_table, before(), expected())
437
438
439def test_optional_pattern():
440    r"""Test the pattern with optional operators. We can define a pattern with some operators
441    optional. The merge composite pass will create composite functions for all matched patterns,
442    but with different "PartitionedFromPattern" attribute. We expect the backend codegen to
443    analyze that attribute and determine the corresponding action.
444
445    Pattern:    Matched Case A:    Matched Case B:
446
447     conv2d        conv2d             conv2d
448       |             |                  |
449    bias_add      bias_add           bias_add
450       |             |
451     (relu)         relu
452
453    In the above example, the composite function for matched case A would have
454    PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_" while the one for matched case B
455    woud be "nn.conv2d_nn.bias_add_".
456    """
457    pattern_table = [("layer", make_pattern_with_optional())]
458
459    def before():
460        x = relay.var("x", shape=(1, 3, 7, 7))
461        w1 = relay.var("w", shape=(3, 3, 1, 1))
462        b1 = relay.var("b", shape=(3,))
463        w2 = relay.var("w", shape=(3, 3, 1, 1))
464        b2 = relay.var("b", shape=(3,))
465        conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1))
466        bias = relay.nn.bias_add(conv, b1)
467        relu = relay.nn.relu(bias)
468        conv = relay.nn.conv2d(relu, w2, kernel_size=(1, 1))
469        bias = relay.nn.bias_add(conv, b2)
470        return relay.Function([x, w1, w2, b1, b2], bias)
471
472    def expected():
473        # Matched composite function A
474        x = relay.var("x")
475        w = relay.var("w")
476        b = relay.var("b")
477        conv = relay.nn.conv2d(x, w, kernel_size=(1, 1))
478        bias = relay.nn.bias_add(conv, b)
479        relu = relay.nn.relu(bias)
480        func1 = relay.Function([x, w, b], relu)
481        func1 = func1.with_attr("Composite", "layer")
482        func1 = func1.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
483
484        # Matched composite function B
485        x = relay.var("x")
486        w = relay.var("w")
487        b = relay.var("b")
488        conv = relay.nn.conv2d(x, w, kernel_size=(1, 1))
489        bias = relay.nn.bias_add(conv, b)
490        func2 = relay.Function([x, w, b], bias)
491        func2 = func2.with_attr("Composite", "layer")
492        func2 = func2.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_")
493
494        # Main function
495        x = relay.var("x", shape=(1, 3, 7, 7))
496        w1 = relay.var("w", shape=(3, 3, 1, 1))
497        b1 = relay.var("b", shape=(3,))
498        w2 = relay.var("w", shape=(3, 3, 1, 1))
499        b2 = relay.var("b", shape=(3,))
500        out1 = func1(x, w1, b1)
501        out2 = func2(out1, w2, b2)
502        return relay.Function([x, w1, w2, b1, b2], out2)
503
504    check_result(pattern_table, before(), expected())
505
506
507def test_merge_order():
508    r"""Test that patterns are merged in the order they exist in the pattern table.
509
510    There can be cases where one pattern is a subgraph of another, in which case
511    it is not clear which match should take priority. The priority should come
512    from the order in which the patterns are declared in the pattern table. The
513    first patterns will be merged with highest priority and the last with lowest.
514
515    A:       B:       C:
516    add      add      abs
517     |        |        |
518    abs      abs      relu
519     |
520    relu
521
522    """
523
524    def pattern_A():
525        x = wildcard()
526        y = wildcard()
527        out = is_op("add")(x, y)
528        out = is_op("abs")(out)
529        out = is_op("nn.relu")(out)
530        return out
531
532    def pattern_B():
533        x = wildcard()
534        y = wildcard()
535        out = is_op("add")(x, y)
536        out = is_op("abs")(out)
537        return out
538
539    def pattern_C():
540        x = wildcard()
541        out = is_op("abs")(x)
542        out = is_op("nn.relu")(out)
543        return out
544
545    def before():
546        input_1 = relay.var("input_1", shape=(10, 10))
547        input_2 = relay.var("input_2", shape=(10, 10))
548        out = relay.add(input_1, input_2)
549        out = relay.abs(out)
550        out = relay.nn.relu(out)
551        return relay.Function([input_1, input_2], out)
552
553    def after_A_priority():
554        input_1 = relay.var("input_1", shape=(10, 10))
555        input_2 = relay.var("input_2", shape=(10, 10))
556        x = relay.var("x")
557        y = relay.var("y")
558        out = relay.add(x, y)
559        out = relay.abs(out)
560        out = relay.nn.relu(out)
561        merged_func = relay.Function([x, y], out)
562        merged_func = merged_func.with_attr("Composite", "A")
563        merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_nn.relu_")
564        ret = relay.Call(merged_func, [input_1, input_2])
565        return relay.Function([input_1, input_2], ret)
566
567    def after_B_priority():
568        input_1 = relay.var("input_1", shape=(10, 10))
569        input_2 = relay.var("input_2", shape=(10, 10))
570        x = relay.var("x")
571        y = relay.var("y")
572        out = relay.add(x, y)
573        out = relay.abs(out)
574        merged_func = relay.Function([x, y], out)
575        merged_func = merged_func.with_attr("Composite", "B")
576        merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_")
577        out = relay.Call(merged_func, [input_1, input_2])
578        ret = relay.nn.relu(out)
579        return relay.Function([input_1, input_2], ret)
580
581    def after_C_priority():
582        input_1 = relay.var("input_1", shape=(10, 10))
583        input_2 = relay.var("input_2", shape=(10, 10))
584        x = relay.var("x")
585        out = relay.abs(x)
586        out = relay.nn.relu(out)
587        merged_func = relay.Function([x], out)
588        merged_func = merged_func.with_attr("Composite", "C")
589        merged_func = merged_func.with_attr("PartitionedFromPattern", "abs_nn.relu_")
590        out = relay.add(input_1, input_2)
591        ret = relay.Call(merged_func, [out])
592        return relay.Function([input_1, input_2], ret)
593
594    # check A highest priority
595    pattern_table = [
596        ("A", pattern_A()),
597        ("B", pattern_B()),
598        ("C", pattern_C()),
599    ]
600    check_result(pattern_table, before(), after_A_priority())
601
602    # check B highest priority
603    pattern_table = [
604        ("B", pattern_B()),
605        ("C", pattern_C()),
606        ("A", pattern_A()),
607    ]
608    check_result(pattern_table, before(), after_B_priority())
609
610    # check C highest priority
611    pattern_table = [
612        ("C", pattern_C()),
613        ("A", pattern_A()),
614        ("B", pattern_B()),
615    ]
616    check_result(pattern_table, before(), after_C_priority())
617
618
619def test_parallel_merge():
620    r"""Tests that parallel patterns relying on the same inputs are correctly merged.
621
622    The test graph is difficult to draw out as ascii art. It is essentially two parallel
623    add-sub-mul units which both consume input_1 and input_2 with their results being multiplied
624    to give the output. We expect both parallel branches should get merged and both should still
625    consume the same input variables, input_1 and input_2."""
626
627    def before():
628        input_1 = relay.var("input_1", shape=(10, 10))
629        input_2 = relay.var("input_2", shape=(10, 10))
630        branch_1_add = relay.add(input_1, input_2)
631        branch_1_sub = relay.subtract(input_1, input_2)
632        branch_1 = relay.multiply(branch_1_add, branch_1_sub)
633        branch_2_add = relay.add(input_1, input_2)
634        branch_2_sub = relay.subtract(input_1, input_2)
635        branch_2 = relay.multiply(branch_2_add, branch_2_sub)
636        out = relay.multiply(branch_1, branch_2)
637        return relay.Function([input_1, input_2], out)
638
639    def expected():
640        input_1 = relay.var("input_1", shape=(10, 10))
641        input_2 = relay.var("input_2", shape=(10, 10))
642        x = relay.var("x")
643        y = relay.var("y")
644        branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
645        func_1 = relay.Function([x, y], branch_1)
646        func_1 = func_1.with_attr("Composite", "add_sub_mul")
647        func_1 = func_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
648        call_1 = relay.Call(func_1, [input_1, input_2])
649        x1 = relay.var("x1")
650        y1 = relay.var("y1")
651        branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
652        func_2 = relay.Function([x1, y1], branch_2)
653        func_2 = func_2.with_attr("Composite", "add_sub_mul")
654        func_2 = func_2.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
655        call_2 = relay.Call(func_2, [input_1, input_2])
656        out = relay.multiply(call_1, call_2)
657        return relay.Function([input_1, input_2], out)
658
659    pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
660    check_result(pattern_table, before(), expected())
661
662
663def test_multiple_input_subgraphs():
664    r"""Test the case when multiple input subgraphs feed into another subgraph.
665
666     (1)    (2)    (3)    (4)
667    add    add    add    add
668     |      |      |      |
669    relu   relu   relu   relu
670     \      /      \      /
671      \   /         \   /
672       add           sub
673        \            /
674          \        /
675            \    /
676              mul
677
678    ----> When 1=3 and 2=4 (Case 'A')
679
680    add_relu  add_relu
681       \         /
682        \      /
683       add_sub_mul
684
685    ----> When 1!=3 and 2!=4 (Case 'B')
686
687    add_relu  add_relu  add_relu  add_relu
688       \       /           \       /
689         \   /               \   /
690          add                 sub
691           \                  /
692            --------     -----
693                   \    /
694                    mul
695
696    The difference in behaviour comes from the fact that add_sub_mul expects that the
697    inputs to add and sub are identical (the same two relay expressions). So when you
698    have 4 independent inputs, the pattern should not be merged.
699    """
700
701    def before():
702        before_funcs = {}
703        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)]
704        add_relu_1 = relay.add(inputs[0], inputs[1])
705        add_relu_1 = relay.nn.relu(add_relu_1)
706        add_relu_2 = relay.add(inputs[2], inputs[3])
707        add_relu_2 = relay.nn.relu(add_relu_2)
708        add_relu_3 = relay.add(inputs[4], inputs[5])
709        add_relu_3 = relay.nn.relu(add_relu_3)
710        add_relu_4 = relay.add(inputs[6], inputs[7])
711        add_relu_4 = relay.nn.relu(add_relu_4)
712        add = relay.add(add_relu_1, add_relu_2)
713        sub = relay.subtract(add_relu_3, add_relu_4)
714        out = relay.multiply(add, sub)
715        before_funcs["B"] = relay.Function(inputs, out)
716        sub = relay.subtract(add_relu_1, add_relu_2)
717        out = relay.multiply(add, sub)
718        before_funcs["A"] = relay.Function(inputs[:4], out)
719        return before_funcs
720
721    def after_A():
722        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(4)]
723        x = relay.var("x")
724        y = relay.var("y")
725        add_relu_1 = relay.add(x, y)
726        add_relu_1 = relay.nn.relu(add_relu_1)
727        add_relu_1 = relay.Function([x, y], add_relu_1)
728        add_relu_1 = add_relu_1.with_attr("Composite", "add_relu")
729        add_relu_1 = add_relu_1.with_attr("PartitionedFromPattern", "add_nn.relu_")
730        add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
731        x1 = relay.var("x1")
732        y1 = relay.var("y1")
733        add_relu_2 = relay.add(x1, y1)
734        add_relu_2 = relay.nn.relu(add_relu_2)
735        add_relu_2 = relay.Function([x1, y1], add_relu_2)
736        add_relu_2 = add_relu_2.with_attr("Composite", "add_relu")
737        add_relu_2 = add_relu_2.with_attr("PartitionedFromPattern", "add_nn.relu_")
738        add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
739        x2 = relay.var("x2")
740        y2 = relay.var("y2")
741        add = relay.add(x2, y2)
742        sub = relay.subtract(x2, y2)
743        add_sub_mul = relay.multiply(add, sub)
744        add_sub_mul = relay.Function([x2, y2], add_sub_mul)
745        add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
746        add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
747        add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
748        return relay.Function(inputs, add_sub_mul_call)
749
750    def after_B():
751        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)]
752        add_relu_calls = []
753        for i in range(4):
754            x = relay.var("x" + str(i))
755            y = relay.var("x" + str(i))
756            add_relu = relay.add(x, y)
757            add_relu = relay.nn.relu(add_relu)
758            add_relu = relay.Function([x, y], add_relu)
759            add_relu = add_relu.with_attr("Composite", "add_relu")
760            add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_")
761            add_relu_call = relay.Call(add_relu, [inputs[i * 2], inputs[i * 2 + 1]])
762            add_relu_calls.append(add_relu_call)
763
764        add = relay.add(add_relu_calls[0], add_relu_calls[1])
765        sub = relay.subtract(add_relu_calls[2], add_relu_calls[3])
766        out = relay.multiply(add, sub)
767        return relay.Function(inputs, out)
768
769    pattern_table = [
770        ("add_sub_mul", make_add_sub_mul_pattern()),
771        ("add_relu", make_add_relu_pattern()),
772    ]
773    check_result(pattern_table, before()["A"], after_A())
774    check_result(pattern_table, before()["B"], after_B())
775
776
777def test_tuple_get_item_merge():
778    """Test composite function can be merged from pattern containing TupleGetItem nodes."""
779    pattern_table = [("bn_relu", make_bn_relu_pattern())]
780
781    def before():
782        x = relay.var("x", shape=(1, 8))
783        gamma = relay.var("gamma", shape=(8,))
784        beta = relay.var("beta", shape=(8,))
785        moving_mean = relay.var("moving_mean", shape=(8,))
786        moving_var = relay.var("moving_var", shape=(8,))
787        bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
788        tuple_get_item_node = bn_node[0]
789        r = relay.nn.relu(tuple_get_item_node)
790        return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
791
792    def expected():
793        x = relay.var("x", shape=(1, 8))
794        beta = relay.var("beta", shape=(8,))
795        gamma = relay.var("gamma", shape=(8,))
796        moving_mean = relay.var("moving_mean", shape=(8,))
797        moving_var = relay.var("moving_var", shape=(8,))
798
799        # bn_relu function
800        in_1 = relay.var("x1", shape=(1, 8))
801        in_2 = relay.var("gamma1", shape=(8,))
802        in_3 = relay.var("beta1", shape=(8,))
803        in_4 = relay.var("moving_mean1", shape=(8,))
804        in_5 = relay.var("moving_var1", shape=(8,))
805        bn_node = relay.nn.batch_norm(in_1, in_2, in_3, in_4, in_5)
806        tuple_get_item_node = bn_node[0]
807        relu_node = relay.nn.relu(tuple_get_item_node)
808        bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
809        bn_relu = bn_relu.with_attr("Composite", "bn_relu")
810        bn_relu = bn_relu.with_attr(
811            "PartitionedFromPattern", "nn.batch_norm_TupleGetItem0_nn.relu_"
812        )
813
814        # merged function
815        r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
816        return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
817
818    check_result(pattern_table, before(), expected())
819
820
821def test_pattern_with_check():
822    def before():
823        x = relay.var("x", shape=(1, 10, 10, 10))
824        w = relay.var("w", shape=(10, 10, 3, 3))
825        b = relay.var("b", shape=(8,))
826        conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
827        bias = relay.nn.bias_add(conv, b)
828        relu = relay.nn.relu(bias)
829        return relay.Function([x, w, b], relu)
830
831    def _check_true(extract):
832        conv = extract.args[0].args[0]
833        return conv.attrs.data_layout == "NHWC"
834
835    def _check_false(extract):
836        conv = extract.args[0].args[0]
837        return conv.attrs.data_layout == "NCHW"
838
839    def expected():
840        x = relay.var("x")
841        w = relay.var("w")
842        b = relay.var("b")
843        conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
844        bias = relay.nn.bias_add(conv, b)
845        relu = relay.nn.relu(bias)
846        func = relay.Function([x, w, b], relu)
847        func = func.with_attr("Composite", "conv_bias_relu")
848        func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
849
850        x = relay.var("x", shape=(1, 10, 10, 10))
851        w = relay.var("w", shape=(10, 10, 3, 3))
852        b = relay.var("b", shape=(8,))
853        return relay.Function([x, w, b], func(x, w, b))
854
855    pattern_table_false = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)]
856    check_result(pattern_table_false, before(), before())
857
858    pattern_table_true = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)]
859    check_result(pattern_table_true, before(), expected())
860
861
862def test_diamond_not_merge():
863    r"""
864    The pattern on the left shouldn't match the structure on the right
865
866    relu             relu
867     | \              | \
868     | clip           | add
869     |  /             |  |
870     mul              | clip
871                      |  /
872                      mul
873    """
874
875    def get_pattern():
876        conv = make_conv_bias_relu_pattern()
877        clip = is_op("clip")(conv, wildcard(), wildcard())
878        return is_op("multiply")(conv, clip)
879
880    def get_net():
881        data = relay.var("data", shape=(1, 512, 28, 28))
882        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
883        conv = relay.nn.conv2d(data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1))
884        bias = relay.nn.bias_add(conv, relay.var("bias", shape=(256,)))
885        relu = relay.nn.relu(bias)
886        add = relay.op.add(relu, relay.const(1.0))
887        clip2 = relay.op.clip(add, 0, 255)
888        mul = relay.op.multiply(relu, clip2)
889        return relay.Function(relay.analysis.free_vars(mul), mul)
890
891    pattern_table = [("pat", get_pattern())]
892    net = get_net()
893    check_result(pattern_table, net, net)
894
895
896def test_type_check():
897    """Test that we can query tensor types in the 'check' function."""
898
899    def before():
900        x = relay.var("x", shape=(1, 10, 10, 10))
901        w = relay.var("w", shape=(10, 10, 3, 3))
902        b = relay.var("b", shape=(8,))
903        add = relay.op.add(x, x)
904        relu = relay.nn.relu(add)
905        conv = relay.nn.conv2d(
906            relu, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC"
907        )
908        bias = relay.nn.bias_add(conv, b)
909        relu2 = relay.nn.relu(bias)
910        return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType())
911
912    def expected_false():
913        x = relay.var("x", shape=(1, 10, 10, 10))
914        w = relay.var("w", shape=(10, 10, 3, 3))
915        b = relay.var("b", shape=(8,))
916
917        x0 = relay.var("x")
918        y0 = relay.var("y")
919
920        add = relay.op.add(y0, y0)
921        relu = relay.nn.relu(add)
922        func = relay.Function([x0, y0], relu)
923        func = func.with_attr("PartitionedFromPattern", "add_nn.relu_")
924        func = func.with_attr("Composite", "add_relu")
925        call = relay.Call(func, [x, x])
926
927        conv = relay.nn.conv2d(
928            call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC"
929        )
930        bias = relay.nn.bias_add(conv, b)
931        relu2 = relay.nn.relu(bias)
932        return relay.Function([x, w, b], relu2)
933
934    def expected_true():
935        x = relay.var("x", shape=(1, 10, 10, 10))
936        w = relay.var("w", shape=(10, 10, 3, 3))
937        b = relay.var("b", shape=(8,))
938
939        x0 = relay.var("x")
940        y0 = relay.var("y")
941
942        add = relay.op.add(y0, y0)
943        relu = relay.nn.relu(add)
944        func = relay.Function([x0, y0], relu)
945        func = func.with_attr("PartitionedFromPattern", "add_nn.relu_")
946        func = func.with_attr("Composite", "add_relu")
947        call = relay.Call(func, [x, x])
948
949        x2 = relay.var("x")
950        w1 = relay.var("w")
951        b1 = relay.var("b")
952        conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
953        bias = relay.nn.bias_add(conv, b1)
954        relu2 = relay.nn.relu(bias)
955        func = relay.Function([x2, w1, b1], relu2)
956        func = func.with_attr("Composite", "conv_bias_relu")
957        func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
958        call = relay.Call(func, [call, w, b])
959        return relay.Function([x, w, b], call)
960
961    def _check_type_true(extract):
962        conv = extract.args[0].args[0]
963        typ = conv.checked_type
964        return bool(typ.shape[0] == 1)
965
966    def _check_type_false(extract):
967        conv = extract.args[0].args[0]
968        typ = conv.checked_type
969        return bool(typ.shape[0] != 1)
970
971    pattern_table_false = [
972        ("add_relu", make_add_relu_pattern()),
973        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false),
974    ]
975    check_result(pattern_table_false, before(), expected_false())
976
977    pattern_table_true = [
978        ("add_relu", make_add_relu_pattern()),
979        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true),
980    ]
981    check_result(pattern_table_true, before(), expected_true())
982
983
984if __name__ == "__main__":
985    pytest.main([__file__])
986