1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""Conv2d transpose template for cuda backend""" 19 20import tvm 21from tvm import autotvm 22from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity 23from .. import nn, generic 24from ..util import equal_const_int, get_const_tuple, traverse_inline 25 26 27@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct") 28def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): 29 """Transposed 2D convolution nchw forward operator. 30 31 Parameters 32 ---------- 33 cfg: ConfigEntity 34 The config for this template 35 Input : tvm.Tensor 36 4-D with shape [batch, in_channel, in_height, in_width] 37 Filter : tvm.Tensor 38 4-D with shape [in_channel, num_filter, filter_height, filter_width] 39 strides : tuple of two ints 40 The spatial stride along height and width 41 padding : int or str 42 Padding size, or ['VALID', 'SAME'] 43 out_dtype: str 44 The output type. This is used in mixed precision 45 46 Returns 47 ------- 48 Output : tvm.Tensor 49 4-D with shape [batch, out_channel, out_height, out_width] 50 """ 51 batch, in_c, in_h, in_w = get_const_tuple(Input.shape) 52 _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape) 53 stride_h, stride_w = strides 54 55 # attach stride info to config, this is used in schedule space definition 56 cfg.stride = strides 57 58 # padding stage 59 fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w)) 60 bpad_top = filter_h - 1 - fpad_top 61 bpad_bottom = filter_h - 1 - fpad_bottom 62 bpad_left = filter_w - 1 - fpad_left 63 bpad_right = filter_w - 1 - fpad_right 64 65 # padding stage 66 FirstPad = nn.pad(Input, 67 [0, 0, (bpad_top + stride_h - 1) // stride_h, 68 (bpad_left + stride_w - 1) // stride_w], 69 [0, 0, (bpad_bottom + stride_h - 1) // stride_h, 70 (bpad_right + stride_w - 1) // stride_w], name='FirstPad') 71 72 idxdiv = tvm.indexdiv 73 idxmod = tvm.indexmod 74 # remove extra padding introduced by dilatation 75 border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) 76 border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) 77 78 # dilation stage 79 data = FirstPad 80 strides = [1, 1, stride_h, stride_w] 81 n = len(data.shape) 82 83 def _dilate(*indices): 84 not_zero = [] 85 index_tuple = [] 86 for i in range(n): 87 if not equal_const_int(strides[i], 1): 88 index_tuple.append(idxdiv(indices[i], strides[i])) 89 not_zero.append(idxmod(indices[i], strides[i]).equal(0)) 90 else: 91 index_tuple.append(indices[i]) 92 if not_zero: 93 not_zero = tvm.all(*not_zero) 94 return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) 95 return data(*index_tuple) 96 97 # convolution stage 98 out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h 99 out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w 100 dc = tvm.reduce_axis((0, in_c), name='dc') 101 dh = tvm.reduce_axis((0, filter_h), name='dh') 102 dw = tvm.reduce_axis((0, filter_w), name='dw') 103 104 Output = tvm.compute( 105 (batch, out_c, out_h, out_w), 106 lambda b, c, h, w: tvm.sum( 107 _dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) * 108 Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype), 109 axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") 110 111 return Output 112 113@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, 114 ['cuda', 'gpu'], 'direct') 115def schedule_conv2d_transpose_nchw_cuda(cfg, outs): 116 """TOPI Schedule callback for conv2d transpose operator. 117 118 Parameters 119 ---------- 120 cfg: ConfigEntity 121 The parameters for this template 122 123 outs: Array of Tensor 124 The computation graph description of conv2d transpose 125 in the format of an array of tensors. 126 127 Returns 128 ------- 129 s: Schedule 130 The computation schedule for conv2d transpose. 131 """ 132 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 133 s = tvm.create_schedule([x.op for x in outs]) 134 135 def _fallback_schedule(N, F, Y, X): 136 # pylint: disable=unused-argument 137 # split N (batch dimension) 138 if N > 1: 139 cfg["tile_n"] = SplitEntity([-1, 1, 1, 4]) 140 else: 141 cfg["tile_n"] = SplitEntity([1, 1, 1, 1]) 142 # split F (output channel dimension) 143 cfg["tile_f"] = SplitEntity([-1, 1, 64, 1]) 144 # split Y (height dimension) 145 y_split_factor = 1 146 for candidate in range(5, 17): 147 if Y % candidate == 0: 148 y_split_factor = candidate 149 break 150 cfg["tile_y"] = SplitEntity([-1, 1, 1, y_split_factor]) 151 # split X (width dimension) 152 x_split_factor = 1 153 for candidate in range(5, 17): 154 if X % candidate == 0: 155 x_split_factor = candidate 156 break 157 cfg["tile_x"] = SplitEntity([-1, x_split_factor, 1, 1]) 158 # split RC (input channel dimension, which is a reduction axis) 159 cfg["tile_rc"] = SplitEntity([-1, 1, 16]) 160 # other configurations 161 cfg["fuse_yx"] = OtherOptionEntity(False) 162 cfg["unroll_explicit"] = OtherOptionEntity(True) 163 cfg["auto_unroll_max_step"] = OtherOptionEntity(1500) 164 165 def _callback(op): 166 if op.tag == 'conv2d_transpose_nchw': 167 pad_data = op.input_tensors[0] 168 kernel = op.input_tensors[1] 169 conv = op.output(0) 170 171 ##### space definition begin ##### 172 n, f, y, x = s[conv].op.axis 173 rc = s[conv].op.reduce_axis[0] 174 cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) 175 cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) 176 cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) 177 cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) 178 cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) 179 cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) 180 181 target = tvm.target.current_target() 182 if target.target_name in ['nvptx', 'rocm']: 183 cfg.define_knob("unroll_explicit", [1]) 184 else: 185 cfg.define_knob("unroll_explicit", [0, 1]) 186 187 if cfg.is_fallback: 188 N, F, Y, X = get_const_tuple(conv.shape) 189 _fallback_schedule(N, F, Y, X) 190 191 ##### space definition end ##### 192 193 if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: 194 s[kernel].compute_inline() 195 196 if conv.op in s.outputs: 197 output = conv 198 OL = s.cache_write(conv, 'local') 199 else: 200 output = s.outputs[0].output(0) 201 s[conv].set_scope('local') 202 OL = conv 203 204 # create cache stage 205 s[pad_data].set_scope('shared') 206 AA = pad_data 207 WW = s.cache_read(kernel, 'shared', [OL]) 208 209 # tile and bind spatial axes 210 n, f, y, x = s[output].op.axis 211 kernel_scope, n = s[output].split(n, nparts=1) 212 bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) 213 bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) 214 by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) 215 bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) 216 217 s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) 218 s[output].bind(bn, tvm.thread_axis("blockIdx.z")) 219 s[output].bind(bf, tvm.thread_axis("blockIdx.y")) 220 s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x")) 221 s[output].bind(vn, tvm.thread_axis("vthread")) 222 s[output].bind(vf, tvm.thread_axis("vthread")) 223 s[output].bind(vy, tvm.thread_axis("vthread")) 224 s[output].bind(vx, tvm.thread_axis("vthread")) 225 226 cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf 227 228 if cfg["fuse_yx"].val: 229 s[output].bind(tn, tvm.thread_axis("threadIdx.z")) 230 s[output].bind(tf, tvm.thread_axis("threadIdx.y")) 231 tyx = s[output].fuse(ty, tx) 232 s[output].bind(s[output].fuse(ty, tx), tvm.thread_axis("threadIdx.x")) 233 s[OL].compute_at(s[output], tyx) 234 235 # number of threads 236 n_tz = cfg["tile_n"].size[2] 237 n_ty = cfg["tile_f"].size[2] 238 n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] 239 else: 240 s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z")) 241 s[output].bind(ty, tvm.thread_axis("threadIdx.y")) 242 s[output].bind(tx, tvm.thread_axis("threadIdx.x")) 243 s[OL].compute_at(s[output], tx) 244 245 # number of threads 246 n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] 247 n_ty = cfg["tile_y"].size[2] 248 n_tx = cfg["tile_x"].size[2] 249 250 # tile reduction axes 251 n, f, y, x = s[OL].op.axis 252 rc, ry, rx = s[OL].op.reduce_axis 253 rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) 254 s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x) 255 256 s[AA].compute_at(s[OL], rx) 257 s[WW].compute_at(s[OL], rx) 258 259 # cooperative fetching 260 for load in [AA, WW]: 261 n, f, y, x = s[load].op.axis 262 fused = s[load].fuse(f, y, x) 263 tz, fused = s[load].split(fused, nparts=n_tz) 264 ty, fused = s[load].split(fused, nparts=n_ty) 265 tx, fused = s[load].split(fused, nparts=n_tx) 266 s[load].bind(tz, tvm.thread_axis("threadIdx.z")) 267 s[load].bind(ty, tvm.thread_axis("threadIdx.y")) 268 s[load].bind(tx, tvm.thread_axis("threadIdx.x")) 269 270 s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) 271 s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) 272 273 traverse_inline(s, outs[0].op, _callback) 274 275 return s 276