1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"Example code to perform int8 GEMM" 18import logging 19import sys 20import numpy as np 21import tvm 22from tvm import te 23from tvm import autotvm 24from tvm.topi.cuda.tensor_intrin import dp4a 25 26DO_TUNING = True 27PRETUNED_INDEX = 75333 28 29intrin_dp4a = dp4a("local", "local", "local") 30 31 32@autotvm.template 33def gemm_int8(n, m, l): 34 A = te.placeholder((n, l), name="A", dtype="int8") 35 B = te.placeholder((m, l), name="B", dtype="int8") 36 37 k = te.reduce_axis((0, l), name="k") 38 C = te.compute( 39 (n, m), 40 lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), 41 name="C", 42 ) 43 44 cfg = autotvm.get_config() 45 s = te.create_schedule(C.op) 46 y, x = C.op.axis 47 48 AA = s.cache_read(A, "shared", [C]) 49 BB = s.cache_read(B, "shared", [C]) 50 AL = s.cache_read(AA, "local", [C]) 51 BL = s.cache_read(BB, "local", [C]) 52 CC = s.cache_write(C, "local") 53 54 k = CC.op.reduce_axis[0] 55 56 cfg.define_split( 57 "tile_k", 58 cfg.axis(k), 59 num_outputs=3, 60 filter=lambda entity: entity.size[2] == 4 and entity.size[0] * 2 >= entity.size[1], 61 ) 62 63 ko, kt, ki = cfg["tile_k"].apply(s, CC, k) 64 65 s[CC].tensorize(ki, intrin_dp4a) 66 67 block_x = te.thread_axis("blockIdx.x") 68 block_y = te.thread_axis("blockIdx.y") 69 thread_x = te.thread_axis("threadIdx.x") 70 thread_y = te.thread_axis("threadIdx.y") 71 72 def block_size_filter(entity): 73 return ( 74 entity.size[0] * 2 >= entity.size[1] * 2 75 and entity.size[1] <= 16 76 and entity.size[3] <= 4 77 ) 78 79 cfg.define_split("tile_y", cfg.axis(y), num_outputs=4, filter=block_size_filter) 80 cfg.define_split("tile_x", cfg.axis(x), num_outputs=4, filter=block_size_filter) 81 by, tyz, ty, yi = cfg["tile_y"].apply(s, C, y) 82 bx, txz, tx, xi = cfg["tile_x"].apply(s, C, x) 83 84 s[C].bind(by, block_y) 85 s[C].bind(bx, block_x) 86 s[C].bind(tyz, te.thread_axis("vthread")) 87 s[C].bind(txz, te.thread_axis("vthread")) 88 s[C].bind(ty, thread_y) 89 s[C].bind(tx, thread_x) 90 s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) 91 92 s[CC].compute_at(s[C], tx) 93 94 yo, xo = CC.op.axis 95 s[CC].reorder(ko, kt, yo, xo, ki) 96 s[CC].unroll(kt) 97 98 for stage in [AL, BL]: 99 s[stage].compute_at(s[CC], kt) 100 _, xi = s[stage].split(stage.op.axis[1], factor=4) 101 s[stage].vectorize(xi) 102 s[stage].double_buffer() 103 104 cfg.define_knob("storage_align", [16, 48]) 105 for stage in [AA, BB]: 106 s[stage].storage_align(s[stage].op.axis[0], cfg["storage_align"].val, 0) 107 s[stage].compute_at(s[CC], ko) 108 109 fused = s[stage].fuse(*s[stage].op.axis) 110 ty, tx = s[stage].split(fused, nparts=cfg["tile_y"].size[2]) 111 tx, xi = s[stage].split(tx, nparts=cfg["tile_x"].size[2]) 112 _, xi = s[stage].split(xi, factor=16) 113 114 s[stage].bind(ty, thread_y) 115 s[stage].bind(tx, thread_x) 116 s[stage].vectorize(xi) 117 118 cfg.define_knob("auto_unroll_max_step", [512, 1500]) 119 s[C].pragma(by, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) 120 s[C].pragma(by, "unroll_explicit", False) 121 122 cfg.add_flop(n * m * l * 2) 123 return s, [A, B, C] 124 125 126if __name__ == "__main__": 127 N = 2048 128 n = m = l = N 129 130 logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) 131 task = autotvm.task.create(gemm_int8, args=(n, m, l), target="cuda") 132 print(task.config_space) 133 134 measure_option = autotvm.measure_option( 135 builder=autotvm.LocalBuilder(), 136 runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4), 137 ) 138 139 log_name = "gemm_int8.log" 140 if DO_TUNING: 141 tuner = autotvm.tuner.XGBTuner(task) 142 tuner.tune( 143 n_trial=1000, 144 measure_option=measure_option, 145 callbacks=[autotvm.callback.log_to_file(log_name)], 146 ) 147 148 dispatch_context = autotvm.apply_history_best(log_name) 149 best_config = dispatch_context.query(task.target, task.workload) 150 print("\nBest config:") 151 print(best_config) 152 else: 153 config = task.config_space.get(PRETUNED_INDEX) 154 dispatch_context = autotvm.task.ApplyConfig(config) 155 print("Using pretuned config:") 156 print(config) 157 158 with dispatch_context: 159 with tvm.target.Target("cuda"): 160 s, arg_bufs = gemm_int8(n, m, l) 161 f = tvm.build(s, arg_bufs, "cuda", name="gemm_int8") 162 163 ctx = tvm.context("cuda", 0) 164 165 a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype="int8") 166 b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype="int8") 167 168 a = tvm.nd.array(a_np, ctx) 169 b = tvm.nd.array(b_np, ctx) 170 c = tvm.nd.array(np.zeros((n, m), dtype="int32"), ctx) 171 f(a, b, c) 172 173 tvm.testing.assert_allclose( 174 c.asnumpy(), np.dot(a_np.astype("int32"), b_np.T.astype("int32")), rtol=1e-5 175 ) 176 177 num_ops = 2 * l * m * n 178 num_runs = 1000 179 timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs) 180 t = timer_f(a, b, c).mean 181 GOPS = num_ops / (t * 1e3) / 1e6 182 print("average time cost of %d runs = %g ms, %g GOPS." % (num_runs, t * 1e3, GOPS)) 183