1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Tensor intrinsics on CUDA."""
18#pylint: disable=invalid-name
19import tvm
20
21
22def dp4a(x_scope='local', y_scope='local', z_scope='local'):
23    """
24    Int8 dot product reduced by every 4 elements using __dp4a
25
26    Parameters
27    ----------
28    x_scope : str, optional
29        The storage scope of buffer for lhs
30    y_scope : str, optional
31        The storage scope of buffer for rhs
32    z_scope : str, optional
33        The storage scope of buffer for result
34
35    Returns
36    -------
37    intrin : TensorIntrin
38        The dp4a TensorIntrin that can be used in tensorizing schedule.
39    """
40
41    n = 4  # dp4a requires operands packed by 4
42    x = tvm.placeholder((n,), name='x', dtype='int8')
43    y = tvm.placeholder((n,), name='y', dtype='int8')
44
45    k = tvm.reduce_axis((0, n), name='rc')
46
47    z = tvm.compute((1,), lambda i: tvm.sum(
48        x[k].astype('int32') * y[k].astype('int32'), axis=[k]))
49
50    def _intrin_func(ins, outs):
51        def _instr(index):
52            xx, yy = ins
53            zz = outs[0]
54
55            if index == 1:
56                return zz.vstore(0, 0)
57
58            ib = tvm.ir_builder.create()
59
60            vec_x = xx.vload(0, dtype='int8x4')
61            vec_y = yy.vload(0, dtype='int8x4')
62            prev_z = 0 if index == 0 else zz.vload(0)
63
64            new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z)
65            ib.emit(zz.vstore(0, new_z))
66
67            return ib.get()
68
69        return _instr(0), _instr(1), _instr(2) # body, reset, update
70
71    with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
72        scopes = {x: x_scope, y: y_scope, z: z_scope}
73        binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
74                                    data_alignment=cfg.data_alignment,
75                                    offset_factor=cfg.offset_factor,
76                                    scope=scopes[t]) for t in [x, y, z]}
77
78        return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
79