1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19
20from benchmark.opperf.utils.op_registry_utils import get_all_nn_basic_operators
21from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
22
23"""Performance benchmark tests for MXNet NDArray basic NN Operators.
24
251. FullyConnected
262. Dropout
273. BatchNorm
284. SoftmaxOutput
295. LinearRegressionOutput
306. LogisticRegressionOutput
317. MAERegressionOutput
328. SVMOutput
339. L2Normalization
3410. LayerNorm
3511. InstanceNorm
3612. Embedding
3713. Correlation
3814. SpatialTransformer
3915. im2col
4016. col2im
4117. GroupNorm
4218. RNN
4319. LRN
44
45"""
46
47
48def run_nn_basic_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
49    """Runs benchmarks with the given context and precision (dtype)for all the NN basic
50    operators in MXNet.
51
52    Parameters
53    ----------
54    ctx: mx.ctx
55        Context to run benchmarks
56    dtype: str, default 'float32'
57        Precision to use for benchmarks
58    profiler: str, default 'native'
59        Module to use for tracking benchmark excecution time
60    warmup: int, default 25
61        Number of times to run for warmup
62    runs: int, default 100
63        Number of runs to capture benchmark results
64
65    Returns
66    -------
67    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
68
69    """
70
71    # Fetch all NN Basic Operators
72    mx_nn_basic_ops = get_all_nn_basic_operators()
73
74    # Run benchmarks
75    mx_nn_basic_op_results = run_op_benchmarks(mx_nn_basic_ops, dtype, ctx, profiler, warmup, runs)
76    return mx_nn_basic_op_results
77