1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=missing-docstring
19from __future__ import print_function
20
21from collections import namedtuple
22
23import mxnet as mx
24from nce import nce_loss
25
26LSTMState = namedtuple("LSTMState", ["c", "h"])
27LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
28                                     "h2h_weight", "h2h_bias"])
29
30
31def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
32    """LSTM Cell symbol"""
33    if dropout > 0.:
34        indata = mx.sym.Dropout(data=indata, p=dropout)
35    i2h = mx.sym.FullyConnected(data=indata,
36                                weight=param.i2h_weight,
37                                bias=param.i2h_bias,
38                                num_hidden=num_hidden * 4,
39                                name="t%d_l%d_i2h" % (seqidx, layeridx))
40    h2h = mx.sym.FullyConnected(data=prev_state.h,
41                                weight=param.h2h_weight,
42                                bias=param.h2h_bias,
43                                num_hidden=num_hidden * 4,
44                                name="t%d_l%d_h2h" % (seqidx, layeridx))
45    gates = i2h + h2h
46    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
47                                      name="t%d_l%d_slice" % (seqidx, layeridx))
48    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
49    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
50    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
51    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
52    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
53    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
54    return LSTMState(c=next_c, h=next_h)
55
56
57def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden):
58    param_cells = []
59    last_states = []
60    for i in range(num_lstm_layer):
61        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
62                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
63                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
64                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
65        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
66                          h=mx.sym.Variable("l%d_init_h" % i))
67        last_states.append(state)
68
69    data = mx.sym.Variable('data')
70    label = mx.sym.Variable('label')
71    label_weight = mx.sym.Variable('label_weight')
72    embed_weight = mx.sym.Variable('embed_weight')
73    label_embed_weight = mx.sym.Variable('label_embed_weight')
74    data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
75                                  weight=embed_weight,
76                                  output_dim=100, name='data_embed')
77    datavec = mx.sym.SliceChannel(data=data_embed,
78                                  num_outputs=seq_len,
79                                  squeeze_axis=True, name='data_slice')
80    labelvec = mx.sym.SliceChannel(data=label,
81                                   num_outputs=seq_len,
82                                   squeeze_axis=True, name='label_slice')
83    labelweightvec = mx.sym.SliceChannel(data=label_weight,
84                                         num_outputs=seq_len,
85                                         squeeze_axis=True, name='label_weight_slice')
86    probs = []
87    for seqidx in range(seq_len):
88        hidden = datavec[seqidx]
89
90        for i in range(num_lstm_layer):
91            next_state = _lstm(num_hidden, indata=hidden,
92                               prev_state=last_states[i],
93                               param=param_cells[i],
94                               seqidx=seqidx, layeridx=i)
95            hidden = next_state.h
96            last_states[i] = next_state
97
98        probs.append(nce_loss(data=hidden,
99                              label=labelvec[seqidx],
100                              label_weight=labelweightvec[seqidx],
101                              embed_weight=label_embed_weight,
102                              vocab_size=vocab_size,
103                              num_hidden=100))
104    return mx.sym.Group(probs)
105