1# Copyright 2021 The Flax Authors.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Linear modules."""
17from collections.abc import Iterable  # pylint: disable=g-importing-member
19from flax.nn import initializers
21from flax.core import Scope
23from flax import struct
25from jax import lax
27import jax.numpy as jnp
28import numpy as np
31default_kernel_init = initializers.lecun_normal()
34def _normalize_axes(axes, ndim):
35  # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
36  return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
39def dense_general(
40    scope,
41    inputs,
42    features,
43    axis=-1,
44    batch_dims=(),
45    bias=True,
46    dtype=jnp.float32,
47    kernel_init=default_kernel_init,
48    bias_init=initializers.zeros,
49    precision=None):
50  """Applies a linear transformation to the inputs along multiple dimensions.
52  Args:
53    inputs: The nd-array to be transformed.
54    features: tuple with numbers of output features.
55    axis: tuple with axes to apply the transformation on.
56    batch_dims: tuple with batch axes.
57    bias: whether to add a bias to the output (default: True).
58    dtype: the dtype of the computation (default: float32).
59    kernel_init: initializer function for the weight matrix.
60    bias_init: initializer function for the bias.
61    precision: numerical precision of the computation see `jax.lax.Precision`
62      for details.
63  Returns:
64    The transformed input.
65  """
66  inputs = jnp.asarray(inputs, dtype)
68  if not isinstance(features, Iterable):
69    features = (features,)
70  if not isinstance(axis, Iterable):
71    axis = (axis,)
72  if not isinstance(batch_dims, Iterable):
73    batch_dims = (batch_dims,)
74  features, axis, batch_dims = tuple(features), tuple(axis), tuple(batch_dims)
76  if batch_dims:
77    max_dim = np.max(batch_dims)
78    if set(batch_dims) != set(range(max_dim + 1)):
79      raise ValueError('batch_dims %s must be consecutive leading '
80                        'dimensions starting from 0.' % str(batch_dims))
82  ndim = inputs.ndim
83  n_batch_dims = len(batch_dims)
84  axis = _normalize_axes(axis, ndim)
85  batch_dims = _normalize_axes(batch_dims, ndim)
86  n_axis, n_features = len(axis), len(features)
88  def kernel_init_wrap(rng, shape, dtype=jnp.float32):
89    size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
90    flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
91                  np.prod(shape[-n_features:]),)
92    kernel = jnp.concatenate([kernel_init(rng, flat_shape, dtype)
93                              for _ in range(size_batch_dims)], axis=0)
94    return jnp.reshape(kernel, shape)
96  batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
97  kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
98  kernel = scope.param('kernel', kernel_init_wrap, batch_shape + kernel_shape)
99  kernel = jnp.asarray(kernel, dtype)
101  batch_ind = tuple(range(n_batch_dims))
102  contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
103  out = lax.dot_general(inputs,
104                        kernel,
105                        ((axis, contract_ind), (batch_dims, batch_ind)),
106                        precision=precision)
107  if bias:
108    def bias_init_wrap(rng, shape, dtype=jnp.float32):
109      size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
110      flat_shape = (np.prod(shape[-n_features:]),)
111      bias = jnp.concatenate([bias_init(rng, flat_shape, dtype)
112                              for _ in range(size_batch_dims)], axis=0)
113      return jnp.reshape(bias, shape)
115    bias = scope.param('bias', bias_init_wrap, batch_shape + features)
117    # Reshape bias for broadcast.
118    expand_dims = sorted(
119        set(range(inputs.ndim)) - set(axis) - set(batch_dims))
120    for ax in expand_dims:
121      bias = jnp.expand_dims(bias, ax)
122    bias = jnp.asarray(bias, dtype)
123    out = out + bias
124  return out
127def dense(scope,
128          inputs,
129          features,
130          bias=True,
131          dtype=jnp.float32,
132          precision=None,
133          kernel_init=default_kernel_init,
134          bias_init=initializers.zeros):
135  """Applies a linear transformation to the inputs along the last dimension.
137  Args:
138    inputs: The nd-array to be transformed.
139    features: the number of output features.
140    bias: whether to add a bias to the output (default: True).
141    dtype: the dtype of the computation (default: float32).
142    precision: numerical precision of the computation see `jax.lax.Precision`
143      for details.
144    kernel_init: initializer function for the weight matrix.
145    bias_init: initializer function for the bias.
146  Returns:
147    The transformed input.
148  """
149  inputs = jnp.asarray(inputs, dtype)
150  kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
151  kernel = jnp.asarray(kernel, dtype)
152  y = lax.dot_general(inputs, kernel,
153                      (((inputs.ndim - 1,), (0,)), ((), ())),
154                      precision=precision)
155  if bias:
156    bias = scope.param('bias', bias_init, (features,))
157    bias = jnp.asarray(bias, dtype)
158    y = y + bias
159  return y
162def _conv_dimension_numbers(input_shape):
163  """Computes the dimension numbers based on the input shape."""
164  ndim = len(input_shape)
165  lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
166  rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
167  out_spec = lhs_spec
168  return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
171def conv(scope,
172         inputs,
173         features,
174         kernel_size,
175         strides=None,
176         padding='SAME',
177         input_dilation=None,
178         kernel_dilation=None,
179         feature_group_count=1,
180         bias=True,
181         dtype=jnp.float32,
182         precision=None,
183         kernel_init=default_kernel_init,
184         bias_init=initializers.zeros):
185  """Applies a convolution to the inputs.
187  Args:
188    inputs: input data with dimensions (batch, spatial_dims..., features).
189    features: number of convolution filters.
190    kernel_size: shape of the convolutional kernel.
191    strides: a sequence of `n` integers, representing the inter-window
192      strides.
193    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
194      of `n` `(low, high)` integer pairs that give the padding to apply before
195      and after each spatial dimension.
196    input_dilation: `None`, or a sequence of `n` integers, giving the
197      dilation factor to apply in each spatial dimension of `inputs`.
198      Convolution with input dilation `d` is equivalent to transposed
199      convolution with stride `d`.
200    kernel_dilation: `None`, or a sequence of `n` integers, giving the
201      dilation factor to apply in each spatial dimension of the convolution
202      kernel. Convolution with kernel dilation is also known as 'atrous
203      convolution'.
204    feature_group_count: integer, default 1. If specified divides the input
205      features into groups.
206    bias: whether to add a bias to the output (default: True).
207    dtype: the dtype of the computation (default: float32).
208    precision: numerical precision of the computation see `jax.lax.Precision`
209      for details.
210    kernel_init: initializer for the convolutional kernel.
211    bias_init: initializer for the bias.
212  Returns:
213    The convolved data.
214  """
216  inputs = jnp.asarray(inputs, dtype)
218  if strides is None:
219    strides = (1,) * (inputs.ndim - 2)
221  in_features = inputs.shape[-1]
222  assert in_features % feature_group_count == 0
223  kernel_shape = kernel_size + (in_features // feature_group_count, features)
224  kernel = scope.param('kernel', kernel_init, kernel_shape)
225  kernel = jnp.asarray(kernel, dtype)
227  dimension_numbers = _conv_dimension_numbers(inputs.shape)
228  y = lax.conv_general_dilated(
229      inputs,
230      kernel,
231      strides,
232      padding,
233      lhs_dilation=input_dilation,
234      rhs_dilation=kernel_dilation,
235      dimension_numbers=dimension_numbers,
236      feature_group_count=feature_group_count,
237      precision=precision)
239  if bias:
240    bias = scope.param('bias', bias_init, (features,))
241    bias = jnp.asarray(bias, dtype)
242    y = y + bias
243  return y
246def conv_transpose(scope,
247                   inputs,
248                   features,
249                   kernel_size,
250                   strides=None,
251                   padding='SAME',
252                   kernel_dilation=None,
253                   bias=True,
254                   dtype=jnp.float32,
255                   precision=None,
256                   kernel_init=default_kernel_init,
257                   bias_init=initializers.zeros):
258  """Applies a transposed convolution to the inputs. Behaviour mirrors that of
259  `jax.lax.conv_transpose`.
261  Args:
262    scope: functional scope.
263    inputs: input data with dimensions (batch, spatial_dims..., features).
264    features: number of convolution filters.
265    kernel_size: shape of the convolutional kernel.
266    strides: a sequence of `n` integers, representing the inter-window
267      strides.
268    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
269      of `n` `(low, high)` integer pairs that give the padding to apply before
270      and after each spatial dimension.
271    kernel_dilation: `None`, or a sequence of `n` integers, giving the
272      dilation factor to apply in each spatial dimension of the convolution
273      kernel. Convolution with kernel dilation is also known as 'atrous
274      convolution'.
275    bias: whether to add a bias to the output (default: True).
276    dtype: the dtype of the computation (default: float32).
277    precision: numerical precision of the computation see `jax.lax.Precision`
278      for details.
279    kernel_init: initializer for the convolutional kernel.
280    bias_init: initializer for the bias.
281  Returns:
282    The convolved data.
283  """
284  inputs = jnp.asarray(inputs, dtype)
285  strides = strides or (1,) * (inputs.ndim - 2)
287  in_features = inputs.shape[-1]
288  kernel_shape = kernel_size + (in_features, features)
289  kernel = scope.param('kernel', kernel_init, kernel_shape)
290  kernel = jnp.asarray(kernel, dtype)
292  y = lax.conv_transpose(inputs, kernel, strides, padding,
293                          rhs_dilation=kernel_dilation, precision=precision)
295  if bias:
296    bias = scope.param('bias', bias_init, (features,))
297    bias = jnp.asarray(bias, dtype)
298    y = y + bias
299  return y
302default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal',
303                                                   out_axis=0)
307class Embedding:
308  table: np.ndarray
310  def lookup(self, indices):
311    """Embeds the inputs along the last dimension.
313    Args:
314      indices: input data, all dimensions are considered batch dimensions.
316    Returns:
317      Output which is embedded input data.  The output shape follows the input,
318      with an additional `features` dimension appended.
319    """
320    if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
321      raise ValueError('Input type must be an integer or unsigned integer.')
322    return self.table[indices]
324  def attend(self, query):
325    """Attend over the embedding using a query array.
327    Args:
328      query: array with last dimension equal the feature depth `features` of the
329        embedding.
331    Returns:
332      An array with final dim `num_embeddings` corresponding to the batched
333      inner-product of the array of query vectors against each embedding.
334      Commonly used for weight-sharing between embeddings and logit transform
335      in NLP models.
336    """
337    return jnp.dot(query, self.table.T)
340def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init) -> Embedding:
341  """Creates embedding dataclass.
343    Args:
344      num_embeddings: number of embeddings.
345      features: Number of feature dimensions for each embedding.
346      embedding_init: embedding initializer.
348    Returns:
349      Embedding dataclass with lookup and attend methods.
350  """
351  table = scope.param('table', init_fn, (num_embeddings, features))
352  return Embedding(table)