1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18""" Test measurement and log serialization. """
19
20import tvm
21from tvm import topi
22from tvm import te, auto_scheduler
23import tempfile
24import tvm.testing
25
26from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
27
28
29def record_common(dag, s):
30    target = tvm.target.Target("llvm")
31    task = auto_scheduler.SearchTask(dag, "test", target)
32
33    inp = auto_scheduler.measure.MeasureInput(task, s)
34    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
35
36    with tempfile.NamedTemporaryFile() as fp:
37        auto_scheduler.save_records(fp.name, [inp], [res])
38
39        log_reader = auto_scheduler.RecordReader(fp.name)
40        inputs, results = log_reader.read_lines()
41        assert len(inputs) == 1
42
43        s1 = dag.infer_bound_from_state(s)
44        s2 = dag.infer_bound_from_state(inputs[0].state)
45
46        assert s1 == s2
47        assert not (s1 == dag.get_init_state())
48
49
50def test_record_split_reorder_fuse_annotation():
51    if not tvm.testing.device_enabled("llvm"):
52        return
53
54    A = te.placeholder((512, 512), name="A")
55    B = te.placeholder((512, 512), name="B")
56    k = te.reduce_axis((0, 512), name="k")
57    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
58
59    dag = auto_scheduler.ComputeDAG([A, B, C])
60    s = dag.get_init_state()
61
62    # Split
63    its0 = s.split(C, s[C].iters[0], [4, 8, 8])
64    its1 = s.split(C, s[C].iters[4], [8, 4, 4])
65    # Reorder
66    s.reorder(
67        C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]
68    )
69    # Fuse
70    s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
71    # Parallel
72    s.parallel(C, s[C].iters[0])
73    # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
74    s.bind(C, s[C].iters[1], "blockIdx.x")
75    s.bind(C, s[C].iters[2], "threadIdx.z")
76    s.bind(C, s[C].iters[3], "vthread")
77    # Unroll
78    s.unroll(C, s[C].iters[4])
79    # Vectorize
80    s.vectorize(C, s[C].iters[6])
81
82    record_common(dag, s)
83
84
85def test_record_compute_at_root_inline_cache_read_write():
86    if not tvm.testing.device_enabled("llvm"):
87        return
88
89    A = te.placeholder((512, 512), name="A")
90    AA = topi.nn.relu(A)
91    B = te.placeholder((512, 512), name="B")
92    k = te.reduce_axis((0, 512), name="k")
93    C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name="C")
94
95    dag = auto_scheduler.ComputeDAG([A, B, C])
96    s = dag.get_init_state()
97
98    # Cache Write
99    C_shared = s.cache_write(C, "shared")
100    # Compute At
101    s.compute_at(C_shared, C, s[C].iters[0])
102    # Cache Read
103    B_global = s.cache_read(B, "global", [C_shared])
104    s.compute_at(B_global, C_shared, s[C_shared].iters[2])
105    # Compute Inline
106    s.compute_inline(AA)
107    # Compute Root
108    s.compute_root(C_shared)
109
110    record_common(dag, s)
111
112
113def test_record_follow_split_follow_fused_split():
114    if not tvm.testing.device_enabled("llvm"):
115        return
116
117    A = te.placeholder((512, 512), name="A")
118    B = te.placeholder((512, 512), name="B")
119    k = te.reduce_axis((0, 512), name="k")
120    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
121    D = topi.nn.relu(C)
122    E = topi.nn.relu(D)
123
124    dag = auto_scheduler.ComputeDAG([A, B, E])
125    s = dag.get_init_state()
126
127    # Follow Split
128    s.split(C, s[C].iters[0], [4, 2, 8, 4], True)
129    split_step0 = len(s.transform_steps) - 1
130    s.follow_split(C, s[C].iters[5], split_step0, 4)
131    # Follow Fused Split
132    its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True)
133    split_step1 = len(s.transform_steps) - 1
134    its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True)
135    split_step2 = len(s.transform_steps) - 1
136    its = []
137    for i0, i1 in zip(its0, its1):
138        its.append(i0)
139        its.append(i1)
140    for i in range(0, 5):
141        s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]])
142    s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True)
143
144    record_common(dag, s)
145
146
147def test_record_pragma_storage_align_rfactor():
148    if not tvm.testing.device_enabled("llvm"):
149        return
150
151    A = te.placeholder((512, 512), name="A")
152    B = te.placeholder((512, 512), name="B")
153    k = te.reduce_axis((0, 512), name="k")
154    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
155
156    dag = auto_scheduler.ComputeDAG([A, B, C])
157    s = dag.get_init_state()
158
159    # Rfactor
160    ko, _ = s.split(C, s[C].iters[2], [16])
161    s.rfactor(C, ko, 2)
162    # Pragma
163    s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64")
164    # StorageAlign
165    s.storage_align(C, s[C].iters[-1], 8, 4)
166
167    record_common(dag, s)
168
169
170def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
171    if not tvm.testing.device_enabled("llvm"):
172        return
173
174    dag, s0 = get_tiled_matmul()
175    tgt = tvm.target.Target("llvm")
176    task = auto_scheduler.SearchTask(dag, "test", tgt)
177
178    minp = auto_scheduler.MeasureInput(task, s0)
179    local_builder = auto_scheduler.LocalBuilder()
180    local_runner = auto_scheduler.LocalRunner(
181        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
182    )
183
184    bress = local_builder.build([minp])
185    assert bress[0].error_no == 0
186    mress = local_runner.run([minp], bress)
187    assert mress[0].error_no == 0
188
189
190def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
191    if not tvm.testing.device_enabled("llvm"):
192        return
193
194    dag, s0 = get_tiled_matmul()
195    tgt = tvm.target.Target("llvm")
196    task = auto_scheduler.SearchTask(dag, "test", tgt)
197
198    minp = auto_scheduler.MeasureInput(task, s0)
199    local_builder = auto_scheduler.LocalBuilder()
200    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
201        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
202    )
203    rpc_runner = measure_ctx.runner
204
205    bress = local_builder.build([minp])
206    assert bress[0].error_no == 0
207    mress = rpc_runner.run([minp], bress)
208    assert mress[0].error_no == 0
209
210
211if __name__ == "__main__":
212    test_record_split_reorder_fuse_annotation()
213    test_record_compute_at_root_inline_cache_read_write()
214    test_record_follow_split_follow_fused_split()
215    test_record_pragma_storage_align_rfactor()
216    test_measure_local_builder_runner(enable_cpu_cache_flush=True)
217    test_measure_local_builder_runner(enable_cpu_cache_flush=False)
218    test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=True)
219    test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False)
220