1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19 20from benchmark.opperf.utils.op_registry_utils import get_all_nn_basic_operators 21from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks 22 23"""Performance benchmark tests for MXNet NDArray basic NN Operators. 24 251. FullyConnected 262. Dropout 273. BatchNorm 284. SoftmaxOutput 295. LinearRegressionOutput 306. LogisticRegressionOutput 317. MAERegressionOutput 328. SVMOutput 339. L2Normalization 3410. LayerNorm 3511. InstanceNorm 3612. Embedding 3713. Correlation 3814. SpatialTransformer 3915. im2col 4016. col2im 4117. GroupNorm 4218. RNN 4319. LRN 44 45""" 46 47 48def run_nn_basic_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100): 49 """Runs benchmarks with the given context and precision (dtype)for all the NN basic 50 operators in MXNet. 51 52 Parameters 53 ---------- 54 ctx: mx.ctx 55 Context to run benchmarks 56 dtype: str, default 'float32' 57 Precision to use for benchmarks 58 profiler: str, default 'native' 59 Module to use for tracking benchmark excecution time 60 warmup: int, default 25 61 Number of times to run for warmup 62 runs: int, default 100 63 Number of runs to capture benchmark results 64 65 Returns 66 ------- 67 Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. 68 69 """ 70 71 # Fetch all NN Basic Operators 72 mx_nn_basic_ops = get_all_nn_basic_operators() 73 74 # Run benchmarks 75 mx_nn_basic_op_results = run_op_benchmarks(mx_nn_basic_ops, dtype, ctx, profiler, warmup, runs) 76 return mx_nn_basic_op_results 77