1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"Example code to perform int8 GEMM"
18import logging
19import sys
20import numpy as np
21import tvm
22from tvm import autotvm
23from topi.cuda.tensor_intrin import dp4a
24
25DO_TUNING = True
26PRETUNED_INDEX = 75333
27
28intrin_dp4a = dp4a('local', 'local', 'local')
29
30@autotvm.template
31def gemm_int8(n, m, l):
32    A = tvm.placeholder((n, l), name='A', dtype='int8')
33    B = tvm.placeholder((m, l), name='B', dtype='int8')
34
35    k = tvm.reduce_axis((0, l), name='k')
36    C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('int32') * B[j, k].astype(
37        'int32'), axis=k), name='C')
38
39    cfg = autotvm.get_config()
40    s = tvm.create_schedule(C.op)
41    y, x = C.op.axis
42
43    AA = s.cache_read(A, 'shared', [C])
44    BB = s.cache_read(B, 'shared', [C])
45    AL = s.cache_read(AA, 'local', [C])
46    BL = s.cache_read(BB, 'local', [C])
47    CC = s.cache_write(C, 'local')
48
49    k = CC.op.reduce_axis[0]
50
51    cfg.define_split('tile_k', cfg.axis(k), num_outputs=3,
52                     filter=lambda entity: entity.size[2] == 4 and \
53                     entity.size[0] * 2 >= entity.size[1])
54
55    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
56
57    s[CC].tensorize(ki, intrin_dp4a)
58
59    block_x = tvm.thread_axis('blockIdx.x')
60    block_y = tvm.thread_axis('blockIdx.y')
61    thread_x = tvm.thread_axis('threadIdx.x')
62    thread_y = tvm.thread_axis('threadIdx.y')
63
64    def block_size_filter(entity):
65        return entity.size[0] * 2 >= entity.size[1] * 2 and \
66                entity.size[1] <= 16 and entity.size[3] <= 4
67    cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter)
68    cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter)
69    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y)
70    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x)
71
72    s[C].bind(by, block_y)
73    s[C].bind(bx, block_x)
74    s[C].bind(tyz, tvm.thread_axis('vthread'))
75    s[C].bind(txz, tvm.thread_axis('vthread'))
76    s[C].bind(ty, thread_y)
77    s[C].bind(tx, thread_x)
78    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
79
80    s[CC].compute_at(s[C], tx)
81
82    yo, xo = CC.op.axis
83    s[CC].reorder(ko, kt, yo, xo, ki)
84    s[CC].unroll(kt)
85
86    for stage in [AL, BL]:
87        s[stage].compute_at(s[CC], kt)
88        _, xi = s[stage].split(stage.op.axis[1], factor=4)
89        s[stage].vectorize(xi)
90        s[stage].double_buffer()
91
92    cfg.define_knob('storage_align', [16, 48])
93    for stage in [AA, BB]:
94        s[stage].storage_align(s[stage].op.axis[0],
95                               cfg['storage_align'].val, 0)
96        s[stage].compute_at(s[CC], ko)
97
98        fused = s[stage].fuse(*s[stage].op.axis)
99        ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2])
100        tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2])
101        _, xi = s[stage].split(xi, factor=16)
102
103        s[stage].bind(ty, thread_y)
104        s[stage].bind(tx, thread_x)
105        s[stage].vectorize(xi)
106
107    cfg.define_knob('auto_unroll_max_step', [512, 1500])
108    s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
109    s[C].pragma(by, 'unroll_explicit', False)
110
111    cfg.add_flop(n*m*l*2)
112    return s, [A, B, C]
113
114
115if __name__ == '__main__':
116    N = 2048
117    n = m = l = N
118
119    logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
120    task = autotvm.task.create(gemm_int8, args=(n, m, l), target='cuda')
121    print(task.config_space)
122
123    measure_option = autotvm.measure_option(
124        builder=autotvm.LocalBuilder(),
125        runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
126    )
127
128    log_name = 'gemm_int8.log'
129    if DO_TUNING:
130        tuner = autotvm.tuner.XGBTuner(task)
131        tuner.tune(n_trial=1000, measure_option=measure_option,
132                   callbacks=[autotvm.callback.log_to_file(log_name)])
133
134        dispatch_context = autotvm.apply_history_best(log_name)
135        best_config = dispatch_context.query(task.target, task.workload)
136        print('\nBest config:')
137        print(best_config)
138    else:
139        config = task.config_space.get(PRETUNED_INDEX)
140        dispatch_context = autotvm.task.ApplyConfig(config)
141        print("Using pretuned config:")
142        print(config)
143
144    with dispatch_context:
145        with tvm.target.create('cuda'):
146            s, arg_bufs = gemm_int8(n, m, l)
147            f = tvm.build(s, arg_bufs, 'cuda', name='gemm_int8')
148
149    ctx = tvm.context('cuda', 0)
150
151    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype='int8')
152    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype='int8')
153
154    a = tvm.nd.array(a_np, ctx)
155    b = tvm.nd.array(b_np, ctx)
156    c = tvm.nd.array(np.zeros((n, m), dtype='int32'), ctx)
157    f(a, b, c)
158
159    tvm.testing.assert_allclose(
160        c.asnumpy(),
161        np.dot(
162            a_np.astype('int32'),
163            b_np.T.astype('int32')),
164        rtol=1e-5)
165
166    num_ops = 2 * l * m * n
167    num_runs = 1000
168    timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
169    t = timer_f(a, b, c).mean
170    GOPS = num_ops / (t * 1e3) / 1e6
171    print("average time cost of %d runs = %g ms, %g GOPS." %
172          (num_runs, t * 1e3, GOPS))
173