1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18
19import tvm
20from tvm import te
21from tvm import relay
22from tvm.relay import transform
23
24
25def _get_positive_scale(size):
26    return np.random.uniform(0.5, 1, size=size).astype("float32")
27
28
29def run_opt_pass(expr, opt_pass):
30    assert isinstance(opt_pass, tvm.transform.Pass)
31    mod = tvm.IRModule.from_expr(expr)
32    mod = opt_pass(mod)
33    entry = mod["main"]
34    return entry if isinstance(expr, relay.Function) else entry.body
35
36
37def test_fold_fwd_simple():
38    """Simple testcase."""
39
40    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
41        args = [x, conv_weight, in_bias]
42        x = relay.multiply(x, in_scale)
43        x = relay.nn.relu(x)
44        x = relay.add(x, in_bias)
45        y = relay.nn.conv2d(
46            x,
47            conv_weight,
48            channels=channels,
49            kernel_size=(3, 3),
50            padding=(1, 1),
51            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
52            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
53        )
54
55        return relay.Function(args, y)
56
57    def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking):
58        # use a fixed order of args so alpha equal check can pass
59        args = [x, conv_weight, in_bias]
60        if blocking:
61            squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3])
62            x = relay.nn.relu(x)
63            in_bias = relay.divide(
64                in_bias,
65                relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0])),
66            )  # NCHWc
67            x = relay.add(x, in_bias)
68            conv_weight = relay.multiply(
69                conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 2, 1))
70            )  # OIHWio
71        else:
72            squeezed_scale = relay.squeeze(in_scale, axis=[1, 2])
73            x = relay.nn.relu(x)
74            in_bias = relay.divide(
75                in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
76            )
77            x = relay.add(x, in_bias)
78            conv_weight = relay.multiply(
79                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
80            )
81
82        y = relay.nn.conv2d(
83            x,
84            conv_weight,
85            channels=channels,
86            kernel_size=(3, 3),
87            padding=(1, 1),
88            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
89            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
90        )
91        return relay.Function(args, y)
92
93    def check(shape, channels, blocking):
94        x = relay.var("x", shape=shape)
95        weight = relay.var("weight")
96        if blocking:
97            in_channels = shape[1] * shape[4]
98            in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
99            in_scale = relay.const(
100                _get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0]))
101            )
102        else:
103            in_channels = shape[1]
104            in_bias = relay.var("in_bias", shape=(in_channels, 1, 1))
105            in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
106        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
107        y1 = run_opt_pass(y1, transform.InferType())
108        type_dict = {x.name_hint: x.checked_type for x in y1.params}
109        weight = relay.var("weight", type_dict["weight"])
110        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
111        y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)
112
113        y1_folded = run_opt_pass(y1_folded, transform.InferType())
114        y1_expected = run_opt_pass(y1_expected, transform.InferType())
115        assert tvm.ir.structural_equal(y1_folded, y1_expected)
116
117    check((2, 4, 10, 10), 2, None)
118    check((2, 2, 10, 10, 2), 8, (2, 4))
119
120
121def test_fold_fwd_dual_path():
122    """scale axis being consumed by two consumers"""
123
124    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
125        args = [x, conv_weight, in_bias]
126        x = relay.multiply(in_scale, x)
127        x = relay.nn.relu(x)
128        x = relay.subtract(x, in_bias)
129        y1 = relay.nn.conv2d(
130            x,
131            conv_weight,
132            channels=channels,
133            kernel_size=(3, 3),
134            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
135            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
136            groups=channels,
137            padding=(1, 1),
138        )
139        y2 = relay.nn.conv2d(
140            x,
141            conv_weight,
142            channels=channels,
143            kernel_size=(3, 3),
144            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
145            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
146            groups=channels,
147            padding=(1, 1),
148        )
149        z = relay.add(y1, y2)
150        return relay.Function(args, z)
151
152    def expected(x, conv_weight, in_bias, in_scale, channels, blocking):
153        args = [x, conv_weight, in_bias]
154        x = relay.nn.relu(x)
155        if blocking:
156            _in_scale = relay.reshape(
157                in_scale, (1, 1, 1, channels // blocking[0], blocking[0])
158            )  # NHWCc
159        else:
160            _in_scale = in_scale
161        in_bias = relay.divide(in_bias, _in_scale)
162        x = relay.subtract(x, in_bias)
163        if blocking:
164            _in_scale = relay.reshape(
165                in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
166            )  # HWIOio
167        y1 = relay.nn.conv2d(
168            x,
169            relay.multiply(conv_weight, _in_scale),
170            channels=channels,
171            kernel_size=(3, 3),
172            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
173            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
174            groups=channels,
175            padding=(1, 1),
176        )
177        if blocking:
178            _in_scale = relay.reshape(
179                in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
180            )  # HWIOio
181        y2 = relay.nn.conv2d(
182            x,
183            relay.multiply(conv_weight, _in_scale),
184            channels=channels,
185            kernel_size=(3, 3),
186            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
187            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
188            groups=channels,
189            padding=(1, 1),
190        )
191        z = relay.add(y1, y2)
192        return relay.Function(args, z)
193
194    def check(dshape, channels, blocking):
195        x = relay.var("x", shape=dshape)
196        if blocking:
197            in_channels = dshape[3] * dshape[4]
198            wshape = (3, 3, 1, channels // blocking[1], 1, blocking[1])  # HWIOio
199            weight = relay.var("weight", shape=wshape)
200            in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0]))
201            in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0])))
202        else:
203            in_channels = dshape[-1]
204            wshape = (3, 3, 1, channels)  # HWIO
205            weight = relay.var("weight", shape=wshape)
206            in_bias = relay.var("in_bias", shape=(in_channels,))
207            in_scale = relay.const(
208                _get_positive_scale(
209                    in_channels,
210                )
211            )
212
213        # test depthwise
214        assert in_channels == channels
215
216        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
217        y1 = run_opt_pass(y1, transform.InferType())
218        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
219        type_dict = {x.name_hint: x.checked_type for x in y1.params}
220        weight = relay.var("weight", type_dict["weight"])
221        y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking)
222        y1_expected = run_opt_pass(y1_expected, transform.InferType())
223        assert tvm.ir.structural_equal(y1_folded, y1_expected)
224
225    check((2, 4, 10, 3), 3, None)
226    check((2, 4, 10, 2, 2), 4, (2, 2))
227
228
229def test_fold_fwd_fail():
230    """testcase where we canont fold"""
231
232    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
233        x = relay.multiply(x, in_scale)
234        xx = relay.nn.leaky_relu(x, alpha=0.1)
235        y1 = relay.nn.conv2d(
236            xx,
237            conv_weight,
238            channels=channels,
239            kernel_size=(3, 3),
240            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
241            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
242            padding=(1, 1),
243        )
244        z = relay.add(y1, x)
245        return relay.Function(relay.analysis.free_vars(z), z)
246
247    def check(shape, channels, blocking):
248        x = relay.var("x", shape=shape)
249        if blocking:
250            in_channels = shape[3] * shape[4]
251            in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0]))
252            in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0])))
253        else:
254            in_channels = shape[-1]
255            in_bias = relay.var("in_bias", shape=(in_channels,))
256            in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
257        # test depthwise
258        assert in_channels == channels
259        weight = relay.var("weight")
260        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
261        y1 = run_opt_pass(y1, transform.InferType())
262        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
263        assert tvm.ir.structural_equal(y1, y1_folded)
264
265    check((2, 11, 10, 4), 4, None)
266    check((2, 11, 10, 2, 2), 4, (2, 2))
267
268
269def test_fold_fwd_relu_fail():
270    """testcase where we canont fold because scale can not pass relu"""
271
272    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
273        x = relay.multiply(x, in_scale)
274        xx = relay.nn.relu(x)
275        y1 = relay.nn.conv2d(
276            xx,
277            conv_weight,
278            channels=channels,
279            kernel_size=(3, 3),
280            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
281            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
282            padding=(1, 1),
283        )
284        z = relay.add(y1, x)
285        return relay.Function(relay.analysis.free_vars(z), z)
286
287    def check(shape, channels, blocking, in_scale):
288        x = relay.var("x", shape=shape)
289        weight = relay.var("weight")
290        if blocking:
291            in_channels = shape[3] * shape[4]
292            in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
293        else:
294            in_channels = shape[-1]
295            in_bias = relay.var("in_bias", shape=(in_channels,))
296
297        assert in_channels == channels
298        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
299        y1 = run_opt_pass(y1, transform.InferType())
300        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
301        assert tvm.ir.structural_equal(y1, y1_folded)
302
303    in_scale = relay.var("in_scale", shape=(4,))
304    check((2, 11, 10, 4), 4, None, in_scale)
305    in_scale = relay.const(-_get_positive_scale((4,)))
306    check((2, 11, 10, 4), 4, None, in_scale)
307
308    in_scale = relay.var("in_scale", shape=(1, 1, 1, 2, 2))
309    check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
310    in_scale = relay.const(-_get_positive_scale((1, 1, 1, 2, 2)))
311    check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
312
313
314def test_fold_fwd_negative_scale():
315    """Testcase of folding negative scale"""
316
317    def before(x, conv_weight, in_scale, channels, blocking):
318        args = [x, conv_weight]
319        x = relay.multiply(x, in_scale)
320        y = relay.nn.conv2d(
321            x,
322            conv_weight,
323            channels=channels,
324            kernel_size=(3, 3),
325            padding=(1, 1),
326            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
327            kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW",
328        )
329        return relay.Function(args, y)
330
331    def expected(x, conv_weight, in_scale, in_channels, channels, blocking):
332        # use a fixed order of args so alpha equal check can pass
333        args = [x, conv_weight]
334        if blocking:
335            squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3])
336            conv_weight = relay.multiply(
337                conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 4, 1, 1, 4, 1))
338            )
339            # blocking by "i" in OIHWio
340        else:
341            squeezed_scale = relay.squeeze(in_scale, axis=[1, 2])
342            conv_weight = relay.multiply(
343                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
344            )
345        y = relay.nn.conv2d(
346            x,
347            conv_weight,
348            channels=channels,
349            kernel_size=(3, 3),
350            padding=(1, 1),
351            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
352            kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW",
353        )
354        return relay.Function(args, y)
355
356    def check(shape, channels, blocking):
357        x = relay.var("x", shape=shape)
358        if blocking:
359            in_channels = shape[1] * shape[4]
360            in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4])))
361        else:
362            in_channels = shape[1]
363            in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
364        weight = relay.var("weight")
365        y1 = before(x, weight, in_scale, channels, blocking)
366        y1 = run_opt_pass(y1, transform.InferType())
367        type_dict = {x.name_hint: x.checked_type for x in y1.params}
368        weight = relay.var("weight", type_dict["weight"])
369        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
370        y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking)
371        y1_expected = run_opt_pass(y1_expected, transform.InferType())
372        assert tvm.ir.structural_equal(y1_folded, y1_expected)
373
374    check((2, 4, 10, 10), 4, None)
375    check((2, 2, 10, 10, 2), 8, (2, 2))
376
377
378def test_fold_bwd_simple():
379    """Simple testcase."""
380
381    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
382        args = [x, conv_weight, out_bias]
383        if blocking:
384            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1]))
385        else:
386            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
387        y = relay.nn.conv2d(
388            x,
389            conv_weight,
390            channels=channels,
391            kernel_size=(3, 3),
392            padding=(1, 1),
393            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
394            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
395        )
396        y = relay.add(y, out_bias)
397        y = relay.nn.relu(y)
398        if blocking:
399            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1]))
400        y = relay.multiply(y, out_scale)
401        return relay.Function(args, y)
402
403    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
404        # use a fixed order of args so alpha equal check can pass
405        args = [x, conv_weight, out_bias]
406        if blocking:
407            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1]))
408            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1]))
409            squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3])
410            conv_weight = relay.multiply(
411                conv_weight,
412                relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])),
413            )
414        else:
415            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
416            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
417            conv_weight = relay.multiply(
418                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
419            )
420
421        y = relay.nn.conv2d(
422            x,
423            conv_weight,
424            channels=channels,
425            kernel_size=(3, 3),
426            padding=(1, 1),
427            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
428            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
429        )
430        if blocking:
431            out_bias = relay.multiply(
432                out_bias,
433                relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, blocking[1])),
434            )
435        else:
436            out_bias = relay.multiply(
437                out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
438            )
439        y = relay.add(y, out_bias)
440        y = relay.nn.relu(y)
441        return relay.Function(args, y)
442
443    def check(shape, in_channels, channels, blocking):
444        x = relay.var("x", shape=shape)
445        weight = relay.var("weight")
446        out_bias = relay.var("out_bias", shape=(channels,))
447        if blocking:
448            out_scale = relay.const(_get_positive_scale((channels,)))
449        else:
450            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
451        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
452        y1 = run_opt_pass(y1, transform.InferType())
453        type_dict = {x.name_hint: x.checked_type for x in y1.params}
454        weight = relay.var("weight", type_dict["weight"])
455        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
456        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
457        y1_expected = run_opt_pass(y1_expected, transform.InferType())
458        assert tvm.ir.structural_equal(y1_folded, y1_expected)
459
460    check((2, 4, 10, 10), 4, 8, None)
461    check((2, 2, 10, 10, 16), 32, 64, (16, 16))
462
463
464def test_fold_bwd_dual_path():
465    """Dual path testcase."""
466
467    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
468        args = [x, conv_weight, out_bias]
469        y1 = relay.nn.conv2d(
470            x,
471            conv_weight,
472            channels=channels,
473            kernel_size=(3, 3),
474            padding=(1, 1),
475            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
476            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
477        )
478        y1 = relay.nn.relu(y1)
479        y2 = relay.nn.conv2d(
480            x,
481            conv_weight,
482            channels=channels,
483            kernel_size=(3, 3),
484            padding=(1, 1),
485            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
486            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
487        )
488        y2 = relay.nn.relu(y2)
489        y = relay.add(y1, y2)
490        y = relay.multiply(y, out_scale)
491        return relay.Function(args, y)
492
493    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
494        # use a fixed order of args so alpha equal check can pass
495        args = [x, conv_weight, out_bias]
496        if not blocking:
497            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
498        squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
499
500        def fold_conv_weight():
501            if blocking:
502                return relay.multiply(
503                    conv_weight,
504                    relay.reshape(
505                        squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])
506                    ),
507                )
508            else:
509                return relay.multiply(
510                    conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
511                )
512
513        y1 = relay.nn.conv2d(
514            x,
515            fold_conv_weight(),
516            channels=channels,
517            kernel_size=(3, 3),
518            padding=(1, 1),
519            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
520            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
521        )
522        y1 = relay.nn.relu(y1)
523        y2 = relay.nn.conv2d(
524            x,
525            fold_conv_weight(),
526            channels=channels,
527            kernel_size=(3, 3),
528            padding=(1, 1),
529            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
530            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
531        )
532        y2 = relay.nn.relu(y2)
533        y = relay.add(y1, y2)
534        return relay.Function(args, y)
535
536    def check(shape, in_channels, channels, blocking):
537        x = relay.var("x", shape=shape)
538        weight = relay.var("weight")
539        if blocking:
540            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
541            out_scale = relay.const(
542                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
543            )
544        else:
545            out_bias = relay.var("out_bias", shape=(channels,))
546            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
547
548        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
549        y1 = run_opt_pass(y1, transform.InferType())
550        type_dict = {x.name_hint: x.checked_type for x in y1.params}
551        weight = relay.var("weight", type_dict["weight"])
552        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
553        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
554        y1_expected = run_opt_pass(y1_expected, transform.InferType())
555        assert tvm.ir.structural_equal(y1_folded, y1_expected)
556
557    check((2, 4, 10, 10), 4, 8, None)
558    check((2, 2, 10, 10, 2), 4, 8, (2, 2))
559
560
561def test_fold_bwd_dual_consumer():
562    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
563        args = [x, conv_weight, out_bias]
564        y0 = relay.nn.conv2d(
565            x,
566            conv_weight,
567            channels=channels,
568            kernel_size=(3, 3),
569            padding=(1, 1),
570            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
571            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
572        )
573        y0 = relay.multiply(y0, out_scale)
574        y0 = relay.nn.relu(y0)
575
576        y1 = relay.nn.conv2d(
577            y0,
578            conv_weight,
579            channels=channels,
580            kernel_size=(3, 3),
581            padding=(1, 1),
582            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
583            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
584        )
585        y1 = relay.multiply(y1, out_scale)
586        y1 = relay.nn.relu(y1)
587
588        y2 = relay.nn.conv2d(
589            y0,
590            conv_weight,
591            channels=channels,
592            kernel_size=(3, 3),
593            padding=(1, 1),
594            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
595            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
596        )
597        y2 = relay.multiply(y2, out_scale)
598        y2 = relay.nn.relu(y2)
599
600        y = relay.add(y1, y2)
601        return relay.Function(args, y)
602
603    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
604        # use a fixed order of args so alpha equal check can pass
605        args = [x, conv_weight, out_bias]
606
607        def fold_conv_weight():
608            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
609            if blocking:
610                return relay.multiply(
611                    conv_weight,
612                    relay.reshape(
613                        squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])
614                    ),
615                )
616            else:
617                return relay.multiply(
618                    conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
619                )
620
621        y0 = relay.nn.conv2d(
622            x,
623            fold_conv_weight(),
624            channels=channels,
625            kernel_size=(3, 3),
626            padding=(1, 1),
627            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
628            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
629        )
630        y0 = relay.nn.relu(y0)
631        y1 = relay.nn.conv2d(
632            y0,
633            fold_conv_weight(),
634            channels=channels,
635            kernel_size=(3, 3),
636            padding=(1, 1),
637            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
638            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
639        )
640        y1 = relay.nn.relu(y1)
641        y2 = relay.nn.conv2d(
642            y0,
643            fold_conv_weight(),
644            channels=channels,
645            kernel_size=(3, 3),
646            padding=(1, 1),
647            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
648            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
649        )
650        y2 = relay.nn.relu(y2)
651        y = relay.add(y1, y2)
652        return relay.Function(args, y)
653
654    def check(shape, in_channels, channels, blocking):
655        x = relay.var("x", shape=shape)
656        weight = relay.var("weight")
657        if blocking:
658            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
659            out_scale = relay.const(
660                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
661            )
662        else:
663            out_bias = relay.var("out_bias", shape=(channels,))
664            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
665
666        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
667        y1 = run_opt_pass(y1, transform.InferType())
668        type_dict = {x.name_hint: x.checked_type for x in y1.params}
669        weight = relay.var("weight", type_dict["weight"])
670        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
671        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
672        y1_expected = run_opt_pass(y1_expected, transform.InferType())
673        assert tvm.ir.structural_equal(y1_folded, y1_expected)
674
675    check((2, 4, 10, 10), 4, 4, None)
676    check((2, 2, 10, 10, 2), 4, 4, (2, 2))
677
678
679def test_fold_bwd_fail():
680    """Dual path testcase."""
681
682    def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
683        args = [x, conv_weight, out_bias]
684        y1 = relay.nn.conv2d(
685            x,
686            conv_weight,
687            channels=channels,
688            kernel_size=(3, 3),
689            padding=(1, 1),
690            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
691            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
692        )
693        y1 = relay.nn.relu(y1)
694        y2 = relay.nn.conv2d(
695            x,
696            conv_weight,
697            channels=channels,
698            kernel_size=(3, 3),
699            padding=(1, 1),
700            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
701            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
702            out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW",
703        )
704        # fold will fail because the axis from two path
705        # differs from each other.
706        y2 = relay.nn.relu(y2)
707        y = relay.add(y1, y2)
708        y = relay.multiply(y, out_scale)
709        return relay.Function(args, y)
710
711    def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
712        args = [x, conv_weight, out_bias]
713        y1 = relay.nn.conv2d(
714            x,
715            conv_weight,
716            channels=channels,
717            kernel_size=(3, 3),
718            padding=(1, 1),
719            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
720            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
721        )
722        y2 = relay.nn.relu(y1)
723        # fold will fail because y1 is referred also by y2
724        y1 = relay.multiply(y1, out_scale)
725        y = relay.add(y1, y2)
726        return relay.Function(args, y)
727
728    def check(shape, in_channels, channels, blocking, fbefore):
729        x = relay.var("x", shape=shape)
730        weight = relay.var("weight")
731        if blocking:
732            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
733            out_scale = relay.const(
734                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
735            )
736        else:
737            out_bias = relay.var("out_bias", shape=(channels, 1, 1))
738            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
739        y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking)
740        y1 = run_opt_pass(y1, transform.InferType())
741        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
742        assert tvm.ir.structural_equal(y1_folded, y1)
743
744    check((4, 4, 10, 10), 4, 4, None, fail1)
745    check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1)
746    check((4, 4, 10, 10), 4, 4, None, fail2)
747    check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2)
748
749
750def test_fold_bwd_relu_fail():
751    """testcase where we canont fold because scale can not pass relu"""
752
753    def before(x, conv_weight, out_scale, channels, blocking):
754        y = relay.nn.conv2d(
755            x,
756            conv_weight,
757            channels=channels,
758            kernel_size=(3, 3),
759            padding=(1, 1),
760            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
761            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
762        )
763        y = relay.nn.relu(y)
764        y = relay.multiply(x, out_scale)
765        return relay.Function(relay.analysis.free_vars(y), y)
766
767    def check(shape, channels, blocking, out_scale):
768        x = relay.var("x", shape=shape)
769        in_channels = shape[1]
770        weight = relay.var("weight")
771        y1 = before(x, weight, out_scale, channels, blocking)
772        y1 = run_opt_pass(y1, transform.InferType())
773        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
774        assert tvm.ir.structural_equal(y1, y1_folded)
775
776    out_scale = relay.var("in_scale", shape=(4, 1, 1))
777    check((4, 4, 10, 10), 4, None, out_scale)
778    out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
779    check((4, 4, 10, 10), 4, None, out_scale)
780
781    out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2))
782    check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
783    out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype(
784        "float32"
785    )
786    check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
787
788
789def test_fold_bwd_negative_scale():
790    """Testcase of folding negative scale"""
791
792    def before(x, conv_weight, out_scale, channels, blocking):
793        args = [x, conv_weight]
794        y = relay.nn.conv2d(
795            x,
796            conv_weight,
797            channels=channels,
798            kernel_size=(3, 3),
799            padding=(1, 1),
800            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
801            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
802        )
803        y = relay.multiply(y, out_scale)
804        return relay.Function(args, y)
805
806    def expected(x, conv_weight, out_scale, channels, blocking):
807        # use a fixed order of args so alpha equal check can pass
808        args = [x, conv_weight]
809        if blocking:
810            squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3])
811            conv_weight = relay.multiply(
812                conv_weight,
813                relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])),
814            )
815        else:
816            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
817            conv_weight = relay.multiply(
818                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
819            )
820        y = relay.nn.conv2d(
821            x,
822            conv_weight,
823            channels=channels,
824            kernel_size=(3, 3),
825            padding=(1, 1),
826            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
827            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
828        )
829        return relay.Function(args, y)
830
831    def check(shape, channels, blocking):
832        x = relay.var("x", shape=shape)
833        weight = relay.var("weight")
834        if blocking:
835            out_scale = relay.const(
836                -_get_positive_scale((1, channels // blocking[1], 1, 1, blocking[1]))
837            )
838        else:
839            out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
840        y1 = before(x, weight, out_scale, channels, blocking)
841        y1 = run_opt_pass(y1, transform.InferType())
842        type_dict = {x.name_hint: x.checked_type for x in y1.params}
843        weight = relay.var("weight", type_dict["weight"])
844        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
845        y1_expected = expected(x, weight, out_scale, channels, blocking)
846        y1_expected = run_opt_pass(y1_expected, transform.InferType())
847        assert tvm.ir.structural_equal(y1_folded, y1_expected)
848
849    check((2, 4, 10, 10), 8, None)
850    check((2, 2, 10, 10, 2), 8, (2, 2))
851
852
853if __name__ == "__main__":
854    test_fold_fwd_simple()
855    test_fold_fwd_dual_path()
856    test_fold_fwd_fail()
857    test_fold_fwd_relu_fail()
858    test_fold_fwd_negative_scale()
859    test_fold_bwd_simple()
860    test_fold_bwd_dual_path()
861    test_fold_bwd_dual_consumer()
862    test_fold_bwd_fail()
863    test_fold_bwd_relu_fail()
864    test_fold_bwd_negative_scale()
865