1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19
20
21def batchnorm(net,
22              gamma=None,
23              beta=None,
24              eps=0.001,
25              momentum=0.9,
26              fix_gamma=False,
27              use_global_stats=False,
28              output_mean_var=False,
29              name=None):
30    if gamma is not None and beta is not None:
31        net = mx.sym.BatchNorm(data=net,
32                               gamma=gamma,
33                               beta=beta,
34                               eps=eps,
35                               momentum=momentum,
36                               fix_gamma=fix_gamma,
37                               use_global_stats=use_global_stats,
38                               output_mean_var=output_mean_var,
39                               name=name
40                               )
41    else:
42        net = mx.sym.BatchNorm(data=net,
43                               eps=eps,
44                               momentum=momentum,
45                               fix_gamma=fix_gamma,
46                               use_global_stats=use_global_stats,
47                               output_mean_var=output_mean_var,
48                               name=name
49                               )
50    return net
51