1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18""" Test measurement and log serialization. """ 19 20import tvm 21from tvm import topi 22from tvm import te, auto_scheduler 23import tempfile 24import tvm.testing 25 26from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul 27 28 29def record_common(dag, s): 30 target = tvm.target.Target("llvm") 31 task = auto_scheduler.SearchTask(dag, "test", target) 32 33 inp = auto_scheduler.measure.MeasureInput(task, s) 34 res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) 35 36 with tempfile.NamedTemporaryFile() as fp: 37 auto_scheduler.save_records(fp.name, [inp], [res]) 38 39 log_reader = auto_scheduler.RecordReader(fp.name) 40 inputs, results = log_reader.read_lines() 41 assert len(inputs) == 1 42 43 s1 = dag.infer_bound_from_state(s) 44 s2 = dag.infer_bound_from_state(inputs[0].state) 45 46 assert s1 == s2 47 assert not (s1 == dag.get_init_state()) 48 49 50def test_record_split_reorder_fuse_annotation(): 51 if not tvm.testing.device_enabled("llvm"): 52 return 53 54 A = te.placeholder((512, 512), name="A") 55 B = te.placeholder((512, 512), name="B") 56 k = te.reduce_axis((0, 512), name="k") 57 C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") 58 59 dag = auto_scheduler.ComputeDAG([A, B, C]) 60 s = dag.get_init_state() 61 62 # Split 63 its0 = s.split(C, s[C].iters[0], [4, 8, 8]) 64 its1 = s.split(C, s[C].iters[4], [8, 4, 4]) 65 # Reorder 66 s.reorder( 67 C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]] 68 ) 69 # Fuse 70 s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) 71 # Parallel 72 s.parallel(C, s[C].iters[0]) 73 # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) 74 s.bind(C, s[C].iters[1], "blockIdx.x") 75 s.bind(C, s[C].iters[2], "threadIdx.z") 76 s.bind(C, s[C].iters[3], "vthread") 77 # Unroll 78 s.unroll(C, s[C].iters[4]) 79 # Vectorize 80 s.vectorize(C, s[C].iters[6]) 81 82 record_common(dag, s) 83 84 85def test_record_compute_at_root_inline_cache_read_write(): 86 if not tvm.testing.device_enabled("llvm"): 87 return 88 89 A = te.placeholder((512, 512), name="A") 90 AA = topi.nn.relu(A) 91 B = te.placeholder((512, 512), name="B") 92 k = te.reduce_axis((0, 512), name="k") 93 C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name="C") 94 95 dag = auto_scheduler.ComputeDAG([A, B, C]) 96 s = dag.get_init_state() 97 98 # Cache Write 99 C_shared = s.cache_write(C, "shared") 100 # Compute At 101 s.compute_at(C_shared, C, s[C].iters[0]) 102 # Cache Read 103 B_global = s.cache_read(B, "global", [C_shared]) 104 s.compute_at(B_global, C_shared, s[C_shared].iters[2]) 105 # Compute Inline 106 s.compute_inline(AA) 107 # Compute Root 108 s.compute_root(C_shared) 109 110 record_common(dag, s) 111 112 113def test_record_follow_split_follow_fused_split(): 114 if not tvm.testing.device_enabled("llvm"): 115 return 116 117 A = te.placeholder((512, 512), name="A") 118 B = te.placeholder((512, 512), name="B") 119 k = te.reduce_axis((0, 512), name="k") 120 C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") 121 D = topi.nn.relu(C) 122 E = topi.nn.relu(D) 123 124 dag = auto_scheduler.ComputeDAG([A, B, E]) 125 s = dag.get_init_state() 126 127 # Follow Split 128 s.split(C, s[C].iters[0], [4, 2, 8, 4], True) 129 split_step0 = len(s.transform_steps) - 1 130 s.follow_split(C, s[C].iters[5], split_step0, 4) 131 # Follow Fused Split 132 its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True) 133 split_step1 = len(s.transform_steps) - 1 134 its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True) 135 split_step2 = len(s.transform_steps) - 1 136 its = [] 137 for i0, i1 in zip(its0, its1): 138 its.append(i0) 139 its.append(i1) 140 for i in range(0, 5): 141 s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]]) 142 s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True) 143 144 record_common(dag, s) 145 146 147def test_record_pragma_storage_align_rfactor(): 148 if not tvm.testing.device_enabled("llvm"): 149 return 150 151 A = te.placeholder((512, 512), name="A") 152 B = te.placeholder((512, 512), name="B") 153 k = te.reduce_axis((0, 512), name="k") 154 C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") 155 156 dag = auto_scheduler.ComputeDAG([A, B, C]) 157 s = dag.get_init_state() 158 159 # Rfactor 160 ko, _ = s.split(C, s[C].iters[2], [16]) 161 s.rfactor(C, ko, 2) 162 # Pragma 163 s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64") 164 # StorageAlign 165 s.storage_align(C, s[C].iters[-1], 8, 4) 166 167 record_common(dag, s) 168 169 170def test_measure_local_builder_runner(enable_cpu_cache_flush=False): 171 if not tvm.testing.device_enabled("llvm"): 172 return 173 174 dag, s0 = get_tiled_matmul() 175 tgt = tvm.target.Target("llvm") 176 task = auto_scheduler.SearchTask(dag, "test", tgt) 177 178 minp = auto_scheduler.MeasureInput(task, s0) 179 local_builder = auto_scheduler.LocalBuilder() 180 local_runner = auto_scheduler.LocalRunner( 181 timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush 182 ) 183 184 bress = local_builder.build([minp]) 185 assert bress[0].error_no == 0 186 mress = local_runner.run([minp], bress) 187 assert mress[0].error_no == 0 188 189 190def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False): 191 if not tvm.testing.device_enabled("llvm"): 192 return 193 194 dag, s0 = get_tiled_matmul() 195 tgt = tvm.target.Target("llvm") 196 task = auto_scheduler.SearchTask(dag, "test", tgt) 197 198 minp = auto_scheduler.MeasureInput(task, s0) 199 local_builder = auto_scheduler.LocalBuilder() 200 measure_ctx = auto_scheduler.LocalRPCMeasureContext( 201 timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush 202 ) 203 rpc_runner = measure_ctx.runner 204 205 bress = local_builder.build([minp]) 206 assert bress[0].error_no == 0 207 mress = rpc_runner.run([minp], bress) 208 assert mress[0].error_no == 0 209 210 211if __name__ == "__main__": 212 test_record_split_reorder_fuse_annotation() 213 test_record_compute_at_root_inline_cache_read_write() 214 test_record_follow_split_follow_fused_split() 215 test_record_pragma_storage_align_rfactor() 216 test_measure_local_builder_runner(enable_cpu_cache_flush=True) 217 test_measure_local_builder_runner(enable_cpu_cache_flush=False) 218 test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=True) 219 test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False) 220