1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""Conv2d transpose template for cuda backend"""
19
20import tvm
21from tvm import autotvm
22from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
23from .. import nn, generic
24from ..util import equal_const_int, get_const_tuple, traverse_inline
25
26
27@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
28def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
29    """Transposed 2D convolution nchw forward operator.
30
31    Parameters
32    ----------
33    cfg: ConfigEntity
34        The config for this template
35    Input : tvm.Tensor
36        4-D with shape [batch, in_channel, in_height, in_width]
37    Filter : tvm.Tensor
38        4-D with shape [in_channel, num_filter, filter_height, filter_width]
39    strides : tuple of two ints
40        The spatial stride along height and width
41    padding : int or str
42        Padding size, or ['VALID', 'SAME']
43    out_dtype: str
44        The output type. This is used in mixed precision
45
46    Returns
47    -------
48    Output : tvm.Tensor
49        4-D with shape [batch, out_channel, out_height, out_width]
50    """
51    batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
52    _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
53    stride_h, stride_w = strides
54
55    # attach stride info to config, this is used in schedule space definition
56    cfg.stride = strides
57
58    # padding stage
59    fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w))
60    bpad_top = filter_h - 1 - fpad_top
61    bpad_bottom = filter_h - 1 - fpad_bottom
62    bpad_left = filter_w - 1 - fpad_left
63    bpad_right = filter_w - 1 - fpad_right
64
65    # padding stage
66    FirstPad = nn.pad(Input,
67                      [0, 0, (bpad_top + stride_h - 1) // stride_h,
68                       (bpad_left + stride_w - 1) // stride_w],
69                      [0, 0, (bpad_bottom + stride_h - 1) // stride_h,
70                       (bpad_right + stride_w - 1) // stride_w], name='FirstPad')
71
72    idxdiv = tvm.indexdiv
73    idxmod = tvm.indexmod
74    # remove extra padding introduced by dilatation
75    border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
76    border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
77
78    # dilation stage
79    data = FirstPad
80    strides = [1, 1, stride_h, stride_w]
81    n = len(data.shape)
82
83    def _dilate(*indices):
84        not_zero = []
85        index_tuple = []
86        for i in range(n):
87            if not equal_const_int(strides[i], 1):
88                index_tuple.append(idxdiv(indices[i], strides[i]))
89                not_zero.append(idxmod(indices[i], strides[i]).equal(0))
90            else:
91                index_tuple.append(indices[i])
92        if not_zero:
93            not_zero = tvm.all(*not_zero)
94            return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
95        return data(*index_tuple)
96
97    # convolution stage
98    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
99    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
100    dc = tvm.reduce_axis((0, in_c), name='dc')
101    dh = tvm.reduce_axis((0, filter_h), name='dh')
102    dw = tvm.reduce_axis((0, filter_w), name='dw')
103
104    Output = tvm.compute(
105        (batch, out_c, out_h, out_w),
106        lambda b, c, h, w: tvm.sum(
107            _dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) *
108            Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype),
109            axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
110
111    return Output
112
113@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
114                                     ['cuda', 'gpu'], 'direct')
115def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
116    """TOPI Schedule callback for conv2d transpose operator.
117
118    Parameters
119    ----------
120    cfg: ConfigEntity
121        The parameters for this template
122
123    outs: Array of Tensor
124        The computation graph description of conv2d transpose
125        in the format of an array of tensors.
126
127    Returns
128    -------
129    s: Schedule
130        The computation schedule for conv2d transpose.
131    """
132    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
133    s = tvm.create_schedule([x.op for x in outs])
134
135    def _fallback_schedule(N, F, Y, X):
136        # pylint: disable=unused-argument
137        # split N (batch dimension)
138        if N > 1:
139            cfg["tile_n"] = SplitEntity([-1, 1, 1, 4])
140        else:
141            cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
142        # split F (output channel dimension)
143        cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
144        # split Y (height dimension)
145        y_split_factor = 1
146        for candidate in range(5, 17):
147            if Y % candidate == 0:
148                y_split_factor = candidate
149                break
150        cfg["tile_y"] = SplitEntity([-1, 1, 1, y_split_factor])
151        # split X (width dimension)
152        x_split_factor = 1
153        for candidate in range(5, 17):
154            if X % candidate == 0:
155                x_split_factor = candidate
156                break
157        cfg["tile_x"] = SplitEntity([-1, x_split_factor, 1, 1])
158        # split RC (input channel dimension, which is a reduction axis)
159        cfg["tile_rc"] = SplitEntity([-1, 1, 16])
160        # other configurations
161        cfg["fuse_yx"] = OtherOptionEntity(False)
162        cfg["unroll_explicit"] = OtherOptionEntity(True)
163        cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
164
165    def _callback(op):
166        if op.tag == 'conv2d_transpose_nchw':
167            pad_data = op.input_tensors[0]
168            kernel = op.input_tensors[1]
169            conv = op.output(0)
170
171            ##### space definition begin #####
172            n, f, y, x = s[conv].op.axis
173            rc = s[conv].op.reduce_axis[0]
174            cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
175            cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
176            cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
177            cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
178            cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
179            cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
180
181            target = tvm.target.current_target()
182            if target.target_name in ['nvptx', 'rocm']:
183                cfg.define_knob("unroll_explicit", [1])
184            else:
185                cfg.define_knob("unroll_explicit", [0, 1])
186
187            if cfg.is_fallback:
188                N, F, Y, X = get_const_tuple(conv.shape)
189                _fallback_schedule(N, F, Y, X)
190
191            ##### space definition end #####
192
193            if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
194                s[kernel].compute_inline()
195
196            if conv.op in s.outputs:
197                output = conv
198                OL = s.cache_write(conv, 'local')
199            else:
200                output = s.outputs[0].output(0)
201                s[conv].set_scope('local')
202                OL = conv
203
204            # create cache stage
205            s[pad_data].set_scope('shared')
206            AA = pad_data
207            WW = s.cache_read(kernel, 'shared', [OL])
208
209            # tile and bind spatial axes
210            n, f, y, x = s[output].op.axis
211            kernel_scope, n = s[output].split(n, nparts=1)
212            bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
213            bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
214            by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
215            bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
216
217            s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
218            s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
219            s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
220            s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
221            s[output].bind(vn, tvm.thread_axis("vthread"))
222            s[output].bind(vf, tvm.thread_axis("vthread"))
223            s[output].bind(vy, tvm.thread_axis("vthread"))
224            s[output].bind(vx, tvm.thread_axis("vthread"))
225
226            cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
227
228            if cfg["fuse_yx"].val:
229                s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
230                s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
231                tyx = s[output].fuse(ty, tx)
232                s[output].bind(s[output].fuse(ty, tx), tvm.thread_axis("threadIdx.x"))
233                s[OL].compute_at(s[output], tyx)
234
235                # number of threads
236                n_tz = cfg["tile_n"].size[2]
237                n_ty = cfg["tile_f"].size[2]
238                n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
239            else:
240                s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
241                s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
242                s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
243                s[OL].compute_at(s[output], tx)
244
245                # number of threads
246                n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
247                n_ty = cfg["tile_y"].size[2]
248                n_tx = cfg["tile_x"].size[2]
249
250            # tile reduction axes
251            n, f, y, x = s[OL].op.axis
252            rc, ry, rx = s[OL].op.reduce_axis
253            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
254            s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
255
256            s[AA].compute_at(s[OL], rx)
257            s[WW].compute_at(s[OL], rx)
258
259            # cooperative fetching
260            for load in [AA, WW]:
261                n, f, y, x = s[load].op.axis
262                fused = s[load].fuse(f, y, x)
263                tz, fused = s[load].split(fused, nparts=n_tz)
264                ty, fused = s[load].split(fused, nparts=n_ty)
265                tx, fused = s[load].split(fused, nparts=n_tx)
266                s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
267                s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
268                s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
269
270            s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
271            s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
272
273    traverse_inline(s, outs[0].op, _callback)
274
275    return s
276