1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Performance benchmark tests for MXNet NDArray Reduction Operations.
191. Operators are automatically fetched from MXNet operator registry.
202. Default Inputs are generated. See rules/default_params.py. You can override the default values.
21
22Below 10 reduction Operators are covered:
23
24['max', 'max_axis', 'mean', 'min', 'min_axis', 'nanprod', 'nansum', 'prod', 'sum', 'sum_axis']
25
26"""
27
28import mxnet as mx
29
30from benchmark.opperf.utils.op_registry_utils import get_all_reduction_operators
31from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
32
33
34def run_mx_reduction_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
35    """Runs benchmarks with the given context and precision (dtype)for all the reduction
36    operators in MXNet.
37
38    Parameters
39    ----------
40    ctx: mx.ctx
41        Context to run benchmarks
42    dtype: str, default 'float32'
43        Precision to use for benchmarks
44    profiler: str, default 'native'
45        Type of Profiler to use (native/python)
46    warmup: int, default 25
47        Number of times to run for warmup
48    runs: int, default 100
49        Number of runs to capture benchmark results
50
51    Returns
52    -------
53    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
54
55    """
56    # Fetch all Reduction Operators
57    mx_reduction_broadcast_ops = get_all_reduction_operators()
58    # Run benchmarks
59    mx_reduction_op_results = run_op_benchmarks(mx_reduction_broadcast_ops, dtype, ctx, profiler, warmup, runs)
60    return mx_reduction_op_results
61