1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import ctypes
19
20from mxnet.test_utils import *
21import os
22import time
23import argparse
24
25from mxnet.base import check_call, _LIB
26
27parser = argparse.ArgumentParser(description="Benchmark cast storage operators",
28                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
29parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
30args = parser.parse_args()
31
32def measure_cost(repeat, f, *args, **kwargs):
33    start = time.time()
34    results = []
35    for i in range(repeat):
36        (f(*args, **kwargs)).wait_to_read()
37    end = time.time()
38    diff = end - start
39    return diff / repeat
40
41
42def run_cast_storage_synthetic():
43    def dense_to_sparse(m, n, density, ctx, repeat, stype):
44        set_default_context(ctx)
45        data_shape = (m, n)
46        dns_data = rand_ndarray(data_shape, stype, density).tostype('default')
47        dns_data.wait_to_read()
48
49        # do one warm up run, verify correctness
50        assert same(mx.nd.cast_storage(dns_data, stype).asnumpy(), dns_data.asnumpy())
51
52        # start benchmarking
53        cost = measure_cost(repeat, mx.nd.cast_storage, dns_data, stype)
54        results = '{:10.1f} {:>10} {:8d} {:8d} {:10.2f}'.format(density*100, str(ctx), m, n, cost*1000)
55        print(results)
56
57    check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
58
59    # params
60    # m           number of rows
61    # n           number of columns
62    # density     density of the matrix
63    # num_repeat  number of benchmark runs to average over
64    # contexts    mx.cpu(), mx.gpu()
65    #             note: benchmark different contexts separately; to benchmark cpu, compile without CUDA
66    # benchmarks  dns_to_csr, dns_to_rsp
67    m = [  512,    512]
68    n = [50000, 100000]
69    density = [1.00, 0.80, 0.60, 0.40, 0.20, 0.10, 0.05, 0.02, 0.01]
70    num_repeat = 10
71    contexts = [mx.gpu()]
72    benchmarks = ["dns_to_csr", "dns_to_rsp"]
73
74    # run benchmark
75    for b in benchmarks:
76        stype = ''
77        print("==================================================")
78        if b is "dns_to_csr":
79            stype = 'csr'
80            print(" cast_storage benchmark: dense to csr, size m x n ")
81        elif b is "dns_to_rsp":
82            stype = 'row_sparse'
83            print(" cast_storage benchmark: dense to rsp, size m x n ")
84        else:
85            print("invalid benchmark: %s" %b)
86            continue
87        print("==================================================")
88        headline = '{:>10} {:>10} {:>8} {:>8} {:>10}'.format('density(%)', 'context', 'm', 'n', 'time(ms)')
89        print(headline)
90        for i in range(len(n)):
91            for ctx in contexts:
92                for den in density:
93                    dense_to_sparse(m[i], n[i], den, ctx, num_repeat, stype)
94            print("")
95        print("")
96
97
98if __name__ == "__main__":
99    run_cast_storage_synthetic()
100