1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint:skip-file
19import sys
20sys.path.insert(0, "../../python")
21import mxnet as mx
22import numpy as np
23from collections import namedtuple
24import time
25import math
26LSTMState = namedtuple("LSTMState", ["c", "h"])
27LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
28                                     "h2h_weight", "h2h_bias"])
29LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
30                                     "init_states", "last_states",
31                                     "seq_data", "seq_labels", "seq_outputs",
32                                     "param_blocks"])
33
34def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
35    """LSTM Cell symbol"""
36    if dropout > 0.:
37        indata = mx.sym.Dropout(data=indata, p=dropout)
38    i2h = mx.sym.FullyConnected(data=indata,
39                                weight=param.i2h_weight,
40                                bias=param.i2h_bias,
41                                num_hidden=num_hidden * 4,
42                                name="t%d_l%d_i2h" % (seqidx, layeridx))
43    h2h = mx.sym.FullyConnected(data=prev_state.h,
44                                weight=param.h2h_weight,
45                                bias=param.h2h_bias,
46                                num_hidden=num_hidden * 4,
47                                name="t%d_l%d_h2h" % (seqidx, layeridx))
48    gates = i2h + h2h
49    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
50                                      name="t%d_l%d_slice" % (seqidx, layeridx))
51    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
52    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
53    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
54    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
55    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
56    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
57    return LSTMState(c=next_c, h=next_h)
58
59
60# we define a new unrolling function here because the original
61# one in lstm.py concats all the labels at the last layer together,
62# making the mini-batch size of the label different from the data.
63# I think the existing data-parallelization code need some modification
64# to allow this situation to work properly
65def lstm_unroll(num_lstm_layer, seq_len, input_size,
66                num_hidden, num_embed, num_label, dropout=0.):
67
68    embed_weight = mx.sym.Variable("embed_weight")
69    cls_weight = mx.sym.Variable("cls_weight")
70    cls_bias = mx.sym.Variable("cls_bias")
71    param_cells = []
72    last_states = []
73    for i in range(num_lstm_layer):
74        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
75                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
76                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
77                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
78        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
79                          h=mx.sym.Variable("l%d_init_h" % i))
80        last_states.append(state)
81    assert(len(last_states) == num_lstm_layer)
82
83    # embeding layer
84    data = mx.sym.Variable('data')
85    label = mx.sym.Variable('softmax_label')
86    embed = mx.sym.Embedding(data=data, input_dim=input_size,
87                             weight=embed_weight, output_dim=num_embed, name='embed')
88    wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
89
90    hidden_all = []
91    for seqidx in range(seq_len):
92        hidden = wordvec[seqidx]
93
94        # stack LSTM
95        for i in range(num_lstm_layer):
96            if i == 0:
97                dp_ratio = 0.
98            else:
99                dp_ratio = dropout
100            next_state = lstm(num_hidden, indata=hidden,
101                              prev_state=last_states[i],
102                              param=param_cells[i],
103                              seqidx=seqidx, layeridx=i, dropout=dp_ratio)
104            hidden = next_state.h
105            last_states[i] = next_state
106        # decoder
107        if dropout > 0.:
108            hidden = mx.sym.Dropout(data=hidden, p=dropout)
109        hidden_all.append(hidden)
110
111    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
112    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
113                                 weight=cls_weight, bias=cls_bias, name='pred')
114
115    ################################################################################
116    # Make label the same shape as our produced data path
117    # I did not observe big speed difference between the following two ways
118
119    label = mx.sym.transpose(data=label)
120    label = mx.sym.Reshape(data=label, target_shape=(0,))
121
122    #label_slice = mx.sym.SliceChannel(data=label, num_outputs=seq_len)
123    #label = [label_slice[t] for t in range(seq_len)]
124    #label = mx.sym.Concat(*label, dim=0)
125    #label = mx.sym.Reshape(data=label, target_shape=(0,))
126    ################################################################################
127
128    sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
129
130    return sm
131
132def lstm_inference_symbol(num_lstm_layer, input_size,
133                          num_hidden, num_embed, num_label, dropout=0.):
134    seqidx = 0
135    embed_weight=mx.sym.Variable("embed_weight")
136    cls_weight = mx.sym.Variable("cls_weight")
137    cls_bias = mx.sym.Variable("cls_bias")
138    param_cells = []
139    last_states = []
140    for i in range(num_lstm_layer):
141        param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
142                                      i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
143                                      h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
144                                      h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i)))
145        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
146                          h=mx.sym.Variable("l%d_init_h" % i))
147        last_states.append(state)
148    assert(len(last_states) == num_lstm_layer)
149    data = mx.sym.Variable("data")
150
151    hidden = mx.sym.Embedding(data=data,
152                              input_dim=input_size,
153                              output_dim=num_embed,
154                              weight=embed_weight,
155                              name="embed")
156    # stack LSTM
157    for i in range(num_lstm_layer):
158        if i==0:
159            dp=0.
160        else:
161            dp = dropout
162        next_state = lstm(num_hidden, indata=hidden,
163                          prev_state=last_states[i],
164                          param=param_cells[i],
165                          seqidx=seqidx, layeridx=i, dropout=dp)
166        hidden = next_state.h
167        last_states[i] = next_state
168    # decoder
169    if dropout > 0.:
170        hidden = mx.sym.Dropout(data=hidden, p=dropout)
171    fc = mx.sym.FullyConnected(data=hidden, num_hidden=num_label,
172                               weight=cls_weight, bias=cls_bias, name='pred')
173    sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
174    output = [sm]
175    for state in last_states:
176        output.append(state.c)
177        output.append(state.h)
178    return mx.sym.Group(output)
179