1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=no-member, invalid-name, protected-access, no-self-use
20# pylint: disable=too-many-branches, too-many-arguments, no-self-use
21# pylint: disable=too-many-lines, arguments-differ
22"""Definition of various recurrent neural network cells."""
23__all__ = ['RecurrentCell', 'HybridRecurrentCell',
24           'RNNCell', 'LSTMCell', 'GRUCell',
25           'SequentialRNNCell', 'HybridSequentialRNNCell', 'DropoutCell',
26           'ModifierCell', 'ZoneoutCell', 'ResidualCell',
27           'BidirectionalCell']
28
29from ... import symbol, ndarray
30from ...base import string_types, numeric_types, _as_list
31from ..block import Block, HybridBlock
32from ..utils import _indent
33from .. import tensor_types
34from ..nn import LeakyReLU
35
36
37def _cells_state_info(cells, batch_size):
38    return sum([c.state_info(batch_size) for c in cells], [])
39
40def _cells_begin_state(cells, **kwargs):
41    return sum([c.begin_state(**kwargs) for c in cells], [])
42
43def _get_begin_state(cell, F, begin_state, inputs, batch_size):
44    if begin_state is None:
45        if F is ndarray:
46            ctx = inputs.context if isinstance(inputs, tensor_types) else inputs[0].context
47            with ctx:
48                begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size)
49        else:
50            begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size)
51    return begin_state
52
53def _format_sequence(length, inputs, layout, merge, in_layout=None):
54    assert inputs is not None, \
55        "unroll(inputs=None) has been deprecated. " \
56        "Please create input variables outside unroll."
57
58    axis = layout.find('T')
59    batch_axis = layout.find('N')
60    batch_size = 0
61    in_axis = in_layout.find('T') if in_layout is not None else axis
62    if isinstance(inputs, symbol.Symbol):
63        F = symbol
64        if merge is False:
65            assert len(inputs.list_outputs()) == 1, \
66                "unroll doesn't allow grouped symbol as input. Please convert " \
67                "to list with list(inputs) first or let unroll handle splitting."
68            inputs = list(symbol.split(inputs, axis=in_axis, num_outputs=length,
69                                       squeeze_axis=1))
70    elif isinstance(inputs, ndarray.NDArray):
71        F = ndarray
72        batch_size = inputs.shape[batch_axis]
73        if merge is False:
74            assert length is None or length == inputs.shape[in_axis]
75            inputs = _as_list(ndarray.split(inputs, axis=in_axis,
76                                            num_outputs=inputs.shape[in_axis],
77                                            squeeze_axis=1))
78    else:
79        assert length is None or len(inputs) == length
80        if isinstance(inputs[0], symbol.Symbol):
81            F = symbol
82        else:
83            F = ndarray
84            batch_size = inputs[0].shape[0]
85        if merge is True:
86            inputs = F.stack(*inputs, axis=axis)
87            in_axis = axis
88
89    if isinstance(inputs, tensor_types) and axis != in_axis:
90        inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis)
91
92    return inputs, axis, F, batch_size
93
94def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, merge):
95    assert valid_length is not None
96    if not isinstance(data, tensor_types):
97        data = F.stack(*data, axis=time_axis)
98    outputs = F.SequenceMask(data, sequence_length=valid_length, use_sequence_length=True,
99                             axis=time_axis)
100    if not merge:
101        outputs = _as_list(F.split(outputs, num_outputs=length, axis=time_axis,
102                                   squeeze_axis=True))
103    return outputs
104
105def _reverse_sequences(sequences, unroll_step, valid_length=None):
106    if isinstance(sequences[0], symbol.Symbol):
107        F = symbol
108    else:
109        F = ndarray
110
111    if valid_length is None:
112        reversed_sequences = list(reversed(sequences))
113    else:
114        reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
115                                               sequence_length=valid_length,
116                                               use_sequence_length=True)
117        if unroll_step > 1 or F is symbol:
118            reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
119        else:
120            reversed_sequences = [reversed_sequences[0]]
121
122    return reversed_sequences
123
124
125class RecurrentCell(Block):
126    """Abstract base class for RNN cells
127
128    Parameters
129    ----------
130    prefix : str, optional
131        Prefix for names of `Block`s
132        (this prefix is also used for names of weights if `params` is `None`
133        i.e. if `params` are being created and not reused)
134    params : Parameter or None, default None
135        Container for weight sharing between cells.
136        A new Parameter container is created if `params` is `None`.
137    """
138    def __init__(self, prefix=None, params=None):
139        super(RecurrentCell, self).__init__(prefix=prefix, params=params)
140        self._modified = False
141        self.reset()
142
143    def reset(self):
144        """Reset before re-using the cell for another graph."""
145        self._init_counter = -1
146        self._counter = -1
147        for cell in self._children.values():
148            cell.reset()
149
150    def state_info(self, batch_size=0):
151        """shape and layout information of states"""
152        raise NotImplementedError()
153
154    def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
155        """Initial state for this cell.
156
157        Parameters
158        ----------
159        func : callable, default symbol.zeros
160            Function for creating initial state.
161
162            For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
163            `symbol.var etc`. Use `symbol.var` if you want to directly
164            feed input as states.
165
166            For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc.
167        batch_size: int, default 0
168            Only required for NDArray API. Size of the batch ('N' in layout)
169            dimension of input.
170
171        **kwargs :
172            Additional keyword arguments passed to func. For example
173            `mean`, `std`, `dtype`, etc.
174
175        Returns
176        -------
177        states : nested list of Symbol
178            Starting states for the first RNN step.
179        """
180        assert not self._modified, \
181            "After applying modifier cells (e.g. ZoneoutCell) the base " \
182            "cell cannot be called directly. Call the modifier cell instead."
183        states = []
184        for info in self.state_info(batch_size):
185            self._init_counter += 1
186            if info is not None:
187                info.update(kwargs)
188            else:
189                info = kwargs
190            state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter),
191                         **info)
192            states.append(state)
193        return states
194
195    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
196               valid_length=None):
197        """Unrolls an RNN cell across time steps.
198
199        Parameters
200        ----------
201        length : int
202            Number of steps to unroll.
203        inputs : Symbol, list of Symbol, or None
204            If `inputs` is a single Symbol (usually the output
205            of Embedding symbol), it should have shape
206            (batch_size, length, ...) if `layout` is 'NTC',
207            or (length, batch_size, ...) if `layout` is 'TNC'.
208
209            If `inputs` is a list of symbols (usually output of
210            previous unroll), they should all have shape
211            (batch_size, ...).
212        begin_state : nested list of Symbol, optional
213            Input states created by `begin_state()`
214            or output state of another cell.
215            Created from `begin_state()` if `None`.
216        layout : str, optional
217            `layout` of input symbol. Only used if inputs
218            is a single Symbol.
219        merge_outputs : bool, optional
220            If `False`, returns outputs as a list of Symbols.
221            If `True`, concatenates output across time steps
222            and returns a single symbol with shape
223            (batch_size, length, ...) if layout is 'NTC',
224            or (length, batch_size, ...) if layout is 'TNC'.
225            If `None`, output whatever is faster.
226        valid_length : Symbol, NDArray or None
227            `valid_length` specifies the length of the sequences in the batch without padding.
228            This option is especially useful for building sequence-to-sequence models where
229            the input and output sequences would potentially be padded.
230            If `valid_length` is None, all sequences are assumed to have the same length.
231            If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
232            The ith element will be the length of the ith sequence in the batch.
233            The last valid state will be return and the padded outputs will be masked with 0.
234            Note that `valid_length` must be smaller or equal to `length`.
235
236        Returns
237        -------
238        outputs : list of Symbol or Symbol
239            Symbol (if `merge_outputs` is True) or list of Symbols
240            (if `merge_outputs` is False) corresponding to the output from
241            the RNN from this unrolling.
242
243        states : list of Symbol
244            The new state of this RNN after this unrolling.
245            The type of this symbol is same as the output of `begin_state()`.
246        """
247        # pylint: disable=too-many-locals
248        self.reset()
249
250        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
251        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
252
253        states = begin_state
254        outputs = []
255        all_states = []
256        for i in range(length):
257            output, states = self(inputs[i], states)
258            outputs.append(output)
259            if valid_length is not None:
260                all_states.append(states)
261        if valid_length is not None:
262            states = [F.SequenceLast(F.stack(*ele_list, axis=0),
263                                     sequence_length=valid_length,
264                                     use_sequence_length=True,
265                                     axis=0)
266                      for ele_list in zip(*all_states)]
267            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, True)
268        outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
269
270        return outputs, states
271
272    #pylint: disable=no-self-use
273    def _get_activation(self, F, inputs, activation, **kwargs):
274        """Get activation function. Convert if is string"""
275        func = {'tanh': F.tanh,
276                'relu': F.relu,
277                'sigmoid': F.sigmoid,
278                'softsign': F.softsign}.get(activation)
279        if func:
280            return func(inputs, **kwargs)
281        elif isinstance(activation, string_types):
282            return F.Activation(inputs, act_type=activation, **kwargs)
283        elif isinstance(activation, LeakyReLU):
284            return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs)
285        return activation(inputs, **kwargs)
286
287    def forward(self, inputs, states):
288        """Unrolls the recurrent cell for one time step.
289
290        Parameters
291        ----------
292        inputs : sym.Variable
293            Input symbol, 2D, of shape (batch_size * num_units).
294        states : list of sym.Variable
295            RNN state from previous step or the output of begin_state().
296
297        Returns
298        -------
299        output : Symbol
300            Symbol corresponding to the output from the RNN when unrolling
301            for a single time step.
302        states : list of Symbol
303            The new state of this RNN after this unrolling.
304            The type of this symbol is same as the output of `begin_state()`.
305            This can be used as an input state to the next time step
306            of this RNN.
307
308        See Also
309        --------
310        begin_state: This function can provide the states for the first time step.
311        unroll: This function unrolls an RNN for a given number of (>=1) time steps.
312        """
313        # pylint: disable= arguments-differ
314        self._counter += 1
315        return super(RecurrentCell, self).forward(inputs, states)
316
317
318class HybridRecurrentCell(RecurrentCell, HybridBlock):
319    """HybridRecurrentCell supports hybridize."""
320    def __init__(self, prefix=None, params=None):
321        super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params)
322
323    def hybrid_forward(self, F, x, *args, **kwargs):
324        raise NotImplementedError
325
326
327class RNNCell(HybridRecurrentCell):
328    r"""Elman RNN recurrent neural network cell.
329
330    Each call computes the following function:
331
332    .. math::
333
334        h_t = \tanh(w_{ih} * x_t + b_{ih}  +  w_{hh} * h_{(t-1)} + b_{hh})
335
336    where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden
337    state of the previous layer at time `t` or :math:`input_t` for the first layer.
338    If nonlinearity='relu', then `ReLU` is used instead of `tanh`.
339
340    Parameters
341    ----------
342    hidden_size : int
343        Number of units in output symbol
344    activation : str or Symbol, default 'tanh'
345        Type of activation function.
346    i2h_weight_initializer : str or Initializer
347        Initializer for the input weights matrix, used for the linear
348        transformation of the inputs.
349    h2h_weight_initializer : str or Initializer
350        Initializer for the recurrent weights matrix, used for the linear
351        transformation of the recurrent state.
352    i2h_bias_initializer : str or Initializer, default 'zeros'
353        Initializer for the bias vector.
354    h2h_bias_initializer : str or Initializer, default 'zeros'
355        Initializer for the bias vector.
356    prefix : str, default ``'rnn_'``
357        Prefix for name of `Block`s
358        (and name of weight if params is `None`).
359    params : Parameter or None
360        Container for weight sharing between cells.
361        Created if `None`.
362
363
364    Inputs:
365        - **data**: input tensor with shape `(batch_size, input_size)`.
366        - **states**: a list of one initial recurrent state tensor with shape
367          `(batch_size, num_hidden)`.
368
369    Outputs:
370        - **out**: output tensor with shape `(batch_size, num_hidden)`.
371        - **next_states**: a list of one output recurrent state tensor with the
372          same shape as `states`.
373    """
374    def __init__(self, hidden_size, activation='tanh',
375                 i2h_weight_initializer=None, h2h_weight_initializer=None,
376                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
377                 input_size=0, prefix=None, params=None):
378        super(RNNCell, self).__init__(prefix=prefix, params=params)
379        self._hidden_size = hidden_size
380        self._activation = activation
381        self._input_size = input_size
382        self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
383                                          init=i2h_weight_initializer,
384                                          allow_deferred_init=True)
385        self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
386                                          init=h2h_weight_initializer,
387                                          allow_deferred_init=True)
388        self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
389                                        init=i2h_bias_initializer,
390                                        allow_deferred_init=True)
391        self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
392                                        init=h2h_bias_initializer,
393                                        allow_deferred_init=True)
394
395    def state_info(self, batch_size=0):
396        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
397
398    def _alias(self):
399        return 'rnn'
400
401    def __repr__(self):
402        s = '{name}({mapping}'
403        if hasattr(self, '_activation'):
404            s += ', {_activation}'
405        s += ')'
406        shape = self.i2h_weight.shape
407        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
408        return s.format(name=self.__class__.__name__,
409                        mapping=mapping,
410                        **self.__dict__)
411
412    def hybrid_forward(self, F, inputs, states, i2h_weight,
413                       h2h_weight, i2h_bias, h2h_bias):
414        prefix = 't%d_'%self._counter
415        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
416                               num_hidden=self._hidden_size,
417                               name=prefix+'i2h')
418        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
419                               num_hidden=self._hidden_size,
420                               name=prefix+'h2h')
421        i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
422        output = self._get_activation(F, i2h_plus_h2h, self._activation,
423                                      name=prefix+'out')
424
425        return output, [output]
426
427
428class LSTMCell(HybridRecurrentCell):
429    r"""Long-Short Term Memory (LSTM) network cell.
430
431    Each call computes the following function:
432
433    .. math::
434        \begin{array}{ll}
435        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
436        f_t = sigmoid(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
437        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
438        o_t = sigmoid(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
439        c_t = f_t * c_{(t-1)} + i_t * g_t \\
440        h_t = o_t * \tanh(c_t)
441        \end{array}
442
443    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
444    cell state at time `t`, :math:`x_t` is the hidden state of the previous
445    layer at time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
446    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
447    out gates, respectively.
448
449    Parameters
450    ----------
451    hidden_size : int
452        Number of units in output symbol.
453    i2h_weight_initializer : str or Initializer
454        Initializer for the input weights matrix, used for the linear
455        transformation of the inputs.
456    h2h_weight_initializer : str or Initializer
457        Initializer for the recurrent weights matrix, used for the linear
458        transformation of the recurrent state.
459    i2h_bias_initializer : str or Initializer, default 'zeros'
460        Initializer for the bias vector.
461    h2h_bias_initializer : str or Initializer, default 'zeros'
462        Initializer for the bias vector.
463    prefix : str, default ``'lstm_'``
464        Prefix for name of `Block`s
465        (and name of weight if params is `None`).
466    params : Parameter or None, default None
467        Container for weight sharing between cells.
468        Created if `None`.
469    activation : str, default 'tanh'
470        Activation type to use. See nd/symbol Activation
471        for supported types.
472    recurrent_activation : str, default 'sigmoid'
473        Activation type to use for the recurrent step. See nd/symbol Activation
474        for supported types.
475
476    Inputs:
477        - **data**: input tensor with shape `(batch_size, input_size)`.
478        - **states**: a list of two initial recurrent state tensors. Each has shape
479          `(batch_size, num_hidden)`.
480
481    Outputs:
482        - **out**: output tensor with shape `(batch_size, num_hidden)`.
483        - **next_states**: a list of two output recurrent state tensors. Each has
484          the same shape as `states`.
485    """
486    # pylint: disable=too-many-instance-attributes
487    def __init__(self, hidden_size,
488                 i2h_weight_initializer=None, h2h_weight_initializer=None,
489                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
490                 input_size=0, prefix=None, params=None, activation='tanh',
491                 recurrent_activation='sigmoid'):
492        super(LSTMCell, self).__init__(prefix=prefix, params=params)
493
494        self._hidden_size = hidden_size
495        self._input_size = input_size
496        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
497                                          init=i2h_weight_initializer,
498                                          allow_deferred_init=True)
499        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
500                                          init=h2h_weight_initializer,
501                                          allow_deferred_init=True)
502        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
503                                        init=i2h_bias_initializer,
504                                        allow_deferred_init=True)
505        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
506                                        init=h2h_bias_initializer,
507                                        allow_deferred_init=True)
508        self._activation = activation
509        self._recurrent_activation = recurrent_activation
510
511
512    def state_info(self, batch_size=0):
513        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
514                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
515
516    def _alias(self):
517        return 'lstm'
518
519    def __repr__(self):
520        s = '{name}({mapping})'
521        shape = self.i2h_weight.shape
522        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
523        return s.format(name=self.__class__.__name__,
524                        mapping=mapping,
525                        **self.__dict__)
526
527    def hybrid_forward(self, F, inputs, states, i2h_weight,
528                       h2h_weight, i2h_bias, h2h_bias):
529        # pylint: disable=too-many-locals
530        prefix = 't%d_'%self._counter
531        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
532                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
533        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
534                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
535        gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
536        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
537        in_gate = self._get_activation(
538            F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
539        forget_gate = self._get_activation(
540            F, slice_gates[1], self._recurrent_activation, name=prefix+'f')
541        in_transform = self._get_activation(
542            F, slice_gates[2], self._activation, name=prefix+'c')
543        out_gate = self._get_activation(
544            F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
545        next_c = F.elemwise_add(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
546                                F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
547                                name=prefix+'state')
548        next_h = F.elemwise_mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
549                                name=prefix+'out')
550
551        return next_h, [next_h, next_c]
552
553
554class GRUCell(HybridRecurrentCell):
555    r"""Gated Rectified Unit (GRU) network cell.
556    Note: this is an implementation of the cuDNN version of GRUs
557    (slight modification compared to Cho et al. 2014; the reset gate :math:`r_t`
558    is applied after matrix multiplication).
559
560    Each call computes the following function:
561
562    .. math::
563        \begin{array}{ll}
564        r_t = sigmoid(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
565        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
566        n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)} + b_{hn})) \\
567        h_t = (1 - i_t) * n_t + i_t * h_{(t-1)} \\
568        \end{array}
569
570    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
571    state of the previous layer at time `t` or :math:`input_t` for the first layer,
572    and :math:`r_t`, :math:`i_t`, :math:`n_t` are the reset, input, and new gates, respectively.
573
574    Parameters
575    ----------
576    hidden_size : int
577        Number of units in output symbol.
578    i2h_weight_initializer : str or Initializer
579        Initializer for the input weights matrix, used for the linear
580        transformation of the inputs.
581    h2h_weight_initializer : str or Initializer
582        Initializer for the recurrent weights matrix, used for the linear
583        transformation of the recurrent state.
584    i2h_bias_initializer : str or Initializer, default 'zeros'
585        Initializer for the bias vector.
586    h2h_bias_initializer : str or Initializer, default 'zeros'
587        Initializer for the bias vector.
588    prefix : str, default ``'gru_'``
589        prefix for name of `Block`s
590        (and name of weight if params is `None`).
591    params : Parameter or None, default None
592        Container for weight sharing between cells.
593        Created if `None`.
594
595
596    Inputs:
597        - **data**: input tensor with shape `(batch_size, input_size)`.
598        - **states**: a list of one initial recurrent state tensor with shape
599          `(batch_size, num_hidden)`.
600
601    Outputs:
602        - **out**: output tensor with shape `(batch_size, num_hidden)`.
603        - **next_states**: a list of one output recurrent state tensor with the
604          same shape as `states`.
605    """
606    def __init__(self, hidden_size,
607                 i2h_weight_initializer=None, h2h_weight_initializer=None,
608                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
609                 input_size=0, prefix=None, params=None):
610        super(GRUCell, self).__init__(prefix=prefix, params=params)
611        self._hidden_size = hidden_size
612        self._input_size = input_size
613        self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
614                                          init=i2h_weight_initializer,
615                                          allow_deferred_init=True)
616        self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
617                                          init=h2h_weight_initializer,
618                                          allow_deferred_init=True)
619        self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
620                                        init=i2h_bias_initializer,
621                                        allow_deferred_init=True)
622        self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
623                                        init=h2h_bias_initializer,
624                                        allow_deferred_init=True)
625
626    def state_info(self, batch_size=0):
627        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
628
629    def _alias(self):
630        return 'gru'
631
632    def __repr__(self):
633        s = '{name}({mapping})'
634        shape = self.i2h_weight.shape
635        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
636        return s.format(name=self.__class__.__name__,
637                        mapping=mapping,
638                        **self.__dict__)
639
640    def hybrid_forward(self, F, inputs, states, i2h_weight,
641                       h2h_weight, i2h_bias, h2h_bias):
642        # pylint: disable=too-many-locals
643        prefix = 't%d_'%self._counter
644        prev_state_h = states[0]
645        i2h = F.FullyConnected(data=inputs,
646                               weight=i2h_weight,
647                               bias=i2h_bias,
648                               num_hidden=self._hidden_size * 3,
649                               name=prefix+'i2h')
650        h2h = F.FullyConnected(data=prev_state_h,
651                               weight=h2h_weight,
652                               bias=h2h_bias,
653                               num_hidden=self._hidden_size * 3,
654                               name=prefix+'h2h')
655
656        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3,
657                                           name=prefix+'i2h_slice')
658        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
659                                           name=prefix+'h2h_slice')
660
661        reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid",
662                                  name=prefix+'r_act')
663        update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid",
664                                   name=prefix+'z_act')
665
666        next_h_tmp = F.Activation(F.elemwise_add(i2h,
667                                                 F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'),
668                                                 name=prefix+'plus2'),
669                                  act_type="tanh",
670                                  name=prefix+'h_act')
671
672        ones = F.ones_like(update_gate, name=prefix+"ones_like0")
673        next_h = F.elemwise_add(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
674                                               next_h_tmp,
675                                               name=prefix+'mul1'),
676                                F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
677                                name=prefix+'out')
678
679        return next_h, [next_h]
680
681
682class SequentialRNNCell(RecurrentCell):
683    """Sequentially stacking multiple RNN cells."""
684    def __init__(self, prefix=None, params=None):
685        super(SequentialRNNCell, self).__init__(prefix=prefix, params=params)
686
687    def __repr__(self):
688        s = '{name}(\n{modstr}\n)'
689        return s.format(name=self.__class__.__name__,
690                        modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
691                                          for i, m in self._children.items()]))
692
693    def add(self, cell):
694        """Appends a cell into the stack.
695
696        Parameters
697        ----------
698        cell : RecurrentCell
699            The cell to add.
700        """
701        self.register_child(cell)
702
703    def state_info(self, batch_size=0):
704        return _cells_state_info(self._children.values(), batch_size)
705
706    def begin_state(self, **kwargs):
707        assert not self._modified, \
708            "After applying modifier cells (e.g. ZoneoutCell) the base " \
709            "cell cannot be called directly. Call the modifier cell instead."
710        return _cells_begin_state(self._children.values(), **kwargs)
711
712    def __call__(self, inputs, states):
713        self._counter += 1
714        next_states = []
715        p = 0
716        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
717        for cell in self._children.values():
718            assert not isinstance(cell, BidirectionalCell)
719            n = len(cell.state_info())
720            state = states[p:p+n]
721            p += n
722            inputs, state = cell(inputs, state)
723            next_states.append(state)
724        return inputs, sum(next_states, [])
725
726    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
727               valid_length=None):
728        # pylint: disable=too-many-locals
729        self.reset()
730
731        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
732        num_cells = len(self._children)
733        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
734
735        p = 0
736        next_states = []
737        for i, cell in enumerate(self._children.values()):
738            n = len(cell.state_info())
739            states = begin_state[p:p+n]
740            p += n
741            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
742                                         layout=layout,
743                                         merge_outputs=None if i < num_cells-1 else merge_outputs,
744                                         valid_length=valid_length)
745            next_states.extend(states)
746
747        return inputs, next_states
748
749    def __getitem__(self, i):
750        return self._children[str(i)]
751
752    def __len__(self):
753        return len(self._children)
754
755    def hybrid_forward(self, *args, **kwargs):
756        # pylint: disable=missing-docstring
757        raise NotImplementedError
758
759
760class HybridSequentialRNNCell(HybridRecurrentCell):
761    """Sequentially stacking multiple HybridRNN cells."""
762    def __init__(self, prefix=None, params=None):
763        super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)
764
765    def __repr__(self):
766        s = '{name}(\n{modstr}\n)'
767        return s.format(name=self.__class__.__name__,
768                        modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
769                                          for i, m in self._children.items()]))
770
771    def add(self, cell):
772        """Appends a cell into the stack.
773
774        Parameters
775        ----------
776        cell : RecurrentCell
777            The cell to add.
778        """
779        self.register_child(cell)
780
781    def state_info(self, batch_size=0):
782        return _cells_state_info(self._children.values(), batch_size)
783
784    def begin_state(self, **kwargs):
785        assert not self._modified, \
786            "After applying modifier cells (e.g. ZoneoutCell) the base " \
787            "cell cannot be called directly. Call the modifier cell instead."
788        return _cells_begin_state(self._children.values(), **kwargs)
789
790    def __call__(self, inputs, states):
791        self._counter += 1
792        next_states = []
793        p = 0
794        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
795        for cell in self._children.values():
796            n = len(cell.state_info())
797            state = states[p:p+n]
798            p += n
799            inputs, state = cell(inputs, state)
800            next_states.append(state)
801        return inputs, sum(next_states, [])
802
803    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
804               valid_length=None):
805        self.reset()
806
807        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
808        num_cells = len(self._children)
809        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
810
811        p = 0
812        next_states = []
813        for i, cell in enumerate(self._children.values()):
814            n = len(cell.state_info())
815            states = begin_state[p:p+n]
816            p += n
817            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
818                                         layout=layout,
819                                         merge_outputs=None if i < num_cells-1 else merge_outputs,
820                                         valid_length=valid_length)
821            next_states.extend(states)
822
823        return inputs, next_states
824
825    def __getitem__(self, i):
826        return self._children[str(i)]
827
828    def __len__(self):
829        return len(self._children)
830
831    def hybrid_forward(self, F, inputs, states):
832        return self.__call__(inputs, states)
833
834
835class DropoutCell(HybridRecurrentCell):
836    """Applies dropout on input.
837
838    Parameters
839    ----------
840    rate : float
841        Percentage of elements to drop out, which
842        is 1 - percentage to retain.
843    axes : tuple of int, default ()
844        The axes on which dropout mask is shared. If empty, regular dropout is applied.
845
846
847    Inputs:
848        - **data**: input tensor with shape `(batch_size, size)`.
849        - **states**: a list of recurrent state tensors.
850
851    Outputs:
852        - **out**: output tensor with shape `(batch_size, size)`.
853        - **next_states**: returns input `states` directly.
854    """
855    def __init__(self, rate, axes=(), prefix=None, params=None):
856        super(DropoutCell, self).__init__(prefix, params)
857        assert isinstance(rate, numeric_types), "rate must be a number"
858        self._rate = rate
859        self._axes = axes
860
861    def __repr__(self):
862        s = '{name}(rate={_rate}, axes={_axes})'
863        return s.format(name=self.__class__.__name__,
864                        **self.__dict__)
865
866    def state_info(self, batch_size=0):
867        return []
868
869    def _alias(self):
870        return 'dropout'
871
872    def hybrid_forward(self, F, inputs, states):
873        if self._rate > 0:
874            inputs = F.Dropout(data=inputs, p=self._rate, axes=self._axes,
875                               name='t%d_fwd'%self._counter)
876        return inputs, states
877
878    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
879               valid_length=None):
880        self.reset()
881
882        inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
883        if isinstance(inputs, tensor_types):
884            return self.hybrid_forward(F, inputs, begin_state if begin_state else [])
885        return super(DropoutCell, self).unroll(
886            length, inputs, begin_state=begin_state, layout=layout,
887            merge_outputs=merge_outputs, valid_length=None)
888
889
890class ModifierCell(HybridRecurrentCell):
891    """Base class for modifier cells. A modifier
892    cell takes a base cell, apply modifications
893    on it (e.g. Zoneout), and returns a new cell.
894
895    After applying modifiers the base cell should
896    no longer be called directly. The modifier cell
897    should be used instead.
898    """
899    def __init__(self, base_cell):
900        assert not base_cell._modified, \
901            "Cell %s is already modified. One cell cannot be modified twice"%base_cell.name
902        base_cell._modified = True
903        super(ModifierCell, self).__init__(prefix=base_cell.prefix+self._alias(),
904                                           params=None)
905        self.base_cell = base_cell
906
907    @property
908    def params(self):
909        return self.base_cell.params
910
911    def state_info(self, batch_size=0):
912        return self.base_cell.state_info(batch_size)
913
914    def begin_state(self, func=symbol.zeros, **kwargs):
915        assert not self._modified, \
916            "After applying modifier cells (e.g. DropoutCell) the base " \
917            "cell cannot be called directly. Call the modifier cell instead."
918        self.base_cell._modified = False
919        begin = self.base_cell.begin_state(func=func, **kwargs)
920        self.base_cell._modified = True
921        return begin
922
923    def hybrid_forward(self, F, inputs, states):
924        raise NotImplementedError
925
926    def __repr__(self):
927        s = '{name}({base_cell})'
928        return s.format(name=self.__class__.__name__,
929                        **self.__dict__)
930
931
932class ZoneoutCell(ModifierCell):
933    """Applies Zoneout on base cell."""
934    def __init__(self, base_cell, zoneout_outputs=0., zoneout_states=0.):
935        assert not isinstance(base_cell, BidirectionalCell), \
936            "BidirectionalCell doesn't support zoneout since it doesn't support step. " \
937            "Please add ZoneoutCell to the cells underneath instead."
938        assert not isinstance(base_cell, SequentialRNNCell) or not base_cell._bidirectional, \
939            "Bidirectional SequentialRNNCell doesn't support zoneout. " \
940            "Please add ZoneoutCell to the cells underneath instead."
941        super(ZoneoutCell, self).__init__(base_cell)
942        self.zoneout_outputs = zoneout_outputs
943        self.zoneout_states = zoneout_states
944        self._prev_output = None
945
946    def __repr__(self):
947        s = '{name}(p_out={zoneout_outputs}, p_state={zoneout_states}, {base_cell})'
948        return s.format(name=self.__class__.__name__,
949                        **self.__dict__)
950
951    def _alias(self):
952        return 'zoneout'
953
954    def reset(self):
955        super(ZoneoutCell, self).reset()
956        self._prev_output = None
957
958    def hybrid_forward(self, F, inputs, states):
959        cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states
960        next_output, next_states = cell(inputs, states)
961        mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p))
962
963        prev_output = self._prev_output
964        if prev_output is None:
965            prev_output = F.zeros_like(next_output)
966
967        output = (F.where(mask(p_outputs, next_output), next_output, prev_output)
968                  if p_outputs != 0. else next_output)
969        states = ([F.where(mask(p_states, new_s), new_s, old_s) for new_s, old_s in
970                   zip(next_states, states)] if p_states != 0. else next_states)
971
972        self._prev_output = output
973
974        return output, states
975
976
977class ResidualCell(ModifierCell):
978    """
979    Adds residual connection as described in Wu et al, 2016
980    (https://arxiv.org/abs/1609.08144).
981    Output of the cell is output of the base cell plus input.
982    """
983
984    def __init__(self, base_cell):
985        # pylint: disable=useless-super-delegation
986        super(ResidualCell, self).__init__(base_cell)
987
988    def hybrid_forward(self, F, inputs, states):
989        output, states = self.base_cell(inputs, states)
990        output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
991        return output, states
992
993    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
994               valid_length=None):
995        self.reset()
996
997        self.base_cell._modified = False
998        outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
999                                                layout=layout, merge_outputs=merge_outputs,
1000                                                valid_length=valid_length)
1001        self.base_cell._modified = True
1002
1003        merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
1004                        merge_outputs
1005        inputs, axis, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
1006        if valid_length is not None:
1007            # mask the padded inputs to zero
1008            inputs = _mask_sequence_variable_length(F, inputs, length, valid_length, axis,
1009                                                    merge_outputs)
1010        if merge_outputs:
1011            outputs = F.elemwise_add(outputs, inputs)
1012        else:
1013            outputs = [F.elemwise_add(i, j) for i, j in zip(outputs, inputs)]
1014
1015        return outputs, states
1016
1017
1018class BidirectionalCell(HybridRecurrentCell):
1019    """Bidirectional RNN cell.
1020
1021    Parameters
1022    ----------
1023    l_cell : RecurrentCell
1024        Cell for forward unrolling
1025    r_cell : RecurrentCell
1026        Cell for backward unrolling
1027    """
1028    def __init__(self, l_cell, r_cell, output_prefix='bi_'):
1029        super(BidirectionalCell, self).__init__(prefix='', params=None)
1030        self.register_child(l_cell, 'l_cell')
1031        self.register_child(r_cell, 'r_cell')
1032        self._output_prefix = output_prefix
1033
1034    def __call__(self, inputs, states):
1035        raise NotImplementedError("Bidirectional cannot be stepped. Please use unroll")
1036
1037    def __repr__(self):
1038        s = '{name}(forward={l_cell}, backward={r_cell})'
1039        return s.format(name=self.__class__.__name__,
1040                        l_cell=self._children['l_cell'],
1041                        r_cell=self._children['r_cell'])
1042
1043    def state_info(self, batch_size=0):
1044        return _cells_state_info(self._children.values(), batch_size)
1045
1046    def begin_state(self, **kwargs):
1047        assert not self._modified, \
1048            "After applying modifier cells (e.g. DropoutCell) the base " \
1049            "cell cannot be called directly. Call the modifier cell instead."
1050        return _cells_begin_state(self._children.values(), **kwargs)
1051
1052    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
1053               valid_length=None):
1054        # pylint: disable=too-many-locals
1055        self.reset()
1056
1057        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
1058        reversed_inputs = list(_reverse_sequences(inputs, length, valid_length))
1059        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
1060
1061        states = begin_state
1062        l_cell, r_cell = self._children.values()
1063        l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
1064                                            begin_state=states[:len(l_cell.state_info(batch_size))],
1065                                            layout=layout, merge_outputs=merge_outputs,
1066                                            valid_length=valid_length)
1067        r_outputs, r_states = r_cell.unroll(length,
1068                                            inputs=reversed_inputs,
1069                                            begin_state=states[len(l_cell.state_info(batch_size)):],
1070                                            layout=layout, merge_outputs=False,
1071                                            valid_length=valid_length)
1072        reversed_r_outputs = _reverse_sequences(r_outputs, length, valid_length)
1073
1074        if merge_outputs is None:
1075            merge_outputs = isinstance(l_outputs, tensor_types)
1076            l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
1077            reversed_r_outputs, _, _, _ = _format_sequence(None, reversed_r_outputs, layout,
1078                                                           merge_outputs)
1079
1080        if merge_outputs:
1081            reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis)
1082            outputs = F.concat(l_outputs, reversed_r_outputs, dim=2,
1083                               name='%sout'%self._output_prefix)
1084
1085        else:
1086            outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i))
1087                       for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed_r_outputs))]
1088        if valid_length is not None:
1089            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
1090                                                     merge_outputs)
1091        states = l_states + r_states
1092        return outputs, states
1093