1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19
20from stt_layer_batchnorm import batchnorm
21
22
23def fc(net,
24       num_hidden,
25       act_type,
26       weight=None,
27       bias=None,
28       no_bias=False,
29       name=None
30       ):
31    # when weight and bias doesn't have specific name
32    if weight is None and bias is None:
33        net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name)
34    # when weight doesn't have specific name but bias has
35    elif weight is None and bias is not None:
36        if no_bias:
37            net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name)
38        else:
39            net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias, name=name)
40    # when bias doesn't have specific name but weight has
41    elif weight is not None and bias is None:
42        net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name)
43    # when weight and bias specific name
44    else:
45        if no_bias:
46            net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name)
47        else:
48            net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias, name=name)
49    # activation
50    if act_type is not None:
51        net = mx.sym.Activation(data=net, act_type=act_type, name="%s_activation" % name)
52    return net
53
54
55def sequence_fc(net,
56                seq_len,
57                num_layer,
58                prefix,
59                num_hidden_list=[],
60                act_type_list=[],
61                is_batchnorm=False,
62                dropout_rate=0,
63                ):
64    if num_layer == len(num_hidden_list) == len(act_type_list):
65        if num_layer > 0:
66            weight_list = []
67            bias_list = []
68
69            for layer_index in range(num_layer):
70                weight_list.append(mx.sym.Variable(name='%s_sequence_fc%d_weight' % (prefix, layer_index)))
71                # if you use batchnorm bias do not have any effect
72                if not is_batchnorm:
73                    bias_list.append(mx.sym.Variable(name='%s_sequence_fc%d_bias' % (prefix, layer_index)))
74            # batch normalization parameters
75            gamma_list = []
76            beta_list = []
77            if is_batchnorm:
78                for layer_index in range(num_layer):
79                    gamma_list.append(mx.sym.Variable(name='%s_sequence_fc%d_gamma' % (prefix, layer_index)))
80                    beta_list.append(mx.sym.Variable(name='%s_sequence_fc%d_beta' % (prefix, layer_index)))
81            # batch normalization parameters ends
82            if type(net) is mx.symbol.Symbol:
83                net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=1, squeeze_axis=1)
84            elif type(net) is list:
85                for net_index, one_net in enumerate(net):
86                    if type(one_net) is not mx.symbol.Symbol:
87                        raise Exception('%d th elements of the net should be mx.symbol.Symbol' % net_index)
88            else:
89                raise Exception('type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol')
90            hidden_all = []
91            for seq_index in range(seq_len):
92                hidden = net[seq_index]
93                for layer_index in range(num_layer):
94                    if dropout_rate > 0:
95                        hidden = mx.sym.Dropout(data=hidden, p=dropout_rate)
96
97                    if is_batchnorm:
98                        hidden = fc(net=hidden,
99                                    num_hidden=num_hidden_list[layer_index],
100                                    act_type=None,
101                                    weight=weight_list[layer_index],
102                                    no_bias=is_batchnorm,
103                                    name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index)
104                                    )
105                        # last layer doesn't have batchnorm
106                        hidden = batchnorm(net=hidden,
107                                           gamma=gamma_list[layer_index],
108                                           beta=beta_list[layer_index],
109                                           name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index))
110                        hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index],
111                                                   name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index))
112                    else:
113                        hidden = fc(net=hidden,
114                                    num_hidden=num_hidden_list[layer_index],
115                                    act_type=act_type_list[layer_index],
116                                    weight=weight_list[layer_index],
117                                    bias=bias_list[layer_index]
118                                    )
119                hidden_all.append(hidden)
120            net = hidden_all
121        return net
122    else:
123        raise Exception("length doesn't met - num_layer:",
124                        num_layer, ",len(num_hidden_list):",
125                        len(num_hidden_list),
126                        ",len(act_type_list):",
127                        len(act_type_list)
128                        )
129