1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19 20 21def batchnorm(net, 22 gamma=None, 23 beta=None, 24 eps=0.001, 25 momentum=0.9, 26 fix_gamma=False, 27 use_global_stats=False, 28 output_mean_var=False, 29 name=None): 30 if gamma is not None and beta is not None: 31 net = mx.sym.BatchNorm(data=net, 32 gamma=gamma, 33 beta=beta, 34 eps=eps, 35 momentum=momentum, 36 fix_gamma=fix_gamma, 37 use_global_stats=use_global_stats, 38 output_mean_var=output_mean_var, 39 name=name 40 ) 41 else: 42 net = mx.sym.BatchNorm(data=net, 43 eps=eps, 44 momentum=momentum, 45 fix_gamma=fix_gamma, 46 use_global_stats=use_global_stats, 47 output_mean_var=output_mean_var, 48 name=name 49 ) 50 return net 51