1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""measure bandwidth and compute peak"""
19
20import logging
21import tvm
22from . import util
23from .. import rpc
24
25def _convert_to_remote(func, remote):
26    """ convert module function to remote rpc function"""
27    temp = util.tempdir()
28    path_dso = temp.relpath("tmp_func.tar")
29    func.export_library(path_dso)
30
31    remote.upload(path_dso)
32    func = remote.load_module("tmp_func.tar")
33    return func
34
35def measure_bandwidth_sum(total_item, item_per_thread, stride,
36                          base_type, bits, lanes,
37                          target, target_host, remote, ctx, n_times):
38    """ measure memory bandwidth of gpu by product reduction for a given type
39
40    The IR for measurement is
41
42    for each thread
43        for i in 1..num_per_thread:
44            y[global_id] = y[global_id] * x[base + i * stride]
45
46    Parameters
47    ----------
48    total_item: int
49        number of elements in input array
50    item_per_thread: int
51        number of elements each thread accumulates
52    stride: int
53        stride in memory access
54    base_type: str
55        can be "int", "float"
56    bits: int
57        can be 16, 32
58    lanes: int
59       lane of the vector type, can be 1, 2, 4, 8, 16
60    target: :any:`tvm.target.Target`
61        the target and option of the compilation.
62    target_host : str or :any:`tvm.target.Target`
63        host compilation target
64    ctx: TVMcontext
65        the context of array
66    remote: tvm.rpc.RPCSession
67        remote rpc session
68    n_times: int
69        number of runs for taking mean
70
71    Returns
72    -------
73    GBPS: float
74         gigabyte per second
75    """
76    n, m = total_item, item_per_thread
77    n //= lanes
78
79    base_type = str(base_type) + str(bits)
80    dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)
81
82    k = tvm.reduce_axis((0, m), name="k")
83
84    x = tvm.placeholder((n,), dtype=dtype, name="x")
85    op = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="sum")
86    y = tvm.compute((n // m,),
87                    lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k))
88    s = tvm.create_schedule(y.op)
89
90    yo, yi = s[y].split(y.op.axis[0], target.max_num_threads)
91    s[y].bind(yo, tvm.thread_axis("blockIdx.x"))
92    s[y].bind(yi, tvm.thread_axis("threadIdx.x"))
93    s[y].unroll(k)
94
95    try:
96        func = tvm.build(s, [x, y], target, target_host=target_host)
97
98        x = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
99        y = tvm.nd.empty((n // m,), dtype=dtype, ctx=ctx)
100
101        func = _convert_to_remote(func, remote)
102        time_f = func.time_evaluator(func.entry_name, ctx, number=n_times)
103        time = time_f(x, y).mean
104    except tvm._ffi.base.TVMError:
105        # build error (occur when device does not support half)
106        return -1
107
108    return 1.0 * (total_item * bits / 8) / 1e9 / time
109
110def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
111                                target, target_host, remote, ctx, verbose=True):
112    """ measure memory bandwidth for all types
113
114    Parameters
115    ----------
116    total_item: int
117        number of elements in input array
118    item_per_thread: int
119        number of elements each thread accmulates
120    n_times: int
121        number of runs for averaging
122    target: :any:`tvm.target.Target`
123        the target and option of the compilation.
124    target_host : str or :any:`tvm.target.Target`
125        host compilation target
126    remote: tvm.rpc.RPCSession
127        remote rpc session
128    ctx: TVMcontext
129        the context of array
130    verbose: bool
131        whether outputs immediate result
132
133    Returns
134    -------
135    result: list
136        a list of (type_name, GBPS) pairs
137    """
138    max_threads = target.max_num_threads
139
140    result = []
141    for base_type in ["float"]:
142        for bits in [32]:
143            for lanes in [1, 2, 4, 8, 16]:
144                max_speed = -1e9
145                # try different strides
146                for stride in [max_threads, total_item // (lanes * item_per_thread)]:
147                    speed = measure_bandwidth_sum(total_item, item_per_thread, stride,
148                                                  base_type, bits, lanes, target,
149                                                  target_host, remote, ctx, n_times)
150                    max_speed = max(max_speed, speed)
151                type_name = base_type + str(bits)
152                result.append(["%sx%d" % (type_name, lanes), max_speed])
153                if verbose:
154                    logging.info("\t%-10s %.2f GBPS", result[-1][0], result[-1][1])
155    return result
156
157def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
158                        target, target_host, remote, ctx, n_times):
159    """ measure peak compute speed by computing mad for a type
160
161    The IR for measurement is
162
163    for each thread
164        for i in 1..item_per_thread
165            x = mad(x, x, y)
166            y = mad(y, y, x)
167
168    Parameters
169    ----------
170    total_item: int
171        number of elements in input array
172    item_per_thread: int
173        number of operations each thread does
174    base_type: str
175        can be "int", "float"
176    bits: int
177        can be 16, 32
178    lanes: int
179       lane of the vector type, can be 1, 2, 4, 8, 16
180    target: :any:`tvm.target.Target`
181        the target and option of the compilation.
182    target_host : str or :any:`tvm.target.Target`
183        host compilation target
184    remote: tvm.rpc.RPCSession
185        if it is not None, use remote rpc session
186    ctx: TVMcontext
187        the context of array
188    n_times: int
189        number of runs for taking mean
190
191    Returns
192    -------
193    GOPS: float
194         giga operation per second
195    """
196
197    n = total_item
198
199    if bits >= 64 or lanes >= 16:
200        n //= 2
201
202    max_threads = target.max_num_threads
203
204    base_type = str(base_type) + str(bits)
205    dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)
206
207    def extern(ins, outs):
208        # pylint: disable=unused-argument
209        """construct measurement function by building IR directly"""
210        ib = tvm.ir_builder.create()
211
212        bx = tvm.thread_axis("blockIdx.x")
213        tx = tvm.thread_axis("threadIdx.x")
214
215        ib.scope_attr(bx, "thread_extent", n // max_threads)
216        ib.scope_attr(tx, "thread_extent", max_threads)
217
218        idx = bx.var * max_threads + tx.var
219
220        a = ib.allocate(dtype, (1), name='a', scope='local')
221        b = ib.allocate(dtype, (1), name='b', scope='local')
222
223        a[0] = outs[0].vload(idx, dtype)
224        b[0] = outs[0].vload(idx, dtype)
225
226        if base_type.find('float') != -1:
227            mad_func = lambda x, y: (x * x + y)
228        else:
229            mad_func = lambda x, y: y * y + x
230
231        for _ in range(item_per_thread // 4 // lanes):
232            a[0] = mad_func(a[0], b[0])
233            b[0] = mad_func(b[0], a[0])
234
235        ib.emit(outs[0].vstore(idx, b[0]))
236        return ib.get()
237
238    y = tvm.extern((n,), [], extern, name="y", dtype=dtype)
239    s = tvm.create_schedule(y.op)
240
241    try:
242        func = tvm.build(s, [y], target, target_host=target_host)
243        func = _convert_to_remote(func, remote)
244        time_f = func.time_evaluator(func.entry_name, ctx, number=n_times)
245        y = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
246        time = time_f(y).mean
247    except tvm._ffi.base.TVMError:
248        # build error (occur when device does not support half)
249        return -1
250
251    return 1.0 * (n * item_per_thread) / 1e9 / time
252
253def measure_compute_all_types(total_item, item_per_thread, n_times,
254                              target, target_host, remote, ctx, verbose=True):
255    """ measure peak flops for all types
256
257    Parameters
258    ----------
259    total_item: int
260        number of elements in input array
261    item_per_thread: int
262        number of elements each thread accmulates
263    n_times: int
264        number of runs for averaging
265    target: :any:`tvm.target.Target`
266        the target and option of the compilation.
267    target_host : str or :any:`tvm.target.Target`
268        host compilation target
269    remote: tvm.rpc.RPCSession
270        remote rpc session
271    ctx: TVMcontext
272        the context of array
273    verbose: bool
274        whether outputs immediate result
275
276    Returns
277    -------
278    result: list
279        a list of (type_name, GFLOPS/GIOPS) pairs
280    """
281    result = []
282    for base_type in ["float", "int"]:
283        for bits in [16, 32, 64]:
284            for lanes in [1, 2, 4, 8, 16]:
285                if base_type == 'int' and bits != 32:  # only measure int32
286                    continue
287
288                max_speed = -1e9
289                for per_thread in [item_per_thread//2, item_per_thread, item_per_thread*2]:
290                    speed = measure_compute_mad(total_item, per_thread,
291                                                base_type, bits, lanes, target,
292                                                target_host, remote, ctx, n_times)
293                    max_speed = max(max_speed, speed)
294                type_name = base_type + str(bits)
295                result.append(["%sx%d" % (type_name, lanes), max_speed])
296
297                unit = "GFLOPS" if base_type == "float" else "GIOPS"
298
299                if verbose:
300                    logging.info("\t%-10s %.2f %s", result[-1][0], result[-1][1], unit)
301
302    return result
303
304
305def measure_peak_all(target, target_host, host, port):
306    """measure memory bandwidth and peak compute for gpu devices
307
308    Parameters
309    ----------
310    target: str or :any:`tvm.target.Target`
311    target_host: str
312    host: str
313    port: int
314    """
315
316    target = tvm.target.create(target)
317    remote = rpc.connect(host, port)
318    n_times = 20
319
320    bandwidth_total_item = 1 << 25
321    bandwidth_item_per_thread = 32
322
323    compute_total_item = 1 << 21
324    compute_item_per_thread = 4096
325
326    if str(target).startswith("opencl"):
327        ctx = remote.cl()
328    elif str(target).startswith("cuda"):
329        ctx = remote.gpu()
330    elif str(target).startswith("metal"):
331        ctx = remote.metal()
332    else:
333        raise RuntimeError("Unsupported target")
334
335    logging.info("========== measure memory bandwidth ==========")
336    measure_bandwidth_all_types(bandwidth_total_item, bandwidth_item_per_thread,
337                                n_times, target, target_host, remote, ctx)
338
339    logging.info("========== measure peak compute ==========")
340    measure_compute_all_types(compute_total_item, compute_item_per_thread,
341                              n_times, target, target_host, remote, ctx)
342