1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme 19# pylint: disable=superfluous-parens, no-member, invalid-name 20import sys 21sys.path.insert(0, "../../python") 22import mxnet as mx 23from collections import namedtuple 24GRUState = namedtuple("GRUState", ["h"]) 25GRUParam = namedtuple("GRUParam", ["gates_i2h_weight", "gates_i2h_bias", 26 "gates_h2h_weight", "gates_h2h_bias", 27 "trans_i2h_weight", "trans_i2h_bias", 28 "trans_h2h_weight", "trans_h2h_bias"]) 29GRUModel = namedtuple("GRUModel", ["rnn_exec", "symbol", 30 "init_states", "last_states", 31 "seq_data", "seq_labels", "seq_outputs", 32 "param_blocks"]) 33 34def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 35 """ 36 GRU Cell symbol 37 Reference: 38 * Chung, Junyoung, et al. "Empirical evaluation of gated recurrent neural 39 networks on sequence modeling." arXiv preprint arXiv:1412.3555 (2014). 40 """ 41 if dropout > 0.: 42 indata = mx.sym.Dropout(data=indata, p=dropout) 43 i2h = mx.sym.FullyConnected(data=indata, 44 weight=param.gates_i2h_weight, 45 bias=param.gates_i2h_bias, 46 num_hidden=num_hidden * 2, 47 name="t%d_l%d_gates_i2h" % (seqidx, layeridx)) 48 h2h = mx.sym.FullyConnected(data=prev_state.h, 49 weight=param.gates_h2h_weight, 50 bias=param.gates_h2h_bias, 51 num_hidden=num_hidden * 2, 52 name="t%d_l%d_gates_h2h" % (seqidx, layeridx)) 53 gates = i2h + h2h 54 slice_gates = mx.sym.SliceChannel(gates, num_outputs=2, 55 name="t%d_l%d_slice" % (seqidx, layeridx)) 56 update_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 57 reset_gate = mx.sym.Activation(slice_gates[1], act_type="sigmoid") 58 # The transform part of GRU is a little magic 59 htrans_i2h = mx.sym.FullyConnected(data=indata, 60 weight=param.trans_i2h_weight, 61 bias=param.trans_i2h_bias, 62 num_hidden=num_hidden, 63 name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) 64 h_after_reset = prev_state.h * reset_gate 65 htrans_h2h = mx.sym.FullyConnected(data=h_after_reset, 66 weight=param.trans_h2h_weight, 67 bias=param.trans_h2h_bias, 68 num_hidden=num_hidden, 69 name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) 70 h_trans = htrans_i2h + htrans_h2h 71 h_trans_active = mx.sym.Activation(h_trans, act_type="tanh") 72 next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h) 73 return GRUState(h=next_h) 74 75def gru_unroll(num_gru_layer, seq_len, input_size, 76 num_hidden, num_embed, num_label, dropout=0.): 77 seqidx = 0 78 embed_weight = mx.sym.Variable("embed_weight") 79 cls_weight = mx.sym.Variable("cls_weight") 80 cls_bias = mx.sym.Variable("cls_bias") 81 param_cells = [] 82 last_states = [] 83 for i in range(num_gru_layer): 84 param_cells.append(GRUParam(gates_i2h_weight=mx.sym.Variable("l%d_i2h_gates_weight" % i), 85 gates_i2h_bias=mx.sym.Variable("l%d_i2h_gates_bias" % i), 86 gates_h2h_weight=mx.sym.Variable("l%d_h2h_gates_weight" % i), 87 gates_h2h_bias=mx.sym.Variable("l%d_h2h_gates_bias" % i), 88 trans_i2h_weight=mx.sym.Variable("l%d_i2h_trans_weight" % i), 89 trans_i2h_bias=mx.sym.Variable("l%d_i2h_trans_bias" % i), 90 trans_h2h_weight=mx.sym.Variable("l%d_h2h_trans_weight" % i), 91 trans_h2h_bias=mx.sym.Variable("l%d_h2h_trans_bias" % i))) 92 state = GRUState(h=mx.sym.Variable("l%d_init_h" % i)) 93 last_states.append(state) 94 assert(len(last_states) == num_gru_layer) 95 # embeding layer 96 data = mx.sym.Variable('data') 97 label = mx.sym.Variable('softmax_label') 98 embed = mx.sym.Embedding(data=data, input_dim=input_size, 99 weight=embed_weight, output_dim=num_embed, name='embed') 100 wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) 101 102 hidden_all = [] 103 for seqidx in range(seq_len): 104 hidden = wordvec[seqidx] 105 106 # stack GRU 107 for i in range(num_gru_layer): 108 if i == 0: 109 dp_ratio = 0. 110 else: 111 dp_ratio = dropout 112 next_state = gru(num_hidden, indata=hidden, 113 prev_state=last_states[i], 114 param=param_cells[i], 115 seqidx=seqidx, layeridx=i, dropout=dp_ratio) 116 hidden = next_state.h 117 last_states[i] = next_state 118 # decoder 119 if dropout > 0.: 120 hidden = mx.sym.Dropout(data=hidden, p=dropout) 121 hidden_all.append(hidden) 122 123 hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 124 pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, 125 weight=cls_weight, bias=cls_bias, name='pred') 126 label = mx.sym.transpose(data=label) 127 label = mx.sym.Reshape(data=label, target_shape=(0,)) 128 return mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') 129 130