1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import tvm 19from tvm import te 20import numpy as np 21from tvm import relay 22from tvm.relay import transform 23from tvm.relay.testing import run_infer_type 24from tvm.contrib import graph_runtime 25from tvm.relay.testing.temp_op_attr import TempOpAttr 26 27# We use llvm target for testing functionality. `llvm` points to an older Intel 28# generation machine, that legalizes to a simple lowering. Therefore, the 29# legalization is overwritten such that it can be skipped and we use the 30# QNNCanonicalizeOps lowering for the testing. 31def legalize_qnn_conv2d(attrs, inputs, types): 32 return None 33 34 35def get_ref_func( 36 data, 37 kernel, 38 input_zero_point, 39 kernel_zero_point, 40 input_scale, 41 kernel_scale, 42 kernel_size, 43 padding, 44 strides, 45 dilation, 46 data_layout, 47 kernel_layout, 48 out_dtype, 49 groups, 50 channels=None, 51): 52 casted_data = relay.op.cast(data, "int32") 53 casted_kernel = relay.op.cast(kernel, "int32") 54 shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32")) 55 shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32")) 56 func = relay.op.nn.conv2d( 57 shifted_data, 58 shifted_kernel, 59 padding=padding, 60 strides=strides, 61 dilation=dilation, 62 groups=groups, 63 channels=channels, 64 kernel_size=kernel_size, 65 out_dtype=out_dtype, 66 data_layout=data_layout, 67 kernel_layout=kernel_layout, 68 ) 69 70 func = relay.Function(relay.analysis.free_vars(func), func) 71 return func 72 73 74def get_qnn_func( 75 data, 76 kernel, 77 input_zero_point, 78 kernel_zero_point, 79 input_scale, 80 kernel_scale, 81 kernel_size, 82 padding, 83 strides, 84 dilation, 85 data_layout, 86 kernel_layout, 87 out_dtype, 88 channels, 89 groups, 90): 91 func = relay.qnn.op.conv2d( 92 data, 93 kernel, 94 input_zero_point=relay.const(input_zero_point, "int32"), 95 kernel_zero_point=relay.const(kernel_zero_point, "int32"), 96 input_scale=relay.const(input_scale, "float32"), 97 kernel_scale=relay.const(kernel_scale, "float32"), 98 kernel_size=kernel_size, 99 strides=strides, 100 dilation=dilation, 101 padding=padding, 102 out_dtype=out_dtype, 103 groups=groups, 104 channels=channels, 105 data_layout=data_layout, 106 kernel_layout=kernel_layout, 107 ) 108 109 mod = relay.Function(relay.analysis.free_vars(func), func) 110 mod = tvm.IRModule.from_expr(mod) 111 return mod 112 113 114def get_funcs( 115 data_shape, 116 data_dtype, 117 kernel_shape, 118 kernel_dtype, 119 input_zero_point, 120 kernel_zero_point, 121 input_scale, 122 kernel_scale, 123 kernel_size, 124 padding, 125 strides, 126 dilation, 127 data_layout, 128 kernel_layout, 129 out_dtype, 130 groups=1, 131 channels=None, 132): 133 data = relay.var("data", shape=data_shape, dtype=data_dtype) 134 kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) 135 136 ref_func = get_ref_func( 137 data, 138 kernel, 139 input_zero_point, 140 kernel_zero_point, 141 input_scale, 142 kernel_scale, 143 kernel_size, 144 padding, 145 strides, 146 dilation, 147 data_layout, 148 kernel_layout, 149 out_dtype, 150 groups, 151 channels, 152 ) 153 ref_func = run_infer_type(ref_func) 154 ref_func = tvm.IRModule.from_expr(ref_func) 155 qnn_func = get_qnn_func( 156 data, 157 kernel, 158 input_zero_point, 159 kernel_zero_point, 160 input_scale, 161 kernel_scale, 162 kernel_size, 163 padding, 164 strides, 165 dilation, 166 data_layout, 167 kernel_layout, 168 out_dtype, 169 channels, 170 groups, 171 ) 172 173 return (ref_func, qnn_func) 174 175 176def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype): 177 def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype): 178 # Keeping inputs multiple of 4 because of a bug in Average Pool2d 179 # https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377 180 low = -128 181 high = 127 182 if data_dtype == "uint8": 183 low = 0 184 high = 255 185 golden_data = np.random.randint(low=low, high=high, size=data_shape).astype(data_dtype) 186 low = -128 187 high = 127 188 if kernel_dtype == "uint8": 189 low = 0 190 high = 255 191 golden_weight = np.random.randint(low=low, high=high, size=kernel_shape).astype( 192 kernel_dtype 193 ) 194 return (golden_data, golden_weight) 195 196 def get_output(func, golden_inputs): 197 with tvm.transform.PassContext(opt_level=2): 198 golden_data, golden_weight = golden_inputs 199 params = {"kernel": golden_weight} 200 graph, lib, params = relay.build(func, "llvm", params=params) 201 mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) 202 mod.set_input("data", golden_data) 203 mod.set_input(**params) 204 mod.run() 205 res = mod.get_output(0).asnumpy() 206 return res 207 208 golden_inputs = get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype) 209 golden_output = get_output(ref_func, golden_inputs) 210 qnn_output = get_output(qnn_func, golden_inputs) 211 np.testing.assert_equal(qnn_output, golden_output) 212 213 214def test_no_zero_point(): 215 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 216 217 # uint8 input 218 data_shape = (2, 1, 2, 4) 219 data_dtype = "uint8" 220 kernel_shape = (3, 1, 2, 2) 221 kernel_dtype = "uint8" 222 ref_func, qnn_func = get_funcs( 223 data_shape=data_shape, 224 data_dtype=data_dtype, 225 kernel_shape=kernel_shape, 226 kernel_dtype=kernel_dtype, 227 input_zero_point=0, 228 kernel_zero_point=0, 229 input_scale=1.0, 230 kernel_scale=1.0, 231 kernel_size=(2, 2), 232 padding=(0, 0), 233 strides=(1, 1), 234 dilation=(1, 1), 235 data_layout="NCHW", 236 kernel_layout="OIHW", 237 out_dtype="int32", 238 ) 239 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 240 241 # int8 input 242 data_shape = (2, 1, 2, 4) 243 data_dtype = "int8" 244 kernel_shape = (3, 1, 2, 2) 245 kernel_dtype = "int8" 246 ref_func, qnn_func = get_funcs( 247 data_shape=data_shape, 248 data_dtype=data_dtype, 249 kernel_shape=kernel_shape, 250 kernel_dtype=kernel_dtype, 251 input_zero_point=0, 252 kernel_zero_point=0, 253 input_scale=1.0, 254 kernel_scale=1.0, 255 kernel_size=(2, 2), 256 padding=(0, 0), 257 strides=(1, 1), 258 dilation=(1, 1), 259 data_layout="NCHW", 260 kernel_layout="OIHW", 261 out_dtype="int32", 262 ) 263 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 264 265 266def test_kernel_zero_point(): 267 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 268 269 # uint8 input 270 data_shape = (2, 4, 2, 4) 271 data_dtype = "uint8" 272 kernel_shape = (3, 4, 2, 2) 273 kernel_dtype = "uint8" 274 ref_func, qnn_func = get_funcs( 275 data_shape=data_shape, 276 data_dtype=data_dtype, 277 kernel_shape=kernel_shape, 278 kernel_dtype=kernel_dtype, 279 input_zero_point=0, 280 kernel_zero_point=1, 281 input_scale=1.0, 282 kernel_scale=1.0, 283 kernel_size=(2, 2), 284 padding=(0, 0), 285 strides=(1, 1), 286 dilation=(1, 1), 287 data_layout="NCHW", 288 kernel_layout="OIHW", 289 out_dtype="int32", 290 ) 291 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 292 293 # int8 input 294 data_shape = (2, 1, 2, 4) 295 data_dtype = "int8" 296 kernel_shape = (3, 1, 2, 2) 297 kernel_dtype = "int8" 298 ref_func, qnn_func = get_funcs( 299 data_shape=data_shape, 300 data_dtype=data_dtype, 301 kernel_shape=kernel_shape, 302 kernel_dtype=kernel_dtype, 303 input_zero_point=0, 304 kernel_zero_point=5, 305 input_scale=1.0, 306 kernel_scale=1.0, 307 kernel_size=(2, 2), 308 padding=(0, 0), 309 strides=(1, 1), 310 dilation=(1, 1), 311 data_layout="NCHW", 312 kernel_layout="OIHW", 313 out_dtype="int32", 314 ) 315 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 316 317 318def test_input_zero_point(): 319 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 320 321 # uint8 input 322 data_shape = (2, 4, 2, 4) 323 data_dtype = "uint8" 324 kernel_shape = (3, 4, 2, 2) 325 kernel_dtype = "uint8" 326 ref_func, qnn_func = get_funcs( 327 data_shape=data_shape, 328 data_dtype=data_dtype, 329 kernel_shape=kernel_shape, 330 kernel_dtype=kernel_dtype, 331 input_zero_point=5, 332 kernel_zero_point=0, 333 input_scale=1.0, 334 kernel_scale=1.0, 335 kernel_size=(2, 2), 336 padding=(0, 0), 337 strides=(1, 1), 338 dilation=(1, 1), 339 data_layout="NCHW", 340 kernel_layout="OIHW", 341 out_dtype="int32", 342 ) 343 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 344 345 # int8 input 346 data_shape = (2, 4, 2, 4) 347 data_dtype = "int8" 348 kernel_shape = (3, 4, 2, 2) 349 kernel_dtype = "int8" 350 ref_func, qnn_func = get_funcs( 351 data_shape=data_shape, 352 data_dtype=data_dtype, 353 kernel_shape=kernel_shape, 354 kernel_dtype=kernel_dtype, 355 input_zero_point=5, 356 kernel_zero_point=0, 357 input_scale=1.0, 358 kernel_scale=1.0, 359 kernel_size=(2, 2), 360 padding=(0, 0), 361 strides=(1, 1), 362 dilation=(1, 1), 363 data_layout="NCHW", 364 kernel_layout="OIHW", 365 out_dtype="int32", 366 ) 367 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 368 369 370def test_both_zero_point(): 371 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 372 373 # uint8 input 374 data_shape = (2, 4, 2, 4) 375 data_dtype = "uint8" 376 kernel_shape = (3, 4, 2, 2) 377 kernel_dtype = "uint8" 378 ref_func, qnn_func = get_funcs( 379 data_shape=data_shape, 380 data_dtype=data_dtype, 381 kernel_shape=kernel_shape, 382 kernel_dtype=kernel_dtype, 383 input_zero_point=5, 384 kernel_zero_point=3, 385 input_scale=1.0, 386 kernel_scale=1.0, 387 kernel_size=(2, 2), 388 padding=(0, 0), 389 strides=(1, 1), 390 dilation=(1, 1), 391 data_layout="NCHW", 392 kernel_layout="OIHW", 393 out_dtype="int32", 394 ) 395 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 396 397 # int8 input 398 data_shape = (2, 4, 2, 4) 399 data_dtype = "int8" 400 kernel_shape = (3, 4, 2, 2) 401 kernel_dtype = "int8" 402 ref_func, qnn_func = get_funcs( 403 data_shape=data_shape, 404 data_dtype=data_dtype, 405 kernel_shape=kernel_shape, 406 kernel_dtype=kernel_dtype, 407 input_zero_point=5, 408 kernel_zero_point=3, 409 input_scale=1.0, 410 kernel_scale=1.0, 411 kernel_size=(2, 2), 412 padding=(0, 0), 413 strides=(1, 1), 414 dilation=(1, 1), 415 data_layout="NCHW", 416 kernel_layout="OIHW", 417 out_dtype="int32", 418 ) 419 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 420 421 422def test_layout(): 423 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 424 425 # uint8 input 426 data_shape = (2, 2, 4, 4) # NHWC 427 data_dtype = "uint8" 428 kernel_shape = (2, 2, 4, 3) # HWIO 429 kernel_dtype = "uint8" 430 ref_func, qnn_func = get_funcs( 431 data_shape=data_shape, 432 data_dtype=data_dtype, 433 kernel_shape=kernel_shape, 434 kernel_dtype=kernel_dtype, 435 input_zero_point=5, 436 kernel_zero_point=3, 437 input_scale=1.0, 438 kernel_scale=1.0, 439 kernel_size=(2, 2), 440 padding=(0, 0), 441 strides=(1, 1), 442 dilation=(1, 1), 443 data_layout="NHWC", 444 kernel_layout="HWIO", 445 out_dtype="int32", 446 ) 447 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 448 449 # NHWC and HWOI layout. Used in depthwise conv. 450 data_shape = (2, 2, 4, 3) # NHWC 451 data_dtype = "uint8" 452 kernel_shape = (2, 2, 3, 1) # HWOI 453 kernel_dtype = "uint8" 454 ref_func, qnn_func = get_funcs( 455 data_shape=data_shape, 456 data_dtype=data_dtype, 457 kernel_shape=kernel_shape, 458 kernel_dtype=kernel_dtype, 459 input_zero_point=5, 460 kernel_zero_point=3, 461 input_scale=1.0, 462 kernel_scale=1.0, 463 kernel_size=(2, 2), 464 padding=(0, 0), 465 strides=(1, 1), 466 dilation=(1, 1), 467 groups=3, 468 data_layout="NHWC", 469 kernel_layout="HWOI", 470 out_dtype="int32", 471 ) 472 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 473 474 475def test_padding(): 476 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 477 478 # uint8 input 479 data_shape = (1, 4, 2, 2) 480 data_dtype = "uint8" 481 kernel_shape = (3, 4, 2, 2) 482 kernel_dtype = "uint8" 483 ref_func, qnn_func = get_funcs( 484 data_shape=data_shape, 485 data_dtype=data_dtype, 486 kernel_shape=kernel_shape, 487 kernel_dtype=kernel_dtype, 488 input_zero_point=8, 489 kernel_zero_point=5, 490 input_scale=1.0, 491 kernel_scale=1.0, 492 kernel_size=(2, 2), 493 padding=(1, 1), 494 strides=(1, 1), 495 dilation=(1, 1), 496 data_layout="NCHW", 497 kernel_layout="OIHW", 498 out_dtype="int32", 499 ) 500 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 501 502 # Try different layout 503 data_shape = (2, 2, 4, 4) # NHWC 504 data_dtype = "uint8" 505 kernel_shape = (2, 2, 4, 3) # HWIO 506 kernel_dtype = "uint8" 507 ref_func, qnn_func = get_funcs( 508 data_shape=data_shape, 509 data_dtype=data_dtype, 510 kernel_shape=kernel_shape, 511 kernel_dtype=kernel_dtype, 512 input_zero_point=8, 513 kernel_zero_point=3, 514 input_scale=1.0, 515 kernel_scale=1.0, 516 kernel_size=(2, 2), 517 padding=(1, 1), 518 strides=(1, 1), 519 dilation=(1, 1), 520 data_layout="NHWC", 521 kernel_layout="HWIO", 522 out_dtype="int32", 523 ) 524 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 525 526 # Try asymmetric padding 527 data_shape = (2, 2, 4, 4) # NHWC 528 data_dtype = "uint8" 529 kernel_shape = (2, 2, 4, 3) # HWIO 530 kernel_dtype = "uint8" 531 ref_func, qnn_func = get_funcs( 532 data_shape=data_shape, 533 data_dtype=data_dtype, 534 kernel_shape=kernel_shape, 535 kernel_dtype=kernel_dtype, 536 input_zero_point=8, 537 kernel_zero_point=3, 538 input_scale=1.0, 539 kernel_scale=1.0, 540 kernel_size=(2, 2), 541 padding=(1, 1, 2, 2), 542 strides=(1, 1), 543 dilation=(1, 1), 544 data_layout="NHWC", 545 kernel_layout="HWIO", 546 out_dtype="int32", 547 ) 548 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 549 550 551def test_dilation(): 552 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 553 554 # Non-zero kernel point - fall back to simpler lowering. 555 data_shape = (2, 4, 4, 4) 556 data_dtype = "uint8" 557 kernel_shape = (3, 4, 2, 2) 558 kernel_dtype = "uint8" 559 ref_func, qnn_func = get_funcs( 560 data_shape=data_shape, 561 data_dtype=data_dtype, 562 kernel_shape=kernel_shape, 563 kernel_dtype=kernel_dtype, 564 input_zero_point=5, 565 kernel_zero_point=3, 566 input_scale=1.0, 567 kernel_scale=1.0, 568 kernel_size=(2, 2), 569 padding=(0, 0), 570 strides=(1, 1), 571 dilation=(2, 2), 572 data_layout="NCHW", 573 kernel_layout="OIHW", 574 out_dtype="int32", 575 ) 576 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 577 578 # Zero kernel point 579 data_shape = (2, 4, 4, 4) 580 data_dtype = "uint8" 581 kernel_shape = (3, 4, 2, 2) 582 kernel_dtype = "uint8" 583 ref_func, qnn_func = get_funcs( 584 data_shape=data_shape, 585 data_dtype=data_dtype, 586 kernel_shape=kernel_shape, 587 kernel_dtype=kernel_dtype, 588 input_zero_point=0, 589 kernel_zero_point=0, 590 input_scale=1.0, 591 kernel_scale=1.0, 592 kernel_size=(2, 2), 593 padding=(0, 0), 594 strides=(1, 1), 595 dilation=(2, 2), 596 data_layout="NCHW", 597 kernel_layout="OIHW", 598 out_dtype="int32", 599 ) 600 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 601 602 603def test_const_folding(): 604 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 605 606 data_shape = (2, 4, 2, 4) 607 data_dtype = "uint8" 608 kernel_shape = (3, 4, 2, 2) 609 kernel_dtype = "uint8" 610 611 golden_weight = np.random.randint(low=0, high=255, size=kernel_shape).astype(kernel_dtype) 612 data = relay.var("data", shape=data_shape, dtype=data_dtype) 613 kernel = relay.const(golden_weight) 614 qnn_func = get_qnn_func( 615 data, 616 kernel, 617 input_zero_point=8, 618 kernel_zero_point=3, 619 kernel_size=(2, 2), 620 input_scale=1.0, 621 kernel_scale=1.0, 622 padding=(0, 0), 623 strides=(1, 1), 624 dilation=(1, 1), 625 data_layout="NCHW", 626 kernel_layout="OIHW", 627 out_dtype="int32", 628 channels=kernel_shape[0], 629 groups=1, 630 ) 631 folded_mod = transform.FoldConstant()(qnn_func) 632 folded_func = folded_mod["main"] 633 assert "reshape" not in folded_func.astext() 634 635 636def test_kernel_size_1x1(): 637 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 638 639 # uint8 input 640 data_shape = (2, 4, 2, 4) 641 data_dtype = "uint8" 642 kernel_shape = (3, 4, 1, 1) 643 kernel_dtype = "uint8" 644 ref_func, qnn_func = get_funcs( 645 data_shape=data_shape, 646 data_dtype=data_dtype, 647 kernel_shape=kernel_shape, 648 kernel_dtype=kernel_dtype, 649 input_zero_point=5, 650 kernel_zero_point=3, 651 input_scale=1.0, 652 kernel_scale=1.0, 653 kernel_size=(1, 1), 654 padding=(0, 0), 655 strides=(1, 1), 656 dilation=(1, 1), 657 data_layout="NCHW", 658 kernel_layout="OIHW", 659 out_dtype="int32", 660 ) 661 assert "avg_pool2d" not in qnn_func.astext() 662 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 663 664 665def test_kernel_size_1x1_strides_2(): 666 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 667 668 # uint8 input 669 data_shape = (2, 4, 2, 4) 670 data_dtype = "uint8" 671 kernel_shape = (3, 4, 1, 1) 672 kernel_dtype = "uint8" 673 ref_func, qnn_func = get_funcs( 674 data_shape=data_shape, 675 data_dtype=data_dtype, 676 kernel_shape=kernel_shape, 677 kernel_dtype=kernel_dtype, 678 input_zero_point=5, 679 kernel_zero_point=3, 680 input_scale=1.0, 681 kernel_scale=1.0, 682 kernel_size=(1, 1), 683 padding=(0, 0), 684 strides=(2, 2), 685 dilation=(1, 1), 686 data_layout="NCHW", 687 kernel_layout="OIHW", 688 out_dtype="int32", 689 ) 690 assert "avg_pool2d" not in qnn_func.astext() 691 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 692 693 694def test_tflite_large_irregular(): 695 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 696 697 # uint8 input 698 data_shape = (1, 1024, 1, 1) 699 data_dtype = "uint8" 700 kernel_shape = (1001, 1024, 1, 1) 701 kernel_dtype = "uint8" 702 ref_func, qnn_func = get_funcs( 703 data_shape=data_shape, 704 data_dtype=data_dtype, 705 kernel_shape=kernel_shape, 706 kernel_dtype=kernel_dtype, 707 input_zero_point=127, 708 kernel_zero_point=127, 709 input_scale=1.0, 710 kernel_scale=1.0, 711 kernel_size=(1, 1), 712 padding=(0, 0), 713 strides=(1, 1), 714 dilation=(1, 1), 715 data_layout="NCHW", 716 kernel_layout="OIHW", 717 out_dtype="int32", 718 ) 719 golden_data = np.full(data_shape, 127).astype("uint8") 720 golden_weight = np.full(kernel_shape, 127).astype("uint8") 721 722 with tvm.transform.PassContext(opt_level=2): 723 params = {"kernel": golden_weight} 724 graph, lib, params = relay.build(qnn_func, "llvm", params=params) 725 mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) 726 mod.set_input("data", golden_data) 727 mod.set_input(**params) 728 mod.run() 729 qnn_output = mod.get_output(0).asnumpy() 730 golden_output = np.full((1, 1001, 1, 1), 0).astype("uint8") 731 np.testing.assert_equal(qnn_output, golden_output) 732 733 734def test_tflite_output_multiplier_greater_than_one(): 735 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 736 737 # uint8 input 738 data_shape = (2, 1, 2, 4) 739 data_dtype = "uint8" 740 kernel_shape = (3, 1, 2, 2) 741 kernel_dtype = "uint8" 742 ref_func, qnn_func = get_funcs( 743 data_shape=data_shape, 744 data_dtype=data_dtype, 745 kernel_shape=kernel_shape, 746 kernel_dtype=kernel_dtype, 747 input_scale=1.0, 748 kernel_scale=1.0, 749 input_zero_point=128, 750 kernel_zero_point=128, 751 kernel_size=(2, 2), 752 padding=(0, 0), 753 strides=(2, 2), 754 dilation=(1, 1), 755 data_layout="NCHW", 756 kernel_layout="OIHW", 757 out_dtype="int32", 758 ) 759 golden_data = 128 + np.array((1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 3, 4, 1, 2, 3, 4)).reshape( 760 data_shape 761 ).astype("uint8") 762 golden_weight = 128 + np.array((1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1)).reshape( 763 kernel_shape 764 ) 765 golden_weight = golden_weight.astype("uint8") 766 767 with tvm.transform.PassContext(opt_level=2): 768 params = {"kernel": golden_weight} 769 graph, lib, params = relay.build(qnn_func, "llvm", params=params) 770 mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) 771 mod.set_input("data", golden_data) 772 mod.set_input(**params) 773 mod.run() 774 qnn_output = mod.get_output(0).asnumpy() 775 golden_output = np.array((17, 17, 0, 0, 2, 2, 16, 36, 2, 2, 0, 0)).reshape(2, 3, 1, 2) 776 np.testing.assert_equal(qnn_output, golden_output) 777 778 779def test_tflite_anistropic_strides(): 780 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 781 782 # uint8 input 783 data_shape = (1, 1, 3, 6) 784 data_dtype = "uint8" 785 kernel_shape = (1, 1, 2, 2) 786 kernel_dtype = "uint8" 787 ref_func, qnn_func = get_funcs( 788 data_shape=data_shape, 789 data_dtype=data_dtype, 790 kernel_shape=kernel_shape, 791 kernel_dtype=kernel_dtype, 792 input_zero_point=127, 793 kernel_zero_point=127, 794 input_scale=1.0, 795 kernel_scale=1.0, 796 kernel_size=(2, 2), 797 padding=(0, 0), 798 strides=(1, 3), 799 dilation=(1, 1), 800 data_layout="NCHW", 801 kernel_layout="OIHW", 802 out_dtype="int32", 803 ) 804 golden_data = np.array( 805 ( 806 133, 807 131, 808 129, 809 125, 810 123, 811 121, 812 135, 813 133, 814 131, 815 123, 816 121, 817 119, 818 137, 819 135, 820 133, 821 121, 822 119, 823 117, 824 ) 825 ).reshape(data_shape) 826 golden_data = golden_data.astype("uint8") 827 golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape) 828 golden_weight = golden_weight.astype("uint8") 829 830 with tvm.transform.PassContext(opt_level=2): 831 params = {"kernel": golden_weight} 832 graph, lib, params = relay.build(qnn_func, "llvm", params=params) 833 mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) 834 mod.set_input("data", golden_data) 835 mod.set_input(**params) 836 mod.run() 837 qnn_output = mod.get_output(0).asnumpy() 838 golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2) 839 np.testing.assert_equal(qnn_output, golden_output) 840 841 842def test_broadcast_layout(): 843 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 844 845 # Test broadcast support for NHWC layout. 846 data_shape = (1, 229, 229, 3) # NHWC 847 data_dtype = "uint8" 848 kernel_shape = (7, 7, 3, 64) # HWIO 849 kernel_dtype = "int8" 850 _, qnn_func = get_funcs( 851 data_shape=data_shape, 852 data_dtype=data_dtype, 853 kernel_shape=kernel_shape, 854 kernel_dtype=kernel_dtype, 855 input_zero_point=8, 856 kernel_zero_point=3, 857 input_scale=1.0, 858 kernel_scale=1.0, 859 kernel_size=(7, 7), 860 padding=(1, 1), 861 strides=(1, 1), 862 dilation=(1, 1), 863 data_layout="NHWC", 864 kernel_layout="HWIO", 865 out_dtype="int32", 866 ) 867 func = qnn_func["main"].body 868 bias = relay.var("bias", shape=(64,), dtype="int32") 869 bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32") 870 871 # Check broadcast support on both lhs and rhs 872 func = relay.add(func, bias2) 873 func = relay.add(bias2, func) 874 func = relay.add(bias, func) 875 func = relay.add(func, bias) 876 func = relay.Function(relay.analysis.free_vars(func), func) 877 mod = tvm.IRModule.from_expr(func) 878 with tvm.transform.PassContext(opt_level=3): 879 graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") 880 881 882def test_depthwise_depth_multiplier(): 883 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 884 885 # uint8 input, NCHW and OIHW 886 # Depthwise multiplier = 1 887 data_shape = (2, 4, 16, 16) 888 data_dtype = "uint8" 889 kernel_shape = (4, 1, 3, 3) 890 kernel_dtype = "uint8" 891 ref_func, qnn_func = get_funcs( 892 data_shape=data_shape, 893 data_dtype=data_dtype, 894 kernel_shape=kernel_shape, 895 kernel_dtype=kernel_dtype, 896 input_zero_point=5, 897 kernel_zero_point=3, 898 input_scale=1.0, 899 kernel_scale=1.0, 900 kernel_size=(3, 3), 901 padding=(0, 0), 902 strides=(1, 1), 903 dilation=(1, 1), 904 data_layout="NCHW", 905 kernel_layout="OIHW", 906 out_dtype="int32", 907 groups=4, 908 ) 909 910 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 911 912 # Depthwise multiplier = 2 913 data_shape = (10, 4, 16, 16) 914 data_dtype = "uint8" 915 kernel_shape = (4, 2, 3, 3) 916 kernel_dtype = "uint8" 917 ref_func, qnn_func = get_funcs( 918 data_shape=data_shape, 919 data_dtype=data_dtype, 920 kernel_shape=kernel_shape, 921 kernel_dtype=kernel_dtype, 922 input_zero_point=5, 923 kernel_zero_point=3, 924 input_scale=1.0, 925 kernel_scale=1.0, 926 kernel_size=(3, 3), 927 padding=(0, 0), 928 strides=(1, 1), 929 dilation=(1, 1), 930 data_layout="NCHW", 931 kernel_layout="OIHW", 932 out_dtype="int32", 933 groups=4, 934 channels=8, 935 ) 936 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 937 938 # uint8 input, NHWC and HWOI 939 # Depthwise multiplier = 1 940 data_shape = (2, 16, 16, 4) 941 data_dtype = "uint8" 942 kernel_shape = (3, 3, 4, 1) 943 kernel_dtype = "uint8" 944 ref_func, qnn_func = get_funcs( 945 data_shape=data_shape, 946 data_dtype=data_dtype, 947 kernel_shape=kernel_shape, 948 kernel_dtype=kernel_dtype, 949 input_zero_point=5, 950 kernel_zero_point=3, 951 input_scale=1.0, 952 kernel_scale=1.0, 953 kernel_size=(3, 3), 954 padding=(0, 0), 955 strides=(1, 1), 956 dilation=(1, 1), 957 data_layout="NHWC", 958 kernel_layout="HWOI", 959 out_dtype="int32", 960 groups=4, 961 ) 962 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 963 964 # Depthwise multiplier = 2 965 data_shape = (2, 16, 16, 4) 966 data_dtype = "uint8" 967 kernel_shape = (3, 3, 4, 2) 968 kernel_dtype = "uint8" 969 ref_func, qnn_func = get_funcs( 970 data_shape=data_shape, 971 data_dtype=data_dtype, 972 kernel_shape=kernel_shape, 973 kernel_dtype=kernel_dtype, 974 input_zero_point=5, 975 kernel_zero_point=3, 976 input_scale=1.0, 977 kernel_scale=1.0, 978 kernel_size=(3, 3), 979 padding=(0, 0), 980 strides=(1, 1), 981 dilation=(1, 1), 982 data_layout="NHWC", 983 kernel_layout="HWOI", 984 out_dtype="int32", 985 groups=4, 986 channels=8, 987 ) 988 verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) 989 990 991def test_per_channel_kernel_scale(): 992 with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): 993 data_shape = (2, 1, 2, 4) 994 data_dtype = "uint8" 995 kernel_shape = (3, 1, 2, 2) 996 kernel_dtype = "uint8" 997 data = relay.var("data", shape=data_shape, dtype=data_dtype) 998 kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) 999 kernel_scales = [2, 2, 2] 1000 kernel_scales = relay.const(np.array(kernel_scales).astype("float32")) 1001 func = relay.qnn.op.conv2d( 1002 data, 1003 kernel, 1004 input_zero_point=relay.const(0, "int32"), 1005 kernel_zero_point=relay.const(0, "int32"), 1006 input_scale=relay.const(2.0, "float32"), 1007 kernel_scale=kernel_scales, 1008 kernel_size=(2, 2), 1009 channels=kernel_shape[0], 1010 padding=(0, 0), 1011 strides=(1, 1), 1012 dilation=(1, 1), 1013 data_layout="NCHW", 1014 kernel_layout="OIHW", 1015 out_dtype="int32", 1016 ) 1017 1018 mod = relay.Function(relay.analysis.free_vars(func), func) 1019 mod = tvm.IRModule.from_expr(mod) 1020 1021 1022if __name__ == "__main__": 1023 test_no_zero_point() 1024 test_input_zero_point() 1025 test_kernel_zero_point() 1026 test_both_zero_point() 1027 test_layout() 1028 test_padding() 1029 test_dilation() 1030 test_const_folding() 1031 test_kernel_size_1x1() 1032 test_kernel_size_1x1_strides_2() 1033 test_tflite_large_irregular() 1034 test_broadcast_layout() 1035 test_tflite_output_multiplier_greater_than_one() 1036 test_tflite_anistropic_strides() 1037 test_depthwise_depth_multiplier() 1038 test_per_channel_kernel_scale() 1039