1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""LSTM projection cell with cell clip and projection clip."""
18__all__ = ['LSTMPCellWithClip']
19
20from mxnet.gluon.contrib.rnn import LSTMPCell
21
22
23class LSTMPCellWithClip(LSTMPCell):
24    r"""Long-Short Term Memory Projected (LSTMP) network cell with cell clip and projection clip.
25    Each call computes the following function:
26
27    .. math::
28
29        \DeclareMathOperator{\sigmoid}{sigmoid}
30        \begin{array}{ll}
31        i_t = \sigmoid(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
32        f_t = \sigmoid(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
33        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
34        o_t = \sigmoid(W_{io} x_t + b_{io} + W_{ro} r_{(t-1)} + b_{ro}) \\
35        c_t = c_{\text{clip}}(f_t * c_{(t-1)} + i_t * g_t) \\
36        h_t = o_t * \tanh(c_t) \\
37        r_t = p_{\text{clip}}(W_{hr} h_t)
38        \end{array}
39
40    where :math:`c_{\text{clip}}` is the cell clip applied on the next cell;
41    :math:`r_t` is the projected recurrent activation at time `t`,
42    :math:`p_{\text{clip}}` means apply projection clip on he projected output.
43    math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
44    cell state at time `t`, :math:`x_t` is the input at time `t`, and :math:`i_t`,
45    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
46    out gates, respectively.
47
48    Parameters
49    ----------
50    hidden_size : int
51        Number of units in cell state symbol.
52    projection_size : int
53        Number of units in output symbol.
54    i2h_weight_initializer : str or Initializer
55        Initializer for the input weights matrix, used for the linear
56        transformation of the inputs.
57    h2h_weight_initializer : str or Initializer
58        Initializer for the recurrent weights matrix, used for the linear
59        transformation of the hidden state.
60    h2r_weight_initializer : str or Initializer
61        Initializer for the projection weights matrix, used for the linear
62        transformation of the recurrent state.
63    i2h_bias_initializer : str or Initializer, default 'lstmbias'
64        Initializer for the bias vector. By default, bias for the forget
65        gate is initialized to 1 while all other biases are initialized
66        to zero.
67    h2h_bias_initializer : str or Initializer
68        Initializer for the bias vector.
69    prefix : str
70        Prefix for name of `Block`s
71        (and name of weight if params is `None`).
72    params : Parameter or None
73        Container for weight sharing between cells.
74        Created if `None`.
75    cell_clip : float
76        Clip cell state between `[-cell_clip, cell_clip]` in LSTMPCellWithClip cell
77    projection_clip : float
78        Clip projection between `[-projection_clip, projection_clip]` in LSTMPCellWithClip cell
79    """
80    def __init__(self, hidden_size, projection_size,
81                 i2h_weight_initializer=None, h2h_weight_initializer=None,
82                 h2r_weight_initializer=None,
83                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
84                 input_size=0, cell_clip=None, projection_clip=None, prefix=None, params=None):
85        super(LSTMPCellWithClip, self).__init__(hidden_size,
86                                                projection_size,
87                                                i2h_weight_initializer,
88                                                h2h_weight_initializer,
89                                                h2r_weight_initializer,
90                                                i2h_bias_initializer,
91                                                h2h_bias_initializer,
92                                                input_size,
93                                                prefix=prefix,
94                                                params=params)
95
96        self._cell_clip = cell_clip
97        self._projection_clip = projection_clip
98
99    # pylint: disable= arguments-differ
100    def hybrid_forward(self, F, inputs, states, i2h_weight,
101                       h2h_weight, h2r_weight, i2h_bias, h2h_bias):
102        r"""Hybrid forward computation for Long-Short Term Memory Projected network cell
103        with cell clip and projection clip.
104
105        Parameters
106        ----------
107        inputs : input tensor with shape `(batch_size, input_size)`.
108        states : a list of two initial recurrent state tensors, with shape
109            `(batch_size, projection_size)` and `(batch_size, hidden_size)` respectively.
110
111        Returns
112        --------
113        out : output tensor with shape `(batch_size, num_hidden)`.
114        next_states : a list of two output recurrent state tensors. Each has
115            the same shape as `states`.
116        """
117        prefix = 't%d_'%self._counter
118        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
119                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
120        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
121                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
122        gates = i2h + h2h
123        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
124        in_gate = F.Activation(slice_gates[0], act_type='sigmoid', name=prefix+'i')
125        forget_gate = F.Activation(slice_gates[1], act_type='sigmoid', name=prefix+'f')
126        in_transform = F.Activation(slice_gates[2], act_type='tanh', name=prefix+'c')
127        out_gate = F.Activation(slice_gates[3], act_type='sigmoid', name=prefix+'o')
128        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
129                                   name=prefix+'state')
130        if self._cell_clip is not None:
131            next_c = next_c.clip(-self._cell_clip, self._cell_clip)
132        hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type='tanh'),
133                                  name=prefix+'hidden')
134        next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size,
135                                  weight=h2r_weight, no_bias=True, name=prefix+'out')
136        if self._projection_clip is not None:
137            next_r = next_r.clip(-self._projection_clip, self._projection_clip)
138
139        return next_r, [next_r, next_c]
140