1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint:skip-file
19from collections import namedtuple
20
21import mxnet as mx
22
23from stt_layer_batchnorm import batchnorm
24
25LSTMState = namedtuple("LSTMState", ["c", "h"])
26LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
27                                     "h2h_weight", "h2h_bias",
28                                     "ph2h_weight",
29                                     "c2i_bias", "c2f_bias", "c2o_bias"])
30LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
31                                     "init_states", "last_states",
32                                     "seq_data", "seq_labels", "seq_outputs",
33                                     "param_blocks"])
34
35
36def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None):
37    """LSTM Cell symbol"""
38    i2h = mx.sym.FullyConnected(data=indata,
39                                weight=param.i2h_weight,
40                                bias=param.i2h_bias,
41                                num_hidden=num_hidden * 4,
42                                name="t%d_l%d_i2h" % (seqidx, layeridx))
43    if is_batchnorm:
44        if name is not None:
45            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
46        else:
47            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
48    h2h = mx.sym.FullyConnected(data=prev_state.h,
49                                weight=param.h2h_weight,
50                                bias=param.h2h_bias,
51                                num_hidden=num_hidden * 4,
52                                name="t%d_l%d_h2h" % (seqidx, layeridx))
53    gates = i2h + h2h
54    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
55                                      name="t%d_l%d_slice" % (seqidx, layeridx))
56    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
57    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
58    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
59    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
60    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
61    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
62    return LSTMState(c=next_c, h=next_h)
63
64
65def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False,
66         gamma=None, beta=None, name=None):
67    """LSTM Cell symbol"""
68    # dropout input
69    if dropout > 0.:
70        indata = mx.sym.Dropout(data=indata, p=dropout)
71
72    i2h = mx.sym.FullyConnected(data=indata,
73                                weight=param.i2h_weight,
74                                bias=param.i2h_bias,
75                                num_hidden=num_hidden * 4,
76                                name="t%d_l%d_i2h" % (seqidx, layeridx))
77    if is_batchnorm:
78        if name is not None:
79            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
80        else:
81            i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
82
83    h2h = mx.sym.FullyConnected(data=prev_state.h,
84                                weight=param.h2h_weight,
85                                # bias=param.h2h_bias,
86                                no_bias=True,
87                                num_hidden=num_hidden * 4,
88                                name="t%d_l%d_h2h" % (seqidx, layeridx))
89
90    gates = i2h + h2h
91    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
92                                      name="t%d_l%d_slice" % (seqidx, layeridx))
93
94    Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0]
95    in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid")
96
97    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
98
99    Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2]
100    forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid")
101
102    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
103
104    Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3]
105    out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid")
106
107    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
108
109    if num_hidden_proj > 0:
110        proj_next_h = mx.sym.FullyConnected(data=next_h,
111                                            weight=param.ph2h_weight,
112                                            no_bias=True,
113                                            num_hidden=num_hidden_proj,
114                                            name="t%d_l%d_ph2h" % (seqidx, layeridx))
115
116        return LSTMState(c=next_c, h=proj_next_h)
117    else:
118        return LSTMState(c=next_c, h=next_h)
119
120
121def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0,
122                lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward", is_bucketing=False):
123    if num_lstm_layer > 0:
124        param_cells = []
125        last_states = []
126        for i in range(num_lstm_layer):
127            param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_weight" % i),
128                                         i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_bias" % i),
129                                         h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_weight" % i),
130                                         h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_bias" % i),
131                                         ph2h_weight=mx.sym.Variable(prefix + "l%d_ph2h_weight" % i),
132                                         c2i_bias=mx.sym.Variable(prefix + "l%d_c2i_bias" % i,
133                                                                  shape=(1, num_hidden_lstm_list[i])),
134                                         c2f_bias=mx.sym.Variable(prefix + "l%d_c2f_bias" % i,
135                                                                  shape=(1, num_hidden_lstm_list[i])),
136                                         c2o_bias=mx.sym.Variable(prefix + "l%d_c2o_bias" % i,
137                                                                  shape=(1, num_hidden_lstm_list[i]))
138                                         ))
139            state = LSTMState(c=mx.sym.Variable(prefix + "l%d_init_c" % i),
140                              h=mx.sym.Variable(prefix + "l%d_init_h" % i))
141            last_states.append(state)
142        assert (len(last_states) == num_lstm_layer)
143        # declare batchnorm param(gamma,beta) in timestep wise
144        if is_batchnorm:
145            batchnorm_gamma = []
146            batchnorm_beta = []
147            if is_bucketing:
148                for l in range(num_lstm_layer):
149                    batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l))
150                    batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l))
151            else:
152                for seqidx in range(seq_len):
153                    batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
154                    batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
155
156        hidden_all = []
157        for seqidx in range(seq_len):
158            if direction == "forward":
159                k = seqidx
160                hidden = net[k]
161            elif direction == "backward":
162                k = seq_len - seqidx - 1
163                hidden = net[k]
164            else:
165                raise Exception("direction should be whether forward or backward")
166
167            # stack LSTM
168            for i in range(num_lstm_layer):
169                if i == 0:
170                    dp = 0.
171                else:
172                    dp = dropout
173
174                if lstm_type == 'fc_lstm':
175                    if is_batchnorm:
176                        if is_bucketing:
177                            next_state = lstm(num_hidden_lstm_list[i],
178                                              indata=hidden,
179                                              prev_state=last_states[i],
180                                              param=param_cells[i],
181                                              seqidx=k,
182                                              layeridx=i,
183                                              dropout=dp,
184                                              num_hidden_proj=num_hidden_proj,
185                                              is_batchnorm=is_batchnorm,
186                                              gamma=batchnorm_gamma[i],
187                                              beta=batchnorm_beta[i],
188                                              name=prefix + ("t%d_l%d" % (seqidx, i))
189                                              )
190                    else:
191                        next_state = lstm(num_hidden_lstm_list[i],
192                                          indata=hidden,
193                                          prev_state=last_states[i],
194                                          param=param_cells[i],
195                                          seqidx=k,
196                                          layeridx=i,
197                                          dropout=dp,
198                                          num_hidden_proj=num_hidden_proj,
199                                          is_batchnorm=is_batchnorm,
200                                          name=prefix + ("t%d_l%d" % (seqidx, i))
201                                          )
202                elif lstm_type == 'vanilla_lstm':
203                    if is_batchnorm:
204                        next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden,
205                                                  prev_state=last_states[i],
206                                                  param=param_cells[i],
207                                                  seqidx=k, layeridx=i,
208                                                  is_batchnorm=is_batchnorm,
209                                                  gamma=batchnorm_gamma[i],
210                                                  beta=batchnorm_beta[i],
211                                                  name=prefix + ("t%d_l%d" % (seqidx, i))
212                                                  )
213                    else:
214                        next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden,
215                                                  prev_state=last_states[i],
216                                                  param=param_cells[i],
217                                                  seqidx=k, layeridx=i,
218                                                  is_batchnorm=is_batchnorm,
219                                                  name=prefix + ("t%d_l%d" % (seqidx, i))
220                                                  )
221                else:
222                    raise Exception("lstm type %s error" % lstm_type)
223
224                hidden = next_state.h
225                last_states[i] = next_state
226            # decoder
227            if dropout > 0.:
228                hidden = mx.sym.Dropout(data=hidden, p=dropout)
229
230            if direction == "forward":
231                hidden_all.append(hidden)
232            elif direction == "backward":
233                hidden_all.insert(0, hidden)
234            else:
235                raise Exception("direction should be whether forward or backward")
236        net = hidden_all
237
238    return net
239
240
241def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0,
242                   lstm_type='fc_lstm', is_batchnorm=False, is_bucketing=False):
243    if num_lstm_layer > 0:
244        net_forward = lstm_unroll(net=net,
245                                  num_lstm_layer=num_lstm_layer,
246                                  seq_len=seq_len,
247                                  num_hidden_lstm_list=num_hidden_lstm_list,
248                                  dropout=dropout,
249                                  num_hidden_proj=num_hidden_proj,
250                                  lstm_type=lstm_type,
251                                  is_batchnorm=is_batchnorm,
252                                  prefix="forward_",
253                                  direction="forward",
254                                  is_bucketing=is_bucketing)
255
256        net_backward = lstm_unroll(net=net,
257                                   num_lstm_layer=num_lstm_layer,
258                                   seq_len=seq_len,
259                                   num_hidden_lstm_list=num_hidden_lstm_list,
260                                   dropout=dropout,
261                                   num_hidden_proj=num_hidden_proj,
262                                   lstm_type=lstm_type,
263                                   is_batchnorm=is_batchnorm,
264                                   prefix="backward_",
265                                   direction="backward",
266                                   is_bucketing=is_bucketing)
267        hidden_all = []
268        for i in range(seq_len):
269            hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1))
270        net = hidden_all
271    return net
272
273
274# bilistm_2to1
275def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
276                                        num_hidden_proj=0,
277                                        lstm_type='fc_lstm',
278                                        is_batchnorm=False,
279                                        is_bucketing=False):
280    if num_lstm_layer > 0:
281        net_forward = lstm_unroll(net=net1,
282                                  num_lstm_layer=num_lstm_layer,
283                                  seq_len=seq_len,
284                                  num_hidden_lstm_list=num_hidden_lstm_list,
285                                  dropout=dropout,
286                                  num_hidden_proj=num_hidden_proj,
287                                  lstm_type=lstm_type,
288                                  is_batchnorm=is_batchnorm,
289                                  prefix="forward_",
290                                  direction="forward",
291                                  is_bucketing=is_bucketing)
292
293        net_backward = lstm_unroll(net=net2,
294                                   num_lstm_layer=num_lstm_layer,
295                                   seq_len=seq_len,
296                                   num_hidden_lstm_list=num_hidden_lstm_list,
297                                   dropout=dropout,
298                                   num_hidden_proj=num_hidden_proj,
299                                   lstm_type=lstm_type,
300                                   is_batchnorm=is_batchnorm,
301                                   prefix="backward_",
302                                   direction="backward",
303                                   is_bucketing=is_bucketing)
304        return net_forward, net_backward
305    else:
306        return net1, net2
307