1# Copyright 2021 The Flax Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Linear modules."""
16
17from collections.abc import Iterable  # pylint: disable=g-importing-member
18
19from flax.nn import initializers
20
21from flax.core import Scope
22
23from flax import struct
24
25from jax import lax
26
27import jax.numpy as jnp
28import numpy as np
29
30
31default_kernel_init = initializers.lecun_normal()
32
33
34def _normalize_axes(axes, ndim):
35  # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
36  return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
37
38
39def dense_general(
40    scope,
41    inputs,
42    features,
43    axis=-1,
44    batch_dims=(),
45    bias=True,
46    dtype=jnp.float32,
47    kernel_init=default_kernel_init,
48    bias_init=initializers.zeros,
49    precision=None):
50  """Applies a linear transformation to the inputs along multiple dimensions.
51
52  Args:
53    inputs: The nd-array to be transformed.
54    features: tuple with numbers of output features.
55    axis: tuple with axes to apply the transformation on.
56    batch_dims: tuple with batch axes.
57    bias: whether to add a bias to the output (default: True).
58    dtype: the dtype of the computation (default: float32).
59    kernel_init: initializer function for the weight matrix.
60    bias_init: initializer function for the bias.
61    precision: numerical precision of the computation see `jax.lax.Precision`
62      for details.
63  Returns:
64    The transformed input.
65  """
66  inputs = jnp.asarray(inputs, dtype)
67
68  if not isinstance(features, Iterable):
69    features = (features,)
70  if not isinstance(axis, Iterable):
71    axis = (axis,)
72  if not isinstance(batch_dims, Iterable):
73    batch_dims = (batch_dims,)
74  features, axis, batch_dims = tuple(features), tuple(axis), tuple(batch_dims)
75
76  if batch_dims:
77    max_dim = np.max(batch_dims)
78    if set(batch_dims) != set(range(max_dim + 1)):
79      raise ValueError('batch_dims %s must be consecutive leading '
80                        'dimensions starting from 0.' % str(batch_dims))
81
82  ndim = inputs.ndim
83  n_batch_dims = len(batch_dims)
84  axis = _normalize_axes(axis, ndim)
85  batch_dims = _normalize_axes(batch_dims, ndim)
86  n_axis, n_features = len(axis), len(features)
87
88  def kernel_init_wrap(rng, shape, dtype=jnp.float32):
89    size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
90    flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
91                  np.prod(shape[-n_features:]),)
92    kernel = jnp.concatenate([kernel_init(rng, flat_shape, dtype)
93                              for _ in range(size_batch_dims)], axis=0)
94    return jnp.reshape(kernel, shape)
95
96  batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
97  kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
98  kernel = scope.param('kernel', kernel_init_wrap, batch_shape + kernel_shape)
99  kernel = jnp.asarray(kernel, dtype)
100
101  batch_ind = tuple(range(n_batch_dims))
102  contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
103  out = lax.dot_general(inputs,
104                        kernel,
105                        ((axis, contract_ind), (batch_dims, batch_ind)),
106                        precision=precision)
107  if bias:
108    def bias_init_wrap(rng, shape, dtype=jnp.float32):
109      size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
110      flat_shape = (np.prod(shape[-n_features:]),)
111      bias = jnp.concatenate([bias_init(rng, flat_shape, dtype)
112                              for _ in range(size_batch_dims)], axis=0)
113      return jnp.reshape(bias, shape)
114
115    bias = scope.param('bias', bias_init_wrap, batch_shape + features)
116
117    # Reshape bias for broadcast.
118    expand_dims = sorted(
119        set(range(inputs.ndim)) - set(axis) - set(batch_dims))
120    for ax in expand_dims:
121      bias = jnp.expand_dims(bias, ax)
122    bias = jnp.asarray(bias, dtype)
123    out = out + bias
124  return out
125
126
127def dense(scope,
128          inputs,
129          features,
130          bias=True,
131          dtype=jnp.float32,
132          precision=None,
133          kernel_init=default_kernel_init,
134          bias_init=initializers.zeros):
135  """Applies a linear transformation to the inputs along the last dimension.
136
137  Args:
138    inputs: The nd-array to be transformed.
139    features: the number of output features.
140    bias: whether to add a bias to the output (default: True).
141    dtype: the dtype of the computation (default: float32).
142    precision: numerical precision of the computation see `jax.lax.Precision`
143      for details.
144    kernel_init: initializer function for the weight matrix.
145    bias_init: initializer function for the bias.
146  Returns:
147    The transformed input.
148  """
149  inputs = jnp.asarray(inputs, dtype)
150  kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
151  kernel = jnp.asarray(kernel, dtype)
152  y = lax.dot_general(inputs, kernel,
153                      (((inputs.ndim - 1,), (0,)), ((), ())),
154                      precision=precision)
155  if bias:
156    bias = scope.param('bias', bias_init, (features,))
157    bias = jnp.asarray(bias, dtype)
158    y = y + bias
159  return y
160
161
162def _conv_dimension_numbers(input_shape):
163  """Computes the dimension numbers based on the input shape."""
164  ndim = len(input_shape)
165  lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
166  rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
167  out_spec = lhs_spec
168  return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
169
170
171def conv(scope,
172         inputs,
173         features,
174         kernel_size,
175         strides=None,
176         padding='SAME',
177         input_dilation=None,
178         kernel_dilation=None,
179         feature_group_count=1,
180         bias=True,
181         dtype=jnp.float32,
182         precision=None,
183         kernel_init=default_kernel_init,
184         bias_init=initializers.zeros):
185  """Applies a convolution to the inputs.
186
187  Args:
188    inputs: input data with dimensions (batch, spatial_dims..., features).
189    features: number of convolution filters.
190    kernel_size: shape of the convolutional kernel.
191    strides: a sequence of `n` integers, representing the inter-window
192      strides.
193    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
194      of `n` `(low, high)` integer pairs that give the padding to apply before
195      and after each spatial dimension.
196    input_dilation: `None`, or a sequence of `n` integers, giving the
197      dilation factor to apply in each spatial dimension of `inputs`.
198      Convolution with input dilation `d` is equivalent to transposed
199      convolution with stride `d`.
200    kernel_dilation: `None`, or a sequence of `n` integers, giving the
201      dilation factor to apply in each spatial dimension of the convolution
202      kernel. Convolution with kernel dilation is also known as 'atrous
203      convolution'.
204    feature_group_count: integer, default 1. If specified divides the input
205      features into groups.
206    bias: whether to add a bias to the output (default: True).
207    dtype: the dtype of the computation (default: float32).
208    precision: numerical precision of the computation see `jax.lax.Precision`
209      for details.
210    kernel_init: initializer for the convolutional kernel.
211    bias_init: initializer for the bias.
212  Returns:
213    The convolved data.
214  """
215
216  inputs = jnp.asarray(inputs, dtype)
217
218  if strides is None:
219    strides = (1,) * (inputs.ndim - 2)
220
221  in_features = inputs.shape[-1]
222  assert in_features % feature_group_count == 0
223  kernel_shape = kernel_size + (in_features // feature_group_count, features)
224  kernel = scope.param('kernel', kernel_init, kernel_shape)
225  kernel = jnp.asarray(kernel, dtype)
226
227  dimension_numbers = _conv_dimension_numbers(inputs.shape)
228  y = lax.conv_general_dilated(
229      inputs,
230      kernel,
231      strides,
232      padding,
233      lhs_dilation=input_dilation,
234      rhs_dilation=kernel_dilation,
235      dimension_numbers=dimension_numbers,
236      feature_group_count=feature_group_count,
237      precision=precision)
238
239  if bias:
240    bias = scope.param('bias', bias_init, (features,))
241    bias = jnp.asarray(bias, dtype)
242    y = y + bias
243  return y
244
245
246def conv_transpose(scope,
247                   inputs,
248                   features,
249                   kernel_size,
250                   strides=None,
251                   padding='SAME',
252                   kernel_dilation=None,
253                   bias=True,
254                   dtype=jnp.float32,
255                   precision=None,
256                   kernel_init=default_kernel_init,
257                   bias_init=initializers.zeros):
258  """Applies a transposed convolution to the inputs. Behaviour mirrors that of
259  `jax.lax.conv_transpose`.
260
261  Args:
262    scope: functional scope.
263    inputs: input data with dimensions (batch, spatial_dims..., features).
264    features: number of convolution filters.
265    kernel_size: shape of the convolutional kernel.
266    strides: a sequence of `n` integers, representing the inter-window
267      strides.
268    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
269      of `n` `(low, high)` integer pairs that give the padding to apply before
270      and after each spatial dimension.
271    kernel_dilation: `None`, or a sequence of `n` integers, giving the
272      dilation factor to apply in each spatial dimension of the convolution
273      kernel. Convolution with kernel dilation is also known as 'atrous
274      convolution'.
275    bias: whether to add a bias to the output (default: True).
276    dtype: the dtype of the computation (default: float32).
277    precision: numerical precision of the computation see `jax.lax.Precision`
278      for details.
279    kernel_init: initializer for the convolutional kernel.
280    bias_init: initializer for the bias.
281  Returns:
282    The convolved data.
283  """
284  inputs = jnp.asarray(inputs, dtype)
285  strides = strides or (1,) * (inputs.ndim - 2)
286
287  in_features = inputs.shape[-1]
288  kernel_shape = kernel_size + (in_features, features)
289  kernel = scope.param('kernel', kernel_init, kernel_shape)
290  kernel = jnp.asarray(kernel, dtype)
291
292  y = lax.conv_transpose(inputs, kernel, strides, padding,
293                          rhs_dilation=kernel_dilation, precision=precision)
294
295  if bias:
296    bias = scope.param('bias', bias_init, (features,))
297    bias = jnp.asarray(bias, dtype)
298    y = y + bias
299  return y
300
301
302default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal',
303                                                   out_axis=0)
304
305
306@struct.dataclass
307class Embedding:
308  table: np.ndarray
309
310  def lookup(self, indices):
311    """Embeds the inputs along the last dimension.
312
313    Args:
314      indices: input data, all dimensions are considered batch dimensions.
315
316    Returns:
317      Output which is embedded input data.  The output shape follows the input,
318      with an additional `features` dimension appended.
319    """
320    if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
321      raise ValueError('Input type must be an integer or unsigned integer.')
322    return self.table[indices]
323
324  def attend(self, query):
325    """Attend over the embedding using a query array.
326
327    Args:
328      query: array with last dimension equal the feature depth `features` of the
329        embedding.
330
331    Returns:
332      An array with final dim `num_embeddings` corresponding to the batched
333      inner-product of the array of query vectors against each embedding.
334      Commonly used for weight-sharing between embeddings and logit transform
335      in NLP models.
336    """
337    return jnp.dot(query, self.table.T)
338
339
340def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init) -> Embedding:
341  """Creates embedding dataclass.
342
343    Args:
344      num_embeddings: number of embeddings.
345      features: Number of feature dimensions for each embedding.
346      embedding_init: embedding initializer.
347
348    Returns:
349      Embedding dataclass with lookup and attend methods.
350  """
351  table = scope.param('table', init_fn, (num_embeddings, features))
352  return Embedding(table)
353