1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""LSTM projection cell with cell clip and projection clip.""" 18__all__ = ['LSTMPCellWithClip'] 19 20from mxnet.gluon.contrib.rnn import LSTMPCell 21 22 23class LSTMPCellWithClip(LSTMPCell): 24 r"""Long-Short Term Memory Projected (LSTMP) network cell with cell clip and projection clip. 25 Each call computes the following function: 26 27 .. math:: 28 29 \DeclareMathOperator{\sigmoid}{sigmoid} 30 \begin{array}{ll} 31 i_t = \sigmoid(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ 32 f_t = \sigmoid(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ 33 g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ 34 o_t = \sigmoid(W_{io} x_t + b_{io} + W_{ro} r_{(t-1)} + b_{ro}) \\ 35 c_t = c_{\text{clip}}(f_t * c_{(t-1)} + i_t * g_t) \\ 36 h_t = o_t * \tanh(c_t) \\ 37 r_t = p_{\text{clip}}(W_{hr} h_t) 38 \end{array} 39 40 where :math:`c_{\text{clip}}` is the cell clip applied on the next cell; 41 :math:`r_t` is the projected recurrent activation at time `t`, 42 :math:`p_{\text{clip}}` means apply projection clip on he projected output. 43 math:`h_t` is the hidden state at time `t`, :math:`c_t` is the 44 cell state at time `t`, :math:`x_t` is the input at time `t`, and :math:`i_t`, 45 :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and 46 out gates, respectively. 47 48 Parameters 49 ---------- 50 hidden_size : int 51 Number of units in cell state symbol. 52 projection_size : int 53 Number of units in output symbol. 54 i2h_weight_initializer : str or Initializer 55 Initializer for the input weights matrix, used for the linear 56 transformation of the inputs. 57 h2h_weight_initializer : str or Initializer 58 Initializer for the recurrent weights matrix, used for the linear 59 transformation of the hidden state. 60 h2r_weight_initializer : str or Initializer 61 Initializer for the projection weights matrix, used for the linear 62 transformation of the recurrent state. 63 i2h_bias_initializer : str or Initializer, default 'lstmbias' 64 Initializer for the bias vector. By default, bias for the forget 65 gate is initialized to 1 while all other biases are initialized 66 to zero. 67 h2h_bias_initializer : str or Initializer 68 Initializer for the bias vector. 69 prefix : str 70 Prefix for name of `Block`s 71 (and name of weight if params is `None`). 72 params : Parameter or None 73 Container for weight sharing between cells. 74 Created if `None`. 75 cell_clip : float 76 Clip cell state between `[-cell_clip, cell_clip]` in LSTMPCellWithClip cell 77 projection_clip : float 78 Clip projection between `[-projection_clip, projection_clip]` in LSTMPCellWithClip cell 79 """ 80 def __init__(self, hidden_size, projection_size, 81 i2h_weight_initializer=None, h2h_weight_initializer=None, 82 h2r_weight_initializer=None, 83 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', 84 input_size=0, cell_clip=None, projection_clip=None, prefix=None, params=None): 85 super(LSTMPCellWithClip, self).__init__(hidden_size, 86 projection_size, 87 i2h_weight_initializer, 88 h2h_weight_initializer, 89 h2r_weight_initializer, 90 i2h_bias_initializer, 91 h2h_bias_initializer, 92 input_size, 93 prefix=prefix, 94 params=params) 95 96 self._cell_clip = cell_clip 97 self._projection_clip = projection_clip 98 99 # pylint: disable= arguments-differ 100 def hybrid_forward(self, F, inputs, states, i2h_weight, 101 h2h_weight, h2r_weight, i2h_bias, h2h_bias): 102 r"""Hybrid forward computation for Long-Short Term Memory Projected network cell 103 with cell clip and projection clip. 104 105 Parameters 106 ---------- 107 inputs : input tensor with shape `(batch_size, input_size)`. 108 states : a list of two initial recurrent state tensors, with shape 109 `(batch_size, projection_size)` and `(batch_size, hidden_size)` respectively. 110 111 Returns 112 -------- 113 out : output tensor with shape `(batch_size, num_hidden)`. 114 next_states : a list of two output recurrent state tensors. Each has 115 the same shape as `states`. 116 """ 117 prefix = 't%d_'%self._counter 118 i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, 119 num_hidden=self._hidden_size*4, name=prefix+'i2h') 120 h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, 121 num_hidden=self._hidden_size*4, name=prefix+'h2h') 122 gates = i2h + h2h 123 slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice') 124 in_gate = F.Activation(slice_gates[0], act_type='sigmoid', name=prefix+'i') 125 forget_gate = F.Activation(slice_gates[1], act_type='sigmoid', name=prefix+'f') 126 in_transform = F.Activation(slice_gates[2], act_type='tanh', name=prefix+'c') 127 out_gate = F.Activation(slice_gates[3], act_type='sigmoid', name=prefix+'o') 128 next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform, 129 name=prefix+'state') 130 if self._cell_clip is not None: 131 next_c = next_c.clip(-self._cell_clip, self._cell_clip) 132 hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type='tanh'), 133 name=prefix+'hidden') 134 next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size, 135 weight=h2r_weight, no_bias=True, name=prefix+'out') 136 if self._projection_clip is not None: 137 next_r = next_r.clip(-self._projection_clip, self._projection_clip) 138 139 return next_r, [next_r, next_c] 140