1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable
18"""Schedule for dense operator"""
19from __future__ import absolute_import as _abs
20import logging
21import tvm
22import tvm.autotvm as autotvm
23from tvm.autotvm.task.space import SplitEntity
24from tvm.contrib import cublas
25from .tensor_intrin import dp4a
26from ..nn.dense import dense, dense_default
27from .. import tag
28from .. import generic
29from ..util import traverse_inline, get_const_tuple
30
31logger = logging.getLogger('topi')
32
33
34@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
35def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
36    """Dense operator for cuda backend.
37
38    Parameters
39    ----------
40    data : tvm.Tensor
41        2-D with shape [batch, in_dim]
42
43    weight : tvm.Tensor
44        2-D with shape [out_dim, in_dim]
45
46    bias : tvm.Tensor, optional
47        1-D with shape [out_dim]
48
49    Returns
50    -------
51    output : tvm.Tensor
52        2-D with shape [batch, out_dim]
53    """
54    # pylint: disable=unused-argument
55    assert len(data.shape) == 2 and len(weight.shape) == 2, \
56        "only support 2-dim dense"
57    if bias is not None:
58        assert len(bias.shape) == 1
59    if out_dtype is None:
60        out_dtype = data.dtype
61    batch, in_dim = data.shape
62    out_dim, _ = weight.shape
63    target = tvm.target.current_target()
64    if "cublas" in target.libs:
65        assert out_dtype == data.dtype, "Mixed precision not supported."
66        matmul = cublas.matmul(data, weight, False, True)
67        if bias is not None:
68            matmul = tvm.compute((batch, out_dim), \
69                                 lambda i, j: matmul[i, j] + bias[j], \
70                                 tag=tag.BROADCAST)
71        return matmul
72    return dense_default(data, weight, bias, out_dtype)
73
74
75@autotvm.register_topi_schedule(generic.schedule_dense, ["cuda", "gpu"], "direct")
76def schedule_dense(cfg, outs):
77    """Schedule for dense operator.
78
79    Parameters
80    ----------
81    outs: Array of Tensor
82        The computation graph description of dense
83        in the format of an array of tensors.
84
85    Returns
86    -------
87    s: Schedule
88        The computation schedule for dense.
89    """
90    # pylint: disable=unused-argument
91    target = tvm.target.current_target()
92
93    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
94    if target.target_name == "cuda" and "cublas" in target.libs:
95        return generic.schedule_extern(outs)
96
97    s = tvm.create_schedule([x.op for x in outs])
98
99    def _schedule(C):
100        A, _ = C.op.input_tensors
101        batch, _ = get_const_tuple(A.shape)
102        if batch < 32:
103            return schedule_dense_small_batch(cfg, s, C)
104        return schedule_dense_large_batch(cfg, s, C)
105
106    scheduled_ops = []
107
108    def traverse(OP):
109        """Internal travserse function"""
110        # inline all one-to-one-mapping operators except the last stage (output)
111        if tag.is_broadcast(OP.tag):
112            if OP not in s.outputs:
113                s[OP].compute_inline()
114            for tensor in OP.input_tensors:
115                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
116                    traverse(tensor.op)
117        # schedule dense
118        elif OP.tag == 'dense':
119            Dense = OP.output(0)
120            _schedule(Dense)
121        else:
122            raise RuntimeError("Unsupported operator: %s" % OP.tag)
123
124        scheduled_ops.append(OP)
125
126    traverse(outs[0].op)
127    return s
128
129
130def schedule_dense_small_batch(cfg, s, C):
131    """Schedule float32/64 dense with small batch size"""
132    A, _ = C.op.input_tensors
133    _, in_dim = get_const_tuple(A.shape)
134    cfg.define_split('tile_k', in_dim, num_outputs=2)
135    if cfg.is_fallback:
136        cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
137
138    _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
139    CF = s.rfactor(C, kf)
140
141    if C.op in s.outputs:
142        Out = C
143    else:
144        Out = s.outputs[0].output(0)
145        s[C].compute_at(s[Out], s[Out].op.axis[1])
146    s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
147    s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
148
149    tx = s[C].op.reduce_axis[0]
150    thread_x = tvm.thread_axis("threadIdx.x")
151    s[C].bind(tx, thread_x)
152    s[CF].compute_at(s[C], tx)
153    s[C].set_store_predicate(thread_x.var.equal(0))
154    s[Out].set_store_predicate(thread_x.var.equal(0))
155
156def schedule_dense_large_batch(cfg, s, C):
157    """Schedule float32/64 dense with large batch size"""
158    A, B = C.op.input_tensors
159    batch, in_dim = get_const_tuple(A.shape)
160    out_dim, _ = get_const_tuple(B.shape)
161    k = C.op.reduce_axis[0]
162
163    # create tuning space
164    try:
165        block_cand = [64, 128]
166        vthread_cand = [2**x for x in range(1, 7)]
167        n_thread_cand = [2**x for x in range(3, 7)]
168        cfg.define_split('tile_x', batch, num_outputs=4,
169                         filter=lambda x: (x.size[1] in vthread_cand and
170                                           x.size[2] in n_thread_cand and
171                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
172        cfg.define_split('tile_y', out_dim, num_outputs=4,
173                         filter=lambda x: (x.size[1] in vthread_cand and
174                                           x.size[2] in n_thread_cand and
175                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
176        cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
177    except IndexError:
178        # Index error happens when no entities left after filtering, which was designed
179        # to prune tuning space for better search efficiency.
180        logger.debug(
181            'Tuning space was created without pruning due to unfit shapes')
182        cfg.define_split('tile_x', batch, num_outputs=4)
183        cfg.define_split('tile_y', out_dim, num_outputs=4)
184        cfg.define_split('tile_k', in_dim, num_outputs=3)
185
186    if cfg.is_fallback:
187        if batch > 1:
188            cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
189        else:
190            cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
191        if out_dim > 1:
192            cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
193        else:
194            cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
195        if in_dim > 8:
196            cfg['tile_k'] = SplitEntity([-1, 8, 1])
197        else:
198            cfg['tile_k'] = SplitEntity([-1, 1, 1])
199
200    # Explicit memory access
201    AA = s.cache_read(A, "shared", [C])
202    BB = s.cache_read(B, "shared", [C])
203    AL = s.cache_read(AA, "local", [C])
204    BL = s.cache_read(BB, "local", [C])
205    CC = s.cache_write(C, "local")
206
207    # Deal with op fusion
208    if C.op not in s.outputs:
209        s[C].compute_inline()
210        C = s.outputs[0].output(0)
211
212    # Split and reorder computation
213    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
214    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
215    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
216    s[CC].compute_at(s[C], tx)
217
218    # Binding
219    s[C].bind(by, tvm.thread_axis("blockIdx.y"))
220    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
221    s[C].bind(tyz, tvm.thread_axis("vthread"))
222    s[C].bind(txz, tvm.thread_axis("vthread"))
223    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
224    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
225
226    # Split reduction
227    yo, xo = CC.op.axis
228    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
229    s[CC].reorder(ko, kt, ki, yo, xo)
230    s[AA].compute_at(s[CC], ko)
231    s[BB].compute_at(s[CC], ko)
232    s[CC].unroll(kt)
233    s[AL].compute_at(s[CC], kt)
234    s[BL].compute_at(s[CC], kt)
235
236    # Schedule for A's shared memory load
237    num_thread_x = cfg['tile_x'].size[2]
238    ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
239    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
240    tx, xi = s[AA].split(xi, nparts=num_thread_x)
241    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
242    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
243    s[AA].double_buffer()
244
245    # Schedule for B' shared memory load
246    num_thread_y = cfg['tile_y'].size[2]
247    ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
248    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
249    tx, xi = s[BB].split(xi, nparts=num_thread_y)
250    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
251    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
252    s[BB].double_buffer()
253
254@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
255def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
256    """Dense operator for int8 on CUDA"""
257    if out_dtype is None:
258        out_dtype = data.dtype
259    batch, in_dim = get_const_tuple(data.shape)
260    out_dim, _ = get_const_tuple(weight.shape)
261
262    k = tvm.reduce_axis((0, in_dim), name='k')
263
264    matmul = tvm.compute((batch, out_dim),
265                         lambda i, j: tvm.sum(data[i, k].astype(out_dtype) *
266                                              weight[j, k].astype(out_dtype), axis=[k]),
267                         tag="dense_int8")
268
269    cfg.add_flop(batch * in_dim * out_dim * 2)
270
271    if bias is not None:
272        matmul = tvm.compute((batch, out_dim),
273                             lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
274                             tag=tag.BROADCAST)
275        cfg.add_flop(batch * out_dim)
276
277    return matmul
278
279
280@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8'])
281def schedule_dense_int8(cfg, outs):
282    s = tvm.create_schedule([x.op for x in outs])
283    def _callback(op):
284        if "dense_int8" in op.tag:
285            _schedule_dense_int8(cfg, s, op.output(0))
286    traverse_inline(s, outs[0].op, _callback)
287    return s
288
289
290_dp4a = dp4a('shared', 'shared', 'local')
291
292def _schedule_dense_int8(cfg, s, output):
293    data, weight = s[output].op.input_tensors
294
295    batch, in_dim = get_const_tuple(data.shape)
296    out_dim, _ = get_const_tuple(weight.shape)
297
298    in_dim_factor = 4
299    assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor)
300    if in_dim % 16 == 0:
301        in_dim_factor = 16
302
303    # create tuning space
304    cfg.define_split("tile_y", batch, num_outputs=4)
305    cfg.define_split("tile_x", out_dim, num_outputs=4)
306    cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2)
307    cfg.define_knob('auto_unroll_max_step', [0, 512, 1500])
308
309    # create cache stage
310    AA = s.cache_read(data, 'shared', [output])
311    WW = s.cache_read(weight, 'shared', [output])
312    CC = s.cache_write(output, 'local')
313
314    # handle bias
315    if output.op not in s.outputs:
316        s[output].compute_inline()
317        output = s.outputs[0].output(0)
318
319    n, x = s[output].op.axis
320
321    # this is the scope to attach global config inside this kernel
322    kernel_scope, n = s[output].split(n, nparts=1)
323
324    ko = CC.op.reduce_axis[0]
325    ko, ki = s[CC].split(ko, factor=4)
326    ko, kt = cfg['tile_k'].apply(s, CC, ko)
327    s[CC].tensorize(ki, _dp4a)
328    by, vy, ty, yi = cfg['tile_y'].apply(s, output, n)
329    bx, vx, tx, xi = cfg['tile_x'].apply(s, output, x)
330
331    s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi)
332    s[output].bind(by, tvm.thread_axis('blockIdx.y'))
333    s[output].bind(bx, tvm.thread_axis('blockIdx.x'))
334    s[output].bind(vy, tvm.thread_axis('vthread'))
335    s[output].bind(vx, tvm.thread_axis('vthread'))
336    s[output].bind(ty, tvm.thread_axis('threadIdx.y'))
337    s[output].bind(tx, tvm.thread_axis('threadIdx.x'))
338    n_ty = cfg['tile_y'].size[2]
339    n_tx = cfg['tile_x'].size[2]
340
341    s[CC].compute_at(s[output], tx)
342    yo, xo = CC.op.axis[:2]
343    s[CC].reorder(ko, kt, yo, xo, ki)
344
345    for load in [AA, WW]:
346        s[load].compute_at(s[CC], ko)
347
348        outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor)
349        s[load].vectorize(inner)
350        fused = s[load].op.axis[:-1] + [outer]
351        fused = s[load].fuse(*fused)
352
353        fused, tx = s[load].split(fused, factor=n_tx)
354        fused, ty = s[load].split(fused, factor=n_ty)
355        s[load].bind(tx, tvm.thread_axis('threadIdx.x'))
356        s[load].bind(ty, tvm.thread_axis('threadIdx.y'))
357
358    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
359    s[output].pragma(kernel_scope, 'unroll_explicit', False)
360    return s
361