1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import tvm
19from tvm import te
20import numpy as np
21from tvm import relay
22from tvm.relay import transform
23from tvm.relay.testing import run_infer_type
24from tvm.contrib import graph_runtime
25from tvm.relay.testing.temp_op_attr import TempOpAttr
26
27# We use llvm target for testing functionality. `llvm` points to an older Intel
28# generation machine, that legalizes to a simple lowering. Therefore, the
29# legalization is overwritten such that it can be skipped and we use the
30# QNNCanonicalizeOps lowering for the testing.
31def legalize_qnn_conv2d(attrs, inputs, types):
32    return None
33
34
35def get_ref_func(
36    data,
37    kernel,
38    input_zero_point,
39    kernel_zero_point,
40    input_scale,
41    kernel_scale,
42    kernel_size,
43    padding,
44    strides,
45    dilation,
46    data_layout,
47    kernel_layout,
48    out_dtype,
49    groups,
50    channels=None,
51):
52    casted_data = relay.op.cast(data, "int32")
53    casted_kernel = relay.op.cast(kernel, "int32")
54    shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32"))
55    shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32"))
56    func = relay.op.nn.conv2d(
57        shifted_data,
58        shifted_kernel,
59        padding=padding,
60        strides=strides,
61        dilation=dilation,
62        groups=groups,
63        channels=channels,
64        kernel_size=kernel_size,
65        out_dtype=out_dtype,
66        data_layout=data_layout,
67        kernel_layout=kernel_layout,
68    )
69
70    func = relay.Function(relay.analysis.free_vars(func), func)
71    return func
72
73
74def get_qnn_func(
75    data,
76    kernel,
77    input_zero_point,
78    kernel_zero_point,
79    input_scale,
80    kernel_scale,
81    kernel_size,
82    padding,
83    strides,
84    dilation,
85    data_layout,
86    kernel_layout,
87    out_dtype,
88    channels,
89    groups,
90):
91    func = relay.qnn.op.conv2d(
92        data,
93        kernel,
94        input_zero_point=relay.const(input_zero_point, "int32"),
95        kernel_zero_point=relay.const(kernel_zero_point, "int32"),
96        input_scale=relay.const(input_scale, "float32"),
97        kernel_scale=relay.const(kernel_scale, "float32"),
98        kernel_size=kernel_size,
99        strides=strides,
100        dilation=dilation,
101        padding=padding,
102        out_dtype=out_dtype,
103        groups=groups,
104        channels=channels,
105        data_layout=data_layout,
106        kernel_layout=kernel_layout,
107    )
108
109    mod = relay.Function(relay.analysis.free_vars(func), func)
110    mod = tvm.IRModule.from_expr(mod)
111    return mod
112
113
114def get_funcs(
115    data_shape,
116    data_dtype,
117    kernel_shape,
118    kernel_dtype,
119    input_zero_point,
120    kernel_zero_point,
121    input_scale,
122    kernel_scale,
123    kernel_size,
124    padding,
125    strides,
126    dilation,
127    data_layout,
128    kernel_layout,
129    out_dtype,
130    groups=1,
131    channels=None,
132):
133    data = relay.var("data", shape=data_shape, dtype=data_dtype)
134    kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
135
136    ref_func = get_ref_func(
137        data,
138        kernel,
139        input_zero_point,
140        kernel_zero_point,
141        input_scale,
142        kernel_scale,
143        kernel_size,
144        padding,
145        strides,
146        dilation,
147        data_layout,
148        kernel_layout,
149        out_dtype,
150        groups,
151        channels,
152    )
153    ref_func = run_infer_type(ref_func)
154    ref_func = tvm.IRModule.from_expr(ref_func)
155    qnn_func = get_qnn_func(
156        data,
157        kernel,
158        input_zero_point,
159        kernel_zero_point,
160        input_scale,
161        kernel_scale,
162        kernel_size,
163        padding,
164        strides,
165        dilation,
166        data_layout,
167        kernel_layout,
168        out_dtype,
169        channels,
170        groups,
171    )
172
173    return (ref_func, qnn_func)
174
175
176def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype):
177    def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype):
178        # Keeping inputs multiple of 4 because of a bug in Average Pool2d
179        # https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377
180        low = -128
181        high = 127
182        if data_dtype == "uint8":
183            low = 0
184            high = 255
185        golden_data = np.random.randint(low=low, high=high, size=data_shape).astype(data_dtype)
186        low = -128
187        high = 127
188        if kernel_dtype == "uint8":
189            low = 0
190            high = 255
191        golden_weight = np.random.randint(low=low, high=high, size=kernel_shape).astype(
192            kernel_dtype
193        )
194        return (golden_data, golden_weight)
195
196    def get_output(func, golden_inputs):
197        with tvm.transform.PassContext(opt_level=2):
198            golden_data, golden_weight = golden_inputs
199            params = {"kernel": golden_weight}
200            graph, lib, params = relay.build(func, "llvm", params=params)
201            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
202            mod.set_input("data", golden_data)
203            mod.set_input(**params)
204            mod.run()
205            res = mod.get_output(0).asnumpy()
206            return res
207
208    golden_inputs = get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype)
209    golden_output = get_output(ref_func, golden_inputs)
210    qnn_output = get_output(qnn_func, golden_inputs)
211    np.testing.assert_equal(qnn_output, golden_output)
212
213
214def test_no_zero_point():
215    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
216
217        # uint8 input
218        data_shape = (2, 1, 2, 4)
219        data_dtype = "uint8"
220        kernel_shape = (3, 1, 2, 2)
221        kernel_dtype = "uint8"
222        ref_func, qnn_func = get_funcs(
223            data_shape=data_shape,
224            data_dtype=data_dtype,
225            kernel_shape=kernel_shape,
226            kernel_dtype=kernel_dtype,
227            input_zero_point=0,
228            kernel_zero_point=0,
229            input_scale=1.0,
230            kernel_scale=1.0,
231            kernel_size=(2, 2),
232            padding=(0, 0),
233            strides=(1, 1),
234            dilation=(1, 1),
235            data_layout="NCHW",
236            kernel_layout="OIHW",
237            out_dtype="int32",
238        )
239        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
240
241        # int8 input
242        data_shape = (2, 1, 2, 4)
243        data_dtype = "int8"
244        kernel_shape = (3, 1, 2, 2)
245        kernel_dtype = "int8"
246        ref_func, qnn_func = get_funcs(
247            data_shape=data_shape,
248            data_dtype=data_dtype,
249            kernel_shape=kernel_shape,
250            kernel_dtype=kernel_dtype,
251            input_zero_point=0,
252            kernel_zero_point=0,
253            input_scale=1.0,
254            kernel_scale=1.0,
255            kernel_size=(2, 2),
256            padding=(0, 0),
257            strides=(1, 1),
258            dilation=(1, 1),
259            data_layout="NCHW",
260            kernel_layout="OIHW",
261            out_dtype="int32",
262        )
263        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
264
265
266def test_kernel_zero_point():
267    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
268
269        # uint8 input
270        data_shape = (2, 4, 2, 4)
271        data_dtype = "uint8"
272        kernel_shape = (3, 4, 2, 2)
273        kernel_dtype = "uint8"
274        ref_func, qnn_func = get_funcs(
275            data_shape=data_shape,
276            data_dtype=data_dtype,
277            kernel_shape=kernel_shape,
278            kernel_dtype=kernel_dtype,
279            input_zero_point=0,
280            kernel_zero_point=1,
281            input_scale=1.0,
282            kernel_scale=1.0,
283            kernel_size=(2, 2),
284            padding=(0, 0),
285            strides=(1, 1),
286            dilation=(1, 1),
287            data_layout="NCHW",
288            kernel_layout="OIHW",
289            out_dtype="int32",
290        )
291        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
292
293        # int8 input
294        data_shape = (2, 1, 2, 4)
295        data_dtype = "int8"
296        kernel_shape = (3, 1, 2, 2)
297        kernel_dtype = "int8"
298        ref_func, qnn_func = get_funcs(
299            data_shape=data_shape,
300            data_dtype=data_dtype,
301            kernel_shape=kernel_shape,
302            kernel_dtype=kernel_dtype,
303            input_zero_point=0,
304            kernel_zero_point=5,
305            input_scale=1.0,
306            kernel_scale=1.0,
307            kernel_size=(2, 2),
308            padding=(0, 0),
309            strides=(1, 1),
310            dilation=(1, 1),
311            data_layout="NCHW",
312            kernel_layout="OIHW",
313            out_dtype="int32",
314        )
315        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
316
317
318def test_input_zero_point():
319    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
320
321        # uint8 input
322        data_shape = (2, 4, 2, 4)
323        data_dtype = "uint8"
324        kernel_shape = (3, 4, 2, 2)
325        kernel_dtype = "uint8"
326        ref_func, qnn_func = get_funcs(
327            data_shape=data_shape,
328            data_dtype=data_dtype,
329            kernel_shape=kernel_shape,
330            kernel_dtype=kernel_dtype,
331            input_zero_point=5,
332            kernel_zero_point=0,
333            input_scale=1.0,
334            kernel_scale=1.0,
335            kernel_size=(2, 2),
336            padding=(0, 0),
337            strides=(1, 1),
338            dilation=(1, 1),
339            data_layout="NCHW",
340            kernel_layout="OIHW",
341            out_dtype="int32",
342        )
343        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
344
345        # int8 input
346        data_shape = (2, 4, 2, 4)
347        data_dtype = "int8"
348        kernel_shape = (3, 4, 2, 2)
349        kernel_dtype = "int8"
350        ref_func, qnn_func = get_funcs(
351            data_shape=data_shape,
352            data_dtype=data_dtype,
353            kernel_shape=kernel_shape,
354            kernel_dtype=kernel_dtype,
355            input_zero_point=5,
356            kernel_zero_point=0,
357            input_scale=1.0,
358            kernel_scale=1.0,
359            kernel_size=(2, 2),
360            padding=(0, 0),
361            strides=(1, 1),
362            dilation=(1, 1),
363            data_layout="NCHW",
364            kernel_layout="OIHW",
365            out_dtype="int32",
366        )
367        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
368
369
370def test_both_zero_point():
371    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
372
373        # uint8 input
374        data_shape = (2, 4, 2, 4)
375        data_dtype = "uint8"
376        kernel_shape = (3, 4, 2, 2)
377        kernel_dtype = "uint8"
378        ref_func, qnn_func = get_funcs(
379            data_shape=data_shape,
380            data_dtype=data_dtype,
381            kernel_shape=kernel_shape,
382            kernel_dtype=kernel_dtype,
383            input_zero_point=5,
384            kernel_zero_point=3,
385            input_scale=1.0,
386            kernel_scale=1.0,
387            kernel_size=(2, 2),
388            padding=(0, 0),
389            strides=(1, 1),
390            dilation=(1, 1),
391            data_layout="NCHW",
392            kernel_layout="OIHW",
393            out_dtype="int32",
394        )
395        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
396
397        # int8 input
398        data_shape = (2, 4, 2, 4)
399        data_dtype = "int8"
400        kernel_shape = (3, 4, 2, 2)
401        kernel_dtype = "int8"
402        ref_func, qnn_func = get_funcs(
403            data_shape=data_shape,
404            data_dtype=data_dtype,
405            kernel_shape=kernel_shape,
406            kernel_dtype=kernel_dtype,
407            input_zero_point=5,
408            kernel_zero_point=3,
409            input_scale=1.0,
410            kernel_scale=1.0,
411            kernel_size=(2, 2),
412            padding=(0, 0),
413            strides=(1, 1),
414            dilation=(1, 1),
415            data_layout="NCHW",
416            kernel_layout="OIHW",
417            out_dtype="int32",
418        )
419        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
420
421
422def test_layout():
423    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
424
425        # uint8 input
426        data_shape = (2, 2, 4, 4)  # NHWC
427        data_dtype = "uint8"
428        kernel_shape = (2, 2, 4, 3)  # HWIO
429        kernel_dtype = "uint8"
430        ref_func, qnn_func = get_funcs(
431            data_shape=data_shape,
432            data_dtype=data_dtype,
433            kernel_shape=kernel_shape,
434            kernel_dtype=kernel_dtype,
435            input_zero_point=5,
436            kernel_zero_point=3,
437            input_scale=1.0,
438            kernel_scale=1.0,
439            kernel_size=(2, 2),
440            padding=(0, 0),
441            strides=(1, 1),
442            dilation=(1, 1),
443            data_layout="NHWC",
444            kernel_layout="HWIO",
445            out_dtype="int32",
446        )
447        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
448
449        # NHWC and HWOI layout. Used in depthwise conv.
450        data_shape = (2, 2, 4, 3)  # NHWC
451        data_dtype = "uint8"
452        kernel_shape = (2, 2, 3, 1)  # HWOI
453        kernel_dtype = "uint8"
454        ref_func, qnn_func = get_funcs(
455            data_shape=data_shape,
456            data_dtype=data_dtype,
457            kernel_shape=kernel_shape,
458            kernel_dtype=kernel_dtype,
459            input_zero_point=5,
460            kernel_zero_point=3,
461            input_scale=1.0,
462            kernel_scale=1.0,
463            kernel_size=(2, 2),
464            padding=(0, 0),
465            strides=(1, 1),
466            dilation=(1, 1),
467            groups=3,
468            data_layout="NHWC",
469            kernel_layout="HWOI",
470            out_dtype="int32",
471        )
472        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
473
474
475def test_padding():
476    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
477
478        # uint8 input
479        data_shape = (1, 4, 2, 2)
480        data_dtype = "uint8"
481        kernel_shape = (3, 4, 2, 2)
482        kernel_dtype = "uint8"
483        ref_func, qnn_func = get_funcs(
484            data_shape=data_shape,
485            data_dtype=data_dtype,
486            kernel_shape=kernel_shape,
487            kernel_dtype=kernel_dtype,
488            input_zero_point=8,
489            kernel_zero_point=5,
490            input_scale=1.0,
491            kernel_scale=1.0,
492            kernel_size=(2, 2),
493            padding=(1, 1),
494            strides=(1, 1),
495            dilation=(1, 1),
496            data_layout="NCHW",
497            kernel_layout="OIHW",
498            out_dtype="int32",
499        )
500        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
501
502        # Try different layout
503        data_shape = (2, 2, 4, 4)  # NHWC
504        data_dtype = "uint8"
505        kernel_shape = (2, 2, 4, 3)  # HWIO
506        kernel_dtype = "uint8"
507        ref_func, qnn_func = get_funcs(
508            data_shape=data_shape,
509            data_dtype=data_dtype,
510            kernel_shape=kernel_shape,
511            kernel_dtype=kernel_dtype,
512            input_zero_point=8,
513            kernel_zero_point=3,
514            input_scale=1.0,
515            kernel_scale=1.0,
516            kernel_size=(2, 2),
517            padding=(1, 1),
518            strides=(1, 1),
519            dilation=(1, 1),
520            data_layout="NHWC",
521            kernel_layout="HWIO",
522            out_dtype="int32",
523        )
524        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
525
526        # Try asymmetric padding
527        data_shape = (2, 2, 4, 4)  # NHWC
528        data_dtype = "uint8"
529        kernel_shape = (2, 2, 4, 3)  # HWIO
530        kernel_dtype = "uint8"
531        ref_func, qnn_func = get_funcs(
532            data_shape=data_shape,
533            data_dtype=data_dtype,
534            kernel_shape=kernel_shape,
535            kernel_dtype=kernel_dtype,
536            input_zero_point=8,
537            kernel_zero_point=3,
538            input_scale=1.0,
539            kernel_scale=1.0,
540            kernel_size=(2, 2),
541            padding=(1, 1, 2, 2),
542            strides=(1, 1),
543            dilation=(1, 1),
544            data_layout="NHWC",
545            kernel_layout="HWIO",
546            out_dtype="int32",
547        )
548        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
549
550
551def test_dilation():
552    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
553
554        # Non-zero kernel point - fall back to simpler lowering.
555        data_shape = (2, 4, 4, 4)
556        data_dtype = "uint8"
557        kernel_shape = (3, 4, 2, 2)
558        kernel_dtype = "uint8"
559        ref_func, qnn_func = get_funcs(
560            data_shape=data_shape,
561            data_dtype=data_dtype,
562            kernel_shape=kernel_shape,
563            kernel_dtype=kernel_dtype,
564            input_zero_point=5,
565            kernel_zero_point=3,
566            input_scale=1.0,
567            kernel_scale=1.0,
568            kernel_size=(2, 2),
569            padding=(0, 0),
570            strides=(1, 1),
571            dilation=(2, 2),
572            data_layout="NCHW",
573            kernel_layout="OIHW",
574            out_dtype="int32",
575        )
576        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
577
578        # Zero kernel point
579        data_shape = (2, 4, 4, 4)
580        data_dtype = "uint8"
581        kernel_shape = (3, 4, 2, 2)
582        kernel_dtype = "uint8"
583        ref_func, qnn_func = get_funcs(
584            data_shape=data_shape,
585            data_dtype=data_dtype,
586            kernel_shape=kernel_shape,
587            kernel_dtype=kernel_dtype,
588            input_zero_point=0,
589            kernel_zero_point=0,
590            input_scale=1.0,
591            kernel_scale=1.0,
592            kernel_size=(2, 2),
593            padding=(0, 0),
594            strides=(1, 1),
595            dilation=(2, 2),
596            data_layout="NCHW",
597            kernel_layout="OIHW",
598            out_dtype="int32",
599        )
600        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
601
602
603def test_const_folding():
604    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
605
606        data_shape = (2, 4, 2, 4)
607        data_dtype = "uint8"
608        kernel_shape = (3, 4, 2, 2)
609        kernel_dtype = "uint8"
610
611        golden_weight = np.random.randint(low=0, high=255, size=kernel_shape).astype(kernel_dtype)
612        data = relay.var("data", shape=data_shape, dtype=data_dtype)
613        kernel = relay.const(golden_weight)
614        qnn_func = get_qnn_func(
615            data,
616            kernel,
617            input_zero_point=8,
618            kernel_zero_point=3,
619            kernel_size=(2, 2),
620            input_scale=1.0,
621            kernel_scale=1.0,
622            padding=(0, 0),
623            strides=(1, 1),
624            dilation=(1, 1),
625            data_layout="NCHW",
626            kernel_layout="OIHW",
627            out_dtype="int32",
628            channels=kernel_shape[0],
629            groups=1,
630        )
631        folded_mod = transform.FoldConstant()(qnn_func)
632        folded_func = folded_mod["main"]
633        assert "reshape" not in folded_func.astext()
634
635
636def test_kernel_size_1x1():
637    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
638
639        # uint8 input
640        data_shape = (2, 4, 2, 4)
641        data_dtype = "uint8"
642        kernel_shape = (3, 4, 1, 1)
643        kernel_dtype = "uint8"
644        ref_func, qnn_func = get_funcs(
645            data_shape=data_shape,
646            data_dtype=data_dtype,
647            kernel_shape=kernel_shape,
648            kernel_dtype=kernel_dtype,
649            input_zero_point=5,
650            kernel_zero_point=3,
651            input_scale=1.0,
652            kernel_scale=1.0,
653            kernel_size=(1, 1),
654            padding=(0, 0),
655            strides=(1, 1),
656            dilation=(1, 1),
657            data_layout="NCHW",
658            kernel_layout="OIHW",
659            out_dtype="int32",
660        )
661        assert "avg_pool2d" not in qnn_func.astext()
662        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
663
664
665def test_kernel_size_1x1_strides_2():
666    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
667
668        # uint8 input
669        data_shape = (2, 4, 2, 4)
670        data_dtype = "uint8"
671        kernel_shape = (3, 4, 1, 1)
672        kernel_dtype = "uint8"
673        ref_func, qnn_func = get_funcs(
674            data_shape=data_shape,
675            data_dtype=data_dtype,
676            kernel_shape=kernel_shape,
677            kernel_dtype=kernel_dtype,
678            input_zero_point=5,
679            kernel_zero_point=3,
680            input_scale=1.0,
681            kernel_scale=1.0,
682            kernel_size=(1, 1),
683            padding=(0, 0),
684            strides=(2, 2),
685            dilation=(1, 1),
686            data_layout="NCHW",
687            kernel_layout="OIHW",
688            out_dtype="int32",
689        )
690        assert "avg_pool2d" not in qnn_func.astext()
691        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
692
693
694def test_tflite_large_irregular():
695    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
696
697        # uint8 input
698        data_shape = (1, 1024, 1, 1)
699        data_dtype = "uint8"
700        kernel_shape = (1001, 1024, 1, 1)
701        kernel_dtype = "uint8"
702        ref_func, qnn_func = get_funcs(
703            data_shape=data_shape,
704            data_dtype=data_dtype,
705            kernel_shape=kernel_shape,
706            kernel_dtype=kernel_dtype,
707            input_zero_point=127,
708            kernel_zero_point=127,
709            input_scale=1.0,
710            kernel_scale=1.0,
711            kernel_size=(1, 1),
712            padding=(0, 0),
713            strides=(1, 1),
714            dilation=(1, 1),
715            data_layout="NCHW",
716            kernel_layout="OIHW",
717            out_dtype="int32",
718        )
719        golden_data = np.full(data_shape, 127).astype("uint8")
720        golden_weight = np.full(kernel_shape, 127).astype("uint8")
721
722        with tvm.transform.PassContext(opt_level=2):
723            params = {"kernel": golden_weight}
724            graph, lib, params = relay.build(qnn_func, "llvm", params=params)
725            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
726            mod.set_input("data", golden_data)
727            mod.set_input(**params)
728            mod.run()
729            qnn_output = mod.get_output(0).asnumpy()
730        golden_output = np.full((1, 1001, 1, 1), 0).astype("uint8")
731        np.testing.assert_equal(qnn_output, golden_output)
732
733
734def test_tflite_output_multiplier_greater_than_one():
735    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
736
737        # uint8 input
738        data_shape = (2, 1, 2, 4)
739        data_dtype = "uint8"
740        kernel_shape = (3, 1, 2, 2)
741        kernel_dtype = "uint8"
742        ref_func, qnn_func = get_funcs(
743            data_shape=data_shape,
744            data_dtype=data_dtype,
745            kernel_shape=kernel_shape,
746            kernel_dtype=kernel_dtype,
747            input_scale=1.0,
748            kernel_scale=1.0,
749            input_zero_point=128,
750            kernel_zero_point=128,
751            kernel_size=(2, 2),
752            padding=(0, 0),
753            strides=(2, 2),
754            dilation=(1, 1),
755            data_layout="NCHW",
756            kernel_layout="OIHW",
757            out_dtype="int32",
758        )
759        golden_data = 128 + np.array((1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 3, 4, 1, 2, 3, 4)).reshape(
760            data_shape
761        ).astype("uint8")
762        golden_weight = 128 + np.array((1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1)).reshape(
763            kernel_shape
764        )
765        golden_weight = golden_weight.astype("uint8")
766
767        with tvm.transform.PassContext(opt_level=2):
768            params = {"kernel": golden_weight}
769            graph, lib, params = relay.build(qnn_func, "llvm", params=params)
770            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
771            mod.set_input("data", golden_data)
772            mod.set_input(**params)
773            mod.run()
774            qnn_output = mod.get_output(0).asnumpy()
775        golden_output = np.array((17, 17, 0, 0, 2, 2, 16, 36, 2, 2, 0, 0)).reshape(2, 3, 1, 2)
776        np.testing.assert_equal(qnn_output, golden_output)
777
778
779def test_tflite_anistropic_strides():
780    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
781
782        # uint8 input
783        data_shape = (1, 1, 3, 6)
784        data_dtype = "uint8"
785        kernel_shape = (1, 1, 2, 2)
786        kernel_dtype = "uint8"
787        ref_func, qnn_func = get_funcs(
788            data_shape=data_shape,
789            data_dtype=data_dtype,
790            kernel_shape=kernel_shape,
791            kernel_dtype=kernel_dtype,
792            input_zero_point=127,
793            kernel_zero_point=127,
794            input_scale=1.0,
795            kernel_scale=1.0,
796            kernel_size=(2, 2),
797            padding=(0, 0),
798            strides=(1, 3),
799            dilation=(1, 1),
800            data_layout="NCHW",
801            kernel_layout="OIHW",
802            out_dtype="int32",
803        )
804        golden_data = np.array(
805            (
806                133,
807                131,
808                129,
809                125,
810                123,
811                121,
812                135,
813                133,
814                131,
815                123,
816                121,
817                119,
818                137,
819                135,
820                133,
821                121,
822                119,
823                117,
824            )
825        ).reshape(data_shape)
826        golden_data = golden_data.astype("uint8")
827        golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape)
828        golden_weight = golden_weight.astype("uint8")
829
830        with tvm.transform.PassContext(opt_level=2):
831            params = {"kernel": golden_weight}
832            graph, lib, params = relay.build(qnn_func, "llvm", params=params)
833            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
834            mod.set_input("data", golden_data)
835            mod.set_input(**params)
836            mod.run()
837            qnn_output = mod.get_output(0).asnumpy()
838        golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
839        np.testing.assert_equal(qnn_output, golden_output)
840
841
842def test_broadcast_layout():
843    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
844
845        # Test broadcast support for NHWC layout.
846        data_shape = (1, 229, 229, 3)  # NHWC
847        data_dtype = "uint8"
848        kernel_shape = (7, 7, 3, 64)  # HWIO
849        kernel_dtype = "int8"
850        _, qnn_func = get_funcs(
851            data_shape=data_shape,
852            data_dtype=data_dtype,
853            kernel_shape=kernel_shape,
854            kernel_dtype=kernel_dtype,
855            input_zero_point=8,
856            kernel_zero_point=3,
857            input_scale=1.0,
858            kernel_scale=1.0,
859            kernel_size=(7, 7),
860            padding=(1, 1),
861            strides=(1, 1),
862            dilation=(1, 1),
863            data_layout="NHWC",
864            kernel_layout="HWIO",
865            out_dtype="int32",
866        )
867        func = qnn_func["main"].body
868        bias = relay.var("bias", shape=(64,), dtype="int32")
869        bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")
870
871        # Check broadcast support on both lhs and rhs
872        func = relay.add(func, bias2)
873        func = relay.add(bias2, func)
874        func = relay.add(bias, func)
875        func = relay.add(func, bias)
876        func = relay.Function(relay.analysis.free_vars(func), func)
877        mod = tvm.IRModule.from_expr(func)
878        with tvm.transform.PassContext(opt_level=3):
879            graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
880
881
882def test_depthwise_depth_multiplier():
883    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
884
885        # uint8 input, NCHW and OIHW
886        # Depthwise multiplier = 1
887        data_shape = (2, 4, 16, 16)
888        data_dtype = "uint8"
889        kernel_shape = (4, 1, 3, 3)
890        kernel_dtype = "uint8"
891        ref_func, qnn_func = get_funcs(
892            data_shape=data_shape,
893            data_dtype=data_dtype,
894            kernel_shape=kernel_shape,
895            kernel_dtype=kernel_dtype,
896            input_zero_point=5,
897            kernel_zero_point=3,
898            input_scale=1.0,
899            kernel_scale=1.0,
900            kernel_size=(3, 3),
901            padding=(0, 0),
902            strides=(1, 1),
903            dilation=(1, 1),
904            data_layout="NCHW",
905            kernel_layout="OIHW",
906            out_dtype="int32",
907            groups=4,
908        )
909
910        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
911
912        # Depthwise multiplier = 2
913        data_shape = (10, 4, 16, 16)
914        data_dtype = "uint8"
915        kernel_shape = (4, 2, 3, 3)
916        kernel_dtype = "uint8"
917        ref_func, qnn_func = get_funcs(
918            data_shape=data_shape,
919            data_dtype=data_dtype,
920            kernel_shape=kernel_shape,
921            kernel_dtype=kernel_dtype,
922            input_zero_point=5,
923            kernel_zero_point=3,
924            input_scale=1.0,
925            kernel_scale=1.0,
926            kernel_size=(3, 3),
927            padding=(0, 0),
928            strides=(1, 1),
929            dilation=(1, 1),
930            data_layout="NCHW",
931            kernel_layout="OIHW",
932            out_dtype="int32",
933            groups=4,
934            channels=8,
935        )
936        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
937
938        # uint8 input, NHWC and HWOI
939        # Depthwise multiplier = 1
940        data_shape = (2, 16, 16, 4)
941        data_dtype = "uint8"
942        kernel_shape = (3, 3, 4, 1)
943        kernel_dtype = "uint8"
944        ref_func, qnn_func = get_funcs(
945            data_shape=data_shape,
946            data_dtype=data_dtype,
947            kernel_shape=kernel_shape,
948            kernel_dtype=kernel_dtype,
949            input_zero_point=5,
950            kernel_zero_point=3,
951            input_scale=1.0,
952            kernel_scale=1.0,
953            kernel_size=(3, 3),
954            padding=(0, 0),
955            strides=(1, 1),
956            dilation=(1, 1),
957            data_layout="NHWC",
958            kernel_layout="HWOI",
959            out_dtype="int32",
960            groups=4,
961        )
962        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
963
964        # Depthwise multiplier = 2
965        data_shape = (2, 16, 16, 4)
966        data_dtype = "uint8"
967        kernel_shape = (3, 3, 4, 2)
968        kernel_dtype = "uint8"
969        ref_func, qnn_func = get_funcs(
970            data_shape=data_shape,
971            data_dtype=data_dtype,
972            kernel_shape=kernel_shape,
973            kernel_dtype=kernel_dtype,
974            input_zero_point=5,
975            kernel_zero_point=3,
976            input_scale=1.0,
977            kernel_scale=1.0,
978            kernel_size=(3, 3),
979            padding=(0, 0),
980            strides=(1, 1),
981            dilation=(1, 1),
982            data_layout="NHWC",
983            kernel_layout="HWOI",
984            out_dtype="int32",
985            groups=4,
986            channels=8,
987        )
988        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
989
990
991def test_per_channel_kernel_scale():
992    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
993        data_shape = (2, 1, 2, 4)
994        data_dtype = "uint8"
995        kernel_shape = (3, 1, 2, 2)
996        kernel_dtype = "uint8"
997        data = relay.var("data", shape=data_shape, dtype=data_dtype)
998        kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
999        kernel_scales = [2, 2, 2]
1000        kernel_scales = relay.const(np.array(kernel_scales).astype("float32"))
1001        func = relay.qnn.op.conv2d(
1002            data,
1003            kernel,
1004            input_zero_point=relay.const(0, "int32"),
1005            kernel_zero_point=relay.const(0, "int32"),
1006            input_scale=relay.const(2.0, "float32"),
1007            kernel_scale=kernel_scales,
1008            kernel_size=(2, 2),
1009            channels=kernel_shape[0],
1010            padding=(0, 0),
1011            strides=(1, 1),
1012            dilation=(1, 1),
1013            data_layout="NCHW",
1014            kernel_layout="OIHW",
1015            out_dtype="int32",
1016        )
1017
1018        mod = relay.Function(relay.analysis.free_vars(func), func)
1019        mod = tvm.IRModule.from_expr(mod)
1020
1021
1022if __name__ == "__main__":
1023    test_no_zero_point()
1024    test_input_zero_point()
1025    test_kernel_zero_point()
1026    test_both_zero_point()
1027    test_layout()
1028    test_padding()
1029    test_dilation()
1030    test_const_folding()
1031    test_kernel_size_1x1()
1032    test_kernel_size_1x1_strides_2()
1033    test_tflite_large_irregular()
1034    test_broadcast_layout()
1035    test_tflite_output_multiplier_greater_than_one()
1036    test_tflite_anistropic_strides()
1037    test_depthwise_depth_multiplier()
1038    test_per_channel_kernel_scale()
1039