1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18from tvm import te
19from tvm import relay
20from tvm.relay import transform
21from tvm.relay.testing import run_opt_pass
22import tvm.testing
23
24
25def test_fuse_simple():
26    """Simple testcase."""
27
28    def before():
29        x = relay.var("x", shape=(10, 20))
30        y = relay.add(x, relay.const(1, "float32"))
31        z = relay.exp(y)
32        w = relay.squeeze(z)
33        return relay.Function([x], w)
34
35    def expected():
36        x = relay.var("p", shape=(10, 20))
37        y = relay.add(x, relay.const(1, "float32"))
38        z = relay.exp(y)
39        w = relay.squeeze(z)
40        f1 = relay.Function([x], w)
41        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
42        x = relay.var("x", shape=(10, 20))
43        y = relay.Call(f1, [x])
44        return relay.Function([x], y)
45
46    z = before()
47    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
48    zz = run_opt_pass(z, transform.FuseOps())
49    after = run_opt_pass(expected(), transform.InferType())
50    assert tvm.ir.structural_equal(zz, after)
51
52
53def test_conv2d_fuse():
54    """Test fusion case of conv2d"""
55
56    def before(dshape):
57        x = relay.var("x", shape=dshape)
58        x = relay.add(x, relay.const(1, "float32"))
59        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=16)
60        # this is the next dominator.
61        y1 = relay.add(relay.const(1, "float32"), y)
62        y = relay.add(y, y1)
63        # second path
64        z2 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(1, 1), padding=(0, 0), channels=16)
65        z3 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=16)
66        # add can only be fused to z1
67        z = relay.add(z2, z3)
68        return relay.Function(relay.analysis.free_vars(z), z)
69
70    def expected(dshape):
71        # segment 0
72        x = relay.var("p0", shape=dshape)
73        y = relay.add(x, relay.const(1, "float32"))
74        f0 = relay.Function([x], y)
75        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
76
77        # segment 1
78        x = relay.var("p0", shape=dshape)
79        w = relay.var("p1")
80        y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16)
81        y1 = relay.add(relay.const(1, "float32"), y)
82        y = relay.add(y, y1)
83        f1 = relay.Function([x, w], y)
84        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
85
86        # segment 2
87        x = relay.var("p0", shape=dshape)
88        w = relay.var("p1")
89        z2 = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16)
90        f2 = relay.Function([x, w], z2)
91        f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
92
93        # segment 3
94        x = relay.var("p0", shape=dshape)
95        w = relay.var("p1")
96        offset = relay.var("p2", shape=dshape)
97        z3 = relay.nn.conv2d(x, w, kernel_size=(1, 1), padding=(0, 0), channels=16)
98        z3 = relay.add(z3, offset)
99        f3 = relay.Function([x, w, offset], z3)
100        f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
101
102        # compose
103        x = relay.var("x", shape=dshape)
104        y = relay.Call(f0, [x])
105        y = relay.Call(f1, [y, relay.var("w1")])
106        z2 = relay.Call(f2, [y, relay.var("w3")])
107        z3 = relay.Call(f3, [y, relay.var("w2"), z2])
108        z = z3
109        return relay.Function(relay.analysis.free_vars(z), z)
110
111    dshape = (1, 16, 64, 64)
112    z = before(dshape)
113    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
114    after = run_opt_pass(expected(dshape), transform.InferType())
115    assert tvm.ir.structural_equal(zz, after)
116
117
118def test_concatenate():
119    """Test fusion case involving concat op and Tuple node"""
120
121    def before(dshape):
122        x = relay.var("x", shape=dshape)
123        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
124        upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
125        concat = relay.concatenate((upsampled, x), axis=1)
126        out = relay.add(concat, relay.const(1, "float32"))
127        return relay.Function(relay.analysis.free_vars(out), out)
128
129    def expected(dshape):
130        x = relay.var("x", shape=dshape)
131        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
132        f0 = relay.Function([x], pooled)
133        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
134
135        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2))
136        p1 = relay.var("p1", shape=dshape)
137        upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
138        concat = relay.concatenate((upsampled, p1), axis=1)
139        out = relay.add(concat, relay.const(1, "float32"))
140        f1 = relay.Function([p0, p1], out)
141        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
142
143        x = relay.var("x", shape=dshape)
144        y = relay.Call(f0, [x])
145        z = relay.Call(f1, [y, x])
146        return relay.Function([x], z)
147
148    dshape = (1, 16, 64, 64)
149    z = before(dshape)
150    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
151    assert not relay.analysis.free_vars(zz)
152    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
153    assert not relay.analysis.free_vars(zz)
154    after = run_opt_pass(expected(dshape), transform.InferType())
155    assert tvm.ir.structural_equal(zz, after)
156
157
158def test_tuple_root():
159    """Test fusion case where Tuple node is the root in its group"""
160
161    def before(dshape):
162        x = relay.var("x", shape=dshape)
163        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
164        upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
165        out = relay.Tuple((upsampled, x))
166        return relay.Function(relay.analysis.free_vars(out), out)
167
168    def expected(dshape):
169        x = relay.var("x", shape=dshape)
170        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
171        f0 = relay.Function([x], pooled)
172        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
173
174        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2))
175        upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
176        f1 = relay.Function([p0], upsampled)
177        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
178
179        x = relay.var("x", shape=dshape)
180        y = relay.Call(f0, [x])
181        z = relay.Call(f1, [y])
182        tup = relay.Tuple((z, x))
183        return relay.Function([x], tup)
184
185    dshape = (1, 16, 64, 64)
186    z = before(dshape)
187    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
188    assert not relay.analysis.free_vars(zz)
189    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
190    assert not relay.analysis.free_vars(zz)
191    after = run_opt_pass(expected(dshape), transform.InferType())
192    assert tvm.ir.structural_equal(zz, after)
193
194
195def test_stop_fusion():
196    def before(dshape):
197        x = relay.var("x", shape=dshape)
198        y = relay.add(x, relay.const(1, "float32"))
199        y = relay.annotation.stop_fusion(y)
200        z = relay.exp(y)
201        return relay.Function([x], z)
202
203    def expected(dshape):
204        x = relay.var("p0", shape=dshape)
205        y = relay.add(x, relay.const(1, "float32"))
206        f1 = relay.Function([x], y)
207        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
208
209        x = relay.var("p01", shape=dshape)
210        y = relay.exp(x)
211        f2 = relay.Function([x], y)
212        f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
213
214        x = relay.var("x", shape=dshape)
215        y = relay.Call(f1, [x])
216        z = relay.Call(f2, [y])
217        return relay.Function([x], z)
218
219    dshape = (10, 20)
220    z = before(dshape)
221    zz = run_opt_pass(z, transform.FuseOps())
222    after = run_opt_pass(expected(dshape), transform.InferType())
223    assert tvm.ir.structural_equal(zz, after)
224
225
226def test_fuse_myia_regression():
227    def before(dshape, dtype):
228        x = relay.var("x", shape=dshape, dtype=dtype)
229        y = relay.var("y", shape=dshape, dtype=dtype)
230        sb = relay.ScopeBuilder()
231        with sb.if_scope(relay.op.greater(x, y)):
232            sb.ret(relay.Function([], x))
233        with sb.else_scope():
234            sb.ret(relay.Function([], y))
235        return relay.Function([x, y], relay.Call(sb.get(), []))
236
237    def expected(dshape, dtype):
238        x = relay.var("x", shape=dshape, dtype=dtype)
239        y = relay.var("y", shape=dshape, dtype=dtype)
240        sb = relay.ScopeBuilder()
241        p1 = relay.var("p1", shape=dshape, dtype=dtype)
242        p2 = relay.var("p2", shape=dshape, dtype=dtype)
243        fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2))
244        fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
245        with sb.if_scope(fused_gt(x, y)):
246            sb.ret(relay.Function([], x))
247        with sb.else_scope():
248            sb.ret(relay.Function([], y))
249        return relay.Function([x, y], relay.Call(sb.get(), []))
250
251    dshape = ()
252    dtype = "int64"
253    f = before(dshape, dtype)
254    zz = run_opt_pass(f, transform.FuseOps())
255    after = run_opt_pass(expected(dshape, dtype), transform.InferType())
256    assert tvm.ir.structural_equal(zz, after)
257
258
259def test_fuse_tuple_get_elemwise():
260    def before(dim):
261        X = relay.var("X", shape=(1, dim))
262        W = relay.var("W", shape=(3 * dim, dim))
263        matmul = relay.nn.dense(X, W)
264        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
265        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
266        return relay.Function([X, W], out)
267
268    def expected(dim):
269        p0 = relay.var("p0", shape=(1, dim))
270        p1 = relay.var("p1", shape=(3 * dim, dim))
271        matmul = relay.nn.dense(p0, p1)
272        f0 = relay.Function([p0, p1], matmul)
273        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
274
275        p01 = relay.var("p01", shape=(1, 3 * dim))
276        splitted = relay.split(p01, indices_or_sections=3, axis=1)
277        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
278        f1 = relay.Function([p01], out)
279        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
280
281        X = relay.var("X", shape=(1, dim))
282        W = relay.var("W", shape=(3 * dim, dim))
283        y = relay.Call(f0, [X, W])
284        z = relay.Call(f1, [y])
285        return relay.Function([X, W], z)
286
287    dim = 10
288    z = before(dim)
289    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
290    assert not relay.analysis.free_vars(zz)
291    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
292    assert not relay.analysis.free_vars(zz)
293    after = run_opt_pass(expected(dim), transform.InferType())
294    assert tvm.ir.structural_equal(zz, after)
295
296
297def test_tuple_get_root():
298    def before(dim):
299        X = relay.var("X", shape=(1, 3 * dim))
300        W = relay.var("W", shape=(dim, dim))
301        splitted = relay.split(X, indices_or_sections=3, axis=1)
302        out = relay.nn.dense(splitted[0], W)
303        return relay.Function([X, W], out)
304
305    def expected(dim):
306        p0 = relay.var("p0", shape=(1, 3 * dim))
307        splitted = relay.split(p0, indices_or_sections=3, axis=1)
308        out = splitted[0]
309        f0 = relay.Function([p0], out)
310        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
311
312        p01 = relay.var("p01", shape=(1, dim))
313        p1 = relay.var("p1", shape=(dim, dim))
314        out = relay.nn.dense(p01, p1)
315        f1 = relay.Function([p01, p1], out)
316        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
317
318        X = relay.var("X", shape=(1, 3 * dim))
319        W = relay.var("W", shape=(dim, dim))
320        y = relay.Call(f0, [X])
321        z = relay.Call(f1, [y, W])
322        return relay.Function([X, W], z)
323
324    dim = 10
325    z = before(dim)
326    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
327    assert not relay.analysis.free_vars(zz)
328    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
329    assert not relay.analysis.free_vars(zz)
330    after = run_opt_pass(expected(dim), transform.InferType())
331    assert tvm.ir.structural_equal(zz, after)
332
333
334fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
335fuse2 = relay.transform.FuseOps(fuse_opt_level=2)
336
337
338def test_tuple_intermediate():
339    def before(x):
340        inj = relay.squeeze(x)
341        y1 = relay.add(inj, relay.const(1, "float32"))
342        tmp = relay.squeeze(inj)
343        tmp = relay.add(tmp, relay.const(1, "float32"))
344        y2 = relay.add(tmp, relay.const(1, "float32"))
345        y3 = relay.add(inj, relay.const(1, "float32"))
346        concat = relay.concatenate((y1, y2, y3), axis=1)
347        out_inj = relay.squeeze(concat)
348        out = relay.add(out_inj, relay.const(1, "float32"))
349        return relay.Function(relay.analysis.free_vars(out), out)
350
351    def expected(p0):
352        f0 = before(p0)
353        f1 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
354        x = relay.var("x", shape=dshape)
355        y = relay.Call(f1, [x])
356        return relay.Function([x], y)
357
358    dshape = (1, 16, 64, 64)
359    x = relay.var("x", shape=dshape)
360    orig = before(x)
361    fuse0(tvm.IRModule.from_expr(orig))
362    m = fuse2(tvm.IRModule.from_expr(orig))
363    relay.build(m, "llvm")
364    after = run_opt_pass(expected(x), transform.InferType())
365    assert tvm.ir.structural_equal(m["main"], after)
366
367
368def test_tuple_consecutive():
369    def gen_intermediate_tuple(x):
370        y1 = relay.add(x, relay.const(1, "float32"))
371        y2 = relay.add(x, relay.const(1, "float32"))
372        y3 = relay.add(x, relay.const(1, "float32"))
373        concat = relay.concatenate((y1, y2, y3), axis=1)
374        out = relay.add(concat, relay.const(1, "float32"))
375        return out
376
377    def gen_consecutive_tuple(x):
378        y1 = gen_intermediate_tuple(x)
379        y2 = gen_intermediate_tuple(x)
380        y3 = gen_intermediate_tuple(x)
381        concat = relay.concatenate((y1, y2, y3), axis=1)
382        return concat
383
384    def before(x):
385        concat = gen_consecutive_tuple(x)
386        pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
387        out = relay.add(pooled, relay.const(1, "float32"))
388        out2 = relay.add(out, relay.const(1, "float32"))
389        out_tup = relay.Tuple((out, out2))
390        return relay.Function(relay.analysis.free_vars(out_tup), out_tup)
391
392    def expected(dshape):
393        p0 = relay.var("p0", shape=dshape)
394        concat = gen_consecutive_tuple(p0)
395        f0 = relay.Function([p0], concat)
396        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
397
398        p01 = relay.var("p01", shape=(1, dshape[1] * 9, dshape[2], dshape[3]))
399        pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
400        out = relay.add(pooled, relay.const(1, "float32"))
401        f1 = relay.Function([p01], out)
402        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
403
404        p02 = relay.var("p02", shape=(1, dshape[1] * 9, dshape[2] // 2, dshape[3] // 2))
405        out = relay.add(p02, relay.const(1, "float32"))
406        f2 = relay.Function([p02], out)
407        f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
408
409        x = relay.var("x", shape=dshape)
410        y = relay.Call(f0, [x])
411        z = relay.Call(f1, [y])
412        z2 = relay.Call(f2, [z])
413
414        return relay.Function([x], relay.Tuple((z, z2)))
415
416    dshape = (1, 16, 64, 64)
417    x = relay.var("x", shape=dshape)
418    orig = before(x)
419    fuse0(tvm.IRModule.from_expr(orig))
420    m = fuse2(tvm.IRModule.from_expr(orig))
421    relay.build(m, "llvm")
422    after = run_opt_pass(expected(dshape), transform.InferType())
423    assert tvm.ir.structural_equal(m["main"], after)
424
425
426def test_inception_like():
427    def conv(data):
428        y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16)
429        return relay.nn.relu(data=y)
430
431    def inception_like(data):
432        c0 = conv(data)
433        c1 = conv(data)
434        return relay.concatenate((c0, c1), axis=1)
435
436    def before(dshape):
437        x = relay.var("x", shape=dshape)
438        in1 = inception_like(x)
439        in2 = inception_like(in1)
440        return relay.Function(relay.analysis.free_vars(in2), in2)
441
442    def expected(dshape):
443        p0 = relay.var("p0", shape=dshape)
444        c = conv(p0)
445        f0 = relay.Function(relay.analysis.free_vars(c), c)
446        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
447
448        p01 = relay.var("p01", shape=dshape)
449        c = conv(p01)
450        f1 = relay.Function(relay.analysis.free_vars(c), c)
451        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
452
453        p02 = relay.var("p02", shape=dshape)
454        p12 = relay.var("p12", shape=dshape)
455        concat1 = relay.concatenate((p02, p12), axis=1)
456        f_concat1 = relay.Function([p02, p12], concat1)
457        f_concat1 = f_concat1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
458
459        dshape2 = (dshape[0], dshape[1] * 2, dshape[2], dshape[3])
460
461        p03 = relay.var("p03", shape=dshape2)
462        c = conv(p03)
463        f2 = relay.Function(relay.analysis.free_vars(c), c)
464        f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
465
466        p04 = relay.var("p04", shape=dshape2)
467        c = conv(p04)
468        f3 = relay.Function(relay.analysis.free_vars(c), c)
469        f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
470
471        p05 = relay.var("p05", shape=dshape)
472        p15 = relay.var("p15", shape=dshape)
473        concat2 = relay.concatenate((p05, p15), axis=1)
474        f_concat2 = relay.Function([p05, p15], concat2)
475        f_concat2 = f_concat2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
476
477        x = relay.var("x", shape=dshape)
478        c1 = relay.Call(f0, [x, relay.var("w1")])
479        c2 = relay.Call(f1, [x, relay.var("w2")])
480        concat = relay.Call(f_concat1, [c1, c2])
481        c3 = relay.Call(f2, [concat, relay.var("w3")])
482        c4 = relay.Call(f3, [concat, relay.var("w4")])
483        out = relay.Call(f_concat2, [c3, c4])
484
485        return relay.Function(relay.analysis.free_vars(out), out)
486
487    dshape = (1, 16, 64, 64)
488    orig = before(dshape)
489    fuse0(tvm.IRModule.from_expr(orig))
490    m = fuse2(tvm.IRModule.from_expr(orig))
491    relay.build(m, "llvm")
492    after = run_opt_pass(expected(dshape), transform.InferType())
493    assert tvm.ir.structural_equal(m["main"], after)
494
495
496def test_fuse_parallel_injective():
497    """Test fusing parallel injective ops to an elemwise op."""
498
499    def before():
500        x = relay.var("x", shape=(10, 20))
501        y = relay.add(x, relay.const(1, "float32"))
502        z = relay.squeeze(y)
503        u = relay.transpose(y, axes=[0, 1])
504        w = relay.left_shift(z, u)
505        return relay.Function([x], w)
506
507    def expected():
508        x = relay.var("p", shape=(10, 20))
509        y = relay.add(x, relay.const(1, "float32"))
510        z = relay.squeeze(y)
511        u = relay.transpose(y, axes=[0, 1])
512        w = relay.left_shift(z, u)
513        f1 = relay.Function([x], w)
514        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
515        x = relay.var("x", shape=(10, 20))
516        y = relay.Call(f1, [x])
517        return relay.Function([x], y)
518
519    z = before()
520    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
521    assert not relay.analysis.free_vars(zz)
522    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
523    assert not relay.analysis.free_vars(zz)
524    after = run_opt_pass(expected(), transform.InferType())
525    assert tvm.ir.structural_equal(zz, after)
526
527
528def test_immutable():
529    """Verify the fusion pass won't change original module."""
530
531    def before():
532        x = relay.var("x", shape=(10, 20))
533        y = relay.add(x, relay.const(1, "float32"))
534        z = relay.exp(y)
535        w = relay.squeeze(z)
536        mod = tvm.IRModule()
537        mod["main"] = relay.Function([x], w)
538        return mod
539
540    def expected():
541        x = relay.var("p", shape=(10, 20))
542        y = relay.add(x, relay.const(1, "float32"))
543        z = relay.exp(y)
544        w = relay.squeeze(z)
545        f1 = relay.Function([x], w)
546        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
547        x = relay.var("x", shape=(10, 20))
548        y = relay.Call(f1, [x])
549        mod = tvm.IRModule()
550        mod["main"] = relay.Function([x], y)
551        return mod
552
553    mod = before()
554    new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
555    assert tvm.ir.structural_equal(mod, before())
556    assert tvm.ir.structural_equal(new_mod, expected())
557
558
559def test_split():
560    """Test that the result is well formed."""
561    x = relay.var("x", shape=(6, 9))
562    y = relay.split(x, 3).astuple()
563    a = relay.TupleGetItem(y, 0)
564    b = relay.TupleGetItem(y, 1)
565    c = relay.TupleGetItem(y, 2)
566    mod = tvm.IRModule()
567    mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
568    mod = transform.FuseOps()(mod)
569
570
571def test_fuse_max():
572    """Test the constraint of number of nodes in op fusion."""
573
574    def before(n):
575        x = relay.var("x", shape=(10, 20))
576        y = x
577        for i in range(n):
578            y = relay.exp(y)
579        return relay.Function([x], y)
580
581    def expected(n, max_fused_ops):
582        x = relay.var("p", shape=(10, 20))
583        y = x
584        for i in range(max_fused_ops):
585            y = relay.exp(y)
586        f1 = relay.Function([x], y)
587        f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
588        x = relay.var("x", shape=(10, 20))
589        z = relay.Call(f1, [x])
590        xx = relay.var("pp", shape=(10, 20))
591        yy = xx
592        # it is assumed that there are two fused functions
593        for i in range(n - max_fused_ops):
594            yy = relay.exp(yy)
595        f2 = relay.Function([xx], yy)
596        f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
597        zz = relay.Call(f2, [z])
598        return relay.Function([x], zz)
599
600    max_fused_ops = 256
601    n = 300
602    z = before(n)
603    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
604    zz = run_opt_pass(z, transform.FuseOps())
605    after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
606    assert tvm.ir.structural_equal(zz, after)
607
608    max_fused_ops = 10
609    n = 20
610    z = before(n)
611    after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
612
613    with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
614        zz = run_opt_pass(z, transform.FuseOps())
615
616    assert tvm.ir.structural_equal(zz, after)
617
618
619def test_fuse_take():
620    """Test fusion case involving concat and take"""
621
622    def before():
623        shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
624        x = relay.var("x", shape=shape)
625        concat = relay.concatenate([x, x], axis=-1)
626        out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
627        return relay.Function(relay.analysis.free_vars(out), out)
628
629    def expected():
630        shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
631        shape2 = (tvm.tir.const(1, "int64"),)
632        x = relay.var("x", shape=shape1)
633        p0 = relay.var("p0", shape=shape1)
634        p1 = relay.var("p1", shape=shape2, dtype="int64")
635        c = relay.const([0], dtype="int64")
636        concat = relay.concatenate([p0, p0], axis=-1)
637        out = relay.op.take(concat, indices=p1)
638
639        f0 = relay.Function([p0, p1], out)
640        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
641
642        y = relay.Call(f0, [x, c])
643        return relay.Function([x], y)
644
645    orig = before()
646    m = fuse2(tvm.IRModule.from_expr(orig))
647    relay.build(m, "llvm")
648    after = run_opt_pass(expected(), transform.InferType())
649    assert tvm.ir.structural_equal(m["main"], after)
650
651
652def test_fuse_gather_nd():
653    """Test fusion case involving concat and gather_nd"""
654
655    def before():
656        shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
657        x = relay.var("x", shape=shape)
658        concat = relay.concatenate([x, x], axis=-1)
659        out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64"))
660        return relay.Function(relay.analysis.free_vars(out), out)
661
662    def expected():
663        shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
664        shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64"))
665        x = relay.var("x", shape=shape1)
666        p0 = relay.var("p0", shape=shape1)
667        p1 = relay.var("p1", shape=shape2, dtype="int64")
668        c = relay.const([[0, 1], [1, 0]], dtype="int64")
669        concat = relay.concatenate([p0, p0], axis=-1)
670        out = relay.gather_nd(concat, indices=p1)
671
672        f0 = relay.Function([p0, p1], out)
673        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
674
675        y = relay.Call(f0, [x, c])
676        return relay.Function([x], y)
677
678    orig = before()
679    m = fuse2(tvm.IRModule.from_expr(orig))
680    relay.build(m, "llvm")
681    after = run_opt_pass(expected(), transform.InferType())
682    assert tvm.ir.structural_equal(m["main"], after)
683
684
685@tvm.testing.uses_gpu
686def test_fuse_bcast_reduce_scalar():
687    """Test fusion case with broadcast and reduction involving scalar"""
688
689    def before():
690        x = relay.var("x", shape=(), dtype="int32")
691        less = relay.less(x, relay.const(10, dtype="int32"))
692        z = relay.min(less)
693        return relay.Function([x], z)
694
695    def expected():
696        p0 = relay.var("p0", shape=(), dtype="int32")
697        less = relay.less(p0, relay.const(10, dtype="int32"))
698        z0 = relay.min(less)
699        f0 = relay.Function([p0], z0)
700        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
701
702        x = relay.var("x", shape=(), dtype="int32")
703        f = relay.Call(f0, [x])
704        return relay.Function([x], f)
705
706    orig = before()
707    m = fuse2(tvm.IRModule.from_expr(orig))
708    for tgt, ctx in tvm.testing.enabled_targets():
709        relay.build(m, tgt)
710    after = run_opt_pass(expected(), transform.InferType())
711    assert tvm.ir.structural_equal(m["main"], after)
712
713
714def test_fuse_max_diamond():
715    def create_diamond(x, branch_len):
716        x1 = x
717        x2 = x
718        for _ in range(branch_len):
719            x1 = relay.exp(x1)
720            x2 = relay.exp(x2)
721        return relay.add(x1, x2)
722
723    def before(branch_len, num_diamond):
724        x = relay.var("x", shape=(10, 20))
725        out = x
726        for _ in range(num_diamond):
727            out = create_diamond(out, branch_len)
728        return relay.Function([x], out)
729
730    def after(branch_len, num_diamond):
731        def create_diamond_func(inp):
732            inp_var = relay.var("p", shape=(10, 20))
733            d = create_diamond(inp_var, branch_len)
734            f = relay.Function([inp_var], d)
735            f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
736            return relay.Call(f, [inp])
737
738        inp = relay.var("x", shape=(10, 20))
739        out = inp
740        for _ in range(num_diamond):
741            out = create_diamond_func(out)
742        return relay.Function([inp], out)
743
744    branch_len = 5
745    max_fused_ops = branch_len * 2 + 1  # the number of ops in one diamond
746    num_diamond = 3
747
748    with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
749        fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps())
750
751    expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType())
752    assert tvm.ir.structural_equal(fused, expected)
753
754
755if __name__ == "__main__":
756    test_fuse_simple()
757    test_conv2d_fuse()
758    test_concatenate()
759    test_tuple_root()
760    test_stop_fusion()
761    test_fuse_myia_regression()
762    test_fuse_tuple_get_elemwise()
763    test_tuple_get_root()
764    test_tuple_intermediate()
765    test_tuple_consecutive()
766    test_inception_like()
767    test_fuse_parallel_injective()
768    test_immutable()
769    test_split()
770    test_fuse_max()
771    test_fuse_take()
772    test_fuse_gather_nd()
773    test_fuse_bcast_reduce_scalar()
774    test_fuse_max_diamond()
775