1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
20from benchmark.opperf.utils.op_registry_utils import get_all_indexing_routines
21
22"""Performance benchmark tests for MXNet Indexing routines.
23
241. slice
252. slice_axis
263. slice_like
274. take
285. pick
296. where
307. ravel_multi_index
318. unravel_index [to do]
329. gather_nd
3310. scatter_nd [to do]
3411. one_hot
35"""
36
37
38def run_indexing_routines_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
39    """Runs benchmarks with the given context and precision (dtype) for all the indexing routines
40    in MXNet.
41
42    Parameters
43    ----------
44    ctx: mx.ctx
45        Context to run benchmarks
46    dtype: str, default 'float32'
47        Precision to use for benchmarks
48    profiler: str, default 'native'
49        Type of Profiler to use (native/python)
50    warmup: int, default 25
51        Number of times to run for warmup
52    runs: int, default 100
53        Number of runs to capture benchmark results
54
55    Returns
56    -------
57    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
58
59    """
60    # Fetch all indexing routines
61    mx_indexing_ops = get_all_indexing_routines()
62
63    # Run benchmarks
64    mx_indexing_op_results = run_op_benchmarks(mx_indexing_ops, dtype, ctx, profiler, warmup, runs)
65    return mx_indexing_op_results
66