1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"Example code to perform int8 GEMM" 18import logging 19import sys 20import numpy as np 21import tvm 22from tvm import autotvm 23from topi.cuda.tensor_intrin import dp4a 24 25DO_TUNING = True 26PRETUNED_INDEX = 75333 27 28intrin_dp4a = dp4a('local', 'local', 'local') 29 30@autotvm.template 31def gemm_int8(n, m, l): 32 A = tvm.placeholder((n, l), name='A', dtype='int8') 33 B = tvm.placeholder((m, l), name='B', dtype='int8') 34 35 k = tvm.reduce_axis((0, l), name='k') 36 C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('int32') * B[j, k].astype( 37 'int32'), axis=k), name='C') 38 39 cfg = autotvm.get_config() 40 s = tvm.create_schedule(C.op) 41 y, x = C.op.axis 42 43 AA = s.cache_read(A, 'shared', [C]) 44 BB = s.cache_read(B, 'shared', [C]) 45 AL = s.cache_read(AA, 'local', [C]) 46 BL = s.cache_read(BB, 'local', [C]) 47 CC = s.cache_write(C, 'local') 48 49 k = CC.op.reduce_axis[0] 50 51 cfg.define_split('tile_k', cfg.axis(k), num_outputs=3, 52 filter=lambda entity: entity.size[2] == 4 and \ 53 entity.size[0] * 2 >= entity.size[1]) 54 55 ko, kt, ki = cfg['tile_k'].apply(s, CC, k) 56 57 s[CC].tensorize(ki, intrin_dp4a) 58 59 block_x = tvm.thread_axis('blockIdx.x') 60 block_y = tvm.thread_axis('blockIdx.y') 61 thread_x = tvm.thread_axis('threadIdx.x') 62 thread_y = tvm.thread_axis('threadIdx.y') 63 64 def block_size_filter(entity): 65 return entity.size[0] * 2 >= entity.size[1] * 2 and \ 66 entity.size[1] <= 16 and entity.size[3] <= 4 67 cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter) 68 cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter) 69 by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y) 70 bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x) 71 72 s[C].bind(by, block_y) 73 s[C].bind(bx, block_x) 74 s[C].bind(tyz, tvm.thread_axis('vthread')) 75 s[C].bind(txz, tvm.thread_axis('vthread')) 76 s[C].bind(ty, thread_y) 77 s[C].bind(tx, thread_x) 78 s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) 79 80 s[CC].compute_at(s[C], tx) 81 82 yo, xo = CC.op.axis 83 s[CC].reorder(ko, kt, yo, xo, ki) 84 s[CC].unroll(kt) 85 86 for stage in [AL, BL]: 87 s[stage].compute_at(s[CC], kt) 88 _, xi = s[stage].split(stage.op.axis[1], factor=4) 89 s[stage].vectorize(xi) 90 s[stage].double_buffer() 91 92 cfg.define_knob('storage_align', [16, 48]) 93 for stage in [AA, BB]: 94 s[stage].storage_align(s[stage].op.axis[0], 95 cfg['storage_align'].val, 0) 96 s[stage].compute_at(s[CC], ko) 97 98 fused = s[stage].fuse(*s[stage].op.axis) 99 ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2]) 100 tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2]) 101 _, xi = s[stage].split(xi, factor=16) 102 103 s[stage].bind(ty, thread_y) 104 s[stage].bind(tx, thread_x) 105 s[stage].vectorize(xi) 106 107 cfg.define_knob('auto_unroll_max_step', [512, 1500]) 108 s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) 109 s[C].pragma(by, 'unroll_explicit', False) 110 111 cfg.add_flop(n*m*l*2) 112 return s, [A, B, C] 113 114 115if __name__ == '__main__': 116 N = 2048 117 n = m = l = N 118 119 logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) 120 task = autotvm.task.create(gemm_int8, args=(n, m, l), target='cuda') 121 print(task.config_space) 122 123 measure_option = autotvm.measure_option( 124 builder=autotvm.LocalBuilder(), 125 runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4) 126 ) 127 128 log_name = 'gemm_int8.log' 129 if DO_TUNING: 130 tuner = autotvm.tuner.XGBTuner(task) 131 tuner.tune(n_trial=1000, measure_option=measure_option, 132 callbacks=[autotvm.callback.log_to_file(log_name)]) 133 134 dispatch_context = autotvm.apply_history_best(log_name) 135 best_config = dispatch_context.query(task.target, task.workload) 136 print('\nBest config:') 137 print(best_config) 138 else: 139 config = task.config_space.get(PRETUNED_INDEX) 140 dispatch_context = autotvm.task.ApplyConfig(config) 141 print("Using pretuned config:") 142 print(config) 143 144 with dispatch_context: 145 with tvm.target.create('cuda'): 146 s, arg_bufs = gemm_int8(n, m, l) 147 f = tvm.build(s, arg_bufs, 'cuda', name='gemm_int8') 148 149 ctx = tvm.context('cuda', 0) 150 151 a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype='int8') 152 b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype='int8') 153 154 a = tvm.nd.array(a_np, ctx) 155 b = tvm.nd.array(b_np, ctx) 156 c = tvm.nd.array(np.zeros((n, m), dtype='int32'), ctx) 157 f(a, b, c) 158 159 tvm.testing.assert_allclose( 160 c.asnumpy(), 161 np.dot( 162 a_np.astype('int32'), 163 b_np.T.astype('int32')), 164 rtol=1e-5) 165 166 num_ops = 2 * l * m * n 167 num_runs = 1000 168 timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs) 169 t = timer_f(a, b, c).mean 170 GOPS = num_ops / (t * 1e3) / 1e6 171 print("average time cost of %d runs = %g ms, %g GOPS." % 172 (num_runs, t * 1e3, GOPS)) 173