1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Contain helpers for creating LSTM symbolic graph for training and inference """ 18 19from __future__ import print_function 20 21from collections import namedtuple 22 23import mxnet as mx 24 25 26__all__ = ["lstm_unroll", "init_states"] 27 28 29LSTMState = namedtuple("LSTMState", ["c", "h"]) 30LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 31 "h2h_weight", "h2h_bias"]) 32 33 34def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): 35 """LSTM Cell symbol""" 36 i2h = mx.sym.FullyConnected(data=indata, 37 weight=param.i2h_weight, 38 bias=param.i2h_bias, 39 num_hidden=num_hidden * 4, 40 name="t%d_l%d_i2h" % (seqidx, layeridx)) 41 h2h = mx.sym.FullyConnected(data=prev_state.h, 42 weight=param.h2h_weight, 43 bias=param.h2h_bias, 44 num_hidden=num_hidden * 4, 45 name="t%d_l%d_h2h" % (seqidx, layeridx)) 46 gates = i2h + h2h 47 slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 48 name="t%d_l%d_slice" % (seqidx, layeridx)) 49 in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 50 in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 51 forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 52 out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 53 next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 54 next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 55 return LSTMState(c=next_c, h=next_h) 56 57 58def _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden): 59 """ Returns symbol for LSTM model up to loss/softmax""" 60 param_cells = [] 61 last_states = [] 62 for i in range(num_lstm_layer): 63 param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 64 i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 65 h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 66 h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 67 state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 68 h=mx.sym.Variable("l%d_init_h" % i)) 69 last_states.append(state) 70 assert len(last_states) == num_lstm_layer 71 72 # embedding layer 73 data = mx.sym.Variable('data') 74 wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 75 76 hidden_all = [] 77 for seqidx in range(seq_len): 78 hidden = wordvec[seqidx] 79 for i in range(num_lstm_layer): 80 next_state = _lstm( 81 num_hidden=num_hidden, 82 indata=hidden, 83 prev_state=last_states[i], 84 param=param_cells[i], 85 seqidx=seqidx, 86 layeridx=i) 87 hidden = next_state.h 88 last_states[i] = next_state 89 hidden_all.append(hidden) 90 91 hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 92 pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11, name="pred_fc") 93 return pred_fc 94 95 96def _add_warp_ctc_loss(pred, seq_len, num_label, label): 97 """ Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """ 98 label = mx.sym.Reshape(data=label, shape=(-1,)) 99 label = mx.sym.Cast(data=label, dtype='int32') 100 return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len) 101 102 103def _add_mxnet_ctc_loss(pred, seq_len, label): 104 """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """ 105 pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0)) 106 107 loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label) 108 ctc_loss = mx.sym.MakeLoss(loss) 109 110 softmax_class = mx.symbol.SoftmaxActivation(data=pred) 111 softmax_loss = mx.sym.MakeLoss(softmax_class) 112 softmax_loss = mx.sym.BlockGrad(softmax_loss) 113 return mx.sym.Group([softmax_loss, ctc_loss]) 114 115 116def _add_ctc_loss(pred, seq_len, num_label, loss_type): 117 """ Adds CTC loss on top of pred symbol and returns the resulting symbol """ 118 label = mx.sym.Variable('label') 119 if loss_type == 'warpctc': 120 print("Using WarpCTC Loss") 121 sm = _add_warp_ctc_loss(pred, seq_len, num_label, label) 122 else: 123 print("Using MXNet CTC Loss") 124 assert loss_type == 'ctc' 125 sm = _add_mxnet_ctc_loss(pred, seq_len, label) 126 return sm 127 128 129def lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_label, loss_type=None): 130 """ 131 Creates an unrolled LSTM symbol for inference if loss_type is not specified, and for training 132 if loss_type is specified. loss_type must be one of 'ctc' or 'warpctc' 133 134 Parameters 135 ---------- 136 num_lstm_layer: int 137 seq_len: int 138 num_hidden: int 139 num_label: int 140 loss_type: str 141 'ctc' or 'warpctc' 142 143 Returns 144 ------- 145 mxnet.symbol.symbol.Symbol 146 """ 147 # Create the base (shared between training and inference) and add loss to the end 148 pred = _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden) 149 150 if loss_type: 151 # Training mode, add loss 152 return _add_ctc_loss(pred, seq_len, num_label, loss_type) 153 else: 154 # Inference mode, add softmax 155 return mx.sym.softmax(data=pred, name='softmax') 156 157 158def init_states(batch_size, num_lstm_layer, num_hidden): 159 """ 160 Returns name and shape of init states of LSTM network 161 162 Parameters 163 ---------- 164 batch_size: list of tuple of str and tuple of int and int 165 num_lstm_layer: int 166 num_hidden: int 167 168 Returns 169 ------- 170 list of tuple of str and tuple of int and int 171 """ 172 init_c = [('l%d_init_c' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] 173 init_h = [('l%d_init_h' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] 174 return init_c + init_h 175