1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint:skip-file 19from collections import namedtuple 20 21import mxnet as mx 22 23from stt_layer_batchnorm import batchnorm 24 25LSTMState = namedtuple("LSTMState", ["c", "h"]) 26LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 27 "h2h_weight", "h2h_bias", 28 "ph2h_weight", 29 "c2i_bias", "c2f_bias", "c2o_bias"]) 30LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 31 "init_states", "last_states", 32 "seq_data", "seq_labels", "seq_outputs", 33 "param_blocks"]) 34 35 36def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None): 37 """LSTM Cell symbol""" 38 i2h = mx.sym.FullyConnected(data=indata, 39 weight=param.i2h_weight, 40 bias=param.i2h_bias, 41 num_hidden=num_hidden * 4, 42 name="t%d_l%d_i2h" % (seqidx, layeridx)) 43 if is_batchnorm: 44 if name is not None: 45 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) 46 else: 47 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) 48 h2h = mx.sym.FullyConnected(data=prev_state.h, 49 weight=param.h2h_weight, 50 bias=param.h2h_bias, 51 num_hidden=num_hidden * 4, 52 name="t%d_l%d_h2h" % (seqidx, layeridx)) 53 gates = i2h + h2h 54 slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 55 name="t%d_l%d_slice" % (seqidx, layeridx)) 56 in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 57 in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 58 forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 59 out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 60 next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 61 next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 62 return LSTMState(c=next_c, h=next_h) 63 64 65def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False, 66 gamma=None, beta=None, name=None): 67 """LSTM Cell symbol""" 68 # dropout input 69 if dropout > 0.: 70 indata = mx.sym.Dropout(data=indata, p=dropout) 71 72 i2h = mx.sym.FullyConnected(data=indata, 73 weight=param.i2h_weight, 74 bias=param.i2h_bias, 75 num_hidden=num_hidden * 4, 76 name="t%d_l%d_i2h" % (seqidx, layeridx)) 77 if is_batchnorm: 78 if name is not None: 79 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) 80 else: 81 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) 82 83 h2h = mx.sym.FullyConnected(data=prev_state.h, 84 weight=param.h2h_weight, 85 # bias=param.h2h_bias, 86 no_bias=True, 87 num_hidden=num_hidden * 4, 88 name="t%d_l%d_h2h" % (seqidx, layeridx)) 89 90 gates = i2h + h2h 91 slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 92 name="t%d_l%d_slice" % (seqidx, layeridx)) 93 94 Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0] 95 in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid") 96 97 in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 98 99 Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2] 100 forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid") 101 102 next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 103 104 Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3] 105 out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid") 106 107 next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 108 109 if num_hidden_proj > 0: 110 proj_next_h = mx.sym.FullyConnected(data=next_h, 111 weight=param.ph2h_weight, 112 no_bias=True, 113 num_hidden=num_hidden_proj, 114 name="t%d_l%d_ph2h" % (seqidx, layeridx)) 115 116 return LSTMState(c=next_c, h=proj_next_h) 117 else: 118 return LSTMState(c=next_c, h=next_h) 119 120 121def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0, 122 lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward", is_bucketing=False): 123 if num_lstm_layer > 0: 124 param_cells = [] 125 last_states = [] 126 for i in range(num_lstm_layer): 127 param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_weight" % i), 128 i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_bias" % i), 129 h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_weight" % i), 130 h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_bias" % i), 131 ph2h_weight=mx.sym.Variable(prefix + "l%d_ph2h_weight" % i), 132 c2i_bias=mx.sym.Variable(prefix + "l%d_c2i_bias" % i, 133 shape=(1, num_hidden_lstm_list[i])), 134 c2f_bias=mx.sym.Variable(prefix + "l%d_c2f_bias" % i, 135 shape=(1, num_hidden_lstm_list[i])), 136 c2o_bias=mx.sym.Variable(prefix + "l%d_c2o_bias" % i, 137 shape=(1, num_hidden_lstm_list[i])) 138 )) 139 state = LSTMState(c=mx.sym.Variable(prefix + "l%d_init_c" % i), 140 h=mx.sym.Variable(prefix + "l%d_init_h" % i)) 141 last_states.append(state) 142 assert (len(last_states) == num_lstm_layer) 143 # declare batchnorm param(gamma,beta) in timestep wise 144 if is_batchnorm: 145 batchnorm_gamma = [] 146 batchnorm_beta = [] 147 if is_bucketing: 148 for l in range(num_lstm_layer): 149 batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l)) 150 batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l)) 151 else: 152 for seqidx in range(seq_len): 153 batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) 154 batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) 155 156 hidden_all = [] 157 for seqidx in range(seq_len): 158 if direction == "forward": 159 k = seqidx 160 hidden = net[k] 161 elif direction == "backward": 162 k = seq_len - seqidx - 1 163 hidden = net[k] 164 else: 165 raise Exception("direction should be whether forward or backward") 166 167 # stack LSTM 168 for i in range(num_lstm_layer): 169 if i == 0: 170 dp = 0. 171 else: 172 dp = dropout 173 174 if lstm_type == 'fc_lstm': 175 if is_batchnorm: 176 if is_bucketing: 177 next_state = lstm(num_hidden_lstm_list[i], 178 indata=hidden, 179 prev_state=last_states[i], 180 param=param_cells[i], 181 seqidx=k, 182 layeridx=i, 183 dropout=dp, 184 num_hidden_proj=num_hidden_proj, 185 is_batchnorm=is_batchnorm, 186 gamma=batchnorm_gamma[i], 187 beta=batchnorm_beta[i], 188 name=prefix + ("t%d_l%d" % (seqidx, i)) 189 ) 190 else: 191 next_state = lstm(num_hidden_lstm_list[i], 192 indata=hidden, 193 prev_state=last_states[i], 194 param=param_cells[i], 195 seqidx=k, 196 layeridx=i, 197 dropout=dp, 198 num_hidden_proj=num_hidden_proj, 199 is_batchnorm=is_batchnorm, 200 name=prefix + ("t%d_l%d" % (seqidx, i)) 201 ) 202 elif lstm_type == 'vanilla_lstm': 203 if is_batchnorm: 204 next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden, 205 prev_state=last_states[i], 206 param=param_cells[i], 207 seqidx=k, layeridx=i, 208 is_batchnorm=is_batchnorm, 209 gamma=batchnorm_gamma[i], 210 beta=batchnorm_beta[i], 211 name=prefix + ("t%d_l%d" % (seqidx, i)) 212 ) 213 else: 214 next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden, 215 prev_state=last_states[i], 216 param=param_cells[i], 217 seqidx=k, layeridx=i, 218 is_batchnorm=is_batchnorm, 219 name=prefix + ("t%d_l%d" % (seqidx, i)) 220 ) 221 else: 222 raise Exception("lstm type %s error" % lstm_type) 223 224 hidden = next_state.h 225 last_states[i] = next_state 226 # decoder 227 if dropout > 0.: 228 hidden = mx.sym.Dropout(data=hidden, p=dropout) 229 230 if direction == "forward": 231 hidden_all.append(hidden) 232 elif direction == "backward": 233 hidden_all.insert(0, hidden) 234 else: 235 raise Exception("direction should be whether forward or backward") 236 net = hidden_all 237 238 return net 239 240 241def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0, 242 lstm_type='fc_lstm', is_batchnorm=False, is_bucketing=False): 243 if num_lstm_layer > 0: 244 net_forward = lstm_unroll(net=net, 245 num_lstm_layer=num_lstm_layer, 246 seq_len=seq_len, 247 num_hidden_lstm_list=num_hidden_lstm_list, 248 dropout=dropout, 249 num_hidden_proj=num_hidden_proj, 250 lstm_type=lstm_type, 251 is_batchnorm=is_batchnorm, 252 prefix="forward_", 253 direction="forward", 254 is_bucketing=is_bucketing) 255 256 net_backward = lstm_unroll(net=net, 257 num_lstm_layer=num_lstm_layer, 258 seq_len=seq_len, 259 num_hidden_lstm_list=num_hidden_lstm_list, 260 dropout=dropout, 261 num_hidden_proj=num_hidden_proj, 262 lstm_type=lstm_type, 263 is_batchnorm=is_batchnorm, 264 prefix="backward_", 265 direction="backward", 266 is_bucketing=is_bucketing) 267 hidden_all = [] 268 for i in range(seq_len): 269 hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1)) 270 net = hidden_all 271 return net 272 273 274# bilistm_2to1 275def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., 276 num_hidden_proj=0, 277 lstm_type='fc_lstm', 278 is_batchnorm=False, 279 is_bucketing=False): 280 if num_lstm_layer > 0: 281 net_forward = lstm_unroll(net=net1, 282 num_lstm_layer=num_lstm_layer, 283 seq_len=seq_len, 284 num_hidden_lstm_list=num_hidden_lstm_list, 285 dropout=dropout, 286 num_hidden_proj=num_hidden_proj, 287 lstm_type=lstm_type, 288 is_batchnorm=is_batchnorm, 289 prefix="forward_", 290 direction="forward", 291 is_bucketing=is_bucketing) 292 293 net_backward = lstm_unroll(net=net2, 294 num_lstm_layer=num_lstm_layer, 295 seq_len=seq_len, 296 num_hidden_lstm_list=num_hidden_lstm_list, 297 dropout=dropout, 298 num_hidden_proj=num_hidden_proj, 299 lstm_type=lstm_type, 300 is_batchnorm=is_batchnorm, 301 prefix="backward_", 302 direction="backward", 303 is_bucketing=is_bucketing) 304 return net_forward, net_backward 305 else: 306 return net1, net2 307