1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""Compute definition for conv2d with cuda backend"""
19import tvm
20from tvm import autotvm
21from tvm.contrib import cudnn
22
23from .. import nn, generic
24from ..util import get_const_tuple, traverse_inline
25
26from .conv2d_direct import schedule_direct_cuda
27from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda
28from .conv2d_int8 import conv2d_NCHWc_int8, schedule_conv2d_NCHWc_int8
29
30
31@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd', 'int8'])
32def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'):
33    """Conv2D operator for cuda backend.
34
35    Parameters
36    ----------
37    cfg: ConfigEntity
38        The config for this template
39
40    data : tvm.Tensor
41        4-D with shape [batch, in_channel, in_height, in_width] or
42        5-D with shape [batch, ic_chunk, in_height, in_width, ic_block]
43
44    kernel : tvm.Tensor
45        4-D with shape [num_filter, in_channel, filter_height, filter_width] or
46        6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
47        filter_width, num_filter_block, in_channel_block]
48
49    strides : int or a list/tuple of two ints
50        stride size, or [stride_height, stride_width]
51
52    padding : int or a list/tuple of two ints
53        padding size, or [pad_height, pad_width]
54
55    dilation: int or a list/tuple of two ints
56        dilation size, or [dilation_height, dilation_width]
57
58    layout : str
59        layout of data
60
61    out_dtype: str
62        The output type. This is used for mixed precision.
63
64    Returns
65    -------
66    output : tvm.Tensor
67        4-D with shape [batch, out_channel, out_height, out_width]
68    """
69    target = tvm.target.current_target()
70
71    if "cudnn" in target.libs:
72        if layout == 'NCHW':
73            tensor_format = 0 # CUDNN_TENSOR_NCHW
74            N, _, H, W = get_const_tuple(data.shape)
75        elif layout == 'NHWC':
76            tensor_format = 1 # CUDNN_TENSOR_NHWC
77            N, H, W, _ = get_const_tuple(data.shape)
78        else:
79            raise ValueError("Unsupported layout %s in cudnn" % layout)
80        CO, CI, KH, KW = get_const_tuple(kernel.shape)
81
82        # handle dilation
83        stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
84        pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
85        dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
86
87        OH = (H + 2 * pad_h - KH) // stride_h + 1
88        OW = (W + 2 * pad_w - KW) // stride_w + 1
89        cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
90                    ((KW - 1) * dilation_w + 1))
91
92        return cudnn.conv2d_forward(data,
93                                    kernel,
94                                    stride_h,
95                                    stride_w,
96                                    pad_h,
97                                    pad_w,
98                                    dilation_h,
99                                    dilation_w,
100                                    conv_mode=1,
101                                    tensor_format=tensor_format,
102                                    algo=-1)  # let CUDNN choose the best algo
103
104    if cfg.template_key == 'winograd':
105        return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
106                             pre_computed=False)
107    if cfg.template_key == 'int8':
108        if (data.dtype == 'int8' or data.dtype == 'uint8'):
109            return conv2d_NCHWc_int8(
110                cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
111
112    if layout == 'NCHW':
113        return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
114    if layout == 'HWCN':
115        return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
116    raise ValueError("not support this layout {} yet".format(layout))
117
118
119@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
120                                ["direct", 'winograd', "int8"])
121def schedule_conv2d_nchw_cuda(cfg, outs):
122    """TOPI schedule callback of conv2d for cuda gpu
123
124    Parameters
125    ----------
126    cfg: ConfigEntity
127        The config for this template
128
129    outs: Array of Tensor
130        The computation graph description of conv2d
131        in the format of an array of tensors.
132
133    Returns
134    -------
135    s: Schedule
136        The computation schedule for conv2d.
137    """
138    target = tvm.target.current_target()
139    if 'cudnn' in target.libs:
140        return generic.schedule_extern(outs)
141
142    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
143    s = tvm.create_schedule([x.op for x in outs])
144
145    def _callback(op):
146        if op.tag == 'conv2d_nchw':
147            schedule_direct_cuda(cfg, s, op.output(0))
148        if op.tag == 'conv2d_nchw_winograd':
149            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
150        if op.tag == "conv2d_NCHWc_int8":
151            schedule_conv2d_NCHWc_int8(cfg, s, op.output(0))
152
153    traverse_inline(s, outs[0].op, _callback)
154    return s
155