1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable, too-many-locals
18# pylint: disable=unused-argument, redefined-builtin
19"""Conv2D operators"""
20from __future__ import absolute_import as _abs
21from collections import namedtuple
22import tvm
23
24from .pad import pad
25from .util import get_pad_tuple
26from ..util import simplify, get_const_tuple
27from .winograd_util import winograd_transform_matrices
28
29# workload description of conv2d
30Workload = namedtuple('Workload',
31                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
32                       'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
33
34@tvm.target.generic_func
35def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
36    """Conv2D operator.
37
38    Parameters
39    ----------
40    input : tvm.Tensor
41        4-D with shape [batch, in_channel, in_height, in_width]
42
43    filter : tvm.Tensor
44        4-D with shape [num_filter, in_channel, filter_height, filter_width]
45
46    strides : int or a list/tuple of two ints
47        stride size, or [stride_height, stride_width]
48
49    padding : int or a list/tuple of two ints
50        padding size, or [pad_height, pad_width]
51
52    dilation: int or a list/tuple of two ints
53        dilation size, or [dilation_height, dilation_width]
54
55    layout : str
56        layout of data
57
58    Returns
59    -------
60    output : tvm.Tensor
61        4-D with shape [batch, out_channel, out_height, out_width]
62    """
63    # search platform specific declaration first
64    # default declaration
65    if layout == 'NCHW':
66        return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
67    elif layout == 'HWCN':
68        return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
69    elif layout == 'NHWC':
70        return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
71    raise ValueError("not support this layout {} yet".format(layout))
72
73
74@tvm.target.generic_func
75def conv2d_legalize(attrs, inputs, types):
76    """Legalizes Conv2D op.
77
78    Parameters
79    ----------
80    attrs : tvm.attrs.Attrs
81        Attributes of current convolution
82    inputs : list of tvm.relay.Expr
83        The args of the Relay expr to be legalized
84    types : list of types
85        List of input and output types
86
87    Returns
88    -------
89    result : tvm.relay.Expr
90        The legalized expr
91    """
92    # not to change by default
93    return None
94
95
96@tvm.target.generic_func
97def conv2d_alter_layout(attrs, inputs, tinfos, F):
98    """Change Conv2D layout.
99
100    Parameters
101    ----------
102    attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
103        Attributes of current convolution
104    inputs : nnvm.symbol or tvm.relay.Expr
105        Grouped input symbols
106    tinfos : list
107        Input shape and dtype
108    F: symbol
109        The context, can be either nnvm.sym or relay.op
110
111    Note
112    ----
113    Unlike other TOPI functions, this function operates on both graph level and operator level,
114    so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
115    """
116    # not to change by default
117    return None
118
119@tvm.target.generic_func
120def conv2d_infer_layout(workload, cfg):
121    """Infer input/output shapes and layouts from a workload and cfg.
122
123    Parameters
124    ----------
125    workload : tuple
126        conv2d workload
127
128    cfg : tuple
129        tvm.autotvm config
130
131    Returns
132    -------
133    Output : [tuple of tuple and str, tuple of tuple and str]
134        Input shapes and layouts, and output shapes and layouts
135    """
136    raise ValueError("missing register for topi.nn.conv2d_infer_layout")
137
138
139
140def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
141    """ Get the workload structure. """
142    if data_layout == 'NCHW':
143        _, CI, IH, IW = [x.value for x in data.shape]
144    elif data_layout == 'NHWC':
145        _, IH, IW, CI = [x.value for x in data.shape]
146    elif data_layout == 'HWCN':
147        IH, IW, CI, _ = [x.value for x in data.shape]
148    else:
149        raise ValueError("not support this layout {} yet".format(data_layout))
150
151    if data_layout == 'NCHW':
152        CO, CIG, KH, KW = [x.value for x in kernel.shape]
153    else:
154        KH, KW, CIG, CO = [x.value for x in kernel.shape]
155
156    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
157    GRPS = CI // CIG
158    if isinstance(stride, (tuple, list)):
159        HSTR, WSTR = stride
160    else:
161        HSTR, WSTR = stride, stride
162    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
163        "Do not support inputs with different data types now. ' \
164        '{} vs. {}".format(data.dtype, kernel.dtype)
165    return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
166
167
168def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
169    """Convolution operator in NCHW layout.
170
171    Parameters
172    ----------
173    Input : tvm.Tensor
174        4-D with shape [batch, in_channel, in_height, in_width]
175
176    Filter : tvm.Tensor
177        4-D with shape [num_filter, in_channel, filter_height, filter_width]
178
179    stride : int or a list/tuple of two ints
180        Stride size, or [stride_height, stride_width]
181
182    padding : int or str
183        Padding size, or ['VALID', 'SAME']
184
185    dilation: int or a list/tuple of two ints
186        dilation size, or [dilation_height, dilation_width]
187
188    Returns
189    -------
190    Output : tvm.Tensor
191        4-D with shape [batch, out_channel, out_height, out_width]
192    """
193    if out_dtype is None:
194        out_dtype = Input.dtype
195    assert isinstance(stride, int) or len(stride) == 2
196    assert isinstance(dilation, int) or len(dilation) == 2
197    if isinstance(stride, int):
198        stride_h = stride_w = stride
199    else:
200        stride_h, stride_w = stride
201
202    if isinstance(dilation, int):
203        dilation_h = dilation_w = dilation
204    else:
205        dilation_h, dilation_w = dilation
206
207    batch, in_channel, in_height, in_width = Input.shape
208    num_filter, channel, kernel_h, kernel_w = Filter.shape
209    # compute the output shape
210    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
211    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
212    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
213        padding, (dilated_kernel_h, dilated_kernel_w))
214    out_channel = num_filter
215    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
216    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
217    # compute graph
218    pad_before = [0, 0, pad_top, pad_left]
219    pad_after = [0, 0, pad_down, pad_right]
220    temp = pad(Input, pad_before, pad_after, name="pad_temp")
221    rc = tvm.reduce_axis((0, in_channel), name='rc')
222    ry = tvm.reduce_axis((0, kernel_h), name='ry')
223    rx = tvm.reduce_axis((0, kernel_w), name='rx')
224
225    return tvm.compute(
226        (batch, out_channel, out_height, out_width),
227        lambda nn, ff, yy, xx: tvm.sum(
228            temp[nn, rc, yy * stride_h + ry * dilation_h,
229                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
230            Filter[ff, rc, ry, rx].astype(out_dtype),
231            axis=[rc, ry, rx]), tag="conv2d_nchw")
232
233
234def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
235    """Convolution operator in HWCN layout.
236
237    Parameters
238    ----------
239    Input : tvm.Tensor
240        4-D with shape [in_height, in_width, in_channel, batch]
241
242    Filter : tvm.Tensor
243        4-D with shape [filter_height, filter_width, in_channel, num_filter]
244
245    stride : int or a list/tuple of two ints
246        Stride size, or [stride_height, stride_width]
247
248    padding : int or str
249        Padding size, or ['VALID', 'SAME']
250
251    dilation: int or a list/tuple of two ints
252        dilation size, or [dilation_height, dilation_width]
253
254    Returns
255    -------
256    output : tvm.Tensor
257        4-D with shape [out_height, out_width, out_channel, batch]
258    """
259    if out_dtype is None:
260        out_dtype = Input.dtype
261    assert isinstance(stride, int) or len(stride) == 2
262    assert isinstance(dilation, int) or len(dilation) == 2
263
264    if isinstance(stride, int):
265        stride_h = stride_w = stride
266    else:
267        stride_h, stride_w = stride
268
269    if isinstance(dilation, int):
270        dilation_h = dilation_w = dilation
271    else:
272        dilation_h, dilation_w = dilation
273
274    in_height, in_width, in_channel, batch = Input.shape
275    kernel_h, kernel_w, channel, num_filter = Filter.shape
276    # compute the output shape
277    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
278    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
279    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
280        padding, (dilated_kernel_h, dilated_kernel_w))
281    out_channel = num_filter
282    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
283    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
284    pad_before = [pad_top, pad_left, 0, 0]
285    pad_after = [pad_down, pad_right, 0, 0]
286    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
287    rc = tvm.reduce_axis((0, in_channel), name='rc')
288    ry = tvm.reduce_axis((0, kernel_h), name='ry')
289    rx = tvm.reduce_axis((0, kernel_w), name='rx')
290    Output = tvm.compute(
291        (out_height, out_width, out_channel, batch),
292        lambda yy, xx, ff, nn: tvm.sum(
293            PaddedInput[yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w,
294                        rc, nn].astype(out_dtype) *
295            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
296        name="Conv2dOutput", tag="conv2d_hwcn")
297    return Output
298
299
300def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
301    """Convolution operator in NHWC layout.
302
303    Parameters
304    ----------
305    Input : tvm.Tensor
306        4-D with shape [batch, in_height, in_width, in_channel]
307
308    Filter : tvm.Tensor
309        4-D with shape [filter_height, filter_width, in_channel, num_filter]
310
311    stride : int or a list/tuple of two ints
312        Stride size, or [stride_height, stride_width]
313
314    padding : int or str
315        Padding size, or ['VALID', 'SAME']
316
317    dilation: int or a list/tuple of two ints
318        dilation size, or [dilation_height, dilation_width]
319
320    Returns
321    -------
322    output : tvm.Tensor
323        4-D with shape [batch, out_height, out_width, out_channel]
324    """
325    assert isinstance(stride, int) or len(stride) == 2
326    assert isinstance(dilation, int) or len(dilation) == 2
327
328    if isinstance(stride, int):
329        stride_h = stride_w = stride
330    else:
331        stride_h, stride_w = stride
332
333    if isinstance(dilation, int):
334        dilation_h = dilation_w = dilation
335    else:
336        dilation_h, dilation_w = dilation
337
338    batch, in_height, in_width, in_channel = Input.shape
339    kernel_h, kernel_w, channel, num_filter = Filter.shape
340    # compute the output shape
341    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
342    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
343    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
344        padding, (dilated_kernel_h, dilated_kernel_w))
345    out_channel = num_filter
346    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
347    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
348    pad_before = [0, pad_top, pad_left, 0]
349    pad_after = [0, pad_down, pad_right, 0]
350    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
351    rc = tvm.reduce_axis((0, in_channel), name='rc')
352    ry = tvm.reduce_axis((0, kernel_h), name='ry')
353    rx = tvm.reduce_axis((0, kernel_w), name='rx')
354    Output = tvm.compute(
355        (batch, out_height, out_width, out_channel),
356        lambda nn, yy, xx, ff: tvm.sum(
357            PaddedInput[nn, yy * stride_h + ry * dilation_h,
358                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
359            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
360        name="Conv2dOutput", tag="conv2d_nhwc")
361    return Output
362
363
364@tvm.target.generic_func
365def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'):
366    """Conv2D operator for nChw[x]c layout.
367
368    Parameters
369    ----------
370    data : tvm.Tensor
371        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
372
373    kernel : tvm.Tensor
374        6-D with shape
375        [num_filter_chunk, in_channel_chunk, filter_height, filter_width,
376        in_channel_block, num_filter_block]
377
378    stride : int or a list/tuple of two ints
379        stride size, or [stride_height, stride_width]
380
381    padding : int or a list/tuple of two ints
382        padding size, or [pad_height, pad_width]
383
384    dilation: int or a list/tuple of two ints
385        dilation size, or [dilation_height, dilation_width]
386
387    layout : str
388        Input data layout
389
390    out_layout : str
391        Output data layout
392
393    out_dtype : str
394        output data type
395
396    Returns
397    -------
398    output : tvm.Tensor
399        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
400    """
401
402    return conv2d_NCHWc_compute(data,
403                                kernel,
404                                stride,
405                                padding,
406                                dilation,
407                                layout,
408                                out_layout,
409                                out_dtype)
410
411
412def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
413    """Conv2D operator compute for nChw[x]c layout.
414
415    Parameters
416    ----------
417    data : tvm.Tensor
418        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
419
420    kernel : tvm.Tensor
421        6-D with shape
422        [num_filter_chunk, in_channel_chunk, filter_height, filter_width,
423        in_channel_block, num_filter_block]
424
425    stride : int or a list/tuple of two ints
426        stride size, or [stride_height, stride_width]
427
428    padding : int or a list/tuple of two ints
429        padding size, or [pad_height, pad_width]
430
431    dilation: int or a list/tuple of two ints
432        dilation size, or [dilation_height, dilation_width]
433
434    layout : str
435        Input data layout
436
437    out_layout : str
438        Output data layout
439
440    out_dtype : str
441        output data type
442
443    Returns
444    -------
445    output : tvm.Tensor
446        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
447    """
448
449    # layout and out_layout are not used here,
450    # we keep them for debug convenience when dumping autotvm workload
451    HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
452    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
453    dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
454        else (dilation, dilation)
455
456    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
457    in_channel = ic_chunk * ic_bn
458    target = tvm.target.current_target(allow_none=False)
459    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
460        get_const_tuple(kernel.shape)
461    num_filter = oc_chunk * oc_bn
462    groups = ic_chunk // ic_chunk_group
463
464    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
465    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
466
467    # output shape
468    out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
469    out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
470    oshape = (n, oc_chunk, out_height, out_width, oc_bn)
471
472    # DOPAD
473    DOPAD = (HPAD != 0 or WPAD != 0)
474    if DOPAD:
475        data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
476    else:
477        data_pad = data
478
479    ic = tvm.reduce_axis((0, in_channel), name='ic')
480    kh = tvm.reduce_axis((0, kernel_height), name='kh')
481    kw = tvm.reduce_axis((0, kernel_width), name='kw')
482
483    idxdiv = tvm.indexdiv
484    idxmod = tvm.indexmod
485
486    return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
487                       tvm.sum(data_pad[n,
488                                        idxdiv(ic, ic_bn),
489                                        oh * HSTR + kh * dilation_h,
490                                        ow * WSTR + kw * dilation_w,
491                                        idxmod(ic, ic_bn)].astype(out_dtype)
492                               * kernel[oc_chunk,
493                                        idxdiv(ic, ic_bn),
494                                        kh,
495                                        kw,
496                                        idxmod(ic, ic_bn),
497                                        oc_block],
498                               axis=[ic, kh, kw]),
499                       name='conv2d_NCHWc', tag="conv2d_NCHWc")
500
501
502@tvm.target.generic_func
503def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layout,
504                      out_dtype='int32'):
505    """Conv2D operator for nChw[x]c layout.
506
507    Parameters
508    ----------
509    data : tvm.Tensor
510        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
511
512    kernel : tvm.Tensor
513        7-D with shape
514        [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
515        num_filter_block, 4]
516
517    stride : int or a list/tuple of two ints
518        stride size, or [stride_height, stride_width]
519
520    padding : int or a list/tuple of two ints
521        padding size, or [pad_height, pad_width]
522
523    dilation: int or a list/tuple of two ints
524        dilation size, or [dilation_height, dilation_width]
525
526    layout : str
527        Input data layout
528
529    out_layout : str
530        Output data layout
531
532    out_dtype : str
533        output data type
534
535    Returns
536    -------
537    output : tvm.Tensor
538        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
539    """
540
541    return conv2d_NCHWc_int8_compute(data,
542                                     kernel,
543                                     strides,
544                                     padding,
545                                     dilation,
546                                     layout,
547                                     out_layout,
548                                     out_dtype)
549
550
551def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, out_layout,
552                              out_dtype='int32'):
553    """Conv2D operator for nChw[x]c layout.
554
555    Parameters
556    ----------
557    data : tvm.Tensor
558        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
559
560    kernel : tvm.Tensor
561        7-D with shape
562        [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
563        num_filter_block, 4]
564
565    stride : int or a list/tuple of two ints
566        stride size, or [stride_height, stride_width]
567
568    padding : int or a list/tuple of two ints
569        padding size, or [pad_height, pad_width]
570
571    dilation: int or a list/tuple of two ints
572        dilation size, or [dilation_height, dilation_width]
573
574    layout : str
575        Input data layout
576
577    out_layout : str
578        Output data layout
579
580    out_dtype : str
581        output data type
582
583    Returns
584    -------
585    output : tvm.Tensor
586        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
587    """
588
589    # layout and out_layout are not used here,
590    # we keep them for debug convenience when dumping autotvm workload
591    HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
592    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
593    dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
594        else (dilation, dilation)
595
596    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
597    in_channel = ic_chunk * ic_bn
598    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
599        get_const_tuple(kernel.shape)
600    num_filter = oc_chunk * oc_bn
601    groups = ic_chunk // ic_chunk_group
602
603    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
604    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
605
606    # output shape
607    out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
608    out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
609    oshape = (n, oc_chunk, out_height, out_width, oc_bn)
610
611    # DOPAD
612    DOPAD = (HPAD != 0 or WPAD != 0)
613    if DOPAD:
614        data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
615    else:
616        data_pad = data
617
618    ic = tvm.reduce_axis((0, in_channel), name='ic')
619    kh = tvm.reduce_axis((0, kernel_height), name='kh')
620    kw = tvm.reduce_axis((0, kernel_width), name='kw')
621
622    if groups == 1:
623        n_elems = 4
624        ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
625        ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
626        ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
627        return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
628                           tvm.sum(data_pad[n,
629                                            ic_outer,
630                                            oh * HSTR + kh * dilation_h,
631                                            ow * WSTR + kw * dilation_w,
632                                            ic_f_inner * n_elems + ic_s_inner].astype(out_dtype)
633                                   * kernel[oc_chunk,
634                                            ic_outer,
635                                            kh,
636                                            kw,
637                                            ic_f_inner,
638                                            oc_block,
639                                            ic_s_inner].astype(out_dtype),
640                                   axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
641                           name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
642    # for int8 group conv support
643    n_elems = 4
644    ic_chunk = in_channel//ic_bn
645    ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer')
646    ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
647    ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
648    oshape = (n, oc_chunk, out_height, out_width, oc_bn)
649    return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
650                       tvm.sum(data_pad[n,
651                                        (occ * oc_bn // (oc_chunk * oc_bn // groups))
652                                        * (ic_chunk // groups) + ic_outer,
653                                        oh * HSTR + kh,
654                                        ow * WSTR + kw,
655                                        ic_f_inner * n_elems +  ic_s_inner].astype(out_dtype)
656                               * kernel[occ,
657                                        ic_outer,
658                                        kh,
659                                        kw,
660                                        ic_f_inner,
661                                        oc_block,
662                                        ic_s_inner].astype(out_dtype),
663                               axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
664                       name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
665
666
667def conv2d_winograd_weight_transform(kernel, tile_size):
668    """Weight transformation for winograd
669
670    Parameters
671    ----------
672    kernel: Tensor
673        The raw kernel tensor with layout "NCHW".
674    tile_size: int
675        Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
676
677    Returns
678    -------
679    output : tvm.Tensor
680        4-D with shape [alpha, alpha, CO, CI]
681    """
682    shape = get_const_tuple(kernel.shape)
683    assert shape[2] == shape[3], "Only support NxN kernel"
684
685    K = shape[3]
686    r = tile_size + K - 1
687    shape = (r, r) + shape[:2]
688
689    _, _, G = winograd_transform_matrices(tile_size, K, kernel.dtype)
690
691    r_kh = tvm.reduce_axis((0, K), name='r_kh')
692    r_kw = tvm.reduce_axis((0, K), name='r_kw')
693    return tvm.compute(shape, lambda eps, nu, co, ci:
694                       tvm.sum(kernel[co][ci][r_kh][r_kw] *
695                               G[eps][r_kh] * G[nu][r_kw],
696                               axis=[r_kh, r_kw]), name='transform_weight')
697
698
699@tvm.target.generic_func
700def conv2d_winograd_without_weight_transform(input, filter, strides, padding, dilation,
701                                             layout, out_dtype, tile_size):
702    """Compute convolution in winograd algorithm. The filter is supposed to be transformed
703    in advance.
704
705    Parameters
706    ----------
707    input : tvm.Tensor
708        4-D with shape [batch, in_height, in_width, in_channel]
709    filter : tvm.Tensor
710        4-D with shape [filter_height, filter_width, in_channel, num_filter]
711    strides : int or a list/tuple of two ints
712        Stride size, or [stride_height, stride_width]
713    padding : int or str
714        Padding size, or ['VALID', 'SAME']
715    tile_size: int
716        Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
717
718    Returns
719    -------
720    output : tvm.Tensor
721        4-D with shape [batch, out_height, out_width, out_channel]
722    """
723    raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
724
725
726def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
727    """Weight transformation for winograd
728     Parameters
729    ----------
730    kernel: Tensor
731        The raw kernel tensor with layout "NCHW". Only 3x3 kernel is supported for now.
732    convolution_algorithm: int
733        The convolution algorithm for Winograd NNPACK.
734     Returns
735    -------
736    output : tvm.Tensor
737        4-D with shape [alpha, alpha, CO, CI]
738    """
739    from tvm.contrib import nnpack
740    return nnpack.convolution_inference_weight_transform(
741        kernel, algorithm=convolution_algorithm, dtype=out_dtype)
742
743@tvm.target.generic_func
744def conv2d_winograd_nnpack_without_weight_transform(
745        input, filter, bias, strides, padding, dilation, layout, out_dtype):
746    """Compute convolution in winograd algorithm. The filter is supposed to be transformed
747    in advance.
748     Parameters
749    ----------
750    input : tvm.Tensor
751        4-D with shape [batch, in_height, in_width, in_channel]
752    filter : tvm.Tensor
753        4-D with shape [num_filter, in_channel, 8, 8]
754    bias : tvm.Tensor
755        1-D with shape [num_filter]
756    strides : int or a list/tuple of two ints
757        Stride size, or [stride_height, stride_width]
758    padding : int or str
759        Padding size, or ['VALID', 'SAME']
760     Returns
761    -------
762    output : tvm.Tensor
763        4-D with shape [batch, out_height, out_width, out_channel]
764    """
765    raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
766
767
768@tvm.target.generic_func
769def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
770    """Group convolution operator in NCHW layout.
771
772    Parameters
773    ----------
774    Input : tvm.Tensor
775        4-D with shape [batch, in_channel, in_height, in_width]
776
777    Filter : tvm.Tensor
778        4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
779
780    stride : int or a list/tuple of two ints
781        Stride size, or [stride_height, stride_width]
782
783    padding : int or str
784        Padding size, or ['VALID', 'SAME']
785
786    dilation : int or a list/tuple of two ints
787        dilation size, or [dilation_height, dilation_width]
788
789    groups : int
790        number of groups
791
792    out_dtype : str
793        The output type. This is used for mixed precision.
794
795    Returns
796    -------
797    Output : tvm.Tensor
798        4-D with shape [batch, out_channel, out_height, out_width]
799    """
800    if out_dtype is None:
801        out_dtype = Input.dtype
802    assert isinstance(stride, int) or len(stride) == 2
803    assert isinstance(dilation, int) or len(dilation) == 2
804    if isinstance(stride, int):
805        stride_h = stride_w = stride
806    else:
807        stride_h, stride_w = stride
808
809    if isinstance(dilation, int):
810        dilation_h = dilation_w = dilation
811    else:
812        dilation_h, dilation_w = dilation
813
814    batch, in_channel, in_height, in_width = get_const_tuple(Input.shape)
815    num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape)
816
817    assert in_channel % groups == 0, "input channels must divide group size"
818    assert num_filter % groups == 0, "output channels must divide group size"
819
820    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
821        padding, (kernel_h, kernel_w))
822    # compute the output shape
823    out_channel = num_filter
824    out_height = simplify(
825        (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1)
826    out_width = simplify(
827        (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1)
828    # compute graph
829    pad_before = [0, 0, pad_top, pad_left]
830    pad_after = [0, 0, pad_down, pad_right]
831    temp = pad(Input, pad_before, pad_after, name="pad_temp")
832    rc = tvm.reduce_axis((0, in_channel // groups), name='rc')
833    ry = tvm.reduce_axis((0, kernel_h), name='ry')
834    rx = tvm.reduce_axis((0, kernel_w), name='rx')
835    return tvm.compute(
836        (batch, out_channel, out_height, out_width),
837        lambda nn, ff, yy, xx: tvm.sum(
838            temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc,
839                 yy * stride_h + ry * dilation_h,
840                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
841            Filter[ff, rc, ry, rx].astype(out_dtype),
842            axis=[rc, ry, rx]), tag='group_conv2d_nchw')
843