1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=no-member, invalid-name, protected-access, no-self-use
20# pylint: disable=too-many-branches, too-many-arguments, no-self-use
21# pylint: disable=too-many-lines, arguments-differ
22"""Definition of various recurrent neural network layers."""
23import re
24
25__all__ = ['RNN', 'LSTM', 'GRU']
26
27from ... import ndarray, symbol
28from .. import HybridBlock, tensor_types
29from . import rnn_cell
30from ...util import is_np_array
31
32
33class _RNNLayer(HybridBlock):
34    """Implementation of recurrent layers."""
35    def __init__(self, hidden_size, num_layers, layout,
36                 dropout, bidirectional, input_size,
37                 i2h_weight_initializer, h2h_weight_initializer,
38                 i2h_bias_initializer, h2h_bias_initializer,
39                 mode, projection_size, h2r_weight_initializer,
40                 lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
41                 dtype, use_sequence_length=False, **kwargs):
42        super(_RNNLayer, self).__init__(**kwargs)
43        assert layout in ('TNC', 'NTC'), \
44            "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
45        self._hidden_size = hidden_size
46        self._projection_size = projection_size if projection_size else None
47        self._num_layers = num_layers
48        self._mode = mode
49        self._layout = layout
50        self._dropout = dropout
51        self._dir = 2 if bidirectional else 1
52        self._input_size = input_size
53        self._i2h_weight_initializer = i2h_weight_initializer
54        self._h2h_weight_initializer = h2h_weight_initializer
55        self._i2h_bias_initializer = i2h_bias_initializer
56        self._h2h_bias_initializer = h2h_bias_initializer
57        self._h2r_weight_initializer = h2r_weight_initializer
58        self._lstm_state_clip_min = lstm_state_clip_min
59        self._lstm_state_clip_max = lstm_state_clip_max
60        self._lstm_state_clip_nan = lstm_state_clip_nan
61        self._dtype = dtype
62        self._use_sequence_length = use_sequence_length
63        self.skip_states = None
64
65        self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
66
67        ng, ni, nh = self._gates, input_size, hidden_size
68        if not projection_size:
69            for i in range(num_layers):
70                for j in ['l', 'r'][:self._dir]:
71                    self._register_param('{}{}_i2h_weight'.format(j, i),
72                                         shape=(ng*nh, ni),
73                                         init=i2h_weight_initializer, dtype=dtype)
74                    self._register_param('{}{}_h2h_weight'.format(j, i),
75                                         shape=(ng*nh, nh),
76                                         init=h2h_weight_initializer, dtype=dtype)
77                    self._register_param('{}{}_i2h_bias'.format(j, i),
78                                         shape=(ng*nh,),
79                                         init=i2h_bias_initializer, dtype=dtype)
80                    self._register_param('{}{}_h2h_bias'.format(j, i),
81                                         shape=(ng*nh,),
82                                         init=h2h_bias_initializer, dtype=dtype)
83                ni = nh * self._dir
84        else:
85            np = self._projection_size
86            for i in range(num_layers):
87                for j in ['l', 'r'][:self._dir]:
88                    self._register_param('{}{}_i2h_weight'.format(j, i),
89                                         shape=(ng*nh, ni),
90                                         init=i2h_weight_initializer, dtype=dtype)
91                    self._register_param('{}{}_h2h_weight'.format(j, i),
92                                         shape=(ng*nh, np),
93                                         init=h2h_weight_initializer, dtype=dtype)
94                    self._register_param('{}{}_i2h_bias'.format(j, i),
95                                         shape=(ng*nh,),
96                                         init=i2h_bias_initializer, dtype=dtype)
97                    self._register_param('{}{}_h2h_bias'.format(j, i),
98                                         shape=(ng*nh,),
99                                         init=h2h_bias_initializer, dtype=dtype)
100                    self._register_param('{}{}_h2r_weight'.format(j, i),
101                                         shape=(np, nh),
102                                         init=h2r_weight_initializer, dtype=dtype)
103                ni = np * self._dir
104
105    def _register_param(self, name, shape, init, dtype):
106        p = self.params.get(name, shape=shape, init=init,
107                            allow_deferred_init=True, dtype=dtype)
108        setattr(self, name, p)
109        return p
110
111    def __repr__(self):
112        s = '{name}({mapping}, {_layout}'
113        if self._num_layers != 1:
114            s += ', num_layers={_num_layers}'
115        if self._dropout != 0:
116            s += ', dropout={_dropout}'
117        if self._dir == 2:
118            s += ', bidirectional'
119        s += ')'
120        shape = self.l0_i2h_weight.shape
121        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
122        return s.format(name=self.__class__.__name__,
123                        mapping=mapping,
124                        **self.__dict__)
125
126    def _collect_params_with_prefix(self, prefix=''):
127        if prefix:
128            prefix += '.'
129        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z')
130        def convert_key(m, bidirectional): # for compatibility with old parameter format
131            d, l, g, t = [m.group(i) for i in range(1, 5)]
132            if bidirectional:
133                return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t)
134            else:
135                return '_unfused.{}.{}_{}'.format(l, g, t)
136        bidirectional = any(pattern.match(k).group(1) == 'r' for k in self._reg_params)
137
138        ret = {prefix + convert_key(pattern.match(key), bidirectional) : val
139               for key, val in self._reg_params.items()}
140        for name, child in self._children.items():
141            ret.update(child._collect_params_with_prefix(prefix + name))
142        return ret
143
144    def state_info(self, batch_size=0):
145        raise NotImplementedError
146
147    def _unfuse(self):
148        """Unfuses the fused RNN in to a stack of rnn cells."""
149        assert not self._projection_size, "_unfuse does not support projection layer yet!"
150        assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \
151                "_unfuse does not support state clipping yet!"
152        get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
153                                                                  activation='relu',
154                                                                  **kwargs),
155                    'rnn_tanh': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
156                                                                  activation='tanh',
157                                                                  **kwargs),
158                    'lstm': lambda **kwargs: rnn_cell.LSTMCell(self._hidden_size,
159                                                               **kwargs),
160                    'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
161                                                             **kwargs)}[self._mode]
162
163        stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params)
164        with stack.name_scope():
165            ni = self._input_size
166            for i in range(self._num_layers):
167                kwargs = {'input_size': ni,
168                          'i2h_weight_initializer': self._i2h_weight_initializer,
169                          'h2h_weight_initializer': self._h2h_weight_initializer,
170                          'i2h_bias_initializer': self._i2h_bias_initializer,
171                          'h2h_bias_initializer': self._h2h_bias_initializer}
172                if self._dir == 2:
173                    stack.add(rnn_cell.BidirectionalCell(
174                        get_cell(prefix='l%d_'%i, **kwargs),
175                        get_cell(prefix='r%d_'%i, **kwargs)))
176                else:
177                    stack.add(get_cell(prefix='l%d_'%i, **kwargs))
178
179                if self._dropout > 0 and i != self._num_layers - 1:
180                    stack.add(rnn_cell.DropoutCell(self._dropout))
181
182                ni = self._hidden_size * self._dir
183
184        return stack
185
186    def cast(self, dtype):
187        super(_RNNLayer, self).cast(dtype)
188        self._dtype = dtype
189
190    def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
191        """Initial state for this cell.
192
193        Parameters
194        ----------
195        batch_size: int
196            Only required for `NDArray` API. Size of the batch ('N' in layout).
197            Dimension of the input.
198        func : callable, default `ndarray.zeros`
199            Function for creating initial state.
200
201            For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
202            `symbol.var` etc. Use `symbol.var` if you want to directly
203            feed input as states.
204
205            For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc.
206
207        **kwargs :
208            Additional keyword arguments passed to func. For example
209            `mean`, `std`, `dtype`, etc.
210
211        Returns
212        -------
213        states : nested list of Symbol
214            Starting states for the first RNN step.
215        """
216        states = []
217        for i, info in enumerate(self.state_info(batch_size)):
218            if info is not None:
219                info.update(kwargs)
220            else:
221                info = kwargs
222            state = func(name='%sh0_%d' % (self.prefix, i), **info)
223            if is_np_array():
224                state = state.as_np_ndarray()
225            states.append(state)
226        return states
227
228    def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
229        self.skip_states = states is None
230        if states is None:
231            if isinstance(inputs, ndarray.NDArray):
232                batch_size = inputs.shape[self._layout.find('N')]
233                states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
234            else:
235                states = self.begin_state(0, func=symbol.zeros)
236        if isinstance(states, tensor_types):
237            states = [states]
238
239        if self._use_sequence_length:
240            return super(_RNNLayer, self).__call__(inputs, states, sequence_length, **kwargs)
241        else:
242            return super(_RNNLayer, self).__call__(inputs, states, **kwargs)
243
244    def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
245        if F is ndarray:
246            batch_size = inputs.shape[self._layout.find('N')]
247
248        if F is ndarray:
249            for state, info in zip(states, self.state_info(batch_size)):
250                if state.shape != info['shape']:
251                    raise ValueError(
252                        "Invalid recurrent state shape. Expecting %s, got %s."%(
253                            str(info['shape']), str(state.shape)))
254        out = self._forward_kernel(F, inputs, states, sequence_length, **kwargs)
255
256        # out is (output, state)
257        return out[0] if self.skip_states else out
258
259    def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
260        """ forward using CUDNN or CPU kenrel"""
261        swapaxes = F.np.swapaxes if is_np_array() else F.swapaxes
262        if self._layout == 'NTC':
263            inputs = swapaxes(inputs, 0, 1)
264        if self._projection_size is None:
265            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
266                      for t in ['weight', 'bias']
267                      for l in range(self._num_layers)
268                      for d in ['l', 'r'][:self._dir]
269                      for g in ['i2h', 'h2h'])
270        else:
271            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
272                      for t in ['weight', 'bias']
273                      for l in range(self._num_layers)
274                      for d in ['l', 'r'][:self._dir]
275                      for g in ['i2h', 'h2h', 'h2r']
276                      if g != 'h2r' or t != 'bias')
277
278        rnn_param_concat = F.np._internal.rnn_param_concat if is_np_array()\
279            else F._internal._rnn_param_concat
280        params = rnn_param_concat(*params, dim=0)
281
282        if self._use_sequence_length:
283            rnn_args = states + [sequence_length]
284        else:
285            rnn_args = states
286
287        rnn_fn = F.npx.rnn if is_np_array() else F.RNN
288        rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
289                     state_size=self._hidden_size, projection_size=self._projection_size,
290                     num_layers=self._num_layers, bidirectional=self._dir == 2,
291                     p=self._dropout, state_outputs=True, mode=self._mode,
292                     lstm_state_clip_min=self._lstm_state_clip_min,
293                     lstm_state_clip_max=self._lstm_state_clip_max,
294                     lstm_state_clip_nan=self._lstm_state_clip_nan)
295
296        if self._mode == 'lstm':
297            outputs, states = rnn[0], [rnn[1], rnn[2]]
298        else:
299            outputs, states = rnn[0], [rnn[1]]
300
301        if self._layout == 'NTC':
302            outputs = swapaxes(outputs, 0, 1)
303
304        return outputs, states
305
306
307class RNN(_RNNLayer):
308    r"""Applies a multi-layer Elman RNN with `tanh` or `ReLU` non-linearity to an input sequence.
309
310    For each element in the input sequence, each layer computes the following
311    function:
312
313    .. math::
314        h_t = \tanh(w_{ih} * x_t + b_{ih}  +  w_{hh} * h_{(t-1)} + b_{hh})
315
316    where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the output
317    of the previous layer at time `t` or :math:`input_t` for the first layer.
318    If nonlinearity='relu', then `ReLU` is used instead of `tanh`.
319
320    Parameters
321    ----------
322    hidden_size: int
323        The number of features in the hidden state h.
324    num_layers: int, default 1
325        Number of recurrent layers.
326    activation: {'relu' or 'tanh'}, default 'relu'
327        The activation function to use.
328    layout : str, default 'TNC'
329        The format of input and output tensors. T, N and C stand for
330        sequence length, batch size, and feature dimensions respectively.
331    dropout: float, default 0
332        If non-zero, introduces a dropout layer on the outputs of each
333        RNN layer except the last layer.
334    bidirectional: bool, default False
335        If `True`, becomes a bidirectional RNN.
336    i2h_weight_initializer : str or Initializer
337        Initializer for the input weights matrix, used for the linear
338        transformation of the inputs.
339    h2h_weight_initializer : str or Initializer
340        Initializer for the recurrent weights matrix, used for the linear
341        transformation of the recurrent state.
342    i2h_bias_initializer : str or Initializer
343        Initializer for the bias vector.
344    h2h_bias_initializer : str or Initializer
345        Initializer for the bias vector.
346    input_size: int, default 0
347        The number of expected features in the input x.
348        If not specified, it will be inferred from input.
349    dtype : str, default 'float32'
350        Type to initialize the parameters and default states to
351    prefix : str or None
352        Prefix of this `Block`.
353    params : ParameterDict or None
354        Shared Parameters for this `Block`.
355
356
357    Inputs:
358        - **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
359          when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
360          using transpose() operator which adds performance overhead. Consider creating
361          batches in TNC layout during data batching step.
362
363        - **states**: initial recurrent state tensor with shape
364          `(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
365          shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
366          `states` is None, zeros will be used as default begin states.
367
368    Outputs:
369        - **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
370          when `layout` is "TNC". If `bidirectional` is True, output shape will instead
371          be `(sequence_length, batch_size, 2*num_hidden)`
372        - **out_states**: output recurrent state tensor with the same shape as `states`.
373          If `states` is None `out_states` will not be returned.
374
375
376    Examples
377    --------
378    >>> layer = mx.gluon.rnn.RNN(100, 3)
379    >>> layer.initialize()
380    >>> input = mx.nd.random.uniform(shape=(5, 3, 10))
381    >>> # by default zeros are used as begin state
382    >>> output = layer(input)
383    >>> # manually specify begin state.
384    >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
385    >>> output, hn = layer(input, h0)
386    """
387    def __init__(self, hidden_size, num_layers=1, activation='relu',
388                 layout='TNC', dropout=0, bidirectional=False,
389                 i2h_weight_initializer=None, h2h_weight_initializer=None,
390                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
391                 input_size=0, dtype='float32', **kwargs):
392        super(RNN, self).__init__(hidden_size, num_layers, layout,
393                                  dropout, bidirectional, input_size,
394                                  i2h_weight_initializer, h2h_weight_initializer,
395                                  i2h_bias_initializer, h2h_bias_initializer,
396                                  'rnn_'+activation, None, None, None, None, False,
397                                  dtype, **kwargs)
398
399    def state_info(self, batch_size=0):
400        return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
401                 '__layout__': 'LNC', 'dtype': self._dtype}]
402
403
404class LSTM(_RNNLayer):
405    r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
406
407    For each element in the input sequence, each layer computes the following
408    function:
409
410    .. math::
411        \begin{array}{ll}
412        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
413        f_t = sigmoid(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
414        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
415        o_t = sigmoid(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
416        c_t = f_t * c_{(t-1)} + i_t * g_t \\
417        h_t = o_t * \tanh(c_t)
418        \end{array}
419
420    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
421    cell state at time `t`, :math:`x_t` is the hidden state of the previous
422    layer at time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
423    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
424    out gates, respectively.
425
426    Parameters
427    ----------
428    hidden_size: int
429        The number of features in the hidden state h.
430    num_layers: int, default 1
431        Number of recurrent layers.
432    layout : str, default 'TNC'
433        The format of input and output tensors. T, N and C stand for
434        sequence length, batch size, and feature dimensions respectively.
435    dropout: float, default 0
436        If non-zero, introduces a dropout layer on the outputs of each
437        RNN layer except the last layer.
438    bidirectional: bool, default False
439        If `True`, becomes a bidirectional RNN.
440    i2h_weight_initializer : str or Initializer
441        Initializer for the input weights matrix, used for the linear
442        transformation of the inputs.
443    h2h_weight_initializer : str or Initializer
444        Initializer for the recurrent weights matrix, used for the linear
445        transformation of the recurrent state.
446    i2h_bias_initializer : str or Initializer, default 'lstmbias'
447        Initializer for the bias vector. By default, bias for the forget
448        gate is initialized to 1 while all other biases are initialized
449        to zero.
450    h2h_bias_initializer : str or Initializer
451        Initializer for the bias vector.
452    projection_size: int, default None
453        The number of features after projection.
454    h2r_weight_initializer : str or Initializer, default None
455        Initializer for the projected recurrent weights matrix, used for the linear
456        transformation of the recurrent state to the projected space.
457    state_clip_min : float or None, default None
458        Minimum clip value of LSTM states. This option must be used together with
459        state_clip_max. If None, clipping is not applied.
460    state_clip_max : float or None, default None
461        Maximum clip value of LSTM states. This option must be used together with
462        state_clip_min. If None, clipping is not applied.
463    state_clip_nan : boolean, default False
464        Whether to stop NaN from propagating in state by clipping it to min/max.
465        If the clipping range is not specified, this option is ignored.
466    dtype : str, default 'float32'
467        Type to initialize the parameters and default states to
468    input_size: int, default 0
469        The number of expected features in the input x.
470        If not specified, it will be inferred from input.
471    prefix : str or None
472        Prefix of this `Block`.
473    params : `ParameterDict` or `None`
474        Shared Parameters for this `Block`.
475
476
477    Inputs:
478        - **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
479          when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
480          using transpose() operator which adds performance overhead. Consider creating
481          batches in TNC layout during data batching step.
482        - **states**: a list of two initial recurrent state tensors. Each has shape
483          `(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
484          shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
485          `states` is None, zeros will be used as default begin states.
486
487    Outputs:
488        - **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
489          when `layout` is "TNC". If `bidirectional` is True, output shape will instead
490          be `(sequence_length, batch_size, 2*num_hidden)`
491        - **out_states**: a list of two output recurrent state tensors with the same
492          shape as in `states`. If `states` is None `out_states` will not be returned.
493
494
495    Examples
496    --------
497    >>> layer = mx.gluon.rnn.LSTM(100, 3)
498    >>> layer.initialize()
499    >>> input = mx.nd.random.uniform(shape=(5, 3, 10))
500    >>> # by default zeros are used as begin state
501    >>> output = layer(input)
502    >>> # manually specify begin state.
503    >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
504    >>> c0 = mx.nd.random.uniform(shape=(3, 3, 100))
505    >>> output, hn = layer(input, [h0, c0])
506    """
507    def __init__(self, hidden_size, num_layers=1, layout='TNC',
508                 dropout=0, bidirectional=False, input_size=0,
509                 i2h_weight_initializer=None, h2h_weight_initializer=None,
510                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
511                 projection_size=None, h2r_weight_initializer=None,
512                 state_clip_min=None, state_clip_max=None, state_clip_nan=False,
513                 dtype='float32', **kwargs):
514        super(LSTM, self).__init__(hidden_size, num_layers, layout,
515                                   dropout, bidirectional, input_size,
516                                   i2h_weight_initializer, h2h_weight_initializer,
517                                   i2h_bias_initializer, h2h_bias_initializer,
518                                   'lstm', projection_size, h2r_weight_initializer,
519                                   state_clip_min, state_clip_max, state_clip_nan,
520                                   dtype, **kwargs)
521
522    def state_info(self, batch_size=0):
523        if self._projection_size is None:
524            return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
525                     '__layout__': 'LNC', 'dtype': self._dtype},
526                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
527                     '__layout__': 'LNC', 'dtype': self._dtype}]
528        else:
529            return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
530                     '__layout__': 'LNC', 'dtype': self._dtype},
531                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
532                     '__layout__': 'LNC', 'dtype': self._dtype}]
533
534
535class GRU(_RNNLayer):
536    r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
537    Note: this is an implementation of the cuDNN version of GRUs
538    (slight modification compared to Cho et al. 2014; the reset gate :math:`r_t`
539    is applied after matrix multiplication).
540
541    For each element in the input sequence, each layer computes the following
542    function:
543
544    .. math::
545        \begin{array}{ll}
546        r_t = sigmoid(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
547        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
548        n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)} + b_{hn})) \\
549        h_t = (1 - i_t) * n_t + i_t * h_{(t-1)} \\
550        \end{array}
551
552    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
553    state of the previous layer at time `t` or :math:`input_t` for the first layer,
554    and :math:`r_t`, :math:`i_t`, :math:`n_t` are the reset, input, and new gates, respectively.
555
556    Parameters
557    ----------
558    hidden_size: int
559        The number of features in the hidden state h
560    num_layers: int, default 1
561        Number of recurrent layers.
562    layout : str, default 'TNC'
563        The format of input and output tensors. T, N and C stand for
564        sequence length, batch size, and feature dimensions respectively.
565    dropout: float, default 0
566        If non-zero, introduces a dropout layer on the outputs of each
567        RNN layer except the last layer
568    bidirectional: bool, default False
569        If True, becomes a bidirectional RNN.
570    i2h_weight_initializer : str or Initializer
571        Initializer for the input weights matrix, used for the linear
572        transformation of the inputs.
573    h2h_weight_initializer : str or Initializer
574        Initializer for the recurrent weights matrix, used for the linear
575        transformation of the recurrent state.
576    i2h_bias_initializer : str or Initializer
577        Initializer for the bias vector.
578    h2h_bias_initializer : str or Initializer
579        Initializer for the bias vector.
580    dtype : str, default 'float32'
581        Type to initialize the parameters and default states to
582    input_size: int, default 0
583        The number of expected features in the input x.
584        If not specified, it will be inferred from input.
585    prefix : str or None
586        Prefix of this `Block`.
587    params : ParameterDict or None
588        Shared Parameters for this `Block`.
589
590
591    Inputs:
592        - **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
593          when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
594          using transpose() operator which adds performance overhead. Consider creating
595          batches in TNC layout during data batching step.
596        - **states**: initial recurrent state tensor with shape
597          `(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
598          shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
599          `states` is None, zeros will be used as default begin states.
600
601    Outputs:
602        - **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
603          when `layout` is "TNC". If `bidirectional` is True, output shape will instead
604          be `(sequence_length, batch_size, 2*num_hidden)`
605        - **out_states**: output recurrent state tensor with the same shape as `states`.
606          If `states` is None `out_states` will not be returned.
607
608
609    Examples
610    --------
611    >>> layer = mx.gluon.rnn.GRU(100, 3)
612    >>> layer.initialize()
613    >>> input = mx.nd.random.uniform(shape=(5, 3, 10))
614    >>> # by default zeros are used as begin state
615    >>> output = layer(input)
616    >>> # manually specify begin state.
617    >>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
618    >>> output, hn = layer(input, h0)
619    """
620    def __init__(self, hidden_size, num_layers=1, layout='TNC',
621                 dropout=0, bidirectional=False, input_size=0,
622                 i2h_weight_initializer=None, h2h_weight_initializer=None,
623                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
624                 dtype='float32', **kwargs):
625        super(GRU, self).__init__(hidden_size, num_layers, layout,
626                                  dropout, bidirectional, input_size,
627                                  i2h_weight_initializer, h2h_weight_initializer,
628                                  i2h_bias_initializer, h2h_bias_initializer,
629                                  'gru', None, None, None, None, False,
630                                  dtype, **kwargs)
631
632    def state_info(self, batch_size=0):
633        return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
634                 '__layout__': 'LNC', 'dtype': self._dtype}]
635