1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable 18"""Schedule for dense operator""" 19from __future__ import absolute_import as _abs 20import logging 21import tvm 22import tvm.autotvm as autotvm 23from tvm.autotvm.task.space import SplitEntity 24from tvm.contrib import cublas 25from .tensor_intrin import dp4a 26from ..nn.dense import dense, dense_default 27from .. import tag 28from .. import generic 29from ..util import traverse_inline, get_const_tuple 30 31logger = logging.getLogger('topi') 32 33 34@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct") 35def dense_cuda(cfg, data, weight, bias=None, out_dtype=None): 36 """Dense operator for cuda backend. 37 38 Parameters 39 ---------- 40 data : tvm.Tensor 41 2-D with shape [batch, in_dim] 42 43 weight : tvm.Tensor 44 2-D with shape [out_dim, in_dim] 45 46 bias : tvm.Tensor, optional 47 1-D with shape [out_dim] 48 49 Returns 50 ------- 51 output : tvm.Tensor 52 2-D with shape [batch, out_dim] 53 """ 54 # pylint: disable=unused-argument 55 assert len(data.shape) == 2 and len(weight.shape) == 2, \ 56 "only support 2-dim dense" 57 if bias is not None: 58 assert len(bias.shape) == 1 59 if out_dtype is None: 60 out_dtype = data.dtype 61 batch, in_dim = data.shape 62 out_dim, _ = weight.shape 63 target = tvm.target.current_target() 64 if "cublas" in target.libs: 65 assert out_dtype == data.dtype, "Mixed precision not supported." 66 matmul = cublas.matmul(data, weight, False, True) 67 if bias is not None: 68 matmul = tvm.compute((batch, out_dim), \ 69 lambda i, j: matmul[i, j] + bias[j], \ 70 tag=tag.BROADCAST) 71 return matmul 72 return dense_default(data, weight, bias, out_dtype) 73 74 75@autotvm.register_topi_schedule(generic.schedule_dense, ["cuda", "gpu"], "direct") 76def schedule_dense(cfg, outs): 77 """Schedule for dense operator. 78 79 Parameters 80 ---------- 81 outs: Array of Tensor 82 The computation graph description of dense 83 in the format of an array of tensors. 84 85 Returns 86 ------- 87 s: Schedule 88 The computation schedule for dense. 89 """ 90 # pylint: disable=unused-argument 91 target = tvm.target.current_target() 92 93 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 94 if target.target_name == "cuda" and "cublas" in target.libs: 95 return generic.schedule_extern(outs) 96 97 s = tvm.create_schedule([x.op for x in outs]) 98 99 def _schedule(C): 100 A, _ = C.op.input_tensors 101 batch, _ = get_const_tuple(A.shape) 102 if batch < 32: 103 return schedule_dense_small_batch(cfg, s, C) 104 return schedule_dense_large_batch(cfg, s, C) 105 106 scheduled_ops = [] 107 108 def traverse(OP): 109 """Internal travserse function""" 110 # inline all one-to-one-mapping operators except the last stage (output) 111 if tag.is_broadcast(OP.tag): 112 if OP not in s.outputs: 113 s[OP].compute_inline() 114 for tensor in OP.input_tensors: 115 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 116 traverse(tensor.op) 117 # schedule dense 118 elif OP.tag == 'dense': 119 Dense = OP.output(0) 120 _schedule(Dense) 121 else: 122 raise RuntimeError("Unsupported operator: %s" % OP.tag) 123 124 scheduled_ops.append(OP) 125 126 traverse(outs[0].op) 127 return s 128 129 130def schedule_dense_small_batch(cfg, s, C): 131 """Schedule float32/64 dense with small batch size""" 132 A, _ = C.op.input_tensors 133 _, in_dim = get_const_tuple(A.shape) 134 cfg.define_split('tile_k', in_dim, num_outputs=2) 135 if cfg.is_fallback: 136 cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) 137 138 _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0]) 139 CF = s.rfactor(C, kf) 140 141 if C.op in s.outputs: 142 Out = C 143 else: 144 Out = s.outputs[0].output(0) 145 s[C].compute_at(s[Out], s[Out].op.axis[1]) 146 s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y")) 147 s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x")) 148 149 tx = s[C].op.reduce_axis[0] 150 thread_x = tvm.thread_axis("threadIdx.x") 151 s[C].bind(tx, thread_x) 152 s[CF].compute_at(s[C], tx) 153 s[C].set_store_predicate(thread_x.var.equal(0)) 154 s[Out].set_store_predicate(thread_x.var.equal(0)) 155 156def schedule_dense_large_batch(cfg, s, C): 157 """Schedule float32/64 dense with large batch size""" 158 A, B = C.op.input_tensors 159 batch, in_dim = get_const_tuple(A.shape) 160 out_dim, _ = get_const_tuple(B.shape) 161 k = C.op.reduce_axis[0] 162 163 # create tuning space 164 try: 165 block_cand = [64, 128] 166 vthread_cand = [2**x for x in range(1, 7)] 167 n_thread_cand = [2**x for x in range(3, 7)] 168 cfg.define_split('tile_x', batch, num_outputs=4, 169 filter=lambda x: (x.size[1] in vthread_cand and 170 x.size[2] in n_thread_cand and 171 (x.size[1] * x.size[2] * x.size[3]) in block_cand)) 172 cfg.define_split('tile_y', out_dim, num_outputs=4, 173 filter=lambda x: (x.size[1] in vthread_cand and 174 x.size[2] in n_thread_cand and 175 (x.size[1] * x.size[2] * x.size[3]) in block_cand)) 176 cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2) 177 except IndexError: 178 # Index error happens when no entities left after filtering, which was designed 179 # to prune tuning space for better search efficiency. 180 logger.debug( 181 'Tuning space was created without pruning due to unfit shapes') 182 cfg.define_split('tile_x', batch, num_outputs=4) 183 cfg.define_split('tile_y', out_dim, num_outputs=4) 184 cfg.define_split('tile_k', in_dim, num_outputs=3) 185 186 if cfg.is_fallback: 187 if batch > 1: 188 cfg['tile_x'] = SplitEntity([-1, 2, 16, 2]) 189 else: 190 cfg['tile_x'] = SplitEntity([1, 1, 1, 1]) 191 if out_dim > 1: 192 cfg['tile_y'] = SplitEntity([-1, 2, 16, 2]) 193 else: 194 cfg['tile_y'] = SplitEntity([1, 1, 1, 1]) 195 if in_dim > 8: 196 cfg['tile_k'] = SplitEntity([-1, 8, 1]) 197 else: 198 cfg['tile_k'] = SplitEntity([-1, 1, 1]) 199 200 # Explicit memory access 201 AA = s.cache_read(A, "shared", [C]) 202 BB = s.cache_read(B, "shared", [C]) 203 AL = s.cache_read(AA, "local", [C]) 204 BL = s.cache_read(BB, "local", [C]) 205 CC = s.cache_write(C, "local") 206 207 # Deal with op fusion 208 if C.op not in s.outputs: 209 s[C].compute_inline() 210 C = s.outputs[0].output(0) 211 212 # Split and reorder computation 213 bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0]) 214 by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1]) 215 s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) 216 s[CC].compute_at(s[C], tx) 217 218 # Binding 219 s[C].bind(by, tvm.thread_axis("blockIdx.y")) 220 s[C].bind(bx, tvm.thread_axis("blockIdx.x")) 221 s[C].bind(tyz, tvm.thread_axis("vthread")) 222 s[C].bind(txz, tvm.thread_axis("vthread")) 223 s[C].bind(ty, tvm.thread_axis("threadIdx.y")) 224 s[C].bind(tx, tvm.thread_axis("threadIdx.x")) 225 226 # Split reduction 227 yo, xo = CC.op.axis 228 ko, kt, ki = cfg['tile_k'].apply(s, CC, k) 229 s[CC].reorder(ko, kt, ki, yo, xo) 230 s[AA].compute_at(s[CC], ko) 231 s[BB].compute_at(s[CC], ko) 232 s[CC].unroll(kt) 233 s[AL].compute_at(s[CC], kt) 234 s[BL].compute_at(s[CC], kt) 235 236 # Schedule for A's shared memory load 237 num_thread_x = cfg['tile_x'].size[2] 238 ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x) 239 _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4) 240 tx, xi = s[AA].split(xi, nparts=num_thread_x) 241 s[AA].bind(ty, tvm.thread_axis("threadIdx.y")) 242 s[AA].bind(tx, tvm.thread_axis("threadIdx.x")) 243 s[AA].double_buffer() 244 245 # Schedule for B' shared memory load 246 num_thread_y = cfg['tile_y'].size[2] 247 ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y) 248 _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4) 249 tx, xi = s[BB].split(xi, nparts=num_thread_y) 250 s[BB].bind(ty, tvm.thread_axis("threadIdx.y")) 251 s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) 252 s[BB].double_buffer() 253 254@autotvm.register_topi_compute(dense, ['cuda'], ['int8']) 255def dense_int8(cfg, data, weight, bias=None, out_dtype=None): 256 """Dense operator for int8 on CUDA""" 257 if out_dtype is None: 258 out_dtype = data.dtype 259 batch, in_dim = get_const_tuple(data.shape) 260 out_dim, _ = get_const_tuple(weight.shape) 261 262 k = tvm.reduce_axis((0, in_dim), name='k') 263 264 matmul = tvm.compute((batch, out_dim), 265 lambda i, j: tvm.sum(data[i, k].astype(out_dtype) * 266 weight[j, k].astype(out_dtype), axis=[k]), 267 tag="dense_int8") 268 269 cfg.add_flop(batch * in_dim * out_dim * 2) 270 271 if bias is not None: 272 matmul = tvm.compute((batch, out_dim), 273 lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), 274 tag=tag.BROADCAST) 275 cfg.add_flop(batch * out_dim) 276 277 return matmul 278 279 280@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8']) 281def schedule_dense_int8(cfg, outs): 282 s = tvm.create_schedule([x.op for x in outs]) 283 def _callback(op): 284 if "dense_int8" in op.tag: 285 _schedule_dense_int8(cfg, s, op.output(0)) 286 traverse_inline(s, outs[0].op, _callback) 287 return s 288 289 290_dp4a = dp4a('shared', 'shared', 'local') 291 292def _schedule_dense_int8(cfg, s, output): 293 data, weight = s[output].op.input_tensors 294 295 batch, in_dim = get_const_tuple(data.shape) 296 out_dim, _ = get_const_tuple(weight.shape) 297 298 in_dim_factor = 4 299 assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor) 300 if in_dim % 16 == 0: 301 in_dim_factor = 16 302 303 # create tuning space 304 cfg.define_split("tile_y", batch, num_outputs=4) 305 cfg.define_split("tile_x", out_dim, num_outputs=4) 306 cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2) 307 cfg.define_knob('auto_unroll_max_step', [0, 512, 1500]) 308 309 # create cache stage 310 AA = s.cache_read(data, 'shared', [output]) 311 WW = s.cache_read(weight, 'shared', [output]) 312 CC = s.cache_write(output, 'local') 313 314 # handle bias 315 if output.op not in s.outputs: 316 s[output].compute_inline() 317 output = s.outputs[0].output(0) 318 319 n, x = s[output].op.axis 320 321 # this is the scope to attach global config inside this kernel 322 kernel_scope, n = s[output].split(n, nparts=1) 323 324 ko = CC.op.reduce_axis[0] 325 ko, ki = s[CC].split(ko, factor=4) 326 ko, kt = cfg['tile_k'].apply(s, CC, ko) 327 s[CC].tensorize(ki, _dp4a) 328 by, vy, ty, yi = cfg['tile_y'].apply(s, output, n) 329 bx, vx, tx, xi = cfg['tile_x'].apply(s, output, x) 330 331 s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi) 332 s[output].bind(by, tvm.thread_axis('blockIdx.y')) 333 s[output].bind(bx, tvm.thread_axis('blockIdx.x')) 334 s[output].bind(vy, tvm.thread_axis('vthread')) 335 s[output].bind(vx, tvm.thread_axis('vthread')) 336 s[output].bind(ty, tvm.thread_axis('threadIdx.y')) 337 s[output].bind(tx, tvm.thread_axis('threadIdx.x')) 338 n_ty = cfg['tile_y'].size[2] 339 n_tx = cfg['tile_x'].size[2] 340 341 s[CC].compute_at(s[output], tx) 342 yo, xo = CC.op.axis[:2] 343 s[CC].reorder(ko, kt, yo, xo, ki) 344 345 for load in [AA, WW]: 346 s[load].compute_at(s[CC], ko) 347 348 outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor) 349 s[load].vectorize(inner) 350 fused = s[load].op.axis[:-1] + [outer] 351 fused = s[load].fuse(*fused) 352 353 fused, tx = s[load].split(fused, factor=n_tx) 354 fused, ty = s[load].split(fused, factor=n_ty) 355 s[load].bind(tx, tvm.thread_axis('threadIdx.x')) 356 s[load].bind(ty, tvm.thread_axis('threadIdx.y')) 357 358 s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) 359 s[output].pragma(kernel_scope, 'unroll_explicit', False) 360 return s 361