1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Contain helpers for creating LSTM symbolic graph for training and inference """
18
19from __future__ import print_function
20
21from collections import namedtuple
22
23import mxnet as mx
24
25
26__all__ = ["lstm_unroll", "init_states"]
27
28
29LSTMState = namedtuple("LSTMState", ["c", "h"])
30LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
31                                     "h2h_weight", "h2h_bias"])
32
33
34def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
35    """LSTM Cell symbol"""
36    i2h = mx.sym.FullyConnected(data=indata,
37                                weight=param.i2h_weight,
38                                bias=param.i2h_bias,
39                                num_hidden=num_hidden * 4,
40                                name="t%d_l%d_i2h" % (seqidx, layeridx))
41    h2h = mx.sym.FullyConnected(data=prev_state.h,
42                                weight=param.h2h_weight,
43                                bias=param.h2h_bias,
44                                num_hidden=num_hidden * 4,
45                                name="t%d_l%d_h2h" % (seqidx, layeridx))
46    gates = i2h + h2h
47    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
48                                      name="t%d_l%d_slice" % (seqidx, layeridx))
49    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
50    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
51    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
52    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
53    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
54    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
55    return LSTMState(c=next_c, h=next_h)
56
57
58def _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden):
59    """ Returns symbol for LSTM model up to loss/softmax"""
60    param_cells = []
61    last_states = []
62    for i in range(num_lstm_layer):
63        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
64                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
65                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
66                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
67        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
68                          h=mx.sym.Variable("l%d_init_h" % i))
69        last_states.append(state)
70    assert len(last_states) == num_lstm_layer
71
72    # embedding layer
73    data = mx.sym.Variable('data')
74    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
75
76    hidden_all = []
77    for seqidx in range(seq_len):
78        hidden = wordvec[seqidx]
79        for i in range(num_lstm_layer):
80            next_state = _lstm(
81                num_hidden=num_hidden,
82                indata=hidden,
83                prev_state=last_states[i],
84                param=param_cells[i],
85                seqidx=seqidx,
86                layeridx=i)
87            hidden = next_state.h
88            last_states[i] = next_state
89        hidden_all.append(hidden)
90
91    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
92    pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11, name="pred_fc")
93    return pred_fc
94
95
96def _add_warp_ctc_loss(pred, seq_len, num_label, label):
97    """ Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """
98    label = mx.sym.Reshape(data=label, shape=(-1,))
99    label = mx.sym.Cast(data=label, dtype='int32')
100    return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len)
101
102
103def _add_mxnet_ctc_loss(pred, seq_len, label):
104    """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """
105    pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0))
106
107    loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
108    ctc_loss = mx.sym.MakeLoss(loss)
109
110    softmax_class = mx.symbol.SoftmaxActivation(data=pred)
111    softmax_loss = mx.sym.MakeLoss(softmax_class)
112    softmax_loss = mx.sym.BlockGrad(softmax_loss)
113    return mx.sym.Group([softmax_loss, ctc_loss])
114
115
116def _add_ctc_loss(pred, seq_len, num_label, loss_type):
117    """ Adds CTC loss on top of pred symbol and returns the resulting symbol """
118    label = mx.sym.Variable('label')
119    if loss_type == 'warpctc':
120        print("Using WarpCTC Loss")
121        sm = _add_warp_ctc_loss(pred, seq_len, num_label, label)
122    else:
123        print("Using MXNet CTC Loss")
124        assert loss_type == 'ctc'
125        sm = _add_mxnet_ctc_loss(pred, seq_len, label)
126    return sm
127
128
129def lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_label, loss_type=None):
130    """
131    Creates an unrolled LSTM symbol for inference if loss_type is not specified, and for training
132    if loss_type is specified. loss_type must be one of 'ctc' or 'warpctc'
133
134    Parameters
135    ----------
136    num_lstm_layer: int
137    seq_len: int
138    num_hidden: int
139    num_label: int
140    loss_type: str
141        'ctc' or 'warpctc'
142
143    Returns
144    -------
145    mxnet.symbol.symbol.Symbol
146    """
147    # Create the base (shared between training and inference) and add loss to the end
148    pred = _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden)
149
150    if loss_type:
151        # Training mode, add loss
152        return _add_ctc_loss(pred, seq_len, num_label, loss_type)
153    else:
154        # Inference mode, add softmax
155        return mx.sym.softmax(data=pred, name='softmax')
156
157
158def init_states(batch_size, num_lstm_layer, num_hidden):
159    """
160    Returns name and shape of init states of LSTM network
161
162    Parameters
163    ----------
164    batch_size: list of tuple of str and tuple of int and int
165    num_lstm_layer: int
166    num_hidden: int
167
168    Returns
169    -------
170    list of tuple of str and tuple of int and int
171    """
172    init_c = [('l%d_init_c' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
173    init_h = [('l%d_init_h' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
174    return init_c + init_h
175