1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from collections import namedtuple
19
20import mxnet as mx
21
22from stt_layer_batchnorm import batchnorm
23
24GRUState = namedtuple("GRUState", ["h"])
25GRUParam = namedtuple("GRUParam", ["gates_i2h_weight", "gates_i2h_bias",
26                                   "gates_h2h_weight", "gates_h2h_bias",
27                                   "trans_i2h_weight", "trans_i2h_bias",
28                                   "trans_h2h_weight", "trans_h2h_bias"])
29GRUModel = namedtuple("GRUModel", ["rnn_exec", "symbol",
30                                   "init_states", "last_states",
31                                   "seq_data", "seq_labels", "seq_outputs",
32                                   "param_blocks"])
33
34
35def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None):
36    """
37    GRU Cell symbol
38    Reference:
39    * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural
40        networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014).
41    """
42    if dropout > 0.:
43        indata = mx.sym.Dropout(data=indata, p=dropout)
44    i2h = mx.sym.FullyConnected(data=indata,
45                                weight=param.gates_i2h_weight,
46                                bias=param.gates_i2h_bias,
47                                num_hidden=num_hidden * 2,
48                                name="t%d_l%d_gates_i2h" % (seqidx, layeridx))
49
50    if is_batchnorm:
51        if name is not None:
52            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
53        else:
54            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
55    h2h = mx.sym.FullyConnected(data=prev_state.h,
56                                weight=param.gates_h2h_weight,
57                                bias=param.gates_h2h_bias,
58                                num_hidden=num_hidden * 2,
59                                name="t%d_l%d_gates_h2h" % (seqidx, layeridx))
60    gates = i2h + h2h
61    slice_gates = mx.sym.SliceChannel(gates, num_outputs=2,
62                                      name="t%d_l%d_slice" % (seqidx, layeridx))
63    update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
64    reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid")
65    # The transform part of GRU is a little magic
66    htrans_i2h = mx.sym.FullyConnected(data=indata,
67                                       weight=param.trans_i2h_weight,
68                                       bias=param.trans_i2h_bias,
69                                       num_hidden=num_hidden,
70                                       name="t%d_l%d_trans_i2h" % (seqidx, layeridx))
71    h_after_reset = prev_state.h * reset_gate
72    htrans_h2h = mx.sym.FullyConnected(data=h_after_reset,
73                                       weight=param.trans_h2h_weight,
74                                       bias=param.trans_h2h_bias,
75                                       num_hidden=num_hidden,
76                                       name="t%d_l%d_trans_h2h" % (seqidx, layeridx))
77    h_trans = htrans_i2h + htrans_h2h
78    h_trans_active = mx.sym.Activation(h_trans, act_type="tanh")
79    next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h)
80    return GRUState(h=next_h)
81
82
83def gru_unroll(net, num_gru_layer, seq_len,  num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="",
84               direction="forward", is_bucketing=False):
85    if num_gru_layer > 0:
86        param_cells = []
87        last_states = []
88        for i in range(num_gru_layer):
89            param_cells.append(GRUParam(gates_i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_gates_weight" % i),
90                                        gates_i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_gates_bias" % i),
91                                        gates_h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_gates_weight" % i),
92                                        gates_h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_gates_bias" % i),
93                                        trans_i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_trans_weight" % i),
94                                        trans_i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_trans_bias" % i),
95                                        trans_h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_trans_weight" % i),
96                                        trans_h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_trans_bias" % i)))
97            state = GRUState(h=mx.sym.Variable(prefix + "l%d_init_h" % i))
98            last_states.append(state)
99        assert (len(last_states) == num_gru_layer)
100        # declare batchnorm param(gamma,beta) in timestep wise
101        if is_batchnorm:
102            batchnorm_gamma = []
103            batchnorm_beta = []
104            if is_bucketing:
105                for l in range(num_gru_layer):
106                    batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l))
107                    batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l))
108            else:
109                for seqidx in range(seq_len):
110                    batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
111                    batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
112
113        hidden_all = []
114        for seqidx in range(seq_len):
115            if direction == "forward":
116                k = seqidx
117                hidden = net[k]
118            elif direction == "backward":
119                k = seq_len - seqidx - 1
120                hidden = net[k]
121            else:
122                raise Exception("direction should be whether forward or backward")
123
124            # stack GRU
125            for i in range(num_gru_layer):
126                if i == 0:
127                    dp_ratio = 0.
128                else:
129                    dp_ratio = dropout
130                if is_batchnorm:
131                    if is_bucketing:
132                        next_state = gru(num_hidden_gru_list[i], indata=hidden,
133                                         prev_state=last_states[i],
134                                         param=param_cells[i],
135                                         seqidx=k, layeridx=i, dropout=dp_ratio,
136                                         is_batchnorm=is_batchnorm,
137                                         gamma=batchnorm_gamma[i],
138                                         beta=batchnorm_beta[i],
139                                         name=prefix + ("t%d_l%d" % (seqidx, i))
140                                         )
141                    else:
142                        next_state = gru(num_hidden_gru_list[i], indata=hidden,
143                                         prev_state=last_states[i],
144                                         param=param_cells[i],
145                                         seqidx=k, layeridx=i, dropout=dp_ratio,
146                                         is_batchnorm=is_batchnorm,
147                                         gamma=batchnorm_gamma[k],
148                                         beta=batchnorm_beta[k],
149                                         name=prefix + ("t%d_l%d" % (seqidx, i))
150                                         )
151                else:
152                    next_state = gru(num_hidden_gru_list[i], indata=hidden,
153                                     prev_state=last_states[i],
154                                     param=param_cells[i],
155                                     seqidx=k, layeridx=i, dropout=dp_ratio,
156                                     is_batchnorm=is_batchnorm,
157                                     name=prefix)
158                hidden = next_state.h
159                last_states[i] = next_state
160            # decoder
161            if dropout > 0.:
162                hidden = mx.sym.Dropout(data=hidden, p=dropout)
163
164            if direction == "forward":
165                hidden_all.append(hidden)
166            elif direction == "backward":
167                hidden_all.insert(0, hidden)
168            else:
169                raise Exception("direction should be whether forward or backward")
170        net = hidden_all
171
172    return net
173
174
175def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, is_bucketing=False):
176    if num_gru_layer > 0:
177        net_forward = gru_unroll(net=net,
178                                 num_gru_layer=num_gru_layer,
179                                 seq_len=seq_len,
180                                 num_hidden_gru_list=num_hidden_gru_list,
181                                 dropout=dropout,
182                                 is_batchnorm=is_batchnorm,
183                                 prefix="forward_",
184                                 direction="forward",
185                                 is_bucketing=is_bucketing)
186        net_backward = gru_unroll(net=net,
187                                  num_gru_layer=num_gru_layer,
188                                  seq_len=seq_len,
189                                  num_hidden_gru_list=num_hidden_gru_list,
190                                  dropout=dropout,
191                                  is_batchnorm=is_batchnorm,
192                                  prefix="backward_",
193                                  direction="backward",
194                                  is_bucketing=is_bucketing)
195        hidden_all = []
196        for i in range(seq_len):
197            hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1))
198        net = hidden_all
199    return net
200
201
202def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0.,
203                                       is_batchnorm=False, is_bucketing=False):
204    if num_gru_layer > 0:
205        net_forward = gru_unroll(net=net1,
206                                 num_gru_layer=num_gru_layer,
207                                 seq_len=seq_len,
208                                 num_hidden_gru_list=num_hidden_gru_list,
209                                 dropout=dropout,
210                                 is_batchnorm=is_batchnorm,
211                                 prefix="forward_",
212                                 direction="forward",
213                                 is_bucketing=is_bucketing)
214        net_backward = gru_unroll(net=net2,
215                                  num_gru_layer=num_gru_layer,
216                                  seq_len=seq_len,
217                                  num_hidden_gru_list=num_hidden_gru_list,
218                                  dropout=dropout,
219                                  is_batchnorm=is_batchnorm,
220                                  prefix="backward_",
221                                  direction="backward",
222                                  is_bucketing=is_bucketing)
223        return net_forward, net_backward
224    else:
225        return net1, net2
226