1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,too-many-locals,unused-variable 18"""cuda batch_matmul operators""" 19from __future__ import absolute_import as _abs 20import tvm 21from tvm.contrib import cublas 22from topi.nn import batch_matmul, batch_matmul_default 23from .. import generic 24from ..util import traverse_inline, get_const_tuple, get_max_power2_factor 25 26@batch_matmul.register(["cuda", "gpu"]) 27def batch_matmul_cuda(x, y): 28 """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are 29 data in batch. 30 31 Parameters 32 ---------- 33 x : tvm.Tensor 34 3-D with shape [batch, M, K] 35 36 y : tvm.Tensor 37 3-D with shape [batch, N, K] 38 39 Returns 40 ------- 41 output : tvm.Tensor 42 3-D with shape [batch, M, N] 43 """ 44 target = tvm.target.current_target() 45 if target.target_name == "cuda" and "cublas" in target.libs: 46 return cublas.batch_matmul(x, y, False, True) 47 return batch_matmul_default(x, y) 48 49@generic.schedule_batch_matmul.register(["cuda", "gpu"]) 50def schedule_batch_matmul(outs): 51 """Schedule for batch_matmul 52 53 Parameters 54 ---------- 55 outs: Array of Tensor 56 The computation graph description of batch_matmul 57 in the format of an array of tensors. 58 59 Returns 60 ------- 61 s: Schedule 62 The computation schedule for the op. 63 """ 64 target = tvm.target.current_target() 65 if target.target_name == "cuda" and "cublas" in target.libs: 66 return generic.schedule_extern(outs) 67 68 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 69 s = tvm.create_schedule([x.op for x in outs]) 70 71 def _schedule(op): 72 C = op.output(0) 73 A, B = s[C].op.input_tensors 74 _, M, N = get_const_tuple(C.shape) 75 AA = s.cache_read(A, "shared", [C]) 76 AL = s.cache_read(AA, "local", [C]) 77 BB = s.cache_read(B, "shared", [C]) 78 BL = s.cache_read(BB, "local", [C]) 79 CC = s.cache_write(C, "local") 80 if op not in s.outputs: 81 s[C].compute_inline() 82 C = s.outputs[0].output(0) 83 84 b, y, x = s[C].op.axis 85 y_bn = get_max_power2_factor(M, 64) 86 x_bn = get_max_power2_factor(N, 64) 87 by, y = s[C].split(y, y_bn) 88 bx, x = s[C].split(x, x_bn) 89 y_nthreads = min(y_bn, 8) 90 x_nthreads = min(x_bn, 8) 91 ty, yi = s[C].split(y, nparts=y_nthreads) 92 tx, xi = s[C].split(x, nparts=x_nthreads) 93 thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x") 94 thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y") 95 96 s[C].reorder(b, by, bx, ty, tx, yi, xi) 97 s[C].bind(b, tvm.thread_axis("blockIdx.z")) 98 s[C].bind(by, tvm.thread_axis("blockIdx.y")) 99 s[C].bind(bx, tvm.thread_axis("blockIdx.x")) 100 s[C].bind(ty, thread_y) 101 s[C].bind(tx, thread_x) 102 s[C].pragma(yi, "auto_unroll_max_step", 16) 103 104 s[CC].compute_at(s[C], tx) 105 _, yi, xi = s[CC].op.axis 106 k, = s[CC].op.reduce_axis 107 ko, ki = s[CC].split(k, 8) 108 s[CC].reorder(ko, ki, yi, xi) 109 s[CC].pragma(ki, "auto_unroll_max_step", 16) 110 111 s[AA].compute_at(s[CC], ko) 112 s[AL].compute_at(s[CC], ki) 113 s[BB].compute_at(s[CC], ko) 114 s[BL].compute_at(s[CC], ki) 115 _, y, k = s[AA].op.axis 116 ty, yi = s[AA].split(y, nparts=y_nthreads) 117 tx, ki = s[AA].split(k, nparts=x_nthreads) 118 s[AA].reorder(ty, tx, yi, ki) 119 s[AA].bind(ty, thread_y) 120 s[AA].bind(tx, thread_x) 121 s[AA].pragma(yi, "auto_unroll_max_step", 16) 122 123 _, x, k = s[BB].op.axis 124 ty, xi = s[BB].split(x, nparts=y_nthreads) 125 tx, ki = s[BB].split(k, nparts=x_nthreads) 126 s[BB].bind(ty, thread_y) 127 s[BB].bind(tx, thread_x) 128 s[BB].reorder(ty, tx, xi, ki) 129 s[BB].pragma(xi, "auto_unroll_max_step", 16) 130 131 def _callback(op): 132 if "batch_matmul" in op.tag: 133 _schedule(op) 134 135 traverse_inline(s, outs[0].op, _callback) 136 return s 137