1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# An explicitly unrolled LSTM with fixed sequence length.
19using MXNet
20
21#--LSTMState
22struct LSTMState
23  c :: mx.SymbolicNode
24  h :: mx.SymbolicNode
25end
26#--/LSTMState
27
28#--LSTMParam
29struct LSTMParam
30  i2h_W :: mx.SymbolicNode
31  h2h_W :: mx.SymbolicNode
32  i2h_b :: mx.SymbolicNode
33  h2h_b :: mx.SymbolicNode
34end
35#--/LSTMParam
36
37#--lstm_cell
38function lstm_cell(data::mx.SymbolicNode, prev_state::LSTMState, param::LSTMParam;
39                   num_hidden::Int=512, dropout::Real=0, name::Symbol=gensym())
40
41  if dropout > 0
42    data = mx.Dropout(data, p=dropout)
43  end
44
45  i2h = mx.FullyConnected(data, weight=param.i2h_W, bias=param.i2h_b,
46                          num_hidden=4num_hidden, name=Symbol(name, "_i2h"))
47  h2h = mx.FullyConnected(prev_state.h, weight=param.h2h_W, bias=param.h2h_b,
48                          num_hidden=4num_hidden, name=Symbol(name, "_h2h"))
49
50  gates = mx.SliceChannel(i2h + h2h, num_outputs=4, name=Symbol(name, "_gates"))
51
52  in_gate     = mx.Activation(gates[1], act_type=:sigmoid)
53  in_trans    = mx.Activation(gates[2], act_type=:tanh)
54  forget_gate = mx.Activation(gates[3], act_type=:sigmoid)
55  out_gate    = mx.Activation(gates[4], act_type=:sigmoid)
56
57  next_c = (forget_gate .* prev_state.c) + (in_gate .* in_trans)
58  next_h = out_gate .* mx.Activation(next_c, act_type=:tanh)
59
60  return LSTMState(next_c, next_h)
61end
62#--/lstm_cell
63
64#--LSTM-part1
65function LSTM(n_layer::Int, seq_len::Int, dim_hidden::Int, dim_embed::Int, n_class::Int;
66              dropout::Real=0, name::Symbol=gensym(), output_states::Bool=false)
67
68  # placeholder nodes for all parameters
69  embed_W = mx.Variable(Symbol(name, "_embed_weight"))
70  pred_W  = mx.Variable(Symbol(name, "_pred_weight"))
71  pred_b  = mx.Variable(Symbol(name, "_pred_bias"))
72
73  layer_param_states = map(1:n_layer) do i
74    param = LSTMParam(mx.Variable(Symbol(name, "_l$(i)_i2h_weight")),
75                      mx.Variable(Symbol(name, "_l$(i)_h2h_weight")),
76                      mx.Variable(Symbol(name, "_l$(i)_i2h_bias")),
77                      mx.Variable(Symbol(name, "_l$(i)_h2h_bias")))
78    state = LSTMState(mx.Variable(Symbol(name, "_l$(i)_init_c")),
79                      mx.Variable(Symbol(name, "_l$(i)_init_h")))
80    (param, state)
81  end
82  #...
83  #--/LSTM-part1
84
85  #--LSTM-part2
86  # now unroll over time
87  outputs = mx.SymbolicNode[]
88  for t = 1:seq_len
89    data   = mx.Variable(Symbol(name, "_data_$t"))
90    label  = mx.Variable(Symbol(name, "_label_$t"))
91    hidden = mx.FullyConnected(data, weight=embed_W, num_hidden=dim_embed,
92                               no_bias=true, name=Symbol(name, "_embed_$t"))
93
94    # stack LSTM cells
95    for i = 1:n_layer
96      l_param, l_state = layer_param_states[i]
97      dp = i == 1 ? 0 : dropout # don't do dropout for data
98      next_state = lstm_cell(hidden, l_state, l_param, num_hidden=dim_hidden, dropout=dp,
99                             name=Symbol(name, "_lstm_$t"))
100      hidden = next_state.h
101      layer_param_states[i] = (l_param, next_state)
102    end
103
104    # prediction / decoder
105    if dropout > 0
106      hidden = mx.Dropout(hidden, p=dropout)
107    end
108    pred = mx.FullyConnected(hidden, weight=pred_W, bias=pred_b, num_hidden=n_class,
109                             name=Symbol(name, "_pred_$t"))
110    smax = mx.SoftmaxOutput(pred, label, name=Symbol(name, "_softmax_$t"))
111    push!(outputs, smax)
112  end
113  #...
114  #--/LSTM-part2
115
116  #--LSTM-part3
117  # append block-gradient nodes to the final states
118  for i = 1:n_layer
119    l_param, l_state = layer_param_states[i]
120    final_state = LSTMState(mx.BlockGrad(l_state.c, name=Symbol(name, "_l$(i)_last_c")),
121                            mx.BlockGrad(l_state.h, name=Symbol(name, "_l$(i)_last_h")))
122    layer_param_states[i] = (l_param, final_state)
123  end
124
125  # now group all outputs together
126  if output_states
127    outputs = outputs ∪ [x[2].c for x in layer_param_states] ∪
128                        [x[2].h for x in layer_param_states]
129  end
130  return mx.Group(outputs...)
131end
132#--/LSTM-part3
133
134
135# Negative Log-likelihood
136mutable struct NLL <: mx.AbstractEvalMetric
137  nll_sum  :: Float64
138  n_sample :: Int
139
140  NLL() = new(0.0, 0)
141end
142
143function mx.update!(metric::NLL, labels::Vector{<:mx.NDArray}, preds::Vector{<:mx.NDArray})
144  @assert length(labels) == length(preds)
145  nll = 0.0
146  for (label, pred) in zip(labels, preds)
147    @mx.nd_as_jl ro=(label, pred) begin
148      nll -= sum(
149        log.(
150          max.(
151            getindex.(
152            (pred,),
153            round.(Int,label .+ 1),
154            1:length(label)),
155          1e-20)
156        )
157      )
158    end
159  end
160
161  nll = nll / length(labels)
162  metric.nll_sum += nll
163  metric.n_sample += length(labels[1])
164end
165
166function mx.get(metric :: NLL)
167  nll  = metric.nll_sum / metric.n_sample
168  perp = exp(nll)
169  return [(:NLL, nll), (:perplexity, perp)]
170end
171
172function mx.reset!(metric :: NLL)
173  metric.nll_sum  = 0.0
174  metric.n_sample = 0
175end
176