1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,too-many-locals,unused-variable
18"""cuda batch_matmul operators"""
19from __future__ import absolute_import as _abs
20import tvm
21from tvm.contrib import cublas
22from topi.nn import batch_matmul, batch_matmul_default
23from .. import generic
24from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
25
26@batch_matmul.register(["cuda", "gpu"])
27def batch_matmul_cuda(x, y):
28    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
29    data in batch.
30
31    Parameters
32    ----------
33    x : tvm.Tensor
34        3-D with shape [batch, M, K]
35
36    y : tvm.Tensor
37        3-D with shape [batch, N, K]
38
39    Returns
40    -------
41    output : tvm.Tensor
42        3-D with shape [batch, M, N]
43    """
44    target = tvm.target.current_target()
45    if target.target_name == "cuda" and "cublas" in target.libs:
46        return cublas.batch_matmul(x, y, False, True)
47    return batch_matmul_default(x, y)
48
49@generic.schedule_batch_matmul.register(["cuda", "gpu"])
50def schedule_batch_matmul(outs):
51    """Schedule for batch_matmul
52
53    Parameters
54    ----------
55    outs: Array of Tensor
56          The computation graph description of batch_matmul
57          in the format of an array of tensors.
58
59    Returns
60    -------
61    s: Schedule
62        The computation schedule for the op.
63    """
64    target = tvm.target.current_target()
65    if target.target_name == "cuda" and "cublas" in target.libs:
66        return generic.schedule_extern(outs)
67
68    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
69    s = tvm.create_schedule([x.op for x in outs])
70
71    def _schedule(op):
72        C = op.output(0)
73        A, B = s[C].op.input_tensors
74        _, M, N = get_const_tuple(C.shape)
75        AA = s.cache_read(A, "shared", [C])
76        AL = s.cache_read(AA, "local", [C])
77        BB = s.cache_read(B, "shared", [C])
78        BL = s.cache_read(BB, "local", [C])
79        CC = s.cache_write(C, "local")
80        if op not in s.outputs:
81            s[C].compute_inline()
82            C = s.outputs[0].output(0)
83
84        b, y, x = s[C].op.axis
85        y_bn = get_max_power2_factor(M, 64)
86        x_bn = get_max_power2_factor(N, 64)
87        by, y = s[C].split(y, y_bn)
88        bx, x = s[C].split(x, x_bn)
89        y_nthreads = min(y_bn, 8)
90        x_nthreads = min(x_bn, 8)
91        ty, yi = s[C].split(y, nparts=y_nthreads)
92        tx, xi = s[C].split(x, nparts=x_nthreads)
93        thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x")
94        thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y")
95
96        s[C].reorder(b, by, bx, ty, tx, yi, xi)
97        s[C].bind(b, tvm.thread_axis("blockIdx.z"))
98        s[C].bind(by, tvm.thread_axis("blockIdx.y"))
99        s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
100        s[C].bind(ty, thread_y)
101        s[C].bind(tx, thread_x)
102        s[C].pragma(yi, "auto_unroll_max_step", 16)
103
104        s[CC].compute_at(s[C], tx)
105        _, yi, xi = s[CC].op.axis
106        k, = s[CC].op.reduce_axis
107        ko, ki = s[CC].split(k, 8)
108        s[CC].reorder(ko, ki, yi, xi)
109        s[CC].pragma(ki, "auto_unroll_max_step", 16)
110
111        s[AA].compute_at(s[CC], ko)
112        s[AL].compute_at(s[CC], ki)
113        s[BB].compute_at(s[CC], ko)
114        s[BL].compute_at(s[CC], ki)
115        _, y, k = s[AA].op.axis
116        ty, yi = s[AA].split(y, nparts=y_nthreads)
117        tx, ki = s[AA].split(k, nparts=x_nthreads)
118        s[AA].reorder(ty, tx, yi, ki)
119        s[AA].bind(ty, thread_y)
120        s[AA].bind(tx, thread_x)
121        s[AA].pragma(yi, "auto_unroll_max_step", 16)
122
123        _, x, k = s[BB].op.axis
124        ty, xi = s[BB].split(x, nparts=y_nthreads)
125        tx, ki = s[BB].split(k, nparts=x_nthreads)
126        s[BB].bind(ty, thread_y)
127        s[BB].bind(tx, thread_x)
128        s[BB].reorder(ty, tx, xi, ki)
129        s[BB].pragma(xi, "auto_unroll_max_step", 16)
130
131    def _callback(op):
132        if "batch_matmul" in op.tag:
133            _schedule(op)
134
135    traverse_inline(s, outs[0].op, _callback)
136    return s
137