1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Example code to do square matrix multiplication."""
18import tvm
19import os
20from tvm.contrib import nvcc
21from tvm.contrib import spirv
22import numpy as np
23
24TASK="gemm"
25USE_MANUAL_CODE = False
26
27@tvm.register_func
28def tvm_callback_cuda_compile(code):
29    ptx =  nvcc.compile_cuda(code, target="ptx")
30    return ptx
31
32def write_code(code, fname):
33    with open(fname, "w") as f:
34        f.write(code)
35
36@tvm.register_func
37def tvm_callback_cuda_postproc(code):
38    if not os.path.exists("perf"):
39        os.mkdir("perf")
40    write_code(code, "perf/%s_generated.cu" % TASK)
41    if USE_MANUAL_CODE:
42        code = open("perf/%s_manual.cu" % TASK).read()
43    return code
44
45
46def test_gemm():
47    # graph
48    nn = 2048
49    n = tvm.var('n')
50    n = tvm.convert(nn)
51    m, l = n, n
52    A = tvm.placeholder((l, n), name='A')
53    B = tvm.placeholder((l, m), name='B')
54    k = tvm.reduce_axis((0, l), name='k')
55    C = tvm.compute(
56        (m, n),
57        lambda ii, jj: tvm.sum(A[k, jj] * B[k, ii], axis=k),
58        name='C')
59
60    # schedule
61    s = tvm.create_schedule(C.op)
62    AA = s.cache_read(A, "shared", [C])
63    BB = s.cache_read(B, "shared", [C])
64    AL = s.cache_read(AA, "local", [C])
65    BL = s.cache_read(BB, "local", [C])
66    CC = s.cache_write(C, "local")
67
68    scale = 8
69    num_thread = 8
70    block_factor = scale * num_thread
71    block_x = tvm.thread_axis("blockIdx.x")
72    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
73    block_y = tvm.thread_axis("blockIdx.y")
74    thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
75    thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx")
76    thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy")
77
78    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
79    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
80    s[C].bind(by, block_y)
81    s[C].bind(bx, block_x)
82    s[C].reorder(by, bx, yi, xi)
83
84    tyz, yi = s[C].split(yi, nparts=2)
85    ty, yi = s[C].split(yi, nparts=num_thread)
86    txz, xi = s[C].split(xi, nparts=2)
87    tx, xi = s[C].split(xi, nparts=num_thread)
88    s[C].bind(tyz, thread_yz)
89    s[C].bind(txz, thread_xz)
90    s[C].bind(ty, thread_y)
91    s[C].bind(tx, thread_x)
92    s[C].reorder(tyz, txz, ty, tx, yi, xi)
93    s[CC].compute_at(s[C], tx)
94
95    yo, xo = CC.op.axis
96    ko, ki = s[CC].split(k, factor=8)
97    kt, ki = s[CC].split(ki, factor=1)
98    s[CC].reorder(ko, kt, ki, yo, xo)
99    s[AA].compute_at(s[CC], ko)
100    s[BB].compute_at(s[CC], ko)
101    s[CC].unroll(kt)
102    s[AL].compute_at(s[CC], kt)
103    s[BL].compute_at(s[CC], kt)
104    # Schedule for A's shared memory load
105    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
106    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4)
107    tx, xi = s[AA].split(xi, nparts=num_thread)
108    s[AA].bind(ty, thread_y)
109    s[AA].bind(tx, thread_x)
110    s[AA].vectorize(xi)
111    # Schedule for B' shared memory load
112    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
113    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4)
114    tx, xi = s[BB].split(xi, nparts=num_thread)
115    s[BB].bind(ty, thread_y)
116    s[BB].bind(tx, thread_x)
117    s[BB].vectorize(xi)
118    s[AA].double_buffer()
119    s[BB].double_buffer()
120    # correctness
121    def check_device(device):
122        ctx = tvm.context(device, 0)
123        if not ctx.exist:
124            print("Skip because %s is not enabled" % device)
125            return
126        print("Device %s" % device)
127        f = tvm.build(s, [A, B, C], device)
128        # launch the kernel.
129        n, m, l = nn, nn, nn
130        a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
131        b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
132        a = tvm.nd.array(a_np, ctx)
133        b = tvm.nd.array(b_np, ctx)
134        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
135        for i in range(2):
136            f(a, b, c)
137        tvm.testing.assert_allclose(
138            c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
139
140        num_flops = 2 * nn * nn * nn
141        num_runs = 10
142        timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
143        t = timer_f(a, b, c).mean
144        GFLOPS = num_flops / (t * 1e3) / 1e6
145        print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
146
147    for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
148        with tvm.build_config(auto_unroll_max_step=128,
149                              unroll_explicit=(device != "cuda")):
150            check_device(device)
151
152if __name__ == "__main__":
153    test_gemm()
154