1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint:skip-file 19import sys 20sys.path.insert(0, "../../python") 21import mxnet as mx 22import numpy as np 23from collections import namedtuple 24import time 25import math 26LSTMState = namedtuple("LSTMState", ["c", "h"]) 27LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 28 "h2h_weight", "h2h_bias"]) 29LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 30 "init_states", "last_states", 31 "seq_data", "seq_labels", "seq_outputs", 32 "param_blocks"]) 33 34def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 35 """LSTM Cell symbol""" 36 if dropout > 0.: 37 indata = mx.sym.Dropout(data=indata, p=dropout) 38 i2h = mx.sym.FullyConnected(data=indata, 39 weight=param.i2h_weight, 40 bias=param.i2h_bias, 41 num_hidden=num_hidden * 4, 42 name="t%d_l%d_i2h" % (seqidx, layeridx)) 43 h2h = mx.sym.FullyConnected(data=prev_state.h, 44 weight=param.h2h_weight, 45 bias=param.h2h_bias, 46 num_hidden=num_hidden * 4, 47 name="t%d_l%d_h2h" % (seqidx, layeridx)) 48 gates = i2h + h2h 49 slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 50 name="t%d_l%d_slice" % (seqidx, layeridx)) 51 in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 52 in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 53 forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 54 out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 55 next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 56 next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 57 return LSTMState(c=next_c, h=next_h) 58 59 60# we define a new unrolling function here because the original 61# one in lstm.py concats all the labels at the last layer together, 62# making the mini-batch size of the label different from the data. 63# I think the existing data-parallelization code need some modification 64# to allow this situation to work properly 65def lstm_unroll(num_lstm_layer, seq_len, input_size, 66 num_hidden, num_embed, num_label, dropout=0.): 67 68 embed_weight = mx.sym.Variable("embed_weight") 69 cls_weight = mx.sym.Variable("cls_weight") 70 cls_bias = mx.sym.Variable("cls_bias") 71 param_cells = [] 72 last_states = [] 73 for i in range(num_lstm_layer): 74 param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 75 i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 76 h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 77 h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 78 state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 79 h=mx.sym.Variable("l%d_init_h" % i)) 80 last_states.append(state) 81 assert(len(last_states) == num_lstm_layer) 82 83 # embeding layer 84 data = mx.sym.Variable('data') 85 label = mx.sym.Variable('softmax_label') 86 embed = mx.sym.Embedding(data=data, input_dim=input_size, 87 weight=embed_weight, output_dim=num_embed, name='embed') 88 wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) 89 90 hidden_all = [] 91 for seqidx in range(seq_len): 92 hidden = wordvec[seqidx] 93 94 # stack LSTM 95 for i in range(num_lstm_layer): 96 if i == 0: 97 dp_ratio = 0. 98 else: 99 dp_ratio = dropout 100 next_state = lstm(num_hidden, indata=hidden, 101 prev_state=last_states[i], 102 param=param_cells[i], 103 seqidx=seqidx, layeridx=i, dropout=dp_ratio) 104 hidden = next_state.h 105 last_states[i] = next_state 106 # decoder 107 if dropout > 0.: 108 hidden = mx.sym.Dropout(data=hidden, p=dropout) 109 hidden_all.append(hidden) 110 111 hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 112 pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, 113 weight=cls_weight, bias=cls_bias, name='pred') 114 115 ################################################################################ 116 # Make label the same shape as our produced data path 117 # I did not observe big speed difference between the following two ways 118 119 label = mx.sym.transpose(data=label) 120 label = mx.sym.Reshape(data=label, target_shape=(0,)) 121 122 #label_slice = mx.sym.SliceChannel(data=label, num_outputs=seq_len) 123 #label = [label_slice[t] for t in range(seq_len)] 124 #label = mx.sym.Concat(*label, dim=0) 125 #label = mx.sym.Reshape(data=label, target_shape=(0,)) 126 ################################################################################ 127 128 sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') 129 130 return sm 131 132def lstm_inference_symbol(num_lstm_layer, input_size, 133 num_hidden, num_embed, num_label, dropout=0.): 134 seqidx = 0 135 embed_weight=mx.sym.Variable("embed_weight") 136 cls_weight = mx.sym.Variable("cls_weight") 137 cls_bias = mx.sym.Variable("cls_bias") 138 param_cells = [] 139 last_states = [] 140 for i in range(num_lstm_layer): 141 param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i), 142 i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i), 143 h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i), 144 h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i))) 145 state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 146 h=mx.sym.Variable("l%d_init_h" % i)) 147 last_states.append(state) 148 assert(len(last_states) == num_lstm_layer) 149 data = mx.sym.Variable("data") 150 151 hidden = mx.sym.Embedding(data=data, 152 input_dim=input_size, 153 output_dim=num_embed, 154 weight=embed_weight, 155 name="embed") 156 # stack LSTM 157 for i in range(num_lstm_layer): 158 if i==0: 159 dp=0. 160 else: 161 dp = dropout 162 next_state = lstm(num_hidden, indata=hidden, 163 prev_state=last_states[i], 164 param=param_cells[i], 165 seqidx=seqidx, layeridx=i, dropout=dp) 166 hidden = next_state.h 167 last_states[i] = next_state 168 # decoder 169 if dropout > 0.: 170 hidden = mx.sym.Dropout(data=hidden, p=dropout) 171 fc = mx.sym.FullyConnected(data=hidden, num_hidden=num_label, 172 weight=cls_weight, bias=cls_bias, name='pred') 173 sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') 174 output = [sm] 175 for state in last_states: 176 output.append(state.c) 177 output.append(state.h) 178 return mx.sym.Group(output) 179