1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# Pytype is too slow to check this file. 16# pytype: skip-file 17 18import builtins 19from enum import IntEnum 20import functools 21import itertools 22import operator 23from typing import (Any, Callable, List, NamedTuple, Optional, Sequence,\ 24 Union, Tuple) 25import warnings 26 27import numpy as np 28 29import jax 30from jax import core 31from jax import ad_util 32from jax import api 33from jax import api_util 34from jax import linear_util as lu 35from jax import dtypes 36from jax import lazy 37from jax import tree_util 38from jax.config import flags, config 39from jax.core import (Primitive, _canonicalize_dimension, UnshapedArray, 40 ShapedArray, ConcreteArray, raise_to_shaped, 41 abstract_token, canonicalize_shape) 42from jax.abstract_arrays import array_types 43from jax.interpreters import partial_eval as pe 44from jax.interpreters import xla 45from jax.interpreters import pxla 46from jax.interpreters import ad 47from jax.interpreters import invertible_ad as iad 48from jax.interpreters import batching 49from jax.interpreters import masking 50from jax._src.util import (cache, safe_zip, partial, prod, safe_map, 51 canonicalize_axis, split_list) 52from jax.tree_util import tree_map 53from jax.lib import pytree 54from jax.lib import xla_bridge 55from jax.lib import xla_client 56 57xb = xla_bridge 58xc = xla_client 59xops = xla_client.ops 60 61FLAGS = flags.FLAGS 62 63_max = builtins.max 64_min = builtins.min 65_reduce = functools.reduce 66 67Array = Any 68DType = Any 69Shape = Sequence[int] 70 71def _try_broadcast_shapes(shapes): 72 assert shapes 73 if len(shapes) == 1: return shapes[0] 74 rank, *others = {len(shape) for shape in shapes} 75 if others: return None # must have consistent rank 76 if not rank: return () # scalar case 77 result_shape = [None] * rank 78 for i, sizes in enumerate(zip(*shapes)): 79 if sizes[:-1] == sizes[1:]: 80 result_shape[i] = sizes[0] # all equal sizes for this dimension 81 else: 82 sizes = [d for d in sizes if d != 1] 83 if sizes[:-1] != sizes[1:]: 84 return None # must have equal sizes other than 1-sized axes 85 result_shape[i] = sizes[0] if sizes else 1 86 return tuple(result_shape) 87 88@cache() 89def broadcast_shapes(*shapes): 90 """Returns the shape that results from NumPy broadcasting of `shapes`.""" 91 if len(shapes) == 1: 92 return shapes[0] 93 ndim = _max(len(shape) for shape in shapes) 94 shapes = [(1,) * (ndim - len(shape)) + shape for shape in shapes] 95 result_shape = _try_broadcast_shapes(shapes) 96 if result_shape is None: 97 raise ValueError("Incompatible shapes for broadcasting: {}" 98 .format(tuple(map(tuple, shapes)))) 99 return result_shape 100 101def _identity(x): return x 102 103### traceables 104 105def neg(x: Array) -> Array: 106 r"""Elementwise negation: :math:`-x`.""" 107 return neg_p.bind(x) 108 109def sign(x: Array) -> Array: 110 r"""Elementwise sign. 111 112 For floating-point inputs, returns 113 :math:`\mathrm{sign}(x) = \begin{cases} 114 -1 & x < 0\\ 115 -0 & x = -0\\ 116 \mathit{NaN} & x = \mathit{NaN}\\ 117 +0 & x = +0\\ 118 1 & x > 0 119 \end{cases}` 120 121 For signed integer inputs, returns 122 :math:`\mathrm{sign}(x) = \begin{cases} 123 -1 & x < 0\\ 124 0 & x = 0\\ 125 1 & x > 0 126 \end{cases}` 127 128 For complex inputs, returns the complex phase, i.e. 129 :math:`\mathrm{sign}(x) = \frac{x}{|x|}`. 130 """ 131 return sign_p.bind(x) 132 133def nextafter(x1: Array, x2: Array) -> Array: 134 r"""Returns the next representable value after `x1` in the direction of `x2`. 135 136 Note that in some environments flush-denormal-to-zero semantics is used. 137 This means that, around zero, this function returns strictly non-zero 138 values which appear as zero in any operations. Consider this example:: 139 >>> jnp.nextafter(0, 1) # denormal numbers are representable 140 DeviceArray(1.e-45, dtype=float32) 141 >>> jnp.nextafter(0, 1) * 1 # but are flushed to zero 142 DeviceArray(0., dtype=float32) 143 144 For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. 145 """ 146 return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1)) 147 148def floor(x: Array) -> Array: 149 r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.""" 150 return floor_p.bind(x) 151 152def ceil(x: Array) -> Array: 153 r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.""" 154 return ceil_p.bind(x) 155 156class RoundingMethod(IntEnum): 157 AWAY_FROM_ZERO = 0 158 TO_NEAREST_EVEN = 1 159 160def round(x: Array, 161 rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO 162 ) -> Array: 163 r"""Elementwise round. 164 165 Rounds values to the nearest integer. 166 167 Args: 168 x: an array or scalar value to round. 169 rounding_method: the method to use when rounding halfway values 170 (e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible 171 values. 172 173 Returns: 174 An array containing the elementwise rounding of x. 175 """ 176 rounding_method = RoundingMethod(rounding_method) 177 return round_p.bind(x, rounding_method=rounding_method) 178 179def is_finite(x: Array) -> Array: 180 r"""Elementwise :math:`\mathrm{isfinite}`. 181 182 For each element x returns `True` if and only if x is not :math:`\pm\infty` or 183 :math:`\mathit{NaN}`. 184 """ 185 return is_finite_p.bind(x) 186 187def exp(x: Array) -> Array: 188 r"""Elementwise exponential: :math:`e^x`.""" 189 return exp_p.bind(x) 190 191def expm1(x: Array) -> Array: 192 r"""Elementwise :math:`e^{x} - 1`.""" 193 return expm1_p.bind(x) 194 195def log(x: Array) -> Array: 196 r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.""" 197 return log_p.bind(x) 198 199def log1p(x: Array) -> Array: 200 r"""Elementwise :math:`\mathrm{log}(1 + x)`.""" 201 return log1p_p.bind(x) 202 203def tanh(x: Array) -> Array: 204 r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.""" 205 return tanh_p.bind(x) 206 207def sin(x: Array) -> Array: 208 r"""Elementwise sine: :math:`\mathrm{sin}(x)`.""" 209 return sin_p.bind(x) 210 211def cos(x: Array) -> Array: 212 r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.""" 213 return cos_p.bind(x) 214 215def atan2(x: Array, y: Array) -> Array: 216 r"""Elementwise arc tangent of two variables: 217 :math:`\mathrm{atan}({x \over y})`.""" 218 return atan2_p.bind(x, y) 219 220def betainc(a: Array, b: Array, x: Array) -> Array: 221 r"""Elementwise regularized incomplete beta integral.""" 222 return regularized_incomplete_beta_p.bind(a, b, x) 223 224def lgamma(x: Array) -> Array: 225 r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`.""" 226 return lgamma_p.bind(x) 227 228def digamma(x: Array) -> Array: 229 r"""Elementwise digamma: :math:`\psi(x)`.""" 230 return digamma_p.bind(x) 231 232def igamma(a: Array, x: Array) -> Array: 233 r"""Elementwise regularized incomplete gamma function.""" 234 return igamma_p.bind(a, x) 235 236def igammac(a: Array, x: Array) -> Array: 237 r"""Elementwise complementary regularized incomplete gamma function.""" 238 return igammac_p.bind(a, x) 239 240def igamma_grad_a(a: Array, x: Array) -> Array: 241 r"""Elementwise derivative of the regularized incomplete gamma function.""" 242 return igamma_grad_a_p.bind(a, x) 243 244def random_gamma_grad(a: Array, x: Array) -> Array: 245 r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" 246 return random_gamma_grad_p.bind(a, x) 247 248def bessel_i0e(x: Array) -> Array: 249 r"""Exponentially scaled modified Bessel function of order 0: 250 :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)` 251 """ 252 return bessel_i0e_p.bind(x) 253 254def bessel_i1e(x: Array) -> Array: 255 r"""Exponentially scaled modified Bessel function of order 1: 256 :math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)` 257 """ 258 return bessel_i1e_p.bind(x) 259 260def erf(x: Array) -> Array: 261 r"""Elementwise error function: :math:`\mathrm{erf}(x)`.""" 262 return erf_p.bind(x) 263 264def erfc(x: Array) -> Array: 265 r"""Elementwise complementary error function: 266 :math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`.""" 267 return erfc_p.bind(x) 268 269def erf_inv(x: Array) -> Array: 270 r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`.""" 271 return erf_inv_p.bind(x) 272 273def real(x: Array) -> Array: 274 r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`. 275 276 Returns the real part of a complex number. 277 """ 278 return real_p.bind(x) 279 280def imag(x: Array) -> Array: 281 r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`. 282 283 Returns the imaginary part of a complex number. 284 """ 285 return imag_p.bind(x) 286 287def complex(x: Array, y: Array) -> Array: 288 r"""Elementwise make complex number: :math:`x + jy`. 289 290 Builds a complex number from real and imaginary parts. 291 """ 292 return complex_p.bind(_brcast(x, y), _brcast(y, x)) 293 294def conj(x: Array) -> Array: 295 r"""Elementwise complex conjugate function: :math:`\overline{x}`.""" 296 return conj_p.bind(x, input_dtype=_dtype(x)) 297 298def abs(x: Array) -> Array: 299 r"""Elementwise absolute value: :math:`|x|`.""" 300 return abs_p.bind(x) 301 302def pow(x: Array, y: Array) -> Array: 303 r"""Elementwise power: :math:`x^y`.""" 304 return pow_p.bind(x, y) 305 306def integer_pow(x: Array, y: int) -> Array: 307 r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer.""" 308 return integer_pow_p.bind(x, y=y) 309 310def sqrt(x: Array) -> Array: 311 r"""Elementwise square root: :math:`\sqrt{x}`.""" 312 return sqrt_p.bind(x) 313 314def rsqrt(x: Array) -> Array: 315 r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}.""" 316 return rsqrt_p.bind(x) 317 318def bitwise_not(x: Array) -> Array: 319 r"""Elementwise NOT: :math:`\neg x`.""" 320 return not_p.bind(x) 321 322def bitwise_and(x: Array, y: Array) -> Array: 323 r"""Elementwise AND: :math:`x \wedge y`.""" 324 return and_p.bind(x, y) 325 326def bitwise_or(x: Array, y: Array) -> Array: 327 r"""Elementwise OR: :math:`x \vee y`.""" 328 return or_p.bind(x, y) 329 330def bitwise_xor(x: Array, y: Array) -> Array: 331 r"""Elementwise exclusive OR: :math:`x \oplus y`.""" 332 return xor_p.bind(x, y) 333 334def population_count(x: Array) -> Array: 335 r"""Elementwise popcount, count the number of set bits in each element.""" 336 return population_count_p.bind(x) 337 338def add(x: Array, y: Array) -> Array: 339 r"""Elementwise addition: :math:`x + y`.""" 340 return add_p.bind(x, y) 341 342def sub(x: Array, y: Array) -> Array: 343 r"""Elementwise subtraction: :math:`x - y`.""" 344 return sub_p.bind(x, y) 345 346def mul(x: Array, y: Array) -> Array: 347 r"""Elementwise multiplication: :math:`x \times y`.""" 348 return mul_p.bind(x, y) 349 350def div(x: Array, y: Array) -> Array: 351 r"""Elementwise division: :math:`x \over y`.""" 352 return div_p.bind(x, y) 353 354def rem(x: Array, y: Array) -> Array: 355 r"""Elementwise remainder: :math:`x \bmod y`.""" 356 return rem_p.bind(x, y) 357 358def max(x: Array, y: Array) -> Array: 359 r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` 360 361 For complex numbers, uses a lexicographic comparison on the 362 `(real, imaginary)` pairs.""" 363 return max_p.bind(x, y) 364 365def min(x: Array, y: Array) -> Array: 366 r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` 367 368 For complex numbers, uses a lexicographic comparison on the 369 `(real, imaginary)` pairs.""" 370 return min_p.bind(x, y) 371 372def shift_left(x: Array, y: Array) -> Array: 373 r"""Elementwise left shift: :math:`x \ll y`.""" 374 return shift_left_p.bind(x, y) 375 376def shift_right_arithmetic(x: Array, y: Array) -> Array: 377 r"""Elementwise arithmetic right shift: :math:`x \gg y`.""" 378 return shift_right_arithmetic_p.bind(x, y) 379 380def shift_right_logical(x: Array, y: Array) -> Array: 381 r"""Elementwise logical right shift: :math:`x \gg y`.""" 382 return shift_right_logical_p.bind(x, y) 383 384def eq(x: Array, y: Array) -> Array: 385 r"""Elementwise equals: :math:`x = y`.""" 386 return eq_p.bind(x, y) 387 388def ne(x: Array, y: Array) -> Array: 389 r"""Elementwise not-equals: :math:`x \neq y`.""" 390 return ne_p.bind(x, y) 391 392def ge(x: Array, y: Array) -> Array: 393 r"""Elementwise greater-than-or-equals: :math:`x \geq y`.""" 394 return ge_p.bind(x, y) 395 396def gt(x: Array, y: Array) -> Array: 397 r"""Elementwise greater-than: :math:`x > y`.""" 398 return gt_p.bind(x, y) 399 400def le(x: Array, y: Array) -> Array: 401 r"""Elementwise less-than-or-equals: :math:`x \leq y`.""" 402 return le_p.bind(x, y) 403 404def lt(x: Array, y: Array) -> Array: 405 r"""Elementwise less-than: :math:`x < y`.""" 406 return lt_p.bind(x, y) 407 408def convert_element_type(operand: Array, new_dtype: DType) -> Array: 409 """Elementwise cast. 410 411 Wraps XLA's `ConvertElementType 412 <https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_ 413 operator, which performs an elementwise conversion from one type to another. 414 Similar to a C++ `static_cast`. 415 416 Args: 417 operand: an array or scalar value to be cast. 418 new_dtype: the new type. Should be a NumPy type. 419 420 Returns: 421 An array with the same shape as `operand`, cast elementwise to `new_dtype`. 422 """ 423 new_dtype = dtypes.canonicalize_dtype(new_dtype) 424 # Avoids dropping precision by casting Python scalars to the default Jax 425 # type. If we passed a Python scalar directly to the bind call below, it is 426 # cast to the default type as part of the calling convention. 427 if type(operand) in dtypes.python_scalar_dtypes: 428 operand = np.asarray(operand, new_dtype) 429 old_dtype = dtypes.canonicalize_dtype(_dtype(operand)) 430 if old_dtype == new_dtype: 431 if isinstance(operand, (core.Tracer, xla.DeviceArray)): 432 return operand 433 else: 434 return _device_put_raw(np.asarray(operand)) 435 if (dtypes.issubdtype(old_dtype, np.complexfloating) and 436 not dtypes.issubdtype(new_dtype, np.complexfloating)): 437 msg = "Casting complex values to real discards the imaginary part" 438 warnings.warn(msg, np.ComplexWarning, stacklevel=2) 439 return convert_element_type_p.bind(operand, new_dtype=new_dtype) 440 441def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array: 442 """Elementwise bitcast. 443 444 Wraps XLA's `BitcastConvertType 445 <https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_ 446 operator, which performs a bit cast from one type to another. The bitwidth 447 of the source and destination types must match. 448 449 Args: 450 operand: an array or scalar value to be cast 451 new_dtype: the new type. Should be a NumPy type. 452 453 Returns: 454 An array with the same shape as `operand`, bitcast elementwise to 455 `new_dtype`. 456 """ 457 new_dtype = dtypes.canonicalize_dtype(new_dtype) 458 old_dtype = _dtype(operand) 459 if old_dtype != new_dtype: 460 return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) 461 else: 462 return operand 463 464def clamp(min: Array, x: Array, max: Array) -> Array: 465 r"""Elementwise clamp. 466 467 Returns :math:`\mathrm{clamp}(x) = \begin{cases} 468 \mathit{min} & \text{if } x < \mathit{min},\\ 469 \mathit{max} & \text{if } x > \mathit{max},\\ 470 x & \text{otherwise} 471 \end{cases}`. 472 """ 473 return clamp_p.bind(min, x, max) 474 475def concatenate(operands: Sequence[Array], dimension: int) -> Array: 476 """Concatenates a sequence of arrays along `dimension`. 477 478 Wraps XLA's `Concatenate 479 <https://www.tensorflow.org/xla/operation_semantics#concatenate>`_ 480 operator. 481 482 Args: 483 operands: a sequence of arrays to concatenate. The arrays must have equal 484 shapes, except in the `dimension` axis. 485 dimension: the dimension along which to concatenate the arrays. 486 487 Returns: 488 An array containing the concatenation. 489 """ 490 return concatenate_p.bind(*operands, dimension=dimension) 491 492Precision = xla_client.PrecisionConfig.Precision 493Precision.__str__ = lambda precision: precision.name 494PrecisionType = Any 495PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]] 496 497 498class ConvDimensionNumbers(NamedTuple): 499 """Describes batch, spatial, and feature dimensions of a convolution. 500 501 Args: 502 lhs_spec: a tuple of nonnegative integer dimension numbers containing 503 `(batch dimension, feature dimension, spatial dimensions...)`. 504 rhs_spec: a tuple of nonnegative integer dimension numbers containing 505 `(out feature dimension, in feature dimension, spatial dimensions...)`. 506 out_spec: a tuple of nonnegative integer dimension numbers containing 507 `(batch dimension, feature dimension, spatial dimensions...)`. 508 """ 509 lhs_spec: Sequence[int] 510 rhs_spec: Sequence[int] 511 out_spec: Sequence[int] 512 513ConvGeneralDilatedDimensionNumbers = Union[ 514 None, ConvDimensionNumbers, Tuple[str, str, str]] 515 516def conv_general_dilated( 517 lhs: Array, rhs: Array, window_strides: Sequence[int], 518 padding: Union[str, Sequence[Tuple[int, int]]], 519 lhs_dilation: Optional[Sequence[int]] = None, 520 rhs_dilation: Optional[Sequence[int]] = None, 521 dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, 522 feature_group_count: int = 1, batch_group_count: int = 1, 523 precision: PrecisionLike = None) -> Array: 524 """General n-dimensional convolution operator, with optional dilation. 525 526 Wraps XLA's `Conv 527 <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_ 528 operator. 529 530 Args: 531 lhs: a rank `n+2` dimensional input array. 532 rhs: a rank `n+2` dimensional array of kernel weights. 533 window_strides: a sequence of `n` integers, representing the inter-window 534 strides. 535 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of 536 `n` `(low, high)` integer pairs that give the padding to apply before and 537 after each spatial dimension. 538 lhs_dilation: `None`, or a sequence of `n` integers, giving the 539 dilation factor to apply in each spatial dimension of `lhs`. LHS dilation 540 is also known as transposed convolution. 541 rhs_dilation: `None`, or a sequence of `n` integers, giving the 542 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation 543 is also known as atrous convolution. 544 dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or 545 a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string 546 of length `n+2`. 547 feature_group_count: integer, default 1. See XLA HLO docs. 548 batch_group_count: integer, default 1. See XLA HLO docs. 549 precision: Optional. Either ``None``, which means the default precision for 550 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 551 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 552 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 553 554 Returns: 555 An array containing the convolution result. 556 557 In the string case of `dimension_numbers`, each character identifies by 558 position: 559 560 - the batch dimensions in `lhs`, `rhs`, and the output with the character 561 'N', 562 - the feature dimensions in `lhs` and the output with the character 'C', 563 - the input and output feature dimensions in rhs with the characters 'I' 564 and 'O' respectively, and 565 - spatial dimension correspondences between lhs, rhs, and the output using 566 any distinct characters. 567 568 For example, to indicate dimension numbers consistent with the `conv` function 569 with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As 570 another example, to indicate dimension numbers consistent with the TensorFlow 571 Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the 572 latter form of convolution dimension specification, window strides are 573 associated with spatial dimension character labels according to the order in 574 which the labels appear in the `rhs_spec` string, so that `window_strides[0]` 575 is matched with the dimension corresponding to the first character 576 appearing in rhs_spec that is not `'I'` or `'O'`. 577 578 If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` 579 (for a 2D convolution). 580 """ 581 dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) 582 if lhs_dilation is None: 583 lhs_dilation = (1,) * (lhs.ndim - 2) 584 elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1): 585 raise ValueError( 586 "String padding is not implemented for transposed convolution " 587 "using this op. Please either exactly specify the required padding or " 588 "use conv_transpose.") 589 if rhs_dilation is None: 590 rhs_dilation = (1,) * (rhs.ndim - 2) 591 if isinstance(padding, str): 592 lhs_perm, rhs_perm, _ = dnums 593 rhs_shape = np.take(rhs.shape, rhs_perm)[2:] 594 effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)] 595 padding = padtype_to_pads( 596 np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, 597 window_strides, padding) 598 return conv_general_dilated_p.bind( 599 lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), 600 lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), 601 dimension_numbers=dnums, 602 feature_group_count=feature_group_count, 603 batch_group_count=batch_group_count, 604 lhs_shape=lhs.shape, rhs_shape=rhs.shape, 605 precision=_canonicalize_precision(precision)) 606 607def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, 608 preferred_element_type: Optional[DType] = None) -> Array: 609 """Vector/vector, matrix/vector, and matrix/matrix multiplication. 610 611 Wraps XLA's `Dot 612 <https://www.tensorflow.org/xla/operation_semantics#dot>`_ 613 operator. 614 615 For more general contraction, see the `dot_general` operator. 616 617 Args: 618 lhs: an array of rank 1 or 2. 619 rhs: an array of rank 1 or 2. 620 precision: Optional. Either ``None``, which means the default precision for 621 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 622 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 623 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 624 preferred_element_type: Optional. Either ``None``, which means the default 625 accumulation type for the input types, or a datatype, indicating to 626 accumulate results to and return a result with that datatype. 627 628 Returns: 629 An array containing the product. 630 """ 631 if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]: 632 return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), 633 precision=precision, preferred_element_type=preferred_element_type) 634 else: 635 raise TypeError("Incompatible shapes for dot: got {} and {}.".format( 636 lhs.shape, rhs.shape)) 637 638 639DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]], 640 Tuple[Sequence[int], Sequence[int]]] 641 642def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers, 643 precision: PrecisionLike = None, 644 preferred_element_type: Optional[DType] = None) -> Array: 645 """More general contraction operator. 646 647 Wraps XLA's `DotGeneral 648 <https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_ 649 operator. 650 651 Args: 652 lhs: an array 653 rhs: an array 654 dimension_numbers: a tuple of tuples of the form 655 `((lhs_contracting_dims, rhs_contracting_dims), 656 (lhs_batch_dims, rhs_batch_dims))` 657 precision: Optional. Either ``None``, which means the default precision for 658 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 659 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 660 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 661 preferred_element_type: Optional. Either ``None``, which means the default 662 accumulation type for the input types, or a datatype, indicating to 663 accumulate results to and return a result with that datatype. 664 665 Returns: 666 An array containing the result. 667 """ 668 contract_dims_seq, batch_dims_seq = dimension_numbers 669 contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq)) 670 batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq)) 671 return dot_general_p.bind(lhs, rhs, 672 dimension_numbers=(contract_dims, batch_dims), 673 precision=_canonicalize_precision(precision), 674 preferred_element_type=preferred_element_type) 675 676def broadcast(operand: Array, sizes: Sequence[int]) -> Array: 677 """Broadcasts an array, adding new major dimensions. 678 679 Wraps XLA's `Broadcast 680 <https://www.tensorflow.org/xla/operation_semantics#broadcast>`_ 681 operator. 682 683 Args: 684 operand: an array 685 sizes: a sequence of integers, giving the sizes of new major dimensions 686 to add. 687 688 Returns: 689 An array containing the result. 690 """ 691 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) 692 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) 693 694def broadcast_in_dim(operand: Array, shape: Shape, 695 broadcast_dimensions: Sequence[int]) -> Array: 696 """Wraps XLA's `BroadcastInDim 697 <https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_ 698 operator. 699 """ 700 shape = _broadcast_in_dim_shape_rule( 701 operand, shape=shape, broadcast_dimensions=broadcast_dimensions) 702 if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) 703 and isinstance(operand, (xla.DeviceArray, core.Tracer))): 704 return operand 705 return broadcast_in_dim_p.bind( 706 operand, shape=tuple(shape), 707 broadcast_dimensions=tuple(broadcast_dimensions)) 708 709def broadcast_to_rank(x: Array, rank: int) -> Array: 710 """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" 711 return broadcast(x, (1,) * (rank - x.ndim)) 712 713def reshape(operand: Array, new_sizes: Shape, 714 dimensions: Optional[Sequence[int]] = None) -> Array: 715 """Wraps XLA's `Reshape 716 <https://www.tensorflow.org/xla/operation_semantics#reshape>`_ 717 operator. 718 719 For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` / 720 ``lax.expand_dims``. These preserve information about axis identity that may 721 be useful for advanced transformation rules. 722 """ 723 new_sizes = canonicalize_shape(new_sizes) # TODO 724 new_sizes = tuple(new_sizes) 725 same_shape = np.shape(operand) == new_sizes 726 same_dims = dimensions is None or tuple(dimensions) == tuple(range(np.ndim(operand))) 727 if np.shape(operand) and same_shape and same_dims: 728 return operand 729 else: 730 return reshape_p.bind( 731 operand, new_sizes=new_sizes, 732 dimensions=None if dimensions is None or same_dims else tuple(dimensions)) 733 734def pad(operand: Array, padding_value: Array, 735 padding_config: Sequence[Tuple[int, int, int]]) -> Array: 736 """Applies low, high, and/or interior padding to an array. 737 738 Wraps XLA's `Pad 739 <https://www.tensorflow.org/xla/operation_semantics#pad>`_ 740 operator. 741 742 Args: 743 operand: an array to be padded. 744 padding_value: the value to be inserted as padding. Must have the same dtype 745 as ``operand``. 746 padding_config: a sequence of ``(low, high, interior)`` tuples of integers, 747 giving the amount of low, high, and interior (dilation) padding to insert 748 in each dimension. 749 750 Returns: 751 The ``operand`` array with padding value ``padding_value`` inserted in each 752 dimension according to the ``padding_config``. 753 """ 754 return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) 755 756def rev(operand: Array, dimensions: Sequence[int]) -> Array: 757 """Wraps XLA's `Rev 758 <https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_ 759 operator. 760 """ 761 return rev_p.bind(operand, dimensions=tuple(dimensions)) 762 763def select(pred: Array, on_true: Array, on_false: Array) -> Array: 764 """Wraps XLA's `Select 765 <https://www.tensorflow.org/xla/operation_semantics#select>`_ 766 operator. 767 """ 768 return select_p.bind(pred, on_true, on_false) 769 770def slice(operand: Array, start_indices: Sequence[int], 771 limit_indices: Sequence[int], 772 strides: Optional[Sequence[int]] = None) -> Array: 773 """Wraps XLA's `Slice 774 <https://www.tensorflow.org/xla/operation_semantics#slice>`_ 775 operator. 776 """ 777 return slice_p.bind(operand, start_indices=tuple(start_indices), 778 limit_indices=tuple(limit_indices), 779 strides=None if strides is None else tuple(strides)) 780 781def dynamic_slice(operand: Array, start_indices: Sequence[Array], 782 slice_sizes: Shape) -> Array: 783 """Wraps XLA's `DynamicSlice 784 <https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_ 785 operator. 786 787 Args: 788 operand: an array to slice. 789 start_indices: a list of scalar indices, one per dimension. These values 790 may be dynamic. 791 slice_sizes: the size of the slice. Must be a sequence of non-negative 792 integers with length equal to `ndim(operand)`. Inside a JIT compiled 793 function, only static values are supported (all JAX arrays inside JIT 794 must have statically known size). 795 796 Returns: 797 An array containing the slice. 798 """ 799 start_indices = _dynamic_slice_indices(operand, start_indices) 800 return dynamic_slice_p.bind(operand, *start_indices, 801 slice_sizes=tuple(slice_sizes)) 802 803def dynamic_update_slice(operand: Array, update: Array, 804 start_indices: Array) -> Array: 805 """Wraps XLA's `DynamicUpdateSlice 806 <https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_ 807 operator. 808 809 Args: 810 operand: an array to slice. 811 update: an array containing the new values to write onto `operand`. 812 start_indices: a list of scalar indices, one per dimension. 813 814 Returns: 815 An array containing the slice. 816 """ 817 start_indices = _dynamic_slice_indices(operand, start_indices) 818 return dynamic_update_slice_p.bind(operand, update, *start_indices) 819 820 821class GatherDimensionNumbers(NamedTuple): 822 """ 823 Describes the dimension number arguments to an `XLA's Gather operator 824 <https://www.tensorflow.org/xla/operation_semantics#gather>`_. See the XLA 825 documentation for more details of what the dimension numbers mean. 826 827 Args: 828 offset_dims: the set of dimensions in the `gather` output that offset into 829 an array sliced from `operand`. Must be a tuple of integers in ascending 830 order, each representing a dimension number of the output. 831 collapsed_slice_dims: the set of dimensions `i` in `operand` that have 832 `slice_sizes[i] == 1` and that should not have a corresponding dimension 833 in the output of the gather. Must be a tuple of integers in ascending 834 order. 835 start_index_map: for each dimension in `start_indices`, gives the 836 corresponding dimension in `operand` that is to be sliced. Must be a 837 tuple of integers with size equal to `start_indices.shape[-1]`. 838 839 Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is 840 implicit; there is always an index vector dimension and it must always be the 841 last dimension. To gather scalar indices, add a trailing dimension of size 1. 842 """ 843 offset_dims: Sequence[int] 844 collapsed_slice_dims: Sequence[int] 845 start_index_map: Sequence[int] 846 847 848def gather(operand: Array, start_indices: Array, 849 dimension_numbers: GatherDimensionNumbers, 850 slice_sizes: Shape) -> Array: 851 """Gather operator. 852 853 Wraps `XLA's Gather operator 854 <https://www.tensorflow.org/xla/operation_semantics#gather>`_. 855 856 The semantics of gather are complicated, and its API might change in the 857 future. For most use cases, you should prefer `Numpy-style indexing 858 <https://docs.scipy.org/doc/numpy-1.16.0/reference/arrays.indexing.html>`_ 859 (e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly. 860 861 Args: 862 operand: an array from which slices should be taken 863 start_indices: the indices at which slices should be taken 864 dimension_numbers: a `lax.GatherDimensionNumbers` object that describes 865 how dimensions of `operand`, `start_indices` and the output relate. 866 slice_sizes: the size of each slice. Must be a sequence of non-negative 867 integers with length equal to `ndim(operand)`. 868 869 Returns: 870 An array containing the gather output. 871 """ 872 return gather_p.bind( 873 operand, start_indices, dimension_numbers=dimension_numbers, 874 slice_sizes=canonicalize_shape(slice_sizes)) 875 876 877class ScatterDimensionNumbers(NamedTuple): 878 """ 879 Describes the dimension number arguments to an `XLA's Scatter operator 880 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_. See the XLA 881 documentation for more details of what the dimension numbers mean. 882 883 Args: 884 update_window_dims: the set of dimensions in the `updates` that are window 885 dimensions. Must be a tuple of integers in ascending 886 order, each representing a dimension number. 887 inserted_window_dims: the set of size 1 window dimensions that must be inserted 888 into the shape of `updates`. Must be a tuple of integers in ascending 889 order, each representing a dimension number of the output. These are the 890 mirror image of `collapsed_slice_dims` in the case of `gather`. 891 scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives 892 the corresponding dimension in `operand`. Must be a sequence of integers 893 with size equal to indices.shape[-1]. 894 895 Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is 896 implicit; there is always an index vector dimension and it must always be the 897 last dimension. To scatter scalar indices, add a trailing dimension of size 1. 898 """ 899 update_window_dims: Sequence[int] 900 inserted_window_dims: Sequence[int] 901 scatter_dims_to_operand_dims: Sequence[int] 902 903def scatter_add(operand: Array, scatter_indices: Array, updates: Array, 904 dimension_numbers: ScatterDimensionNumbers, *, 905 indices_are_sorted: bool = False, 906 unique_indices: bool = False) -> Array: 907 """Scatter-add operator. 908 909 Wraps `XLA's Scatter operator 910 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where 911 addition is used to combine updates and values from `operand`. 912 913 The semantics of scatter are complicated and its API is subject to change. 914 915 Args: 916 operand: an array to which the scatter should be applied 917 scatter_indices: an array that gives the indices in `operand` to which each 918 update in `updates` should be applied. 919 updates: the updates that should be scattered onto `operand`. 920 dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes 921 how dimensions of `operand`, `start_indices`, `updates` and the output 922 relate. 923 indices_are_sorted: whether `scatter_indices` is known to be sorted. If 924 true, may improve performance on some backends. 925 unique_indices: whether the indices to be updated in ``operand`` are 926 guaranteed to not overlap with each other. If true, may improve performance on 927 some backends. 928 929 Returns: 930 An array containing the sum of `operand` and the scattered updates. 931 """ 932 jaxpr, consts = _reduction_jaxpr(add, _abstractify(_const(operand, 0))) 933 return scatter_add_p.bind( 934 operand, scatter_indices, updates, update_jaxpr=jaxpr, 935 update_consts=consts, dimension_numbers=dimension_numbers, 936 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 937 938def scatter_mul(operand: Array, scatter_indices: Array, updates: Array, 939 dimension_numbers: ScatterDimensionNumbers, *, 940 indices_are_sorted: bool = False, 941 unique_indices: bool = False) -> Array: 942 """Scatter-multiply operator. 943 944 Wraps `XLA's Scatter operator 945 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where 946 multiplication is used to combine updates and values from `operand`. 947 948 The semantics of scatter are complicated and its API is subject to change. 949 950 Args: 951 operand: an array to which the scatter should be applied 952 scatter_indices: an array that gives the indices in `operand` to which each 953 update in `updates` should be applied. 954 updates: the updates that should be scattered onto `operand`. 955 dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes 956 how dimensions of `operand`, `start_indices`, `updates` and the output 957 relate. 958 indices_are_sorted: whether `scatter_indices` is known to be sorted. If 959 true, may improve performance on some backends. 960 unique_indices: whether the indices to be updated in ``operand`` are 961 guaranteed to not overlap with each other. If true, may improve performance on 962 some backends. 963 964 Returns: 965 An array containing the sum of `operand` and the scattered updates. 966 """ 967 jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1))) 968 return scatter_mul_p.bind( 969 operand, scatter_indices, updates, update_jaxpr=jaxpr, 970 update_consts=consts, dimension_numbers=dimension_numbers, 971 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 972 973def scatter_min(operand: Array, scatter_indices: Array, updates: Array, 974 dimension_numbers: ScatterDimensionNumbers, *, 975 indices_are_sorted: bool = False, 976 unique_indices: bool = False) -> Array: 977 """Scatter-min operator. 978 979 Wraps `XLA's Scatter operator 980 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where 981 the `min` function is used to combine updates and values from `operand`. 982 983 The semantics of scatter are complicated and its API is subject to change. 984 985 Args: 986 operand: an array to which the scatter should be applied 987 scatter_indices: an array that gives the indices in `operand` to which each 988 update in `updates` should be applied. 989 updates: the updates that should be scattered onto `operand`. 990 dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes 991 how dimensions of `operand`, `start_indices`, `updates` and the output 992 relate. 993 indices_are_sorted: whether `scatter_indices` is known to be sorted. If 994 true, may improve performance on some backends. 995 unique_indices: whether the indices to be updated in ``operand`` are 996 guaranteed to not overlap with each other. If true, may improve performance on 997 some backends. 998 999 Returns: 1000 An array containing the sum of `operand` and the scattered updates. 1001 """ 1002 jaxpr, consts = _reduction_jaxpr(min, _abstractify(_const(operand, 0))) 1003 return scatter_min_p.bind( 1004 operand, scatter_indices, updates, update_jaxpr=jaxpr, 1005 update_consts=consts, dimension_numbers=dimension_numbers, 1006 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 1007 1008def scatter_max(operand: Array, scatter_indices: Array, updates: Array, 1009 dimension_numbers: ScatterDimensionNumbers, *, 1010 indices_are_sorted: bool = False, 1011 unique_indices: bool = False) -> Array: 1012 """Scatter-max operator. 1013 1014 Wraps `XLA's Scatter operator 1015 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where 1016 the `max` function is used to combine updates and values from `operand`. 1017 1018 The semantics of scatter are complicated and its API is subject to change. 1019 1020 Args: 1021 operand: an array to which the scatter should be applied 1022 scatter_indices: an array that gives the indices in `operand` to which each 1023 update in `updates` should be applied. 1024 updates: the updates that should be scattered onto `operand`. 1025 dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes 1026 how dimensions of `operand`, `start_indices`, `updates` and the output 1027 relate. 1028 indices_are_sorted: whether `scatter_indices` is known to be sorted. If 1029 true, may improve performance on some backends. 1030 unique_indices: whether the indices to be updated in ``operand`` are 1031 guaranteed to not overlap with each other. If true, may improve performance on 1032 some backends. 1033 1034 Returns: 1035 An array containing the sum of `operand` and the scattered updates. 1036 """ 1037 jaxpr, consts = _reduction_jaxpr(max, _abstractify(_const(operand, 0))) 1038 return scatter_max_p.bind( 1039 operand, scatter_indices, updates, update_jaxpr=jaxpr, 1040 update_consts=consts, dimension_numbers=dimension_numbers, 1041 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 1042 1043# Define this outside of scatter to ensure cache hits. 1044_scatter_reduction_computation = lambda x, y: y 1045 1046def scatter(operand: Array, scatter_indices: Array, updates: Array, 1047 dimension_numbers: ScatterDimensionNumbers, *, 1048 indices_are_sorted: bool = False, 1049 unique_indices: bool = False) -> Array: 1050 """Scatter-update operator. 1051 1052 Wraps `XLA's Scatter operator 1053 <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates 1054 replace values from `operand`. 1055 1056 If multiple updates are performed to the same index of operand, they may be 1057 applied in any order. 1058 1059 The semantics of scatter are complicated and its API is subject to change. 1060 1061 Args: 1062 operand: an array to which the scatter should be applied 1063 scatter_indices: an array that gives the indices in `operand` to which each 1064 update in `updates` should be applied. 1065 updates: the updates that should be scattered onto `operand`. 1066 dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes 1067 how dimensions of `operand`, `start_indices`, `updates` and the output 1068 relate. 1069 indices_are_sorted: whether `scatter_indices` is known to be sorted. If 1070 true, may improve performance on some backends. 1071 unique_indices: whether the indices to be updated in ``operand`` are 1072 guaranteed to not overlap with each other. If true, may improve performance on 1073 some backends. 1074 1075 Returns: 1076 An array containing the sum of `operand` and the scattered updates. 1077 """ 1078 jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation, 1079 _abstractify(_const(operand, 0))) 1080 return scatter_p.bind( 1081 operand, scatter_indices, updates, update_jaxpr=jaxpr, 1082 update_consts=consts, dimension_numbers=dimension_numbers, 1083 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 1084 1085def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: 1086 indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1) 1087 indices = indices % np.array([src.shape[ax] for ax in axes]) 1088 slice_sizes = list(src.shape) 1089 for ax in axes: 1090 slice_sizes[ax] = 1 1091 offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1)) 1092 dnums = GatherDimensionNumbers( 1093 offset_dims=offset_dims, 1094 collapsed_slice_dims=axes, 1095 start_index_map=axes) 1096 return gather(src, indices, dimension_numbers=dnums, 1097 slice_sizes=tuple(slice_sizes)) 1098 1099def transpose(operand: Array, permutation: Sequence[int]) -> Array: 1100 """Wraps XLA's `Transpose 1101 <https://www.tensorflow.org/xla/operation_semantics#transpose>`_ 1102 operator. 1103 """ 1104 permutation = tuple(permutation) 1105 if permutation == tuple(range(len(permutation))): 1106 return operand 1107 else: 1108 return transpose_p.bind(operand, permutation=permutation) 1109 1110def argmin(operand: Array, axis: int, 1111 index_dtype: DType) -> Tuple[Array, Array]: 1112 """Computes the index of the minimum element along ``axis``.""" 1113 return argmin_p.bind(operand, axes=(axis,), 1114 index_dtype=dtypes.canonicalize_dtype(index_dtype)) 1115 1116def argmax(operand: Array, axis: int, 1117 index_dtype: DType) -> Tuple[Array, Array]: 1118 """Computes the index of the maximum element along ``axis``.""" 1119 return argmax_p.bind(operand, axes=(axis,), 1120 index_dtype=dtypes.canonicalize_dtype(index_dtype)) 1121 1122def reduce(operands: Array, init_values: Array, computation: Callable, 1123 dimensions: Sequence[int]) -> Array: 1124 """Wraps XLA's `Reduce 1125 <https://www.tensorflow.org/xla/operation_semantics#reduce>`_ 1126 operator. 1127 """ 1128 flat_operands, operand_tree = tree_util.tree_flatten(operands) 1129 flat_init_values, init_value_tree = tree_util.tree_flatten(init_values) 1130 if operand_tree != init_value_tree: 1131 raise ValueError('Operands must have the same tree structure as init_values:' 1132 f' {operand_tree} vs. {init_value_tree}') 1133 if len(flat_operands) != len(flat_init_values): 1134 raise ValueError('Must have same total number of operands as init_values: ' 1135 f' {len(flat_operands)} vs. {len(flat_init_values)}') 1136 monoid_reducer = _get_monoid_reducer(computation, flat_init_values) 1137 if monoid_reducer: 1138 return monoid_reducer(*flat_operands, dimensions) 1139 else: 1140 flat_init_avals = safe_map(_abstractify, flat_init_values) 1141 jaxpr, consts, out_tree = _variadic_reduction_jaxpr( 1142 computation, tuple(flat_init_avals), init_value_tree) 1143 out = reduce_p.bind(*(flat_operands + flat_init_values), computation=computation, 1144 jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions)) 1145 return tree_util.tree_unflatten(out_tree, out) 1146 1147@cache() 1148def _reduction_jaxpr(computation, aval): 1149 pval = pe.PartialVal.unknown(aval) 1150 comp = lu.wrap_init(lambda x, y: (computation(x, y),)) 1151 jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False) 1152 return jaxpr, consts 1153 1154@cache() 1155def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): 1156 avals = tree_util.tree_unflatten(aval_tree, flat_avals) 1157 flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals)) 1158 pvals = safe_map(pe.PartialVal.unknown, flat_in_avals) 1159 comp = lu.wrap_init(computation) 1160 flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree) 1161 jaxpr, _, consts = pe.trace_to_jaxpr(flat_comp, tuple(pvals), 1162 instantiate=False) 1163 return jaxpr, consts, out_tree() 1164 1165def _get_monoid_reducer(monoid_op: Callable, xs: Array) -> Optional[Callable]: 1166 if len(xs) != 1: 1167 return None 1168 x, = xs 1169 aval = core.get_aval(x) 1170 dtype = _dtype(x) 1171 if (type(aval) is ConcreteArray) and aval.shape == (): 1172 if monoid_op is add: 1173 return np.equal(aval.val, 0) and _reduce_sum 1174 elif monoid_op is mul: 1175 return np.equal(aval.val, 1) and _reduce_prod 1176 elif monoid_op is bitwise_or and dtype == np.bool_: 1177 return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_or 1178 elif monoid_op is bitwise_and and dtype == np.bool_: 1179 return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_and 1180 elif monoid_op is max: 1181 return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_max 1182 elif monoid_op is min: 1183 return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_min 1184 return None 1185 1186def _get_max_identity(dtype: DType) -> Array: 1187 if dtypes.issubdtype(dtype, np.inexact): 1188 return np.array(-np.inf, dtype) 1189 elif dtypes.issubdtype(dtype, np.integer): 1190 return np.array(dtypes.iinfo(dtype).min, dtype) 1191 elif dtypes.issubdtype(dtype, np.bool_): 1192 return np.array(False, np.bool_) 1193 1194def _get_min_identity(dtype: DType) -> Array: 1195 if dtypes.issubdtype(dtype, np.inexact): 1196 return np.array(np.inf, dtype) 1197 elif dtypes.issubdtype(dtype, np.integer): 1198 return np.array(dtypes.iinfo(dtype).max, dtype) 1199 elif dtypes.issubdtype(dtype, np.bool_): 1200 return np.array(True, np.bool_) 1201 1202def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array: 1203 return reduce_sum_p.bind(operand, axes=tuple(axes)) 1204 1205def _reduce_prod(operand: Array, axes: Sequence[int]) -> Array: 1206 return reduce_prod_p.bind(operand, axes=tuple(axes)) 1207 1208def _reduce_max(operand: Array, axes: Sequence[int]) -> Array: 1209 return reduce_max_p.bind(operand, axes=tuple(axes)) 1210 1211def _reduce_min(operand: Array, axes: Sequence[int]) -> Array: 1212 return reduce_min_p.bind(operand, axes=tuple(axes)) 1213 1214def _reduce_or(operand: Array, axes: Sequence[int]) -> Array: 1215 return reduce_or_p.bind(operand, axes=tuple(axes)) 1216 1217def _reduce_and(operand: Array, axes: Sequence[int]) -> Array: 1218 return reduce_and_p.bind(operand, axes=tuple(axes)) 1219 1220def reduce_window(operand: Array, init_value: Array, computation: Callable, 1221 window_dimensions: Shape, window_strides: Sequence[int], 1222 padding: Union[str, Sequence[Tuple[int, int]]], 1223 base_dilation: Optional[Sequence[int]] = None, 1224 window_dilation: Optional[Sequence[int]] = None) -> Array: 1225 """Wraps XLA's `ReduceWindowWithGeneralPadding 1226 <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_ 1227 operator. 1228 """ 1229 if isinstance(padding, str): 1230 dilated_window_dims = (window_dimensions if window_dilation is None else 1231 _dilate_shape(window_dimensions, window_dilation)) 1232 padding = tuple(padtype_to_pads(operand.shape, dilated_window_dims, 1233 window_strides, padding)) 1234 else: 1235 padding = tuple(padding) 1236 if base_dilation is None: 1237 base_dilation = (1,) * len(window_dimensions) 1238 if window_dilation is None: 1239 window_dilation = (1,) * len(window_dimensions) 1240 monoid_reducer = _get_monoid_window_reducer(computation, init_value) 1241 if monoid_reducer: 1242 return monoid_reducer(operand, window_dimensions, window_strides, padding, 1243 base_dilation, window_dilation) 1244 else: 1245 jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value)) 1246 return reduce_window_p.bind( 1247 operand, init_value, jaxpr=jaxpr, consts=consts, 1248 window_dimensions=tuple(window_dimensions), 1249 window_strides=tuple(window_strides), padding=padding, 1250 base_dilation=tuple(base_dilation), 1251 window_dilation=tuple(window_dilation)) 1252 1253def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: 1254 aval = core.get_aval(x) 1255 if (type(aval) is ConcreteArray) and aval.shape == (): 1256 if monoid_op is add: 1257 return aval.val == 0 and _reduce_window_sum 1258 elif monoid_op is max: 1259 return aval.val == _get_max_identity(aval.dtype) and _reduce_window_max 1260 elif monoid_op is min: 1261 return aval.val == _get_min_identity(aval.dtype) and _reduce_window_min 1262 return None 1263 1264def _reduce_window_sum(operand: Array, window_dimensions: Shape, 1265 window_strides: Sequence[int], 1266 padding: Sequence[Tuple[int, int]], 1267 base_dilation: Optional[Sequence[int]] = None, 1268 window_dilation: Optional[Sequence[int]] = None) -> Array: 1269 if base_dilation is None: 1270 base_dilation = (1,) * len(window_dimensions) 1271 if window_dilation is None: 1272 window_dilation = (1,) * len(window_dimensions) 1273 return reduce_window_sum_p.bind( 1274 operand, window_dimensions=tuple(window_dimensions), 1275 window_strides=tuple(window_strides), padding=tuple(padding), 1276 base_dilation=tuple(base_dilation), 1277 window_dilation=tuple(window_dilation)) 1278 1279def _reduce_window_prod(operand: Array, window_dimensions: Shape, 1280 window_strides: Sequence[int], 1281 padding: Sequence[Tuple[int, int]], 1282 base_dilation: Optional[Sequence[int]] = None, 1283 window_dilation: Optional[Sequence[int]] = None) -> Array: 1284 init_value = _const(operand, 1) 1285 jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value)) 1286 if base_dilation is None: 1287 base_dilation = (1,) * len(window_dimensions) 1288 if window_dilation is None: 1289 window_dilation = (1,) * len(window_dimensions) 1290 return reduce_window_p.bind( 1291 operand, init_value, jaxpr=jaxpr, consts=consts, 1292 window_dimensions=tuple(window_dimensions), 1293 window_strides=tuple(window_strides), padding=tuple(padding), 1294 base_dilation=tuple(base_dilation), 1295 window_dilation=tuple(window_dilation)) 1296 1297def _reduce_window_max(operand: Array, window_dimensions: Shape, 1298 window_strides: Sequence[int], 1299 padding: Sequence[Tuple[int, int]], 1300 base_dilation: Optional[Sequence[int]] = None, 1301 window_dilation: Optional[Sequence[int]] = None) -> Array: 1302 if base_dilation is None: 1303 base_dilation = (1,) * len(window_dimensions) 1304 if window_dilation is None: 1305 window_dilation = (1,) * len(window_dimensions) 1306 return reduce_window_max_p.bind( 1307 operand, window_dimensions=tuple(window_dimensions), 1308 window_strides=tuple(window_strides), padding=tuple(padding), 1309 base_dilation=tuple(base_dilation), 1310 window_dilation=tuple(window_dilation)) 1311 1312def _reduce_window_min(operand: Array, window_dimensions: Shape, 1313 window_strides: Sequence[int], 1314 padding: Sequence[Tuple[int, int]], 1315 base_dilation: Optional[Sequence[int]] = None, 1316 window_dilation: Optional[Sequence[int]] = None) -> Array: 1317 if base_dilation is None: 1318 base_dilation = (1,) * len(window_dimensions) 1319 if window_dilation is None: 1320 window_dilation = (1,) * len(window_dimensions) 1321 return reduce_window_min_p.bind( 1322 operand, window_dimensions=tuple(window_dimensions), 1323 window_strides=tuple(window_strides), padding=tuple(padding), 1324 base_dilation=tuple(base_dilation), 1325 window_dilation=tuple(window_dilation)) 1326 1327def _select_and_scatter(operand: Array, select: Callable, 1328 window_dimensions: Shape, window_strides: Sequence[int], 1329 padding: Sequence[Tuple[int, int]], source: Array, 1330 init_value: Array, scatter: Callable, 1331 base_dilation: Sequence[int], 1332 window_dilation: Sequence[int]) -> Array: 1333 select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value)) 1334 scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value)) 1335 return select_and_scatter_p.bind( 1336 operand, source, init_value, select_jaxpr=select_jaxpr, 1337 select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, 1338 scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions), 1339 window_strides=tuple(window_strides), padding=tuple(padding), 1340 base_dilation=tuple(base_dilation), 1341 window_dilation=tuple(window_dilation)) 1342 1343def _select_and_scatter_add(source: Array, operand: Array, 1344 select_prim: core.Primitive, 1345 window_dimensions: Shape, 1346 window_strides: Sequence[int], 1347 padding: Sequence[Tuple[int, int]]) -> Array: 1348 return select_and_scatter_add_p.bind( 1349 source, operand, select_prim=select_prim, 1350 window_dimensions=tuple(window_dimensions), 1351 window_strides=tuple(window_strides), padding=tuple(padding)) 1352 1353def _select_and_gather_add(tangents: Array, operand: Array, 1354 select_prim: core.Primitive, 1355 window_dimensions: Shape, 1356 window_strides: Sequence[int], 1357 padding: Sequence[Tuple[int, int]], 1358 base_dilation: Sequence[int], 1359 window_dilation: Sequence[int]) -> Array: 1360 """Extracts the tangent corresponding to the minimum or maximum element in each 1361 window of the `operand` array. 1362 1363 Wraps XLA's `ReduceWindow 1364 <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_ 1365 operator, which applies a reduction function to all elements in each window of the 1366 input multi-dimensional array. In this case, the input multi-dimensional array is 1367 built by packing each element in the `operand` array with its corresponding 1368 element in the `tangents` array. 1369 1370 Args: 1371 tangents: an array 1372 operand: an array with the same shape as `tangents` 1373 select_prim: a reduction function (restricted to `ge_p` and `le_p`) 1374 window_dimensions: an array of integers for window dimension values 1375 window_strides: an array of integers for window stride values 1376 base_dilation: an array of integers for base dilation values 1377 window_dilation: an array of integers for window dilation values 1378 1379 Returns: 1380 An array containing the elements in `tangents` corresponding to the output of the 1381 reduction of `operand` fin each window. 1382 """ 1383 return select_and_gather_add_p.bind( 1384 tangents, operand, select_prim=select_prim, 1385 window_dimensions=tuple(window_dimensions), 1386 window_strides=tuple(window_strides), padding=tuple(padding), 1387 base_dilation=tuple(base_dilation), 1388 window_dilation=tuple(window_dilation)) 1389 1390def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1, 1391 is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]: 1392 """Wraps XLA's `Sort 1393 <https://www.tensorflow.org/xla/operation_semantics#sort>`_ 1394 operator. 1395 1396 Args: 1397 operand : Array or sequence of arrays 1398 dimension : integer dimension along which to sort. Default: -1. 1399 is_stable : boolean specifying whether to use a stable sort. Default: True. 1400 num_keys : number of operands to treat as sort keys. Default: 1. 1401 For num_keys > 1, the sort order will be determined lexicographically using 1402 the first `num_keys` arrays, with the first key being primary. 1403 The remaining operands will be returned with the same permutation. 1404 1405 Returns: 1406 operand : sorted version of the input or inputs. 1407 """ 1408 if isinstance(operand, Sequence): 1409 if len(operand) == 0: 1410 raise TypeError("Sort requires at least one operand") 1411 if not (1 <= num_keys <= len(operand)): 1412 raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}") 1413 dimension = canonicalize_axis(dimension, len(operand[0].shape)) 1414 return tuple(sort_p.bind(*operand, dimension=dimension, 1415 is_stable=is_stable, 1416 num_keys=num_keys)) 1417 else: 1418 if num_keys != 1: 1419 raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.") 1420 dimension = canonicalize_axis(dimension, len(operand.shape)) 1421 return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0] 1422 1423def sort_key_val(keys: Array, values: Array, dimension: int = -1, 1424 is_stable: bool = True) -> Tuple[Array, Array]: 1425 """Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``.""" 1426 dimension = canonicalize_axis(dimension, len(keys.shape)) 1427 k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1) 1428 return k, v 1429 1430def top_k(operand: Array, k: int) -> Tuple[Array, Array]: 1431 """Returns top ``k`` values and their indices along the last axis of ``operand``.""" 1432 k = int(k) 1433 if k < 0: 1434 raise ValueError("k argument to top_k must be nonnegative, got {}".format(k)) 1435 return top_k_p.bind(operand, k=k) 1436 1437def tie_in(x: Array, y: Array) -> Array: 1438 """Deprecated. Ignores ``x`` and returns ``y``.""" 1439 return y 1440 1441def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array: 1442 """Returns an array of `shape` filled with `fill_value`. 1443 1444 Args: 1445 shape: sequence of integers, describing the shape of the output array. 1446 fill_value: the value to fill the new array with. 1447 dtype: the type of the output array, or `None`. If not `None`, `fill_value` 1448 will be cast to `dtype`. 1449 """ 1450 shape = canonicalize_shape(shape) 1451 if np.shape(fill_value): 1452 msg = "full must be called with scalar fill_value, got fill_value.shape {}." 1453 raise TypeError(msg.format(np.shape(fill_value))) 1454 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) 1455 fill_value = convert_element_type(fill_value, dtype) 1456 return broadcast(fill_value, shape) 1457 1458def _device_put_raw(x): 1459 if isinstance(x, xla.DeviceArray): 1460 return x 1461 else: 1462 aval = raise_to_shaped(core.get_aval(x)) 1463 return xla.array_result_handler(None, aval)(*xla.device_put(x)) 1464 1465def iota(dtype: DType, size: int) -> Array: 1466 """Wraps XLA's `Iota 1467 <https://www.tensorflow.org/xla/operation_semantics#iota>`_ 1468 operator. 1469 """ 1470 if config.omnistaging_enabled: 1471 dtype = dtypes.canonicalize_dtype(dtype) 1472 size = core.concrete_or_error(int, size, "size argument of lax.iota") 1473 return iota_p.bind(dtype=dtype, shape=(size,), dimension=0) 1474 else: 1475 size = size if type(size) is masking.Poly else int(size) 1476 shape = canonicalize_shape((size,)) 1477 dtype = dtypes.canonicalize_dtype(dtype) 1478 lazy_expr = lazy.iota(dtype, shape[0]) 1479 aval = ShapedArray(shape, dtype) 1480 return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) 1481 1482def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array: 1483 """Convenience wrapper around ``iota``.""" 1484 dtype = dtypes.canonicalize_dtype(dtype) 1485 shape = canonicalize_shape(shape) 1486 dimension = core.concrete_or_error( 1487 int, dimension, "dimension argument of lax.broadcasted_iota") 1488 return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension) 1489 1490def _eye(dtype: DType, shape: Shape, offset: int) -> Array: 1491 """Like numpy.eye, create a 2D array with ones on a diagonal.""" 1492 N, M = tuple(map(int, shape)) 1493 offset = int(offset) 1494 dtype = dtypes.canonicalize_dtype(dtype) 1495 if config.omnistaging_enabled: 1496 bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), 1497 broadcasted_iota(np.int32, (N, M), 1)) 1498 return convert_element_type_p.bind(bool_eye, new_dtype=dtype) 1499 else: 1500 lazy_expr = lazy.eye(dtype, (N, M), offset) 1501 aval = ShapedArray((N, M), dtype) 1502 return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) 1503 1504def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: 1505 """This utility function exists for creating Kronecker delta arrays.""" 1506 shape = tuple(map(int, shape)) 1507 axes = tuple(map(int, axes)) 1508 dtype = dtypes.canonicalize_dtype(dtype) 1509 base_shape = tuple(np.take(shape, axes)) 1510 if config.omnistaging_enabled: 1511 iotas = [broadcasted_iota(np.uint32, base_shape, i) 1512 for i in range(len(base_shape))] 1513 eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] 1514 result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype) 1515 return broadcast_in_dim(result, shape, axes) 1516 else: 1517 lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) 1518 aval = ShapedArray(shape, dtype) 1519 return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) 1520 1521def _tri(dtype: DType, shape: Shape, offset: int) -> Array: 1522 """Like numpy.tri, create a 2D array with ones below a diagonal.""" 1523 N, M = tuple(map(int, shape)) 1524 offset = int(offset) 1525 dtype = dtypes.canonicalize_dtype(dtype) 1526 if config.omnistaging_enabled: 1527 bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), 1528 broadcasted_iota(np.int32, (N, M), 1)) 1529 return convert_element_type_p.bind(bool_tri, new_dtype=dtype) 1530 else: 1531 lazy_expr = lazy.tri(dtype, (N, M), offset) 1532 aval = ShapedArray((N, M), dtype) 1533 return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) 1534 1535def stop_gradient(x): 1536 """Stops gradient computation. 1537 1538 Operationally ``stop_gradient`` is the identity function, that is, it returns 1539 argument `x` unchanged. However, ``stop_gradient`` prevents the flow of 1540 gradients during forward or reverse-mode automatic differentiation. If there 1541 are multiple nested gradient computations, ``stop_gradient`` stops gradients 1542 for all of them. 1543 1544 For example: 1545 1546 >>> jax.grad(lambda x: x**2)(3.) 1547 array(6., dtype=float32) 1548 >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) 1549 array(0., dtype=float32) 1550 >>> jax.grad(jax.grad(lambda x: x**2))(3.) 1551 array(2., dtype=float32) 1552 >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) 1553 array(0., dtype=float32) 1554 """ 1555 def stop(x): 1556 if (dtypes.issubdtype(_dtype(x), np.floating) or 1557 dtypes.issubdtype(_dtype(x), np.complexfloating)): 1558 return ad_util.stop_gradient_p.bind(x) 1559 else: 1560 return x # only bind primitive on inexact dtypes, to avoid some staging 1561 return tree_map(stop, x) 1562 1563 1564### convenience wrappers around traceables 1565 1566 1567def conv(lhs: Array, rhs: Array, window_strides: Sequence[int], 1568 padding: str, precision: PrecisionLike = None) -> Array: 1569 """Convenience wrapper around `conv_general_dilated`. 1570 1571 Args: 1572 lhs: a rank `n+2` dimensional input array. 1573 rhs: a rank `n+2` dimensional array of kernel weights. 1574 window_strides: a sequence of `n` integers, representing the inter-window 1575 strides. 1576 padding: either the string `'SAME'`, the string `'VALID'`. 1577 precision: Optional. Either ``None``, which means the default precision for 1578 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 1579 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 1580 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 1581 1582 Returns: 1583 An array containing the convolution result. 1584 """ 1585 return conv_general_dilated(lhs, rhs, window_strides, padding, 1586 precision=precision) 1587 1588def conv_with_general_padding(lhs: Array, rhs: Array, 1589 window_strides: Sequence[int], 1590 padding: Union[str, Sequence[Tuple[int, int]]], 1591 lhs_dilation: Optional[Sequence[int]], 1592 rhs_dilation: Optional[Sequence[int]], 1593 precision: PrecisionLike = None) -> Array: 1594 """Convenience wrapper around `conv_general_dilated`. 1595 1596 Args: 1597 lhs: a rank `n+2` dimensional input array. 1598 rhs: a rank `n+2` dimensional array of kernel weights. 1599 window_strides: a sequence of `n` integers, representing the inter-window 1600 strides. 1601 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of 1602 `n` `(low, high)` integer pairs that give the padding to apply before and 1603 after each spatial dimension. 1604 lhs_dilation: `None`, or a sequence of `n` integers, giving the 1605 dilation factor to apply in each spatial dimension of `lhs`. LHS dilation 1606 is also known as transposed convolution. 1607 rhs_dilation: `None`, or a sequence of `n` integers, giving the 1608 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation 1609 is also known as atrous convolution. 1610 precision: Optional. Either ``None``, which means the default precision for 1611 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 1612 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 1613 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 1614 1615 Returns: 1616 An array containing the convolution result. 1617 """ 1618 return conv_general_dilated( 1619 lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation, 1620 rhs_dilation=rhs_dilation, precision=precision) 1621 1622 1623def _conv_transpose_padding(k, s, padding): 1624 """Calculate before and after padding for a dim of transposed convolution. 1625 1626 Args: 1627 k: int: kernel dimension. 1628 s: int: dimension stride value. 1629 padding: 'same' or 'valid' padding mode for original forward conv. 1630 1631 Returns: 1632 2-tuple: ints: before and after padding for transposed convolution. 1633 """ 1634 if padding == 'SAME': 1635 pad_len = k + s - 2 1636 if s > k - 1: 1637 pad_a = k - 1 1638 else: 1639 pad_a = int(np.ceil(pad_len / 2)) 1640 elif padding == 'VALID': 1641 pad_len = k + s - 2 + _max(k - s, 0) 1642 pad_a = k - 1 1643 else: 1644 raise ValueError('Padding mode must be `SAME` or `VALID`.') 1645 pad_b = pad_len - pad_a 1646 return pad_a, pad_b 1647 1648 1649def _flip_axes(x, axes): 1650 """Flip ndarray 'x' along each axis specified in axes tuple.""" 1651 for axis in axes: 1652 x = np.flip(x, axis) 1653 return x 1654 1655 1656def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], 1657 padding: Union[str, Sequence[Tuple[int, int]]], 1658 rhs_dilation: Optional[Sequence[int]] = None, 1659 dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, 1660 transpose_kernel: bool = False, 1661 precision: PrecisionLike = None) -> Array: 1662 """Convenience wrapper for calculating the N-d convolution "transpose". 1663 1664 This function directly calculates a fractionally strided conv rather than 1665 indirectly calculating the gradient (transpose) of a forward convolution. 1666 1667 Args: 1668 lhs: a rank `n+2` dimensional input array. 1669 rhs: a rank `n+2` dimensional array of kernel weights. 1670 strides: sequence of `n` integers, sets fractional stride. 1671 padding: 'SAME', 'VALID' will set as transpose of corresponding forward 1672 conv, or a sequence of `n` integer 2-tuples describing before-and-after 1673 padding for each `n` spatial dimension. 1674 rhs_dilation: `None`, or a sequence of `n` integers, giving the 1675 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation 1676 is also known as atrous convolution. 1677 dimension_numbers: tuple of dimension descriptors as in 1678 lax.conv_general_dilated. Defaults to tensorflow convention. 1679 transpose_kernel: if True flips spatial axes and swaps the input/output 1680 channel axes of the kernel. This makes the output of this function identical 1681 to the gradient-derived functions like keras.layers.Conv2DTranspose 1682 applied to the same kernel. For typical use in neural nets this is completely 1683 pointless and just makes input/output channel specification confusing. 1684 precision: Optional. Either ``None``, which means the default precision for 1685 the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, 1686 ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two 1687 ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. 1688 1689 Returns: 1690 Transposed N-d convolution, with output padding following the conventions of 1691 keras.layers.Conv2DTranspose. 1692 """ 1693 assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 1694 ndims = len(lhs.shape) 1695 one = (1,) * (ndims - 2) 1696 # Set dimensional layout defaults if not specified. 1697 if dimension_numbers is None: 1698 if ndims == 2: 1699 dimension_numbers = ('NC', 'IO', 'NC') 1700 elif ndims == 3: 1701 dimension_numbers = ('NHC', 'HIO', 'NHC') 1702 elif ndims == 4: 1703 dimension_numbers = ('NHWC', 'HWIO', 'NHWC') 1704 elif ndims == 5: 1705 dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') 1706 else: 1707 raise ValueError('No 4+ dimensional dimension_number defaults.') 1708 dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) 1709 k_shape = np.take(rhs.shape, dn.rhs_spec) 1710 k_sdims = k_shape[2:] 1711 # Calculate correct output shape given padding and strides. 1712 pads: Union[str, Sequence[Tuple[int, int]]] 1713 if padding in {'SAME', 'VALID'}: 1714 if rhs_dilation is None: 1715 rhs_dilation = (1,) * (rhs.ndim - 2) 1716 effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation) 1717 pads = [_conv_transpose_padding(k, s, padding) 1718 for k,s in zip(effective_k_size, strides)] 1719 else: 1720 pads = padding 1721 if transpose_kernel: 1722 # flip spatial dims and swap input / output channel axes 1723 rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) 1724 rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) 1725 return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn, 1726 precision=precision) 1727 1728 1729def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None, 1730 shape: Optional[Shape] = None) -> Array: 1731 """Create a full array like np.full based on the example array `x`. 1732 1733 Args: 1734 x: example array-like, used for shape and dtype information. 1735 fill_value: a scalar value to fill the entries of the output array. 1736 dtype: optional, a dtype parameter for the output ndarray. 1737 shape: optional, a shape parameter for the output ndarray. 1738 1739 Returns: 1740 An ndarray with the same shape as `x` with its entries set equal to 1741 `fill_value`, similar to the output of np.full. 1742 """ 1743 fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) 1744 if not config.omnistaging_enabled: 1745 fill_value = tie_in(x, fill_value) 1746 return full(fill_shape, fill_value, dtype or _dtype(x)) 1747 1748 1749def collapse(operand: Array, start_dimension: int, 1750 stop_dimension: int) -> Array: 1751 """Collapses dimensions of an array into a single dimension. 1752 1753 For example, if ``operand`` is an array with shape ``[2, 3, 4]``, 1754 ``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed 1755 dimension are laid out major-to-minor, i.e., with the lowest-numbered 1756 dimension as the slowest varying dimension. 1757 1758 Args: 1759 operand: an input array. 1760 start_dimension: the start of the dimensions to collapse (inclusive). 1761 stop_dimension: the end of the dimensions to collapse (exclusive). 1762 1763 Returns: 1764 An array where dimensions ``[start_dimension, stop_dimension)`` have been 1765 collapsed (raveled) into a single dimension. 1766 """ 1767 lo, hi = start_dimension, stop_dimension 1768 size = prod(operand.shape[lo:hi]) 1769 new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:] 1770 return reshape(operand, new_shape) 1771 1772 1773def slice_in_dim(operand: Array, start_index: Optional[int], 1774 limit_index: Optional[int], 1775 stride: int = 1, axis: int = 0)-> Array: 1776 """Convenience wrapper around slice applying to only one dimension.""" 1777 start_indices = [0] * operand.ndim 1778 limit_indices = list(operand.shape) 1779 strides = [1] * operand.ndim 1780 1781 # translate `None` 1782 len_axis = operand.shape[axis] 1783 start_index_int = _canonicalize_dimension(start_index) if start_index is not None else 0 1784 limit_index_int = _canonicalize_dimension(limit_index) if limit_index is not None else len_axis 1785 1786 # translate negative indices 1787 if start_index_int < 0: 1788 start_index_int = start_index_int + len_axis 1789 if limit_index_int < 0: 1790 limit_index_int = limit_index_int + len_axis 1791 1792 axis = int(axis) 1793 start_indices[axis] = start_index_int 1794 limit_indices[axis] = limit_index_int 1795 strides[axis] = int(stride) 1796 1797 return slice(operand, start_indices, limit_indices, strides) 1798 1799 1800def index_in_dim(operand: Array, index: int, axis: int = 0, 1801 keepdims: bool = True) -> Array: 1802 """Convenience wrapper around slice to perform int indexing.""" 1803 index, axis = int(index), int(axis) 1804 axis_size = operand.shape[axis] 1805 wrapped_index = index + axis_size if index < 0 else index 1806 if not 0 <= wrapped_index < axis_size: 1807 msg = 'index {} is out of bounds for axis {} with size {}' 1808 raise IndexError(msg.format(index, axis, axis_size)) 1809 result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis) 1810 if keepdims: 1811 return result 1812 else: 1813 return squeeze(result, (axis,)) 1814 1815 1816def dynamic_slice_in_dim(operand: Array, start_index: Array, 1817 slice_size: int, axis: int = 0) -> Array: 1818 """Convenience wrapper around dynamic_slice applying to one dimension.""" 1819 start_indices = [_zero(start_index)] * operand.ndim 1820 slice_sizes = list(operand.shape) 1821 1822 axis = int(axis) 1823 start_indices[axis] = start_index 1824 slice_sizes[axis] = int(slice_size) 1825 return dynamic_slice(operand, start_indices, slice_sizes) 1826 1827 1828def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0, 1829 keepdims: bool = True) -> Array: 1830 """Convenience wrapper around dynamic_slice to perform int indexing.""" 1831 result = dynamic_slice_in_dim(operand, index, 1, axis) 1832 if keepdims: 1833 return result 1834 else: 1835 return squeeze(result, (axis,)) 1836 1837 1838def dynamic_update_slice_in_dim(operand: Array, update: Array, 1839 start_index: Array, axis: int) -> Array: 1840 """Convenience wrapper around :func:`dynamic_update_slice` to update a slice 1841 in a single ``axis``. 1842 """ 1843 axis = int(axis) 1844 start_indices = [_zero(start_index)] * _ndim(operand) 1845 start_indices[axis] = start_index 1846 return dynamic_update_slice(operand, update, start_indices) 1847 1848 1849def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array, 1850 axis: int) -> Array: 1851 """Convenience wrapper around :func:`dynamic_update_slice` to update a slice 1852 of size 1 in a single ``axis``. 1853 """ 1854 axis = int(axis) 1855 if _ndim(update) != _ndim(operand): 1856 assert _ndim(update) + 1 == _ndim(operand) 1857 update = expand_dims(update, (axis,)) 1858 return dynamic_update_slice_in_dim(operand, update, index, axis) 1859 1860 1861def batch_matmul(lhs: Array, rhs: Array, 1862 precision: PrecisionLike = None) -> Array: 1863 """Batch matrix multiplication.""" 1864 if _min(lhs.ndim, rhs.ndim) < 2: 1865 raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}' 1866 .format(lhs.ndim, rhs.ndim)) 1867 if lhs.ndim != rhs.ndim: 1868 raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}' 1869 .format(lhs.ndim, rhs.ndim)) 1870 lhs_contract = (lhs.ndim - 1,) 1871 rhs_contract = (rhs.ndim - 2,) 1872 batch = tuple(range(lhs.ndim - 2)) 1873 return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)), 1874 precision=precision) 1875 1876 1877# These functions also exist in the XLA client library, but we treat them 1878# as non-primitive to maintain a smaller set of autodiff primitives. 1879 1880def square(x: Array) -> Array: 1881 r"""Elementwise square: :math:`x^2`.""" 1882 return integer_pow(x, 2) 1883 1884def reciprocal(x: Array) -> Array: 1885 r"""Elementwise reciprocal: :math:`1 \over x`.""" 1886 return integer_pow(x, -1) 1887 1888def _upcast_fp16_for_computation(f): 1889 @functools.wraps(f) 1890 def f_wrapped(x): 1891 dtype = _dtype(x) 1892 if dtype == np.float16 or dtype == dtypes.bfloat16: 1893 return convert_element_type( 1894 f(convert_element_type(x, np.float32)), dtype) 1895 return f(x) 1896 1897 return f_wrapped 1898 1899def tan(x: Array) -> Array: 1900 r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" 1901 return tan_p.bind(x) 1902 1903def asin(x: Array) -> Array: 1904 r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.""" 1905 return asin_p.bind(x) 1906 1907def acos(x: Array) -> Array: 1908 r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.""" 1909 return acos_p.bind(x) 1910 1911def atan(x: Array) -> Array: 1912 r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.""" 1913 return atan_p.bind(x) 1914 1915def sinh(x: Array) -> Array: 1916 r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`.""" 1917 return sinh_p.bind(x) 1918 1919def cosh(x: Array) -> Array: 1920 r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`.""" 1921 return cosh_p.bind(x) 1922 1923def asinh(x: Array) -> Array: 1924 r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`.""" 1925 return asinh_p.bind(x) 1926 1927def acosh(x: Array) -> Array: 1928 r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`.""" 1929 return acosh_p.bind(x) 1930 1931def atanh(x: Array) -> Array: 1932 r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.""" 1933 return atanh_p.bind(x) 1934 1935 1936# Add some methods to ShapedArray that rely on lax primitives 1937 1938ShapedArray.broadcast = core.aval_method(broadcast) 1939ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy 1940ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy 1941 1942def _iter(tracer): 1943 if tracer.ndim == 0: 1944 raise TypeError("iteration over a 0-d array") # same as numpy error 1945 else: 1946 n = int(tracer.shape[0]) 1947 # return (index_in_dim(tracer, i, keepdims=False) for i in range(n)) 1948 return iter([index_in_dim(tracer, i, keepdims=False) for i in range(n)]) 1949ShapedArray._iter = staticmethod(_iter) 1950 1951# Add some ad handlers that use (or could use) lax primitives 1952 1953def zeros_like_array(x): 1954 return full_like(x, 0) 1955 1956for t in itertools.chain( 1957 dtypes.python_scalar_dtypes.keys(), array_types, 1958 [xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray]): 1959 ad_util.jaxval_adders[t] = add 1960ad_util.jaxval_zeros_likers[xla._DeviceArray] = zeros_like_array 1961ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array 1962ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array 1963 1964 1965### primitives 1966 1967 1968_input_dtype = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype) 1969_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) 1970_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype 1971 1972def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None, 1973 multiple_results=False): 1974 prim = Primitive(name) 1975 prim.multiple_results = multiple_results 1976 prim.def_impl(partial(xla.apply_primitive, prim)) 1977 prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule)) 1978 xla.translations[prim] = translation_rule or partial(standard_translate, name) 1979 return prim 1980 1981 1982def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs): 1983 assert all(isinstance(arg, UnshapedArray) for arg in args), args 1984 least_specialized = _max( 1985 map(type, args), key=operator.attrgetter('array_abstraction_level')) 1986 if least_specialized is ConcreteArray: 1987 out_vals = prim.impl(*[x.val for x in args], **kwargs) 1988 if not prim.multiple_results: 1989 out_vals = [out_vals] 1990 out_avals = safe_map(ConcreteArray, out_vals) 1991 elif least_specialized is ShapedArray: 1992 shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs) 1993 if not prim.multiple_results: 1994 shapes, dtypes = [shapes], [dtypes] 1995 out_avals = safe_map(ShapedArray, shapes, dtypes) 1996 elif least_specialized is UnshapedArray: 1997 dtypes = dtype_rule(*args, **kwargs) 1998 if not prim.multiple_results: 1999 dtypes = [dtypes] 2000 out_avals = safe_map(UnshapedArray, dtypes) 2001 else: 2002 raise TypeError(args, least_specialized) 2003 if not prim.multiple_results: 2004 return out_avals[0] 2005 return out_avals 2006 2007 2008def standard_translate(name, c, *args, **kwargs): 2009 xla_opname = ''.join(term.capitalize() for term in name.split('_')) 2010 return getattr(xops, xla_opname)(*args, **kwargs) 2011 2012 2013def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): 2014 if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes): 2015 msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.' 2016 typename = str(np.dtype(aval.dtype).name) 2017 accepted_typenames = (t.__name__ for t in accepted_dtypes) 2018 raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames))) 2019 return result_dtype(aval.dtype) 2020 2021 2022def unop(result_dtype, accepted_dtypes, name, translation_rule=None): 2023 dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) 2024 prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, 2025 translation_rule=translation_rule) 2026 batching.defvectorized(prim) 2027 masking.defvectorized(prim) 2028 return prim 2029standard_unop = partial(unop, _identity) 2030_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name) 2031 2032 2033def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs): 2034 aval_dtypes = [aval.dtype for aval in avals] 2035 for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)): 2036 if not any(dtypes.issubdtype(aval_dtype, t) for t in types): 2037 if aval_dtype is dtypes.float0: 2038 raise TypeError( 2039 f"Called {name} with a float0 at position {i}. " 2040 "float0s do not support any operations by design, because they " 2041 "are not compatible with non-trivial vector spaces. No implicit dtype " 2042 "conversion is done. You can use np.zeros_like(arr, dtype=np.float) " 2043 "to cast a float0 array to a regular zeros array. \n" 2044 "If you didn't expect to get a float0 you might have accidentally " 2045 "taken a gradient with respect to an integer argument.") 2046 else: 2047 msg = ('{} does not accept dtype {} at position {}. ' 2048 'Accepted dtypes at position {} are subtypes of {}.') 2049 typename = str(np.dtype(aval_dtype).name) 2050 typenames = ', '.join(t.__name__ for t in types) 2051 raise TypeError(msg.format(name, typename, i, i, typenames)) 2052 _check_same_dtypes(name, False, *aval_dtypes) 2053 return result_dtype(*avals) 2054 2055 2056def _broadcasting_shape_rule(name, *avals): 2057 shapes = [aval.shape for aval in avals if aval.shape] 2058 if not shapes: 2059 return () 2060 if len({len(shape) for shape in shapes}) != 1: 2061 msg = '{} got arrays of different rank: {}.' 2062 raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) 2063 result_shape = _try_broadcast_shapes(shapes) 2064 if result_shape is None: 2065 msg = '{} got incompatible shapes for broadcasting: {}.' 2066 raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) 2067 return result_shape 2068 2069def naryop(result_dtype, accepted_dtypes, name, translation_rule=None): 2070 dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name) 2071 shape_rule = partial(_broadcasting_shape_rule, name) 2072 prim = standard_primitive(shape_rule, dtype_rule, name, 2073 translation_rule=translation_rule) 2074 batching.defbroadcasting(prim) 2075 masking.defnaryop(prim) 2076 return prim 2077standard_naryop = partial(naryop, _input_dtype) 2078 2079 2080def _broadcast_translate(translate: Callable): 2081 # Decorator for translation rules which adds explicit broadcasting of 2082 # positional arguments. This is necessary only for a handful of primitives 2083 # whose XLA implementations do not support broadcasting. 2084 def _broadcast_array(array, array_shape, result_shape): 2085 if array_shape == result_shape: 2086 return array 2087 bcast_dims = tuple(range(len(result_shape) - len(array_shape), 2088 len(result_shape))) 2089 result = xops.BroadcastInDim(array, result_shape, bcast_dims) 2090 return result 2091 2092 def _broadcasted_translation_rule(c, *args, **kwargs): 2093 shapes = [c.get_shape(arg).dimensions() for arg in args] 2094 result_shape = broadcast_shapes(*shapes) 2095 args = [_broadcast_array(arg, arg_shape, result_shape) 2096 for arg, arg_shape in zip(args, shapes)] 2097 return translate(c, *args, **kwargs) 2098 return _broadcasted_translation_rule 2099 2100# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs 2101# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just 2102# a broadcast). but saving the shape info with the primitives isn't great either 2103# because then we can't trace these ops without shape data. 2104def _brcast(x, *others): 2105 # Used in jvprules to make naryop broadcasting explicit for transposability. 2106 # Requires shape info during jvp tracing, which isn't strictly necessary. 2107 # We don't need full numpy broadcasting, but otherwise the logic is the same 2108 # so we reuse the broadcast_shapes function after filtering out scalars. 2109 shapes = tuple(filter(None, map(np.shape, (x,) + others))) 2110 shape = shapes and broadcast_shapes(*shapes) 2111 if np.shape(x) != shape: 2112 return _brcast_to(x, shape) 2113 else: 2114 return x 2115 2116 2117def _brcast_to(x, shape): 2118 x_shape = np.shape(x) 2119 assert x_shape != shape 2120 if x_shape: 2121 assert len(x_shape) == len(shape) 2122 broadcast_dimensions, = np.where(np.equal(x_shape, shape)) 2123 squeezed_dimensions, = np.where(np.not_equal(x_shape, shape)) 2124 squeezed = squeeze(x, squeezed_dimensions) 2125 return broadcast_in_dim(squeezed, shape, broadcast_dimensions) 2126 else: 2127 return broadcast(x, shape) 2128 2129 2130_float = {np.floating} 2131_complex = {np.complexfloating} 2132_complex_elem_types = {np.float32, np.float64} 2133_int = {np.integer} 2134_bool = {np.bool_} 2135 2136_num = _int | _float | _complex 2137_any = _int | _float | _complex | _bool 2138_bool_or_int = _int | _bool 2139 2140neg_p = standard_unop(_num, 'neg') 2141ad.deflinear2(neg_p, lambda t, operand: [neg(t)]) 2142 2143def _sign_translation_rule(c, x): 2144 shape = c.get_shape(x) 2145 dtype = shape.numpy_dtype() 2146 if dtypes.issubdtype(dtype, np.unsignedinteger): 2147 zero = xb.constant(c, np.array(0, dtype=dtype)) 2148 dims = c.get_shape(x).dimensions() 2149 return xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, dims), 2150 xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)), 2151 dims)) 2152 return xops.Sign(x) 2153 2154sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule) 2155ad.defjvp_zero(sign_p) 2156 2157nextafter_p = standard_naryop( 2158 [_float, _float], 'nextafter', 2159 translation_rule=_broadcast_translate(partial(standard_translate, 'next_after'))) 2160 2161floor_p = standard_unop(_float, 'floor') 2162ad.defjvp_zero(floor_p) 2163 2164ceil_p = standard_unop(_float, 'ceil') 2165ad.defjvp_zero(ceil_p) 2166 2167def _round_to_nearest_even(x): 2168 half = _const(x, 0.5) 2169 one = _const(x, 1) 2170 round_val = floor(x) 2171 fraction = x - round_val 2172 nearest_even_int = sub( 2173 round_val, mul(_const(x, 2), floor(mul(half, x)))) 2174 is_odd = eq(nearest_even_int, one) 2175 return select( 2176 bitwise_or(gt(fraction, half), 2177 bitwise_and(eq(fraction, half), is_odd)), 2178 add(round_val, one), round_val) 2179 2180def _round_translation_rule(c, x, *, rounding_method): 2181 if rounding_method is RoundingMethod.AWAY_FROM_ZERO: 2182 return xops.Round(x) 2183 else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN 2184 rounding_fun = xla.lower_fun(_round_to_nearest_even, multiple_results=False) 2185 return rounding_fun(c, x) 2186 2187round_p = standard_unop(_float, 'round') 2188xla.translations[round_p] = _round_translation_rule 2189ad.defjvp_zero(round_p) 2190 2191is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite') 2192ad.defjvp_zero(is_finite_p) 2193 2194exp_p = standard_unop(_float | _complex, 'exp') 2195ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) 2196iad.definverse(exp_p, lambda r, x: log(r)) 2197# For exp_p it is more efficient to use the reconstructed output for the vjp 2198# rule instead of computing it again from the input. 2199iad.primitive_ivjps[exp_p] = lambda x, y, ct: [[log(y[0])], [ct[0] * y[0]]] 2200 2201log_p = standard_unop(_float | _complex, 'log') 2202ad.defjvp(log_p, lambda g, x: div(g, x)) 2203iad.definverse(log_p, lambda r, x: exp(r)) 2204 2205expm1_p = standard_unop(_float | _complex, 'expm1') 2206ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) 2207 2208log1p_p = standard_unop(_float | _complex, 'log1p') 2209ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) 2210 2211tanh_p = standard_unop(_float | _complex, 'tanh') 2212ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), 2213 sub(_one(x), ans))) 2214 2215sin_p = standard_unop(_float | _complex, 'sin') 2216ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) 2217 2218cos_p = standard_unop(_float | _complex, 'cos') 2219ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) 2220 2221@partial(xla.lower_fun, multiple_results=False) 2222@_upcast_fp16_for_computation 2223def tan_translation_rule(x): 2224 return div(sin(x), cos(x)) 2225 2226tan_p = standard_unop(_float | _complex, 'tan', 2227 translation_rule=tan_translation_rule) 2228ad.defjvp(tan_p, lambda g, x: mul(g, _const(x, 1) + square(tan(x)))) 2229 2230 2231@partial(xla.lower_fun, multiple_results=False) 2232def asin_translation_rule(x): 2233 if dtypes.issubdtype(_dtype(x), np.complexfloating): 2234 return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x))) 2235 else: 2236 return mul(_const(x, 2), 2237 atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x)))))) 2238 2239asin_p = standard_unop(_float | _complex, 'asin', 2240 translation_rule=asin_translation_rule) 2241ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) 2242 2243 2244@partial(xla.lower_fun, multiple_results=False) 2245def acos_translation_rule(x): 2246 if dtypes.issubdtype(_dtype(x), np.complexfloating): 2247 result = mul(_const(x, 1j), acosh(x)) 2248 # By convention, numpy chooses the branch with positive real part. 2249 rpart = real(result) 2250 return select( 2251 gt(rpart, _const(rpart, 0)), 2252 result, 2253 neg(result) 2254 ) 2255 else: 2256 return select( 2257 ne(x, _const(x, -1.0)), 2258 mul(_const(x, 2), 2259 atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))), 2260 full_like(x, np.pi)) 2261 2262acos_p = standard_unop(_float | _complex, 'acos', 2263 translation_rule=acos_translation_rule) 2264ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) 2265 2266@partial(xla.lower_fun, multiple_results=False) 2267def atan_translation_rule(x): 2268 if dtypes.issubdtype(_dtype(x), np.complexfloating): 2269 return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x))) 2270 else: 2271 return atan2(x, _const(x, 1)) 2272 2273atan_p = standard_unop(_float | _complex, 'atan', 2274 translation_rule=atan_translation_rule) 2275ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) 2276 2277atan2_p = standard_naryop([_float, _float], 'atan2') 2278ad.defjvp(atan2_p, 2279 lambda g, x, y: _brcast(g, y) * (y / (square(x) + square(y))), 2280 lambda g, x, y: _brcast(g, x) * -x / (square(x) + square(y))) 2281 2282sinh_p = standard_unop(_float | _complex, 'sinh') 2283ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x))) 2284 2285cosh_p = standard_unop(_float | _complex, 'cosh') 2286ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) 2287 2288asinh_p = standard_unop(_float | _complex, 'asinh') 2289ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) 2290 2291acosh_p = standard_unop(_float | _complex, 'acosh') 2292ad.defjvp(acosh_p, 2293 lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) 2294 2295atanh_p = standard_unop(_float | _complex, 'atanh') 2296ad.defjvp(atanh_p, 2297 lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) 2298 2299regularized_incomplete_beta_p = standard_naryop( 2300 [_float, _float, _float], 'regularized_incomplete_beta', 2301 translation_rule=_broadcast_translate( 2302 partial(standard_translate, 'regularized_incomplete_beta'))) 2303 2304def betainc_gradx(g, a, b, x): 2305 lbeta = lgamma(a) + lgamma(b) - lgamma(a + b) 2306 partial_x = exp((b - 1) * log1p(-x) + 2307 (a - 1) * log(x) - lbeta) 2308 return partial_x * g 2309 2310def betainc_grad_not_implemented(g, a, b, x): 2311 raise ValueError("Betainc gradient with respect to a and b not supported.") 2312 2313ad.defjvp(regularized_incomplete_beta_p, 2314 betainc_grad_not_implemented, 2315 betainc_grad_not_implemented, 2316 betainc_gradx) 2317 2318lgamma_p = standard_unop(_float, 'lgamma') 2319ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) 2320 2321digamma_p = standard_unop(_float, 'digamma') 2322 2323igamma_p = standard_naryop( 2324 [_float, _float], 'igamma', 2325 translation_rule=_broadcast_translate(partial(standard_translate, 'igamma'))) 2326igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a', 2327 translation_rule=_broadcast_translate(partial(standard_translate, 2328 'igamma_grad_a'))) 2329 2330def igamma_gradx(g, a, x): 2331 return _brcast(g, a, x) * exp(-x + (a - _ones(a)) * log(x) - lgamma(a)) 2332 2333def igamma_grada(g, a, x): 2334 return _brcast(g, a, x) * igamma_grad_a(a, x) 2335 2336ad.defjvp(igamma_p, igamma_grada, igamma_gradx) 2337 2338igammac_p = standard_naryop( 2339 [_float, _float], 'igammac', 2340 translation_rule=_broadcast_translate(partial(standard_translate, 'igammac'))) 2341 2342def igammac_gradx(g, a, x): 2343 return -igamma_gradx(g, a, x) 2344 2345def igammac_grada(g, a, x): 2346 return -igamma_grada(g, a, x) 2347 2348ad.defjvp(igammac_p, igammac_grada, igammac_gradx) 2349 2350random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad', 2351 translation_rule=_broadcast_translate(partial(standard_translate, 2352 'random_gamma_grad'))) 2353 2354bessel_i0e_p = standard_unop(_float, 'bessel_i0e') 2355ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y)) 2356 2357bessel_i1e_p = standard_unop(_float, 'bessel_i1e') 2358def _bessel_i1e_jvp(g, y, x): 2359 eps = dtypes.finfo(_dtype(x)).eps 2360 x_is_not_tiny = abs(x) > eps 2361 safe_x = select(x_is_not_tiny, x, full_like(x, eps)) 2362 dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x)) 2363 dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5)) 2364 return g * dy_dx 2365ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp) 2366 2367erf_p = standard_unop(_float, 'erf') 2368ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), 2369 mul(g, exp(neg(square(x)))))) 2370 2371erfc_p = standard_unop(_float, 'erfc') 2372ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), 2373 mul(neg(g), exp(neg(square(x)))))) 2374 2375erf_inv_p = standard_unop(_float, 'erf_inv') 2376ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), 2377 mul(g, exp(square(ans))))) 2378 2379real_p = unop(_complex_basetype, _complex, 'real') 2380ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))]) 2381 2382imag_p = unop(_complex_basetype, _complex, 'imag') 2383ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))]) 2384 2385_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype 2386complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types], 2387 'complex') 2388ad.deflinear2(complex_p, lambda t, *args: [real(t), imag(neg(t))]) 2389 2390conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj') 2391 2392def _conj_transpose_rule(t, x, *, input_dtype): 2393 assert ad.is_undefined_primal(x) 2394 if dtypes.issubdtype(input_dtype, np.complexfloating): 2395 return [conj(t)] 2396 else: 2397 return [real(t)] 2398 2399xla.translations[conj_p] = lambda c, x, **kwargs: xops.Conj(x) 2400ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p) 2401ad.primitive_transposes[conj_p] = _conj_transpose_rule 2402 2403abs_p = unop(_complex_basetype, _num, 'abs') 2404 2405def _abs_jvp_rule(g, ans, x): 2406 if _iscomplex(x): 2407 return _maybe_real(mul(g, div(_maybe_conj(x), 2408 _replace_zero(convert_element_type(ans, _dtype(x)))))) 2409 else: 2410 return select(ge(x, _zero(x)), g, neg(g)) 2411ad.defjvp2(abs_p, _abs_jvp_rule) 2412_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x 2413_maybe_real = lambda x: real(x) if _iscomplex(x) else x 2414 2415sqrt_p = standard_unop(_float | _complex, 'sqrt') 2416ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) 2417 2418rsqrt_p = standard_unop(_float | _complex, 'rsqrt') 2419ad.defjvp2(rsqrt_p, 2420 lambda g, ans, x: 2421 mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5))))) 2422 2423pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow') 2424 2425def _pow_jvp_lhs(g, ans, x, y): 2426 jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y))))) 2427 return mul(_brcast(g, y), jac) 2428 2429def _pow_jvp_rhs(g, ans, x, y): 2430 return mul(_brcast(g, x), mul(log(_replace_zero(x)), ans)) 2431 2432ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs) 2433 2434 2435def _integer_pow_dtype_rule(x, *, y): 2436 dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x) 2437 if y < 0 and dtypes.issubdtype(dtype, np.integer): 2438 raise TypeError("Integers cannot be raised to negative powers, got " 2439 f"integer_pow({x}, {y})") 2440 return dtype 2441 2442def _integer_pow_translation_rule(c, x, *, y): 2443 if y == 0: 2444 shape = c.get_shape(x) 2445 one = xb.constant(c, np.array(1, dtype=shape.numpy_dtype())) 2446 return xops.Broadcast(one, shape.dimensions()) 2447 is_reciprocal = y < 0 2448 if is_reciprocal: 2449 y = -y 2450 acc = None 2451 while y > 0: 2452 if y & 1: 2453 acc = x if acc is None else xops.Mul(acc, x) 2454 y >>= 1 2455 if y > 0: 2456 x = xops.Mul(x, x) 2457 return xops.Reciprocal(acc) if is_reciprocal else acc 2458 2459def _integer_pow_jvp(g, x, *, y): 2460 return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1))) 2461 2462integer_pow_p = standard_primitive( 2463 _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', 2464 translation_rule=_integer_pow_translation_rule) 2465batching.defvectorized(integer_pow_p) 2466masking.defvectorized(integer_pow_p) 2467ad.defjvp(integer_pow_p, _integer_pow_jvp) 2468 2469_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x) 2470 2471not_p = standard_unop(_bool_or_int, 'not') 2472ad.defjvp_zero(not_p) 2473 2474and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and') 2475ad.defjvp_zero(and_p) 2476 2477or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or') 2478ad.defjvp_zero(or_p) 2479 2480xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor') 2481ad.defjvp_zero(xor_p) 2482 2483population_count_p = standard_unop(_int, 'population_count') 2484 2485def _add_transpose(t, x, y): 2486 # The following linearity assertion is morally true, but because in some cases we 2487 # instantiate zeros for convenience, it doesn't always hold. 2488 # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y) 2489 return [t, t] 2490 2491add_p = standard_naryop([_num, _num], 'add') 2492ad.defjvp(add_p, lambda g, x, y: _brcast(g, y), lambda g, x, y: _brcast(g, x)) 2493ad.primitive_transposes[add_p] = _add_transpose 2494def _add_inverse(r, x, y): 2495 xr = r - y 2496 yr = r - x 2497 return xr, yr 2498iad.definverse(add_p, _add_inverse) 2499 2500def _sub_transpose(t, x, y): 2501 # The following linearity assertion is morally true, but because in some cases 2502 # we instantiate zeros for convenience, it doesn't always hold. 2503 # TODO(mattjj): re-enable this assertion, don't return None below 2504 # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y) 2505 if type(t) is ad_util.Zero: 2506 x_bar = ad_util.Zero(x.aval) if ad.is_undefined_primal(x) else None 2507 y_bar = ad_util.Zero(y.aval) if ad.is_undefined_primal(y) else None 2508 return [x_bar, y_bar] 2509 else: 2510 return [t, neg(t)] 2511 2512sub_p = standard_naryop([_num, _num], 'sub') 2513ad.defjvp(sub_p, 2514 lambda g, x, y: _brcast(g, y), 2515 lambda g, x, y: _brcast(neg(g), x)) 2516ad.primitive_transposes[sub_p] = _sub_transpose 2517 2518mul_p = standard_naryop([_num, _num], 'mul') 2519ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) 2520def _mul_inverse(r, x, y): 2521 xr = r / y 2522 yr = r / x 2523 return xr, yr 2524iad.definverse(mul_p, _mul_inverse) 2525 2526def _div_transpose_rule(cotangent, x, y): 2527 assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) 2528 res = ad_util.Zero(x.aval) if type(cotangent) is ad_util.Zero else div(cotangent, y) 2529 return res, None 2530div_p = standard_naryop([_num, _num], 'div') 2531ad.defjvp(div_p, 2532 lambda g, x, y: div(_brcast(g, y), y), 2533 lambda g, x, y: mul(mul(neg(_brcast(g, x)), x), integer_pow(y, -2))) 2534ad.primitive_transposes[div_p] = _div_transpose_rule 2535 2536rem_p = standard_naryop([_num, _num], 'rem') 2537ad.defjvp(rem_p, 2538 lambda g, x, y: _brcast(g, y), 2539 lambda g, x, y: mul(_brcast(neg(g), x), floor(div(x, y)))) 2540 2541 2542def _broadcasting_select(c, which, x, y): 2543 """Wrapper around XLA `Select` that broadcasts its arguments.""" 2544 which_shape, x_shape, y_shape = ( 2545 c.get_shape(t).dimensions() for t in (which, x, y)) 2546 out_shape = broadcast_shapes(which_shape, x_shape, y_shape) 2547 bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape), 2548 len(out_shape))) 2549 which = xops.BroadcastInDim(which, out_shape, bcast_dims(which_shape)) 2550 x = xops.BroadcastInDim(x, out_shape, bcast_dims(x_shape)) 2551 y = xops.BroadcastInDim(y, out_shape, bcast_dims(y_shape)) 2552 return xops.Select(which, x, y) 2553 2554 2555def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): 2556 dtype = c.get_shape(x).numpy_dtype() 2557 if dtypes.issubdtype(dtype, np.complexfloating): 2558 rx = xops.Real(x) 2559 ry = xops.Real(y) 2560 return _broadcasting_select( 2561 c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)), 2562 cmp(rx, ry)), 2563 x, y) 2564 return minmax(x, y) 2565 2566max_p: core.Primitive = standard_naryop( 2567 [_any, _any], 'max', translation_rule=partial( 2568 _minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt)) 2569ad.defjvp2(max_p, 2570 lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)), 2571 lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x))) 2572 2573min_p: core.Primitive = standard_naryop( 2574 [_any, _any], 'min', translation_rule=partial( 2575 _minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt)) 2576ad.defjvp2(min_p, 2577 lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)), 2578 lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x))) 2579 2580shift_left_p = standard_naryop([_int, _int], 'shift_left') 2581ad.defjvp_zero(shift_left_p) 2582 2583shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic') 2584ad.defjvp_zero(shift_right_arithmetic_p) 2585 2586shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical') 2587ad.defjvp_zero(shift_right_logical_p) 2588 2589eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq') 2590ad.defjvp_zero(eq_p) 2591 2592ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne') 2593ad.defjvp_zero(ne_p) 2594 2595ge_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ge') 2596ad.defjvp_zero(ge_p) 2597 2598gt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'gt') 2599ad.defjvp_zero(gt_p) 2600 2601le_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'le') 2602ad.defjvp_zero(le_p) 2603 2604lt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'lt') 2605ad.defjvp_zero(lt_p) 2606 2607 2608def _convert_element_type_shape_rule(operand, *, new_dtype): 2609 return operand.shape 2610 2611def _convert_element_type_dtype_rule(operand, *, new_dtype): 2612 return new_dtype 2613 2614def _convert_element_type_translation_rule(c, operand, *, new_dtype): 2615 old_dtype = c.get_shape(operand).numpy_dtype() 2616 if (dtypes.issubdtype(old_dtype, np.complexfloating) and 2617 not dtypes.issubdtype(new_dtype, np.complexfloating)): 2618 operand = xops.Real(operand) 2619 new_etype = xla_client.dtype_to_etype(new_dtype) 2620 return xops.ConvertElementType(operand, new_element_type=new_etype) 2621 2622def _convert_element_type_transpose_rule(ct, operand, *, new_dtype): 2623 assert ad.is_undefined_primal(operand) 2624 old_dtype = operand.aval.dtype 2625 if type(ct) is ad_util.Zero: 2626 return [ad_util.Zero(operand.aval)] 2627 elif core.primal_dtype_to_tangent_dtype(old_dtype) is dtypes.float0: 2628 return [ad_util.Zero(ShapedArray(operand.aval.shape, dtype=dtypes.float0))] 2629 else: 2630 return [convert_element_type_p.bind(ct, new_dtype=old_dtype)] 2631 2632def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype): 2633 if core.primal_dtype_to_tangent_dtype(new_dtype) is dtypes.float0: 2634 return ad_util.Zero(ShapedArray(tangent.shape, dtype=dtypes.float0)) 2635 else: 2636 return convert_element_type_p.bind(tangent, new_dtype=new_dtype) 2637 2638convert_element_type_p = standard_primitive( 2639 _convert_element_type_shape_rule, _convert_element_type_dtype_rule, 2640 'convert_element_type', _convert_element_type_translation_rule) 2641ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule) 2642ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule 2643batching.defvectorized(convert_element_type_p) 2644masking.defvectorized(convert_element_type_p) 2645 2646 2647def _bitcast_convert_type_shape_rule(operand, *, new_dtype): 2648 return operand.shape 2649 2650def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): 2651 return new_dtype 2652 2653def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype): 2654 new_etype = xla_bridge.dtype_to_etype(new_dtype) 2655 return xops.BitcastConvertType(operand, new_element_type=new_etype) 2656 2657bitcast_convert_type_p = standard_primitive( 2658 _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 2659 'bitcast_convert_type', _bitcast_convert_type_translation_rule) 2660ad.defjvp_zero(bitcast_convert_type_p) 2661batching.defvectorized(bitcast_convert_type_p) 2662masking.defvectorized(bitcast_convert_type_p) 2663 2664 2665def _conv_general_dilated_shape_rule( 2666 lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding, 2667 lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, 2668 batch_group_count, **unused_kwargs) -> Tuple[int, ...]: 2669 assert type(dimension_numbers) is ConvDimensionNumbers 2670 if len(lhs.shape) != len(rhs.shape): 2671 msg = ("conv_general_dilated lhs and rhs must have the same number of " 2672 "dimensions, but got {} and {}.") 2673 raise ValueError(msg.format(lhs.shape, rhs.shape)) 2674 if not feature_group_count > 0: 2675 msg = ("conv_general_dilated feature_group_count " 2676 "must be a positive integer, got {}.") 2677 raise ValueError(msg.format(feature_group_count)) 2678 lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]] 2679 quot, rem = divmod(lhs_feature_count, feature_group_count) 2680 if rem: 2681 msg = ("conv_general_dilated feature_group_count must divide lhs feature " 2682 "dimension size, but {} does not divide {}.") 2683 raise ValueError(msg.format(feature_group_count, lhs_feature_count)) 2684 if quot != rhs.shape[dimension_numbers.rhs_spec[1]]: 2685 msg = ("conv_general_dilated lhs feature dimension size divided by " 2686 "feature_group_count must equal the rhs input feature dimension " 2687 "size, but {} // {} != {}.") 2688 raise ValueError(msg.format(lhs_feature_count, feature_group_count, 2689 rhs.shape[dimension_numbers.rhs_spec[1]])) 2690 if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count: 2691 msg = ("conv_general_dilated rhs output feature dimension size must be a " 2692 "multiple of feature_group_count, but {} is not a multiple of {}.") 2693 raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], 2694 feature_group_count)) 2695 2696 if not batch_group_count > 0: 2697 msg = ("conv_general_dilated batch_group_count " 2698 "must be a positive integer, got {}.") 2699 raise ValueError(msg.format(batch_group_count)) 2700 lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]] 2701 if lhs_batch_count % batch_group_count != 0: 2702 msg = ("conv_general_dilated batch_group_count must divide lhs batch " 2703 "dimension size, but {} does not divide {}.") 2704 raise ValueError(msg.format(batch_group_count, lhs_batch_count)) 2705 2706 if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count: 2707 msg = ("conv_general_dilated rhs output feature dimension size must be a " 2708 "multiple of batch_group_count, but {} is not a multiple of {}.") 2709 raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], 2710 batch_group_count)) 2711 2712 if batch_group_count > 1 and feature_group_count > 1: 2713 msg = ("At most one of batch_group_count and feature_group_count may be > " 2714 "1, got batch_group_count={} and feature_group_count={}") 2715 raise ValueError(msg.format(batch_group_count, feature_group_count)) 2716 2717 if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides): 2718 msg = ("conv_general_dilated window and window_strides must have " 2719 "the same number of dimensions, but got {} and {}") 2720 raise ValueError( 2721 msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)), len(window_strides))) 2722 2723 lhs_perm, rhs_perm, out_perm = dimension_numbers 2724 lhs_trans = _dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation) 2725 rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation) 2726 out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding, 2727 batch_group_count) 2728 return tuple(np.take(out_trans, np.argsort(out_perm))) 2729 2730def _conv_general_dilated_dtype_rule( 2731 lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, 2732 dimension_numbers, **unused_kwargs): 2733 return naryop_dtype_rule(_input_dtype, [_float | _complex, _float | _complex], 2734 'conv_general_dilated', lhs, rhs) 2735 2736_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:] 2737_conv_sdims = lambda spec: spec[2:] 2738 2739# Understanding the convolution transpose rules: 2740# Ignoring the spatial dimensions, let m = batch, j = input feature, 2741# k = output feature. 2742# 2743# Convolution computes the following contraction: 2744# Forward: [m, j] [j, k] -> [m, k] 2745# 2746# The transposes are similar to the rules for transposing a matmul: 2747# LHS transpose: [m, k] [k, j] -> [m, j] 2748# RHS transpose: [j, m] [m, k] -> [j, k] 2749# 2750# With feature grouping, we have the following signatures: 2751# Forward: [m, gj] [j, gk] -> [m, gk] 2752# LHS transpose: [m, gk] [k, gj] -> [m, gj] 2753# --> implemented as feature grouping after transposing the group from the 2754# kernel input features to the kernel output features. 2755# RHS transpose: [gj, m] [m, gk] -> [j, gk] 2756# --> which is batch grouping. 2757# 2758# With batch grouping, we have the following signatures: 2759# Forward: [gm,j] [j,gk]->[m,gk] 2760# LHS transpose: [m, gk][gk, j] -> [gm, j] 2761# --> implemented as feature grouping with transposing the group on the kernel 2762# and the output. 2763# RHS transpose: [j, gm][m, gk] -> [j, gk] 2764# --> which is feature grouping. 2765 2766def _conv_general_dilated_transpose_lhs( 2767 g, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, 2768 dimension_numbers, feature_group_count, batch_group_count, 2769 lhs_shape, rhs_shape, precision): 2770 assert type(dimension_numbers) is ConvDimensionNumbers 2771 assert batch_group_count == 1 or feature_group_count == 1 2772 lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers) 2773 lhs_spec, rhs_spec, out_spec = dimension_numbers 2774 t_rhs_spec = _conv_spec_transpose(rhs_spec) 2775 if feature_group_count > 1: 2776 # in addition to switching the dims in the spec, need to move the feature 2777 # group axis into the transposed rhs's output feature dim 2778 rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs) 2779 rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs) 2780 elif batch_group_count > 1: 2781 rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs) 2782 rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs) 2783 feature_group_count = batch_group_count 2784 trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec) 2785 padding = _conv_general_vjp_lhs_padding( 2786 np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims), 2787 window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation, 2788 rhs_dilation) 2789 revd_weights = rev(rhs, rhs_sdims) 2790 out = conv_general_dilated( 2791 g, revd_weights, window_strides=lhs_dilation, padding=padding, 2792 lhs_dilation=window_strides, rhs_dilation=rhs_dilation, 2793 dimension_numbers=trans_dimension_numbers, 2794 feature_group_count=feature_group_count, 2795 batch_group_count=1, precision=precision) 2796 if batch_group_count > 1: 2797 out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out) 2798 out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out) 2799 return out 2800 2801def _conv_general_dilated_transpose_rhs( 2802 g, lhs, *, window_strides, padding, lhs_dilation, rhs_dilation, 2803 dimension_numbers: ConvDimensionNumbers, feature_group_count: int, 2804 batch_group_count: int, lhs_shape, rhs_shape, precision): 2805 assert type(dimension_numbers) is ConvDimensionNumbers 2806 if np.size(g) == 0: 2807 # Avoids forming degenerate convolutions where the RHS has spatial size 0. 2808 # Awkwardly, we don't have an aval for the rhs readily available, so instead 2809 # of returning an ad_util.Zero instance here, representing a symbolic zero 2810 # value, we instead return a None, which is meant to represent having no 2811 # cotangent at all (and is thus incorrect for this situation), since the two 2812 # are treated the same operationally. 2813 # TODO(mattjj): adjust defbilinear so that the rhs aval is available here 2814 return None 2815 lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers) 2816 lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers) 2817 assert batch_group_count == 1 or feature_group_count == 1 2818 if batch_group_count > 1: 2819 feature_group_count = batch_group_count 2820 batch_group_count = 1 2821 elif feature_group_count > 1: 2822 batch_group_count = feature_group_count 2823 feature_group_count = 1 2824 trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans) 2825 padding = _conv_general_vjp_rhs_padding( 2826 np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims), 2827 window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation, 2828 rhs_dilation) 2829 return conv_general_dilated( 2830 lhs, g, window_strides=rhs_dilation, padding=padding, 2831 lhs_dilation=lhs_dilation, rhs_dilation=window_strides, 2832 dimension_numbers=trans_dimension_numbers, 2833 feature_group_count=feature_group_count, 2834 batch_group_count=batch_group_count, precision=precision) 2835 2836 2837def _conv_general_dilated_translation_rule( 2838 c, lhs, rhs, *, window_strides, padding, 2839 lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, 2840 batch_group_count, precision, expand_complex_convolutions, **unused_kwargs): 2841 assert type(dimension_numbers) is ConvDimensionNumbers 2842 dimension_numbers = _conv_general_proto(dimension_numbers) 2843 precision_config = _precision_config(precision) 2844 dtype = c.get_shape(lhs).numpy_dtype() 2845 conv = lambda x, y: xops.ConvGeneralDilated( 2846 x, y, window_strides, padding, lhs_dilation, rhs_dilation, 2847 dimension_numbers, feature_group_count, batch_group_count, 2848 precision_config=precision_config) 2849 if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): 2850 # We use a trick for complex multiplication due to Gauss which uses three 2851 # multiplications and five additions; instead of the naive method of four 2852 # multiplications and two additions. 2853 # https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm 2854 # 2855 # This performance win comes with a trade-off in accuracy; especially in 2856 # cases when the real and imaginary differ hugely in magnitude. The relative 2857 # error bound (e.g. 1p-24 in case of float32) would be relative to the 2858 # maximum of real and imaginary parts of the result instead of being 2859 # satisfied by the real and imaginary parts independently of each other. 2860 lhs_real, lhs_imag = xops.Real(lhs), xops.Imag(lhs) 2861 rhs_real, rhs_imag = xops.Real(rhs), xops.Imag(rhs) 2862 k1 = conv(xops.Add(lhs_real, lhs_imag), rhs_real) 2863 k2 = conv(lhs_real, xops.Sub(rhs_imag, rhs_real)) 2864 k3 = conv(lhs_imag, xops.Add(rhs_real, rhs_imag)) 2865 return xops.Complex(xops.Sub(k1, k3), xops.Add(k1, k2)) 2866 return conv(lhs, rhs) 2867 2868def _conv_general_dilated_batch_rule( 2869 batched_args, batch_dims, *, window_strides, padding, 2870 lhs_dilation, rhs_dilation, dimension_numbers, 2871 feature_group_count, batch_group_count, precision, **unused_kwargs): 2872 assert batch_group_count == 1 or feature_group_count == 1 2873 lhs, rhs = batched_args 2874 lhs_bdim, rhs_bdim = batch_dims 2875 lhs_spec, rhs_spec, out_spec = dimension_numbers 2876 2877 if lhs_bdim is not None and rhs_bdim is not None: 2878 assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim] 2879 if batch_group_count > 1: 2880 new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs) 2881 batch_group_count *= lhs.shape[lhs_bdim] 2882 else: 2883 new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs) 2884 feature_group_count *= lhs.shape[lhs_bdim] 2885 new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs) 2886 out = conv_general_dilated( 2887 new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation, 2888 dimension_numbers, feature_group_count=feature_group_count, 2889 batch_group_count=batch_group_count, 2890 precision=precision) 2891 out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out) 2892 return out, out_spec[1] 2893 2894 elif lhs_bdim is not None: 2895 if batch_group_count == 1: 2896 new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs) 2897 out = conv_general_dilated(new_lhs, rhs, window_strides, padding, 2898 lhs_dilation, rhs_dilation, dimension_numbers, 2899 feature_group_count, precision=precision) 2900 out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out) 2901 return out, out_spec[0] 2902 else: 2903 new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]), 2904 batch_group_count, lhs) 2905 new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim), 2906 lhs_spec[0] + 1, 2907 new_lhs) 2908 new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs) 2909 out = conv_general_dilated(new_lhs, rhs, window_strides, padding, 2910 lhs_dilation, rhs_dilation, dimension_numbers, 2911 feature_group_count, batch_group_count, 2912 precision=precision) 2913 out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out) 2914 return out, out_spec[0] 2915 2916 elif rhs_bdim is not None: 2917 if feature_group_count == 1 and batch_group_count == 1: 2918 new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs) 2919 out = conv_general_dilated(lhs, new_rhs, window_strides, padding, 2920 lhs_dilation, rhs_dilation, dimension_numbers, 2921 feature_group_count, batch_group_count, 2922 precision=precision) 2923 out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out) 2924 return out, out_spec[1] 2925 else: 2926 # groups need to be outermost, so we need to factor them out of the 2927 # rhs output feature dim, then factor the batch dim into the remaining rhs 2928 # output feature dim, then put groups back in. We do something 2929 # similar on the output. An alternative which would require more FLOPs but 2930 # fewer reshapes would be to broadcast lhs. 2931 group_count = (feature_group_count if feature_group_count > 1 2932 else batch_group_count) 2933 new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]), 2934 group_count, rhs) 2935 new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim), 2936 rhs_spec[0] + 1, 2937 new_rhs) 2938 new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs) 2939 out = conv_general_dilated(lhs, new_rhs, window_strides, padding, 2940 lhs_dilation, rhs_dilation, dimension_numbers, 2941 feature_group_count, batch_group_count, 2942 precision=precision) 2943 out = _reshape_axis_out_of(out_spec[1], group_count, out) 2944 out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out) 2945 out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out) 2946 return out, out_spec[1] 2947 2948def _masked(padded_value, logical_shape, dimensions, value=0): 2949 """ 2950 Sets all padding to the given value (default is 0) in the given dimensions. 2951 All values outside the logical shape are considered padding. 2952 """ 2953 if len(dimensions) == 0: 2954 return padded_value 2955 2956 masks = [broadcasted_iota(np.int32, padded_value.shape, d) < logical_shape[d] 2957 for d in dimensions] 2958 mask_intersection = masks[0] 2959 for mask in masks[1:]: 2960 mask_intersection &= mask 2961 return select(mask_intersection, padded_value, full_like(padded_value, value)) 2962 2963def _conv_general_dilated_masking_rule( 2964 padded_vals, logical_shapes, window_strides, padding, lhs_dilation, 2965 rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, 2966 lhs_shape, rhs_shape, precision): 2967 lhs, rhs = padded_vals 2968 logical_lhs_shape, logical_rhs_shape = logical_shapes 2969 2970 o, i, *window_dimensions = dimension_numbers.rhs_spec 2971 assert (np.all(np.take(rhs.shape, window_dimensions) 2972 == np.take(logical_rhs_shape, window_dimensions))), \ 2973 "Conv filter masking not yet implemented." 2974 2975 n, c, *padded_dimensions = dimension_numbers.lhs_spec 2976 2977 return conv_general_dilated( 2978 _masked(lhs, logical_lhs_shape, padded_dimensions), 2979 _masked(rhs, logical_rhs_shape, (i,)), 2980 window_strides=window_strides, padding=padding, 2981 lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, 2982 dimension_numbers=dimension_numbers, 2983 feature_group_count=feature_group_count, 2984 batch_group_count=batch_group_count, 2985 precision=precision) 2986 2987conv_general_dilated_p = standard_primitive( 2988 _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, 2989 'conv_general_dilated', partial(_conv_general_dilated_translation_rule, 2990 expand_complex_convolutions=False)) 2991 2992# TODO(b/161124619, b/161126248): XLA does not support complex convolution on 2993# CPU or GPU; on these backends, lower complex convolutions away. 2994xla.backend_specific_translations['cpu'][conv_general_dilated_p] = partial( 2995 _conv_general_dilated_translation_rule, expand_complex_convolutions=True) 2996xla.backend_specific_translations['gpu'][conv_general_dilated_p] = partial( 2997 _conv_general_dilated_translation_rule, expand_complex_convolutions=True) 2998 2999ad.defbilinear(conv_general_dilated_p, 3000 _conv_general_dilated_transpose_lhs, 3001 _conv_general_dilated_transpose_rhs) 3002batching.primitive_batchers[conv_general_dilated_p] = \ 3003 _conv_general_dilated_batch_rule 3004masking.masking_rules[conv_general_dilated_p] = \ 3005 _conv_general_dilated_masking_rule 3006 3007def _reshape_axis_into(src, dst, x): 3008 perm = [i for i in range(x.ndim) if i != src] 3009 perm.insert(dst, src) 3010 new_shape = list(np.delete(x.shape, src)) 3011 new_shape[dst] *= x.shape[src] 3012 return reshape(x, new_shape, perm) 3013 3014def _reshape_axis_out_of(src, size1, x): 3015 shape = list(x.shape) 3016 size2, ragged = divmod(shape[src], size1) 3017 assert not ragged 3018 shape[src:src+1] = [size1, size2] 3019 return reshape(x, shape) 3020 3021def _precision_config(precision): 3022 if precision is not None: 3023 config = xla_client.PrecisionConfig() 3024 if isinstance(precision, tuple): 3025 config.operand_precision.extend(precision) 3026 else: 3027 config.operand_precision.extend((precision, precision)) 3028 return config 3029 return None 3030 3031 3032def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, 3033 preferred_element_type: Optional[DType]): 3034 (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers 3035 if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) 3036 for d in (lhs_contracting, lhs_batch)): 3037 msg = ("dot_general requires lhs dimension numbers to be nonnegative and " 3038 "less than the number of axes of the lhs value, got " 3039 f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} " 3040 f"for lhs of rank {lhs.ndim}") 3041 raise TypeError(msg) 3042 if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim)) 3043 for d in (rhs_contracting, rhs_batch)): 3044 msg = ("dot_general requires rhs dimension numbers to be nonnegative and " 3045 "less than the number of axes of the rhs value, got " 3046 f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} " 3047 f"for rhs of rank {rhs.ndim}") 3048 raise TypeError(msg) 3049 if len(lhs_batch) != len(rhs_batch): 3050 msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch " 3051 "dimensions, got lhs_batch {} and rhs_batch {}.") 3052 raise TypeError(msg.format(lhs_batch, rhs_batch)) 3053 lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch) 3054 rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch) 3055 if len(lhs_batch_set) != len(lhs_batch): 3056 msg = ("dot_general requires lhs batch dimensions to be distinct, got " 3057 f"lhs_batch {lhs_batch}.") 3058 raise TypeError(msg) 3059 if len(rhs_batch_set) != len(rhs_batch): 3060 msg = ("dot_general requires rhs batch dimensions to be distinct, got " 3061 f"rhs_batch {rhs_batch}.") 3062 raise TypeError(msg) 3063 if len(lhs_contracting_set) != len(lhs_contracting): 3064 msg = ("dot_general requires lhs contracting dimensions to be distinct, " 3065 f"got lhs_contracting {lhs_contracting}.") 3066 raise TypeError(msg) 3067 if len(rhs_contracting_set) != len(rhs_contracting): 3068 msg = ("dot_general requires rhs contracting dimensions to be distinct, " 3069 f"got rhs_contracting {rhs_contracting}.") 3070 raise TypeError(msg) 3071 if lhs_contracting_set & lhs_batch_set: 3072 msg = ("dot_general requires lhs batch dimensions to be disjoint from " 3073 "contracting dimensions, got lhs_batch {} and lhs_contracting {}.") 3074 raise TypeError(msg.format(lhs_batch, lhs_contracting)) 3075 if rhs_contracting_set & rhs_batch_set: 3076 msg = ("dot_general requires rhs batch dimensions to be disjoint from " 3077 "contracting dimensions, got rhs_batch {} and rhs_contracting {}.") 3078 raise TypeError(msg.format(rhs_batch, rhs_contracting)) 3079 lhs_batch_shape = np.take(lhs.shape, lhs_batch) 3080 rhs_batch_shape = np.take(rhs.shape, rhs_batch) 3081 if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)): 3082 msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " 3083 "to have the same shape, got {} and {}.") 3084 raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape)) 3085 lhs_contracting_shape = np.take(lhs.shape, lhs_contracting) 3086 rhs_contracting_shape = np.take(rhs.shape, rhs_contracting) 3087 if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)): 3088 msg = ("dot_general requires contracting dimensions to have the same " 3089 "shape, got {} and {}.") 3090 raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape)) 3091 3092 batch_shape = tuple(lhs_batch_shape) 3093 lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch))) 3094 lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch)) 3095 rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch))) 3096 rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch)) 3097 return batch_shape + lhs_tensored_shape + rhs_tensored_shape 3098 3099def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, 3100 preferred_element_type: Optional[DType]): 3101 input_dtype = naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs) 3102 if preferred_element_type is None: 3103 return input_dtype 3104 if dtypes.issubdtype(input_dtype, np.integer) and not dtypes.issubdtype(preferred_element_type, np.integer): 3105 raise TypeError("`preferred_element_type` and the original type must both be integral or both be floating point.") 3106 if dtypes.issubdtype(input_dtype, np.signedinteger) and not dtypes.issubdtype(preferred_element_type, np.signedinteger): 3107 raise TypeError("`preferred_element_type` must have the same signedness as the original type.") 3108 input_bitwidth = np.dtype(input_dtype).itemsize 3109 preferred_bitwidth = np.dtype(preferred_element_type).itemsize 3110 if preferred_bitwidth < input_bitwidth: 3111 raise TypeError("`preferred_element_type` must not be narrower than the original type.") 3112 return preferred_element_type 3113 3114def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision, 3115 preferred_element_type: Optional[DType], 3116 swap_ans=False): 3117 (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers 3118 x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract) 3119 x_kept = remaining(range(x_ndim), x_contract, x_batch) 3120 y_kept = remaining(range(y.ndim), y_contract, y_batch) 3121 if swap_ans: 3122 ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept) 3123 else: 3124 ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) 3125 dims = ((ans_y, y_kept), (ans_batch, y_batch)) 3126 x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) 3127 out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) 3128 return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type), 3129 tuple(out_axes)) 3130 3131def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision, 3132 preferred_element_type: Optional[DType]): 3133 (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers 3134 swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) 3135 return _dot_general_transpose_lhs( 3136 g, x, dimension_numbers=swapped_dimension_numbers, precision=precision, 3137 preferred_element_type=preferred_element_type, 3138 swap_ans=True) 3139 3140 3141def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, 3142 precision, 3143 preferred_element_type: Optional[DType]): 3144 lhs, rhs = batched_args 3145 new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums( 3146 (lhs.ndim, rhs.ndim), batch_dims, dimension_numbers) 3147 batched_out = dot_general(lhs, rhs, new_dimension_numbers, 3148 precision=precision, 3149 preferred_element_type=preferred_element_type) 3150 return batched_out, result_batch_dim 3151 3152def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): 3153 # there are three kinds of dimensions in a dot_general: 3154 # - contraction dimensions appear in lhs and rhs but not the result 3155 # - batch dimensions appear in lhs, rhs, and result 3156 # - tensor product dimensions appear in the result and one of lhs or rhs 3157 lhs_ndim, rhs_ndim = ndims 3158 lbd, rbd = batch_dims 3159 assert lbd is not None or rbd is not None 3160 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 3161 3162 def bump_dims(dims, b): 3163 return tuple(np.add(dims, np.greater_equal(dims, b))) 3164 3165 if lbd is not None and rbd is not None: 3166 # adding a batch dimension 3167 lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd) 3168 rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd) 3169 lhs_contract = bump_dims(lhs_contract, lbd) 3170 rhs_contract = bump_dims(rhs_contract, rbd) 3171 result_batch_dim = 0 3172 else: 3173 # adding a tensor product dimension 3174 if lbd is not None: 3175 other = tuple(d for d in range(lhs_ndim) 3176 if d not in lhs_batch and d not in lhs_contract) 3177 result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd))) 3178 lhs_batch = bump_dims(lhs_batch, lbd) 3179 lhs_contract = bump_dims(lhs_contract, lbd) 3180 else: 3181 other = tuple(d for d in range(rhs_ndim) 3182 if d not in rhs_batch and d not in rhs_contract) 3183 result_batch_dim = (lhs_ndim - len(lhs_contract) + 3184 sum(np.less(other, rbd))) 3185 rhs_batch = bump_dims(rhs_batch, rbd) 3186 rhs_contract = bump_dims(rhs_contract, rbd) 3187 3188 new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) 3189 return new_dimension_numbers, int(result_batch_dim) 3190 3191def _dot_using_sum_of_products(lhs, rhs, *, dimension_numbers): 3192 contract_dims, batch_dims = dimension_numbers 3193 lhs_contract_dims, rhs_contract_dims = contract_dims 3194 lhs_batch_dims, rhs_batch_dims = batch_dims 3195 lhs_noncontract_dims = tuple(sorted( 3196 set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) 3197 rhs_noncontract_dims = tuple(sorted( 3198 set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) 3199 lhs = transpose(lhs, 3200 lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims) 3201 rhs = transpose(rhs, 3202 rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims) 3203 3204 lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims) 3205 lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims) 3206 lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand))) 3207 3208 rhs_start_expand = len(lhs_batch_dims) 3209 rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims) 3210 rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand))) 3211 3212 out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) + 3213 len(rhs_noncontract_dims)) 3214 op_product = bitwise_and if lhs.dtype == np.bool_ else mul 3215 op_sum = bitwise_or if lhs.dtype == np.bool_ else add 3216 return reduce(op_product(lhs, rhs), _zero(lhs), op_sum, 3217 tuple(range(out_ndim, out_ndim + len(lhs_contract_dims)))) 3218 3219def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision, 3220 preferred_element_type: Optional[DType]): 3221 if preferred_element_type is not None: 3222 preferred_element_type = xla_client.dtype_to_etype(preferred_element_type) 3223 return xops.DotGeneral(lhs, rhs, 3224 xc.make_dot_dimension_numbers(dimension_numbers), 3225 precision_config=_precision_config(precision), 3226 preferred_element_type=preferred_element_type) 3227 3228def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers, 3229 precision, 3230 preferred_element_type: Optional[DType]): 3231 lhs, rhs = padded_vals 3232 # Only need to mask off contraction dims of one side - we mask the lhs here 3233 # but this is arbitrary. Could check the sizes of lhs and rhs and mask 3234 # whichever is smallest. 3235 lhs_shape, _ = logical_shapes 3236 (lhs_contract, _), _ = dimension_numbers 3237 return dot_general(_masked(lhs, lhs_shape, lhs_contract), 3238 rhs, dimension_numbers, precision=precision, 3239 preferred_element_type=preferred_element_type) 3240 3241dot_general_p = standard_primitive(_dot_general_shape_rule, 3242 _dot_general_dtype_rule, 'dot_general', 3243 _dot_general_translation_rule) 3244ad.defbilinear(dot_general_p, 3245 _dot_general_transpose_lhs, _dot_general_transpose_rhs) 3246batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule 3247masking.masking_rules[dot_general_p] = _dot_general_masking_rule 3248 3249def _broadcast_shape_rule(operand, sizes): 3250 _check_shapelike('broadcast', 'sizes', sizes) 3251 return tuple(sizes) + operand.shape 3252 3253def _broadcast_batch_rule(batched_args, batch_dims, *, sizes): 3254 operand, = batched_args 3255 bdim, = batch_dims 3256 new_bdim = None if bdim is None else bdim + len(sizes) 3257 return broadcast(operand, sizes), new_bdim 3258 3259broadcast_p = standard_primitive( 3260 _broadcast_shape_rule, _input_dtype, 'broadcast') 3261ad.deflinear2(broadcast_p, lambda t, _, sizes: [_reduce_sum(t, range(len(sizes)))]) 3262batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule 3263 3264def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions): 3265 if type(operand) is np.ndarray: 3266 operand = _device_put_raw(operand) 3267 if xla.type_is_device_array(operand) and np.all( 3268 np.equal(operand.shape, np.take(shape, broadcast_dimensions))): 3269 shape = _broadcast_in_dim_shape_rule( 3270 operand, shape=shape, broadcast_dimensions=broadcast_dimensions) 3271 aval = ShapedArray(shape, _dtype(operand)) 3272 if operand._lazy_expr is None: 3273 lazy_expr = lazy.broadcast(lazy.array(operand.shape), shape, broadcast_dimensions) 3274 else: 3275 lazy_expr = lazy.broadcast(operand._lazy_expr, shape, broadcast_dimensions) 3276 return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer) 3277 else: 3278 return xla.apply_primitive(broadcast_in_dim_p, operand, shape=shape, 3279 broadcast_dimensions=broadcast_dimensions) 3280 3281def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): 3282 _check_shapelike('broadcast_in_dim', 'shape', shape) 3283 _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', 3284 broadcast_dimensions) 3285 operand_ndim = np.ndim(operand) 3286 if operand_ndim != len(broadcast_dimensions): 3287 msg = ('broadcast_in_dim broadcast_dimensions must have length equal to ' 3288 'operand ndim; got broadcast_dimensions {} for operand ndim {}.') 3289 raise TypeError(msg.format(broadcast_dimensions, operand_ndim)) 3290 if len(shape) < operand_ndim: 3291 msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank ' 3292 'to the operand shape; got operand ndim {} and target broadcast ndim {}.') 3293 raise TypeError(msg.format(operand_ndim, len(shape))) 3294 if not set(broadcast_dimensions).issubset(set(range(len(shape)))): 3295 msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output ' 3296 'dimensions, got {} for operand ndim {} and shape {}.') 3297 raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape)) 3298 if any(operand.shape[i] != shape[broadcast_dimensions[i]] and 3299 operand.shape[i] != 1 for i in range(operand_ndim)): 3300 msg = ( 3301 "broadcast_in_dim operand dimension sizes must either be 1, or be " 3302 "equal to their corresponding dimensions in the target broadcast " 3303 "shape; got operand of shape {}, target broadcast shape {}, " 3304 "broadcast_dimensions {} ") 3305 raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions)) 3306 if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or 3307 tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))): 3308 msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; " 3309 "got broadcast_dimensions {}") 3310 raise TypeError(msg.format(broadcast_dimensions)) 3311 3312 return shape 3313 3314def _broadcast_in_dim_transpose_rule(ct, operand, *, shape, broadcast_dimensions): 3315 shape_in = operand.aval.shape 3316 unit_dimensions = tuple(i for i, s in enumerate(shape_in) if s == 1) 3317 bdims = tuple(np.delete(broadcast_dimensions, unit_dimensions)) 3318 axes = tuple(np.delete(range(len(shape)), bdims)) 3319 return [expand_dims(_reduce_sum(ct, axes), unit_dimensions)] 3320 3321def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape, 3322 broadcast_dimensions): 3323 operand, = batched_args 3324 bdim, = batch_dims 3325 new_operand = batching.moveaxis(operand, bdim, 0) 3326 new_shape = (operand.shape[bdim],) + shape 3327 new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) 3328 return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0 3329 3330 3331broadcast_in_dim_p = standard_primitive( 3332 _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim') 3333broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl) 3334ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule) 3335batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule 3336 3337 3338def _clamp_shape_rule(min, operand, max): 3339 if min.shape and min.shape != operand.shape: 3340 m = "clamp requires min.shape == operand.shape or min.shape == (), got {}." 3341 raise TypeError(m.format(min.shape)) 3342 if max.shape and max.shape != operand.shape: 3343 m = "clamp requires max.shape == operand.shape or max.shape == (), got {}." 3344 raise TypeError(m.format(max.shape)) 3345 return operand.shape 3346 3347_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any], 3348 'clamp') 3349 3350clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp') 3351ad.defjvp(clamp_p, 3352 lambda g, min, operand, max: 3353 select(bitwise_and(gt(min, operand), lt(min, max)), 3354 _brcast(g, operand), _zeros(operand)), 3355 lambda g, min, operand, max: 3356 select(bitwise_and(gt(operand, min), lt(operand, max)), 3357 g, _zeros(operand)), 3358 lambda g, min, operand, max: 3359 select(lt(max, operand), _brcast(g, operand), _zeros(operand))) 3360batching.defbroadcasting(clamp_p) 3361 3362 3363def _concatenate_shape_rule(*operands, **kwargs): 3364 dimension = kwargs.pop('dimension') 3365 if not operands: 3366 msg = "concatenate expects at least one operand, got 0." 3367 raise TypeError(msg) 3368 if not all(isinstance(operand, UnshapedArray) for operand in operands): 3369 msg = "All objects to concatenate must be arrays, got {}." 3370 op = next(op for op in operands if not isinstance(op, UnshapedArray)) 3371 raise TypeError(msg.format(type(op))) 3372 if len({operand.ndim for operand in operands}) != 1: 3373 msg = "Cannot concatenate arrays with different ranks, got {}." 3374 raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands))) 3375 if not 0 <= dimension < operands[0].ndim: 3376 msg = "concatenate dimension out of bounds: dimension {} for shapes {}." 3377 raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands]))) 3378 shapes = [operand.shape[:dimension] + operand.shape[dimension+1:] 3379 for operand in operands] 3380 if not shapes[:-1] == shapes[1:]: 3381 msg = ("Cannot concatenate arrays with shapes that differ in dimensions " 3382 "other than the one being concatenated: concatenating along " 3383 "dimension {} for shapes {}.") 3384 shapes = [operand.shape for operand in operands] 3385 raise TypeError(msg.format(dimension, ", ".join(map(str, shapes)))) 3386 3387 concat_size = sum(o.shape[dimension] for o in operands) 3388 ex_shape = operands[0].shape 3389 return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:] 3390 3391def _concatenate_dtype_rule(*operands, **kwargs): 3392 _check_same_dtypes('concatenate', False, *(o.dtype for o in operands)) 3393 return operands[0].dtype 3394 3395def _concatenate_translation_rule(c, *operands, **kwargs): 3396 dimension = kwargs.pop('dimension') 3397 return xops.ConcatInDim(c, operands, dimension) 3398 3399def _concatenate_transpose_rule(t, *operands, dimension): 3400 operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape 3401 for o in operands] 3402 if type(t) is ad_util.Zero: 3403 return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None 3404 for o in operands] 3405 else: 3406 limit_points = np.cumsum([shape[dimension] for shape in operand_shapes]) 3407 starts = np.zeros((len(operands), t.ndim), dtype=int) 3408 starts[1:, dimension] = limit_points[:-1] 3409 limits = np.tile(t.shape, (len(operands), 1)) 3410 limits[:, dimension] = limit_points 3411 3412 return [slice(t, start, limit) if ad.is_undefined_primal(o) else None 3413 for o, start, limit in zip(operands, starts, limits)] 3414 3415def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): 3416 size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) 3417 if bdim is not None) 3418 operands = [batching.moveaxis(op, bdim, 0) if bdim is not None 3419 else broadcast(op, (size,)) 3420 for op, bdim in zip(batched_args, batch_dims)] 3421 return concatenate(operands, dimension + 1), 0 3422 3423# The concatenate_p masking rule requires use of a while-loop construct and so 3424# is defined in lax_control_flow.py 3425 3426concatenate_p = standard_primitive( 3427 _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', 3428 _concatenate_translation_rule) 3429ad.deflinear2(concatenate_p, _concatenate_transpose_rule) 3430ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule 3431batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule 3432 3433 3434def _pad_dtype_rule(operand, padding_value, *, padding_config): 3435 if operand.dtype != padding_value.dtype: 3436 msg = "pad operand and padding_value must be same dtype: got {} and {}." 3437 raise TypeError(msg.format(operand.dtype, padding_value.dtype)) 3438 3439 return _input_dtype(operand, padding_value) 3440 3441def _pad_shape_rule(operand, padding_value, *, padding_config): 3442 del padding_value 3443 if not len(padding_config) == np.ndim(operand): 3444 raise ValueError("length of padding_config must equal the number of axes " 3445 f"of operand, got padding_config {padding_config} " 3446 f"for operand shape {np.shape(operand)}") 3447 if not all(i >= 0 for _, _, i in padding_config): 3448 raise ValueError("interior padding in padding_config must be nonnegative, " 3449 f"got padding_config {padding_config}") 3450 return tuple(l + h + d + (_max(0, d - 1) * i if i > 0 else 0) 3451 for (l, h, i), d in zip(padding_config, np.shape(operand))) 3452 3453def _pad_transpose(t, operand, padding_value, *, padding_config): 3454 if type(t) is ad_util.Zero: 3455 t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None 3456 t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None 3457 else: 3458 lo, hi, interior = zip(*padding_config) 3459 total = lambda x: _reduce_sum(x, list(range(t.ndim))) 3460 3461 def t_op(): 3462 unpad_config = safe_zip(np.negative(lo), np.negative(hi), 3463 np.zeros_like(interior)) 3464 unpadded = pad(t, np.array(0., t.dtype), unpad_config) 3465 return slice(unpadded, np.zeros_like(lo), unpadded.shape, np.add(interior, 1)) 3466 3467 t_operand = t_op() if ad.is_undefined_primal(operand) else None 3468 t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None 3469 return [t_operand, t_padv] 3470 3471def _pad_batch_rule(batched_args, batch_dims, *, padding_config): 3472 operand, padding_value = batched_args 3473 operand_bdim, padding_value_bdim = batch_dims 3474 if padding_value_bdim is None: 3475 assert operand_bdim is not None 3476 padding_config = list(padding_config) 3477 padding_config.insert(operand_bdim, (0, 0, 0)) 3478 return pad(operand, padding_value, padding_config), operand_bdim 3479 else: 3480 raise NotImplementedError # loop and stack 3481 3482def _pad_translation_rule(c, operand, padding_value, *, padding_config): 3483 return xops.Pad(operand, padding_value, 3484 xc.make_padding_config(padding_config)) 3485 3486def _pad_masking_rule(padded_vals, logical_shapes, padding_config): 3487 operand, padding_value = padded_vals 3488 shape, _ = logical_shapes 3489 3490 out = pad(operand, padding_value, padding_config) 3491 out_shape = [lo + shape[i] * (interior + 1) 3492 for i, (lo, hi, interior) in enumerate(padding_config)] 3493 padded_dims = [i for i, config in enumerate(padding_config) 3494 if config != (0, 0, 0)] 3495 return _masked(out, out_shape, padded_dims, padding_value) 3496 3497pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', 3498 translation_rule=_pad_translation_rule) 3499ad.deflinear2(pad_p, _pad_transpose) 3500batching.primitive_batchers[pad_p] = _pad_batch_rule 3501masking.masking_rules[pad_p] = _pad_masking_rule 3502 3503 3504# The squeeze primitive exists for the benefit of masking and other 3505# transformations that need to keep track of axis identity. 3506# For example, consider reshaping a 2D array with shape (1, N) into a 1D array 3507# with shape (N,). This results in the following JAXpr: 3508# reshape[ dimension=None new_sizes=(N,) ] 3509# For N > 1, we can match up the output array axis with the second axis of the 3510# input. But for N = 1, it is not clear how axes match up: all we know from the 3511# JAXpr is that we are reshaping from (1, 1) to (1,). 3512# In constrast, squeeze[ dimensions=(0,) ] is unambiguous. 3513 3514def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array: 3515 """Squeeze any number of size 1 dimensions from an array.""" 3516 ndim = np.ndim(array) 3517 dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) 3518 if not dimensions: 3519 return array 3520 return squeeze_p.bind(array, dimensions=dimensions) 3521 3522def _squeeze_dtype_rule(operand, *, dimensions): 3523 return operand.dtype 3524 3525def _squeeze_shape_rule(operand, *, dimensions): 3526 return _compute_squeeze_shape(np.shape(operand), dimensions) 3527 3528def _compute_squeeze_shape(shape, dimensions): 3529 dims_set = set(dimensions) 3530 if len(dims_set) != len(dimensions): 3531 raise ValueError(f"dimensions are not unique: {dimensions}") 3532 if not all(0 <= d < len(shape) for d in dims_set): 3533 raise ValueError(f"dimensions outside range [0, ndim): {dimensions}") 3534 if any(shape[d] != 1 for d in dimensions): 3535 raise ValueError( 3536 "cannot select an axis to squeeze out which has size not equal to " 3537 f"one, got shape={shape} and dimensions={dimensions}") 3538 return tuple(s for i, s in enumerate(shape) if i not in dims_set) 3539 3540def _squeeze_translation_rule(c, arg, *, dimensions): 3541 new_shape = _compute_squeeze_shape(c.get_shape(arg).dimensions(), dimensions) 3542 return xops.Reshape(arg, new_shape) 3543 3544def _squeeze_transpose_rule(t, operand, *, dimensions): 3545 assert ad.is_undefined_primal(operand) 3546 return [expand_dims(t, dimensions)] 3547 3548def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): 3549 operand, = batched_args 3550 bdim, = batch_dims 3551 operand = batching.moveaxis(operand, bdim, 0) 3552 dimensions = tuple(np.add(1, dimensions)) 3553 return squeeze(operand, dimensions=dimensions), 0 3554 3555squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, 3556 'squeeze', _squeeze_translation_rule) 3557ad.deflinear2(squeeze_p, _squeeze_transpose_rule) 3558batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule 3559 3560 3561def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array: 3562 """Insert any number of size 1 dimensions into an array.""" 3563 ndim_out = np.ndim(array) + len(dimensions) 3564 dims_set = frozenset(canonicalize_axis(i, ndim_out) for i in dimensions) 3565 result_shape = list(np.shape(array)) 3566 for i in sorted(dims_set): 3567 result_shape.insert(i, 1) 3568 broadcast_dims = [i for i in range(ndim_out) if i not in dims_set] 3569 return broadcast_in_dim(array, result_shape, broadcast_dims) 3570 3571 3572# We have a nonstandard reshape impl so that we can be lazy about data movement. 3573def _reshape_impl(operand, *, new_sizes, dimensions): 3574 old_sizes = np.shape(operand) 3575 if xla.type_is_device_array(operand) and dimensions is None: 3576 bcast_dims = _is_singleton_reshape(old_sizes, new_sizes) 3577 if bcast_dims is not None: 3578 aval = ShapedArray(new_sizes, operand.dtype) 3579 if operand._lazy_expr is None: 3580 lazy_expr = lazy.broadcast(lazy.array(operand.shape), new_sizes, bcast_dims) 3581 else: 3582 lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims) 3583 return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer) 3584 return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes, 3585 dimensions=dimensions) 3586 3587def _is_singleton_reshape(old, new): 3588 # A singleton reshape is one where only singleton dimensions are added. We 3589 # want to detect them because they can be expressed as (lazy) broadcasts. 3590 old, new = iter(old), iter(new) 3591 d1, d2 = next(old, None), next(new, None) 3592 bcast_dims = [] 3593 i = 0 3594 while True: 3595 if d1 is d2 is None: 3596 return bcast_dims 3597 elif d1 == d2: 3598 bcast_dims.append(i) 3599 i += 1 3600 d1, d2 = next(old, None), next(new, None) 3601 elif d2 == 1: 3602 i += 1 3603 d2 = next(new, None) 3604 else: 3605 return None 3606 3607def _reshape_shape_rule(operand, *, new_sizes, dimensions): 3608 if not np.all(np.greater_equal(new_sizes, 0)): 3609 msg = 'reshape new_sizes must all be positive, got {}.' 3610 raise TypeError(msg.format(new_sizes)) 3611 if prod(np.shape(operand)) != prod(new_sizes): 3612 msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' 3613 raise TypeError(msg.format(new_sizes, np.shape(operand))) 3614 if dimensions is not None: 3615 if set(dimensions) != set(range(np.ndim(operand))): 3616 msg = ('reshape dimensions must be a permutation of operand dimensions, ' 3617 'got dimensions {} for shape {}.') 3618 raise TypeError(msg.format(dimensions, np.shape(operand))) 3619 return tuple(new_sizes) 3620 3621def _reshape_dtype_rule(operand, *, new_sizes, dimensions): 3622 return operand.dtype 3623 3624def _reshape_translation_rule(c, operand, *, new_sizes, dimensions): 3625 if dimensions is None: 3626 return xops.Reshape(operand, new_sizes) 3627 else: 3628 return xops.Reshape(operand, dimensions, new_sizes) 3629 3630def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions): 3631 assert ad.is_undefined_primal(operand) 3632 if dimensions is None: 3633 return [reshape(t, operand.aval.shape)] 3634 else: 3635 return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)), 3636 np.argsort(dimensions))] 3637 3638def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): 3639 operand, = batched_args 3640 bdim, = batch_dims 3641 operand = batching.moveaxis(operand, bdim, 0) 3642 if dimensions is not None: 3643 dimensions = (0,) + tuple(np.add(1, dimensions)) 3644 return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0 3645 3646def _reshape_masking_rule(padded_args, logical_shapes, polymorphic_shapes, 3647 new_sizes, dimensions): 3648 operand, = padded_args 3649 old_shape, = polymorphic_shapes 3650 def is_poly(size): return type(size) is masking.Poly and not size.is_constant 3651 def merge_const_sizes(shape): 3652 """Merges all nonpolymorphic sizes into the previous polymorphic size.""" 3653 poly_dims = [i for i, size in enumerate(shape) if is_poly(size)] 3654 return [prod(shape[start:stop]) 3655 for start, stop in zip([0] + poly_dims, poly_dims + [len(shape)])] 3656 if merge_const_sizes(old_shape) != merge_const_sizes(new_sizes): 3657 raise NotImplementedError( 3658 "Reshape on padded dimensions causing fragmentation is not supported.") 3659 3660 return reshape(operand, 3661 new_sizes=masking.padded_shape_as_value(new_sizes), 3662 dimensions=dimensions) 3663 3664reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, 3665 'reshape', _reshape_translation_rule) 3666reshape_p.def_impl(_reshape_impl) 3667ad.deflinear2(reshape_p, _reshape_transpose_rule) 3668batching.primitive_batchers[reshape_p] = _reshape_batch_rule 3669masking.masking_rules[reshape_p] = _reshape_masking_rule 3670 3671def _rev_shape_rule(operand, *, dimensions): 3672 _check_shapelike('rev', 'dimensions', dimensions) 3673 if len(set(dimensions)) != len(dimensions): 3674 msg = 'rev dimensions must be unique, got {}.' 3675 raise TypeError(msg.format(dimensions)) 3676 if dimensions and not _max(dimensions) < operand.ndim: 3677 msg = ('rev dimensions must all be less than operand ndim, got dimensions ' 3678 '{} for operand ndim {}.') 3679 raise TypeError(msg.format(dimensions, operand.ndim)) 3680 return operand.shape 3681 3682def _rev_batch_rule(batched_args, batch_dims, *, dimensions): 3683 operand, = batched_args 3684 bdim, = batch_dims 3685 new_dimensions = [i + 1 if i >= bdim else i for i in dimensions] 3686 return rev(operand, new_dimensions), bdim 3687 3688rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev') 3689ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) 3690batching.primitive_batchers[rev_p] = _rev_batch_rule 3691 3692 3693def _transpose_impl(operand, *, permutation): 3694 if xla.type_is_device_array(operand): 3695 if operand._lazy_expr is None: 3696 lazy_expr = lazy.transpose(lazy.array(operand.shape), permutation) 3697 else: 3698 lazy_expr = lazy.transpose(operand._lazy_expr, permutation) 3699 aval = ShapedArray(lazy_expr.shape, operand.dtype) 3700 return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer) 3701 else: 3702 return xla.apply_primitive(transpose_p, operand, permutation=permutation) 3703 3704def _transpose_shape_rule(operand, *, permutation): 3705 if not isinstance(permutation, (tuple, list, np.ndarray)): 3706 msg = "transpose permutation must be a tuple/list/ndarray, got {}." 3707 raise TypeError(msg.format(type(permutation))) 3708 if tuple(sorted(permutation)) != tuple(range(operand.ndim)): 3709 msg = ("transpose permutation isn't a permutation of operand dimensions, " 3710 "got permutation {} for operand shape {}.") 3711 raise TypeError(msg.format(permutation, operand.shape)) 3712 return tuple(np.take(operand.shape, permutation)) 3713 3714def _transpose_batch_rule(batched_args, batch_dims, *, permutation): 3715 operand, = batched_args 3716 bdim, = batch_dims 3717 perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation) 3718 return transpose(operand, perm), 0 3719 3720def _transpose_masking_rule(padded_vals, logical_shapes, permutation): 3721 return transpose(*padded_vals, permutation=permutation) 3722 3723transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 3724 'transpose') 3725transpose_p.def_impl(_transpose_impl) 3726ad.deflinear2(transpose_p, 3727 lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) 3728batching.primitive_batchers[transpose_p] = _transpose_batch_rule 3729masking.masking_rules[transpose_p] = _transpose_masking_rule 3730 3731 3732def _select_shape_rule(pred, on_true, on_false): 3733 if on_true.shape != on_false.shape: 3734 msg = "select on_true and on_false must have the same shape, got {} and {}." 3735 raise TypeError(msg.format(on_true.shape, on_false.shape)) 3736 if pred.shape and pred.shape != on_true.shape: 3737 msg = ("select pred must be scalar or have the same shape as on_true and " 3738 "on_false, got pred shape {} for on_true and on_false of shape {}.") 3739 raise TypeError(msg.format(pred.shape, on_true.shape)) 3740 return on_true.shape 3741 3742def _select_dtype_rule(pred, on_true, on_false): 3743 _check_same_dtypes("select", False, on_true.dtype, on_false.dtype) 3744 if not dtypes.issubdtype(pred.dtype, np.bool_): 3745 msg = "select pred must be boolean type, got {}." 3746 raise TypeError(msg.format(pred.dtype)) 3747 return on_true.dtype 3748 3749def _select_transpose_rule(t, pred, on_true, on_false): 3750 assert not ad.is_undefined_primal(pred) 3751 if type(t) is ad_util.Zero: 3752 return [None, 3753 ad_util.Zero(on_true.aval) if ad.is_undefined_primal(on_true) else None, 3754 ad_util.Zero(on_false.aval) if ad.is_undefined_primal(on_false) else None] 3755 else: 3756 zeros = full_like(t, 0) 3757 return [None, 3758 select(pred, t, zeros) if ad.is_undefined_primal(on_true) else None, 3759 select(pred, zeros, t) if ad.is_undefined_primal(on_false) else None] 3760 3761def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): 3762 pred, on_true, on_false, = batched_args 3763 pred_bdim, ot_bdim, of_bdim = batch_dims 3764 size = next(x.shape[i] for x, i in zip(batched_args, batch_dims) 3765 if i is not None) 3766 3767 # avoid transposes and some broadcasts in special cases 3768 if pred_bdim == ot_bdim == of_bdim: 3769 if np.shape(pred) == np.shape(on_true): 3770 return select(pred, on_true, on_false), pred_bdim 3771 else: 3772 # vmapped function had a scalar pred with nonscalar args 3773 assert np.ndim(pred) == 1 3774 pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim]) 3775 return select(pred, on_true, on_false), pred_bdim 3776 elif np.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: 3777 if ot_bdim == of_bdim: 3778 return select(pred, on_true, on_false), ot_bdim 3779 elif np.shape(on_true) == np.shape(on_false): 3780 on_false = batching.moveaxis(on_false, of_bdim, ot_bdim) 3781 return select(pred, on_true, on_false), ot_bdim 3782 3783 pred = batching.bdim_at_front(pred, pred_bdim, size) if np.shape(pred) else pred 3784 if not np.shape(on_true) == np.shape(on_false) == (): 3785 on_true = batching.bdim_at_front(on_true, ot_bdim, size) 3786 on_false = batching.bdim_at_front(on_false, of_bdim, size) 3787 assert np.shape(on_true) == np.shape(on_false) 3788 if 0 < np.ndim(pred) < np.ndim(on_true): 3789 # vmapped function had a scalar pred with nonscalar args 3790 assert np.ndim(pred) == 1 3791 pred = broadcast_in_dim(pred, on_true.shape, [0]) 3792 if np.ndim(pred) > np.ndim(on_true): 3793 assert np.ndim(on_true) == 0 3794 on_true = broadcast(on_true, pred.shape) 3795 on_false = broadcast(on_false, pred.shape) 3796 return select(pred, on_true, on_false), 0 3797 3798def _select_masking_rule(padded_vals, logical_shapes): 3799 pred_shape, true_shape, false_shape = [ 3800 masking.padded_shape_as_value(val.shape) for val in padded_vals] 3801 assert np.array_equal(pred_shape, true_shape) 3802 assert np.array_equal(pred_shape, false_shape) 3803 return select(*padded_vals) 3804 3805def _select_jvp(primals, tangents): 3806 pred, on_true, on_false = primals 3807 _, on_true_dot, on_false_dot = tangents 3808 out = select(pred, on_true, on_false) 3809 if type(on_true_dot) is ad_util.Zero: 3810 out_dot = select(pred, _zeros(on_false_dot), on_false_dot) 3811 elif type(on_false_dot) is ad_util.Zero: 3812 out_dot = select(pred, on_true_dot, _zeros(on_true_dot)) 3813 else: 3814 out_dot = select(pred, on_true_dot, on_false_dot) 3815 return out, out_dot 3816 3817select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select') 3818ad.primitive_jvps[select_p] = _select_jvp 3819ad.primitive_transposes[select_p] = _select_transpose_rule 3820batching.primitive_batchers[select_p] = _select_batch_rule 3821masking.masking_rules[select_p] = _select_masking_rule 3822 3823 3824def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): 3825 _check_shapelike("slice", "start_indices", start_indices) 3826 _check_shapelike("slice", "limit_indices", limit_indices) 3827 if operand.ndim != len(start_indices): 3828 msg = ("slice start_indices must have length equal to the number of " 3829 "dimensions of the operand, got indices {} for operand shape {}.") 3830 raise TypeError(msg.format(start_indices, operand.shape)) 3831 if len(start_indices) != len(limit_indices): 3832 msg = ("slice limit_indices must have the same length as start_indices, " 3833 "got start_inidices {} and limit_indices {}.") 3834 raise TypeError(msg.format(start_indices, limit_indices)) 3835 if (not masking.is_polymorphic(limit_indices) and 3836 not masking.is_polymorphic(operand.shape) and 3837 not np.all(np.less_equal(limit_indices, operand.shape))): 3838 msg = ("slice limit_indices must be less than or equal to operand shape, " 3839 "got limit_indices {} for operand shape {}.") 3840 raise TypeError(msg.format(limit_indices, operand.shape)) 3841 if not np.all(np.greater_equal(start_indices, 0)): 3842 msg = ("slice start_indices must be greater than or equal to zero, " 3843 "got start_indices of {}.") 3844 raise TypeError(msg.format(start_indices)) 3845 if (not masking.is_polymorphic(limit_indices) and 3846 not np.all(np.greater_equal(limit_indices, start_indices))): 3847 msg = ("slice limit_indices must be greater than or equal to start_indices," 3848 " got start_indices {} and limit_indices {}.") 3849 raise TypeError(msg.format(start_indices, limit_indices)) 3850 if strides is None: 3851 strides = np.ones(operand.ndim, np.int32) 3852 else: 3853 _check_shapelike("slice", "strides", strides) 3854 if len(strides) != operand.ndim: 3855 msg = ("slice strides must have length equal to the number of dimensions " 3856 "of the operand, got strides {} for operand shape {}.") 3857 raise TypeError(msg.format(strides, operand.shape)) 3858 if not np.all(np.greater(strides, 0)): 3859 msg = "slice strides must be positive, got {}" 3860 raise TypeError(msg.format(strides)) 3861 3862 diff = np.subtract(limit_indices, start_indices) 3863 # Not np.divmod since Poly.__rdivmod__ is ignored by NumPy, breaks poly stride 3864 return tuple(q + (r > 0) for q, r in map(divmod, diff, strides)) 3865 3866def _slice_translation_rule(c, operand, *, start_indices, limit_indices, 3867 strides): 3868 return xops.Slice(operand, start_indices, limit_indices, 3869 strides or [1] * len(start_indices)) 3870 3871def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): 3872 assert ad.is_undefined_primal(operand) 3873 operand_shape = operand.aval.shape 3874 if strides is None or np.all(np.equal(strides, 1)): 3875 pads = zip(start_indices, np.subtract(operand_shape, limit_indices), 3876 (0,) * len(start_indices)) 3877 else: 3878 real_limits = np.add( 3879 start_indices, 3880 np.where(np.array(t.shape) == 0, 0, 3881 np.add(1, np.multiply(np.subtract(t.shape, 1), strides)))) 3882 pads = safe_zip(start_indices, np.subtract(operand_shape, real_limits), 3883 np.subtract(strides, 1)) 3884 result = pad(t, _const(t, 0), pads) 3885 assert result.shape == operand_shape, ( 3886 f"result.shape={result.shape} operand_shape={operand_shape}") 3887 return [result] 3888 3889 3890def _slice_batching_rule(batched_args, batch_dims, *, start_indices, 3891 limit_indices, strides): 3892 operand, = batched_args 3893 bdim, = batch_dims 3894 3895 new_start_indices = list(start_indices) 3896 new_start_indices.insert(bdim, 0) 3897 3898 new_limit_indices = list(limit_indices) 3899 new_limit_indices.insert(bdim, operand.shape[bdim]) 3900 3901 if strides is None: 3902 new_strides = None 3903 else: 3904 new_strides = list(strides) 3905 new_strides.insert(bdim, 1) 3906 3907 out = slice(operand, new_start_indices, new_limit_indices, new_strides) 3908 return out, bdim 3909 3910def _slice_masking_rule( 3911 padded_vals, logical_shapes, start_indices, limit_indices, strides): 3912 operand, = padded_vals 3913 strides = masking.padded_shape_as_value(strides) if strides else None 3914 return slice(operand, 3915 start_indices=masking.padded_shape_as_value(start_indices), 3916 limit_indices=masking.padded_shape_as_value(limit_indices), 3917 strides=strides) 3918 3919slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', 3920 _slice_translation_rule) 3921ad.deflinear2(slice_p, _slice_transpose_rule) 3922batching.primitive_batchers[slice_p] = _slice_batching_rule 3923masking.masking_rules[slice_p] = _slice_masking_rule 3924 3925 3926def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): 3927 if operand.ndim != len(start_indices): 3928 msg = ("dynamic_slice start_indices must have length equal to the number " 3929 "of dimensions of the operand, got indices {} for operand shape {}.") 3930 raise TypeError(msg.format(start_indices, operand.shape)) 3931 if len(start_indices) != len(slice_sizes): 3932 msg = ("dynamic_slice slice_sizes must have the same length as " 3933 "start_indices, got start_inidices length {} and slice_sizes {}.") 3934 raise TypeError(msg.format(len(start_indices), slice_sizes)) 3935 if not np.all(np.less_equal(slice_sizes, operand.shape)): 3936 msg = ("slice slice_sizes must be less than or equal to operand shape, " 3937 "got slice_sizes {} for operand shape {}.") 3938 raise TypeError(msg.format(slice_sizes, operand.shape)) 3939 if not np.all(np.greater_equal(slice_sizes, 0)): 3940 msg = ("slice slice_sizes must be greater than or equal to zero, " 3941 "got slice_sizes of {}.") 3942 raise TypeError(msg.format(slice_sizes)) 3943 return tuple(slice_sizes) 3944 3945def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes): 3946 if any(i.dtype != start_indices[0].dtype or 3947 not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): 3948 msg = ("index arguments to dynamic_slice must be integers of the same " 3949 "type, got: {}") 3950 raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) 3951 return operand.dtype 3952 3953def _dynamic_slice_translation_rule(c, operand, *start_indices, slice_sizes): 3954 return xops.DynamicSlice(operand, start_indices, slice_sizes) 3955 3956def _dynamic_slice_jvp(primals, tangents, *, slice_sizes): 3957 tangent_out = tangents[0] 3958 if type(tangent_out) is not ad_util.Zero: 3959 tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes) 3960 return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out 3961 3962def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): 3963 assert ad.is_undefined_primal(operand) 3964 assert all(not ad.is_undefined_primal(s) for s in start_indices) 3965 operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype 3966 if type(t) is ad_util.Zero: 3967 return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) 3968 else: 3969 if config.omnistaging_enabled: 3970 zeros = full(operand_shape, 0, operand_dtype) 3971 else: 3972 zeros = full(operand_shape, tie_in(t, _zero(t))) 3973 return ([dynamic_update_slice(zeros, t, start_indices)] + 3974 [None] * len(start_indices)) 3975 3976def _batch_dynamic_slice_indices(indices, bdims): 3977 if len(indices) == 0: 3978 return np.array([], 'int32'), None 3979 size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1) 3980 if size < 0: 3981 return concatenate([broadcast(i, (1,)) for i in indices], 0), None 3982 indices = concatenate( 3983 [broadcast_in_dim(x, (size, 1), 3984 broadcast_dimensions=((0,) if i is not None else ())) 3985 for x, i in zip(indices, bdims)], 3986 dimension=1) 3987 return indices, 0 3988 3989def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): 3990 # A dynamic slice is a special case of gather; we can delegate to the gather 3991 # batching rule. 3992 # TODO(phawkins): consider removing dynamic_slice entirely and using gather 3993 # always. 3994 operand, *start_indices = batched_args 3995 operand_bd, *start_idx_bds = batch_dims 3996 operand_shape = (operand.shape if operand_bd is batching.not_mapped 3997 else tuple(np.delete(operand.shape, operand_bd))) 3998 dims = tuple(range(len(operand_shape))) 3999 dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), 4000 start_index_map=dims) 4001 index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) 4002 return _gather_batching_rule( 4003 [operand, index], [operand_bd, index_bdim], dimension_numbers=dnums, 4004 slice_sizes=slice_sizes) 4005 4006 4007dynamic_slice_p = standard_primitive( 4008 _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', 4009 _dynamic_slice_translation_rule) 4010ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO 4011ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule 4012batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule 4013 4014 4015def _dynamic_update_slice_shape_rule(operand, update, *start_indices): 4016 if operand.ndim != update.ndim: 4017 msg = ("dynamic_update_slice update must have the same rank as operand, " 4018 "got update shape {} for operand shape {}.") 4019 raise TypeError(msg.format(update.shape, operand.shape)) 4020 if operand.ndim != len(start_indices): 4021 msg = ("dynamic_update_slice start_indices must have length equal to the " 4022 "rank of operand, got indices {} for operand shape {}.") 4023 raise TypeError(msg.format(start_indices, operand.shape)) 4024 if not np.all(np.less_equal(update.shape, operand.shape)): 4025 msg = ("dynamic_update_slice update shape must be smaller than operand " 4026 "shape, got update shape {} for operand shape {}.") 4027 raise TypeError(msg.format(update.shape, operand.shape)) 4028 return operand.shape 4029 4030def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): 4031 _check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype) 4032 if any(i.dtype != start_indices[0].dtype or 4033 not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): 4034 msg = ("index arguments to dynamic_update_slice must be integers of the " 4035 "same type, got {}") 4036 raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) 4037 return operand.dtype 4038 4039def _dynamic_update_slice_jvp(primals, tangents): 4040 operand, update = primals[:2] 4041 start_indices = primals[2:] 4042 g_operand, g_update = tangents[:2] 4043 val_out = dynamic_update_slice(operand, update, start_indices) 4044 if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero: 4045 tangent_out = ad_util.Zero.from_value(val_out) 4046 else: 4047 g_operand = ad.instantiate_zeros(g_operand) 4048 g_update = ad.instantiate_zeros(g_update) 4049 tangent_out = dynamic_update_slice(g_operand, g_update, start_indices) 4050 return val_out, tangent_out 4051 4052def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices): 4053 assert all(not ad.is_undefined_primal(x) for x in start_indices) 4054 if ad.is_undefined_primal(update): 4055 update_shape = update.aval.shape 4056 else: 4057 update_shape = update.shape 4058 if type(t) is ad_util.Zero: 4059 operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None 4060 update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None 4061 else: 4062 dus = dynamic_update_slice 4063 ds = dynamic_slice 4064 zeros = _zeros(t, shape=update_shape) 4065 operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None 4066 update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None 4067 return [operand_t, update_t] + [None] * len(start_indices) 4068 4069def _dynamic_update_slice_translation_rule(c, operand, update, *start_indices): 4070 return xops.DynamicUpdateSlice(operand, update, start_indices) 4071 4072def _dynamic_update_slice_batching_rule(batched_args, batch_dims): 4073 # A dynamic update slice is a special case of scatter; we can delegate to the 4074 # scatter batching rule. 4075 # TODO(phawkins): consider removing dynamic_update_slice entirely and using 4076 # scatter always. 4077 operand, update, *start_idx = batched_args 4078 operand_bd, update_bd, *start_idx_bd = batch_dims 4079 update_shape = (np.shape(update) if update_bd is batching.not_mapped 4080 else tuple(np.delete(np.shape(update), update_bd))) 4081 dims = tuple(range(len(update_shape))) 4082 dnums = ScatterDimensionNumbers(update_window_dims=dims, 4083 inserted_window_dims=(), 4084 scatter_dims_to_operand_dims=dims) 4085 index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) 4086 return _scatter_batching_rule( 4087 scatter, (operand, index, update), (operand_bd, index_bdim, update_bd), 4088 update_jaxpr=None, update_consts=None, dimension_numbers=dnums, 4089 indices_are_sorted=True, unique_indices=True) 4090 4091 4092dynamic_update_slice_p = standard_primitive( 4093 _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, 4094 'dynamic_update_slice', _dynamic_update_slice_translation_rule) 4095ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp 4096ad.primitive_transposes[dynamic_update_slice_p] = \ 4097 _dynamic_update_slice_transpose_rule 4098batching.primitive_batchers[dynamic_update_slice_p] = \ 4099 _dynamic_update_slice_batching_rule 4100 4101 4102def _gather_dimensions_proto(indices_shape, dimension_numbers): 4103 assert type(dimension_numbers) is GatherDimensionNumbers 4104 proto = xla_client.GatherDimensionNumbers() 4105 proto.offset_dims.extend(dimension_numbers.offset_dims) 4106 proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) 4107 proto.start_index_map.extend(dimension_numbers.start_index_map) 4108 assert indices_shape.rank() > 0 4109 proto.index_vector_dim = indices_shape.rank() - 1 4110 return proto 4111 4112def _gather_dtype_rule(operand, start_indices, **kwargs): 4113 if not dtypes.issubdtype(start_indices.dtype, np.integer): 4114 raise ValueError("start_indices must have an integer type") 4115 return dtypes.canonicalize_dtype(operand.dtype) 4116 4117_rank = lambda arr: len(arr.shape) 4118 4119def _is_sorted(dims, op_name, name): 4120 for i in range(1, len(dims)): 4121 if dims[i] < dims[i - 1]: 4122 raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") 4123 4124def _sorted_dims_in_range(dims, rank, op_name, name): 4125 if len(dims) == 0: 4126 return 4127 invalid_dim = None 4128 if dims[0] < 0: 4129 invalid_dim = dims[0] 4130 elif dims[-1] >= rank: 4131 invalid_dim = dims[-1] 4132 if invalid_dim: 4133 raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " 4134 f"[0, {rank}); got: {invalid_dim}.") 4135 4136def _no_duplicate_dims(dims, op_name, name): 4137 if len(set(dims)) != len(dims): 4138 raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") 4139 4140def _gather_shape_rule(operand, start_indices, *, dimension_numbers, 4141 slice_sizes): 4142 """Validates the well-formedness of the arguments to Gather. 4143 4144 The code implements the checks based on the detailed operation semantics of 4145 XLA's `Gather <https://www.tensorflow.org/xla/operation_semantics#gather>`_ 4146 operator and following the outline of the implementation of 4147 ShapeInference::InferGatherShape in TensorFlow. 4148 """ 4149 4150 offset_dims = dimension_numbers.offset_dims 4151 collapsed_slice_dims = dimension_numbers.collapsed_slice_dims 4152 start_index_map = dimension_numbers.start_index_map 4153 4154 # Note: in JAX, index_vector_dim is always computed as below, cf. the 4155 # documentation of the GatherDimensionNumbers class. 4156 index_vector_dim = _rank(start_indices) - 1 4157 4158 # This case should never happen in JAX, due to the implicit construction of 4159 # index_vector_dim, but is included for completeness. 4160 if _rank(start_indices) < index_vector_dim or index_vector_dim < 0: 4161 raise TypeError(f"Gather index leaf dimension must be within [0, rank(" 4162 f"start_indices) + 1). rank(start_indices) is " 4163 f"{_rank(start_indices)} and gather index leaf dimension " 4164 f"is {index_vector_dim}.") 4165 4166 expanded_start_indices_shape = list(start_indices.shape) 4167 4168 # This case should never happen in JAX, due to the implicit construction of 4169 # index_vector_dim, but is included for completeness. 4170 if len(expanded_start_indices_shape) == index_vector_dim: 4171 expanded_start_indices_shape.append(1) 4172 4173 # Start ValidateGatherDimensions 4174 # In the error messages output by XLA, "offset_dims" is called "Output window 4175 # dimensions" in error messages. For consistency's sake, our error messages 4176 # stick to "offset_dims". 4177 _is_sorted(offset_dims, "gather", "offset_dims") 4178 _no_duplicate_dims(offset_dims, "gather", "offset_dims") 4179 4180 output_offset_dim_count = len(offset_dims) 4181 output_shape_rank = len(offset_dims) + _rank(start_indices) - 1 4182 4183 for i in range(output_offset_dim_count): 4184 offset_dim = offset_dims[i] 4185 if offset_dim < 0 or offset_dim >= output_shape_rank: 4186 raise TypeError(f"Offset dimension {i} in gather op is out of bounds; " 4187 f"got {offset_dim}, but should have been in " 4188 f"[0, {output_shape_rank})") 4189 4190 if len(start_index_map) != start_indices.shape[index_vector_dim]: 4191 raise TypeError(f"Gather op has {len(start_index_map)} elements in " 4192 f"start_index_map and the bound of dimension " 4193 f"index_vector_dim={index_vector_dim} of start_indices is " 4194 f"{start_indices.shape[index_vector_dim]}. These two " 4195 f"numbers must be equal.") 4196 4197 for i in range(len(start_index_map)): 4198 operand_dim_for_start_index_i = start_index_map[i] 4199 if (operand_dim_for_start_index_i < 0 or 4200 operand_dim_for_start_index_i >= _rank(operand)): 4201 raise TypeError(f"Invalid start_index_map; domain is " 4202 f"[0, {_rank(operand)}), got: " 4203 f"{i}->{operand_dim_for_start_index_i}.") 4204 4205 _no_duplicate_dims(start_index_map, "gather", "start_index_map") 4206 4207 # _is_sorted and _sorted_dims_in_range are checked in the opposite order 4208 # compared to the XLA implementation. In cases when the input is not sorted 4209 # AND there are problematic collapsed_slice_dims, the error message will thus 4210 # be different. 4211 _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") 4212 _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", 4213 "collapsed_slice_dims") 4214 _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") 4215 # End ValidateGatherDimensions 4216 4217 if _rank(operand) != len(slice_sizes): 4218 raise TypeError(f"Gather op must have one slice size for every input " 4219 f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " 4220 f"input_shape.rank={_rank(operand)}") 4221 4222 if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): 4223 raise TypeError(f"All components of the offset index in a gather op must " 4224 f"either be a offset dimension or explicitly collapsed; " 4225 f"got len(slice_sizes)={len(slice_sizes)}, " 4226 f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" 4227 f"{collapsed_slice_dims}.") 4228 4229 for i in range(len(slice_sizes)): 4230 slice_size = slice_sizes[i] 4231 corresponding_input_size = operand.shape[i] 4232 4233 if slice_size < 0 or slice_size > corresponding_input_size: 4234 raise TypeError(f"Slice size at index {i} in gather op is out of range, " 4235 f"must be within [0, {corresponding_input_size + 1}), " 4236 f"got {slice_size}.") 4237 4238 for i in range(len(collapsed_slice_dims)): 4239 bound = slice_sizes[collapsed_slice_dims[i]] 4240 if bound > 1: 4241 raise TypeError(f"Gather op can only collapse slice dims with bound 1 " 4242 f"or 0, but bound is {bound} for index " 4243 f"{collapsed_slice_dims[i]} at position {i}.") 4244 4245 expanded_start_indices_shape.pop(index_vector_dim) 4246 start_indices_shape = iter(expanded_start_indices_shape) 4247 4248 slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims)) 4249 return tuple(next(slice_sizes) if i in offset_dims 4250 else next(start_indices_shape) for i in range(output_shape_rank)) 4251 4252def _gather_translation_rule(c, operand, start_indices, *, dimension_numbers, 4253 slice_sizes): 4254 indices_shape = c.get_shape(start_indices) 4255 return xops.Gather( 4256 operand, start_indices, 4257 _gather_dimensions_proto(indices_shape, dimension_numbers), slice_sizes, 4258 indices_are_sorted=False) 4259 4260def _gather_jvp_rule(g, operand, start_indices, *, dimension_numbers, 4261 slice_sizes): 4262 return gather(g, start_indices, dimension_numbers, slice_sizes) 4263 4264def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers, 4265 slice_sizes): 4266 assert ad.is_undefined_primal(operand) 4267 operand_shape = operand.aval.shape 4268 if type(t) is ad_util.Zero: 4269 out = ad_util.Zero(operand.aval) 4270 else: 4271 if config.omnistaging_enabled: 4272 zeros = full(operand_shape, _zero(t)) 4273 else: 4274 zeros = full(operand_shape, tie_in(t, _zero(t))) 4275 scatter_dnums = ScatterDimensionNumbers( 4276 update_window_dims=dimension_numbers.offset_dims, 4277 inserted_window_dims=dimension_numbers.collapsed_slice_dims, 4278 scatter_dims_to_operand_dims=dimension_numbers.start_index_map) 4279 out = scatter_add(zeros, start_indices, t, scatter_dnums, 4280 indices_are_sorted=False, 4281 unique_indices=False) 4282 return [out, None] 4283 4284def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, 4285 slice_sizes): 4286 operand, start_indices = batched_args 4287 operand_bdim, start_indices_bdim = batch_dims 4288 4289 if operand_bdim is not None and start_indices_bdim is None: 4290 operand = batching.moveaxis(operand, operand_bdim, 0) 4291 slice_sizes = (operand.shape[0],) + slice_sizes 4292 offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) 4293 collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) 4294 start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) 4295 dnums = GatherDimensionNumbers( 4296 offset_dims=offset_dims, 4297 collapsed_slice_dims=collapsed_slice_dims, 4298 start_index_map=start_index_map) 4299 return gather(operand, start_indices, dimension_numbers=dnums, 4300 slice_sizes=slice_sizes), 0 4301 4302 elif operand_bdim is None and start_indices_bdim is not None: 4303 start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0) 4304 offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) 4305 dnums = GatherDimensionNumbers( 4306 offset_dims=offset_dims, 4307 collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, 4308 start_index_map=dimension_numbers.start_index_map) 4309 return gather(operand, start_indices, dimension_numbers=dnums, 4310 slice_sizes=slice_sizes), 0 4311 4312 else: 4313 # move batch dimensions to the front to simplify logic 4314 operand = batching.moveaxis(operand, operand_bdim, 0) 4315 start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0) 4316 4317 # Example: user code had start_indices shape (3, 4, 5), and we have to deal 4318 # with start_indices shape (7, 3, 4, 5). We transform that to a 4319 # start_indices of shape (7, 3, 4, 6) where we concatenated an iota that 4320 # counts along our batch dimension to the front of the ndindex. 4321 count_shape = list(start_indices.shape) 4322 count_shape[-1] = 1 4323 counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0) 4324 start_indices = concatenate([counts, start_indices], len(count_shape) - 1) 4325 4326 slice_sizes = (_min(operand.shape[0], 1),) + slice_sizes 4327 collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) 4328 offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) 4329 start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) 4330 4331 dnums = GatherDimensionNumbers( 4332 offset_dims=offset_dims, 4333 collapsed_slice_dims=collapsed_slice_dims, 4334 start_index_map=start_index_map) 4335 return gather(operand, start_indices, dimension_numbers=dnums, 4336 slice_sizes=slice_sizes), 0 4337 4338gather_p = standard_primitive( 4339 _gather_shape_rule, _gather_dtype_rule, 'gather', 4340 _gather_translation_rule) 4341ad.defjvp(gather_p, _gather_jvp_rule, None) 4342 4343ad.primitive_transposes[gather_p] = _gather_transpose_rule 4344batching.primitive_batchers[gather_p] = _gather_batching_rule 4345 4346 4347def _scatter_dimensions_proto(indices_shape, dimension_numbers): 4348 assert type(dimension_numbers) is ScatterDimensionNumbers 4349 proto = xla_client.ScatterDimensionNumbers() 4350 proto.update_window_dims.extend(dimension_numbers.update_window_dims) 4351 proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) 4352 proto.scatter_dims_to_operand_dims.extend( 4353 dimension_numbers.scatter_dims_to_operand_dims) 4354 assert indices_shape.rank() > 0 4355 proto.index_vector_dim = indices_shape.rank() - 1 4356 return proto 4357 4358def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs): 4359 if not dtypes.issubdtype(scatter_indices.dtype, np.integer): 4360 raise ValueError("scatter_indices must have an integer type") 4361 _check_same_dtypes("scatter", False, operand.dtype, updates.dtype) 4362 return dtypes.canonicalize_dtype(operand.dtype) 4363 4364def _scatter_shape_rule(operand, scatter_indices, updates, *, update_jaxpr, 4365 update_consts, dimension_numbers, indices_are_sorted, 4366 unique_indices): 4367 """Validates the well-formedness of the ``dimension_numbers`` argument to 4368 Scatter. 4369 4370 The code implements the checks based on the detailed operation semantics of 4371 XLA's `Scatter <https://www.tensorflow.org/xla/operation_semantics#scatter>`_ 4372 operator and following the outline of the implementation of 4373 ShapeInference::InferScatterShape in TensorFlow. 4374 """ 4375 4376 update_window_dims = dimension_numbers.update_window_dims 4377 inserted_window_dims = dimension_numbers.inserted_window_dims 4378 scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims 4379 # Note: in JAX, index_vector_dim is always computed as below, cf. the 4380 # documentation of the ScatterDimensionNumbers class. 4381 index_vector_dim = _rank(scatter_indices) - 1 4382 4383 # This case should never happen in JAX, due to the implicit construction of 4384 # index_vector_dim, but is included for completeness. 4385 if _rank(scatter_indices) < index_vector_dim or index_vector_dim < 0: 4386 raise TypeError(f"Scatter index leaf dimension must be within [0, " 4387 f"rank(scatter_indices) + 1). rank(scatter_indices) is " 4388 f"{_rank(scatter_indices)} and scatter index leaf " 4389 f"dimension is {index_vector_dim}.") 4390 4391 expanded_scatter_indices_shape = list(scatter_indices.shape) 4392 # This case should never happen in JAX, due to the implicit construction of 4393 # index_vector_dim, but is included for completeness. 4394 if len(expanded_scatter_indices_shape) == index_vector_dim: 4395 expanded_scatter_indices_shape.append(1) 4396 4397 expected_updates_rank = (len(expanded_scatter_indices_shape) - 1 + 4398 len(update_window_dims)) 4399 4400 if _rank(updates) != expected_updates_rank: 4401 raise TypeError(f"Updates tensor must be of rank {expected_updates_rank}; " 4402 f"got {_rank(updates)}.") 4403 4404 # Validate update_window_dims 4405 _is_sorted(update_window_dims, "scatter", "update_window_dims") 4406 _no_duplicate_dims(update_window_dims, "scatter", "update_window_dims") 4407 _sorted_dims_in_range(update_window_dims, _rank(updates), "scatter", 4408 "update_window_dims") 4409 4410 # Validate inserted_window_dims 4411 _is_sorted(inserted_window_dims, "scatter", "inserted_window_dims") 4412 _no_duplicate_dims(inserted_window_dims, "scatter", "inserted_window_dims") 4413 _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", 4414 "inserted_window_dims") 4415 4416 # Validate window_size 4417 window_size = len(update_window_dims) + len(inserted_window_dims) 4418 if _rank(operand) != window_size: 4419 raise TypeError(f"Scatter op has window of size {window_size}; doesn't " 4420 f"match operand of rank {_rank(operand)}.") 4421 4422 # Validate scatter_dims_to_operand_dims 4423 if (len(scatter_dims_to_operand_dims) != 4424 scatter_indices.shape[index_vector_dim]): 4425 raise TypeError(f"Scatter op has {len(scatter_dims_to_operand_dims)} " 4426 f"elements in scatter_dims_to_operand_dims and the bound " 4427 f"of dimension index_vector_dim={index_vector_dim} of " 4428 f"scatter_indices is " 4429 f"{scatter_indices.shape[index_vector_dim]}. These two " 4430 f"numbers must be equal") 4431 4432 for i in range(len(scatter_dims_to_operand_dims)): 4433 dim = scatter_dims_to_operand_dims[i] 4434 if dim < 0 or dim >= _rank(operand): 4435 raise TypeError(f"Invalid scatter_dims_to_operand_dims mapping; domain " 4436 f"is [0, {_rank(operand)}), got: {i}->{dim}.") 4437 4438 _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", 4439 "scatter_dims_to_operand_dims") 4440 4441 max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) 4442 if not i in set(inserted_window_dims)] 4443 4444 for i in range(len(update_window_dims)): 4445 update_window_dim = update_window_dims[i] 4446 if updates.shape[update_window_dim] > max_update_slice_sizes[i]: 4447 raise TypeError(f"Bounds of the window dimensions of updates must not " 4448 f"exceed the bounds of the corresponding dimensions of " 4449 f"operand. For dimension {update_window_dim}, updates " 4450 f"bound is {updates.shape[update_window_dim]}, operand " 4451 f"bound is {max_update_slice_sizes[i]}.") 4452 4453 update_scatter_dims = [dim for dim in range(_rank(updates)) if dim not in 4454 set(update_window_dims)] 4455 4456 scatter_dims_seen = 0 4457 for i in update_scatter_dims: 4458 if scatter_dims_seen == index_vector_dim: 4459 scatter_dims_seen += 1 4460 if updates.shape[i] != expanded_scatter_indices_shape[scatter_dims_seen]: 4461 raise TypeError(f"Bounds of the scatter dimensions of updates must be " 4462 f"the same as the bounds of the corresponding dimensions " 4463 f"of scatter indices. For scatter dimension {i}, updates " 4464 f"bound is {updates.shape[i]}, scatter_indices bound is " 4465 f"{expanded_scatter_indices_shape[scatter_dims_seen]}.") 4466 scatter_dims_seen += 1 4467 4468 return operand.shape 4469 4470def _scatter_translation_rule(c, operand, scatter_indices, updates, *, 4471 update_jaxpr, update_consts, dimension_numbers, 4472 indices_are_sorted, unique_indices): 4473 dtype = c.get_shape(operand).numpy_dtype() 4474 init_value = xb.constant(c, np.array(0, dtype)) 4475 update_computation = _reduction_computation( 4476 c, update_jaxpr, update_consts, init_value) 4477 indices_shape = c.get_shape(scatter_indices) 4478 return xops.Scatter(operand, scatter_indices, updates, update_computation, 4479 _scatter_dimensions_proto(indices_shape, dimension_numbers), 4480 indices_are_sorted, unique_indices) 4481 4482def _scatter_add_translation_rule( 4483 c, operand, scatter_indices, updates, *, update_jaxpr, update_consts, 4484 dimension_numbers, indices_are_sorted, unique_indices, 4485 expand_complex128=False): 4486 dtype = c.get_shape(operand).numpy_dtype() 4487 scatter_dims = _scatter_dimensions_proto(c.get_shape(scatter_indices), 4488 dimension_numbers) 4489 4490 def _make_reducer(dtype): 4491 subc = xla_bridge.make_computation_builder("scatter_add_reducer") 4492 shape = xc.Shape.array_shape(np.dtype(dtype), ()) 4493 args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)] 4494 out = xops.Add(args[0], args[1]) 4495 return subc.build(out) 4496 4497 if expand_complex128 and dtype == np.complex128: 4498 update_computation = _make_reducer(np.float64) 4499 re = xops.Scatter(xops.Real(operand), scatter_indices, xops.Real(updates), 4500 update_computation, scatter_dims, indices_are_sorted, 4501 unique_indices) 4502 im = xops.Scatter(xops.Imag(operand), scatter_indices, xops.Imag(updates), 4503 update_computation, scatter_dims, indices_are_sorted, 4504 unique_indices) 4505 return xops.Complex(re, im) 4506 else: 4507 update_computation = _make_reducer(dtype) 4508 return xops.Scatter(operand, scatter_indices, updates, update_computation, 4509 scatter_dims, indices_are_sorted, unique_indices) 4510 4511def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, 4512 dimension_numbers, indices_are_sorted, unique_indices): 4513 operand, scatter_indices, updates = primals 4514 g_operand, g_scatter_indices, g_updates = tangents 4515 val_out = scatter_add_p.bind( 4516 operand, scatter_indices, updates, update_jaxpr=update_jaxpr, 4517 update_consts=update_consts, dimension_numbers=dimension_numbers, 4518 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 4519 if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: 4520 tangent_out = ad_util.Zero.from_value(val_out) 4521 else: 4522 g_operand = ad.instantiate_zeros(g_operand) 4523 g_updates = ad.instantiate_zeros(g_updates) 4524 tangent_out = scatter_add_p.bind( 4525 g_operand, scatter_indices, g_updates, update_jaxpr=update_jaxpr, 4526 update_consts=update_consts, dimension_numbers=dimension_numbers, 4527 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 4528 return val_out, tangent_out 4529 4530def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *, 4531 update_jaxpr, update_consts, dimension_numbers, 4532 indices_are_sorted, unique_indices): 4533 assert not ad.is_undefined_primal(scatter_indices) 4534 if ad.is_undefined_primal(updates): 4535 updates_shape = updates.aval.shape 4536 else: 4537 updates_shape = updates.shape 4538 if type(t) is ad_util.Zero: 4539 operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None 4540 update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None 4541 else: 4542 operand_t = update_t = None 4543 if ad.is_undefined_primal(operand): 4544 operand_t = t 4545 4546 if ad.is_undefined_primal(updates): 4547 gather_dnums = GatherDimensionNumbers( 4548 offset_dims=dimension_numbers.update_window_dims, 4549 collapsed_slice_dims=dimension_numbers.inserted_window_dims, 4550 start_index_map=dimension_numbers.scatter_dims_to_operand_dims) 4551 slice_sizes = [] 4552 pos = 0 4553 for i in range(len(t.shape)): 4554 if i in dimension_numbers.inserted_window_dims: 4555 slice_sizes.append(1) 4556 else: 4557 slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) 4558 pos += 1 4559 update_t = gather(t, scatter_indices, dimension_numbers=gather_dnums, 4560 slice_sizes=slice_sizes) 4561 return [operand_t, None, update_t] 4562 4563def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *, 4564 update_jaxpr, update_consts, dimension_numbers, 4565 indices_are_sorted, unique_indices): 4566 assert not ad.is_undefined_primal(scatter_indices) 4567 if ad.is_undefined_primal(updates): 4568 updates_shape = updates.aval.shape 4569 else: 4570 updates_shape = updates.shape 4571 if type(t) is ad_util.Zero: 4572 operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None 4573 update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None 4574 else: 4575 operand_t = update_t = None 4576 if ad.is_undefined_primal(operand): 4577 operand_t = scatter_mul( 4578 t, scatter_indices, updates, dimension_numbers=dimension_numbers, 4579 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 4580 if ad.is_undefined_primal(updates): 4581 gather_dnums = GatherDimensionNumbers( 4582 offset_dims=dimension_numbers.update_window_dims, 4583 collapsed_slice_dims=dimension_numbers.inserted_window_dims, 4584 start_index_map=dimension_numbers.scatter_dims_to_operand_dims) 4585 slice_sizes = [] 4586 pos = 0 4587 for i in range(len(t.shape)): 4588 if i in dimension_numbers.inserted_window_dims: 4589 slice_sizes.append(1) 4590 else: 4591 slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) 4592 pos += 1 4593 update_t = gather(mul(t, operand), scatter_indices, 4594 dimension_numbers=gather_dnums, slice_sizes=slice_sizes) 4595 return [operand_t, None, update_t] 4596 4597 4598def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, 4599 update_jaxpr, update_consts, dimension_numbers, 4600 indices_are_sorted, unique_indices): 4601 operand, scatter_indices, updates = batched_args 4602 operand_bdim, scatter_indices_bdim, updates_bdim = batch_dims 4603 del update_jaxpr, update_consts # Unused. 4604 4605 # move the operand batch dim to the front if it is not None, otherwise create 4606 # it at the front (so that we can scatter into it) 4607 size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) 4608 if ax is not None) 4609 operand = batching.bdim_at_front(operand, operand_bdim, size) 4610 operand_bdim = 0 4611 4612 updates = batching.bdim_at_front(updates, updates_bdim, size) 4613 4614 if scatter_indices_bdim is None: 4615 inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) 4616 update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) 4617 scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) 4618 dnums = ScatterDimensionNumbers( 4619 update_window_dims=update_window_dims, 4620 inserted_window_dims=inserted_window_dims, 4621 scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) 4622 return scatter_op( 4623 operand, scatter_indices, updates, dnums, 4624 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0 4625 4626 4627 # see the third case in _gather_batching_rule for comparison and comments 4628 scatter_indices = batching.bdim_at_front( 4629 scatter_indices, scatter_indices_bdim, size) 4630 4631 count_shape = list(scatter_indices.shape) 4632 count_shape[-1] = 1 4633 counts = broadcasted_iota(scatter_indices.dtype, tuple(count_shape), 0) 4634 scatter_indices = concatenate([counts, scatter_indices], 4635 len(count_shape) - 1) 4636 4637 update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) 4638 inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) 4639 scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) 4640 4641 dnums = ScatterDimensionNumbers( 4642 update_window_dims=update_window_dims, 4643 inserted_window_dims=inserted_window_dims, 4644 scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) 4645 return scatter_op( 4646 operand, scatter_indices, updates, dnums, 4647 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0 4648 4649scatter_add_p = standard_primitive( 4650 _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', 4651 _scatter_add_translation_rule) 4652ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp 4653ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule 4654batching.primitive_batchers[scatter_add_p] = ( 4655 partial(_scatter_batching_rule, scatter_add)) 4656 4657xla.backend_specific_translations['gpu'][scatter_add_p] = partial( 4658 _scatter_add_translation_rule, expand_complex128=True) 4659 4660scatter_mul_p = standard_primitive( 4661 _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', 4662 _scatter_translation_rule) 4663 4664def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, 4665 indices_are_sorted, unique_indices, **kw): 4666 return mul(x, scatter_add( 4667 zeros_like_array(x), i, g, dimension_numbers=dimension_numbers, 4668 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)) 4669 4670ad.defjvp(scatter_mul_p, 4671 lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw), 4672 None, 4673 _scatter_mul_jvp_rhs) 4674ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule 4675batching.primitive_batchers[scatter_mul_p] = ( 4676 partial(_scatter_batching_rule, scatter_mul)) 4677 4678def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, 4679 update_consts, dimension_numbers, 4680 indices_are_sorted, unique_indices): 4681 operand, scatter_indices, updates = primals 4682 g_operand, g_scatter_indices, g_updates = tangents 4683 4684 scatter_dnums = dimension_numbers 4685 updates_shape = updates.shape 4686 4687 val_out = scatter_op.bind( 4688 operand, scatter_indices, updates, update_jaxpr=update_jaxpr, 4689 update_consts=update_consts, dimension_numbers=scatter_dnums, 4690 indices_are_sorted=indices_are_sorted, 4691 unique_indices=unique_indices) 4692 4693 if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: 4694 tangent_out = ad_util.Zero.from_value(val_out) 4695 else: 4696 g_operand = ad.instantiate_zeros(g_operand) 4697 g_updates = ad.instantiate_zeros(g_updates) 4698 4699 # gather_dnums and slice_sizes define the gather op that is the inverse of 4700 # the scatter op specified by scatter_dnums 4701 gather_dnums = GatherDimensionNumbers( 4702 offset_dims=scatter_dnums.update_window_dims, 4703 collapsed_slice_dims=scatter_dnums.inserted_window_dims, 4704 start_index_map=scatter_dnums.scatter_dims_to_operand_dims) 4705 4706 slice_sizes = [] 4707 pos = 0 4708 for i in range(len(operand.shape)): 4709 if i in scatter_dnums.inserted_window_dims: 4710 slice_sizes.append(1) 4711 else: 4712 slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) 4713 pos += 1 4714 4715 # For consistency with other max operations, if there are two or more values 4716 # in updates that are contending to replace the same index location, the 4717 # resulting tangent at that location will be the average of the associated 4718 # tangents for the values in updates. 4719 4720 initial_vals = gather( 4721 operand, scatter_indices, gather_dnums, np.array(slice_sizes)) 4722 4723 target_vals = gather( 4724 val_out, scatter_indices, gather_dnums, np.array(slice_sizes)) 4725 4726 successful_updates = (updates == target_vals) 4727 retained_values = (initial_vals == target_vals) 4728 4729 num_updates = gather( 4730 scatter_add(_zeros(operand), 4731 scatter_indices, 4732 select(successful_updates, _ones(updates), _zeros(updates)), 4733 scatter_dnums), 4734 scatter_indices, 4735 gather_dnums, 4736 np.array(slice_sizes)) 4737 4738 num_refs = gather( 4739 scatter_add(_zeros(operand), 4740 scatter_indices, 4741 _ones(updates), 4742 scatter_dnums), 4743 scatter_indices, 4744 gather_dnums, 4745 np.array(slice_sizes)) 4746 4747 updates_normalizer = select(retained_values, 4748 1.0 / (num_updates + 1), 4749 1.0 / num_updates) 4750 4751 updates_coef = select(successful_updates, 4752 updates_normalizer, 4753 _zeros(updates)) 4754 4755 operand_normalizer = select(retained_values, 4756 1.0 / (num_updates + 1), 4757 _zeros(num_updates)) 4758 4759 operand_coef = (-1.0 + operand_normalizer) / num_refs 4760 4761 # This can be simplified once scatter has transpose implemented 4762 target_tangents = gather( 4763 g_operand, scatter_indices, gather_dnums, np.array(slice_sizes)) 4764 4765 tangent_updates = (target_tangents * operand_coef + 4766 g_updates * updates_coef) 4767 4768 tangent_out = scatter_add(g_operand, 4769 scatter_indices, 4770 tangent_updates, 4771 scatter_dnums, 4772 indices_are_sorted=indices_are_sorted, 4773 unique_indices=unique_indices) 4774 4775 return val_out, tangent_out 4776 4777scatter_min_p = standard_primitive( 4778 _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', 4779 _scatter_translation_rule) 4780batching.primitive_batchers[scatter_min_p] = ( 4781 partial(_scatter_batching_rule, scatter_min)) 4782ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) 4783 4784scatter_max_p = standard_primitive( 4785 _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', 4786 _scatter_translation_rule) 4787batching.primitive_batchers[scatter_max_p] = ( 4788 partial(_scatter_batching_rule, scatter_max)) 4789ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) 4790 4791def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, 4792 dimension_numbers, indices_are_sorted, unique_indices): 4793 operand, scatter_indices, updates = primals 4794 g_operand, g_scatter_indices, g_updates = tangents 4795 dnums = dimension_numbers 4796 4797 if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: 4798 val_out = scatter_p.bind( 4799 operand, scatter_indices, updates, update_jaxpr=update_jaxpr, 4800 update_consts=update_consts, dimension_numbers=dnums, 4801 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) 4802 return val_out, ad_util.Zero.from_value(val_out) 4803 4804 g_operand = ad.instantiate_zeros(g_operand) 4805 g_updates = ad.instantiate_zeros(g_updates) 4806 4807 # If there are overlapping indices in the scatter, it is unspecified which 4808 # update "wins". So we use the following perhaps surprising scheme: 4809 # a) attach a positive ID to each update in updates, and perform the scatter 4810 # on the IDs 4811 # b) perform the inverse gather on the scattered IDs (similar to 4812 # _scatter_add_transpose). 4813 # c) use the gathered IDs to mask the primal and tangent values. 4814 # d) perform a scatter-add on the masked primal and tangent values. A benefit 4815 # of using scatter-add here is that we don't need a `scatter` transpose 4816 # rule. 4817 4818 4819 # a) attach a positive ID to each update in `updates`, and perform a scatter 4820 # on the IDs. 4821 ids_shape = np.array(updates.shape, dtype=np.int64) 4822 ids_shape[dnums.update_window_dims,] = 1 4823 num_ids = np.prod(ids_shape) 4824 id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64 4825 update_ids = add(reshape(iota(id_dtype, num_ids), ids_shape), 4826 _ones(updates, dtype=id_dtype)) 4827 4828 scattered_ids = scatter(full(operand.shape, 0, id_dtype), 4829 scatter_indices, update_ids, dnums, 4830 indices_are_sorted=indices_are_sorted, 4831 unique_indices=unique_indices) 4832 4833 # b) compute the inverse gather that "undoes" the scatter on the id values. 4834 gather_dnums = GatherDimensionNumbers( 4835 offset_dims=dnums.update_window_dims, 4836 collapsed_slice_dims=dnums.inserted_window_dims, 4837 start_index_map=dnums.scatter_dims_to_operand_dims) 4838 slice_sizes = [] 4839 pos = 0 4840 for i in range(len(scattered_ids.shape)): 4841 if i in dnums.inserted_window_dims: 4842 slice_sizes.append(1) 4843 else: 4844 slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) 4845 pos += 1 4846 gathered_update_ids = gather(scattered_ids, scatter_indices, 4847 dimension_numbers=gather_dnums, 4848 slice_sizes=slice_sizes) 4849 4850 # c) mask off input elements that do not correspond to a primal output. 4851 masked_operand = select(eq(scattered_ids, _zeros(scattered_ids)), 4852 operand, _zeros(operand)) 4853 masked_updates = select(eq(update_ids, gathered_update_ids), 4854 updates, _zeros(updates)) 4855 masked_g_operand = select(eq(scattered_ids, _zeros(scattered_ids)), 4856 g_operand, _zeros(g_operand)) 4857 masked_g_updates = select(eq(update_ids, gathered_update_ids), 4858 g_updates, _zeros(g_updates)) 4859 4860 # d) perform scatter-adds to compute the primal and tangent outputs. 4861 val_out = scatter_add(masked_operand, scatter_indices, masked_updates, 4862 dimension_numbers=dnums, 4863 indices_are_sorted=indices_are_sorted, 4864 unique_indices=unique_indices) 4865 tangent_out = scatter_add(masked_g_operand, scatter_indices, masked_g_updates, 4866 dimension_numbers=dnums, 4867 indices_are_sorted=indices_are_sorted, 4868 unique_indices=unique_indices) 4869 return val_out, tangent_out 4870 4871 4872scatter_p = standard_primitive( 4873 _scatter_shape_rule, _scatter_dtype_rule, 'scatter', 4874 _scatter_translation_rule) 4875ad.primitive_jvps[scatter_p] = _scatter_jvp 4876batching.primitive_batchers[scatter_p] = ( 4877 partial(_scatter_batching_rule, scatter)) 4878 4879 4880def _reduce_shape_rule(*args, computation, jaxpr, consts, dimensions): 4881 operand_args, init_value_args = split_list(args, [len(args) // 2]) 4882 if any(arg.shape != () for arg in init_value_args): 4883 init_value_shapes = [a.shape for a in init_value_args] 4884 raise ValueError(f'Found non-scalar init_value: {init_value_shapes}') 4885 return [ 4886 tuple(np.delete(op_arg.shape, dimensions)) 4887 for op_arg in operand_args 4888 ] 4889 4890 4891def _reduce_dtype_rule(*args, computation, jaxpr, consts, dimensions): 4892 operand_args, init_value_args = split_list(args, [len(args) // 2]) 4893 operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_args] 4894 init_value_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_value_args] 4895 if operand_dtypes != init_value_dtypes: 4896 raise TypeError(f"operand dtypes should match corresponding initial value dtypes; got " 4897 f"operands={operand_args} and initial_values={init_value_args}") 4898 return operand_dtypes 4899 4900 4901def _reduce_translation_rule(c, *values, computation, jaxpr, 4902 consts, dimensions): 4903 operands, init_values = split_list(values, [len(values) // 2]) 4904 if len(operands) == 1: 4905 init_value = init_values[0] 4906 xla_computation = _reduction_computation(c, jaxpr, consts, init_value) 4907 out = xops.Reduce(c, operands, init_values, xla_computation, dimensions) 4908 return xops.Tuple(c, (out,)) 4909 xla_computation = _reduction_computation(c, jaxpr, consts, init_values, singleton=False) 4910 return xops.Reduce(c, operands, init_values, xla_computation, dimensions) 4911 4912 4913def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, 4914 consts, dimensions): 4915 num_operands = len(batched_args) // 2 4916 operands, init_values = split_list(batched_args, [num_operands]) 4917 operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands]) 4918 if all(init_value_bdim is None for init_value_bdim in init_value_bdims): 4919 # Assume all batch dims are the same for each of the operands 4920 assert all(operand_bdim is not None for operand_bdim in operand_bdims) 4921 assert all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims) 4922 # TODO(sharadmv): handle the case when batch dims are different across 4923 # operands or when some are unbatched 4924 operand_bdim = operand_bdims[0] 4925 new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions] 4926 new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim))) 4927 new_operand_bdims = [new_operand_bdim] * num_operands 4928 return reduce_p.bind(*(operands + init_values), 4929 computation=computation, dimensions=tuple(new_dimensions), 4930 consts=consts, 4931 jaxpr=jaxpr), new_operand_bdims 4932 else: 4933 raise NotImplementedError # loop and stack 4934 4935 4936def _reduction_computation(c, jaxpr, consts, init_values, singleton=True): 4937 if singleton: 4938 init_values = [init_values] 4939 shapes = safe_map(c.get_shape, init_values + init_values) 4940 axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions 4941 subc = xla_bridge.make_computation_builder("reduction_computation") 4942 assert len(consts) == 0, "Reduction computations cannot have constants" 4943 args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)] 4944 out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args) 4945 if singleton: 4946 return subc.build(out_nodes[0]) 4947 out_nodes = xops.Tuple(subc, out_nodes) 4948 return subc.build(out_nodes) 4949 4950def _masking_defreducer(prim, identity): 4951 masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity) 4952 4953def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes, 4954 axes, input_shape=None, **reduce_kwargs): 4955 (padded_val,), (logical_shape,) = padded_vals, logical_shapes 4956 padded_shape = masking.padded_shape_as_value(padded_val.shape) 4957 masks = [broadcasted_iota(np.int32, padded_shape, i) < d 4958 for i, d in enumerate(logical_shape) if i in axes] 4959 mask = _reduce(operator.and_, masks) 4960 masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype)) 4961 prim_bind = partial(prim.bind, **reduce_kwargs) 4962 bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape) 4963 return bind(masked_val, axes=axes) 4964 4965reduce_p = standard_primitive(_reduce_shape_rule, _reduce_dtype_rule, 4966 'reduce', translation_rule=_reduce_translation_rule, 4967 multiple_results=True) 4968batching.primitive_batchers[reduce_p] = _reduce_batch_rule 4969 4970 4971def _reduce_number_dtype_rule(name, operand, *args, **kw): 4972 if not dtypes.issubdtype(operand.dtype, np.number): 4973 raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " 4974 "of number.".format(name, np.dtype(operand.dtype).name)) 4975 return dtypes.canonicalize_dtype(operand.dtype) 4976 4977def _reduce_sum_shape_rule(operand, *, axes): 4978 return _reduce_op_shape_rule(operand, axes=axes) 4979 4980def _reduce_sum_translation_rule(c, operand, *, axes): 4981 shape = c.get_shape(operand) 4982 dtype = shape.numpy_dtype() 4983 scalar = ShapedArray((), dtype) 4984 return xops.Reduce(c, [operand], [xb.constant(c, np.array(0, dtype))], 4985 xla.primitive_subcomputation(add_p, scalar, scalar), 4986 axes) 4987 4988def _reduce_sum_transpose_rule(cotangent, operand, *, axes): 4989 assert ad.is_undefined_primal(operand) 4990 input_shape = operand.aval.shape 4991 broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) 4992 result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) 4993 assert result.shape == input_shape 4994 return [result] 4995 4996reduce_sum_p = standard_primitive( 4997 _reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), 4998 'reduce_sum', _reduce_sum_translation_rule) 4999ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) 5000batching.defreducer(reduce_sum_p) 5001_masking_defreducer(reduce_sum_p, 5002 lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape)) 5003 5004 5005def _reduce_op_shape_rule(operand, *, axes, input_shape=None): 5006 del input_shape # Unused. 5007 if len(axes) != len(set(axes)): 5008 raise ValueError(f"duplicate value in 'axes' of reduction: {axes}") 5009 if not all(0 <= a < operand.ndim for a in axes): 5010 raise ValueError(f"reduction axes {axes} contains out-of-bounds indices for {operand}.") 5011 return tuple(np.delete(operand.shape, axes)) 5012 5013def _reduce_prod_translation_rule(c, operand, *, axes): 5014 dtype = c.get_shape(operand).numpy_dtype() 5015 scalar = ShapedArray((), dtype) 5016 return xops.Reduce(c, [operand], [xb.constant(c, np.array(1, dtype))], 5017 xla.primitive_subcomputation(mul_p, scalar, scalar), axes) 5018 5019def _reduce_prod_jvp_rule(primals, tangents, *, axes): 5020 operand, = primals 5021 tangent, = tangents 5022 input_shape = np.array(operand.shape) 5023 5024 n = np.prod(input_shape[list(axes)]) 5025 non_axes = np.delete(np.arange(len(input_shape)), axes) 5026 5027 # Move the reduced axes to the front, and flatten them to 1D. 5028 permutation = axes + tuple(non_axes) 5029 new_shape = (n,) + tuple(input_shape[non_axes]) 5030 operand = reshape(operand, new_shape, permutation) 5031 tangent = reshape(tangent, new_shape, permutation) 5032 5033 def _reduce_prod_tree(x, axis=0): 5034 """Reduce by repeatedly splitting the array and multiplying.""" 5035 while x.shape[axis] > 1: 5036 n = x.shape[axis] 5037 n1 = (n + 1) // 2 5038 n2 = n - n1 5039 x1 = slice_in_dim(x, 0, n1) 5040 x2 = slice_in_dim(x, n1, None) 5041 if n2 != n1: 5042 paddings = [(0, 0, 0)] * len(x.shape) 5043 paddings[axis] = (0, 1, 0) 5044 x2 = pad(x2, _const(x, 1), paddings) 5045 x = x1 * x2 5046 if x.shape[axis] == 0: 5047 return full(input_shape[non_axes], _one(x)) 5048 return squeeze(x, (axis,)) 5049 5050 return api.jvp(_reduce_prod_tree, (operand,), (tangent,)) 5051 5052 5053reduce_prod_p = standard_primitive( 5054 _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), 5055 'reduce_prod', _reduce_prod_translation_rule) 5056ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule 5057batching.defreducer(reduce_prod_p) 5058_masking_defreducer(reduce_prod_p, 5059 lambda shape, dtype: np.broadcast_to(np.array(1, dtype), shape)) 5060 5061 5062def _reduce_chooser_shape_rule(operand, *, axes): 5063 return tuple(np.delete(operand.shape, axes)) 5064 5065def _reduce_chooser_translation_rule(prim, identity, c, operand, *, axes): 5066 dtype = c.get_shape(operand).numpy_dtype() 5067 scalar = ShapedArray((), dtype) 5068 return xops.Reduce(c, [operand], [xb.constant(c, identity(dtype))], 5069 xla.primitive_subcomputation(prim, scalar, scalar), axes) 5070 5071def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): 5072 # TODO(mattjj): an alternative is to use variadic reduce to compute the chosen 5073 # locations in a single pass (rather than comparing equality) and use a 5074 # gather, and/or even push along the chosen elements of g (b/112040122) 5075 shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] 5076 location_indicators = convert_element_type( 5077 _eq_meet(operand, reshape(ans, shape)), g.dtype) 5078 counts = _reduce_sum(location_indicators, axes) 5079 return div(_reduce_sum(mul(g, location_indicators), axes), counts) 5080 5081_reduce_max_translation_rule = partial(_reduce_chooser_translation_rule, max_p, 5082 _get_max_identity) 5083reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, 5084 'reduce_max', _reduce_max_translation_rule) 5085ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) 5086batching.defreducer(reduce_max_p) 5087_masking_defreducer(reduce_max_p, 5088 lambda shape, dtype: np.broadcast_to(np.array(-np.inf, dtype), shape)) 5089 5090 5091_reduce_min_translation_rule = partial( 5092 _reduce_chooser_translation_rule, min_p, _get_min_identity) 5093reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, 5094 'reduce_min', _reduce_min_translation_rule) 5095ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) 5096batching.defreducer(reduce_min_p) 5097_masking_defreducer(reduce_min_p, 5098 lambda shape, dtype: np.broadcast_to(np.array(np.inf, dtype), shape)) 5099 5100 5101 5102def _argminmax_shape_rule(operand, *, axes, index_dtype): 5103 axis, = axes 5104 return tuple(np.delete(operand.shape, axis)) 5105 5106def _argminmax_dtype_rule(operand, *, axes, index_dtype): 5107 if not dtypes.issubdtype(index_dtype, np.integer): 5108 raise TypeError("index_dtype must be an integer type, but got {}" 5109 .format(np.dtype(index_dtype).name)) 5110 return index_dtype 5111 5112def _argminmax_translation_rule(value_comparator, identity, 5113 c, operand, *, axes, index_dtype): 5114 axis, = axes 5115 shape = c.get_shape(operand) 5116 dtype = shape.numpy_dtype() 5117 5118 subc = xb.make_computation_builder("argminmax_comparator") 5119 value_shape = xc.Shape.array_shape(shape.xla_element_type(), ()) 5120 index_shape = xc.Shape.array_shape(index_dtype, ()) 5121 x_value = xb.parameter(subc, 0, value_shape) 5122 x_index = xb.parameter(subc, 1, index_shape) 5123 y_value = xb.parameter(subc, 2, value_shape) 5124 y_index = xb.parameter(subc, 3, index_shape) 5125 which_value = value_comparator(x_value, y_value) 5126 which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value), 5127 xops.Lt(x_index, y_index))) 5128 xops.Tuple(subc, [xops.Select(which_value, x_value, y_value), 5129 xops.Select(which_index, x_index, y_index)]) 5130 comparator = subc.build() 5131 5132 iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions()) 5133 iota = xc.ops.Iota(c, iota_shape, axis) 5134 out = xops.Reduce( 5135 c, [operand, iota], 5136 [xb.constant(c, identity(dtype)), 5137 xb.constant(c, np.array(0, index_dtype))], comparator, [axis]) 5138 return xops.GetTupleElement(out, 1) 5139 5140def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype): 5141 axis, = axes 5142 idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis)) 5143 maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype) 5144 maxval = broadcast(tie_in(a, maxval), a.shape) 5145 mask_idxs = select(eq(a, expand_dims(op(a, (axis,)), (axis,))), idxs, 5146 maxval) 5147 return _reduce_min(mask_idxs, (axis,)) 5148 5149_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt, 5150 _get_min_identity) 5151_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt, 5152 _get_max_identity) 5153 5154argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 5155 'argmin', _argmin_translation_rule) 5156batching.defreducer(argmin_p) 5157ad.defjvp_zero(argmin_p) 5158xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun( 5159 partial(_argminmax_gpu_translation_rule, _reduce_min), 5160 multiple_results=False) 5161 5162argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 5163 'argmax', _argmax_translation_rule) 5164batching.defreducer(argmax_p) 5165ad.defjvp_zero(argmax_p) 5166xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun( 5167 partial(_argminmax_gpu_translation_rule, _reduce_max), 5168 multiple_results=False) 5169 5170 5171def _reduce_logical_shape_rule(operand, *, axes): 5172 if operand.dtype != np.bool_: 5173 msg = "logical reduction requires operand dtype bool, got {}." 5174 raise TypeError(msg.format(operand.dtype)) 5175 return tuple(np.delete(operand.shape, axes)) 5176 5177def _reduce_logical_translation_rule(prim, identity, c, operand, *, axes): 5178 scalar = ShapedArray((), np.bool_) 5179 return xops.Reduce(c, [operand], [xb.constant(c, identity(np.bool_))], 5180 xla.primitive_subcomputation(prim, scalar, scalar), axes) 5181 5182_reduce_or_translation_rule = partial(_reduce_logical_translation_rule, 5183 or_p, _get_max_identity) 5184reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_), 5185 'reduce_or', _reduce_or_translation_rule) 5186batching.defreducer(reduce_or_p) 5187 5188 5189_reduce_and_translation_rule = partial(_reduce_logical_translation_rule, 5190 and_p, _get_min_identity) 5191reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_), 5192 'reduce_and', _reduce_and_translation_rule) 5193batching.defreducer(reduce_and_p) 5194 5195def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts, 5196 window_dimensions, window_strides, padding, 5197 base_dilation, window_dilation): 5198 if operand.dtype != init_value.dtype: 5199 msg = ("reduce_window got inconsistent dtypes for operand and init_value: " 5200 " got operand dtype {} and init_value dtype {}.") 5201 raise TypeError(msg.format(operand.dtype, init_value.dtype)) 5202 if init_value.shape != (): 5203 msg = ("reduce_window expected init_value to be a scalar but init_value " 5204 "has shape {}.") 5205 raise TypeError(msg.format(init_value.shape)) 5206 return _common_reduce_window_shape_rule( 5207 operand, window_dimensions, window_strides, padding, base_dilation, 5208 window_dilation) 5209 5210def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts, 5211 window_dimensions, window_strides, padding, 5212 base_dilation, window_dilation): 5213 xla_computation = _reduction_computation(c, jaxpr, consts, init_value) 5214 return xops.ReduceWindowWithGeneralPadding( 5215 operand, init_value, xla_computation, window_dimensions, 5216 window_strides, base_dilation, window_dilation, padding) 5217 5218def _generic_reduce_window_batch_rule( 5219 batched_args, batch_dims, *, jaxpr, consts, window_dimensions, 5220 window_strides, padding, base_dilation, window_dilation): 5221 operand, init = batched_args 5222 bdim, init_bdim = batch_dims 5223 if init_bdim is not None: 5224 raise NotImplementedError("reduce_window batching is not implemented for " 5225 "initial values") 5226 5227 def reduce_window(x, window_dimensions, window_strides, padding, base_dilation, 5228 window_dilation): 5229 return reduce_window_p.bind( 5230 x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions, 5231 window_strides=window_strides, padding=padding, base_dilation=base_dilation, 5232 window_dilation=window_dilation) 5233 return _reduce_window_batch_rule( 5234 reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions, 5235 window_strides=window_strides, padding=padding, base_dilation=base_dilation, 5236 window_dilation=window_dilation) 5237 5238 5239reduce_window_p = standard_primitive( 5240 _reduce_window_shape_rule, _input_dtype, 'reduce_window', 5241 _reduce_window_translation_rule) 5242batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule 5243 5244 5245def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides, 5246 padding, base_dilation, window_dilation): 5247 if not dtypes.issubdtype(operand.dtype, np.number): 5248 msg = "operand to reduce_window_sum must have a number dtype, got {}" 5249 raise TypeError(msg.format(np.dtype(operand.dtype).name)) 5250 return _common_reduce_window_shape_rule(operand, window_dimensions, 5251 window_strides, padding, base_dilation, 5252 window_dilation) 5253 5254def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions, 5255 window_strides, padding, base_dilation, 5256 window_dilation): 5257 dtype = c.get_shape(operand).numpy_dtype() 5258 scalar = ShapedArray((), dtype) 5259 return xops.ReduceWindowWithGeneralPadding( 5260 operand, xb.constant(c, np.array(0, dtype)), 5261 xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions, 5262 window_strides, base_dilation, window_dilation, padding) 5263 5264def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions, 5265 window_strides, padding, base_dilation, 5266 window_dilation): 5267 assert ad.is_undefined_primal(operand) 5268 input_shape = operand.aval.shape 5269 pads = _conv_general_vjp_lhs_padding( 5270 input_shape, window_dimensions, window_strides, cotangent.shape, padding, 5271 base_dilation, window_dilation) 5272 ones = [1] * len(input_shape) 5273 padding_config = [(lo, hi, stride - 1) 5274 for (lo, hi), stride in zip(pads, window_strides)] 5275 pad_cotangent = pad(cotangent, _zero(cotangent), padding_config) 5276 result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation, 5277 [(0, 0)] * len(input_shape), 5278 base_dilation=ones, 5279 window_dilation=window_dilation) 5280 assert result.shape == input_shape, (result.shape, input_shape) 5281 return [result] 5282 5283def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, 5284 window_dimensions, window_strides, padding, 5285 base_dilation, window_dilation): 5286 operand, = batched_args 5287 bdim, = bdims 5288 5289 if bdim is not None: 5290 window_dimensions = \ 5291 window_dimensions[:bdim] + (1,) + window_dimensions[bdim:] 5292 window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:] 5293 padding = padding[:bdim] + ((0, 0),) + padding[bdim:] 5294 base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:] 5295 window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:] 5296 5297 operand = reduce_window(operand, window_dimensions, window_strides, padding, 5298 base_dilation, window_dilation) 5299 return operand, bdim 5300 5301reduce_window_sum_p = standard_primitive( 5302 _reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum', 5303 _reduce_window_sum_translation_rule) 5304ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) 5305batching.primitive_batchers[reduce_window_sum_p] = partial( 5306 _reduce_window_batch_rule, _reduce_window_sum) 5307 5308def _reduce_window_chooser_translation_rule( 5309 prim, identity, c, operand, *, window_dimensions, window_strides, padding, 5310 base_dilation, window_dilation): 5311 dtype = c.get_shape(operand).numpy_dtype() 5312 scalar = ShapedArray((), dtype) 5313 return xops.ReduceWindowWithGeneralPadding( 5314 operand, xb.constant(c, identity(dtype)), 5315 xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions, 5316 window_strides, base_dilation, window_dilation, padding) 5317 5318def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions, 5319 window_strides, padding, base_dilation, 5320 window_dilation): 5321 assert prim is max_p or prim is min_p 5322 select_prim = ge_p if prim is max_p else le_p 5323 return _select_and_gather_add(g, operand, select_prim, window_dimensions, 5324 window_strides, padding, base_dilation, 5325 window_dilation) 5326 5327 5328def _common_reduce_window_shape_rule(operand, window_dimensions, 5329 window_strides, padding, base_dilation, 5330 window_dilation): 5331 _check_shapelike("reduce_window", "window_dimensions", window_dimensions, 5332 non_zero_shape=True) 5333 _check_shapelike("reduce_window", "window_strides", window_strides, 5334 non_zero_shape=True) 5335 _check_shapelike("reduce_window", "base_dilation", base_dilation) 5336 _check_shapelike("reduce_window", "window_dilation", window_dilation) 5337 if operand.ndim != len(window_dimensions): 5338 msg = ("reduce_window got the wrong number of window_dimensions for " 5339 "operand: got operand shape {} with window_dimensions {}.") 5340 raise TypeError(msg.format(operand.shape, window_dimensions)) 5341 if len(window_strides) != len(window_dimensions): 5342 msg = ("reduce_window got inconsistent window_strides and " 5343 "window_dimensions: got window_strides {} and window_dimensions {}.") 5344 raise TypeError(msg.format(window_strides, window_dimensions)) 5345 if len(base_dilation) != len(window_dimensions): 5346 msg = ("reduce_window got inconsistent base_dilation and " 5347 "window_dimensions: got base_dilation {} and window_dimensions {}.") 5348 raise TypeError(msg.format(base_dilation, window_dimensions)) 5349 if len(window_dilation) != len(window_dimensions): 5350 msg = ("reduce_window got inconsistent window_dilation and " 5351 "window_dimensions: got window_dilation {} and window_dimensions " 5352 "{}.") 5353 raise TypeError(msg.format(window_dilation, window_dimensions)) 5354 5355 return reduce_window_shape_tuple(operand.shape, window_dimensions, 5356 window_strides, padding, base_dilation, 5357 window_dilation) 5358 5359def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, 5360 padding, base_dilation=None, 5361 window_dilation=None): 5362 if base_dilation is not None: 5363 operand_shape = _dilate_shape(operand_shape, base_dilation) 5364 if window_dilation is not None: 5365 window_dimensions = _dilate_shape(window_dimensions, window_dilation) 5366 operand_padded = np.add(operand_shape, np.add(*zip(*padding))) 5367 t = np.floor_divide( 5368 np.subtract(operand_padded, window_dimensions), window_strides) + 1 5369 return tuple(t) 5370 5371_reduce_window_max_translation_rule = partial( 5372 _reduce_window_chooser_translation_rule, max_p, _get_max_identity) 5373reduce_window_max_p = standard_primitive( 5374 _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max', 5375 _reduce_window_max_translation_rule) 5376ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p)) 5377batching.primitive_batchers[reduce_window_max_p] = partial( 5378 _reduce_window_batch_rule, _reduce_window_max) 5379 5380_reduce_window_min_translation_rule = partial( 5381 _reduce_window_chooser_translation_rule, min_p, _get_min_identity) 5382reduce_window_min_p = standard_primitive( 5383 _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min', 5384 _reduce_window_min_translation_rule) 5385ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p)) 5386 5387_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule, 5388 _reduce_window_min) 5389batching.primitive_batchers[reduce_window_min_p] = partial( 5390 _reduce_window_batch_rule, _reduce_window_min) 5391 5392 5393def _select_and_scatter_shape_rule( 5394 operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, 5395 scatter_consts, window_dimensions, window_strides, padding): 5396 _check_shapelike("select_and_scatter", "window_dimensions", window_dimensions) 5397 _check_shapelike("select_and_scatter", "window_strides", window_strides) 5398 if len(window_dimensions) != len(window_strides): 5399 msg = ("select_and_scatter got inconsistent window_strides and " 5400 "window_dimensions: got window_strides {} and window_dimensions {}.") 5401 raise TypeError(msg.format(window_strides, window_dimensions)) 5402 return operand.shape 5403 5404def _select_and_scatter_translation( 5405 c, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, 5406 scatter_consts, window_dimensions, window_strides, padding): 5407 select = _reduction_computation(c, select_jaxpr, select_consts, init_value) 5408 scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value) 5409 return xops.SelectAndScatterWithGeneralPadding( 5410 operand, select, window_dimensions, window_strides, padding, source, 5411 init_value, scatter) 5412 5413select_and_scatter_p = standard_primitive( 5414 _select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter', 5415 _select_and_scatter_translation) 5416 5417 5418def _select_and_scatter_add_shape_rule( 5419 source, operand, *, select_prim, window_dimensions, window_strides, 5420 padding): 5421 return operand.shape 5422 5423def _select_and_scatter_add_translation( 5424 c, source, operand, *, select_prim, window_dimensions, window_strides, 5425 padding): 5426 dtype = c.get_shape(operand).numpy_dtype() 5427 scalar = ShapedArray((), dtype) 5428 select = xla.primitive_subcomputation(select_prim, scalar, scalar) 5429 scatter = xla.primitive_subcomputation(add_p, scalar, scalar) 5430 zero = xb.constant(c, np.array(0, dtype)) 5431 return xops.SelectAndScatterWithGeneralPadding( 5432 operand, select, window_dimensions, window_strides, padding, source, zero, 5433 scatter) 5434 5435def _select_and_scatter_add_jvp( 5436 primals, tangents, *, select_prim, window_dimensions, window_strides, 5437 padding): 5438 source, operand = primals 5439 g_source, g_operand = tangents 5440 val_out = _select_and_scatter_add( 5441 source, operand, select_prim, window_dimensions, window_strides, 5442 padding) 5443 del g_operand 5444 if type(g_source) is ad_util.Zero: 5445 tangent_out = ad_util.Zero.from_value(val_out) 5446 else: 5447 tangent_out = _select_and_scatter_add( 5448 g_source, operand, select_prim, window_dimensions, 5449 window_strides, padding) 5450 return val_out, tangent_out 5451 5452def _select_and_scatter_add_transpose( 5453 t, source, operand, *, select_prim, window_dimensions, window_strides, 5454 padding): 5455 assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand) 5456 ones = (1,) * len(window_dimensions) 5457 source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions, 5458 window_strides, padding, ones, ones) 5459 return [source_t, None] 5460 5461def _select_and_scatter_add_batch_rule( 5462 batched_args, batch_dims, *, select_prim, window_dimensions, window_strides, 5463 padding): 5464 source, operand = batched_args 5465 s_bdim, o_bdim = batch_dims 5466 size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims) 5467 if bdim is not None) 5468 source = batching.bdim_at_front(source, s_bdim, size) 5469 operand = batching.bdim_at_front(operand, o_bdim, size) 5470 5471 window_dimensions = (1,) + window_dimensions 5472 window_strides = (1,) + window_strides 5473 padding = ((0, 0),) + padding 5474 out = _select_and_scatter_add(source, operand, select_prim, window_dimensions, 5475 window_strides, padding) 5476 return out, 0 5477 5478select_and_scatter_add_p = standard_primitive( 5479 _select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add', 5480 _select_and_scatter_add_translation) 5481ad.primitive_transposes[select_and_scatter_add_p] = \ 5482 _select_and_scatter_add_transpose 5483ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp 5484batching.primitive_batchers[select_and_scatter_add_p] = \ 5485 _select_and_scatter_add_batch_rule 5486 5487def _select_and_gather_add_shape_rule( 5488 tangents, operand, *, select_prim, window_dimensions, window_strides, 5489 padding, base_dilation, window_dilation): 5490 if tangents.shape != operand.shape: 5491 msg = ("select_and_gather_add tangents and operand shapes must match, " 5492 "got {} and {}.") 5493 raise TypeError(msg.format(tangents.shape, operand.shape)) 5494 return _common_reduce_window_shape_rule( 5495 operand, window_dimensions, window_strides, padding, base_dilation, 5496 window_dilation) 5497 5498 5499_UINT_DTYPES = { 5500 16: np.uint16, 5501 32: np.uint32, 5502 64: np.uint64, 5503} 5504 5505_INT_DTYPES = { 5506 16: np.int16, 5507 32: np.int32, 5508 64: np.int64, 5509} 5510 5511def _select_and_gather_add_translation( 5512 c, tangents, operand, *, select_prim, window_dimensions, window_strides, 5513 padding, base_dilation, window_dilation, max_bits=64): 5514 shape = c.get_shape(operand) 5515 dtype = shape.numpy_dtype() 5516 etype = shape.xla_element_type() 5517 nbits = dtypes.finfo(dtype).bits 5518 5519 assert nbits <= max_bits 5520 double_word_reduction = nbits * 2 <= max_bits 5521 5522 const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype), 5523 canonicalize_types=False) 5524 5525 if double_word_reduction: 5526 # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so 5527 # we implement a pair-wise ReduceWindow by packing two k-bit values into 5528 # 2k-bit unsigned integer using bit tricks. 5529 word_dtype = _UINT_DTYPES[nbits] 5530 double_word_dtype = _UINT_DTYPES[nbits * 2] 5531 word_type = xla_client.dtype_to_etype(word_dtype) 5532 double_word_type = xla_client.dtype_to_etype(double_word_dtype) 5533 5534 # Packs two values into a tuple. 5535 def pack(a, b): 5536 a = xops.BitcastConvertType(a, word_type) 5537 b = xops.BitcastConvertType(b, word_type) 5538 a = xops.ConvertElementType(a, double_word_type) 5539 b = xops.ConvertElementType(b, double_word_type) 5540 a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits)) 5541 return xops.Or(a, b) 5542 5543 # Unpacks the first element of a tuple. 5544 def fst(c, t): 5545 st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits)) 5546 return xops.BitcastConvertType(xops.ConvertElementType(st, word_type), etype) 5547 5548 # Unpacks the second element of a tuple. 5549 def snd(t): 5550 return xops.BitcastConvertType(xops.ConvertElementType(t, word_type), etype) 5551 5552 else: 5553 # The double-word trick above only works if we have a sufficiently large 5554 # type. As an alternative, we can pack two half words into a single word, 5555 # at the cost of precision. 5556 # TODO(b/73062247): add support for tuple reductions and remove this case. 5557 warnings.warn("Using reduced precision for gradient of reduce-window " 5558 "min/max operator to work around missing XLA support for " 5559 "pair-reductions. This is likely from a second or " 5560 "higher derivative of a max-pooling operation.") 5561 r_nbits = nbits // 2 5562 # Drop/round the bottom mantissa bits. 5563 nexp = dtypes.finfo(dtype).nexp 5564 nmant = r_nbits - nexp - 1 5565 5566 double_word_dtype = word_dtype = _UINT_DTYPES[nbits] 5567 word_type = xla_client.dtype_to_etype(word_dtype) 5568 5569 # Packs two values into a tuple. 5570 def pack(a, b): 5571 a = xops.ReducePrecision(a, exponent_bits=nexp, mantissa_bits=nmant) 5572 b = xops.ReducePrecision(b, exponent_bits=nexp, mantissa_bits=nmant) 5573 a = xops.BitcastConvertType(a, word_type) 5574 b = xops.BitcastConvertType(b, word_type) 5575 b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits)) 5576 return xops.Or(a, b) 5577 5578 # Unpacks the first element of a tuple. 5579 def fst(c, t): 5580 st = xops.And(t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits)) 5581 return xops.BitcastConvertType(st, etype) 5582 5583 # Unpacks the second element of a tuple. 5584 def snd(t): 5585 return xops.BitcastConvertType(xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), 5586 etype) 5587 5588 def reducer(): 5589 c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer") 5590 x = xb.parameter(c, 0, 5591 xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) 5592 y = xb.parameter(c, 1, 5593 xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) 5594 assert select_prim is ge_p or select_prim is le_p 5595 which = xops.Ge if select_prim is ge_p else xops.Le 5596 xops.Select(which(fst(c, x), fst(c, y)), x, y) 5597 return c.build() 5598 5599 5600 assert select_prim is ge_p or select_prim is le_p, select_prim 5601 init = -np.inf if select_prim is ge_p else np.inf 5602 out = xops.ReduceWindowWithGeneralPadding( 5603 pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)), 5604 reducer(), window_dimensions, window_strides, base_dilation, 5605 window_dilation, padding) 5606 return snd(out) 5607 5608def _select_and_gather_add_jvp( 5609 primals, tangents, *, select_prim, window_dimensions, window_strides, 5610 padding, base_dilation, window_dilation): 5611 source, operand = primals 5612 g_source, g_operand = tangents 5613 val_out = _select_and_gather_add( 5614 source, operand, select_prim, window_dimensions, window_strides, 5615 padding, base_dilation, window_dilation) 5616 del g_operand 5617 if type(g_source) is ad_util.Zero: 5618 tangent_out = ad_util.Zero.from_value(val_out) 5619 else: 5620 tangent_out = _select_and_gather_add( 5621 g_source, operand, select_prim, window_dimensions, 5622 window_strides, padding, base_dilation, window_dilation) 5623 return val_out, tangent_out 5624 5625def _select_and_gather_add_transpose( 5626 t, tangents, operand, *, select_prim, window_dimensions, window_strides, 5627 padding, base_dilation, window_dilation): 5628 assert select_prim in (le_p, ge_p) 5629 assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand) 5630 if any(d != 1 for d in window_dilation): 5631 msg = ("VJP not implemented for select_and_gather (MaxPool) with window " 5632 "dilation, got window_dilation={}.") 5633 raise NotImplementedError(msg.format(window_dilation)) 5634 if type(t) is ad_util.Zero: 5635 return [ad_util.Zero(tangents.aval), None] 5636 has_base_dilation = any(d != 1 for d in base_dilation) 5637 if has_base_dilation: 5638 select_identity = (_get_max_identity if select_prim is ge_p 5639 else _get_min_identity) 5640 operand = pad(operand, select_identity(operand.dtype), 5641 tuple((0, 0, d - 1) for d in base_dilation)) 5642 result = _select_and_scatter_add(t, operand, select_prim, window_dimensions, 5643 window_strides, padding) 5644 if has_base_dilation: 5645 result = slice(operand, (0,) * len(operand.shape), operand.shape, 5646 base_dilation) 5647 return [result, None] 5648 5649def _select_and_gather_add_batching_rule( 5650 batched_args, batch_dims, *, select_prim, window_dimensions, window_strides, 5651 padding, base_dilation, window_dilation): 5652 t, x = batched_args 5653 t_bdim, x_bdim = batch_dims 5654 size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims) 5655 if bdim is not None) 5656 t = batching.bdim_at_front(t, t_bdim, size) 5657 x = batching.bdim_at_front(x, x_bdim, size) 5658 window_dimensions = (1,) + window_dimensions 5659 window_strides = (1,) + window_strides 5660 padding = ((0, 0),) + padding 5661 base_dilation = (1,) + base_dilation 5662 window_dilation = (1,) + window_dilation 5663 out = _select_and_gather_add(t, x, select_prim, window_dimensions, 5664 window_strides, padding, base_dilation, 5665 window_dilation) 5666 return (out, 0) 5667 5668 5669select_and_gather_add_p = standard_primitive( 5670 _select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add', 5671 _select_and_gather_add_translation) 5672ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp 5673ad.primitive_transposes[select_and_gather_add_p] = \ 5674 _select_and_gather_add_transpose 5675batching.primitive_batchers[select_and_gather_add_p] = \ 5676 _select_and_gather_add_batching_rule 5677xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial( 5678 _select_and_gather_add_translation, 5679 max_bits=32) 5680 5681def _sort_abstract_eval(*args, **kwargs): 5682 args = tuple(raise_to_shaped(arg) for arg in args) 5683 if any(arg.shape != args[0].shape for arg in args[1:]): 5684 shapes = " ".join(str(a.shape) for a in args) 5685 raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") 5686 return args 5687 5688 5689def _float_to_int_for_sort(x): 5690 # Switch from a floating point value to a integer value in such a way that 5691 # when using the integer value to compare, we get the same result for normal 5692 # values, and -nan is treated as the smallest value, and nan is treated as 5693 # the largest value. 5694 # If f is a float, and 5695 # x = bit_cast<int32>(f); 5696 # y = x < 0 ? int32_max - x : x; 5697 # then y is ordered as an int32 such that finite values have the obvious 5698 # order, -0 is ordered before 0, and -NaN and NaN appear at the beginning 5699 # and end of the ordering. 5700 # Note that in order to avoid -x to overflow, we calculate 5701 # int32_max - x as unsigned, and then convert back to signed. 5702 if x.dtype == dtypes.bfloat16: 5703 x = convert_element_type(x, np.float32) 5704 nbits = np.finfo(x).bits 5705 signed_dtype = _INT_DTYPES[nbits] 5706 unsigned_dtype = _UINT_DTYPES[nbits] 5707 5708 signed = bitcast_convert_type(x, signed_dtype) 5709 unsigned = bitcast_convert_type(x, unsigned_dtype) 5710 flipped = bitcast_convert_type( 5711 sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype) 5712 return select(lt(signed, _zero(signed)), flipped, signed) 5713 5714# Default comparator that sorts the operands lexicographically on the 5715# first `num_keys` arguments. 5716# For floating point types, a total order is created where 5717# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN. 5718# For complex types, the (real, imag) pairs are sorted lexicographically 5719# (following NumPy's semantics). 5720# This code adds complex-number support and lexicographic ordering to the algorithm from: 5721# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33 5722def _sort_lt_comparator(*operands, num_keys=1): 5723 assert len(operands) >= 2 and len(operands) % 2 == 0, operands 5724 assert len(operands) // 2 >= num_keys, (operands, num_keys) 5725 x_keys, y_keys = [], [] 5726 for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]): 5727 assert x.dtype == y.dtype, (x.dtype, y.dtype) 5728 if np.issubdtype(x.dtype, np.complexfloating): 5729 x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))]) 5730 y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))]) 5731 elif np.issubdtype(x.dtype, np.floating): 5732 x_keys.append(_float_to_int_for_sort(x)) 5733 y_keys.append(_float_to_int_for_sort(y)) 5734 else: 5735 x_keys.append(x) 5736 y_keys.append(y) 5737 5738 p = None 5739 for xk, yk in zip(x_keys[::-1], y_keys[::-1]): 5740 p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None 5741 else lt(xk, yk)) 5742 return p 5743 5744 5745def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys): 5746 types = [c.get_shape(x).xla_element_type() for x in operands] 5747 subc = xla_bridge.make_computation_builder("sort_lt_comparator") 5748 params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ())) 5749 for i, typ in enumerate(types) for j in range(2)] 5750 result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys), 5751 multiple_results=False)(subc, *params) 5752 comparator = subc.build(result) 5753 out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable, 5754 comparator=comparator) 5755 return out if len(operands) != 1 else xops.Tuple(c, [out]) 5756 5757def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys): 5758 shape = primals[0].shape 5759 iotas = [] 5760 for dim, size in enumerate(shape): 5761 dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64 5762 iotas.append(broadcasted_iota(dtype, shape, dim)) 5763 primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension, 5764 is_stable=is_stable, num_keys=num_keys) 5765 idx = tuple(primals[-1] if i == dimension else iotas[i] 5766 for i in range(len(shape))) 5767 tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents) 5768 return tuple(primals[:-1]), tangents_out 5769 5770def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys): 5771 prototype_arg, new_bdim = next( 5772 (a, b) for a, b in zip(batched_args, batch_dims) if b is not None) 5773 new_args = [] 5774 for arg, bdim in zip(batched_args, batch_dims): 5775 if bdim is None: 5776 dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) 5777 new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims)) 5778 else: 5779 new_args.append(batching.moveaxis(arg, bdim, new_bdim)) 5780 new_dimension = dimension + (new_bdim <= dimension) 5781 bdims = (new_bdim,) * len(new_args) 5782 return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys), 5783 bdims) 5784 5785 5786sort_p = Primitive('sort') 5787sort_p.multiple_results = True 5788sort_p.def_impl(partial(xla.apply_primitive, sort_p)) 5789sort_p.def_abstract_eval(_sort_abstract_eval) 5790xla.translations[sort_p] = _sort_translation_rule 5791ad.primitive_jvps[sort_p] = _sort_jvp 5792batching.primitive_batchers[sort_p] = _sort_batch_rule 5793 5794 5795def _top_k_abstract_eval(operand, *, k): 5796 if k < 0: 5797 raise ValueError("k argument to top_k must be nonnegative, got {}".format(k)) 5798 if len(operand.shape) == 0: 5799 raise TypeError("top_k operand must have >= 1 dimension, got {}" 5800 .format(operand.shape)) 5801 shape = list(operand.shape) 5802 if shape[-1] < k: 5803 msg = "k argument to top_k must be no larger than minor dimension; {} vs {}" 5804 raise ValueError(msg.format(k, shape)) 5805 shape[-1] = k 5806 return (ShapedArray(shape, operand.dtype), 5807 ShapedArray(shape, np.dtype(np.int32))) 5808 5809def _top_k_jvp(primals, tangents, *, k): 5810 operand, = primals 5811 tangent, = tangents 5812 primals_out = top_k(operand, k) 5813 if type(tangent) is ad_util.Zero: 5814 tangent_out = ad_util.Zero.from_value(primals_out[0]) 5815 else: 5816 _, k_idxs = primals_out 5817 idx_shape = k_idxs.shape 5818 rank = len(idx_shape) 5819 gather_index_shape = idx_shape + (1,) 5820 gather_indices = [] 5821 for i in range(rank-1): 5822 _iota = iota(k_idxs.dtype, idx_shape[i]) 5823 if not config.omnistaging_enabled: 5824 _iota = tie_in(operand, _iota) 5825 _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) 5826 gather_indices.append(_iota) 5827 gather_indices.append(reshape(k_idxs, gather_index_shape)) 5828 gather_indices = concatenate(gather_indices, dimension=rank) 5829 slice_sizes = (1,) * rank 5830 dnums = GatherDimensionNumbers( 5831 offset_dims=(), 5832 collapsed_slice_dims=tuple(range(rank)), 5833 start_index_map=tuple(range(rank))) 5834 tangent_out = gather(tangent, gather_indices, dnums, slice_sizes) 5835 return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1])) 5836 5837def _top_k_batch_rule(batched_args, batch_dims, *, k): 5838 operand, = batched_args 5839 bdim, = batch_dims 5840 if bdim == operand.ndim-1: 5841 perm = np.arange(operand.ndim) 5842 perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1] 5843 top_k_v, top_k_i = top_k(transpose(operand, perm), k=k) 5844 return (transpose(top_k_v, perm), 5845 transpose(top_k_i, perm)), (bdim, bdim) 5846 else: 5847 return top_k(operand, k=k), (bdim, bdim) 5848 5849top_k_p = Primitive('top_k') 5850top_k_p.multiple_results = True 5851top_k_p.def_impl(partial(xla.apply_primitive, top_k_p)) 5852top_k_p.def_abstract_eval(_top_k_abstract_eval) 5853xla.translations[top_k_p] = partial(standard_translate, 'top_k') 5854ad.primitive_jvps[top_k_p] = _top_k_jvp 5855batching.primitive_batchers[top_k_p] = _top_k_batch_rule 5856 5857def _stop_gradient_jvp_rule(primals, tangents): 5858 # if we don't call stop_gradient here, we'd only peel off one autodiff tracer 5859 x, = primals 5860 return stop_gradient(x), ad_util.Zero.from_value(x) 5861 5862def _stop_gradient_batch_rule(batched_args, batch_dims): 5863 x, = batched_args 5864 dim, = batch_dims 5865 return stop_gradient(x), dim 5866 5867ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule 5868batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule 5869 5870 5871def create_token(_=None): 5872 """Creates an XLA token value with no preconditions for sequencing effects. 5873 5874 Experimental. 5875 5876 The argument is ignored. It exists for backward compatibility. 5877 """ 5878 if config.omnistaging_enabled: 5879 return create_token_p.bind() 5880 else: 5881 x = _ 5882 if x is None: 5883 raise ValueError( 5884 'create_token needs a tie-in operand unless omnistaging is enabled.') 5885 return create_token_p.bind(stop_gradient(x)) 5886 5887create_token_p = Primitive("create_token") 5888create_token_p.def_impl(partial(xla.apply_primitive, create_token_p)) 5889create_token_p.def_abstract_eval(lambda *_: abstract_token) 5890xla.translations[create_token_p] = lambda c, *_: xops.CreateToken(c) 5891 5892def after_all(*operands): 5893 """Merges one or more XLA token values. Experimental. 5894 5895 Wraps the XLA AfterAll operator.""" 5896 return after_all_p.bind(*operands) 5897 5898def _after_all_abstract_eval(*operands): 5899 if any(x is not abstract_token for x in operands): 5900 raise TypeError("Arguments to after_all must be tokens") 5901 return abstract_token 5902 5903 5904def _after_all_translation_rule(c, *operands): 5905 return xops.AfterAll(c, operands) 5906 5907after_all_p = Primitive("after_all") 5908after_all_p.def_impl(partial(xla.apply_primitive, after_all_p)) 5909after_all_p.def_abstract_eval(_after_all_abstract_eval) 5910xla.translations[after_all_p] = _after_all_translation_rule 5911 5912 5913def infeed(token, shape=None, partitions=None): 5914 """Consumes an infeed value of `shape` from the host. Experimental. 5915 5916 `token` is used to sequence infeed and outfeed effects. 5917 `partitions` may be specified inside a `sharded_jit` function. 5918 """ 5919 flat_shapes, treedef = pytree.flatten(shape) 5920 for shape in flat_shapes: 5921 if not isinstance(shape, ShapedArray): 5922 raise TypeError("shape argument to infeed must be a pytree of " 5923 "ShapedArray values, got {}".format(shape)) 5924 if partitions is not None: 5925 # Always replicate token. 5926 # We specifically use type() to raise an error for PartitionSpecs. 5927 if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck 5928 raise ValueError(f"'partitions' argument to infeed should be a tuple, " 5929 f"got {partitions}") 5930 partitions = partitions + (None,) 5931 xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes), 5932 partitions=partitions) 5933 return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1]) 5934 5935def _infeed_abstract_eval(token, *, shapes, partitions): 5936 if token is not abstract_token: 5937 raise TypeError("First argument to infeed must be a token") 5938 return shapes + (abstract_token,) 5939 5940 5941def _infeed_translation_rule(c, token, *, shapes, partitions): 5942 shape = tuple(shape.with_major_to_minor_layout_if_absent() 5943 for x in shapes for shape in xla.aval_to_xla_shapes(x)) 5944 build_infeed = partial(xops.InfeedWithToken, token, 5945 xla_client.Shape.tuple_shape(shape)) 5946 if partitions: 5947 xs_and_token = xb.with_sharding(c, partitions, build_infeed) 5948 else: 5949 # Note that infeed will default to replication if inside a sharded 5950 # computation and no sharding is specified. 5951 xs_and_token = build_infeed() 5952 xs = xops.GetTupleElement(xs_and_token, 0) 5953 token = xops.GetTupleElement(xs_and_token, 1) 5954 outs = [xops.GetTupleElement(xs, i) for i in range(len(shapes))] + [token] 5955 return xops.Tuple(c, outs) 5956 5957infeed_p = Primitive("infeed") 5958infeed_p.multiple_results = True 5959infeed_p.def_impl(partial(xla.apply_primitive, infeed_p)) 5960infeed_p.def_abstract_eval(_infeed_abstract_eval) 5961xla.translations[infeed_p] = _infeed_translation_rule 5962 5963def outfeed(token, xs): 5964 """Outfeeds value `xs` to the host. Experimental. 5965 5966 `token` is used to sequence infeed and outfeed effects. 5967 """ 5968 flat_xs, _ = pytree.flatten(xs) 5969 return outfeed_p.bind(token, *flat_xs) 5970 5971def _outfeed_abstract_eval(token, *xs): 5972 if token is not abstract_token: 5973 raise TypeError("First argument to outfeed must be a token") 5974 return abstract_token 5975 5976 5977def _outfeed_translation_rule(c, token, *xs): 5978 t = xops.Tuple(c, xs) 5979 return xops.OutfeedWithToken(t, token, c.get_shape(t)) 5980 5981outfeed_p = Primitive("outfeed") 5982outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p)) 5983outfeed_p.def_abstract_eval(_outfeed_abstract_eval) 5984xla.translations[outfeed_p] = _outfeed_translation_rule 5985 5986def rng_uniform(a, b, shape): 5987 """Stateful PRNG generator. Experimental and its use is discouraged. 5988 5989 Returns uniformly distributed random numbers in the range [a, b) 5990 5991 You should use jax.random for most purposes; this function exists only for 5992 niche use cases with special performance requirements. 5993 5994 This API may be removed at any time. 5995 """ 5996 return rng_uniform_p.bind(a, b, shape=tuple(shape)) 5997 5998def _rng_uniform_abstract_eval(a, b, *, shape): 5999 if a.dtype != b.dtype: 6000 raise ValueError( 6001 "Arguments to rng_uniform must have identical dtypes, got {} " 6002 "and {}.".format(a.dtype, b.dtype)) 6003 if a.shape != () or b.shape != (): 6004 raise ValueError( 6005 "Arguments to rng_uniform must be scalars; got shapes {} and {}." 6006 .format(a.shape, b.shape)) 6007 return ShapedArray(shape, a.dtype) 6008 6009def _rng_uniform_translation_rule(c, a, b, *, shape): 6010 xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape) 6011 return xops.RngUniform(a, b, xla_shape) 6012 6013rng_uniform_p = Primitive("rng_uniform") 6014rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p)) 6015rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) 6016xla.translations[rng_uniform_p] = _rng_uniform_translation_rule 6017 6018 6019def _iota_abstract_eval(*, dtype, shape, dimension): 6020 _check_shapelike("iota", "shape", shape) 6021 if not any(dtypes.issubdtype(dtype, t) for t in _num): 6022 msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.' 6023 typename = str(np.dtype(dtype).name) 6024 accepted_typenames = (t.__name__ for t in _num) 6025 raise TypeError(msg.format(typename, ', '.join(accepted_typenames))) 6026 if not 0 <= dimension < len(shape): 6027 raise ValueError("iota dimension must be between 0 and len(shape), got " 6028 f"dimension={dimension} for shape {shape}") 6029 return ShapedArray(shape, dtype) 6030 6031def _iota_translation_rule(c, dtype, shape, dimension): 6032 etype = xla_client.dtype_to_etype(dtype) 6033 xla_shape = xc.Shape.array_shape(etype, shape) 6034 return xops.Iota(c, xla_shape, dimension) 6035 6036iota_p = Primitive('iota') 6037iota_p.def_impl(partial(xla.apply_primitive, iota_p)) 6038iota_p.def_abstract_eval(_iota_abstract_eval) 6039xla.translations[iota_p] = _iota_translation_rule 6040 6041 6042### util 6043 6044_ndim = np.ndim 6045 6046 6047def _dilate_shape(shape, dilation): 6048 """Utility function for computing the shape resulting from a dilation.""" 6049 if not np.all(np.greater(dilation, 0)): 6050 msg = "All dilations must be positive, got {}." 6051 raise TypeError(msg.format(dilation)) 6052 dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation) 6053 return np.where(shape == 0, 0, 6054 np.multiply(dilation, np.subtract(shape, 1)) + 1) 6055 6056def _ceil_divide(x1, x2): 6057 return -np.floor_divide(np.negative(x1), x2) 6058 6059def padtype_to_pads(in_shape, window_shape, window_strides, padding): 6060 """Convert padding string to list of pairs of pad values.""" 6061 PaddingType = xla_client.PaddingType 6062 6063 if isinstance(padding, str): 6064 mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME} 6065 try: 6066 padding = mapping[padding.upper()] 6067 except KeyError as err: 6068 msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}." 6069 raise RuntimeError(msg.format(padding)) from err 6070 6071 if padding == PaddingType.SAME: 6072 out_shape = _ceil_divide(in_shape, window_strides) 6073 pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + 6074 window_shape - in_shape) 6075 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 6076 elif padding == PaddingType.VALID: 6077 return [(0, 0)] * len(in_shape) 6078 else: 6079 msg = "Unknown padding type: {}." 6080 raise TypeError(msg.format(padding)) 6081 6082 6083def _check_same_dtypes(name, ignore_fp_precision, *ttypes): 6084 """Check that dtypes agree, possibly ignoring float precision.""" 6085 # the `ignore_fp_precision` flag exists because the XLA shape inference logic 6086 # allows mixed floating point precision, but the HLO verifier often rejects it 6087 types = list(map(np.dtype, ttypes)) # canonicalize 6088 if ignore_fp_precision: 6089 types = [ 6090 np.floating if dtypes.issubdtype(dtype, np.floating) 6091 else np.complexfloating if dtypes.issubdtype(dtype, np.complexfloating) 6092 else dtype for dtype in types] 6093 if len({dtypes.canonicalize_dtype(t) for t in types}) != 1: 6094 if ignore_fp_precision: 6095 msg = ("{} requires arguments to have same dtypes up to floating point " 6096 "precision, got {}.") 6097 else: 6098 msg = "{} requires arguments to have the same dtypes, got {}." 6099 raise TypeError(msg.format(name, ", ".join(map(str, types)))) 6100 6101 6102def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides): 6103 """Check that conv shapes are valid and are consistent with window_strides.""" 6104 if len(lhs_shape) != len(rhs_shape): 6105 msg = "Arguments to {} must have same rank, got {} and {}." 6106 raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape))) 6107 if len(lhs_shape) < 2: 6108 msg = "Arguments to {} must have rank at least 2, got {} and {}." 6109 raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape))) 6110 if lhs_shape[1] != rhs_shape[1]: 6111 msg = "Arguments to {} must agree on input feature size, got {} and {}." 6112 raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1])) 6113 _check_shapelike(name, "window_strides", window_strides) 6114 if not np.all(np.greater(window_strides, 0)): 6115 msg = "All elements of window_strides must be positive, got {}." 6116 raise TypeError(msg.format(window_strides)) 6117 if len(window_strides) != len(lhs_shape) - 2: 6118 msg = "{} window_strides has wrong length: expected {}, got {}." 6119 expected_length = len(lhs_shape) - 2 6120 raise TypeError(msg.format(name, expected_length, len(window_strides))) 6121 6122 6123def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): 6124 """Compute the shape tuple of a conv given input shapes in canonical order.""" 6125 if isinstance(pads, str): 6126 pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) 6127 if len(pads) != len(lhs_shape) - 2: 6128 msg = "Wrong number of explicit pads for convolution: expected {}, got {}." 6129 raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) 6130 6131 lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), 6132 axis=1)) 6133 out_space = np.floor_divide( 6134 np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 6135 out_space = np.maximum(0, out_space) 6136 assert lhs_shape[0] % batch_group_count == 0 6137 out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) 6138 return tuple(out_shape + tuple(out_space)) 6139 6140 6141def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, 6142 dimension_numbers): 6143 lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) 6144 lhs_trans = np.take(lhs_shape, lhs_perm) 6145 rhs_trans = np.take(rhs_shape, rhs_perm) 6146 out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding) 6147 return tuple(np.take(out_trans, np.argsort(out_perm))) 6148 6149 6150def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, 6151 dimension_numbers): 6152 lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) 6153 lhs_trans = np.take(lhs_shape, lhs_perm) 6154 rhs_trans = np.take(rhs_shape, rhs_perm) 6155 if isinstance(padding, str): 6156 padding = [_conv_transpose_padding(k, s, padding) 6157 for k,s in zip(rhs_trans[2:], window_strides)] 6158 padding = list(map(np.sum, padding)) 6159 unpad_out_space = [(i-1) * s - k + 2 6160 for i, k, s in zip(lhs_trans[2:], 6161 rhs_trans[2:], 6162 window_strides)] 6163 out_space = np.sum([unpad_out_space, padding], axis=0).tolist() 6164 out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space)) 6165 return tuple(np.take(out_trans, np.argsort(out_perm))) 6166 6167 6168def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): 6169 """Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints).""" 6170 if not isinstance(obj, (tuple, list, np.ndarray)): 6171 msg = "{} {} must be of type tuple/list/ndarray, got {}." 6172 raise TypeError(msg.format(fun_name, arg_name, type(obj))) 6173 # bool(obj) for an ndarray raises an error, so we check len 6174 if not len(obj): # pylint: disable=g-explicit-length-test 6175 return 6176 obj_arr = np.array(obj) 6177 if obj_arr.ndim != 1: 6178 msg = "{} {} must be rank 1, got {}." 6179 raise TypeError(msg.format(obj_arr.ndim)) 6180 try: 6181 canonicalize_shape(obj_arr) 6182 except TypeError as err: 6183 msg = "{} {} must have every element be an integer type, got {}." 6184 raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err 6185 lower_bound, bound_error = ( 6186 (1, "strictly positive") if non_zero_shape else (0, "nonnegative")) 6187 if not (obj_arr >= lower_bound).all(): 6188 msg = "{} {} must have every element be {}, got {}." 6189 raise TypeError(msg.format(fun_name, arg_name, bound_error, obj)) 6190 6191 6192def _dynamic_slice_indices(operand, start_indices): 6193 if len(start_indices) != operand.ndim: 6194 msg = ("Length of slice indices must match number of operand dimensions ({} " 6195 "vs {})") 6196 raise ValueError(msg.format(len(start_indices), operand.shape)) 6197 # map int over operand.shape to raise any dynamic-shape errors 6198 safe_map(int, operand.shape) 6199 if not isinstance(start_indices, (tuple, list)): 6200 if start_indices.ndim != 1: 6201 raise ValueError("Slice indices must be a 1D sequence, got {}" 6202 .format(start_indices.shape)) 6203 return select(lt(start_indices, _zeros(start_indices)), 6204 add(start_indices, _const(start_indices, operand.shape)), 6205 start_indices) 6206 else: 6207 return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_)) 6208 if isinstance(i, (int, np.integer)) 6209 else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i) 6210 for i, d in zip(start_indices, operand.shape)] 6211 6212 6213def _const(example, val): 6214 if dtypes.is_python_scalar(example): 6215 return dtypes.scalar_type_of(example)(val) 6216 return np.array(val, _dtype(example)) 6217 6218_zeros: Callable = partial(full_like, fill_value=0) 6219_zero: Callable = partial(full_like, shape=(), fill_value=0) 6220_ones: Callable = partial(full_like, fill_value=1) 6221_one: Callable = partial(full_like, shape=(), fill_value=1) 6222_twos: Callable = partial(full_like, fill_value=2) 6223_two: Callable = partial(full_like, shape=(), fill_value=2) 6224 6225dtype: Callable = dtypes.result_type 6226_dtype: Callable = dtypes.result_type 6227 6228def _iscomplex(x) -> bool: 6229 return dtypes.issubdtype(_dtype(x), np.complexfloating) 6230 6231 6232def ranges_like(*xs): 6233 start = 0 6234 for x in xs: 6235 x_len = len(x) 6236 yield range(start, start + x_len) 6237 start += x_len 6238 6239 6240def remaining(original, *removed_lists): 6241 removed = set(itertools.chain(*removed_lists)) 6242 return [i for i in original if i not in removed] 6243 6244 6245def _canonicalize_precision(precision): 6246 if precision is None: 6247 return None 6248 if isinstance(precision, Precision) or ( 6249 isinstance(precision, tuple) 6250 and len(precision) == 2 6251 and all(isinstance(p, Precision) for p in precision) 6252 ): 6253 return precision 6254 else: 6255 raise ValueError("Precision argument must be None, a lax.Precision value " 6256 f"or a tuple of two lax.Precision values; got {precision}") 6257 6258 6259def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers 6260 ) -> ConvDimensionNumbers: 6261 """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. 6262 6263 Args: 6264 lhs_shape: tuple of nonnegative integers, shape of the convolution input. 6265 rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. 6266 dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers 6267 object following the convolution dimension number specification format in 6268 xla_client.py. 6269 6270 Returns: 6271 A `ConvDimensionNumbers` object that represents `dimension_numbers` in the 6272 canonical form used by lax functions. 6273 """ 6274 if isinstance(dimension_numbers, ConvDimensionNumbers): 6275 return dimension_numbers 6276 if len(lhs_shape) != len(rhs_shape): 6277 msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." 6278 raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) 6279 6280 if dimension_numbers is None: 6281 iota = tuple(range(len(lhs_shape))) 6282 return ConvDimensionNumbers(iota, iota, iota) 6283 elif isinstance(dimension_numbers, (list, tuple)): 6284 if len(dimension_numbers) != 3: 6285 msg = "convolution dimension_numbers list/tuple must be length 3, got {}." 6286 raise TypeError(msg.format(len(dimension_numbers))) 6287 if not all(isinstance(elt, str) for elt in dimension_numbers): 6288 msg = "convolution dimension_numbers elements must be strings, got {}." 6289 raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) 6290 msg = ("convolution dimension_numbers[{}] must have len equal to the ndim " 6291 "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.") 6292 for i, elt in enumerate(dimension_numbers): 6293 if len(elt) != len(lhs_shape): 6294 raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) 6295 6296 lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) 6297 return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) 6298 else: 6299 msg = "convolution dimension_numbers must be tuple/list or None, got {}." 6300 raise TypeError(msg.format(type(dimension_numbers))) 6301 6302 6303def conv_general_permutations(dimension_numbers): 6304 """Utility for convolution dimension permutations relative to Conv HLO.""" 6305 lhs_spec, rhs_spec, out_spec = dimension_numbers 6306 lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") 6307 for i, (a, b) in enumerate(charpairs): 6308 if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: 6309 msg = ("convolution dimension_numbers[{}] must contain the characters " 6310 "'{}' and '{}' exactly once, got {}.") 6311 raise TypeError(msg.format(i, a, b, dimension_numbers[i])) 6312 if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): 6313 msg = ("convolution dimension_numbers[{}] cannot have duplicate " 6314 "characters, got {}.") 6315 raise TypeError(msg.format(i, dimension_numbers[i])) 6316 if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == 6317 set(out_spec) - set(out_char)): 6318 msg = ("convolution dimension_numbers elements must each have the same " 6319 "set of spatial characters, got {}.") 6320 raise TypeError(msg.format(dimension_numbers)) 6321 6322 def getperm(spec, charpair): 6323 spatial = (i for i, c in enumerate(spec) if c not in charpair) 6324 if spec is not rhs_spec: 6325 spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) 6326 return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) 6327 6328 lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) 6329 return lhs_perm, rhs_perm, out_perm 6330 6331 6332def _conv_general_proto(dimension_numbers): 6333 assert type(dimension_numbers) is ConvDimensionNumbers 6334 lhs_spec, rhs_spec, out_spec = dimension_numbers 6335 proto = xla_client.ConvolutionDimensionNumbers() 6336 proto.input_batch_dimension = lhs_spec[0] 6337 proto.input_feature_dimension = lhs_spec[1] 6338 proto.output_batch_dimension = out_spec[0] 6339 proto.output_feature_dimension = out_spec[1] 6340 proto.kernel_output_feature_dimension = rhs_spec[0] 6341 proto.kernel_input_feature_dimension = rhs_spec[1] 6342 proto.input_spatial_dimensions.extend(lhs_spec[2:]) 6343 proto.kernel_spatial_dimensions.extend(rhs_spec[2:]) 6344 proto.output_spatial_dimensions.extend(out_spec[2:]) 6345 return proto 6346 6347 6348def _conv_general_vjp_lhs_padding( 6349 in_shape, window_dimensions, window_strides, out_shape, padding, 6350 lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]: 6351 lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation) 6352 rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation) 6353 out_dilated_shape = _dilate_shape(out_shape, window_strides) 6354 pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 6355 pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 6356 - out_dilated_shape - pad_before) 6357 return safe_zip(pad_before, pad_after) 6358 6359 6360def _conv_general_vjp_rhs_padding( 6361 in_shape, window_dimensions, window_strides, out_shape, padding, 6362 lhs_dilation, rhs_dilation): 6363 lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation) 6364 rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation) 6365 out_dilated_shape = _dilate_shape(out_shape, window_strides) 6366 total_in_pad = out_dilated_shape + rhs_dilated_shape - lhs_dilated_shape - 1 6367 return [(pad[0], tot - pad[0]) for pad, tot in zip(padding, total_in_pad)] 6368 6369 6370def _balanced_eq(x, z, y): 6371 return div(select(_eq_meet(x, z), _ones(z), _zeros(z)), 6372 select(_eq_meet(y, z), _twos(z), _ones(z))) 6373 6374 6375def _eq_meet(a, b): 6376 a_dtype, b_dtype = _dtype(a), _dtype(b) 6377 if a_dtype != b_dtype: 6378 higher_dtype = dtypes.promote_types(a_dtype, b_dtype) 6379 if higher_dtype == a_dtype: 6380 a = convert_element_type(a, b_dtype) 6381 else: 6382 b = convert_element_type(b, a_dtype) 6383 return eq(a, b) 6384 6385 6386def _abstractify(x): 6387 return raise_to_shaped(core.get_aval(x)) 6388 6389 6390def _check_user_dtype_supported(dtype, fun_name=None): 6391 # Avoid using `dtype in [...]` becuase of numpy dtype equality overloading. 6392 if isinstance(dtype, type) and dtype in {bool, int, float, complex}: 6393 return 6394 np_dtype = np.dtype(dtype) 6395 if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16: 6396 msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" 6397 msg += f" in {fun_name}" if fun_name else "" 6398 raise TypeError(msg) 6399 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype): 6400 msg = ("Explicitly requested dtype {} {} is not available, " 6401 "and will be truncated to dtype {}. To enable more dtypes, set the " 6402 "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " 6403 "environment variable. " 6404 "See https://github.com/google/jax#current-gotchas for more.") 6405 fun_name = f"requested in {fun_name}" if fun_name else "" 6406 truncated_dtype = dtypes.canonicalize_dtype(dtype).name 6407 warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2) 6408 6409 6410def _canonicalize_axis(axis, num_dims): 6411 """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" 6412 axis = operator.index(axis) 6413 if not -num_dims <= axis < num_dims: 6414 raise ValueError( 6415 "axis {} is out of bounds for array of dimension {}".format( 6416 axis, num_dims)) 6417 if axis < 0: 6418 axis = axis + num_dims 6419 return axis 6420 6421 6422tie_in_p = Primitive('tie_in') 6423 6424@config.register_omnistaging_disabler 6425def omnistaging_disabler() -> None: 6426 global tie_in 6427 6428 def tie_in(x: Array, y: Array) -> Array: 6429 """Returns the value of ``y`` but with a fake data dependence on ``x``. 6430 6431 When staging to XLA (e.g. running under jit or pmap), values that don't depend 6432 on computation inputs are computed op-by-op, and folded into the XLA 6433 computation as constants. 6434 6435 ``tie_in`` provides a way to explicitly stage values into the computation. 6436 When staging to XLA and ``x`` is already staged, then the result of ``tie_in`` 6437 is ``y``, but staged to XLA. Downstream use of the result will also be staged 6438 to XLA. 6439 6440 For example, ``lax.sin(const)`` would be constant-folded if ``const`` is 6441 a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to 6442 XLA as long as ``x`` is staged to XLA. 6443 """ 6444 if config.omnistaging_enabled: 6445 return y 6446 else: 6447 return tie_in_p.bind(x, y) 6448 6449 # If lax has already been imported, we need to monkey-patch the 6450 # lax/__init__.py import of tie_in. If not (i.e. if this is running at lax 6451 # module creation time) then we'll get an import error. 6452 try: 6453 jax.lax.tie_in = tie_in 6454 except AttributeError: 6455 pass 6456 6457 def _tie_in_transpose_rule(t, x, y): 6458 if ad.is_undefined_primal(x): 6459 return [ad_util.Zero(x.aval), t] 6460 else: 6461 return [ad_util.Zero.from_value(x), t] 6462 6463 def _tie_in_batch_rule(batched_args, batch_dims): 6464 y = tie_in(*batched_args) 6465 _, bdim_y = batch_dims 6466 return y, bdim_y 6467 6468 def _tie_in_impl(x, y): 6469 core.check_valid_jaxtype(x) 6470 core.check_valid_jaxtype(y) 6471 return y 6472 6473 def _tie_in_jvp(primals, tangents): 6474 x, y = primals 6475 x_dot, y_dot = tangents 6476 if type(y_dot) is ad_util.Zero or core.get_aval(y_dot).dtype is dtypes.float0: 6477 return y, y_dot # skip tying in in this case 6478 else: 6479 return ad.linear_jvp(tie_in_p, primals, tangents) 6480 6481 tie_in_p.def_impl(_tie_in_impl) 6482 tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y)) 6483 xla.translations[tie_in_p] = lambda c, x, y: y 6484 ad.primitive_jvps[tie_in_p] = _tie_in_jvp 6485 ad.primitive_transposes[tie_in_p] = partial(ad.linear_transpose2, _tie_in_transpose_rule) 6486 batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule 6487 masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] 6488