1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"Example code to perform int8 GEMM"
18import logging
19import sys
20import numpy as np
21import tvm
22from tvm import te
23from tvm import autotvm
24from tvm.topi.cuda.tensor_intrin import dp4a
25
26DO_TUNING = True
27PRETUNED_INDEX = 75333
28
29intrin_dp4a = dp4a("local", "local", "local")
30
31
32@autotvm.template
33def gemm_int8(n, m, l):
34    A = te.placeholder((n, l), name="A", dtype="int8")
35    B = te.placeholder((m, l), name="B", dtype="int8")
36
37    k = te.reduce_axis((0, l), name="k")
38    C = te.compute(
39        (n, m),
40        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
41        name="C",
42    )
43
44    cfg = autotvm.get_config()
45    s = te.create_schedule(C.op)
46    y, x = C.op.axis
47
48    AA = s.cache_read(A, "shared", [C])
49    BB = s.cache_read(B, "shared", [C])
50    AL = s.cache_read(AA, "local", [C])
51    BL = s.cache_read(BB, "local", [C])
52    CC = s.cache_write(C, "local")
53
54    k = CC.op.reduce_axis[0]
55
56    cfg.define_split(
57        "tile_k",
58        cfg.axis(k),
59        num_outputs=3,
60        filter=lambda entity: entity.size[2] == 4 and entity.size[0] * 2 >= entity.size[1],
61    )
62
63    ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
64
65    s[CC].tensorize(ki, intrin_dp4a)
66
67    block_x = te.thread_axis("blockIdx.x")
68    block_y = te.thread_axis("blockIdx.y")
69    thread_x = te.thread_axis("threadIdx.x")
70    thread_y = te.thread_axis("threadIdx.y")
71
72    def block_size_filter(entity):
73        return (
74            entity.size[0] * 2 >= entity.size[1] * 2
75            and entity.size[1] <= 16
76            and entity.size[3] <= 4
77        )
78
79    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4, filter=block_size_filter)
80    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4, filter=block_size_filter)
81    by, tyz, ty, yi = cfg["tile_y"].apply(s, C, y)
82    bx, txz, tx, xi = cfg["tile_x"].apply(s, C, x)
83
84    s[C].bind(by, block_y)
85    s[C].bind(bx, block_x)
86    s[C].bind(tyz, te.thread_axis("vthread"))
87    s[C].bind(txz, te.thread_axis("vthread"))
88    s[C].bind(ty, thread_y)
89    s[C].bind(tx, thread_x)
90    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
91
92    s[CC].compute_at(s[C], tx)
93
94    yo, xo = CC.op.axis
95    s[CC].reorder(ko, kt, yo, xo, ki)
96    s[CC].unroll(kt)
97
98    for stage in [AL, BL]:
99        s[stage].compute_at(s[CC], kt)
100        _, xi = s[stage].split(stage.op.axis[1], factor=4)
101        s[stage].vectorize(xi)
102        s[stage].double_buffer()
103
104    cfg.define_knob("storage_align", [16, 48])
105    for stage in [AA, BB]:
106        s[stage].storage_align(s[stage].op.axis[0], cfg["storage_align"].val, 0)
107        s[stage].compute_at(s[CC], ko)
108
109        fused = s[stage].fuse(*s[stage].op.axis)
110        ty, tx = s[stage].split(fused, nparts=cfg["tile_y"].size[2])
111        tx, xi = s[stage].split(tx, nparts=cfg["tile_x"].size[2])
112        _, xi = s[stage].split(xi, factor=16)
113
114        s[stage].bind(ty, thread_y)
115        s[stage].bind(tx, thread_x)
116        s[stage].vectorize(xi)
117
118    cfg.define_knob("auto_unroll_max_step", [512, 1500])
119    s[C].pragma(by, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
120    s[C].pragma(by, "unroll_explicit", False)
121
122    cfg.add_flop(n * m * l * 2)
123    return s, [A, B, C]
124
125
126if __name__ == "__main__":
127    N = 2048
128    n = m = l = N
129
130    logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
131    task = autotvm.task.create(gemm_int8, args=(n, m, l), target="cuda")
132    print(task.config_space)
133
134    measure_option = autotvm.measure_option(
135        builder=autotvm.LocalBuilder(),
136        runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4),
137    )
138
139    log_name = "gemm_int8.log"
140    if DO_TUNING:
141        tuner = autotvm.tuner.XGBTuner(task)
142        tuner.tune(
143            n_trial=1000,
144            measure_option=measure_option,
145            callbacks=[autotvm.callback.log_to_file(log_name)],
146        )
147
148        dispatch_context = autotvm.apply_history_best(log_name)
149        best_config = dispatch_context.query(task.target, task.workload)
150        print("\nBest config:")
151        print(best_config)
152    else:
153        config = task.config_space.get(PRETUNED_INDEX)
154        dispatch_context = autotvm.task.ApplyConfig(config)
155        print("Using pretuned config:")
156        print(config)
157
158    with dispatch_context:
159        with tvm.target.Target("cuda"):
160            s, arg_bufs = gemm_int8(n, m, l)
161            f = tvm.build(s, arg_bufs, "cuda", name="gemm_int8")
162
163    ctx = tvm.context("cuda", 0)
164
165    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype="int8")
166    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype="int8")
167
168    a = tvm.nd.array(a_np, ctx)
169    b = tvm.nd.array(b_np, ctx)
170    c = tvm.nd.array(np.zeros((n, m), dtype="int32"), ctx)
171    f(a, b, c)
172
173    tvm.testing.assert_allclose(
174        c.asnumpy(), np.dot(a_np.astype("int32"), b_np.T.astype("int32")), rtol=1e-5
175    )
176
177    num_ops = 2 * l * m * n
178    num_runs = 1000
179    timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
180    t = timer_f(a, b, c).mean
181    GOPS = num_ops / (t * 1e3) / 1e6
182    print("average time cost of %d runs = %g ms, %g GOPS." % (num_runs, t * 1e3, GOPS))
183