1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=no-member, invalid-name, protected-access, no-self-use 20# pylint: disable=too-many-branches, too-many-arguments, no-self-use 21# pylint: disable=too-many-lines, arguments-differ 22"""Definition of various recurrent neural network cells.""" 23__all__ = ['RecurrentCell', 'HybridRecurrentCell', 24 'RNNCell', 'LSTMCell', 'GRUCell', 25 'SequentialRNNCell', 'HybridSequentialRNNCell', 'DropoutCell', 26 'ModifierCell', 'ZoneoutCell', 'ResidualCell', 27 'BidirectionalCell'] 28 29from ... import symbol, ndarray 30from ...base import string_types, numeric_types, _as_list 31from ..block import Block, HybridBlock 32from ..utils import _indent 33from .. import tensor_types 34from ..nn import LeakyReLU 35 36 37def _cells_state_info(cells, batch_size): 38 return sum([c.state_info(batch_size) for c in cells], []) 39 40def _cells_begin_state(cells, **kwargs): 41 return sum([c.begin_state(**kwargs) for c in cells], []) 42 43def _get_begin_state(cell, F, begin_state, inputs, batch_size): 44 if begin_state is None: 45 if F is ndarray: 46 ctx = inputs.context if isinstance(inputs, tensor_types) else inputs[0].context 47 with ctx: 48 begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size) 49 else: 50 begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size) 51 return begin_state 52 53def _format_sequence(length, inputs, layout, merge, in_layout=None): 54 assert inputs is not None, \ 55 "unroll(inputs=None) has been deprecated. " \ 56 "Please create input variables outside unroll." 57 58 axis = layout.find('T') 59 batch_axis = layout.find('N') 60 batch_size = 0 61 in_axis = in_layout.find('T') if in_layout is not None else axis 62 if isinstance(inputs, symbol.Symbol): 63 F = symbol 64 if merge is False: 65 assert len(inputs.list_outputs()) == 1, \ 66 "unroll doesn't allow grouped symbol as input. Please convert " \ 67 "to list with list(inputs) first or let unroll handle splitting." 68 inputs = list(symbol.split(inputs, axis=in_axis, num_outputs=length, 69 squeeze_axis=1)) 70 elif isinstance(inputs, ndarray.NDArray): 71 F = ndarray 72 batch_size = inputs.shape[batch_axis] 73 if merge is False: 74 assert length is None or length == inputs.shape[in_axis] 75 inputs = _as_list(ndarray.split(inputs, axis=in_axis, 76 num_outputs=inputs.shape[in_axis], 77 squeeze_axis=1)) 78 else: 79 assert length is None or len(inputs) == length 80 if isinstance(inputs[0], symbol.Symbol): 81 F = symbol 82 else: 83 F = ndarray 84 batch_size = inputs[0].shape[0] 85 if merge is True: 86 inputs = F.stack(*inputs, axis=axis) 87 in_axis = axis 88 89 if isinstance(inputs, tensor_types) and axis != in_axis: 90 inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis) 91 92 return inputs, axis, F, batch_size 93 94def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, merge): 95 assert valid_length is not None 96 if not isinstance(data, tensor_types): 97 data = F.stack(*data, axis=time_axis) 98 outputs = F.SequenceMask(data, sequence_length=valid_length, use_sequence_length=True, 99 axis=time_axis) 100 if not merge: 101 outputs = _as_list(F.split(outputs, num_outputs=length, axis=time_axis, 102 squeeze_axis=True)) 103 return outputs 104 105def _reverse_sequences(sequences, unroll_step, valid_length=None): 106 if isinstance(sequences[0], symbol.Symbol): 107 F = symbol 108 else: 109 F = ndarray 110 111 if valid_length is None: 112 reversed_sequences = list(reversed(sequences)) 113 else: 114 reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0), 115 sequence_length=valid_length, 116 use_sequence_length=True) 117 if unroll_step > 1 or F is symbol: 118 reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True) 119 else: 120 reversed_sequences = [reversed_sequences[0]] 121 122 return reversed_sequences 123 124 125class RecurrentCell(Block): 126 """Abstract base class for RNN cells 127 128 Parameters 129 ---------- 130 prefix : str, optional 131 Prefix for names of `Block`s 132 (this prefix is also used for names of weights if `params` is `None` 133 i.e. if `params` are being created and not reused) 134 params : Parameter or None, default None 135 Container for weight sharing between cells. 136 A new Parameter container is created if `params` is `None`. 137 """ 138 def __init__(self, prefix=None, params=None): 139 super(RecurrentCell, self).__init__(prefix=prefix, params=params) 140 self._modified = False 141 self.reset() 142 143 def reset(self): 144 """Reset before re-using the cell for another graph.""" 145 self._init_counter = -1 146 self._counter = -1 147 for cell in self._children.values(): 148 cell.reset() 149 150 def state_info(self, batch_size=0): 151 """shape and layout information of states""" 152 raise NotImplementedError() 153 154 def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): 155 """Initial state for this cell. 156 157 Parameters 158 ---------- 159 func : callable, default symbol.zeros 160 Function for creating initial state. 161 162 For Symbol API, func can be `symbol.zeros`, `symbol.uniform`, 163 `symbol.var etc`. Use `symbol.var` if you want to directly 164 feed input as states. 165 166 For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc. 167 batch_size: int, default 0 168 Only required for NDArray API. Size of the batch ('N' in layout) 169 dimension of input. 170 171 **kwargs : 172 Additional keyword arguments passed to func. For example 173 `mean`, `std`, `dtype`, etc. 174 175 Returns 176 ------- 177 states : nested list of Symbol 178 Starting states for the first RNN step. 179 """ 180 assert not self._modified, \ 181 "After applying modifier cells (e.g. ZoneoutCell) the base " \ 182 "cell cannot be called directly. Call the modifier cell instead." 183 states = [] 184 for info in self.state_info(batch_size): 185 self._init_counter += 1 186 if info is not None: 187 info.update(kwargs) 188 else: 189 info = kwargs 190 state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter), 191 **info) 192 states.append(state) 193 return states 194 195 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 196 valid_length=None): 197 """Unrolls an RNN cell across time steps. 198 199 Parameters 200 ---------- 201 length : int 202 Number of steps to unroll. 203 inputs : Symbol, list of Symbol, or None 204 If `inputs` is a single Symbol (usually the output 205 of Embedding symbol), it should have shape 206 (batch_size, length, ...) if `layout` is 'NTC', 207 or (length, batch_size, ...) if `layout` is 'TNC'. 208 209 If `inputs` is a list of symbols (usually output of 210 previous unroll), they should all have shape 211 (batch_size, ...). 212 begin_state : nested list of Symbol, optional 213 Input states created by `begin_state()` 214 or output state of another cell. 215 Created from `begin_state()` if `None`. 216 layout : str, optional 217 `layout` of input symbol. Only used if inputs 218 is a single Symbol. 219 merge_outputs : bool, optional 220 If `False`, returns outputs as a list of Symbols. 221 If `True`, concatenates output across time steps 222 and returns a single symbol with shape 223 (batch_size, length, ...) if layout is 'NTC', 224 or (length, batch_size, ...) if layout is 'TNC'. 225 If `None`, output whatever is faster. 226 valid_length : Symbol, NDArray or None 227 `valid_length` specifies the length of the sequences in the batch without padding. 228 This option is especially useful for building sequence-to-sequence models where 229 the input and output sequences would potentially be padded. 230 If `valid_length` is None, all sequences are assumed to have the same length. 231 If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,). 232 The ith element will be the length of the ith sequence in the batch. 233 The last valid state will be return and the padded outputs will be masked with 0. 234 Note that `valid_length` must be smaller or equal to `length`. 235 236 Returns 237 ------- 238 outputs : list of Symbol or Symbol 239 Symbol (if `merge_outputs` is True) or list of Symbols 240 (if `merge_outputs` is False) corresponding to the output from 241 the RNN from this unrolling. 242 243 states : list of Symbol 244 The new state of this RNN after this unrolling. 245 The type of this symbol is same as the output of `begin_state()`. 246 """ 247 # pylint: disable=too-many-locals 248 self.reset() 249 250 inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False) 251 begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) 252 253 states = begin_state 254 outputs = [] 255 all_states = [] 256 for i in range(length): 257 output, states = self(inputs[i], states) 258 outputs.append(output) 259 if valid_length is not None: 260 all_states.append(states) 261 if valid_length is not None: 262 states = [F.SequenceLast(F.stack(*ele_list, axis=0), 263 sequence_length=valid_length, 264 use_sequence_length=True, 265 axis=0) 266 for ele_list in zip(*all_states)] 267 outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, True) 268 outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs) 269 270 return outputs, states 271 272 #pylint: disable=no-self-use 273 def _get_activation(self, F, inputs, activation, **kwargs): 274 """Get activation function. Convert if is string""" 275 func = {'tanh': F.tanh, 276 'relu': F.relu, 277 'sigmoid': F.sigmoid, 278 'softsign': F.softsign}.get(activation) 279 if func: 280 return func(inputs, **kwargs) 281 elif isinstance(activation, string_types): 282 return F.Activation(inputs, act_type=activation, **kwargs) 283 elif isinstance(activation, LeakyReLU): 284 return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs) 285 return activation(inputs, **kwargs) 286 287 def forward(self, inputs, states): 288 """Unrolls the recurrent cell for one time step. 289 290 Parameters 291 ---------- 292 inputs : sym.Variable 293 Input symbol, 2D, of shape (batch_size * num_units). 294 states : list of sym.Variable 295 RNN state from previous step or the output of begin_state(). 296 297 Returns 298 ------- 299 output : Symbol 300 Symbol corresponding to the output from the RNN when unrolling 301 for a single time step. 302 states : list of Symbol 303 The new state of this RNN after this unrolling. 304 The type of this symbol is same as the output of `begin_state()`. 305 This can be used as an input state to the next time step 306 of this RNN. 307 308 See Also 309 -------- 310 begin_state: This function can provide the states for the first time step. 311 unroll: This function unrolls an RNN for a given number of (>=1) time steps. 312 """ 313 # pylint: disable= arguments-differ 314 self._counter += 1 315 return super(RecurrentCell, self).forward(inputs, states) 316 317 318class HybridRecurrentCell(RecurrentCell, HybridBlock): 319 """HybridRecurrentCell supports hybridize.""" 320 def __init__(self, prefix=None, params=None): 321 super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params) 322 323 def hybrid_forward(self, F, x, *args, **kwargs): 324 raise NotImplementedError 325 326 327class RNNCell(HybridRecurrentCell): 328 r"""Elman RNN recurrent neural network cell. 329 330 Each call computes the following function: 331 332 .. math:: 333 334 h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh}) 335 336 where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden 337 state of the previous layer at time `t` or :math:`input_t` for the first layer. 338 If nonlinearity='relu', then `ReLU` is used instead of `tanh`. 339 340 Parameters 341 ---------- 342 hidden_size : int 343 Number of units in output symbol 344 activation : str or Symbol, default 'tanh' 345 Type of activation function. 346 i2h_weight_initializer : str or Initializer 347 Initializer for the input weights matrix, used for the linear 348 transformation of the inputs. 349 h2h_weight_initializer : str or Initializer 350 Initializer for the recurrent weights matrix, used for the linear 351 transformation of the recurrent state. 352 i2h_bias_initializer : str or Initializer, default 'zeros' 353 Initializer for the bias vector. 354 h2h_bias_initializer : str or Initializer, default 'zeros' 355 Initializer for the bias vector. 356 prefix : str, default ``'rnn_'`` 357 Prefix for name of `Block`s 358 (and name of weight if params is `None`). 359 params : Parameter or None 360 Container for weight sharing between cells. 361 Created if `None`. 362 363 364 Inputs: 365 - **data**: input tensor with shape `(batch_size, input_size)`. 366 - **states**: a list of one initial recurrent state tensor with shape 367 `(batch_size, num_hidden)`. 368 369 Outputs: 370 - **out**: output tensor with shape `(batch_size, num_hidden)`. 371 - **next_states**: a list of one output recurrent state tensor with the 372 same shape as `states`. 373 """ 374 def __init__(self, hidden_size, activation='tanh', 375 i2h_weight_initializer=None, h2h_weight_initializer=None, 376 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', 377 input_size=0, prefix=None, params=None): 378 super(RNNCell, self).__init__(prefix=prefix, params=params) 379 self._hidden_size = hidden_size 380 self._activation = activation 381 self._input_size = input_size 382 self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size), 383 init=i2h_weight_initializer, 384 allow_deferred_init=True) 385 self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size), 386 init=h2h_weight_initializer, 387 allow_deferred_init=True) 388 self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,), 389 init=i2h_bias_initializer, 390 allow_deferred_init=True) 391 self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,), 392 init=h2h_bias_initializer, 393 allow_deferred_init=True) 394 395 def state_info(self, batch_size=0): 396 return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] 397 398 def _alias(self): 399 return 'rnn' 400 401 def __repr__(self): 402 s = '{name}({mapping}' 403 if hasattr(self, '_activation'): 404 s += ', {_activation}' 405 s += ')' 406 shape = self.i2h_weight.shape 407 mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) 408 return s.format(name=self.__class__.__name__, 409 mapping=mapping, 410 **self.__dict__) 411 412 def hybrid_forward(self, F, inputs, states, i2h_weight, 413 h2h_weight, i2h_bias, h2h_bias): 414 prefix = 't%d_'%self._counter 415 i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, 416 num_hidden=self._hidden_size, 417 name=prefix+'i2h') 418 h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, 419 num_hidden=self._hidden_size, 420 name=prefix+'h2h') 421 i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0') 422 output = self._get_activation(F, i2h_plus_h2h, self._activation, 423 name=prefix+'out') 424 425 return output, [output] 426 427 428class LSTMCell(HybridRecurrentCell): 429 r"""Long-Short Term Memory (LSTM) network cell. 430 431 Each call computes the following function: 432 433 .. math:: 434 \begin{array}{ll} 435 i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ 436 f_t = sigmoid(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ 437 g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ 438 o_t = sigmoid(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ 439 c_t = f_t * c_{(t-1)} + i_t * g_t \\ 440 h_t = o_t * \tanh(c_t) 441 \end{array} 442 443 where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the 444 cell state at time `t`, :math:`x_t` is the hidden state of the previous 445 layer at time `t` or :math:`input_t` for the first layer, and :math:`i_t`, 446 :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and 447 out gates, respectively. 448 449 Parameters 450 ---------- 451 hidden_size : int 452 Number of units in output symbol. 453 i2h_weight_initializer : str or Initializer 454 Initializer for the input weights matrix, used for the linear 455 transformation of the inputs. 456 h2h_weight_initializer : str or Initializer 457 Initializer for the recurrent weights matrix, used for the linear 458 transformation of the recurrent state. 459 i2h_bias_initializer : str or Initializer, default 'zeros' 460 Initializer for the bias vector. 461 h2h_bias_initializer : str or Initializer, default 'zeros' 462 Initializer for the bias vector. 463 prefix : str, default ``'lstm_'`` 464 Prefix for name of `Block`s 465 (and name of weight if params is `None`). 466 params : Parameter or None, default None 467 Container for weight sharing between cells. 468 Created if `None`. 469 activation : str, default 'tanh' 470 Activation type to use. See nd/symbol Activation 471 for supported types. 472 recurrent_activation : str, default 'sigmoid' 473 Activation type to use for the recurrent step. See nd/symbol Activation 474 for supported types. 475 476 Inputs: 477 - **data**: input tensor with shape `(batch_size, input_size)`. 478 - **states**: a list of two initial recurrent state tensors. Each has shape 479 `(batch_size, num_hidden)`. 480 481 Outputs: 482 - **out**: output tensor with shape `(batch_size, num_hidden)`. 483 - **next_states**: a list of two output recurrent state tensors. Each has 484 the same shape as `states`. 485 """ 486 # pylint: disable=too-many-instance-attributes 487 def __init__(self, hidden_size, 488 i2h_weight_initializer=None, h2h_weight_initializer=None, 489 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', 490 input_size=0, prefix=None, params=None, activation='tanh', 491 recurrent_activation='sigmoid'): 492 super(LSTMCell, self).__init__(prefix=prefix, params=params) 493 494 self._hidden_size = hidden_size 495 self._input_size = input_size 496 self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size), 497 init=i2h_weight_initializer, 498 allow_deferred_init=True) 499 self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size), 500 init=h2h_weight_initializer, 501 allow_deferred_init=True) 502 self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,), 503 init=i2h_bias_initializer, 504 allow_deferred_init=True) 505 self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), 506 init=h2h_bias_initializer, 507 allow_deferred_init=True) 508 self._activation = activation 509 self._recurrent_activation = recurrent_activation 510 511 512 def state_info(self, batch_size=0): 513 return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}, 514 {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] 515 516 def _alias(self): 517 return 'lstm' 518 519 def __repr__(self): 520 s = '{name}({mapping})' 521 shape = self.i2h_weight.shape 522 mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) 523 return s.format(name=self.__class__.__name__, 524 mapping=mapping, 525 **self.__dict__) 526 527 def hybrid_forward(self, F, inputs, states, i2h_weight, 528 h2h_weight, i2h_bias, h2h_bias): 529 # pylint: disable=too-many-locals 530 prefix = 't%d_'%self._counter 531 i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, 532 num_hidden=self._hidden_size*4, name=prefix+'i2h') 533 h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, 534 num_hidden=self._hidden_size*4, name=prefix+'h2h') 535 gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0') 536 slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice') 537 in_gate = self._get_activation( 538 F, slice_gates[0], self._recurrent_activation, name=prefix+'i') 539 forget_gate = self._get_activation( 540 F, slice_gates[1], self._recurrent_activation, name=prefix+'f') 541 in_transform = self._get_activation( 542 F, slice_gates[2], self._activation, name=prefix+'c') 543 out_gate = self._get_activation( 544 F, slice_gates[3], self._recurrent_activation, name=prefix+'o') 545 next_c = F.elemwise_add(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'), 546 F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'), 547 name=prefix+'state') 548 next_h = F.elemwise_mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'), 549 name=prefix+'out') 550 551 return next_h, [next_h, next_c] 552 553 554class GRUCell(HybridRecurrentCell): 555 r"""Gated Rectified Unit (GRU) network cell. 556 Note: this is an implementation of the cuDNN version of GRUs 557 (slight modification compared to Cho et al. 2014; the reset gate :math:`r_t` 558 is applied after matrix multiplication). 559 560 Each call computes the following function: 561 562 .. math:: 563 \begin{array}{ll} 564 r_t = sigmoid(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ 565 i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ 566 n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)} + b_{hn})) \\ 567 h_t = (1 - i_t) * n_t + i_t * h_{(t-1)} \\ 568 \end{array} 569 570 where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden 571 state of the previous layer at time `t` or :math:`input_t` for the first layer, 572 and :math:`r_t`, :math:`i_t`, :math:`n_t` are the reset, input, and new gates, respectively. 573 574 Parameters 575 ---------- 576 hidden_size : int 577 Number of units in output symbol. 578 i2h_weight_initializer : str or Initializer 579 Initializer for the input weights matrix, used for the linear 580 transformation of the inputs. 581 h2h_weight_initializer : str or Initializer 582 Initializer for the recurrent weights matrix, used for the linear 583 transformation of the recurrent state. 584 i2h_bias_initializer : str or Initializer, default 'zeros' 585 Initializer for the bias vector. 586 h2h_bias_initializer : str or Initializer, default 'zeros' 587 Initializer for the bias vector. 588 prefix : str, default ``'gru_'`` 589 prefix for name of `Block`s 590 (and name of weight if params is `None`). 591 params : Parameter or None, default None 592 Container for weight sharing between cells. 593 Created if `None`. 594 595 596 Inputs: 597 - **data**: input tensor with shape `(batch_size, input_size)`. 598 - **states**: a list of one initial recurrent state tensor with shape 599 `(batch_size, num_hidden)`. 600 601 Outputs: 602 - **out**: output tensor with shape `(batch_size, num_hidden)`. 603 - **next_states**: a list of one output recurrent state tensor with the 604 same shape as `states`. 605 """ 606 def __init__(self, hidden_size, 607 i2h_weight_initializer=None, h2h_weight_initializer=None, 608 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', 609 input_size=0, prefix=None, params=None): 610 super(GRUCell, self).__init__(prefix=prefix, params=params) 611 self._hidden_size = hidden_size 612 self._input_size = input_size 613 self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size), 614 init=i2h_weight_initializer, 615 allow_deferred_init=True) 616 self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size), 617 init=h2h_weight_initializer, 618 allow_deferred_init=True) 619 self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,), 620 init=i2h_bias_initializer, 621 allow_deferred_init=True) 622 self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,), 623 init=h2h_bias_initializer, 624 allow_deferred_init=True) 625 626 def state_info(self, batch_size=0): 627 return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] 628 629 def _alias(self): 630 return 'gru' 631 632 def __repr__(self): 633 s = '{name}({mapping})' 634 shape = self.i2h_weight.shape 635 mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) 636 return s.format(name=self.__class__.__name__, 637 mapping=mapping, 638 **self.__dict__) 639 640 def hybrid_forward(self, F, inputs, states, i2h_weight, 641 h2h_weight, i2h_bias, h2h_bias): 642 # pylint: disable=too-many-locals 643 prefix = 't%d_'%self._counter 644 prev_state_h = states[0] 645 i2h = F.FullyConnected(data=inputs, 646 weight=i2h_weight, 647 bias=i2h_bias, 648 num_hidden=self._hidden_size * 3, 649 name=prefix+'i2h') 650 h2h = F.FullyConnected(data=prev_state_h, 651 weight=h2h_weight, 652 bias=h2h_bias, 653 num_hidden=self._hidden_size * 3, 654 name=prefix+'h2h') 655 656 i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3, 657 name=prefix+'i2h_slice') 658 h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3, 659 name=prefix+'h2h_slice') 660 661 reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid", 662 name=prefix+'r_act') 663 update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid", 664 name=prefix+'z_act') 665 666 next_h_tmp = F.Activation(F.elemwise_add(i2h, 667 F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'), 668 name=prefix+'plus2'), 669 act_type="tanh", 670 name=prefix+'h_act') 671 672 ones = F.ones_like(update_gate, name=prefix+"ones_like0") 673 next_h = F.elemwise_add(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'), 674 next_h_tmp, 675 name=prefix+'mul1'), 676 F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'), 677 name=prefix+'out') 678 679 return next_h, [next_h] 680 681 682class SequentialRNNCell(RecurrentCell): 683 """Sequentially stacking multiple RNN cells.""" 684 def __init__(self, prefix=None, params=None): 685 super(SequentialRNNCell, self).__init__(prefix=prefix, params=params) 686 687 def __repr__(self): 688 s = '{name}(\n{modstr}\n)' 689 return s.format(name=self.__class__.__name__, 690 modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2)) 691 for i, m in self._children.items()])) 692 693 def add(self, cell): 694 """Appends a cell into the stack. 695 696 Parameters 697 ---------- 698 cell : RecurrentCell 699 The cell to add. 700 """ 701 self.register_child(cell) 702 703 def state_info(self, batch_size=0): 704 return _cells_state_info(self._children.values(), batch_size) 705 706 def begin_state(self, **kwargs): 707 assert not self._modified, \ 708 "After applying modifier cells (e.g. ZoneoutCell) the base " \ 709 "cell cannot be called directly. Call the modifier cell instead." 710 return _cells_begin_state(self._children.values(), **kwargs) 711 712 def __call__(self, inputs, states): 713 self._counter += 1 714 next_states = [] 715 p = 0 716 assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values()) 717 for cell in self._children.values(): 718 assert not isinstance(cell, BidirectionalCell) 719 n = len(cell.state_info()) 720 state = states[p:p+n] 721 p += n 722 inputs, state = cell(inputs, state) 723 next_states.append(state) 724 return inputs, sum(next_states, []) 725 726 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 727 valid_length=None): 728 # pylint: disable=too-many-locals 729 self.reset() 730 731 inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None) 732 num_cells = len(self._children) 733 begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) 734 735 p = 0 736 next_states = [] 737 for i, cell in enumerate(self._children.values()): 738 n = len(cell.state_info()) 739 states = begin_state[p:p+n] 740 p += n 741 inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, 742 layout=layout, 743 merge_outputs=None if i < num_cells-1 else merge_outputs, 744 valid_length=valid_length) 745 next_states.extend(states) 746 747 return inputs, next_states 748 749 def __getitem__(self, i): 750 return self._children[str(i)] 751 752 def __len__(self): 753 return len(self._children) 754 755 def hybrid_forward(self, *args, **kwargs): 756 # pylint: disable=missing-docstring 757 raise NotImplementedError 758 759 760class HybridSequentialRNNCell(HybridRecurrentCell): 761 """Sequentially stacking multiple HybridRNN cells.""" 762 def __init__(self, prefix=None, params=None): 763 super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params) 764 765 def __repr__(self): 766 s = '{name}(\n{modstr}\n)' 767 return s.format(name=self.__class__.__name__, 768 modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2)) 769 for i, m in self._children.items()])) 770 771 def add(self, cell): 772 """Appends a cell into the stack. 773 774 Parameters 775 ---------- 776 cell : RecurrentCell 777 The cell to add. 778 """ 779 self.register_child(cell) 780 781 def state_info(self, batch_size=0): 782 return _cells_state_info(self._children.values(), batch_size) 783 784 def begin_state(self, **kwargs): 785 assert not self._modified, \ 786 "After applying modifier cells (e.g. ZoneoutCell) the base " \ 787 "cell cannot be called directly. Call the modifier cell instead." 788 return _cells_begin_state(self._children.values(), **kwargs) 789 790 def __call__(self, inputs, states): 791 self._counter += 1 792 next_states = [] 793 p = 0 794 assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values()) 795 for cell in self._children.values(): 796 n = len(cell.state_info()) 797 state = states[p:p+n] 798 p += n 799 inputs, state = cell(inputs, state) 800 next_states.append(state) 801 return inputs, sum(next_states, []) 802 803 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 804 valid_length=None): 805 self.reset() 806 807 inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None) 808 num_cells = len(self._children) 809 begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) 810 811 p = 0 812 next_states = [] 813 for i, cell in enumerate(self._children.values()): 814 n = len(cell.state_info()) 815 states = begin_state[p:p+n] 816 p += n 817 inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, 818 layout=layout, 819 merge_outputs=None if i < num_cells-1 else merge_outputs, 820 valid_length=valid_length) 821 next_states.extend(states) 822 823 return inputs, next_states 824 825 def __getitem__(self, i): 826 return self._children[str(i)] 827 828 def __len__(self): 829 return len(self._children) 830 831 def hybrid_forward(self, F, inputs, states): 832 return self.__call__(inputs, states) 833 834 835class DropoutCell(HybridRecurrentCell): 836 """Applies dropout on input. 837 838 Parameters 839 ---------- 840 rate : float 841 Percentage of elements to drop out, which 842 is 1 - percentage to retain. 843 axes : tuple of int, default () 844 The axes on which dropout mask is shared. If empty, regular dropout is applied. 845 846 847 Inputs: 848 - **data**: input tensor with shape `(batch_size, size)`. 849 - **states**: a list of recurrent state tensors. 850 851 Outputs: 852 - **out**: output tensor with shape `(batch_size, size)`. 853 - **next_states**: returns input `states` directly. 854 """ 855 def __init__(self, rate, axes=(), prefix=None, params=None): 856 super(DropoutCell, self).__init__(prefix, params) 857 assert isinstance(rate, numeric_types), "rate must be a number" 858 self._rate = rate 859 self._axes = axes 860 861 def __repr__(self): 862 s = '{name}(rate={_rate}, axes={_axes})' 863 return s.format(name=self.__class__.__name__, 864 **self.__dict__) 865 866 def state_info(self, batch_size=0): 867 return [] 868 869 def _alias(self): 870 return 'dropout' 871 872 def hybrid_forward(self, F, inputs, states): 873 if self._rate > 0: 874 inputs = F.Dropout(data=inputs, p=self._rate, axes=self._axes, 875 name='t%d_fwd'%self._counter) 876 return inputs, states 877 878 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 879 valid_length=None): 880 self.reset() 881 882 inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs) 883 if isinstance(inputs, tensor_types): 884 return self.hybrid_forward(F, inputs, begin_state if begin_state else []) 885 return super(DropoutCell, self).unroll( 886 length, inputs, begin_state=begin_state, layout=layout, 887 merge_outputs=merge_outputs, valid_length=None) 888 889 890class ModifierCell(HybridRecurrentCell): 891 """Base class for modifier cells. A modifier 892 cell takes a base cell, apply modifications 893 on it (e.g. Zoneout), and returns a new cell. 894 895 After applying modifiers the base cell should 896 no longer be called directly. The modifier cell 897 should be used instead. 898 """ 899 def __init__(self, base_cell): 900 assert not base_cell._modified, \ 901 "Cell %s is already modified. One cell cannot be modified twice"%base_cell.name 902 base_cell._modified = True 903 super(ModifierCell, self).__init__(prefix=base_cell.prefix+self._alias(), 904 params=None) 905 self.base_cell = base_cell 906 907 @property 908 def params(self): 909 return self.base_cell.params 910 911 def state_info(self, batch_size=0): 912 return self.base_cell.state_info(batch_size) 913 914 def begin_state(self, func=symbol.zeros, **kwargs): 915 assert not self._modified, \ 916 "After applying modifier cells (e.g. DropoutCell) the base " \ 917 "cell cannot be called directly. Call the modifier cell instead." 918 self.base_cell._modified = False 919 begin = self.base_cell.begin_state(func=func, **kwargs) 920 self.base_cell._modified = True 921 return begin 922 923 def hybrid_forward(self, F, inputs, states): 924 raise NotImplementedError 925 926 def __repr__(self): 927 s = '{name}({base_cell})' 928 return s.format(name=self.__class__.__name__, 929 **self.__dict__) 930 931 932class ZoneoutCell(ModifierCell): 933 """Applies Zoneout on base cell.""" 934 def __init__(self, base_cell, zoneout_outputs=0., zoneout_states=0.): 935 assert not isinstance(base_cell, BidirectionalCell), \ 936 "BidirectionalCell doesn't support zoneout since it doesn't support step. " \ 937 "Please add ZoneoutCell to the cells underneath instead." 938 assert not isinstance(base_cell, SequentialRNNCell) or not base_cell._bidirectional, \ 939 "Bidirectional SequentialRNNCell doesn't support zoneout. " \ 940 "Please add ZoneoutCell to the cells underneath instead." 941 super(ZoneoutCell, self).__init__(base_cell) 942 self.zoneout_outputs = zoneout_outputs 943 self.zoneout_states = zoneout_states 944 self._prev_output = None 945 946 def __repr__(self): 947 s = '{name}(p_out={zoneout_outputs}, p_state={zoneout_states}, {base_cell})' 948 return s.format(name=self.__class__.__name__, 949 **self.__dict__) 950 951 def _alias(self): 952 return 'zoneout' 953 954 def reset(self): 955 super(ZoneoutCell, self).reset() 956 self._prev_output = None 957 958 def hybrid_forward(self, F, inputs, states): 959 cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states 960 next_output, next_states = cell(inputs, states) 961 mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p)) 962 963 prev_output = self._prev_output 964 if prev_output is None: 965 prev_output = F.zeros_like(next_output) 966 967 output = (F.where(mask(p_outputs, next_output), next_output, prev_output) 968 if p_outputs != 0. else next_output) 969 states = ([F.where(mask(p_states, new_s), new_s, old_s) for new_s, old_s in 970 zip(next_states, states)] if p_states != 0. else next_states) 971 972 self._prev_output = output 973 974 return output, states 975 976 977class ResidualCell(ModifierCell): 978 """ 979 Adds residual connection as described in Wu et al, 2016 980 (https://arxiv.org/abs/1609.08144). 981 Output of the cell is output of the base cell plus input. 982 """ 983 984 def __init__(self, base_cell): 985 # pylint: disable=useless-super-delegation 986 super(ResidualCell, self).__init__(base_cell) 987 988 def hybrid_forward(self, F, inputs, states): 989 output, states = self.base_cell(inputs, states) 990 output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter) 991 return output, states 992 993 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 994 valid_length=None): 995 self.reset() 996 997 self.base_cell._modified = False 998 outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state, 999 layout=layout, merge_outputs=merge_outputs, 1000 valid_length=valid_length) 1001 self.base_cell._modified = True 1002 1003 merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \ 1004 merge_outputs 1005 inputs, axis, F, _ = _format_sequence(length, inputs, layout, merge_outputs) 1006 if valid_length is not None: 1007 # mask the padded inputs to zero 1008 inputs = _mask_sequence_variable_length(F, inputs, length, valid_length, axis, 1009 merge_outputs) 1010 if merge_outputs: 1011 outputs = F.elemwise_add(outputs, inputs) 1012 else: 1013 outputs = [F.elemwise_add(i, j) for i, j in zip(outputs, inputs)] 1014 1015 return outputs, states 1016 1017 1018class BidirectionalCell(HybridRecurrentCell): 1019 """Bidirectional RNN cell. 1020 1021 Parameters 1022 ---------- 1023 l_cell : RecurrentCell 1024 Cell for forward unrolling 1025 r_cell : RecurrentCell 1026 Cell for backward unrolling 1027 """ 1028 def __init__(self, l_cell, r_cell, output_prefix='bi_'): 1029 super(BidirectionalCell, self).__init__(prefix='', params=None) 1030 self.register_child(l_cell, 'l_cell') 1031 self.register_child(r_cell, 'r_cell') 1032 self._output_prefix = output_prefix 1033 1034 def __call__(self, inputs, states): 1035 raise NotImplementedError("Bidirectional cannot be stepped. Please use unroll") 1036 1037 def __repr__(self): 1038 s = '{name}(forward={l_cell}, backward={r_cell})' 1039 return s.format(name=self.__class__.__name__, 1040 l_cell=self._children['l_cell'], 1041 r_cell=self._children['r_cell']) 1042 1043 def state_info(self, batch_size=0): 1044 return _cells_state_info(self._children.values(), batch_size) 1045 1046 def begin_state(self, **kwargs): 1047 assert not self._modified, \ 1048 "After applying modifier cells (e.g. DropoutCell) the base " \ 1049 "cell cannot be called directly. Call the modifier cell instead." 1050 return _cells_begin_state(self._children.values(), **kwargs) 1051 1052 def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, 1053 valid_length=None): 1054 # pylint: disable=too-many-locals 1055 self.reset() 1056 1057 inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False) 1058 reversed_inputs = list(_reverse_sequences(inputs, length, valid_length)) 1059 begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) 1060 1061 states = begin_state 1062 l_cell, r_cell = self._children.values() 1063 l_outputs, l_states = l_cell.unroll(length, inputs=inputs, 1064 begin_state=states[:len(l_cell.state_info(batch_size))], 1065 layout=layout, merge_outputs=merge_outputs, 1066 valid_length=valid_length) 1067 r_outputs, r_states = r_cell.unroll(length, 1068 inputs=reversed_inputs, 1069 begin_state=states[len(l_cell.state_info(batch_size)):], 1070 layout=layout, merge_outputs=False, 1071 valid_length=valid_length) 1072 reversed_r_outputs = _reverse_sequences(r_outputs, length, valid_length) 1073 1074 if merge_outputs is None: 1075 merge_outputs = isinstance(l_outputs, tensor_types) 1076 l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs) 1077 reversed_r_outputs, _, _, _ = _format_sequence(None, reversed_r_outputs, layout, 1078 merge_outputs) 1079 1080 if merge_outputs: 1081 reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis) 1082 outputs = F.concat(l_outputs, reversed_r_outputs, dim=2, 1083 name='%sout'%self._output_prefix) 1084 1085 else: 1086 outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i)) 1087 for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed_r_outputs))] 1088 if valid_length is not None: 1089 outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, 1090 merge_outputs) 1091 states = l_states + r_states 1092 return outputs, states 1093