1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Example code to do square matrix multiplication.""" 18import tvm 19import os 20from tvm.contrib import nvcc 21from tvm.contrib import spirv 22import numpy as np 23 24TASK="gemm" 25USE_MANUAL_CODE = False 26 27@tvm.register_func 28def tvm_callback_cuda_compile(code): 29 ptx = nvcc.compile_cuda(code, target="ptx") 30 return ptx 31 32def write_code(code, fname): 33 with open(fname, "w") as f: 34 f.write(code) 35 36@tvm.register_func 37def tvm_callback_cuda_postproc(code): 38 if not os.path.exists("perf"): 39 os.mkdir("perf") 40 write_code(code, "perf/%s_generated.cu" % TASK) 41 if USE_MANUAL_CODE: 42 code = open("perf/%s_manual.cu" % TASK).read() 43 return code 44 45 46def test_gemm(): 47 # graph 48 nn = 2048 49 n = tvm.var('n') 50 n = tvm.convert(nn) 51 m, l = n, n 52 A = tvm.placeholder((l, n), name='A') 53 B = tvm.placeholder((l, m), name='B') 54 k = tvm.reduce_axis((0, l), name='k') 55 C = tvm.compute( 56 (m, n), 57 lambda ii, jj: tvm.sum(A[k, jj] * B[k, ii], axis=k), 58 name='C') 59 60 # schedule 61 s = tvm.create_schedule(C.op) 62 AA = s.cache_read(A, "shared", [C]) 63 BB = s.cache_read(B, "shared", [C]) 64 AL = s.cache_read(AA, "local", [C]) 65 BL = s.cache_read(BB, "local", [C]) 66 CC = s.cache_write(C, "local") 67 68 scale = 8 69 num_thread = 8 70 block_factor = scale * num_thread 71 block_x = tvm.thread_axis("blockIdx.x") 72 thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") 73 block_y = tvm.thread_axis("blockIdx.y") 74 thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") 75 thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx") 76 thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy") 77 78 by, yi = s[C].split(C.op.axis[0], factor=block_factor) 79 bx, xi = s[C].split(C.op.axis[1], factor=block_factor) 80 s[C].bind(by, block_y) 81 s[C].bind(bx, block_x) 82 s[C].reorder(by, bx, yi, xi) 83 84 tyz, yi = s[C].split(yi, nparts=2) 85 ty, yi = s[C].split(yi, nparts=num_thread) 86 txz, xi = s[C].split(xi, nparts=2) 87 tx, xi = s[C].split(xi, nparts=num_thread) 88 s[C].bind(tyz, thread_yz) 89 s[C].bind(txz, thread_xz) 90 s[C].bind(ty, thread_y) 91 s[C].bind(tx, thread_x) 92 s[C].reorder(tyz, txz, ty, tx, yi, xi) 93 s[CC].compute_at(s[C], tx) 94 95 yo, xo = CC.op.axis 96 ko, ki = s[CC].split(k, factor=8) 97 kt, ki = s[CC].split(ki, factor=1) 98 s[CC].reorder(ko, kt, ki, yo, xo) 99 s[AA].compute_at(s[CC], ko) 100 s[BB].compute_at(s[CC], ko) 101 s[CC].unroll(kt) 102 s[AL].compute_at(s[CC], kt) 103 s[BL].compute_at(s[CC], kt) 104 # Schedule for A's shared memory load 105 ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) 106 _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4) 107 tx, xi = s[AA].split(xi, nparts=num_thread) 108 s[AA].bind(ty, thread_y) 109 s[AA].bind(tx, thread_x) 110 s[AA].vectorize(xi) 111 # Schedule for B' shared memory load 112 ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) 113 _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4) 114 tx, xi = s[BB].split(xi, nparts=num_thread) 115 s[BB].bind(ty, thread_y) 116 s[BB].bind(tx, thread_x) 117 s[BB].vectorize(xi) 118 s[AA].double_buffer() 119 s[BB].double_buffer() 120 # correctness 121 def check_device(device): 122 ctx = tvm.context(device, 0) 123 if not ctx.exist: 124 print("Skip because %s is not enabled" % device) 125 return 126 print("Device %s" % device) 127 f = tvm.build(s, [A, B, C], device) 128 # launch the kernel. 129 n, m, l = nn, nn, nn 130 a_np = np.random.uniform(size=(n, l)).astype(A.dtype) 131 b_np = np.random.uniform(size=(m, l)).astype(B.dtype) 132 a = tvm.nd.array(a_np, ctx) 133 b = tvm.nd.array(b_np, ctx) 134 c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) 135 for i in range(2): 136 f(a, b, c) 137 tvm.testing.assert_allclose( 138 c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5) 139 140 num_flops = 2 * nn * nn * nn 141 num_runs = 10 142 timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs) 143 t = timer_f(a, b, c).mean 144 GFLOPS = num_flops / (t * 1e3) / 1e6 145 print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) 146 147 for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]: 148 with tvm.build_config(auto_unroll_max_step=128, 149 unroll_explicit=(device != "cuda")): 150 check_device(device) 151 152if __name__ == "__main__": 153 test_gemm() 154