1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""measure bandwidth and compute peak""" 19 20import logging 21import tvm 22from . import util 23from .. import rpc 24 25def _convert_to_remote(func, remote): 26 """ convert module function to remote rpc function""" 27 temp = util.tempdir() 28 path_dso = temp.relpath("tmp_func.tar") 29 func.export_library(path_dso) 30 31 remote.upload(path_dso) 32 func = remote.load_module("tmp_func.tar") 33 return func 34 35def measure_bandwidth_sum(total_item, item_per_thread, stride, 36 base_type, bits, lanes, 37 target, target_host, remote, ctx, n_times): 38 """ measure memory bandwidth of gpu by product reduction for a given type 39 40 The IR for measurement is 41 42 for each thread 43 for i in 1..num_per_thread: 44 y[global_id] = y[global_id] * x[base + i * stride] 45 46 Parameters 47 ---------- 48 total_item: int 49 number of elements in input array 50 item_per_thread: int 51 number of elements each thread accumulates 52 stride: int 53 stride in memory access 54 base_type: str 55 can be "int", "float" 56 bits: int 57 can be 16, 32 58 lanes: int 59 lane of the vector type, can be 1, 2, 4, 8, 16 60 target: :any:`tvm.target.Target` 61 the target and option of the compilation. 62 target_host : str or :any:`tvm.target.Target` 63 host compilation target 64 ctx: TVMcontext 65 the context of array 66 remote: tvm.rpc.RPCSession 67 remote rpc session 68 n_times: int 69 number of runs for taking mean 70 71 Returns 72 ------- 73 GBPS: float 74 gigabyte per second 75 """ 76 n, m = total_item, item_per_thread 77 n //= lanes 78 79 base_type = str(base_type) + str(bits) 80 dtype = base_type if lanes == 1 else base_type + "x" + str(lanes) 81 82 k = tvm.reduce_axis((0, m), name="k") 83 84 x = tvm.placeholder((n,), dtype=dtype, name="x") 85 op = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="sum") 86 y = tvm.compute((n // m,), 87 lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k)) 88 s = tvm.create_schedule(y.op) 89 90 yo, yi = s[y].split(y.op.axis[0], target.max_num_threads) 91 s[y].bind(yo, tvm.thread_axis("blockIdx.x")) 92 s[y].bind(yi, tvm.thread_axis("threadIdx.x")) 93 s[y].unroll(k) 94 95 try: 96 func = tvm.build(s, [x, y], target, target_host=target_host) 97 98 x = tvm.nd.empty((n,), dtype=dtype, ctx=ctx) 99 y = tvm.nd.empty((n // m,), dtype=dtype, ctx=ctx) 100 101 func = _convert_to_remote(func, remote) 102 time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 103 time = time_f(x, y).mean 104 except tvm._ffi.base.TVMError: 105 # build error (occur when device does not support half) 106 return -1 107 108 return 1.0 * (total_item * bits / 8) / 1e9 / time 109 110def measure_bandwidth_all_types(total_item, item_per_thread, n_times, 111 target, target_host, remote, ctx, verbose=True): 112 """ measure memory bandwidth for all types 113 114 Parameters 115 ---------- 116 total_item: int 117 number of elements in input array 118 item_per_thread: int 119 number of elements each thread accmulates 120 n_times: int 121 number of runs for averaging 122 target: :any:`tvm.target.Target` 123 the target and option of the compilation. 124 target_host : str or :any:`tvm.target.Target` 125 host compilation target 126 remote: tvm.rpc.RPCSession 127 remote rpc session 128 ctx: TVMcontext 129 the context of array 130 verbose: bool 131 whether outputs immediate result 132 133 Returns 134 ------- 135 result: list 136 a list of (type_name, GBPS) pairs 137 """ 138 max_threads = target.max_num_threads 139 140 result = [] 141 for base_type in ["float"]: 142 for bits in [32]: 143 for lanes in [1, 2, 4, 8, 16]: 144 max_speed = -1e9 145 # try different strides 146 for stride in [max_threads, total_item // (lanes * item_per_thread)]: 147 speed = measure_bandwidth_sum(total_item, item_per_thread, stride, 148 base_type, bits, lanes, target, 149 target_host, remote, ctx, n_times) 150 max_speed = max(max_speed, speed) 151 type_name = base_type + str(bits) 152 result.append(["%sx%d" % (type_name, lanes), max_speed]) 153 if verbose: 154 logging.info("\t%-10s %.2f GBPS", result[-1][0], result[-1][1]) 155 return result 156 157def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes, 158 target, target_host, remote, ctx, n_times): 159 """ measure peak compute speed by computing mad for a type 160 161 The IR for measurement is 162 163 for each thread 164 for i in 1..item_per_thread 165 x = mad(x, x, y) 166 y = mad(y, y, x) 167 168 Parameters 169 ---------- 170 total_item: int 171 number of elements in input array 172 item_per_thread: int 173 number of operations each thread does 174 base_type: str 175 can be "int", "float" 176 bits: int 177 can be 16, 32 178 lanes: int 179 lane of the vector type, can be 1, 2, 4, 8, 16 180 target: :any:`tvm.target.Target` 181 the target and option of the compilation. 182 target_host : str or :any:`tvm.target.Target` 183 host compilation target 184 remote: tvm.rpc.RPCSession 185 if it is not None, use remote rpc session 186 ctx: TVMcontext 187 the context of array 188 n_times: int 189 number of runs for taking mean 190 191 Returns 192 ------- 193 GOPS: float 194 giga operation per second 195 """ 196 197 n = total_item 198 199 if bits >= 64 or lanes >= 16: 200 n //= 2 201 202 max_threads = target.max_num_threads 203 204 base_type = str(base_type) + str(bits) 205 dtype = base_type if lanes == 1 else base_type + "x" + str(lanes) 206 207 def extern(ins, outs): 208 # pylint: disable=unused-argument 209 """construct measurement function by building IR directly""" 210 ib = tvm.ir_builder.create() 211 212 bx = tvm.thread_axis("blockIdx.x") 213 tx = tvm.thread_axis("threadIdx.x") 214 215 ib.scope_attr(bx, "thread_extent", n // max_threads) 216 ib.scope_attr(tx, "thread_extent", max_threads) 217 218 idx = bx.var * max_threads + tx.var 219 220 a = ib.allocate(dtype, (1), name='a', scope='local') 221 b = ib.allocate(dtype, (1), name='b', scope='local') 222 223 a[0] = outs[0].vload(idx, dtype) 224 b[0] = outs[0].vload(idx, dtype) 225 226 if base_type.find('float') != -1: 227 mad_func = lambda x, y: (x * x + y) 228 else: 229 mad_func = lambda x, y: y * y + x 230 231 for _ in range(item_per_thread // 4 // lanes): 232 a[0] = mad_func(a[0], b[0]) 233 b[0] = mad_func(b[0], a[0]) 234 235 ib.emit(outs[0].vstore(idx, b[0])) 236 return ib.get() 237 238 y = tvm.extern((n,), [], extern, name="y", dtype=dtype) 239 s = tvm.create_schedule(y.op) 240 241 try: 242 func = tvm.build(s, [y], target, target_host=target_host) 243 func = _convert_to_remote(func, remote) 244 time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 245 y = tvm.nd.empty((n,), dtype=dtype, ctx=ctx) 246 time = time_f(y).mean 247 except tvm._ffi.base.TVMError: 248 # build error (occur when device does not support half) 249 return -1 250 251 return 1.0 * (n * item_per_thread) / 1e9 / time 252 253def measure_compute_all_types(total_item, item_per_thread, n_times, 254 target, target_host, remote, ctx, verbose=True): 255 """ measure peak flops for all types 256 257 Parameters 258 ---------- 259 total_item: int 260 number of elements in input array 261 item_per_thread: int 262 number of elements each thread accmulates 263 n_times: int 264 number of runs for averaging 265 target: :any:`tvm.target.Target` 266 the target and option of the compilation. 267 target_host : str or :any:`tvm.target.Target` 268 host compilation target 269 remote: tvm.rpc.RPCSession 270 remote rpc session 271 ctx: TVMcontext 272 the context of array 273 verbose: bool 274 whether outputs immediate result 275 276 Returns 277 ------- 278 result: list 279 a list of (type_name, GFLOPS/GIOPS) pairs 280 """ 281 result = [] 282 for base_type in ["float", "int"]: 283 for bits in [16, 32, 64]: 284 for lanes in [1, 2, 4, 8, 16]: 285 if base_type == 'int' and bits != 32: # only measure int32 286 continue 287 288 max_speed = -1e9 289 for per_thread in [item_per_thread//2, item_per_thread, item_per_thread*2]: 290 speed = measure_compute_mad(total_item, per_thread, 291 base_type, bits, lanes, target, 292 target_host, remote, ctx, n_times) 293 max_speed = max(max_speed, speed) 294 type_name = base_type + str(bits) 295 result.append(["%sx%d" % (type_name, lanes), max_speed]) 296 297 unit = "GFLOPS" if base_type == "float" else "GIOPS" 298 299 if verbose: 300 logging.info("\t%-10s %.2f %s", result[-1][0], result[-1][1], unit) 301 302 return result 303 304 305def measure_peak_all(target, target_host, host, port): 306 """measure memory bandwidth and peak compute for gpu devices 307 308 Parameters 309 ---------- 310 target: str or :any:`tvm.target.Target` 311 target_host: str 312 host: str 313 port: int 314 """ 315 316 target = tvm.target.create(target) 317 remote = rpc.connect(host, port) 318 n_times = 20 319 320 bandwidth_total_item = 1 << 25 321 bandwidth_item_per_thread = 32 322 323 compute_total_item = 1 << 21 324 compute_item_per_thread = 4096 325 326 if str(target).startswith("opencl"): 327 ctx = remote.cl() 328 elif str(target).startswith("cuda"): 329 ctx = remote.gpu() 330 elif str(target).startswith("metal"): 331 ctx = remote.metal() 332 else: 333 raise RuntimeError("Unsupported target") 334 335 logging.info("========== measure memory bandwidth ==========") 336 measure_bandwidth_all_types(bandwidth_total_item, bandwidth_item_per_thread, 337 n_times, target, target_host, remote, ctx) 338 339 logging.info("========== measure peak compute ==========") 340 measure_compute_all_types(compute_total_item, compute_item_per_thread, 341 n_times, target, target_host, remote, ctx) 342