1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=missing-docstring 19from __future__ import print_function 20 21from collections import namedtuple 22 23import mxnet as mx 24from nce import nce_loss 25 26LSTMState = namedtuple("LSTMState", ["c", "h"]) 27LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 28 "h2h_weight", "h2h_bias"]) 29 30 31def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 32 """LSTM Cell symbol""" 33 if dropout > 0.: 34 indata = mx.sym.Dropout(data=indata, p=dropout) 35 i2h = mx.sym.FullyConnected(data=indata, 36 weight=param.i2h_weight, 37 bias=param.i2h_bias, 38 num_hidden=num_hidden * 4, 39 name="t%d_l%d_i2h" % (seqidx, layeridx)) 40 h2h = mx.sym.FullyConnected(data=prev_state.h, 41 weight=param.h2h_weight, 42 bias=param.h2h_bias, 43 num_hidden=num_hidden * 4, 44 name="t%d_l%d_h2h" % (seqidx, layeridx)) 45 gates = i2h + h2h 46 slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 47 name="t%d_l%d_slice" % (seqidx, layeridx)) 48 in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 49 in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 50 forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 51 out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 52 next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 53 next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 54 return LSTMState(c=next_c, h=next_h) 55 56 57def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden): 58 param_cells = [] 59 last_states = [] 60 for i in range(num_lstm_layer): 61 param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 62 i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 63 h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 64 h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 65 state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 66 h=mx.sym.Variable("l%d_init_h" % i)) 67 last_states.append(state) 68 69 data = mx.sym.Variable('data') 70 label = mx.sym.Variable('label') 71 label_weight = mx.sym.Variable('label_weight') 72 embed_weight = mx.sym.Variable('embed_weight') 73 label_embed_weight = mx.sym.Variable('label_embed_weight') 74 data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size, 75 weight=embed_weight, 76 output_dim=100, name='data_embed') 77 datavec = mx.sym.SliceChannel(data=data_embed, 78 num_outputs=seq_len, 79 squeeze_axis=True, name='data_slice') 80 labelvec = mx.sym.SliceChannel(data=label, 81 num_outputs=seq_len, 82 squeeze_axis=True, name='label_slice') 83 labelweightvec = mx.sym.SliceChannel(data=label_weight, 84 num_outputs=seq_len, 85 squeeze_axis=True, name='label_weight_slice') 86 probs = [] 87 for seqidx in range(seq_len): 88 hidden = datavec[seqidx] 89 90 for i in range(num_lstm_layer): 91 next_state = _lstm(num_hidden, indata=hidden, 92 prev_state=last_states[i], 93 param=param_cells[i], 94 seqidx=seqidx, layeridx=i) 95 hidden = next_state.h 96 last_states[i] = next_state 97 98 probs.append(nce_loss(data=hidden, 99 label=labelvec[seqidx], 100 label_weight=labelweightvec[seqidx], 101 embed_weight=label_embed_weight, 102 vocab_size=vocab_size, 103 num_hidden=100)) 104 return mx.sym.Group(probs) 105