1# Copyright 2021 The Flax Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""DEPRECATION WARNING:
16  The `flax.nn` module is Deprecated, use `flax.linen` instead.
17  Learn more and find an upgrade guide at
18  https://github.com/google/flax/blob/master/flax/linen/README.md"
19  Recurrent neural network modules.
20
21THe RNNCell modules are designed to fit in with the scan function in JAX::
22
23  _, initial_params = LSTMCell.init(rng_1, time_series[0])
24  model = nn.Model(LSTMCell, initial_params)
25  carry = LSTMCell.initialize_carry(rng_2, (batch_size,), memory_size)
26  carry, y = jax.lax.scan(model, carry, time_series)
27
28"""
29
30import abc
31
32from . import activation
33from . import base
34from . import initializers
35from . import linear
36
37from jax import numpy as jnp
38from jax import random
39from jax import lax
40import numpy as np
41
42
43class RNNCellBase(base.Module):
44  """DEPRECATION WARNING:
45  The `flax.nn` module is Deprecated, use `flax.linen` instead.
46  Learn more and find an upgrade guide at
47  https://github.com/google/flax/blob/master/flax/linen/README.md"
48  RNN cell base class."""
49
50  @staticmethod
51  @abc.abstractmethod
52  def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
53    """initialize the RNN cell carry.
54
55    Args:
56      rng: random number generator passed to the init_fn.
57      batch_dims: a tuple providing the shape of the batch dimensions.
58      size: the size or number of features of the memory.
59      init_fn: initializer function for the carry.
60    Returns:
61      An initialized carry for the given RNN cell.
62    """
63    pass
64
65
66class LSTMCell(RNNCellBase):
67  """DEPRECATION WARNING:
68  The `flax.nn` module is Deprecated, use `flax.linen` instead.
69  Learn more and find an upgrade guide at
70  https://github.com/google/flax/blob/master/flax/linen/README.md"
71  LSTM cell."""
72
73  def apply(self, carry, inputs,
74            gate_fn=activation.sigmoid, activation_fn=activation.tanh,
75            kernel_init=linear.default_kernel_init,
76            recurrent_kernel_init=initializers.orthogonal(),
77            bias_init=initializers.zeros):
78    r"""A long short-term memory (LSTM) cell.
79
80    the mathematical definition of the cell is as follows
81    .. math::
82        \begin{array}{ll}
83        i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
84        f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
85        g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
86        o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
87        c' = f * c + i * g \\
88        h' = o * \tanh(c') \\
89        \end{array}
90    where x is the input, h is the output of the previous time step, and c is
91    the memory.
92
93    Args:
94      carry: the hidden state of the LSTM cell,
95        initialized using `LSTMCell.initialize_carry`.
96      inputs: an ndarray with the input for the current time step.
97        All dimensions except the final are considered batch dimensions.
98      gate_fn: activation function used for gates (default: sigmoid)
99      activation_fn: activation function used for output and memory update
100        (default: tanh).
101      kernel_init: initializer function for the kernels that transform
102        the input (default: lecun_normal).
103      recurrent_kernel_init: initializer function for the kernels that transform
104        the hidden state (default: orthogonal).
105      bias_init: initializer for the bias parameters (default: zeros)
106    Returns:
107      A tuple with the new carry and the output.
108    """
109    c, h = carry
110    hidden_features = h.shape[-1]
111    # input and recurrent layers are summed so only one needs a bias.
112    dense_h = linear.Dense.partial(
113        inputs=h, features=hidden_features, bias=True,
114        kernel_init=recurrent_kernel_init, bias_init=bias_init)
115    dense_i = linear.Dense.partial(
116        inputs=inputs, features=hidden_features, bias=False,
117        kernel_init=kernel_init)
118    i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))
119    f = gate_fn(dense_i(name='if') + dense_h(name='hf'))
120    g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))
121    o = gate_fn(dense_i(name='io') + dense_h(name='ho'))
122    new_c = f * c + i * g
123    new_h = o * activation_fn(new_c)
124    return (new_c, new_h), new_h
125
126  @staticmethod
127  def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
128    """initialize the RNN cell carry.
129
130    Args:
131      rng: random number generator passed to the init_fn.
132      batch_dims: a tuple providing the shape of the batch dimensions.
133      size: the size or number of features of the memory.
134      init_fn: initializer function for the carry.
135    Returns:
136      An initialized carry for the given RNN cell.
137    """
138    key1, key2 = random.split(rng)
139    mem_shape = batch_dims + (size,)
140    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
141
142
143class OptimizedLSTMCell(RNNCellBase):
144  """DEPRECATION WARNING:
145  The `flax.nn` module is Deprecated, use `flax.linen` instead.
146  Learn more and find an upgrade guide at
147  https://github.com/google/flax/blob/master/flax/linen/README.md"
148  More efficient LSTM Cell that concatenates state components before matmul.
149
150  Parameters are compatible with `flax.nn.LSTMCell`.
151  """
152
153  class DummyDense(base.Module):
154    """Dummy module for creating parameters matching `flax.nn.Dense`."""
155
156    def apply(self,
157              inputs,
158              features,
159              kernel_init,
160              bias_init,
161              bias=True):
162      k = self.param('kernel', (inputs.shape[-1], features), kernel_init)
163      b = (self.param('bias', (features,), bias_init)
164           if bias else jnp.zeros((features,)))
165      return k, b
166
167  def apply(self,
168            carry,
169            inputs,
170            gate_fn=activation.sigmoid,
171            activation_fn=activation.tanh,
172            kernel_init=linear.default_kernel_init,
173            recurrent_kernel_init=initializers.orthogonal(),
174            bias_init=initializers.zeros):
175    r"""A long short-term memory (LSTM) cell.
176
177    the mathematical definition of the cell is as follows
178    .. math::
179        \begin{array}{ll}
180        i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
181        f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
182        g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
183        o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
184        c' = f * c + i * g \\
185        h' = o * \tanh(c') \\
186        \end{array}
187    where x is the input, h is the output of the previous time step, and c is
188    the memory.
189
190    Args:
191      carry: the hidden state of the LSTM cell, initialized using
192        `LSTMCell.initialize_carry`.
193      inputs: an ndarray with the input for the current time step. All
194        dimensions except the final are considered batch dimensions.
195      gate_fn: activation function used for gates (default: sigmoid)
196      activation_fn: activation function used for output and memory update
197        (default: tanh).
198      kernel_init: initializer function for the kernels that transform
199        the input (default: lecun_normal).
200      recurrent_kernel_init: initializer function for the kernels that transform
201        the hidden state (default: orthogonal).
202      bias_init: initializer for the bias parameters (default: zeros)
203
204    Returns:
205      A tuple with the new carry and the output.
206    """
207    c, h = carry
208    hidden_features = h.shape[-1]
209
210    def _concat_dense(inputs, params, use_bias=True):
211      kernels, biases = zip(*params.values())
212      kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), jnp.float32)
213
214      y = jnp.dot(inputs, kernel)
215      if use_bias:
216        bias = jnp.asarray(jnp.concatenate(biases, axis=-1), jnp.float32)
217        y = y + bias
218
219      # Split the result back into individual (i, f, g, o) outputs.
220      split_indices = np.cumsum([b.shape[0] for b in biases[:-1]])
221      ys = jnp.split(y, split_indices, axis=-1)
222      return dict(zip(params.keys(), ys))
223
224    # Create the params in the same order as LSTMCell for initialization
225    # compatibility.
226    dense_params_h = {}
227    dense_params_i = {}
228    for component in ['i', 'f', 'g', 'o']:
229      dense_params_i[component] = OptimizedLSTMCell.DummyDense(
230          inputs=inputs, features=hidden_features, bias=False,
231          kernel_init=kernel_init, bias_init=bias_init,
232          name=f'i{component}')
233      dense_params_h[component] = OptimizedLSTMCell.DummyDense(
234          inputs=h, features=hidden_features, bias=True,
235          kernel_init=recurrent_kernel_init, bias_init=bias_init,
236          name=f'h{component}')
237    dense_h = _concat_dense(h, dense_params_h, use_bias=True)
238    dense_i = _concat_dense(inputs, dense_params_i, use_bias=False)
239
240    i = gate_fn(dense_h['i'] + dense_i['i'])
241    f = gate_fn(dense_h['f'] + dense_i['f'])
242    g = activation_fn(dense_h['g'] + dense_i['g'])
243    o = gate_fn(dense_h['o'] + dense_i['o'])
244
245    new_c = f * c + i * g
246    new_h = o * activation_fn(new_c)
247    return (new_c, new_h), new_h
248
249  @staticmethod
250  def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
251    """initialize the RNN cell carry.
252
253    Args:
254      rng: random number generator passed to the init_fn.
255      batch_dims: a tuple providing the shape of the batch dimensions.
256      size: the size or number of features of the memory.
257      init_fn: initializer function for the carry.
258
259    Returns:
260      An initialized carry for the given RNN cell.
261    """
262    key1, key2 = random.split(rng)
263    mem_shape = batch_dims + (size,)
264    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
265
266
267class GRUCell(RNNCellBase):
268  """DEPRECATION WARNING:
269  The `flax.nn` module is Deprecated, use `flax.linen` instead.
270  Learn more and find an upgrade guide at
271  https://github.com/google/flax/blob/master/flax/linen/README.md"
272  GRU cell."""
273
274  def apply(self, carry, inputs,
275            gate_fn=activation.sigmoid, activation_fn=activation.tanh,
276            kernel_init=linear.default_kernel_init,
277            recurrent_kernel_init=initializers.orthogonal(),
278            bias_init=initializers.zeros):
279    r"""Gated recurrent unit (GRU) cell.
280
281    the mathematical definition of the cell is as follows
282    .. math::
283        \begin{array}{ll}
284        r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\
285        z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\
286        n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
287        h' = (1 - z) * n + z * h
288        \end{array}
289    where x is the input and h, is the output of the previous time step.
290
291    Args:
292      carry: the hidden state of the LSTM cell,
293        initialized using `GRUCell.initialize_carry`.
294      inputs: an ndarray with the input for the current time step.
295        All dimensions except the final are considered batch dimensions.
296      gate_fn: activation function used for gates (default: sigmoid)
297      activation_fn: activation function used for output and memory update
298        (default: tanh).
299      kernel_init: initializer function for the kernels that transform
300        the input (default: lecun_normal).
301      recurrent_kernel_init: initializer function for the kernels that transform
302        the hidden state (default: orthogonal).
303      bias_init: initializer for the bias parameters (default: zeros)
304    Returns:
305      A tuple with the new carry and the output.
306    """
307    h = carry
308    hidden_features = h.shape[-1]
309    # input and recurrent layers are summed so only one needs a bias.
310    dense_h = linear.Dense.partial(
311        inputs=h, features=hidden_features, bias=False,
312        kernel_init=recurrent_kernel_init, bias_init=bias_init)
313    dense_i = linear.Dense.partial(
314        inputs=inputs, features=hidden_features, bias=True,
315        kernel_init=kernel_init, bias_init=bias_init)
316    r = gate_fn(dense_i(name='ir') + dense_h(name='hr'))
317    z = gate_fn(dense_i(name='iz') + dense_h(name='hz'))
318    # add bias because the linear transformations aren't directly summed.
319    n = activation_fn(dense_i(name='in') + r * dense_h(name='hn', bias=True))
320    new_h = (1. - z) * n + z * h
321    return new_h, new_h
322
323  @staticmethod
324  def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
325    """initialize the RNN cell carry.
326
327    Args:
328      rng: random number generator passed to the init_fn.
329      batch_dims: a tuple providing the shape of the batch dimensions.
330      size: the size or number of features of the memory.
331      init_fn: initializer function for the carry.
332    Returns:
333      An initialized carry for the given RNN cell.
334    """
335    mem_shape = batch_dims + (size,)
336    return init_fn(rng, mem_shape)
337
338
339class ConvLSTM(RNNCellBase):
340  r"""DEPRECATION WARNING:
341  The `flax.nn` module is Deprecated, use `flax.linen` instead.
342  Learn more and find an upgrade guide at
343  https://github.com/google/flax/blob/master/flax/linen/README.md"
344  A convolutional LSTM cell.
345
346  The implementation is based on xingjian2015convolutional.
347  Given x_t and the previous state (h_{t-1}, c_{t-1})
348  the core computes
349
350  .. math::
351
352     \begin{array}{ll}
353     i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\
354     f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\
355     g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\
356     o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\
357     c_t = f_t c_{t-1} + i_t g_t \\
358     h_t = o_t \tanh(c_t)
359     \end{array}
360
361  where * denotes the convolution operator;
362  i_t, f_t, o_t are input, forget and output gate activations,
363  and g_t is a vector of cell updates.
364
365  Notes:
366    Forget gate initialization:
367      Following jozefowicz2015empirical we add 1.0 to b_f
368      after initialization in order to reduce the scale of forgetting in
369      the beginning of the training.
370  """
371
372  def apply(self,
373            carry,
374            inputs,
375            features,
376            kernel_size,
377            strides=None,
378            padding='SAME',
379            bias=True,
380            dtype=jnp.float32):
381    """Constructs a convolutional LSTM.
382
383    Args:
384      carry: the hidden state of the Conv2DLSTM cell,
385        initialized using `Conv2DLSTM.initialize_carry`.
386      inputs: input data with dimensions (batch, spatial_dims..., features).
387      features: number of convolution filters.
388      kernel_size: shape of the convolutional kernel.
389      strides: a sequence of `n` integers, representing the inter-window
390        strides.
391      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
392        of `n` `(low, high)` integer pairs that give the padding to apply before
393        and after each spatial dimension.
394      bias: whether to add a bias to the output (default: True).
395      dtype: the dtype of the computation (default: float32).
396    Returns:
397      A tuple with the new carry and the output.
398    """
399    c, h = carry
400    input_to_hidden = linear.Conv.partial(
401        features=4*features,
402        kernel_size=kernel_size,
403        strides=strides,
404        padding=padding,
405        bias=bias,
406        dtype=dtype,
407        name="ih")
408
409    hidden_to_hidden = linear.Conv.partial(
410        features=4*features,
411        kernel_size=kernel_size,
412        strides=strides,
413        padding=padding,
414        bias=bias,
415        dtype=dtype,
416        name="hh")
417
418    gates = input_to_hidden(inputs) + hidden_to_hidden(h)
419    i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)
420
421    f = activation.sigmoid(f + 1)
422    new_c = f * c + activation.sigmoid(i) * jnp.tanh(g)
423    new_h = activation.sigmoid(o) * jnp.tanh(new_c)
424    return (new_c, new_h), new_h
425
426  @staticmethod
427  def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
428    """initialize the RNN cell carry.
429
430    Args:
431      rng: random number generator passed to the init_fn.
432      batch_dims: a tuple providing the shape of the batch dimensions.
433      size: the input_shape + (features,).
434      init_fn: initializer function for the carry.
435    Returns:
436      An initialized carry for the given RNN cell.
437    """
438    key1, key2 = random.split(rng)
439    mem_shape = batch_dims + size
440    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
441