1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Bidirectional LM encoder.""" 19__all__ = ['BiLMEncoder'] 20 21from mxnet import gluon 22from mxnet.gluon import rnn 23from .utils import _get_rnn_cell 24 25 26class BiLMEncoder(gluon.HybridBlock): 27 r"""Bidirectional LM encoder. 28 29 We implement the encoder of the biLM proposed in the following work:: 30 31 @inproceedings{Peters:2018, 32 author={Peters, Matthew E. and Neumann, Mark and Iyyer, Mohit and Gardner, Matt and Clark, 33 Christopher and Lee, Kenton and Zettlemoyer, Luke}, 34 title={Deep contextualized word representations}, 35 booktitle={Proc. of NAACL}, 36 year={2018} 37 } 38 39 Parameters 40 ---------- 41 mode : str 42 The type of RNN cell to use. Options are 'lstmpc', 'rnn_tanh', 'rnn_relu', 'lstm', 'gru'. 43 num_layers : int 44 The number of RNN cells in the encoder. 45 input_size : int 46 The initial input size of in the RNN cell. 47 hidden_size : int 48 The hidden size of the RNN cell. 49 dropout : float 50 The dropout rate to use for encoder output. 51 skip_connection : bool 52 Whether to add skip connections (add RNN cell input to output) 53 proj_size : int 54 The projection size of each LSTMPCellWithClip cell 55 cell_clip : float 56 Clip cell state between [-cellclip, cell_clip] in LSTMPCellWithClip cell 57 proj_clip : float 58 Clip projection between [-projclip, projclip] in LSTMPCellWithClip cell 59 """ 60 61 def __init__(self, mode, num_layers, input_size, hidden_size, dropout=0.0, 62 skip_connection=True, proj_size=None, cell_clip=None, proj_clip=None, **kwargs): 63 super(BiLMEncoder, self).__init__(**kwargs) 64 65 self._mode = mode 66 self._num_layers = num_layers 67 self._input_size = input_size 68 self._hidden_size = hidden_size 69 self._dropout = dropout 70 self._skip_connection = skip_connection 71 self._proj_size = proj_size 72 self._cell_clip = cell_clip 73 self._proj_clip = proj_clip 74 75 with self.name_scope(): 76 lstm_input_size = self._input_size 77 self.forward_layers = rnn.HybridSequentialRNNCell() 78 with self.forward_layers.name_scope(): 79 for layer_index in range(self._num_layers): 80 forward_layer = _get_rnn_cell(mode=self._mode, 81 num_layers=1, 82 input_size=lstm_input_size, 83 hidden_size=self._hidden_size, 84 dropout=0 85 if layer_index == num_layers - 1 86 else self._dropout, 87 weight_dropout=0, 88 var_drop_in=0, 89 var_drop_state=0, 90 var_drop_out=0, 91 skip_connection=False 92 if layer_index == 0 93 else self._skip_connection, 94 proj_size=self._proj_size, 95 cell_clip=self._cell_clip, 96 proj_clip=self._proj_clip) 97 98 self.forward_layers.add(forward_layer) 99 lstm_input_size = self._hidden_size \ 100 if self._proj_size is None else self._proj_size 101 102 lstm_input_size = self._input_size 103 self.backward_layers = rnn.HybridSequentialRNNCell() 104 with self.backward_layers.name_scope(): 105 for layer_index in range(self._num_layers): 106 backward_layer = _get_rnn_cell(mode=self._mode, 107 num_layers=1, 108 input_size=lstm_input_size, 109 hidden_size=self._hidden_size, 110 dropout=0 111 if layer_index == num_layers - 1 112 else self._dropout, 113 weight_dropout=0, 114 var_drop_in=0, 115 var_drop_state=0, 116 var_drop_out=0, 117 skip_connection=False 118 if layer_index == 0 119 else self._skip_connection, 120 proj_size=self._proj_size, 121 cell_clip=self._cell_clip, 122 proj_clip=self._proj_clip) 123 self.backward_layers.add(backward_layer) 124 lstm_input_size = self._hidden_size \ 125 if self._proj_size is None else self._proj_size 126 127 def begin_state(self, func, **kwargs): 128 return [self.forward_layers[0][0].begin_state(func=func, **kwargs) 129 for _ in range(self._num_layers)], \ 130 [self.backward_layers[0][0].begin_state(func=func, **kwargs) 131 for _ in range(self._num_layers)] 132 133 def hybrid_forward(self, F, inputs, states=None, mask=None): 134 # pylint: disable=arguments-differ 135 # pylint: disable=unused-argument 136 """Defines the forward computation for cache cell. Arguments can be either 137 :py:class:`NDArray` or :py:class:`Symbol`. 138 139 Parameters 140 ---------- 141 inputs : NDArray 142 The input data layout='TNC'. 143 states : Tuple[List[List[NDArray]]] 144 The states. including: 145 states[0] indicates the states used in forward layer, 146 Each layer has a list of two initial tensors with 147 shape (batch_size, proj_size) and (batch_size, hidden_size). 148 states[1] indicates the states used in backward layer, 149 Each layer has a list of two initial tensors with 150 shape (batch_size, proj_size) and (batch_size, hidden_size). 151 152 Returns 153 -------- 154 out: NDArray 155 The output data with shape (num_layers, seq_len, batch_size, 2*input_size). 156 [states_forward, states_backward] : List 157 Including: 158 states_forward: The out states from forward layer, 159 which has the same structure with *states[0]*. 160 states_backward: The out states from backward layer, 161 which has the same structure with *states[1]*. 162 """ 163 states_forward, states_backward = states 164 if mask is not None: 165 sequence_length = mask.sum(axis=1) 166 167 outputs_forward = [] 168 outputs_backward = [] 169 170 for layer_index in range(self._num_layers): 171 if layer_index == 0: 172 layer_inputs = inputs 173 else: 174 layer_inputs = outputs_forward[layer_index - 1] 175 output, states_forward[layer_index] = F.contrib.foreach( 176 self.forward_layers[layer_index], 177 layer_inputs, 178 states_forward[layer_index]) 179 outputs_forward.append(output) 180 181 if layer_index == 0: 182 layer_inputs = inputs 183 else: 184 layer_inputs = outputs_backward[layer_index - 1] 185 186 if mask is not None: 187 layer_inputs = F.SequenceReverse(layer_inputs, 188 sequence_length=sequence_length, 189 use_sequence_length=True, axis=0) 190 else: 191 layer_inputs = F.SequenceReverse(layer_inputs, axis=0) 192 output, states_backward[layer_index] = F.contrib.foreach( 193 self.backward_layers[layer_index], 194 layer_inputs, 195 states_backward[layer_index]) 196 if mask is not None: 197 backward_out = F.SequenceReverse(output, 198 sequence_length=sequence_length, 199 use_sequence_length=True, axis=0) 200 else: 201 backward_out = F.SequenceReverse(output, axis=0) 202 outputs_backward.append(backward_out) 203 out = F.concat(*[F.stack(*outputs_forward, axis=0), 204 F.stack(*outputs_backward, axis=0)], dim=-1) 205 206 return out, [states_forward, states_backward] 207