1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable, too-many-locals 18# pylint: disable=unused-argument, redefined-builtin 19"""Conv2D operators""" 20from __future__ import absolute_import as _abs 21from collections import namedtuple 22import tvm 23 24from .pad import pad 25from .util import get_pad_tuple 26from ..util import simplify, get_const_tuple 27from .winograd_util import winograd_transform_matrices 28 29# workload description of conv2d 30Workload = namedtuple('Workload', 31 ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 32 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 33 34@tvm.target.generic_func 35def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None): 36 """Conv2D operator. 37 38 Parameters 39 ---------- 40 input : tvm.Tensor 41 4-D with shape [batch, in_channel, in_height, in_width] 42 43 filter : tvm.Tensor 44 4-D with shape [num_filter, in_channel, filter_height, filter_width] 45 46 strides : int or a list/tuple of two ints 47 stride size, or [stride_height, stride_width] 48 49 padding : int or a list/tuple of two ints 50 padding size, or [pad_height, pad_width] 51 52 dilation: int or a list/tuple of two ints 53 dilation size, or [dilation_height, dilation_width] 54 55 layout : str 56 layout of data 57 58 Returns 59 ------- 60 output : tvm.Tensor 61 4-D with shape [batch, out_channel, out_height, out_width] 62 """ 63 # search platform specific declaration first 64 # default declaration 65 if layout == 'NCHW': 66 return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype) 67 elif layout == 'HWCN': 68 return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype) 69 elif layout == 'NHWC': 70 return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype) 71 raise ValueError("not support this layout {} yet".format(layout)) 72 73 74@tvm.target.generic_func 75def conv2d_legalize(attrs, inputs, types): 76 """Legalizes Conv2D op. 77 78 Parameters 79 ---------- 80 attrs : tvm.attrs.Attrs 81 Attributes of current convolution 82 inputs : list of tvm.relay.Expr 83 The args of the Relay expr to be legalized 84 types : list of types 85 List of input and output types 86 87 Returns 88 ------- 89 result : tvm.relay.Expr 90 The legalized expr 91 """ 92 # not to change by default 93 return None 94 95 96@tvm.target.generic_func 97def conv2d_alter_layout(attrs, inputs, tinfos, F): 98 """Change Conv2D layout. 99 100 Parameters 101 ---------- 102 attrs : nnvm.top.AttrDict or tvm.attrs.Attrs 103 Attributes of current convolution 104 inputs : nnvm.symbol or tvm.relay.Expr 105 Grouped input symbols 106 tinfos : list 107 Input shape and dtype 108 F: symbol 109 The context, can be either nnvm.sym or relay.op 110 111 Note 112 ---- 113 Unlike other TOPI functions, this function operates on both graph level and operator level, 114 so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. 115 """ 116 # not to change by default 117 return None 118 119@tvm.target.generic_func 120def conv2d_infer_layout(workload, cfg): 121 """Infer input/output shapes and layouts from a workload and cfg. 122 123 Parameters 124 ---------- 125 workload : tuple 126 conv2d workload 127 128 cfg : tuple 129 tvm.autotvm config 130 131 Returns 132 ------- 133 Output : [tuple of tuple and str, tuple of tuple and str] 134 Input shapes and layouts, and output shapes and layouts 135 """ 136 raise ValueError("missing register for topi.nn.conv2d_infer_layout") 137 138 139 140def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): 141 """ Get the workload structure. """ 142 if data_layout == 'NCHW': 143 _, CI, IH, IW = [x.value for x in data.shape] 144 elif data_layout == 'NHWC': 145 _, IH, IW, CI = [x.value for x in data.shape] 146 elif data_layout == 'HWCN': 147 IH, IW, CI, _ = [x.value for x in data.shape] 148 else: 149 raise ValueError("not support this layout {} yet".format(data_layout)) 150 151 if data_layout == 'NCHW': 152 CO, CIG, KH, KW = [x.value for x in kernel.shape] 153 else: 154 KH, KW, CIG, CO = [x.value for x in kernel.shape] 155 156 HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) 157 GRPS = CI // CIG 158 if isinstance(stride, (tuple, list)): 159 HSTR, WSTR = stride 160 else: 161 HSTR, WSTR = stride, stride 162 assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ 163 "Do not support inputs with different data types now. ' \ 164 '{} vs. {}".format(data.dtype, kernel.dtype) 165 return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) 166 167 168def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): 169 """Convolution operator in NCHW layout. 170 171 Parameters 172 ---------- 173 Input : tvm.Tensor 174 4-D with shape [batch, in_channel, in_height, in_width] 175 176 Filter : tvm.Tensor 177 4-D with shape [num_filter, in_channel, filter_height, filter_width] 178 179 stride : int or a list/tuple of two ints 180 Stride size, or [stride_height, stride_width] 181 182 padding : int or str 183 Padding size, or ['VALID', 'SAME'] 184 185 dilation: int or a list/tuple of two ints 186 dilation size, or [dilation_height, dilation_width] 187 188 Returns 189 ------- 190 Output : tvm.Tensor 191 4-D with shape [batch, out_channel, out_height, out_width] 192 """ 193 if out_dtype is None: 194 out_dtype = Input.dtype 195 assert isinstance(stride, int) or len(stride) == 2 196 assert isinstance(dilation, int) or len(dilation) == 2 197 if isinstance(stride, int): 198 stride_h = stride_w = stride 199 else: 200 stride_h, stride_w = stride 201 202 if isinstance(dilation, int): 203 dilation_h = dilation_w = dilation 204 else: 205 dilation_h, dilation_w = dilation 206 207 batch, in_channel, in_height, in_width = Input.shape 208 num_filter, channel, kernel_h, kernel_w = Filter.shape 209 # compute the output shape 210 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 211 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 212 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 213 padding, (dilated_kernel_h, dilated_kernel_w)) 214 out_channel = num_filter 215 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 216 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 217 # compute graph 218 pad_before = [0, 0, pad_top, pad_left] 219 pad_after = [0, 0, pad_down, pad_right] 220 temp = pad(Input, pad_before, pad_after, name="pad_temp") 221 rc = tvm.reduce_axis((0, in_channel), name='rc') 222 ry = tvm.reduce_axis((0, kernel_h), name='ry') 223 rx = tvm.reduce_axis((0, kernel_w), name='rx') 224 225 return tvm.compute( 226 (batch, out_channel, out_height, out_width), 227 lambda nn, ff, yy, xx: tvm.sum( 228 temp[nn, rc, yy * stride_h + ry * dilation_h, 229 xx * stride_w + rx * dilation_w].astype(out_dtype) * 230 Filter[ff, rc, ry, rx].astype(out_dtype), 231 axis=[rc, ry, rx]), tag="conv2d_nchw") 232 233 234def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): 235 """Convolution operator in HWCN layout. 236 237 Parameters 238 ---------- 239 Input : tvm.Tensor 240 4-D with shape [in_height, in_width, in_channel, batch] 241 242 Filter : tvm.Tensor 243 4-D with shape [filter_height, filter_width, in_channel, num_filter] 244 245 stride : int or a list/tuple of two ints 246 Stride size, or [stride_height, stride_width] 247 248 padding : int or str 249 Padding size, or ['VALID', 'SAME'] 250 251 dilation: int or a list/tuple of two ints 252 dilation size, or [dilation_height, dilation_width] 253 254 Returns 255 ------- 256 output : tvm.Tensor 257 4-D with shape [out_height, out_width, out_channel, batch] 258 """ 259 if out_dtype is None: 260 out_dtype = Input.dtype 261 assert isinstance(stride, int) or len(stride) == 2 262 assert isinstance(dilation, int) or len(dilation) == 2 263 264 if isinstance(stride, int): 265 stride_h = stride_w = stride 266 else: 267 stride_h, stride_w = stride 268 269 if isinstance(dilation, int): 270 dilation_h = dilation_w = dilation 271 else: 272 dilation_h, dilation_w = dilation 273 274 in_height, in_width, in_channel, batch = Input.shape 275 kernel_h, kernel_w, channel, num_filter = Filter.shape 276 # compute the output shape 277 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 278 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 279 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 280 padding, (dilated_kernel_h, dilated_kernel_w)) 281 out_channel = num_filter 282 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 283 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 284 pad_before = [pad_top, pad_left, 0, 0] 285 pad_after = [pad_down, pad_right, 0, 0] 286 PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") 287 rc = tvm.reduce_axis((0, in_channel), name='rc') 288 ry = tvm.reduce_axis((0, kernel_h), name='ry') 289 rx = tvm.reduce_axis((0, kernel_w), name='rx') 290 Output = tvm.compute( 291 (out_height, out_width, out_channel, batch), 292 lambda yy, xx, ff, nn: tvm.sum( 293 PaddedInput[yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, 294 rc, nn].astype(out_dtype) * 295 Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), 296 name="Conv2dOutput", tag="conv2d_hwcn") 297 return Output 298 299 300def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): 301 """Convolution operator in NHWC layout. 302 303 Parameters 304 ---------- 305 Input : tvm.Tensor 306 4-D with shape [batch, in_height, in_width, in_channel] 307 308 Filter : tvm.Tensor 309 4-D with shape [filter_height, filter_width, in_channel, num_filter] 310 311 stride : int or a list/tuple of two ints 312 Stride size, or [stride_height, stride_width] 313 314 padding : int or str 315 Padding size, or ['VALID', 'SAME'] 316 317 dilation: int or a list/tuple of two ints 318 dilation size, or [dilation_height, dilation_width] 319 320 Returns 321 ------- 322 output : tvm.Tensor 323 4-D with shape [batch, out_height, out_width, out_channel] 324 """ 325 assert isinstance(stride, int) or len(stride) == 2 326 assert isinstance(dilation, int) or len(dilation) == 2 327 328 if isinstance(stride, int): 329 stride_h = stride_w = stride 330 else: 331 stride_h, stride_w = stride 332 333 if isinstance(dilation, int): 334 dilation_h = dilation_w = dilation 335 else: 336 dilation_h, dilation_w = dilation 337 338 batch, in_height, in_width, in_channel = Input.shape 339 kernel_h, kernel_w, channel, num_filter = Filter.shape 340 # compute the output shape 341 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 342 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 343 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 344 padding, (dilated_kernel_h, dilated_kernel_w)) 345 out_channel = num_filter 346 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 347 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 348 pad_before = [0, pad_top, pad_left, 0] 349 pad_after = [0, pad_down, pad_right, 0] 350 PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") 351 rc = tvm.reduce_axis((0, in_channel), name='rc') 352 ry = tvm.reduce_axis((0, kernel_h), name='ry') 353 rx = tvm.reduce_axis((0, kernel_w), name='rx') 354 Output = tvm.compute( 355 (batch, out_height, out_width, out_channel), 356 lambda nn, yy, xx, ff: tvm.sum( 357 PaddedInput[nn, yy * stride_h + ry * dilation_h, 358 xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * 359 Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), 360 name="Conv2dOutput", tag="conv2d_nhwc") 361 return Output 362 363 364@tvm.target.generic_func 365def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'): 366 """Conv2D operator for nChw[x]c layout. 367 368 Parameters 369 ---------- 370 data : tvm.Tensor 371 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] 372 373 kernel : tvm.Tensor 374 6-D with shape 375 [num_filter_chunk, in_channel_chunk, filter_height, filter_width, 376 in_channel_block, num_filter_block] 377 378 stride : int or a list/tuple of two ints 379 stride size, or [stride_height, stride_width] 380 381 padding : int or a list/tuple of two ints 382 padding size, or [pad_height, pad_width] 383 384 dilation: int or a list/tuple of two ints 385 dilation size, or [dilation_height, dilation_width] 386 387 layout : str 388 Input data layout 389 390 out_layout : str 391 Output data layout 392 393 out_dtype : str 394 output data type 395 396 Returns 397 ------- 398 output : tvm.Tensor 399 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 400 """ 401 402 return conv2d_NCHWc_compute(data, 403 kernel, 404 stride, 405 padding, 406 dilation, 407 layout, 408 out_layout, 409 out_dtype) 410 411 412def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): 413 """Conv2D operator compute for nChw[x]c layout. 414 415 Parameters 416 ---------- 417 data : tvm.Tensor 418 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] 419 420 kernel : tvm.Tensor 421 6-D with shape 422 [num_filter_chunk, in_channel_chunk, filter_height, filter_width, 423 in_channel_block, num_filter_block] 424 425 stride : int or a list/tuple of two ints 426 stride size, or [stride_height, stride_width] 427 428 padding : int or a list/tuple of two ints 429 padding size, or [pad_height, pad_width] 430 431 dilation: int or a list/tuple of two ints 432 dilation size, or [dilation_height, dilation_width] 433 434 layout : str 435 Input data layout 436 437 out_layout : str 438 Output data layout 439 440 out_dtype : str 441 output data type 442 443 Returns 444 ------- 445 output : tvm.Tensor 446 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 447 """ 448 449 # layout and out_layout are not used here, 450 # we keep them for debug convenience when dumping autotvm workload 451 HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) 452 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) 453 dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ 454 else (dilation, dilation) 455 456 n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) 457 in_channel = ic_chunk * ic_bn 458 target = tvm.target.current_target(allow_none=False) 459 oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ 460 get_const_tuple(kernel.shape) 461 num_filter = oc_chunk * oc_bn 462 groups = ic_chunk // ic_chunk_group 463 464 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 465 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 466 467 # output shape 468 out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 469 out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 470 oshape = (n, oc_chunk, out_height, out_width, oc_bn) 471 472 # DOPAD 473 DOPAD = (HPAD != 0 or WPAD != 0) 474 if DOPAD: 475 data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") 476 else: 477 data_pad = data 478 479 ic = tvm.reduce_axis((0, in_channel), name='ic') 480 kh = tvm.reduce_axis((0, kernel_height), name='kh') 481 kw = tvm.reduce_axis((0, kernel_width), name='kw') 482 483 idxdiv = tvm.indexdiv 484 idxmod = tvm.indexmod 485 486 return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: 487 tvm.sum(data_pad[n, 488 idxdiv(ic, ic_bn), 489 oh * HSTR + kh * dilation_h, 490 ow * WSTR + kw * dilation_w, 491 idxmod(ic, ic_bn)].astype(out_dtype) 492 * kernel[oc_chunk, 493 idxdiv(ic, ic_bn), 494 kh, 495 kw, 496 idxmod(ic, ic_bn), 497 oc_block], 498 axis=[ic, kh, kw]), 499 name='conv2d_NCHWc', tag="conv2d_NCHWc") 500 501 502@tvm.target.generic_func 503def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layout, 504 out_dtype='int32'): 505 """Conv2D operator for nChw[x]c layout. 506 507 Parameters 508 ---------- 509 data : tvm.Tensor 510 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] 511 512 kernel : tvm.Tensor 513 7-D with shape 514 [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4, 515 num_filter_block, 4] 516 517 stride : int or a list/tuple of two ints 518 stride size, or [stride_height, stride_width] 519 520 padding : int or a list/tuple of two ints 521 padding size, or [pad_height, pad_width] 522 523 dilation: int or a list/tuple of two ints 524 dilation size, or [dilation_height, dilation_width] 525 526 layout : str 527 Input data layout 528 529 out_layout : str 530 Output data layout 531 532 out_dtype : str 533 output data type 534 535 Returns 536 ------- 537 output : tvm.Tensor 538 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 539 """ 540 541 return conv2d_NCHWc_int8_compute(data, 542 kernel, 543 strides, 544 padding, 545 dilation, 546 layout, 547 out_layout, 548 out_dtype) 549 550 551def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, out_layout, 552 out_dtype='int32'): 553 """Conv2D operator for nChw[x]c layout. 554 555 Parameters 556 ---------- 557 data : tvm.Tensor 558 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] 559 560 kernel : tvm.Tensor 561 7-D with shape 562 [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4, 563 num_filter_block, 4] 564 565 stride : int or a list/tuple of two ints 566 stride size, or [stride_height, stride_width] 567 568 padding : int or a list/tuple of two ints 569 padding size, or [pad_height, pad_width] 570 571 dilation: int or a list/tuple of two ints 572 dilation size, or [dilation_height, dilation_width] 573 574 layout : str 575 Input data layout 576 577 out_layout : str 578 Output data layout 579 580 out_dtype : str 581 output data type 582 583 Returns 584 ------- 585 output : tvm.Tensor 586 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 587 """ 588 589 # layout and out_layout are not used here, 590 # we keep them for debug convenience when dumping autotvm workload 591 HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) 592 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) 593 dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ 594 else (dilation, dilation) 595 596 n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) 597 in_channel = ic_chunk * ic_bn 598 oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ 599 get_const_tuple(kernel.shape) 600 num_filter = oc_chunk * oc_bn 601 groups = ic_chunk // ic_chunk_group 602 603 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 604 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 605 606 # output shape 607 out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 608 out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 609 oshape = (n, oc_chunk, out_height, out_width, oc_bn) 610 611 # DOPAD 612 DOPAD = (HPAD != 0 or WPAD != 0) 613 if DOPAD: 614 data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") 615 else: 616 data_pad = data 617 618 ic = tvm.reduce_axis((0, in_channel), name='ic') 619 kh = tvm.reduce_axis((0, kernel_height), name='kh') 620 kw = tvm.reduce_axis((0, kernel_width), name='kw') 621 622 if groups == 1: 623 n_elems = 4 624 ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') 625 ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') 626 ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') 627 return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: 628 tvm.sum(data_pad[n, 629 ic_outer, 630 oh * HSTR + kh * dilation_h, 631 ow * WSTR + kw * dilation_w, 632 ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) 633 * kernel[oc_chunk, 634 ic_outer, 635 kh, 636 kw, 637 ic_f_inner, 638 oc_block, 639 ic_s_inner].astype(out_dtype), 640 axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), 641 name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") 642 # for int8 group conv support 643 n_elems = 4 644 ic_chunk = in_channel//ic_bn 645 ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') 646 ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') 647 ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') 648 oshape = (n, oc_chunk, out_height, out_width, oc_bn) 649 return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: 650 tvm.sum(data_pad[n, 651 (occ * oc_bn // (oc_chunk * oc_bn // groups)) 652 * (ic_chunk // groups) + ic_outer, 653 oh * HSTR + kh, 654 ow * WSTR + kw, 655 ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) 656 * kernel[occ, 657 ic_outer, 658 kh, 659 kw, 660 ic_f_inner, 661 oc_block, 662 ic_s_inner].astype(out_dtype), 663 axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), 664 name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") 665 666 667def conv2d_winograd_weight_transform(kernel, tile_size): 668 """Weight transformation for winograd 669 670 Parameters 671 ---------- 672 kernel: Tensor 673 The raw kernel tensor with layout "NCHW". 674 tile_size: int 675 Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) 676 677 Returns 678 ------- 679 output : tvm.Tensor 680 4-D with shape [alpha, alpha, CO, CI] 681 """ 682 shape = get_const_tuple(kernel.shape) 683 assert shape[2] == shape[3], "Only support NxN kernel" 684 685 K = shape[3] 686 r = tile_size + K - 1 687 shape = (r, r) + shape[:2] 688 689 _, _, G = winograd_transform_matrices(tile_size, K, kernel.dtype) 690 691 r_kh = tvm.reduce_axis((0, K), name='r_kh') 692 r_kw = tvm.reduce_axis((0, K), name='r_kw') 693 return tvm.compute(shape, lambda eps, nu, co, ci: 694 tvm.sum(kernel[co][ci][r_kh][r_kw] * 695 G[eps][r_kh] * G[nu][r_kw], 696 axis=[r_kh, r_kw]), name='transform_weight') 697 698 699@tvm.target.generic_func 700def conv2d_winograd_without_weight_transform(input, filter, strides, padding, dilation, 701 layout, out_dtype, tile_size): 702 """Compute convolution in winograd algorithm. The filter is supposed to be transformed 703 in advance. 704 705 Parameters 706 ---------- 707 input : tvm.Tensor 708 4-D with shape [batch, in_height, in_width, in_channel] 709 filter : tvm.Tensor 710 4-D with shape [filter_height, filter_width, in_channel, num_filter] 711 strides : int or a list/tuple of two ints 712 Stride size, or [stride_height, stride_width] 713 padding : int or str 714 Padding size, or ['VALID', 'SAME'] 715 tile_size: int 716 Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) 717 718 Returns 719 ------- 720 output : tvm.Tensor 721 4-D with shape [batch, out_height, out_width, out_channel] 722 """ 723 raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform") 724 725 726def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype): 727 """Weight transformation for winograd 728 Parameters 729 ---------- 730 kernel: Tensor 731 The raw kernel tensor with layout "NCHW". Only 3x3 kernel is supported for now. 732 convolution_algorithm: int 733 The convolution algorithm for Winograd NNPACK. 734 Returns 735 ------- 736 output : tvm.Tensor 737 4-D with shape [alpha, alpha, CO, CI] 738 """ 739 from tvm.contrib import nnpack 740 return nnpack.convolution_inference_weight_transform( 741 kernel, algorithm=convolution_algorithm, dtype=out_dtype) 742 743@tvm.target.generic_func 744def conv2d_winograd_nnpack_without_weight_transform( 745 input, filter, bias, strides, padding, dilation, layout, out_dtype): 746 """Compute convolution in winograd algorithm. The filter is supposed to be transformed 747 in advance. 748 Parameters 749 ---------- 750 input : tvm.Tensor 751 4-D with shape [batch, in_height, in_width, in_channel] 752 filter : tvm.Tensor 753 4-D with shape [num_filter, in_channel, 8, 8] 754 bias : tvm.Tensor 755 1-D with shape [num_filter] 756 strides : int or a list/tuple of two ints 757 Stride size, or [stride_height, stride_width] 758 padding : int or str 759 Padding size, or ['VALID', 'SAME'] 760 Returns 761 ------- 762 output : tvm.Tensor 763 4-D with shape [batch, out_height, out_width, out_channel] 764 """ 765 raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform") 766 767 768@tvm.target.generic_func 769def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): 770 """Group convolution operator in NCHW layout. 771 772 Parameters 773 ---------- 774 Input : tvm.Tensor 775 4-D with shape [batch, in_channel, in_height, in_width] 776 777 Filter : tvm.Tensor 778 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] 779 780 stride : int or a list/tuple of two ints 781 Stride size, or [stride_height, stride_width] 782 783 padding : int or str 784 Padding size, or ['VALID', 'SAME'] 785 786 dilation : int or a list/tuple of two ints 787 dilation size, or [dilation_height, dilation_width] 788 789 groups : int 790 number of groups 791 792 out_dtype : str 793 The output type. This is used for mixed precision. 794 795 Returns 796 ------- 797 Output : tvm.Tensor 798 4-D with shape [batch, out_channel, out_height, out_width] 799 """ 800 if out_dtype is None: 801 out_dtype = Input.dtype 802 assert isinstance(stride, int) or len(stride) == 2 803 assert isinstance(dilation, int) or len(dilation) == 2 804 if isinstance(stride, int): 805 stride_h = stride_w = stride 806 else: 807 stride_h, stride_w = stride 808 809 if isinstance(dilation, int): 810 dilation_h = dilation_w = dilation 811 else: 812 dilation_h, dilation_w = dilation 813 814 batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) 815 num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape) 816 817 assert in_channel % groups == 0, "input channels must divide group size" 818 assert num_filter % groups == 0, "output channels must divide group size" 819 820 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 821 padding, (kernel_h, kernel_w)) 822 # compute the output shape 823 out_channel = num_filter 824 out_height = simplify( 825 (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1) 826 out_width = simplify( 827 (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1) 828 # compute graph 829 pad_before = [0, 0, pad_top, pad_left] 830 pad_after = [0, 0, pad_down, pad_right] 831 temp = pad(Input, pad_before, pad_after, name="pad_temp") 832 rc = tvm.reduce_axis((0, in_channel // groups), name='rc') 833 ry = tvm.reduce_axis((0, kernel_h), name='ry') 834 rx = tvm.reduce_axis((0, kernel_w), name='rx') 835 return tvm.compute( 836 (batch, out_channel, out_height, out_width), 837 lambda nn, ff, yy, xx: tvm.sum( 838 temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc, 839 yy * stride_h + ry * dilation_h, 840 xx * stride_w + rx * dilation_w].astype(out_dtype) * 841 Filter[ff, rc, ry, rx].astype(out_dtype), 842 axis=[rc, ry, rx]), tag='group_conv2d_nchw') 843