1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Bidirectional LM encoder."""
19__all__ = ['BiLMEncoder']
20
21from mxnet import gluon
22from mxnet.gluon import rnn
23from .utils import _get_rnn_cell
24
25
26class BiLMEncoder(gluon.HybridBlock):
27    r"""Bidirectional LM encoder.
28
29    We implement the encoder of the biLM proposed in the following work::
30
31        @inproceedings{Peters:2018,
32        author={Peters, Matthew E. and  Neumann, Mark and Iyyer, Mohit and Gardner, Matt and Clark,
33        Christopher and Lee, Kenton and Zettlemoyer, Luke},
34        title={Deep contextualized word representations},
35        booktitle={Proc. of NAACL},
36        year={2018}
37        }
38
39    Parameters
40    ----------
41    mode : str
42        The type of RNN cell to use. Options are 'lstmpc', 'rnn_tanh', 'rnn_relu', 'lstm', 'gru'.
43    num_layers : int
44        The number of RNN cells in the encoder.
45    input_size : int
46        The initial input size of in the RNN cell.
47    hidden_size : int
48        The hidden size of the RNN cell.
49    dropout : float
50        The dropout rate to use for encoder output.
51    skip_connection : bool
52        Whether to add skip connections (add RNN cell input to output)
53    proj_size : int
54        The projection size of each LSTMPCellWithClip cell
55    cell_clip : float
56        Clip cell state between [-cellclip, cell_clip] in LSTMPCellWithClip cell
57    proj_clip : float
58        Clip projection between [-projclip, projclip] in LSTMPCellWithClip cell
59    """
60
61    def __init__(self, mode, num_layers, input_size, hidden_size, dropout=0.0,
62                 skip_connection=True, proj_size=None, cell_clip=None, proj_clip=None, **kwargs):
63        super(BiLMEncoder, self).__init__(**kwargs)
64
65        self._mode = mode
66        self._num_layers = num_layers
67        self._input_size = input_size
68        self._hidden_size = hidden_size
69        self._dropout = dropout
70        self._skip_connection = skip_connection
71        self._proj_size = proj_size
72        self._cell_clip = cell_clip
73        self._proj_clip = proj_clip
74
75        with self.name_scope():
76            lstm_input_size = self._input_size
77            self.forward_layers = rnn.HybridSequentialRNNCell()
78            with self.forward_layers.name_scope():
79                for layer_index in range(self._num_layers):
80                    forward_layer = _get_rnn_cell(mode=self._mode,
81                                                  num_layers=1,
82                                                  input_size=lstm_input_size,
83                                                  hidden_size=self._hidden_size,
84                                                  dropout=0
85                                                  if layer_index == num_layers - 1
86                                                  else self._dropout,
87                                                  weight_dropout=0,
88                                                  var_drop_in=0,
89                                                  var_drop_state=0,
90                                                  var_drop_out=0,
91                                                  skip_connection=False
92                                                  if layer_index == 0
93                                                  else self._skip_connection,
94                                                  proj_size=self._proj_size,
95                                                  cell_clip=self._cell_clip,
96                                                  proj_clip=self._proj_clip)
97
98                    self.forward_layers.add(forward_layer)
99                    lstm_input_size = self._hidden_size \
100                        if self._proj_size is None else self._proj_size
101
102            lstm_input_size = self._input_size
103            self.backward_layers = rnn.HybridSequentialRNNCell()
104            with self.backward_layers.name_scope():
105                for layer_index in range(self._num_layers):
106                    backward_layer = _get_rnn_cell(mode=self._mode,
107                                                   num_layers=1,
108                                                   input_size=lstm_input_size,
109                                                   hidden_size=self._hidden_size,
110                                                   dropout=0
111                                                   if layer_index == num_layers - 1
112                                                   else self._dropout,
113                                                   weight_dropout=0,
114                                                   var_drop_in=0,
115                                                   var_drop_state=0,
116                                                   var_drop_out=0,
117                                                   skip_connection=False
118                                                   if layer_index == 0
119                                                   else self._skip_connection,
120                                                   proj_size=self._proj_size,
121                                                   cell_clip=self._cell_clip,
122                                                   proj_clip=self._proj_clip)
123                    self.backward_layers.add(backward_layer)
124                    lstm_input_size = self._hidden_size \
125                        if self._proj_size is None else self._proj_size
126
127    def begin_state(self, func, **kwargs):
128        return [self.forward_layers[0][0].begin_state(func=func, **kwargs)
129                for _ in range(self._num_layers)], \
130               [self.backward_layers[0][0].begin_state(func=func, **kwargs)
131                for _ in range(self._num_layers)]
132
133    def hybrid_forward(self, F, inputs, states=None, mask=None):
134        # pylint: disable=arguments-differ
135        # pylint: disable=unused-argument
136        """Defines the forward computation for cache cell. Arguments can be either
137        :py:class:`NDArray` or :py:class:`Symbol`.
138
139        Parameters
140        ----------
141        inputs : NDArray
142            The input data layout='TNC'.
143        states : Tuple[List[List[NDArray]]]
144            The states. including:
145            states[0] indicates the states used in forward layer,
146            Each layer has a list of two initial tensors with
147            shape (batch_size, proj_size) and (batch_size, hidden_size).
148            states[1] indicates the states used in backward layer,
149            Each layer has a list of two initial tensors with
150            shape (batch_size, proj_size) and (batch_size, hidden_size).
151
152        Returns
153        --------
154        out: NDArray
155            The output data with shape (num_layers, seq_len, batch_size, 2*input_size).
156        [states_forward, states_backward] : List
157            Including:
158            states_forward: The out states from forward layer,
159            which has the same structure with *states[0]*.
160            states_backward: The out states from backward layer,
161            which has the same structure with *states[1]*.
162        """
163        states_forward, states_backward = states
164        if mask is not None:
165            sequence_length = mask.sum(axis=1)
166
167        outputs_forward = []
168        outputs_backward = []
169
170        for layer_index in range(self._num_layers):
171            if layer_index == 0:
172                layer_inputs = inputs
173            else:
174                layer_inputs = outputs_forward[layer_index - 1]
175            output, states_forward[layer_index] = F.contrib.foreach(
176                self.forward_layers[layer_index],
177                layer_inputs,
178                states_forward[layer_index])
179            outputs_forward.append(output)
180
181            if layer_index == 0:
182                layer_inputs = inputs
183            else:
184                layer_inputs = outputs_backward[layer_index - 1]
185
186            if mask is not None:
187                layer_inputs = F.SequenceReverse(layer_inputs,
188                                                 sequence_length=sequence_length,
189                                                 use_sequence_length=True, axis=0)
190            else:
191                layer_inputs = F.SequenceReverse(layer_inputs, axis=0)
192            output, states_backward[layer_index] = F.contrib.foreach(
193                self.backward_layers[layer_index],
194                layer_inputs,
195                states_backward[layer_index])
196            if mask is not None:
197                backward_out = F.SequenceReverse(output,
198                                                 sequence_length=sequence_length,
199                                                 use_sequence_length=True, axis=0)
200            else:
201                backward_out = F.SequenceReverse(output, axis=0)
202            outputs_backward.append(backward_out)
203        out = F.concat(*[F.stack(*outputs_forward, axis=0),
204                         F.stack(*outputs_backward, axis=0)], dim=-1)
205
206        return out, [states_forward, states_backward]
207