1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18from tvm import relay
19from tvm.relay import transform
20from tvm.relay.testing import run_opt_pass
21
22
23def test_fuse_simple():
24    """Simple testcase."""
25    def before():
26        x = relay.var("x", shape=(10, 20))
27        y = relay.add(x, relay.const(1, "float32"))
28        z = relay.exp(y)
29        w = relay.squeeze(z)
30        return relay.Function([x], w)
31
32    def expected():
33        x = relay.var("p", shape=(10, 20))
34        y = relay.add(x, relay.const(1, "float32"))
35        z = relay.exp(y)
36        w = relay.squeeze(z)
37        f1 = relay.Function([x], w)
38        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
39        x = relay.var("x", shape=(10, 20))
40        y = relay.Call(f1, [x])
41        return relay.Function([x], y)
42
43    z = before()
44    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
45    zz = run_opt_pass(z, transform.FuseOps())
46    after = run_opt_pass(expected(), transform.InferType())
47    assert relay.analysis.alpha_equal(zz, after)
48
49
50def test_conv2d_fuse():
51    """Test fusion case of conv2d"""
52    def before(dshape):
53        x = relay.var("x", shape=dshape)
54        x = relay.add(x, relay.const(1, "float32"))
55        y = relay.nn.conv2d(x, relay.var("w1"),
56                            kernel_size=(3, 3),
57                            padding=(1, 1),
58                            channels=16)
59        # this is the next dominator.
60        y1 = relay.add(relay.const(1, "float32"), y)
61        y = relay.add(y, y1)
62        # second path
63        z2 = relay.nn.conv2d(y, relay.var("w2"),
64                             kernel_size=(1, 1),
65                             padding=(0,0),
66                             channels=16)
67        z3 = relay.nn.conv2d(y, relay.var("w3"),
68                             kernel_size=(3, 3),
69                             padding=(1,1),
70                             channels=16)
71        # add can only be fused to z1
72        z = relay.add(z2, z3)
73        return relay.Function(relay.analysis.free_vars(z), z)
74
75    def expected(dshape):
76        # segment 0
77        x = relay.var("p0", shape=dshape)
78        y = relay.add(x, relay.const(1, "float32"))
79        f0 = relay.Function([x], y)
80        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
81
82        # segment 1
83        x = relay.var("p0", shape=dshape)
84        w = relay.var("p1")
85        y = relay.nn.conv2d(x, w,
86                            kernel_size=(3, 3),
87                            padding=(1, 1),
88                            channels=16)
89        y1 = relay.add(relay.const(1, "float32"), y)
90        y = relay.add(y, y1)
91        f1 = relay.Function([x, w], y)
92        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
93
94        # segment 2
95        x = relay.var("p0", shape=dshape)
96        w = relay.var("p1")
97        z2 = relay.nn.conv2d(x, w,
98                             kernel_size=(3, 3),
99                             padding=(1,1),
100                             channels=16)
101        f2 = relay.Function([x, w], z2)
102        f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
103
104        # segment 3
105        x = relay.var("p0", shape=dshape)
106        w = relay.var("p1")
107        offset = relay.var("p2", shape=dshape)
108        z3 = relay.nn.conv2d(x, w,
109                             kernel_size=(1, 1),
110                             padding=(0, 0),
111                             channels=16)
112        z3 = relay.add(z3, offset)
113        f3 = relay.Function([x, w, offset], z3)
114        f3 = f3.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
115
116        # compose
117        x = relay.var("x", shape=dshape)
118        y = relay.Call(f0, [x])
119        y = relay.Call(f1, [y, relay.var("w1")])
120        z2 = relay.Call(f2, [y, relay.var("w3")])
121        z3 = relay.Call(f3, [y, relay.var("w2"), z2])
122        z = z3
123        return relay.Function(relay.analysis.free_vars(z), z)
124
125    dshape = (1, 16, 64, 64)
126    z = before(dshape)
127    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
128    after = run_opt_pass(expected(dshape), transform.InferType())
129    assert relay.analysis.alpha_equal(zz, after)
130
131
132def test_concatenate():
133    """Test fusion case involving concat op and Tuple node"""
134
135    def before(dshape):
136        x = relay.var("x", shape=dshape)
137        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
138        upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
139        concat = relay.concatenate((upsampled, x), axis=1)
140        out = relay.add(concat, relay.const(1, "float32"))
141        return relay.Function(relay.analysis.free_vars(out), out)
142
143    def expected(dshape):
144        x = relay.var("x", shape=dshape)
145        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
146        f0 = relay.Function([x], pooled)
147        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
148
149        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
150        p1 = relay.var("p1", shape=dshape)
151        upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
152        concat = relay.concatenate((upsampled, p1), axis=1)
153        out = relay.add(concat, relay.const(1, "float32"))
154        f1 = relay.Function([p0, p1], out)
155        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
156
157        x = relay.var("x", shape=dshape)
158        y = relay.Call(f0, [x])
159        z = relay.Call(f1, [y, x])
160        return relay.Function([x], z)
161
162    dshape = (1, 16, 64, 64)
163    z = before(dshape)
164    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
165    assert not relay.analysis.free_vars(zz)
166    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
167    assert not relay.analysis.free_vars(zz)
168    after = run_opt_pass(expected(dshape), transform.InferType())
169    assert relay.analysis.alpha_equal(zz, after)
170
171
172def test_tuple_root():
173    """Test fusion case where Tuple node is the root in its group"""
174
175    def before(dshape):
176        x = relay.var("x", shape=dshape)
177        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
178        upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
179        out = relay.Tuple((upsampled, x))
180        return relay.Function(relay.analysis.free_vars(out), out)
181
182    def expected(dshape):
183        x = relay.var("x", shape=dshape)
184        pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
185        f0 = relay.Function([x], pooled)
186        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
187
188        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
189        upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
190        f1 = relay.Function([p0], upsampled)
191        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
192
193        x = relay.var("x", shape=dshape)
194        y = relay.Call(f0, [x])
195        z = relay.Call(f1, [y])
196        tup = relay.Tuple((z, x))
197        return relay.Function([x], tup)
198
199    dshape = (1, 16, 64, 64)
200    z = before(dshape)
201    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
202    assert not relay.analysis.free_vars(zz)
203    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
204    assert not relay.analysis.free_vars(zz)
205    after = run_opt_pass(expected(dshape), transform.InferType())
206    assert relay.analysis.alpha_equal(zz, after)
207
208
209def test_stop_fusion():
210    def before(dshape):
211        x = relay.var("x", shape=dshape)
212        y = relay.add(x, relay.const(1, "float32"))
213        y = relay.annotation.stop_fusion(y)
214        z = relay.exp(y)
215        return relay.Function([x], z)
216
217    def expected(dshape):
218        x = relay.var("p0", shape=dshape)
219        y = relay.add(x, relay.const(1, "float32"))
220        f1 = relay.Function([x], y)
221        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
222
223        x = relay.var("p01", shape=dshape)
224        y = relay.exp(x)
225        f2 = relay.Function([x], y)
226        f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
227
228        x = relay.var("x", shape=dshape)
229        y = relay.Call(f1, [x])
230        z = relay.Call(f2, [y])
231        return relay.Function([x], z)
232
233    dshape = (10, 20)
234    z = before(dshape)
235    zz = run_opt_pass(z, transform.FuseOps())
236    after = run_opt_pass(expected(dshape), transform.InferType())
237    assert relay.analysis.alpha_equal(zz, after)
238
239
240def test_fuse_myia_regression():
241    def before(dshape, dtype):
242        x = relay.var('x', shape=dshape, dtype=dtype)
243        y = relay.var('y', shape=dshape, dtype=dtype)
244        sb = relay.ScopeBuilder()
245        with sb.if_scope(relay.op.greater(x, y)):
246            sb.ret(relay.Function([], x))
247        with sb.else_scope():
248            sb.ret(relay.Function([], y))
249        return relay.Function([x, y],
250            relay.Call(sb.get(), []))
251
252    def expected(dshape, dtype):
253        x = relay.var('x', shape=dshape, dtype=dtype)
254        y = relay.var('y', shape=dshape, dtype=dtype)
255        sb = relay.ScopeBuilder()
256        p1 = relay.var('p1', shape=dshape, dtype=dtype)
257        p2 = relay.var('p2', shape=dshape, dtype=dtype)
258        fused_gt = relay.Function([p1, p2],
259            relay.op.greater(p1, p2))
260        fused_gt = fused_gt.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
261        with sb.if_scope(fused_gt(x, y)):
262            sb.ret(relay.Function([], x))
263        with sb.else_scope():
264            sb.ret(relay.Function([], y))
265        return relay.Function([x, y],
266            relay.Call(sb.get(), []))
267
268    dshape = ()
269    dtype = 'int64'
270    f = before(dshape, dtype)
271    zz = run_opt_pass(f, transform.FuseOps())
272    after = run_opt_pass(expected(dshape, dtype), transform.InferType())
273    assert relay.analysis.alpha_equal(zz, after)
274
275
276def test_fuse_tuple_get_elemwise():
277    def before(dim):
278        X = relay.var("X", shape=(1, dim))
279        W = relay.var("W", shape=(3 * dim, dim))
280        matmul = relay.nn.dense(X, W)
281        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
282        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
283        return relay.Function([X, W], out)
284
285    def expected(dim):
286        p0 = relay.var("p0", shape=(1, dim))
287        p1 = relay.var("p1", shape=(3 * dim, dim))
288        matmul = relay.nn.dense(p0, p1)
289        f0 = relay.Function([p0, p1], matmul)
290        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
291
292        p01 = relay.var("p01", shape=(1, 3 * dim))
293        splitted = relay.split(p01, indices_or_sections=3, axis=1)
294        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
295        f1 = relay.Function([p01], out)
296        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
297
298        X = relay.var("X", shape=(1, dim))
299        W = relay.var("W", shape=(3 * dim, dim))
300        y = relay.Call(f0, [X, W])
301        z = relay.Call(f1, [y])
302        return relay.Function([X, W], z)
303
304    dim = 10
305    z = before(dim)
306    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
307    assert not relay.analysis.free_vars(zz)
308    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
309    assert not relay.analysis.free_vars(zz)
310    after = run_opt_pass(expected(dim), transform.InferType())
311    assert relay.analysis.alpha_equal(zz, after)
312
313
314def test_tuple_get_root():
315    def before(dim):
316        X = relay.var("X", shape=(1, 3 * dim))
317        W = relay.var("W", shape=(dim, dim))
318        splitted = relay.split(X, indices_or_sections=3, axis=1)
319        out = relay.nn.dense(splitted[0], W)
320        return relay.Function([X, W], out)
321
322    def expected(dim):
323        p0 = relay.var("p0", shape=(1, 3 * dim))
324        splitted = relay.split(p0, indices_or_sections=3, axis=1)
325        out = splitted[0]
326        f0 = relay.Function([p0], out)
327        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
328
329        p01 = relay.var("p01", shape=(1, dim))
330        p1 = relay.var("p1", shape=(dim, dim))
331        out = relay.nn.dense(p01, p1)
332        f1 = relay.Function([p01, p1], out)
333        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
334
335        X = relay.var("X", shape=(1, 3 * dim))
336        W = relay.var("W", shape=(dim, dim))
337        y = relay.Call(f0, [X])
338        z = relay.Call(f1, [y, W])
339        return relay.Function([X, W], z)
340
341    dim = 10
342    z = before(dim)
343    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
344    assert not relay.analysis.free_vars(zz)
345    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
346    assert not relay.analysis.free_vars(zz)
347    after = run_opt_pass(expected(dim), transform.InferType())
348    assert relay.analysis.alpha_equal(zz, after)
349
350
351fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
352fuse2 = relay.transform.FuseOps(fuse_opt_level=2)
353
354def test_tuple_intermediate():
355    def before(x):
356        inj = relay.squeeze(x)
357        y1 = relay.add(inj, relay.const(1, "float32"))
358        tmp = relay.squeeze(inj)
359        tmp = relay.add(tmp, relay.const(1, "float32"))
360        y2 = relay.add(tmp, relay.const(1, "float32"))
361        y3 = relay.add(inj, relay.const(1, "float32"))
362        concat = relay.concatenate((y1, y2, y3), axis=1)
363        out_inj = relay.squeeze(concat)
364        out = relay.add(out_inj, relay.const(1, "float32"))
365        return relay.Function(relay.analysis.free_vars(out), out)
366
367    def expected(p0):
368        f0 = before(p0)
369        f1 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
370        x = relay.var("x", shape=dshape)
371        y = relay.Call(f1, [x])
372        return relay.Function([x], y)
373
374    dshape = (1, 16, 64, 64)
375    x = relay.var("x", shape=dshape)
376    orig = before(x)
377    fuse0(relay.Module.from_expr(orig))
378    m = fuse2(relay.Module.from_expr(orig))
379    relay.build(m, 'llvm')
380    after = run_opt_pass(expected(x), transform.InferType())
381    assert relay.analysis.alpha_equal(m["main"], after)
382
383
384def test_tuple_consecutive():
385    def gen_intermediate_tuple(x):
386        y1 = relay.add(x, relay.const(1, "float32"))
387        y2 = relay.add(x, relay.const(1, "float32"))
388        y3 = relay.add(x, relay.const(1, "float32"))
389        concat = relay.concatenate((y1, y2, y3), axis=1)
390        out = relay.add(concat, relay.const(1, "float32"))
391        return out
392
393    def gen_consecutive_tuple(x):
394        y1 = gen_intermediate_tuple(x)
395        y2 = gen_intermediate_tuple(x)
396        y3 = gen_intermediate_tuple(x)
397        concat = relay.concatenate((y1, y2, y3), axis=1)
398        return concat
399
400    def before(x):
401        concat = gen_consecutive_tuple(x)
402        pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
403        out = relay.add(pooled, relay.const(1, "float32"))
404        out2 = relay.add(out, relay.const(1, "float32"))
405        out_tup = relay.Tuple((out, out2))
406        return relay.Function(relay.analysis.free_vars(out_tup), out_tup)
407
408    def expected(dshape):
409        p0 = relay.var("p0", shape=dshape)
410        concat = gen_consecutive_tuple(p0)
411        f0 = relay.Function([p0], concat)
412        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
413
414        p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
415        pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
416        out = relay.add(pooled, relay.const(1, "float32"))
417        f1 = relay.Function([p01], out)
418        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
419
420        p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
421        out = relay.add(p02, relay.const(1, "float32"))
422        f2 = relay.Function([p02], out)
423        f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
424
425        x = relay.var("x", shape=dshape)
426        y = relay.Call(f0, [x])
427        z = relay.Call(f1, [y])
428        z2 = relay.Call(f2, [z])
429
430        return relay.Function([x], relay.Tuple((z, z2)))
431
432    dshape = (1, 16, 64, 64)
433    x = relay.var("x", shape=dshape)
434    orig = before(x)
435    fuse0(relay.Module.from_expr(orig))
436    m = fuse2(relay.Module.from_expr(orig))
437    relay.build(m, 'llvm')
438    after = run_opt_pass(expected(dshape), transform.InferType())
439    assert relay.analysis.alpha_equal(m["main"], after)
440
441
442def test_inception_like():
443    def conv(data):
444        y = relay.nn.conv2d(data, relay.var("w"),
445                            kernel_size=(3, 3),
446                            padding=(1, 1),
447                            channels=16)
448        return relay.nn.relu(data=y)
449
450    def inception_like(data):
451        c0 = conv(data)
452        c1 = conv(data)
453        return relay.concatenate((c0, c1), axis=1)
454
455    def before(dshape):
456        x = relay.var("x", shape=dshape)
457        in1 = inception_like(x)
458        in2 = inception_like(in1)
459        return relay.Function(relay.analysis.free_vars(in2), in2)
460
461    def expected(dshape):
462        p0 = relay.var("p0", shape=dshape)
463        c = conv(p0)
464        f0 = relay.Function(relay.analysis.free_vars(c), c)
465        f0 = f0.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
466
467        p01 = relay.var("p01", shape=dshape)
468        c = conv(p01)
469        f1 = relay.Function(relay.analysis.free_vars(c), c)
470        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
471
472        p02 = relay.var("p02", shape=dshape)
473        p12 = relay.var("p12", shape=dshape)
474        concat1 = relay.concatenate((p02, p12), axis=1)
475        f_concat1 = relay.Function([p02, p12], concat1)
476        f_concat1 = f_concat1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
477
478        dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
479
480        p03 = relay.var("p03", shape=dshape2)
481        c = conv(p03)
482        f2 = relay.Function(relay.analysis.free_vars(c), c)
483        f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
484
485        p04 = relay.var("p04", shape=dshape2)
486        c = conv(p04)
487        f3 = relay.Function(relay.analysis.free_vars(c), c)
488        f3 = f3.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
489
490        p05 = relay.var("p05", shape=dshape)
491        p15 = relay.var("p15", shape=dshape)
492        concat2 = relay.concatenate((p05, p15), axis=1)
493        f_concat2 = relay.Function([p05, p15], concat2)
494        f_concat2 = f_concat2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
495
496        x = relay.var("x", shape=dshape)
497        c1 = relay.Call(f0, [x, relay.var("w1")])
498        c2 = relay.Call(f1, [x, relay.var("w2")])
499        concat = relay.Call(f_concat1, [c1, c2])
500        c3 = relay.Call(f2, [concat, relay.var("w3")])
501        c4 = relay.Call(f3, [concat, relay.var("w4")])
502        out = relay.Call(f_concat2, [c3, c4])
503
504        return relay.Function(relay.analysis.free_vars(out), out)
505
506    dshape = (1, 16, 64, 64)
507    orig = before(dshape)
508    fuse0(relay.Module.from_expr(orig))
509    m = fuse2(relay.Module.from_expr(orig))
510    relay.build(m, 'llvm')
511    after = run_opt_pass(expected(dshape), transform.InferType())
512    assert relay.analysis.alpha_equal(m["main"], after)
513
514
515def test_fuse_parallel_injective():
516    """Test fusing parallel injective ops to an elemwise op."""
517    def before():
518        x = relay.var("x", shape=(10, 20))
519        y = relay.add(x, relay.const(1, "float32"))
520        z = relay.squeeze(y)
521        u = relay.transpose(y, axes=[0, 1])
522        w = relay.left_shift(z, u)
523        return relay.Function([x], w)
524
525    def expected():
526        x = relay.var("p", shape=(10, 20))
527        y = relay.add(x, relay.const(1, "float32"))
528        z = relay.squeeze(y)
529        u = relay.transpose(y, axes=[0, 1])
530        w = relay.left_shift(z, u)
531        f1 = relay.Function([x], w)
532        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
533        x = relay.var("x", shape=(10, 20))
534        y = relay.Call(f1, [x])
535        return relay.Function([x], y)
536
537    z = before()
538    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0))
539    assert not relay.analysis.free_vars(zz)
540    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
541    assert not relay.analysis.free_vars(zz)
542    after = run_opt_pass(expected(), transform.InferType())
543    assert relay.analysis.alpha_equal(zz, after)
544
545
546def test_immutable():
547    """Verify the fusion pass won't change original module."""
548    def before():
549        x = relay.var("x", shape=(10, 20))
550        y = relay.add(x, relay.const(1, "float32"))
551        z = relay.exp(y)
552        w = relay.squeeze(z)
553        mod = relay.module.Module()
554        mod["main"] = relay.Function([x], w)
555        return mod
556
557    def expected():
558        x = relay.var("p", shape=(10, 20))
559        y = relay.add(x, relay.const(1, "float32"))
560        z = relay.exp(y)
561        w = relay.squeeze(z)
562        f1 = relay.Function([x], w)
563        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
564        x = relay.var("x", shape=(10, 20))
565        y = relay.Call(f1, [x])
566        mod = relay.module.Module()
567        mod["main"] = relay.Function([x], y)
568        return mod
569
570    mod = before()
571    new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
572    assert relay.analysis.alpha_equal(mod, before())
573    assert relay.analysis.alpha_equal(new_mod, expected())
574
575
576def test_split():
577    """Test that the result is well formed."""
578    x = relay.var("x", shape=(6, 9))
579    y = relay.split(x, 3).astuple()
580    a = relay.TupleGetItem(y, 0)
581    b = relay.TupleGetItem(y, 1)
582    c = relay.TupleGetItem(y, 2)
583    mod = relay.module.Module()
584    mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
585    mod = transform.FuseOps()(mod)
586
587def test_fuse_max():
588    """Test the constraint of number of nodes in op fusion."""
589    max_fused_ops = 256
590    # n is the number of nodes to be fused, should be less than 2*max_fused_ops
591    n = 300
592    def before():
593        x = relay.var("x", shape=(10, 20))
594        y = x
595        for i in range(n):
596            y = relay.exp(y)
597        return relay.Function([x], y)
598
599    def expected():
600        x = relay.var("p", shape=(10, 20))
601        y = x
602        for i in range(max_fused_ops):
603            y = relay.exp(y)
604        f1 = relay.Function([x], y)
605        f1 = f1.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
606        x = relay.var("x", shape=(10, 20))
607        z = relay.Call(f1, [x])
608        xx = relay.var("pp", shape=(10, 20))
609        yy = xx
610        for i in range(n-max_fused_ops):
611            yy = relay.exp(yy)
612        f2 = relay.Function([xx], yy)
613        f2 = f2.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
614        zz = relay.Call(f2, [z])
615        return relay.Function([x], zz)
616
617    z = before()
618    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
619    zz = run_opt_pass(z, transform.FuseOps())
620    after = run_opt_pass(expected(), transform.InferType())
621    assert relay.analysis.alpha_equal(zz, after)
622
623if __name__ == "__main__":
624    test_fuse_simple()
625    test_conv2d_fuse()
626    test_concatenate()
627    test_tuple_root()
628    test_stop_fusion()
629    test_fuse_myia_regression()
630    test_fuse_tuple_get_elemwise()
631    test_tuple_get_root()
632    test_tuple_intermediate()
633    test_tuple_consecutive()
634    test_inception_like()
635    test_fuse_parallel_injective()
636    test_immutable()
637    test_split()
638    test_fuse_max()
639