1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from collections import namedtuple 19 20import mxnet as mx 21 22from stt_layer_batchnorm import batchnorm 23 24GRUState = namedtuple("GRUState", ["h"]) 25GRUParam = namedtuple("GRUParam", ["gates_i2h_weight", "gates_i2h_bias", 26 "gates_h2h_weight", "gates_h2h_bias", 27 "trans_i2h_weight", "trans_i2h_bias", 28 "trans_h2h_weight", "trans_h2h_bias"]) 29GRUModel = namedtuple("GRUModel", ["rnn_exec", "symbol", 30 "init_states", "last_states", 31 "seq_data", "seq_labels", "seq_outputs", 32 "param_blocks"]) 33 34 35def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None): 36 """ 37 GRU Cell symbol 38 Reference: 39 * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural 40 networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014). 41 """ 42 if dropout > 0.: 43 indata = mx.sym.Dropout(data=indata, p=dropout) 44 i2h = mx.sym.FullyConnected(data=indata, 45 weight=param.gates_i2h_weight, 46 bias=param.gates_i2h_bias, 47 num_hidden=num_hidden * 2, 48 name="t%d_l%d_gates_i2h" % (seqidx, layeridx)) 49 50 if is_batchnorm: 51 if name is not None: 52 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) 53 else: 54 i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) 55 h2h = mx.sym.FullyConnected(data=prev_state.h, 56 weight=param.gates_h2h_weight, 57 bias=param.gates_h2h_bias, 58 num_hidden=num_hidden * 2, 59 name="t%d_l%d_gates_h2h" % (seqidx, layeridx)) 60 gates = i2h + h2h 61 slice_gates = mx.sym.SliceChannel(gates, num_outputs=2, 62 name="t%d_l%d_slice" % (seqidx, layeridx)) 63 update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 64 reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid") 65 # The transform part of GRU is a little magic 66 htrans_i2h = mx.sym.FullyConnected(data=indata, 67 weight=param.trans_i2h_weight, 68 bias=param.trans_i2h_bias, 69 num_hidden=num_hidden, 70 name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) 71 h_after_reset = prev_state.h * reset_gate 72 htrans_h2h = mx.sym.FullyConnected(data=h_after_reset, 73 weight=param.trans_h2h_weight, 74 bias=param.trans_h2h_bias, 75 num_hidden=num_hidden, 76 name="t%d_l%d_trans_h2h" % (seqidx, layeridx)) 77 h_trans = htrans_i2h + htrans_h2h 78 h_trans_active = mx.sym.Activation(h_trans, act_type="tanh") 79 next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h) 80 return GRUState(h=next_h) 81 82 83def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="", 84 direction="forward", is_bucketing=False): 85 if num_gru_layer > 0: 86 param_cells = [] 87 last_states = [] 88 for i in range(num_gru_layer): 89 param_cells.append(GRUParam(gates_i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_gates_weight" % i), 90 gates_i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_gates_bias" % i), 91 gates_h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_gates_weight" % i), 92 gates_h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_gates_bias" % i), 93 trans_i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_trans_weight" % i), 94 trans_i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_trans_bias" % i), 95 trans_h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_trans_weight" % i), 96 trans_h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_trans_bias" % i))) 97 state = GRUState(h=mx.sym.Variable(prefix + "l%d_init_h" % i)) 98 last_states.append(state) 99 assert (len(last_states) == num_gru_layer) 100 # declare batchnorm param(gamma,beta) in timestep wise 101 if is_batchnorm: 102 batchnorm_gamma = [] 103 batchnorm_beta = [] 104 if is_bucketing: 105 for l in range(num_gru_layer): 106 batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l)) 107 batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l)) 108 else: 109 for seqidx in range(seq_len): 110 batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) 111 batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) 112 113 hidden_all = [] 114 for seqidx in range(seq_len): 115 if direction == "forward": 116 k = seqidx 117 hidden = net[k] 118 elif direction == "backward": 119 k = seq_len - seqidx - 1 120 hidden = net[k] 121 else: 122 raise Exception("direction should be whether forward or backward") 123 124 # stack GRU 125 for i in range(num_gru_layer): 126 if i == 0: 127 dp_ratio = 0. 128 else: 129 dp_ratio = dropout 130 if is_batchnorm: 131 if is_bucketing: 132 next_state = gru(num_hidden_gru_list[i], indata=hidden, 133 prev_state=last_states[i], 134 param=param_cells[i], 135 seqidx=k, layeridx=i, dropout=dp_ratio, 136 is_batchnorm=is_batchnorm, 137 gamma=batchnorm_gamma[i], 138 beta=batchnorm_beta[i], 139 name=prefix + ("t%d_l%d" % (seqidx, i)) 140 ) 141 else: 142 next_state = gru(num_hidden_gru_list[i], indata=hidden, 143 prev_state=last_states[i], 144 param=param_cells[i], 145 seqidx=k, layeridx=i, dropout=dp_ratio, 146 is_batchnorm=is_batchnorm, 147 gamma=batchnorm_gamma[k], 148 beta=batchnorm_beta[k], 149 name=prefix + ("t%d_l%d" % (seqidx, i)) 150 ) 151 else: 152 next_state = gru(num_hidden_gru_list[i], indata=hidden, 153 prev_state=last_states[i], 154 param=param_cells[i], 155 seqidx=k, layeridx=i, dropout=dp_ratio, 156 is_batchnorm=is_batchnorm, 157 name=prefix) 158 hidden = next_state.h 159 last_states[i] = next_state 160 # decoder 161 if dropout > 0.: 162 hidden = mx.sym.Dropout(data=hidden, p=dropout) 163 164 if direction == "forward": 165 hidden_all.append(hidden) 166 elif direction == "backward": 167 hidden_all.insert(0, hidden) 168 else: 169 raise Exception("direction should be whether forward or backward") 170 net = hidden_all 171 172 return net 173 174 175def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, is_bucketing=False): 176 if num_gru_layer > 0: 177 net_forward = gru_unroll(net=net, 178 num_gru_layer=num_gru_layer, 179 seq_len=seq_len, 180 num_hidden_gru_list=num_hidden_gru_list, 181 dropout=dropout, 182 is_batchnorm=is_batchnorm, 183 prefix="forward_", 184 direction="forward", 185 is_bucketing=is_bucketing) 186 net_backward = gru_unroll(net=net, 187 num_gru_layer=num_gru_layer, 188 seq_len=seq_len, 189 num_hidden_gru_list=num_hidden_gru_list, 190 dropout=dropout, 191 is_batchnorm=is_batchnorm, 192 prefix="backward_", 193 direction="backward", 194 is_bucketing=is_bucketing) 195 hidden_all = [] 196 for i in range(seq_len): 197 hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1)) 198 net = hidden_all 199 return net 200 201 202def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., 203 is_batchnorm=False, is_bucketing=False): 204 if num_gru_layer > 0: 205 net_forward = gru_unroll(net=net1, 206 num_gru_layer=num_gru_layer, 207 seq_len=seq_len, 208 num_hidden_gru_list=num_hidden_gru_list, 209 dropout=dropout, 210 is_batchnorm=is_batchnorm, 211 prefix="forward_", 212 direction="forward", 213 is_bucketing=is_bucketing) 214 net_backward = gru_unroll(net=net2, 215 num_gru_layer=num_gru_layer, 216 seq_len=seq_len, 217 num_hidden_gru_list=num_hidden_gru_list, 218 dropout=dropout, 219 is_batchnorm=is_batchnorm, 220 prefix="backward_", 221 direction="backward", 222 is_bucketing=is_bucketing) 223 return net_forward, net_backward 224 else: 225 return net1, net2 226