1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# An explicitly unrolled LSTM with fixed sequence length. 19using MXNet 20 21#--LSTMState 22struct LSTMState 23 c :: mx.SymbolicNode 24 h :: mx.SymbolicNode 25end 26#--/LSTMState 27 28#--LSTMParam 29struct LSTMParam 30 i2h_W :: mx.SymbolicNode 31 h2h_W :: mx.SymbolicNode 32 i2h_b :: mx.SymbolicNode 33 h2h_b :: mx.SymbolicNode 34end 35#--/LSTMParam 36 37#--lstm_cell 38function lstm_cell(data::mx.SymbolicNode, prev_state::LSTMState, param::LSTMParam; 39 num_hidden::Int=512, dropout::Real=0, name::Symbol=gensym()) 40 41 if dropout > 0 42 data = mx.Dropout(data, p=dropout) 43 end 44 45 i2h = mx.FullyConnected(data, weight=param.i2h_W, bias=param.i2h_b, 46 num_hidden=4num_hidden, name=Symbol(name, "_i2h")) 47 h2h = mx.FullyConnected(prev_state.h, weight=param.h2h_W, bias=param.h2h_b, 48 num_hidden=4num_hidden, name=Symbol(name, "_h2h")) 49 50 gates = mx.SliceChannel(i2h + h2h, num_outputs=4, name=Symbol(name, "_gates")) 51 52 in_gate = mx.Activation(gates[1], act_type=:sigmoid) 53 in_trans = mx.Activation(gates[2], act_type=:tanh) 54 forget_gate = mx.Activation(gates[3], act_type=:sigmoid) 55 out_gate = mx.Activation(gates[4], act_type=:sigmoid) 56 57 next_c = (forget_gate .* prev_state.c) + (in_gate .* in_trans) 58 next_h = out_gate .* mx.Activation(next_c, act_type=:tanh) 59 60 return LSTMState(next_c, next_h) 61end 62#--/lstm_cell 63 64#--LSTM-part1 65function LSTM(n_layer::Int, seq_len::Int, dim_hidden::Int, dim_embed::Int, n_class::Int; 66 dropout::Real=0, name::Symbol=gensym(), output_states::Bool=false) 67 68 # placeholder nodes for all parameters 69 embed_W = mx.Variable(Symbol(name, "_embed_weight")) 70 pred_W = mx.Variable(Symbol(name, "_pred_weight")) 71 pred_b = mx.Variable(Symbol(name, "_pred_bias")) 72 73 layer_param_states = map(1:n_layer) do i 74 param = LSTMParam(mx.Variable(Symbol(name, "_l$(i)_i2h_weight")), 75 mx.Variable(Symbol(name, "_l$(i)_h2h_weight")), 76 mx.Variable(Symbol(name, "_l$(i)_i2h_bias")), 77 mx.Variable(Symbol(name, "_l$(i)_h2h_bias"))) 78 state = LSTMState(mx.Variable(Symbol(name, "_l$(i)_init_c")), 79 mx.Variable(Symbol(name, "_l$(i)_init_h"))) 80 (param, state) 81 end 82 #... 83 #--/LSTM-part1 84 85 #--LSTM-part2 86 # now unroll over time 87 outputs = mx.SymbolicNode[] 88 for t = 1:seq_len 89 data = mx.Variable(Symbol(name, "_data_$t")) 90 label = mx.Variable(Symbol(name, "_label_$t")) 91 hidden = mx.FullyConnected(data, weight=embed_W, num_hidden=dim_embed, 92 no_bias=true, name=Symbol(name, "_embed_$t")) 93 94 # stack LSTM cells 95 for i = 1:n_layer 96 l_param, l_state = layer_param_states[i] 97 dp = i == 1 ? 0 : dropout # don't do dropout for data 98 next_state = lstm_cell(hidden, l_state, l_param, num_hidden=dim_hidden, dropout=dp, 99 name=Symbol(name, "_lstm_$t")) 100 hidden = next_state.h 101 layer_param_states[i] = (l_param, next_state) 102 end 103 104 # prediction / decoder 105 if dropout > 0 106 hidden = mx.Dropout(hidden, p=dropout) 107 end 108 pred = mx.FullyConnected(hidden, weight=pred_W, bias=pred_b, num_hidden=n_class, 109 name=Symbol(name, "_pred_$t")) 110 smax = mx.SoftmaxOutput(pred, label, name=Symbol(name, "_softmax_$t")) 111 push!(outputs, smax) 112 end 113 #... 114 #--/LSTM-part2 115 116 #--LSTM-part3 117 # append block-gradient nodes to the final states 118 for i = 1:n_layer 119 l_param, l_state = layer_param_states[i] 120 final_state = LSTMState(mx.BlockGrad(l_state.c, name=Symbol(name, "_l$(i)_last_c")), 121 mx.BlockGrad(l_state.h, name=Symbol(name, "_l$(i)_last_h"))) 122 layer_param_states[i] = (l_param, final_state) 123 end 124 125 # now group all outputs together 126 if output_states 127 outputs = outputs ∪ [x[2].c for x in layer_param_states] ∪ 128 [x[2].h for x in layer_param_states] 129 end 130 return mx.Group(outputs...) 131end 132#--/LSTM-part3 133 134 135# Negative Log-likelihood 136mutable struct NLL <: mx.AbstractEvalMetric 137 nll_sum :: Float64 138 n_sample :: Int 139 140 NLL() = new(0.0, 0) 141end 142 143function mx.update!(metric::NLL, labels::Vector{<:mx.NDArray}, preds::Vector{<:mx.NDArray}) 144 @assert length(labels) == length(preds) 145 nll = 0.0 146 for (label, pred) in zip(labels, preds) 147 @mx.nd_as_jl ro=(label, pred) begin 148 nll -= sum( 149 log.( 150 max.( 151 getindex.( 152 (pred,), 153 round.(Int,label .+ 1), 154 1:length(label)), 155 1e-20) 156 ) 157 ) 158 end 159 end 160 161 nll = nll / length(labels) 162 metric.nll_sum += nll 163 metric.n_sample += length(labels[1]) 164end 165 166function mx.get(metric :: NLL) 167 nll = metric.nll_sum / metric.n_sample 168 perp = exp(nll) 169 return [(:NLL, nll), (:perplexity, perp)] 170end 171 172function mx.reset!(metric :: NLL) 173 metric.nll_sum = 0.0 174 metric.n_sample = 0 175end 176