1# Copyright 2021 The Flax Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""DEPRECATION WARNING: 16 The `flax.nn` module is Deprecated, use `flax.linen` instead. 17 Learn more and find an upgrade guide at 18 https://github.com/google/flax/blob/master/flax/linen/README.md" 19 Recurrent neural network modules. 20 21THe RNNCell modules are designed to fit in with the scan function in JAX:: 22 23 _, initial_params = LSTMCell.init(rng_1, time_series[0]) 24 model = nn.Model(LSTMCell, initial_params) 25 carry = LSTMCell.initialize_carry(rng_2, (batch_size,), memory_size) 26 carry, y = jax.lax.scan(model, carry, time_series) 27 28""" 29 30import abc 31 32from . import activation 33from . import base 34from . import initializers 35from . import linear 36 37from jax import numpy as jnp 38from jax import random 39from jax import lax 40import numpy as np 41 42 43class RNNCellBase(base.Module): 44 """DEPRECATION WARNING: 45 The `flax.nn` module is Deprecated, use `flax.linen` instead. 46 Learn more and find an upgrade guide at 47 https://github.com/google/flax/blob/master/flax/linen/README.md" 48 RNN cell base class.""" 49 50 @staticmethod 51 @abc.abstractmethod 52 def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): 53 """initialize the RNN cell carry. 54 55 Args: 56 rng: random number generator passed to the init_fn. 57 batch_dims: a tuple providing the shape of the batch dimensions. 58 size: the size or number of features of the memory. 59 init_fn: initializer function for the carry. 60 Returns: 61 An initialized carry for the given RNN cell. 62 """ 63 pass 64 65 66class LSTMCell(RNNCellBase): 67 """DEPRECATION WARNING: 68 The `flax.nn` module is Deprecated, use `flax.linen` instead. 69 Learn more and find an upgrade guide at 70 https://github.com/google/flax/blob/master/flax/linen/README.md" 71 LSTM cell.""" 72 73 def apply(self, carry, inputs, 74 gate_fn=activation.sigmoid, activation_fn=activation.tanh, 75 kernel_init=linear.default_kernel_init, 76 recurrent_kernel_init=initializers.orthogonal(), 77 bias_init=initializers.zeros): 78 r"""A long short-term memory (LSTM) cell. 79 80 the mathematical definition of the cell is as follows 81 .. math:: 82 \begin{array}{ll} 83 i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ 84 f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ 85 g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ 86 o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ 87 c' = f * c + i * g \\ 88 h' = o * \tanh(c') \\ 89 \end{array} 90 where x is the input, h is the output of the previous time step, and c is 91 the memory. 92 93 Args: 94 carry: the hidden state of the LSTM cell, 95 initialized using `LSTMCell.initialize_carry`. 96 inputs: an ndarray with the input for the current time step. 97 All dimensions except the final are considered batch dimensions. 98 gate_fn: activation function used for gates (default: sigmoid) 99 activation_fn: activation function used for output and memory update 100 (default: tanh). 101 kernel_init: initializer function for the kernels that transform 102 the input (default: lecun_normal). 103 recurrent_kernel_init: initializer function for the kernels that transform 104 the hidden state (default: orthogonal). 105 bias_init: initializer for the bias parameters (default: zeros) 106 Returns: 107 A tuple with the new carry and the output. 108 """ 109 c, h = carry 110 hidden_features = h.shape[-1] 111 # input and recurrent layers are summed so only one needs a bias. 112 dense_h = linear.Dense.partial( 113 inputs=h, features=hidden_features, bias=True, 114 kernel_init=recurrent_kernel_init, bias_init=bias_init) 115 dense_i = linear.Dense.partial( 116 inputs=inputs, features=hidden_features, bias=False, 117 kernel_init=kernel_init) 118 i = gate_fn(dense_i(name='ii') + dense_h(name='hi')) 119 f = gate_fn(dense_i(name='if') + dense_h(name='hf')) 120 g = activation_fn(dense_i(name='ig') + dense_h(name='hg')) 121 o = gate_fn(dense_i(name='io') + dense_h(name='ho')) 122 new_c = f * c + i * g 123 new_h = o * activation_fn(new_c) 124 return (new_c, new_h), new_h 125 126 @staticmethod 127 def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): 128 """initialize the RNN cell carry. 129 130 Args: 131 rng: random number generator passed to the init_fn. 132 batch_dims: a tuple providing the shape of the batch dimensions. 133 size: the size or number of features of the memory. 134 init_fn: initializer function for the carry. 135 Returns: 136 An initialized carry for the given RNN cell. 137 """ 138 key1, key2 = random.split(rng) 139 mem_shape = batch_dims + (size,) 140 return init_fn(key1, mem_shape), init_fn(key2, mem_shape) 141 142 143class OptimizedLSTMCell(RNNCellBase): 144 """DEPRECATION WARNING: 145 The `flax.nn` module is Deprecated, use `flax.linen` instead. 146 Learn more and find an upgrade guide at 147 https://github.com/google/flax/blob/master/flax/linen/README.md" 148 More efficient LSTM Cell that concatenates state components before matmul. 149 150 Parameters are compatible with `flax.nn.LSTMCell`. 151 """ 152 153 class DummyDense(base.Module): 154 """Dummy module for creating parameters matching `flax.nn.Dense`.""" 155 156 def apply(self, 157 inputs, 158 features, 159 kernel_init, 160 bias_init, 161 bias=True): 162 k = self.param('kernel', (inputs.shape[-1], features), kernel_init) 163 b = (self.param('bias', (features,), bias_init) 164 if bias else jnp.zeros((features,))) 165 return k, b 166 167 def apply(self, 168 carry, 169 inputs, 170 gate_fn=activation.sigmoid, 171 activation_fn=activation.tanh, 172 kernel_init=linear.default_kernel_init, 173 recurrent_kernel_init=initializers.orthogonal(), 174 bias_init=initializers.zeros): 175 r"""A long short-term memory (LSTM) cell. 176 177 the mathematical definition of the cell is as follows 178 .. math:: 179 \begin{array}{ll} 180 i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ 181 f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ 182 g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ 183 o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ 184 c' = f * c + i * g \\ 185 h' = o * \tanh(c') \\ 186 \end{array} 187 where x is the input, h is the output of the previous time step, and c is 188 the memory. 189 190 Args: 191 carry: the hidden state of the LSTM cell, initialized using 192 `LSTMCell.initialize_carry`. 193 inputs: an ndarray with the input for the current time step. All 194 dimensions except the final are considered batch dimensions. 195 gate_fn: activation function used for gates (default: sigmoid) 196 activation_fn: activation function used for output and memory update 197 (default: tanh). 198 kernel_init: initializer function for the kernels that transform 199 the input (default: lecun_normal). 200 recurrent_kernel_init: initializer function for the kernels that transform 201 the hidden state (default: orthogonal). 202 bias_init: initializer for the bias parameters (default: zeros) 203 204 Returns: 205 A tuple with the new carry and the output. 206 """ 207 c, h = carry 208 hidden_features = h.shape[-1] 209 210 def _concat_dense(inputs, params, use_bias=True): 211 kernels, biases = zip(*params.values()) 212 kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), jnp.float32) 213 214 y = jnp.dot(inputs, kernel) 215 if use_bias: 216 bias = jnp.asarray(jnp.concatenate(biases, axis=-1), jnp.float32) 217 y = y + bias 218 219 # Split the result back into individual (i, f, g, o) outputs. 220 split_indices = np.cumsum([b.shape[0] for b in biases[:-1]]) 221 ys = jnp.split(y, split_indices, axis=-1) 222 return dict(zip(params.keys(), ys)) 223 224 # Create the params in the same order as LSTMCell for initialization 225 # compatibility. 226 dense_params_h = {} 227 dense_params_i = {} 228 for component in ['i', 'f', 'g', 'o']: 229 dense_params_i[component] = OptimizedLSTMCell.DummyDense( 230 inputs=inputs, features=hidden_features, bias=False, 231 kernel_init=kernel_init, bias_init=bias_init, 232 name=f'i{component}') 233 dense_params_h[component] = OptimizedLSTMCell.DummyDense( 234 inputs=h, features=hidden_features, bias=True, 235 kernel_init=recurrent_kernel_init, bias_init=bias_init, 236 name=f'h{component}') 237 dense_h = _concat_dense(h, dense_params_h, use_bias=True) 238 dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) 239 240 i = gate_fn(dense_h['i'] + dense_i['i']) 241 f = gate_fn(dense_h['f'] + dense_i['f']) 242 g = activation_fn(dense_h['g'] + dense_i['g']) 243 o = gate_fn(dense_h['o'] + dense_i['o']) 244 245 new_c = f * c + i * g 246 new_h = o * activation_fn(new_c) 247 return (new_c, new_h), new_h 248 249 @staticmethod 250 def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): 251 """initialize the RNN cell carry. 252 253 Args: 254 rng: random number generator passed to the init_fn. 255 batch_dims: a tuple providing the shape of the batch dimensions. 256 size: the size or number of features of the memory. 257 init_fn: initializer function for the carry. 258 259 Returns: 260 An initialized carry for the given RNN cell. 261 """ 262 key1, key2 = random.split(rng) 263 mem_shape = batch_dims + (size,) 264 return init_fn(key1, mem_shape), init_fn(key2, mem_shape) 265 266 267class GRUCell(RNNCellBase): 268 """DEPRECATION WARNING: 269 The `flax.nn` module is Deprecated, use `flax.linen` instead. 270 Learn more and find an upgrade guide at 271 https://github.com/google/flax/blob/master/flax/linen/README.md" 272 GRU cell.""" 273 274 def apply(self, carry, inputs, 275 gate_fn=activation.sigmoid, activation_fn=activation.tanh, 276 kernel_init=linear.default_kernel_init, 277 recurrent_kernel_init=initializers.orthogonal(), 278 bias_init=initializers.zeros): 279 r"""Gated recurrent unit (GRU) cell. 280 281 the mathematical definition of the cell is as follows 282 .. math:: 283 \begin{array}{ll} 284 r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ 285 z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ 286 n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ 287 h' = (1 - z) * n + z * h 288 \end{array} 289 where x is the input and h, is the output of the previous time step. 290 291 Args: 292 carry: the hidden state of the LSTM cell, 293 initialized using `GRUCell.initialize_carry`. 294 inputs: an ndarray with the input for the current time step. 295 All dimensions except the final are considered batch dimensions. 296 gate_fn: activation function used for gates (default: sigmoid) 297 activation_fn: activation function used for output and memory update 298 (default: tanh). 299 kernel_init: initializer function for the kernels that transform 300 the input (default: lecun_normal). 301 recurrent_kernel_init: initializer function for the kernels that transform 302 the hidden state (default: orthogonal). 303 bias_init: initializer for the bias parameters (default: zeros) 304 Returns: 305 A tuple with the new carry and the output. 306 """ 307 h = carry 308 hidden_features = h.shape[-1] 309 # input and recurrent layers are summed so only one needs a bias. 310 dense_h = linear.Dense.partial( 311 inputs=h, features=hidden_features, bias=False, 312 kernel_init=recurrent_kernel_init, bias_init=bias_init) 313 dense_i = linear.Dense.partial( 314 inputs=inputs, features=hidden_features, bias=True, 315 kernel_init=kernel_init, bias_init=bias_init) 316 r = gate_fn(dense_i(name='ir') + dense_h(name='hr')) 317 z = gate_fn(dense_i(name='iz') + dense_h(name='hz')) 318 # add bias because the linear transformations aren't directly summed. 319 n = activation_fn(dense_i(name='in') + r * dense_h(name='hn', bias=True)) 320 new_h = (1. - z) * n + z * h 321 return new_h, new_h 322 323 @staticmethod 324 def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): 325 """initialize the RNN cell carry. 326 327 Args: 328 rng: random number generator passed to the init_fn. 329 batch_dims: a tuple providing the shape of the batch dimensions. 330 size: the size or number of features of the memory. 331 init_fn: initializer function for the carry. 332 Returns: 333 An initialized carry for the given RNN cell. 334 """ 335 mem_shape = batch_dims + (size,) 336 return init_fn(rng, mem_shape) 337 338 339class ConvLSTM(RNNCellBase): 340 r"""DEPRECATION WARNING: 341 The `flax.nn` module is Deprecated, use `flax.linen` instead. 342 Learn more and find an upgrade guide at 343 https://github.com/google/flax/blob/master/flax/linen/README.md" 344 A convolutional LSTM cell. 345 346 The implementation is based on xingjian2015convolutional. 347 Given x_t and the previous state (h_{t-1}, c_{t-1}) 348 the core computes 349 350 .. math:: 351 352 \begin{array}{ll} 353 i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ 354 f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ 355 g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ 356 o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ 357 c_t = f_t c_{t-1} + i_t g_t \\ 358 h_t = o_t \tanh(c_t) 359 \end{array} 360 361 where * denotes the convolution operator; 362 i_t, f_t, o_t are input, forget and output gate activations, 363 and g_t is a vector of cell updates. 364 365 Notes: 366 Forget gate initialization: 367 Following jozefowicz2015empirical we add 1.0 to b_f 368 after initialization in order to reduce the scale of forgetting in 369 the beginning of the training. 370 """ 371 372 def apply(self, 373 carry, 374 inputs, 375 features, 376 kernel_size, 377 strides=None, 378 padding='SAME', 379 bias=True, 380 dtype=jnp.float32): 381 """Constructs a convolutional LSTM. 382 383 Args: 384 carry: the hidden state of the Conv2DLSTM cell, 385 initialized using `Conv2DLSTM.initialize_carry`. 386 inputs: input data with dimensions (batch, spatial_dims..., features). 387 features: number of convolution filters. 388 kernel_size: shape of the convolutional kernel. 389 strides: a sequence of `n` integers, representing the inter-window 390 strides. 391 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence 392 of `n` `(low, high)` integer pairs that give the padding to apply before 393 and after each spatial dimension. 394 bias: whether to add a bias to the output (default: True). 395 dtype: the dtype of the computation (default: float32). 396 Returns: 397 A tuple with the new carry and the output. 398 """ 399 c, h = carry 400 input_to_hidden = linear.Conv.partial( 401 features=4*features, 402 kernel_size=kernel_size, 403 strides=strides, 404 padding=padding, 405 bias=bias, 406 dtype=dtype, 407 name="ih") 408 409 hidden_to_hidden = linear.Conv.partial( 410 features=4*features, 411 kernel_size=kernel_size, 412 strides=strides, 413 padding=padding, 414 bias=bias, 415 dtype=dtype, 416 name="hh") 417 418 gates = input_to_hidden(inputs) + hidden_to_hidden(h) 419 i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1) 420 421 f = activation.sigmoid(f + 1) 422 new_c = f * c + activation.sigmoid(i) * jnp.tanh(g) 423 new_h = activation.sigmoid(o) * jnp.tanh(new_c) 424 return (new_c, new_h), new_h 425 426 @staticmethod 427 def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): 428 """initialize the RNN cell carry. 429 430 Args: 431 rng: random number generator passed to the init_fn. 432 batch_dims: a tuple providing the shape of the batch dimensions. 433 size: the input_shape + (features,). 434 init_fn: initializer function for the carry. 435 Returns: 436 An initialized carry for the given RNN cell. 437 """ 438 key1, key2 = random.split(rng) 439 mem_shape = batch_dims + size 440 return init_fn(key1, mem_shape), init_fn(key2, mem_shape) 441