1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
20from benchmark.opperf.utils.op_registry_utils import get_all_rearrange_operators
21
22"""Performance benchmark tests for MXNet NDArray Rearrange Operators.
23
241. transpose
252. swapaxes
263. flip
274. depth_to_space
285. space_to_depth
29"""
30
31
32def run_rearrange_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
33    """Runs benchmarks with the given context and precision (dtype) for all the
34    rearrange operators in MXNet.
35
36    Parameters
37    ----------
38    ctx: mx.ctx
39        Context to run benchmarks
40    dtype: str, default 'float32'
41        Precision to use for benchmarks
42    profiler: str, default 'native'
43        Type of Profiler to use (native/python)
44    warmup: int, default 25
45        Number of times to run for warmup
46    runs: int, default 100
47        Number of runs to capture benchmark results
48
49    Returns
50    -------
51    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
52
53    """
54    # Fetch all array rerrange operators
55    mx_rearrange_ops = get_all_rearrange_operators()
56
57    # Run benchmarks
58    mx_rearrange_op_results = run_op_benchmarks(mx_rearrange_ops, dtype, ctx, profiler, warmup, runs)
59    return mx_rearrange_op_results
60