1# Copyright 2021 The Flax Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Linear modules.""" 16 17from collections.abc import Iterable # pylint: disable=g-importing-member 18 19from flax.nn import initializers 20 21from flax.core import Scope 22 23from flax import struct 24 25from jax import lax 26 27import jax.numpy as jnp 28import numpy as np 29 30 31default_kernel_init = initializers.lecun_normal() 32 33 34def _normalize_axes(axes, ndim): 35 # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. 36 return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) 37 38 39def dense_general( 40 scope, 41 inputs, 42 features, 43 axis=-1, 44 batch_dims=(), 45 bias=True, 46 dtype=jnp.float32, 47 kernel_init=default_kernel_init, 48 bias_init=initializers.zeros, 49 precision=None): 50 """Applies a linear transformation to the inputs along multiple dimensions. 51 52 Args: 53 inputs: The nd-array to be transformed. 54 features: tuple with numbers of output features. 55 axis: tuple with axes to apply the transformation on. 56 batch_dims: tuple with batch axes. 57 bias: whether to add a bias to the output (default: True). 58 dtype: the dtype of the computation (default: float32). 59 kernel_init: initializer function for the weight matrix. 60 bias_init: initializer function for the bias. 61 precision: numerical precision of the computation see `jax.lax.Precision` 62 for details. 63 Returns: 64 The transformed input. 65 """ 66 inputs = jnp.asarray(inputs, dtype) 67 68 if not isinstance(features, Iterable): 69 features = (features,) 70 if not isinstance(axis, Iterable): 71 axis = (axis,) 72 if not isinstance(batch_dims, Iterable): 73 batch_dims = (batch_dims,) 74 features, axis, batch_dims = tuple(features), tuple(axis), tuple(batch_dims) 75 76 if batch_dims: 77 max_dim = np.max(batch_dims) 78 if set(batch_dims) != set(range(max_dim + 1)): 79 raise ValueError('batch_dims %s must be consecutive leading ' 80 'dimensions starting from 0.' % str(batch_dims)) 81 82 ndim = inputs.ndim 83 n_batch_dims = len(batch_dims) 84 axis = _normalize_axes(axis, ndim) 85 batch_dims = _normalize_axes(batch_dims, ndim) 86 n_axis, n_features = len(axis), len(features) 87 88 def kernel_init_wrap(rng, shape, dtype=jnp.float32): 89 size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) 90 flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), 91 np.prod(shape[-n_features:]),) 92 kernel = jnp.concatenate([kernel_init(rng, flat_shape, dtype) 93 for _ in range(size_batch_dims)], axis=0) 94 return jnp.reshape(kernel, shape) 95 96 batch_shape = tuple([inputs.shape[ax] for ax in batch_dims]) 97 kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features 98 kernel = scope.param('kernel', kernel_init_wrap, batch_shape + kernel_shape) 99 kernel = jnp.asarray(kernel, dtype) 100 101 batch_ind = tuple(range(n_batch_dims)) 102 contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) 103 out = lax.dot_general(inputs, 104 kernel, 105 ((axis, contract_ind), (batch_dims, batch_ind)), 106 precision=precision) 107 if bias: 108 def bias_init_wrap(rng, shape, dtype=jnp.float32): 109 size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) 110 flat_shape = (np.prod(shape[-n_features:]),) 111 bias = jnp.concatenate([bias_init(rng, flat_shape, dtype) 112 for _ in range(size_batch_dims)], axis=0) 113 return jnp.reshape(bias, shape) 114 115 bias = scope.param('bias', bias_init_wrap, batch_shape + features) 116 117 # Reshape bias for broadcast. 118 expand_dims = sorted( 119 set(range(inputs.ndim)) - set(axis) - set(batch_dims)) 120 for ax in expand_dims: 121 bias = jnp.expand_dims(bias, ax) 122 bias = jnp.asarray(bias, dtype) 123 out = out + bias 124 return out 125 126 127def dense(scope, 128 inputs, 129 features, 130 bias=True, 131 dtype=jnp.float32, 132 precision=None, 133 kernel_init=default_kernel_init, 134 bias_init=initializers.zeros): 135 """Applies a linear transformation to the inputs along the last dimension. 136 137 Args: 138 inputs: The nd-array to be transformed. 139 features: the number of output features. 140 bias: whether to add a bias to the output (default: True). 141 dtype: the dtype of the computation (default: float32). 142 precision: numerical precision of the computation see `jax.lax.Precision` 143 for details. 144 kernel_init: initializer function for the weight matrix. 145 bias_init: initializer function for the bias. 146 Returns: 147 The transformed input. 148 """ 149 inputs = jnp.asarray(inputs, dtype) 150 kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features)) 151 kernel = jnp.asarray(kernel, dtype) 152 y = lax.dot_general(inputs, kernel, 153 (((inputs.ndim - 1,), (0,)), ((), ())), 154 precision=precision) 155 if bias: 156 bias = scope.param('bias', bias_init, (features,)) 157 bias = jnp.asarray(bias, dtype) 158 y = y + bias 159 return y 160 161 162def _conv_dimension_numbers(input_shape): 163 """Computes the dimension numbers based on the input shape.""" 164 ndim = len(input_shape) 165 lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) 166 rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) 167 out_spec = lhs_spec 168 return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) 169 170 171def conv(scope, 172 inputs, 173 features, 174 kernel_size, 175 strides=None, 176 padding='SAME', 177 input_dilation=None, 178 kernel_dilation=None, 179 feature_group_count=1, 180 bias=True, 181 dtype=jnp.float32, 182 precision=None, 183 kernel_init=default_kernel_init, 184 bias_init=initializers.zeros): 185 """Applies a convolution to the inputs. 186 187 Args: 188 inputs: input data with dimensions (batch, spatial_dims..., features). 189 features: number of convolution filters. 190 kernel_size: shape of the convolutional kernel. 191 strides: a sequence of `n` integers, representing the inter-window 192 strides. 193 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence 194 of `n` `(low, high)` integer pairs that give the padding to apply before 195 and after each spatial dimension. 196 input_dilation: `None`, or a sequence of `n` integers, giving the 197 dilation factor to apply in each spatial dimension of `inputs`. 198 Convolution with input dilation `d` is equivalent to transposed 199 convolution with stride `d`. 200 kernel_dilation: `None`, or a sequence of `n` integers, giving the 201 dilation factor to apply in each spatial dimension of the convolution 202 kernel. Convolution with kernel dilation is also known as 'atrous 203 convolution'. 204 feature_group_count: integer, default 1. If specified divides the input 205 features into groups. 206 bias: whether to add a bias to the output (default: True). 207 dtype: the dtype of the computation (default: float32). 208 precision: numerical precision of the computation see `jax.lax.Precision` 209 for details. 210 kernel_init: initializer for the convolutional kernel. 211 bias_init: initializer for the bias. 212 Returns: 213 The convolved data. 214 """ 215 216 inputs = jnp.asarray(inputs, dtype) 217 218 if strides is None: 219 strides = (1,) * (inputs.ndim - 2) 220 221 in_features = inputs.shape[-1] 222 assert in_features % feature_group_count == 0 223 kernel_shape = kernel_size + (in_features // feature_group_count, features) 224 kernel = scope.param('kernel', kernel_init, kernel_shape) 225 kernel = jnp.asarray(kernel, dtype) 226 227 dimension_numbers = _conv_dimension_numbers(inputs.shape) 228 y = lax.conv_general_dilated( 229 inputs, 230 kernel, 231 strides, 232 padding, 233 lhs_dilation=input_dilation, 234 rhs_dilation=kernel_dilation, 235 dimension_numbers=dimension_numbers, 236 feature_group_count=feature_group_count, 237 precision=precision) 238 239 if bias: 240 bias = scope.param('bias', bias_init, (features,)) 241 bias = jnp.asarray(bias, dtype) 242 y = y + bias 243 return y 244 245 246def conv_transpose(scope, 247 inputs, 248 features, 249 kernel_size, 250 strides=None, 251 padding='SAME', 252 kernel_dilation=None, 253 bias=True, 254 dtype=jnp.float32, 255 precision=None, 256 kernel_init=default_kernel_init, 257 bias_init=initializers.zeros): 258 """Applies a transposed convolution to the inputs. Behaviour mirrors that of 259 `jax.lax.conv_transpose`. 260 261 Args: 262 scope: functional scope. 263 inputs: input data with dimensions (batch, spatial_dims..., features). 264 features: number of convolution filters. 265 kernel_size: shape of the convolutional kernel. 266 strides: a sequence of `n` integers, representing the inter-window 267 strides. 268 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence 269 of `n` `(low, high)` integer pairs that give the padding to apply before 270 and after each spatial dimension. 271 kernel_dilation: `None`, or a sequence of `n` integers, giving the 272 dilation factor to apply in each spatial dimension of the convolution 273 kernel. Convolution with kernel dilation is also known as 'atrous 274 convolution'. 275 bias: whether to add a bias to the output (default: True). 276 dtype: the dtype of the computation (default: float32). 277 precision: numerical precision of the computation see `jax.lax.Precision` 278 for details. 279 kernel_init: initializer for the convolutional kernel. 280 bias_init: initializer for the bias. 281 Returns: 282 The convolved data. 283 """ 284 inputs = jnp.asarray(inputs, dtype) 285 strides = strides or (1,) * (inputs.ndim - 2) 286 287 in_features = inputs.shape[-1] 288 kernel_shape = kernel_size + (in_features, features) 289 kernel = scope.param('kernel', kernel_init, kernel_shape) 290 kernel = jnp.asarray(kernel, dtype) 291 292 y = lax.conv_transpose(inputs, kernel, strides, padding, 293 rhs_dilation=kernel_dilation, precision=precision) 294 295 if bias: 296 bias = scope.param('bias', bias_init, (features,)) 297 bias = jnp.asarray(bias, dtype) 298 y = y + bias 299 return y 300 301 302default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', 303 out_axis=0) 304 305 306@struct.dataclass 307class Embedding: 308 table: np.ndarray 309 310 def lookup(self, indices): 311 """Embeds the inputs along the last dimension. 312 313 Args: 314 indices: input data, all dimensions are considered batch dimensions. 315 316 Returns: 317 Output which is embedded input data. The output shape follows the input, 318 with an additional `features` dimension appended. 319 """ 320 if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: 321 raise ValueError('Input type must be an integer or unsigned integer.') 322 return self.table[indices] 323 324 def attend(self, query): 325 """Attend over the embedding using a query array. 326 327 Args: 328 query: array with last dimension equal the feature depth `features` of the 329 embedding. 330 331 Returns: 332 An array with final dim `num_embeddings` corresponding to the batched 333 inner-product of the array of query vectors against each embedding. 334 Commonly used for weight-sharing between embeddings and logit transform 335 in NLP models. 336 """ 337 return jnp.dot(query, self.table.T) 338 339 340def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init) -> Embedding: 341 """Creates embedding dataclass. 342 343 Args: 344 num_embeddings: number of embeddings. 345 features: Number of feature dimensions for each embedding. 346 embedding_init: embedding initializer. 347 348 Returns: 349 Embedding dataclass with lookup and attend methods. 350 """ 351 table = scope.param('table', init_fn, (num_embeddings, features)) 352 return Embedding(table) 353