1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Pytype is too slow to check this file.
16# pytype: skip-file
17
18import builtins
19from enum import IntEnum
20import functools
21import itertools
22import operator
23from typing import (Any, Callable, List, NamedTuple, Optional, Sequence,\
24                    Union, Tuple)
25import warnings
26
27import numpy as np
28
29import jax
30from jax import core
31from jax import ad_util
32from jax import api
33from jax import api_util
34from jax import linear_util as lu
35from jax import dtypes
36from jax import lazy
37from jax import tree_util
38from jax.config import flags, config
39from jax.core import (Primitive, _canonicalize_dimension, UnshapedArray,
40                      ShapedArray, ConcreteArray, raise_to_shaped,
41                      abstract_token, canonicalize_shape)
42from jax.abstract_arrays import array_types
43from jax.interpreters import partial_eval as pe
44from jax.interpreters import xla
45from jax.interpreters import pxla
46from jax.interpreters import ad
47from jax.interpreters import invertible_ad as iad
48from jax.interpreters import batching
49from jax.interpreters import masking
50from jax._src.util import (cache, safe_zip, partial, prod, safe_map,
51                           canonicalize_axis, split_list)
52from jax.tree_util import tree_map
53from jax.lib import pytree
54from jax.lib import xla_bridge
55from jax.lib import xla_client
56
57xb = xla_bridge
58xc = xla_client
59xops = xla_client.ops
60
61FLAGS = flags.FLAGS
62
63_max = builtins.max
64_min = builtins.min
65_reduce = functools.reduce
66
67Array = Any
68DType = Any
69Shape = Sequence[int]
70
71def _try_broadcast_shapes(shapes):
72  assert shapes
73  if len(shapes) == 1: return shapes[0]
74  rank, *others = {len(shape) for shape in shapes}
75  if others: return None  # must have consistent rank
76  if not rank: return ()  # scalar case
77  result_shape = [None] * rank
78  for i, sizes in enumerate(zip(*shapes)):
79    if sizes[:-1] == sizes[1:]:
80      result_shape[i] = sizes[0]  # all equal sizes for this dimension
81    else:
82      sizes = [d for d in sizes if d != 1]
83      if sizes[:-1] != sizes[1:]:
84        return None  # must have equal sizes other than 1-sized axes
85      result_shape[i] = sizes[0] if sizes else 1
86  return tuple(result_shape)
87
88@cache()
89def broadcast_shapes(*shapes):
90  """Returns the shape that results from NumPy broadcasting of `shapes`."""
91  if len(shapes) == 1:
92    return shapes[0]
93  ndim = _max(len(shape) for shape in shapes)
94  shapes = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
95  result_shape = _try_broadcast_shapes(shapes)
96  if result_shape is None:
97    raise ValueError("Incompatible shapes for broadcasting: {}"
98                     .format(tuple(map(tuple, shapes))))
99  return result_shape
100
101def _identity(x): return x
102
103### traceables
104
105def neg(x: Array) -> Array:
106  r"""Elementwise negation: :math:`-x`."""
107  return neg_p.bind(x)
108
109def sign(x: Array) -> Array:
110  r"""Elementwise sign.
111
112  For floating-point inputs, returns
113  :math:`\mathrm{sign}(x) = \begin{cases}
114  -1 & x < 0\\
115  -0 & x = -0\\
116  \mathit{NaN} & x = \mathit{NaN}\\
117  +0 & x = +0\\
118  1 & x > 0
119  \end{cases}`
120
121  For signed integer inputs, returns
122  :math:`\mathrm{sign}(x) = \begin{cases}
123  -1 & x < 0\\
124  0 & x = 0\\
125  1 & x > 0
126  \end{cases}`
127
128  For complex inputs, returns the complex phase, i.e.
129  :math:`\mathrm{sign}(x) = \frac{x}{|x|}`.
130  """
131  return sign_p.bind(x)
132
133def nextafter(x1: Array, x2: Array) -> Array:
134  r"""Returns the next representable value after `x1` in the direction of `x2`.
135
136  Note that in some environments flush-denormal-to-zero semantics is used.
137  This means that, around zero, this function returns strictly non-zero
138  values which appear as zero in any operations. Consider this example::
139    >>> jnp.nextafter(0, 1)  # denormal numbers are representable
140    DeviceArray(1.e-45, dtype=float32)
141    >>> jnp.nextafter(0, 1) * 1  # but are flushed to zero
142    DeviceArray(0., dtype=float32)
143
144  For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
145  """
146  return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1))
147
148def floor(x: Array) -> Array:
149  r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
150  return floor_p.bind(x)
151
152def ceil(x: Array) -> Array:
153  r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
154  return ceil_p.bind(x)
155
156class RoundingMethod(IntEnum):
157  AWAY_FROM_ZERO = 0
158  TO_NEAREST_EVEN = 1
159
160def round(x: Array,
161          rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
162          ) -> Array:
163  r"""Elementwise round.
164
165  Rounds values to the nearest integer.
166
167  Args:
168    x: an array or scalar value to round.
169    rounding_method: the method to use when rounding halfway values
170      (e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible
171      values.
172
173  Returns:
174    An array containing the elementwise rounding of x.
175  """
176  rounding_method = RoundingMethod(rounding_method)
177  return round_p.bind(x, rounding_method=rounding_method)
178
179def is_finite(x: Array) -> Array:
180  r"""Elementwise :math:`\mathrm{isfinite}`.
181
182  For each element x returns `True` if and only if x is not :math:`\pm\infty` or
183  :math:`\mathit{NaN}`.
184  """
185  return is_finite_p.bind(x)
186
187def exp(x: Array) -> Array:
188  r"""Elementwise exponential: :math:`e^x`."""
189  return exp_p.bind(x)
190
191def expm1(x: Array) -> Array:
192  r"""Elementwise :math:`e^{x} - 1`."""
193  return expm1_p.bind(x)
194
195def log(x: Array) -> Array:
196  r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
197  return log_p.bind(x)
198
199def log1p(x: Array) -> Array:
200  r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
201  return log1p_p.bind(x)
202
203def tanh(x: Array) -> Array:
204  r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
205  return tanh_p.bind(x)
206
207def sin(x: Array) -> Array:
208  r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
209  return sin_p.bind(x)
210
211def cos(x: Array) -> Array:
212  r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
213  return cos_p.bind(x)
214
215def atan2(x: Array, y: Array) -> Array:
216  r"""Elementwise arc tangent of two variables:
217    :math:`\mathrm{atan}({x \over y})`."""
218  return atan2_p.bind(x, y)
219
220def betainc(a: Array, b: Array, x: Array) -> Array:
221  r"""Elementwise regularized incomplete beta integral."""
222  return regularized_incomplete_beta_p.bind(a, b, x)
223
224def lgamma(x: Array) -> Array:
225  r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`."""
226  return lgamma_p.bind(x)
227
228def digamma(x: Array) -> Array:
229  r"""Elementwise digamma: :math:`\psi(x)`."""
230  return digamma_p.bind(x)
231
232def igamma(a: Array, x: Array) -> Array:
233  r"""Elementwise regularized incomplete gamma function."""
234  return igamma_p.bind(a, x)
235
236def igammac(a: Array, x: Array) -> Array:
237  r"""Elementwise complementary regularized incomplete gamma function."""
238  return igammac_p.bind(a, x)
239
240def igamma_grad_a(a: Array, x: Array) -> Array:
241  r"""Elementwise derivative of the regularized incomplete gamma function."""
242  return igamma_grad_a_p.bind(a, x)
243
244def random_gamma_grad(a: Array, x: Array) -> Array:
245  r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
246  return random_gamma_grad_p.bind(a, x)
247
248def bessel_i0e(x: Array) -> Array:
249  r"""Exponentially scaled modified Bessel function of order 0:
250  :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
251  """
252  return bessel_i0e_p.bind(x)
253
254def bessel_i1e(x: Array) -> Array:
255  r"""Exponentially scaled modified Bessel function of order 1:
256  :math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)`
257  """
258  return bessel_i1e_p.bind(x)
259
260def erf(x: Array) -> Array:
261  r"""Elementwise error function: :math:`\mathrm{erf}(x)`."""
262  return erf_p.bind(x)
263
264def erfc(x: Array) -> Array:
265  r"""Elementwise complementary error function:
266    :math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`."""
267  return erfc_p.bind(x)
268
269def erf_inv(x: Array) -> Array:
270  r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
271  return erf_inv_p.bind(x)
272
273def real(x: Array) -> Array:
274  r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
275
276  Returns the real part of a complex number.
277  """
278  return real_p.bind(x)
279
280def imag(x: Array) -> Array:
281  r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
282
283  Returns the imaginary part of a complex number.
284  """
285  return imag_p.bind(x)
286
287def complex(x: Array, y: Array) -> Array:
288  r"""Elementwise make complex number: :math:`x + jy`.
289
290  Builds a complex number from real and imaginary parts.
291  """
292  return complex_p.bind(_brcast(x, y), _brcast(y, x))
293
294def conj(x: Array) -> Array:
295  r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
296  return conj_p.bind(x, input_dtype=_dtype(x))
297
298def abs(x: Array) -> Array:
299  r"""Elementwise absolute value: :math:`|x|`."""
300  return abs_p.bind(x)
301
302def pow(x: Array, y: Array) -> Array:
303  r"""Elementwise power: :math:`x^y`."""
304  return pow_p.bind(x, y)
305
306def integer_pow(x: Array, y: int) -> Array:
307  r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
308  return integer_pow_p.bind(x, y=y)
309
310def sqrt(x: Array) -> Array:
311  r"""Elementwise square root: :math:`\sqrt{x}`."""
312  return sqrt_p.bind(x)
313
314def rsqrt(x: Array) -> Array:
315  r"""Elementwise reciprocal square root:  :math:`1 \over \sqrt{x}."""
316  return rsqrt_p.bind(x)
317
318def bitwise_not(x: Array) -> Array:
319  r"""Elementwise NOT: :math:`\neg x`."""
320  return not_p.bind(x)
321
322def bitwise_and(x: Array, y: Array) -> Array:
323  r"""Elementwise AND: :math:`x \wedge y`."""
324  return and_p.bind(x, y)
325
326def bitwise_or(x: Array, y: Array) -> Array:
327  r"""Elementwise OR: :math:`x \vee y`."""
328  return or_p.bind(x, y)
329
330def bitwise_xor(x: Array, y: Array) -> Array:
331  r"""Elementwise exclusive OR: :math:`x \oplus y`."""
332  return xor_p.bind(x, y)
333
334def population_count(x: Array) -> Array:
335  r"""Elementwise popcount, count the number of set bits in each element."""
336  return population_count_p.bind(x)
337
338def add(x: Array, y: Array) -> Array:
339  r"""Elementwise addition: :math:`x + y`."""
340  return add_p.bind(x, y)
341
342def sub(x: Array, y: Array) -> Array:
343  r"""Elementwise subtraction: :math:`x - y`."""
344  return sub_p.bind(x, y)
345
346def mul(x: Array, y: Array) -> Array:
347  r"""Elementwise multiplication: :math:`x \times y`."""
348  return mul_p.bind(x, y)
349
350def div(x: Array, y: Array) -> Array:
351  r"""Elementwise division: :math:`x \over y`."""
352  return div_p.bind(x, y)
353
354def rem(x: Array, y: Array) -> Array:
355  r"""Elementwise remainder: :math:`x \bmod y`."""
356  return rem_p.bind(x, y)
357
358def max(x: Array, y: Array) -> Array:
359  r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
360
361  For complex numbers, uses a lexicographic comparison on the
362  `(real, imaginary)` pairs."""
363  return max_p.bind(x, y)
364
365def min(x: Array, y: Array) -> Array:
366  r"""Elementwise minimum:  :math:`\mathrm{min}(x, y)`
367
368  For complex numbers, uses a lexicographic comparison on the
369  `(real, imaginary)` pairs."""
370  return min_p.bind(x, y)
371
372def shift_left(x: Array, y: Array) -> Array:
373  r"""Elementwise left shift: :math:`x \ll y`."""
374  return shift_left_p.bind(x, y)
375
376def shift_right_arithmetic(x: Array, y: Array) -> Array:
377  r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
378  return shift_right_arithmetic_p.bind(x, y)
379
380def shift_right_logical(x: Array, y: Array) -> Array:
381  r"""Elementwise logical right shift: :math:`x \gg y`."""
382  return shift_right_logical_p.bind(x, y)
383
384def eq(x: Array, y: Array) -> Array:
385  r"""Elementwise equals: :math:`x = y`."""
386  return eq_p.bind(x, y)
387
388def ne(x: Array, y: Array) -> Array:
389  r"""Elementwise not-equals: :math:`x \neq y`."""
390  return ne_p.bind(x, y)
391
392def ge(x: Array, y: Array) -> Array:
393  r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
394  return ge_p.bind(x, y)
395
396def gt(x: Array, y: Array) -> Array:
397  r"""Elementwise greater-than: :math:`x > y`."""
398  return gt_p.bind(x, y)
399
400def le(x: Array, y: Array) -> Array:
401  r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
402  return le_p.bind(x, y)
403
404def lt(x: Array, y: Array) -> Array:
405  r"""Elementwise less-than: :math:`x < y`."""
406  return lt_p.bind(x, y)
407
408def convert_element_type(operand: Array, new_dtype: DType) -> Array:
409  """Elementwise cast.
410
411  Wraps XLA's `ConvertElementType
412  <https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
413  operator, which performs an elementwise conversion from one type to another.
414  Similar to a C++ `static_cast`.
415
416  Args:
417    operand: an array or scalar value to be cast.
418    new_dtype: the new type. Should be a NumPy type.
419
420  Returns:
421    An array with the same shape as `operand`, cast elementwise to `new_dtype`.
422  """
423  new_dtype = dtypes.canonicalize_dtype(new_dtype)
424  # Avoids dropping precision by casting Python scalars to the default Jax
425  # type. If we passed a Python scalar directly to the bind call below, it is
426  # cast to the default type as part of the calling convention.
427  if type(operand) in dtypes.python_scalar_dtypes:
428    operand = np.asarray(operand, new_dtype)
429  old_dtype = dtypes.canonicalize_dtype(_dtype(operand))
430  if old_dtype == new_dtype:
431    if isinstance(operand, (core.Tracer, xla.DeviceArray)):
432      return operand
433    else:
434      return _device_put_raw(np.asarray(operand))
435  if (dtypes.issubdtype(old_dtype, np.complexfloating) and
436      not dtypes.issubdtype(new_dtype, np.complexfloating)):
437    msg = "Casting complex values to real discards the imaginary part"
438    warnings.warn(msg, np.ComplexWarning, stacklevel=2)
439  return convert_element_type_p.bind(operand, new_dtype=new_dtype)
440
441def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array:
442  """Elementwise bitcast.
443
444  Wraps XLA's `BitcastConvertType
445  <https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
446  operator, which performs a bit cast from one type to another. The bitwidth
447  of the source and destination types must match.
448
449  Args:
450    operand: an array or scalar value to be cast
451    new_dtype: the new type. Should be a NumPy type.
452
453  Returns:
454    An array with the same shape as `operand`, bitcast elementwise to
455    `new_dtype`.
456  """
457  new_dtype = dtypes.canonicalize_dtype(new_dtype)
458  old_dtype = _dtype(operand)
459  if old_dtype != new_dtype:
460    return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
461  else:
462    return operand
463
464def clamp(min: Array, x: Array, max: Array) -> Array:
465  r"""Elementwise clamp.
466
467  Returns :math:`\mathrm{clamp}(x) = \begin{cases}
468  \mathit{min} & \text{if } x < \mathit{min},\\
469  \mathit{max} & \text{if } x > \mathit{max},\\
470  x & \text{otherwise}
471  \end{cases}`.
472  """
473  return clamp_p.bind(min, x, max)
474
475def concatenate(operands: Sequence[Array], dimension: int) -> Array:
476  """Concatenates a sequence of arrays along `dimension`.
477
478  Wraps XLA's `Concatenate
479  <https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
480  operator.
481
482  Args:
483    operands: a sequence of arrays to concatenate. The arrays must have equal
484      shapes, except in the `dimension` axis.
485    dimension: the dimension along which to concatenate the arrays.
486
487  Returns:
488    An array containing the concatenation.
489  """
490  return concatenate_p.bind(*operands, dimension=dimension)
491
492Precision = xla_client.PrecisionConfig.Precision
493Precision.__str__ = lambda precision: precision.name
494PrecisionType = Any
495PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]]
496
497
498class ConvDimensionNumbers(NamedTuple):
499  """Describes batch, spatial, and feature dimensions of a convolution.
500
501  Args:
502    lhs_spec: a tuple of nonnegative integer dimension numbers containing
503      `(batch dimension, feature dimension, spatial dimensions...)`.
504    rhs_spec: a tuple of nonnegative integer dimension numbers containing
505      `(out feature dimension, in feature dimension, spatial dimensions...)`.
506    out_spec: a tuple of nonnegative integer dimension numbers containing
507      `(batch dimension, feature dimension, spatial dimensions...)`.
508  """
509  lhs_spec: Sequence[int]
510  rhs_spec: Sequence[int]
511  out_spec: Sequence[int]
512
513ConvGeneralDilatedDimensionNumbers = Union[
514  None, ConvDimensionNumbers, Tuple[str, str, str]]
515
516def conv_general_dilated(
517  lhs: Array, rhs: Array, window_strides: Sequence[int],
518  padding: Union[str, Sequence[Tuple[int, int]]],
519  lhs_dilation: Optional[Sequence[int]] = None,
520  rhs_dilation: Optional[Sequence[int]] = None,
521  dimension_numbers: ConvGeneralDilatedDimensionNumbers  = None,
522  feature_group_count: int = 1, batch_group_count: int = 1,
523  precision: PrecisionLike = None) -> Array:
524  """General n-dimensional convolution operator, with optional dilation.
525
526  Wraps XLA's `Conv
527  <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
528  operator.
529
530  Args:
531    lhs: a rank `n+2` dimensional input array.
532    rhs: a rank `n+2` dimensional array of kernel weights.
533    window_strides: a sequence of `n` integers, representing the inter-window
534      strides.
535    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
536      `n` `(low, high)` integer pairs that give the padding to apply before and
537      after each spatial dimension.
538    lhs_dilation: `None`, or a sequence of `n` integers, giving the
539      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
540      is also known as transposed convolution.
541    rhs_dilation: `None`, or a sequence of `n` integers, giving the
542      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
543      is also known as atrous convolution.
544    dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
545      a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
546      of length `n+2`.
547    feature_group_count: integer, default 1. See XLA HLO docs.
548    batch_group_count: integer, default 1. See XLA HLO docs.
549    precision: Optional. Either ``None``, which means the default precision for
550      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
551      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
552      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
553
554  Returns:
555    An array containing the convolution result.
556
557  In the string case of `dimension_numbers`, each character identifies by
558  position:
559
560  - the batch dimensions in `lhs`, `rhs`, and the output with the character
561    'N',
562  - the feature dimensions in `lhs` and the output with the character 'C',
563  - the input and output feature dimensions in rhs with the characters 'I'
564    and 'O' respectively, and
565  - spatial dimension correspondences between lhs, rhs, and the output using
566    any distinct characters.
567
568  For example, to indicate dimension numbers consistent with the `conv` function
569  with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
570  another example, to indicate dimension numbers consistent with the TensorFlow
571  Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
572  latter form of convolution dimension specification, window strides are
573  associated with spatial dimension character labels according to the order in
574  which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
575  is matched with the dimension corresponding to the first character
576  appearing in rhs_spec that is not `'I'` or `'O'`.
577
578  If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
579  (for a 2D convolution).
580  """
581  dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
582  if lhs_dilation is None:
583    lhs_dilation = (1,) * (lhs.ndim - 2)
584  elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
585    raise ValueError(
586        "String padding is not implemented for transposed convolution "
587        "using this op. Please either exactly specify the required padding or "
588        "use conv_transpose.")
589  if rhs_dilation is None:
590    rhs_dilation = (1,) * (rhs.ndim - 2)
591  if isinstance(padding, str):
592    lhs_perm, rhs_perm, _ = dnums
593    rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
594    effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
595    padding = padtype_to_pads(
596        np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
597        window_strides, padding)
598  return conv_general_dilated_p.bind(
599      lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
600      lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
601      dimension_numbers=dnums,
602      feature_group_count=feature_group_count,
603      batch_group_count=batch_group_count,
604      lhs_shape=lhs.shape, rhs_shape=rhs.shape,
605      precision=_canonicalize_precision(precision))
606
607def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
608        preferred_element_type: Optional[DType] = None) -> Array:
609  """Vector/vector, matrix/vector, and matrix/matrix multiplication.
610
611  Wraps XLA's `Dot
612  <https://www.tensorflow.org/xla/operation_semantics#dot>`_
613  operator.
614
615  For more general contraction, see the `dot_general` operator.
616
617  Args:
618    lhs: an array of rank 1 or 2.
619    rhs: an array of rank 1 or 2.
620    precision: Optional. Either ``None``, which means the default precision for
621      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
622      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
623      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
624    preferred_element_type: Optional. Either ``None``, which means the default
625      accumulation type for the input types, or a datatype, indicating to
626      accumulate results to and return a result with that datatype.
627
628  Returns:
629    An array containing the product.
630  """
631  if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]:
632    return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
633                       precision=precision, preferred_element_type=preferred_element_type)
634  else:
635    raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
636        lhs.shape, rhs.shape))
637
638
639DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
640                            Tuple[Sequence[int], Sequence[int]]]
641
642def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
643                precision: PrecisionLike = None,
644                preferred_element_type: Optional[DType] = None) -> Array:
645  """More general contraction operator.
646
647  Wraps XLA's `DotGeneral
648  <https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
649  operator.
650
651  Args:
652    lhs: an array
653    rhs: an array
654    dimension_numbers: a tuple of tuples of the form
655      `((lhs_contracting_dims, rhs_contracting_dims),
656      (lhs_batch_dims, rhs_batch_dims))`
657    precision: Optional. Either ``None``, which means the default precision for
658      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
659      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
660      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
661    preferred_element_type: Optional. Either ``None``, which means the default
662      accumulation type for the input types, or a datatype, indicating to
663      accumulate results to and return a result with that datatype.
664
665  Returns:
666    An array containing the result.
667  """
668  contract_dims_seq, batch_dims_seq = dimension_numbers
669  contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq))
670  batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq))
671  return dot_general_p.bind(lhs, rhs,
672                            dimension_numbers=(contract_dims, batch_dims),
673                            precision=_canonicalize_precision(precision),
674                            preferred_element_type=preferred_element_type)
675
676def broadcast(operand: Array, sizes: Sequence[int]) -> Array:
677  """Broadcasts an array, adding new major dimensions.
678
679  Wraps XLA's `Broadcast
680  <https://www.tensorflow.org/xla/operation_semantics#broadcast>`_
681  operator.
682
683  Args:
684    operand: an array
685    sizes: a sequence of integers, giving the sizes of new major dimensions
686      to add.
687
688  Returns:
689    An array containing the result.
690  """
691  dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
692  return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
693
694def broadcast_in_dim(operand: Array, shape: Shape,
695                     broadcast_dimensions: Sequence[int]) -> Array:
696  """Wraps XLA's `BroadcastInDim
697  <https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
698  operator.
699  """
700  shape = _broadcast_in_dim_shape_rule(
701    operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
702  if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
703      and isinstance(operand, (xla.DeviceArray, core.Tracer))):
704    return operand
705  return broadcast_in_dim_p.bind(
706      operand, shape=tuple(shape),
707      broadcast_dimensions=tuple(broadcast_dimensions))
708
709def broadcast_to_rank(x: Array, rank: int) -> Array:
710  """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
711  return broadcast(x, (1,) * (rank - x.ndim))
712
713def reshape(operand: Array, new_sizes: Shape,
714            dimensions: Optional[Sequence[int]] = None) -> Array:
715  """Wraps XLA's `Reshape
716  <https://www.tensorflow.org/xla/operation_semantics#reshape>`_
717  operator.
718
719  For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
720  ``lax.expand_dims``. These preserve information about axis identity that may
721  be useful for advanced transformation rules.
722  """
723  new_sizes = canonicalize_shape(new_sizes)  # TODO
724  new_sizes = tuple(new_sizes)
725  same_shape = np.shape(operand) == new_sizes
726  same_dims = dimensions is None or tuple(dimensions) == tuple(range(np.ndim(operand)))
727  if np.shape(operand) and same_shape and same_dims:
728    return operand
729  else:
730    return reshape_p.bind(
731      operand, new_sizes=new_sizes,
732      dimensions=None if dimensions is None or same_dims else tuple(dimensions))
733
734def pad(operand: Array, padding_value: Array,
735        padding_config: Sequence[Tuple[int, int, int]]) -> Array:
736  """Applies low, high, and/or interior padding to an array.
737
738  Wraps XLA's `Pad
739  <https://www.tensorflow.org/xla/operation_semantics#pad>`_
740  operator.
741
742  Args:
743    operand: an array to be padded.
744    padding_value: the value to be inserted as padding. Must have the same dtype
745      as ``operand``.
746    padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
747      giving the amount of low, high, and interior (dilation) padding to insert
748      in each dimension.
749
750  Returns:
751    The ``operand`` array with padding value ``padding_value`` inserted in each
752    dimension according to the ``padding_config``.
753  """
754  return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
755
756def rev(operand: Array, dimensions: Sequence[int]) -> Array:
757  """Wraps XLA's `Rev
758  <https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
759  operator.
760  """
761  return rev_p.bind(operand, dimensions=tuple(dimensions))
762
763def select(pred: Array, on_true: Array, on_false: Array) -> Array:
764  """Wraps XLA's `Select
765  <https://www.tensorflow.org/xla/operation_semantics#select>`_
766  operator.
767  """
768  return select_p.bind(pred, on_true, on_false)
769
770def slice(operand: Array, start_indices: Sequence[int],
771          limit_indices: Sequence[int],
772          strides: Optional[Sequence[int]] = None) -> Array:
773  """Wraps XLA's `Slice
774  <https://www.tensorflow.org/xla/operation_semantics#slice>`_
775  operator.
776  """
777  return slice_p.bind(operand, start_indices=tuple(start_indices),
778                      limit_indices=tuple(limit_indices),
779                      strides=None if strides is None else tuple(strides))
780
781def dynamic_slice(operand: Array, start_indices: Sequence[Array],
782                  slice_sizes: Shape) -> Array:
783  """Wraps XLA's `DynamicSlice
784  <https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
785  operator.
786
787  Args:
788    operand: an array to slice.
789    start_indices: a list of scalar indices, one per dimension. These values
790      may be dynamic.
791    slice_sizes: the size of the slice. Must be a sequence of non-negative
792      integers with length equal to `ndim(operand)`. Inside a JIT compiled
793      function, only static values are supported (all JAX arrays inside JIT
794      must have statically known size).
795
796  Returns:
797    An array containing the slice.
798  """
799  start_indices = _dynamic_slice_indices(operand, start_indices)
800  return dynamic_slice_p.bind(operand, *start_indices,
801                              slice_sizes=tuple(slice_sizes))
802
803def dynamic_update_slice(operand: Array, update: Array,
804                         start_indices: Array) -> Array:
805  """Wraps XLA's `DynamicUpdateSlice
806  <https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
807  operator.
808
809  Args:
810    operand: an array to slice.
811    update: an array containing the new values to write onto `operand`.
812    start_indices: a list of scalar indices, one per dimension.
813
814  Returns:
815    An array containing the slice.
816  """
817  start_indices = _dynamic_slice_indices(operand, start_indices)
818  return dynamic_update_slice_p.bind(operand, update, *start_indices)
819
820
821class GatherDimensionNumbers(NamedTuple):
822  """
823  Describes the dimension number arguments to an `XLA's Gather operator
824  <https://www.tensorflow.org/xla/operation_semantics#gather>`_. See the XLA
825  documentation for more details of what the dimension numbers mean.
826
827  Args:
828    offset_dims: the set of dimensions in the `gather` output that offset into
829      an array sliced from `operand`. Must be a tuple of integers in ascending
830      order, each representing a dimension number of the output.
831    collapsed_slice_dims: the set of dimensions `i` in `operand` that have
832      `slice_sizes[i] == 1` and that should not have a corresponding dimension
833      in the output of the gather. Must be a tuple of integers in ascending
834      order.
835    start_index_map: for each dimension in `start_indices`, gives the
836      corresponding dimension in `operand` that is to be sliced. Must be a
837      tuple of integers with size equal to `start_indices.shape[-1]`.
838
839  Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is
840  implicit; there is always an index vector dimension and it must always be the
841  last dimension. To gather scalar indices, add a trailing dimension of size 1.
842  """
843  offset_dims: Sequence[int]
844  collapsed_slice_dims: Sequence[int]
845  start_index_map: Sequence[int]
846
847
848def gather(operand: Array, start_indices: Array,
849           dimension_numbers: GatherDimensionNumbers,
850           slice_sizes: Shape) -> Array:
851  """Gather operator.
852
853  Wraps `XLA's Gather operator
854  <https://www.tensorflow.org/xla/operation_semantics#gather>`_.
855
856  The semantics of gather are complicated, and its API might change in the
857  future. For most use cases, you should prefer `Numpy-style indexing
858  <https://docs.scipy.org/doc/numpy-1.16.0/reference/arrays.indexing.html>`_
859  (e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
860
861  Args:
862    operand: an array from which slices should be taken
863    start_indices: the indices at which slices should be taken
864    dimension_numbers: a `lax.GatherDimensionNumbers` object that describes
865      how dimensions of `operand`, `start_indices` and the output relate.
866    slice_sizes: the size of each slice. Must be a sequence of non-negative
867      integers with length equal to `ndim(operand)`.
868
869  Returns:
870    An array containing the gather output.
871  """
872  return gather_p.bind(
873      operand, start_indices, dimension_numbers=dimension_numbers,
874      slice_sizes=canonicalize_shape(slice_sizes))
875
876
877class ScatterDimensionNumbers(NamedTuple):
878  """
879  Describes the dimension number arguments to an `XLA's Scatter operator
880  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_. See the XLA
881  documentation for more details of what the dimension numbers mean.
882
883  Args:
884    update_window_dims: the set of dimensions in the `updates` that are window
885      dimensions. Must be a tuple of integers in ascending
886      order, each representing a dimension number.
887    inserted_window_dims: the set of size 1 window dimensions that must be inserted
888      into the shape of `updates`. Must be a tuple of integers in ascending
889      order, each representing a dimension number of the output. These are the
890      mirror image of `collapsed_slice_dims` in the case of `gather`.
891    scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
892      the corresponding dimension in `operand`. Must be a sequence of integers
893      with size equal to indices.shape[-1].
894
895  Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is
896  implicit; there is always an index vector dimension and it must always be the
897  last dimension. To scatter scalar indices, add a trailing dimension of size 1.
898  """
899  update_window_dims: Sequence[int]
900  inserted_window_dims: Sequence[int]
901  scatter_dims_to_operand_dims: Sequence[int]
902
903def scatter_add(operand: Array, scatter_indices: Array, updates: Array,
904                dimension_numbers: ScatterDimensionNumbers, *,
905                indices_are_sorted: bool = False,
906                unique_indices: bool = False) -> Array:
907  """Scatter-add operator.
908
909  Wraps `XLA's Scatter operator
910  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
911  addition is used to combine updates and values from `operand`.
912
913  The semantics of scatter are complicated and its API is subject to change.
914
915  Args:
916    operand: an array to which the scatter should be applied
917    scatter_indices: an array that gives the indices in `operand` to which each
918      update in `updates` should be applied.
919    updates: the updates that should be scattered onto `operand`.
920    dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
921      how dimensions of `operand`, `start_indices`, `updates` and the output
922      relate.
923    indices_are_sorted: whether `scatter_indices` is known to be sorted. If
924      true, may improve performance on some backends.
925    unique_indices: whether the indices to be updated in ``operand`` are
926      guaranteed to not overlap with each other. If true, may improve performance on
927      some backends.
928
929  Returns:
930    An array containing the sum of `operand` and the scattered updates.
931  """
932  jaxpr, consts = _reduction_jaxpr(add, _abstractify(_const(operand, 0)))
933  return scatter_add_p.bind(
934      operand, scatter_indices, updates, update_jaxpr=jaxpr,
935      update_consts=consts, dimension_numbers=dimension_numbers,
936      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
937
938def scatter_mul(operand: Array, scatter_indices: Array, updates: Array,
939                dimension_numbers: ScatterDimensionNumbers, *,
940                indices_are_sorted: bool = False,
941                unique_indices: bool = False) -> Array:
942  """Scatter-multiply operator.
943
944  Wraps `XLA's Scatter operator
945  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
946  multiplication is used to combine updates and values from `operand`.
947
948  The semantics of scatter are complicated and its API is subject to change.
949
950  Args:
951    operand: an array to which the scatter should be applied
952    scatter_indices: an array that gives the indices in `operand` to which each
953      update in `updates` should be applied.
954    updates: the updates that should be scattered onto `operand`.
955    dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
956      how dimensions of `operand`, `start_indices`, `updates` and the output
957      relate.
958    indices_are_sorted: whether `scatter_indices` is known to be sorted. If
959      true, may improve performance on some backends.
960    unique_indices: whether the indices to be updated in ``operand`` are
961      guaranteed to not overlap with each other. If true, may improve performance on
962      some backends.
963
964  Returns:
965    An array containing the sum of `operand` and the scattered updates.
966  """
967  jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1)))
968  return scatter_mul_p.bind(
969      operand, scatter_indices, updates, update_jaxpr=jaxpr,
970      update_consts=consts, dimension_numbers=dimension_numbers,
971      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
972
973def scatter_min(operand: Array, scatter_indices: Array, updates: Array,
974                dimension_numbers: ScatterDimensionNumbers, *,
975                indices_are_sorted: bool = False,
976                unique_indices: bool = False) -> Array:
977  """Scatter-min operator.
978
979  Wraps `XLA's Scatter operator
980  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
981  the `min` function is used to combine updates and values from `operand`.
982
983  The semantics of scatter are complicated and its API is subject to change.
984
985  Args:
986    operand: an array to which the scatter should be applied
987    scatter_indices: an array that gives the indices in `operand` to which each
988      update in `updates` should be applied.
989    updates: the updates that should be scattered onto `operand`.
990    dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
991      how dimensions of `operand`, `start_indices`, `updates` and the output
992      relate.
993    indices_are_sorted: whether `scatter_indices` is known to be sorted. If
994      true, may improve performance on some backends.
995    unique_indices: whether the indices to be updated in ``operand`` are
996      guaranteed to not overlap with each other. If true, may improve performance on
997      some backends.
998
999  Returns:
1000    An array containing the sum of `operand` and the scattered updates.
1001  """
1002  jaxpr, consts = _reduction_jaxpr(min, _abstractify(_const(operand, 0)))
1003  return scatter_min_p.bind(
1004      operand, scatter_indices, updates, update_jaxpr=jaxpr,
1005      update_consts=consts, dimension_numbers=dimension_numbers,
1006      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
1007
1008def scatter_max(operand: Array, scatter_indices: Array, updates: Array,
1009                dimension_numbers: ScatterDimensionNumbers, *,
1010                indices_are_sorted: bool = False,
1011                unique_indices: bool = False) -> Array:
1012  """Scatter-max operator.
1013
1014  Wraps `XLA's Scatter operator
1015  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
1016  the `max` function is used to combine updates and values from `operand`.
1017
1018  The semantics of scatter are complicated and its API is subject to change.
1019
1020  Args:
1021    operand: an array to which the scatter should be applied
1022    scatter_indices: an array that gives the indices in `operand` to which each
1023      update in `updates` should be applied.
1024    updates: the updates that should be scattered onto `operand`.
1025    dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
1026      how dimensions of `operand`, `start_indices`, `updates` and the output
1027      relate.
1028    indices_are_sorted: whether `scatter_indices` is known to be sorted. If
1029      true, may improve performance on some backends.
1030    unique_indices: whether the indices to be updated in ``operand`` are
1031      guaranteed to not overlap with each other. If true, may improve performance on
1032      some backends.
1033
1034  Returns:
1035    An array containing the sum of `operand` and the scattered updates.
1036  """
1037  jaxpr, consts = _reduction_jaxpr(max, _abstractify(_const(operand, 0)))
1038  return scatter_max_p.bind(
1039      operand, scatter_indices, updates, update_jaxpr=jaxpr,
1040      update_consts=consts, dimension_numbers=dimension_numbers,
1041      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
1042
1043# Define this outside of scatter to ensure cache hits.
1044_scatter_reduction_computation = lambda x, y: y
1045
1046def scatter(operand: Array, scatter_indices: Array, updates: Array,
1047            dimension_numbers: ScatterDimensionNumbers, *,
1048            indices_are_sorted: bool = False,
1049            unique_indices: bool = False) -> Array:
1050  """Scatter-update operator.
1051
1052  Wraps `XLA's Scatter operator
1053  <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates
1054  replace values from `operand`.
1055
1056  If multiple updates are performed to the same index of operand, they may be
1057  applied in any order.
1058
1059  The semantics of scatter are complicated and its API is subject to change.
1060
1061  Args:
1062    operand: an array to which the scatter should be applied
1063    scatter_indices: an array that gives the indices in `operand` to which each
1064      update in `updates` should be applied.
1065    updates: the updates that should be scattered onto `operand`.
1066    dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
1067      how dimensions of `operand`, `start_indices`, `updates` and the output
1068      relate.
1069    indices_are_sorted: whether `scatter_indices` is known to be sorted. If
1070      true, may improve performance on some backends.
1071    unique_indices: whether the indices to be updated in ``operand`` are
1072      guaranteed to not overlap with each other. If true, may improve performance on
1073      some backends.
1074
1075  Returns:
1076    An array containing the sum of `operand` and the scattered updates.
1077  """
1078  jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
1079                                   _abstractify(_const(operand, 0)))
1080  return scatter_p.bind(
1081      operand, scatter_indices, updates, update_jaxpr=jaxpr,
1082      update_consts=consts, dimension_numbers=dimension_numbers,
1083      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
1084
1085def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
1086  indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1)
1087  indices = indices % np.array([src.shape[ax] for ax in axes])
1088  slice_sizes = list(src.shape)
1089  for ax in axes:
1090    slice_sizes[ax] = 1
1091  offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1))
1092  dnums = GatherDimensionNumbers(
1093      offset_dims=offset_dims,
1094      collapsed_slice_dims=axes,
1095      start_index_map=axes)
1096  return gather(src, indices, dimension_numbers=dnums,
1097                slice_sizes=tuple(slice_sizes))
1098
1099def transpose(operand: Array, permutation: Sequence[int]) -> Array:
1100  """Wraps XLA's `Transpose
1101  <https://www.tensorflow.org/xla/operation_semantics#transpose>`_
1102  operator.
1103  """
1104  permutation = tuple(permutation)
1105  if permutation == tuple(range(len(permutation))):
1106    return operand
1107  else:
1108    return transpose_p.bind(operand, permutation=permutation)
1109
1110def argmin(operand: Array, axis: int,
1111           index_dtype: DType) -> Tuple[Array, Array]:
1112  """Computes the index of the minimum element along ``axis``."""
1113  return argmin_p.bind(operand, axes=(axis,),
1114                       index_dtype=dtypes.canonicalize_dtype(index_dtype))
1115
1116def argmax(operand: Array, axis: int,
1117           index_dtype: DType) -> Tuple[Array, Array]:
1118  """Computes the index of the maximum element along ``axis``."""
1119  return argmax_p.bind(operand, axes=(axis,),
1120                       index_dtype=dtypes.canonicalize_dtype(index_dtype))
1121
1122def reduce(operands: Array, init_values: Array, computation: Callable,
1123           dimensions: Sequence[int]) -> Array:
1124  """Wraps XLA's `Reduce
1125  <https://www.tensorflow.org/xla/operation_semantics#reduce>`_
1126  operator.
1127  """
1128  flat_operands, operand_tree = tree_util.tree_flatten(operands)
1129  flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
1130  if operand_tree != init_value_tree:
1131    raise ValueError('Operands must have the same tree structure as init_values:'
1132                     f' {operand_tree} vs. {init_value_tree}')
1133  if len(flat_operands) != len(flat_init_values):
1134    raise ValueError('Must have same total number of operands as init_values: '
1135                     f' {len(flat_operands)} vs. {len(flat_init_values)}')
1136  monoid_reducer = _get_monoid_reducer(computation, flat_init_values)
1137  if monoid_reducer:
1138    return monoid_reducer(*flat_operands, dimensions)
1139  else:
1140    flat_init_avals = safe_map(_abstractify, flat_init_values)
1141    jaxpr, consts, out_tree = _variadic_reduction_jaxpr(
1142        computation, tuple(flat_init_avals), init_value_tree)
1143    out = reduce_p.bind(*(flat_operands + flat_init_values), computation=computation,
1144                         jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
1145    return tree_util.tree_unflatten(out_tree, out)
1146
1147@cache()
1148def _reduction_jaxpr(computation, aval):
1149  pval = pe.PartialVal.unknown(aval)
1150  comp = lu.wrap_init(lambda x, y: (computation(x, y),))
1151  jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
1152  return jaxpr, consts
1153
1154@cache()
1155def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
1156  avals = tree_util.tree_unflatten(aval_tree, flat_avals)
1157  flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
1158  pvals = safe_map(pe.PartialVal.unknown, flat_in_avals)
1159  comp = lu.wrap_init(computation)
1160  flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
1161  jaxpr, _, consts = pe.trace_to_jaxpr(flat_comp, tuple(pvals),
1162                                       instantiate=False)
1163  return jaxpr, consts, out_tree()
1164
1165def _get_monoid_reducer(monoid_op: Callable, xs: Array) -> Optional[Callable]:
1166  if len(xs) != 1:
1167    return None
1168  x, = xs
1169  aval = core.get_aval(x)
1170  dtype = _dtype(x)
1171  if (type(aval) is ConcreteArray) and aval.shape == ():
1172    if monoid_op is add:
1173      return np.equal(aval.val, 0) and _reduce_sum
1174    elif monoid_op is mul:
1175      return np.equal(aval.val, 1) and _reduce_prod
1176    elif monoid_op is bitwise_or and dtype == np.bool_:
1177      return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_or
1178    elif monoid_op is bitwise_and and dtype == np.bool_:
1179      return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_and
1180    elif monoid_op is max:
1181      return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_max
1182    elif monoid_op is min:
1183      return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_min
1184  return None
1185
1186def _get_max_identity(dtype: DType) -> Array:
1187  if dtypes.issubdtype(dtype, np.inexact):
1188    return np.array(-np.inf, dtype)
1189  elif dtypes.issubdtype(dtype, np.integer):
1190    return np.array(dtypes.iinfo(dtype).min, dtype)
1191  elif dtypes.issubdtype(dtype, np.bool_):
1192    return np.array(False, np.bool_)
1193
1194def _get_min_identity(dtype: DType) -> Array:
1195  if dtypes.issubdtype(dtype, np.inexact):
1196    return np.array(np.inf, dtype)
1197  elif dtypes.issubdtype(dtype, np.integer):
1198    return np.array(dtypes.iinfo(dtype).max, dtype)
1199  elif dtypes.issubdtype(dtype, np.bool_):
1200    return np.array(True, np.bool_)
1201
1202def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array:
1203  return reduce_sum_p.bind(operand, axes=tuple(axes))
1204
1205def _reduce_prod(operand: Array, axes: Sequence[int]) -> Array:
1206  return reduce_prod_p.bind(operand, axes=tuple(axes))
1207
1208def _reduce_max(operand: Array, axes: Sequence[int]) -> Array:
1209  return reduce_max_p.bind(operand, axes=tuple(axes))
1210
1211def _reduce_min(operand: Array, axes: Sequence[int]) -> Array:
1212  return reduce_min_p.bind(operand, axes=tuple(axes))
1213
1214def _reduce_or(operand: Array, axes: Sequence[int]) -> Array:
1215  return reduce_or_p.bind(operand, axes=tuple(axes))
1216
1217def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
1218  return reduce_and_p.bind(operand, axes=tuple(axes))
1219
1220def reduce_window(operand: Array, init_value: Array, computation: Callable,
1221                  window_dimensions: Shape, window_strides: Sequence[int],
1222                  padding: Union[str, Sequence[Tuple[int, int]]],
1223                  base_dilation: Optional[Sequence[int]] = None,
1224                  window_dilation: Optional[Sequence[int]] = None) -> Array:
1225  """Wraps XLA's `ReduceWindowWithGeneralPadding
1226  <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
1227  operator.
1228  """
1229  if isinstance(padding, str):
1230    dilated_window_dims = (window_dimensions if window_dilation is None else
1231                           _dilate_shape(window_dimensions, window_dilation))
1232    padding = tuple(padtype_to_pads(operand.shape, dilated_window_dims,
1233                                    window_strides, padding))
1234  else:
1235    padding = tuple(padding)
1236  if base_dilation is None:
1237    base_dilation = (1,) * len(window_dimensions)
1238  if window_dilation is None:
1239    window_dilation = (1,) * len(window_dimensions)
1240  monoid_reducer = _get_monoid_window_reducer(computation, init_value)
1241  if monoid_reducer:
1242    return monoid_reducer(operand, window_dimensions, window_strides, padding,
1243                          base_dilation, window_dilation)
1244  else:
1245    jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
1246    return reduce_window_p.bind(
1247        operand, init_value, jaxpr=jaxpr, consts=consts,
1248        window_dimensions=tuple(window_dimensions),
1249        window_strides=tuple(window_strides), padding=padding,
1250        base_dilation=tuple(base_dilation),
1251        window_dilation=tuple(window_dilation))
1252
1253def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
1254  aval = core.get_aval(x)
1255  if (type(aval) is ConcreteArray) and aval.shape == ():
1256    if monoid_op is add:
1257      return aval.val == 0 and _reduce_window_sum
1258    elif monoid_op is max:
1259      return aval.val == _get_max_identity(aval.dtype) and _reduce_window_max
1260    elif monoid_op is min:
1261      return aval.val == _get_min_identity(aval.dtype) and _reduce_window_min
1262  return None
1263
1264def _reduce_window_sum(operand: Array, window_dimensions: Shape,
1265                       window_strides: Sequence[int],
1266                       padding: Sequence[Tuple[int, int]],
1267                       base_dilation: Optional[Sequence[int]] = None,
1268                       window_dilation: Optional[Sequence[int]] = None) -> Array:
1269  if base_dilation is None:
1270    base_dilation = (1,) * len(window_dimensions)
1271  if window_dilation is None:
1272    window_dilation = (1,) * len(window_dimensions)
1273  return reduce_window_sum_p.bind(
1274      operand, window_dimensions=tuple(window_dimensions),
1275      window_strides=tuple(window_strides), padding=tuple(padding),
1276      base_dilation=tuple(base_dilation),
1277      window_dilation=tuple(window_dilation))
1278
1279def _reduce_window_prod(operand: Array, window_dimensions: Shape,
1280                        window_strides: Sequence[int],
1281                        padding: Sequence[Tuple[int, int]],
1282                        base_dilation: Optional[Sequence[int]] = None,
1283                        window_dilation: Optional[Sequence[int]] = None) -> Array:
1284  init_value = _const(operand, 1)
1285  jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value))
1286  if base_dilation is None:
1287    base_dilation = (1,) * len(window_dimensions)
1288  if window_dilation is None:
1289    window_dilation = (1,) * len(window_dimensions)
1290  return reduce_window_p.bind(
1291      operand, init_value, jaxpr=jaxpr, consts=consts,
1292      window_dimensions=tuple(window_dimensions),
1293      window_strides=tuple(window_strides), padding=tuple(padding),
1294      base_dilation=tuple(base_dilation),
1295      window_dilation=tuple(window_dilation))
1296
1297def _reduce_window_max(operand: Array, window_dimensions: Shape,
1298                       window_strides: Sequence[int],
1299                       padding: Sequence[Tuple[int, int]],
1300                       base_dilation: Optional[Sequence[int]] = None,
1301                       window_dilation: Optional[Sequence[int]] = None) -> Array:
1302  if base_dilation is None:
1303    base_dilation = (1,) * len(window_dimensions)
1304  if window_dilation is None:
1305    window_dilation = (1,) * len(window_dimensions)
1306  return reduce_window_max_p.bind(
1307      operand, window_dimensions=tuple(window_dimensions),
1308      window_strides=tuple(window_strides), padding=tuple(padding),
1309      base_dilation=tuple(base_dilation),
1310      window_dilation=tuple(window_dilation))
1311
1312def _reduce_window_min(operand: Array, window_dimensions: Shape,
1313                       window_strides: Sequence[int],
1314                       padding: Sequence[Tuple[int, int]],
1315                       base_dilation: Optional[Sequence[int]] = None,
1316                       window_dilation: Optional[Sequence[int]] = None) -> Array:
1317  if base_dilation is None:
1318    base_dilation = (1,) * len(window_dimensions)
1319  if window_dilation is None:
1320    window_dilation = (1,) * len(window_dimensions)
1321  return reduce_window_min_p.bind(
1322      operand, window_dimensions=tuple(window_dimensions),
1323      window_strides=tuple(window_strides), padding=tuple(padding),
1324      base_dilation=tuple(base_dilation),
1325      window_dilation=tuple(window_dilation))
1326
1327def _select_and_scatter(operand: Array, select: Callable,
1328                        window_dimensions: Shape, window_strides: Sequence[int],
1329                        padding: Sequence[Tuple[int, int]], source: Array,
1330                        init_value: Array, scatter: Callable,
1331                        base_dilation: Sequence[int],
1332                        window_dilation: Sequence[int]) -> Array:
1333  select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value))
1334  scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value))
1335  return select_and_scatter_p.bind(
1336      operand, source, init_value, select_jaxpr=select_jaxpr,
1337      select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
1338      scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
1339      window_strides=tuple(window_strides), padding=tuple(padding),
1340      base_dilation=tuple(base_dilation),
1341      window_dilation=tuple(window_dilation))
1342
1343def _select_and_scatter_add(source: Array, operand: Array,
1344                            select_prim: core.Primitive,
1345                            window_dimensions: Shape,
1346                            window_strides: Sequence[int],
1347                            padding: Sequence[Tuple[int, int]]) -> Array:
1348  return select_and_scatter_add_p.bind(
1349      source, operand, select_prim=select_prim,
1350      window_dimensions=tuple(window_dimensions),
1351      window_strides=tuple(window_strides), padding=tuple(padding))
1352
1353def _select_and_gather_add(tangents: Array, operand: Array,
1354                           select_prim: core.Primitive,
1355                           window_dimensions: Shape,
1356                           window_strides: Sequence[int],
1357                           padding: Sequence[Tuple[int, int]],
1358                           base_dilation: Sequence[int],
1359                           window_dilation: Sequence[int]) -> Array:
1360  """Extracts the tangent corresponding to the minimum or maximum element in each
1361  window of the `operand` array.
1362
1363  Wraps XLA's `ReduceWindow
1364  <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
1365  operator, which applies a reduction function to all elements in each window of the
1366  input multi-dimensional array. In this case, the input multi-dimensional array is
1367  built by packing each element in the `operand` array with its corresponding
1368  element in the `tangents` array.
1369
1370  Args:
1371    tangents: an array
1372    operand: an array with the same shape as `tangents`
1373    select_prim: a reduction function (restricted to `ge_p` and `le_p`)
1374    window_dimensions: an array of integers for window dimension values
1375    window_strides: an array of integers for window stride values
1376    base_dilation: an array of integers for base dilation values
1377    window_dilation: an array of integers for window dilation values
1378
1379  Returns:
1380    An array containing the elements in `tangents` corresponding to the output of the
1381    reduction of `operand` fin each window.
1382  """
1383  return select_and_gather_add_p.bind(
1384      tangents, operand, select_prim=select_prim,
1385      window_dimensions=tuple(window_dimensions),
1386      window_strides=tuple(window_strides), padding=tuple(padding),
1387      base_dilation=tuple(base_dilation),
1388      window_dilation=tuple(window_dilation))
1389
1390def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
1391         is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
1392  """Wraps XLA's `Sort
1393  <https://www.tensorflow.org/xla/operation_semantics#sort>`_
1394  operator.
1395
1396  Args:
1397    operand : Array or sequence of arrays
1398    dimension : integer dimension along which to sort. Default: -1.
1399    is_stable : boolean specifying whether to use a stable sort. Default: True.
1400    num_keys : number of operands to treat as sort keys. Default: 1.
1401      For num_keys > 1, the sort order will be determined lexicographically using
1402      the first `num_keys` arrays, with the first key being primary.
1403      The remaining operands will be returned with the same permutation.
1404
1405  Returns:
1406    operand : sorted version of the input or inputs.
1407  """
1408  if isinstance(operand, Sequence):
1409    if len(operand) == 0:
1410      raise TypeError("Sort requires at least one operand")
1411    if not (1 <= num_keys <= len(operand)):
1412      raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}")
1413    dimension = canonicalize_axis(dimension, len(operand[0].shape))
1414    return tuple(sort_p.bind(*operand, dimension=dimension,
1415                             is_stable=is_stable,
1416                             num_keys=num_keys))
1417  else:
1418    if num_keys != 1:
1419      raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
1420    dimension = canonicalize_axis(dimension, len(operand.shape))
1421    return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
1422
1423def sort_key_val(keys: Array, values: Array, dimension: int = -1,
1424                 is_stable: bool = True) -> Tuple[Array, Array]:
1425  """Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
1426  dimension = canonicalize_axis(dimension, len(keys.shape))
1427  k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
1428  return k, v
1429
1430def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
1431  """Returns top ``k`` values and their indices along the last axis of ``operand``."""
1432  k = int(k)
1433  if k < 0:
1434    raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
1435  return top_k_p.bind(operand, k=k)
1436
1437def tie_in(x: Array, y: Array) -> Array:
1438  """Deprecated. Ignores ``x`` and returns ``y``."""
1439  return y
1440
1441def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array:
1442  """Returns an array of `shape` filled with `fill_value`.
1443
1444  Args:
1445    shape: sequence of integers, describing the shape of the output array.
1446    fill_value: the value to fill the new array with.
1447    dtype: the type of the output array, or `None`. If not `None`, `fill_value`
1448      will be cast to `dtype`.
1449  """
1450  shape = canonicalize_shape(shape)
1451  if np.shape(fill_value):
1452    msg = "full must be called with scalar fill_value, got fill_value.shape {}."
1453    raise TypeError(msg.format(np.shape(fill_value)))
1454  dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
1455  fill_value = convert_element_type(fill_value, dtype)
1456  return broadcast(fill_value, shape)
1457
1458def _device_put_raw(x):
1459  if isinstance(x, xla.DeviceArray):
1460    return x
1461  else:
1462    aval = raise_to_shaped(core.get_aval(x))
1463    return xla.array_result_handler(None, aval)(*xla.device_put(x))
1464
1465def iota(dtype: DType, size: int) -> Array:
1466  """Wraps XLA's `Iota
1467  <https://www.tensorflow.org/xla/operation_semantics#iota>`_
1468  operator.
1469  """
1470  if config.omnistaging_enabled:
1471    dtype = dtypes.canonicalize_dtype(dtype)
1472    size = core.concrete_or_error(int, size, "size argument of lax.iota")
1473    return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
1474  else:
1475    size = size if type(size) is masking.Poly else int(size)
1476    shape = canonicalize_shape((size,))
1477    dtype = dtypes.canonicalize_dtype(dtype)
1478    lazy_expr = lazy.iota(dtype, shape[0])
1479    aval = ShapedArray(shape, dtype)
1480    return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
1481
1482def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
1483  """Convenience wrapper around ``iota``."""
1484  dtype = dtypes.canonicalize_dtype(dtype)
1485  shape = canonicalize_shape(shape)
1486  dimension = core.concrete_or_error(
1487      int, dimension, "dimension argument of lax.broadcasted_iota")
1488  return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension)
1489
1490def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
1491  """Like numpy.eye, create a 2D array with ones on a diagonal."""
1492  N, M = tuple(map(int, shape))
1493  offset = int(offset)
1494  dtype = dtypes.canonicalize_dtype(dtype)
1495  if config.omnistaging_enabled:
1496    bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
1497                  broadcasted_iota(np.int32, (N, M), 1))
1498    return convert_element_type_p.bind(bool_eye, new_dtype=dtype)
1499  else:
1500    lazy_expr = lazy.eye(dtype, (N, M), offset)
1501    aval = ShapedArray((N, M), dtype)
1502    return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
1503
1504def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
1505  """This utility function exists for creating Kronecker delta arrays."""
1506  shape = tuple(map(int, shape))
1507  axes = tuple(map(int, axes))
1508  dtype = dtypes.canonicalize_dtype(dtype)
1509  base_shape = tuple(np.take(shape, axes))
1510  if config.omnistaging_enabled:
1511    iotas = [broadcasted_iota(np.uint32, base_shape, i)
1512             for i in range(len(base_shape))]
1513    eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
1514    result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype)
1515    return broadcast_in_dim(result, shape, axes)
1516  else:
1517    lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
1518    aval = ShapedArray(shape, dtype)
1519    return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
1520
1521def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
1522  """Like numpy.tri, create a 2D array with ones below a diagonal."""
1523  N, M = tuple(map(int, shape))
1524  offset = int(offset)
1525  dtype = dtypes.canonicalize_dtype(dtype)
1526  if config.omnistaging_enabled:
1527    bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
1528                  broadcasted_iota(np.int32, (N, M), 1))
1529    return convert_element_type_p.bind(bool_tri, new_dtype=dtype)
1530  else:
1531    lazy_expr = lazy.tri(dtype, (N, M), offset)
1532    aval = ShapedArray((N, M), dtype)
1533    return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
1534
1535def stop_gradient(x):
1536  """Stops gradient computation.
1537
1538  Operationally ``stop_gradient`` is the identity function, that is, it returns
1539  argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
1540  gradients during forward or reverse-mode automatic differentiation. If there
1541  are multiple nested gradient computations, ``stop_gradient`` stops gradients
1542  for all of them.
1543
1544  For example:
1545
1546  >>> jax.grad(lambda x: x**2)(3.)
1547  array(6., dtype=float32)
1548  >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
1549  array(0., dtype=float32)
1550  >>> jax.grad(jax.grad(lambda x: x**2))(3.)
1551  array(2., dtype=float32)
1552  >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
1553  array(0., dtype=float32)
1554  """
1555  def stop(x):
1556    if (dtypes.issubdtype(_dtype(x), np.floating) or
1557        dtypes.issubdtype(_dtype(x), np.complexfloating)):
1558      return ad_util.stop_gradient_p.bind(x)
1559    else:
1560      return x  # only bind primitive on inexact dtypes, to avoid some staging
1561  return tree_map(stop, x)
1562
1563
1564### convenience wrappers around traceables
1565
1566
1567def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
1568         padding: str, precision: PrecisionLike = None) -> Array:
1569  """Convenience wrapper around `conv_general_dilated`.
1570
1571  Args:
1572    lhs: a rank `n+2` dimensional input array.
1573    rhs: a rank `n+2` dimensional array of kernel weights.
1574    window_strides: a sequence of `n` integers, representing the inter-window
1575      strides.
1576    padding: either the string `'SAME'`, the string `'VALID'`.
1577    precision: Optional. Either ``None``, which means the default precision for
1578      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
1579      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
1580      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
1581
1582  Returns:
1583    An array containing the convolution result.
1584  """
1585  return conv_general_dilated(lhs, rhs, window_strides, padding,
1586                              precision=precision)
1587
1588def conv_with_general_padding(lhs: Array, rhs: Array,
1589                              window_strides: Sequence[int],
1590                              padding: Union[str, Sequence[Tuple[int, int]]],
1591                              lhs_dilation: Optional[Sequence[int]],
1592                              rhs_dilation: Optional[Sequence[int]],
1593                              precision: PrecisionLike = None) -> Array:
1594  """Convenience wrapper around `conv_general_dilated`.
1595
1596  Args:
1597    lhs: a rank `n+2` dimensional input array.
1598    rhs: a rank `n+2` dimensional array of kernel weights.
1599    window_strides: a sequence of `n` integers, representing the inter-window
1600      strides.
1601    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
1602      `n` `(low, high)` integer pairs that give the padding to apply before and
1603      after each spatial dimension.
1604    lhs_dilation: `None`, or a sequence of `n` integers, giving the
1605      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
1606      is also known as transposed convolution.
1607    rhs_dilation: `None`, or a sequence of `n` integers, giving the
1608      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
1609      is also known as atrous convolution.
1610    precision: Optional. Either ``None``, which means the default precision for
1611      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
1612      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
1613      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
1614
1615  Returns:
1616    An array containing the convolution result.
1617  """
1618  return conv_general_dilated(
1619      lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation,
1620      rhs_dilation=rhs_dilation, precision=precision)
1621
1622
1623def _conv_transpose_padding(k, s, padding):
1624  """Calculate before and after padding for a dim of transposed convolution.
1625
1626  Args:
1627    k: int: kernel dimension.
1628    s: int: dimension stride value.
1629    padding: 'same' or 'valid' padding mode for original forward conv.
1630
1631  Returns:
1632    2-tuple: ints: before and after padding for transposed convolution.
1633  """
1634  if padding == 'SAME':
1635    pad_len = k + s - 2
1636    if s > k - 1:
1637      pad_a = k - 1
1638    else:
1639      pad_a = int(np.ceil(pad_len / 2))
1640  elif padding == 'VALID':
1641    pad_len = k + s - 2 + _max(k - s, 0)
1642    pad_a = k - 1
1643  else:
1644    raise ValueError('Padding mode must be `SAME` or `VALID`.')
1645  pad_b = pad_len - pad_a
1646  return pad_a, pad_b
1647
1648
1649def _flip_axes(x, axes):
1650  """Flip ndarray 'x' along each axis specified in axes tuple."""
1651  for axis in axes:
1652    x = np.flip(x, axis)
1653  return x
1654
1655
1656def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
1657                   padding: Union[str, Sequence[Tuple[int, int]]],
1658                   rhs_dilation: Optional[Sequence[int]] = None,
1659                   dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
1660                   transpose_kernel: bool = False,
1661                   precision: PrecisionLike = None) -> Array:
1662  """Convenience wrapper for calculating the N-d convolution "transpose".
1663
1664  This function directly calculates a fractionally strided conv rather than
1665  indirectly calculating the gradient (transpose) of a forward convolution.
1666
1667  Args:
1668    lhs: a rank `n+2` dimensional input array.
1669    rhs: a rank `n+2` dimensional array of kernel weights.
1670    strides: sequence of `n` integers, sets fractional stride.
1671    padding: 'SAME', 'VALID' will set as transpose of corresponding forward
1672      conv, or a sequence of `n` integer 2-tuples describing before-and-after
1673      padding for each `n` spatial dimension.
1674    rhs_dilation: `None`, or a sequence of `n` integers, giving the
1675      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
1676      is also known as atrous convolution.
1677    dimension_numbers: tuple of dimension descriptors as in
1678      lax.conv_general_dilated. Defaults to tensorflow convention.
1679    transpose_kernel: if True flips spatial axes and swaps the input/output
1680      channel axes of the kernel. This makes the output of this function identical
1681      to the gradient-derived functions like keras.layers.Conv2DTranspose
1682      applied to the same kernel. For typical use in neural nets this is completely
1683      pointless and just makes input/output channel specification confusing.
1684    precision: Optional. Either ``None``, which means the default precision for
1685      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
1686      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
1687      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
1688
1689  Returns:
1690    Transposed N-d convolution, with output padding following the conventions of
1691    keras.layers.Conv2DTranspose.
1692  """
1693  assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
1694  ndims = len(lhs.shape)
1695  one = (1,) * (ndims - 2)
1696  # Set dimensional layout defaults if not specified.
1697  if dimension_numbers is None:
1698    if ndims == 2:
1699      dimension_numbers = ('NC', 'IO', 'NC')
1700    elif ndims == 3:
1701      dimension_numbers = ('NHC', 'HIO', 'NHC')
1702    elif ndims == 4:
1703      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
1704    elif ndims == 5:
1705      dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
1706    else:
1707      raise ValueError('No 4+ dimensional dimension_number defaults.')
1708  dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
1709  k_shape = np.take(rhs.shape, dn.rhs_spec)
1710  k_sdims = k_shape[2:]
1711  # Calculate correct output shape given padding and strides.
1712  pads: Union[str, Sequence[Tuple[int, int]]]
1713  if padding in {'SAME', 'VALID'}:
1714    if rhs_dilation is None:
1715      rhs_dilation = (1,) * (rhs.ndim - 2)
1716    effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation)
1717    pads = [_conv_transpose_padding(k, s, padding)
1718            for k,s in zip(effective_k_size, strides)]
1719  else:
1720    pads = padding
1721  if transpose_kernel:
1722    # flip spatial dims and swap input / output channel axes
1723    rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
1724    rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
1725  return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
1726                              precision=precision)
1727
1728
1729def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
1730              shape: Optional[Shape] = None) -> Array:
1731  """Create a full array like np.full based on the example array `x`.
1732
1733  Args:
1734    x: example array-like, used for shape and dtype information.
1735    fill_value: a scalar value to fill the entries of the output array.
1736    dtype: optional, a dtype parameter for the output ndarray.
1737    shape: optional, a shape parameter for the output ndarray.
1738
1739  Returns:
1740    An ndarray with the same shape as `x` with its entries set equal to
1741    `fill_value`, similar to the output of np.full.
1742  """
1743  fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
1744  if not config.omnistaging_enabled:
1745    fill_value = tie_in(x, fill_value)
1746  return full(fill_shape, fill_value, dtype or _dtype(x))
1747
1748
1749def collapse(operand: Array, start_dimension: int,
1750             stop_dimension: int) -> Array:
1751  """Collapses dimensions of an array into a single dimension.
1752
1753  For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
1754  ``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
1755  dimension are laid out major-to-minor, i.e., with the lowest-numbered
1756  dimension as the slowest varying dimension.
1757
1758  Args:
1759    operand: an input array.
1760    start_dimension: the start of the dimensions to collapse (inclusive).
1761    stop_dimension: the end of the dimensions to collapse (exclusive).
1762
1763  Returns:
1764    An array where dimensions ``[start_dimension, stop_dimension)`` have been
1765    collapsed (raveled) into a single dimension.
1766  """
1767  lo, hi = start_dimension, stop_dimension
1768  size = prod(operand.shape[lo:hi])
1769  new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
1770  return reshape(operand, new_shape)
1771
1772
1773def slice_in_dim(operand: Array, start_index: Optional[int],
1774                 limit_index: Optional[int],
1775                 stride: int = 1, axis: int = 0)-> Array:
1776  """Convenience wrapper around slice applying to only one dimension."""
1777  start_indices = [0] * operand.ndim
1778  limit_indices = list(operand.shape)
1779  strides = [1] * operand.ndim
1780
1781  # translate `None`
1782  len_axis = operand.shape[axis]
1783  start_index_int = _canonicalize_dimension(start_index) if start_index is not None else 0
1784  limit_index_int = _canonicalize_dimension(limit_index) if limit_index is not None else len_axis
1785
1786  # translate negative indices
1787  if start_index_int < 0:
1788    start_index_int = start_index_int + len_axis
1789  if limit_index_int < 0:
1790    limit_index_int = limit_index_int + len_axis
1791
1792  axis = int(axis)
1793  start_indices[axis] = start_index_int
1794  limit_indices[axis] = limit_index_int
1795  strides[axis] = int(stride)
1796
1797  return slice(operand, start_indices, limit_indices, strides)
1798
1799
1800def index_in_dim(operand: Array, index: int, axis: int = 0,
1801                 keepdims: bool = True) -> Array:
1802  """Convenience wrapper around slice to perform int indexing."""
1803  index, axis = int(index), int(axis)
1804  axis_size = operand.shape[axis]
1805  wrapped_index = index + axis_size if index < 0 else index
1806  if not 0 <= wrapped_index < axis_size:
1807    msg = 'index {} is out of bounds for axis {} with size {}'
1808    raise IndexError(msg.format(index, axis, axis_size))
1809  result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
1810  if keepdims:
1811    return result
1812  else:
1813    return squeeze(result, (axis,))
1814
1815
1816def dynamic_slice_in_dim(operand: Array, start_index: Array,
1817                         slice_size: int, axis: int = 0) -> Array:
1818  """Convenience wrapper around dynamic_slice applying to one dimension."""
1819  start_indices = [_zero(start_index)] * operand.ndim
1820  slice_sizes = list(operand.shape)
1821
1822  axis = int(axis)
1823  start_indices[axis] = start_index
1824  slice_sizes[axis] = int(slice_size)
1825  return dynamic_slice(operand, start_indices, slice_sizes)
1826
1827
1828def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
1829                         keepdims: bool = True) -> Array:
1830  """Convenience wrapper around dynamic_slice to perform int indexing."""
1831  result = dynamic_slice_in_dim(operand, index, 1, axis)
1832  if keepdims:
1833    return result
1834  else:
1835    return squeeze(result, (axis,))
1836
1837
1838def dynamic_update_slice_in_dim(operand: Array, update: Array,
1839                                start_index: Array, axis: int) -> Array:
1840  """Convenience wrapper around :func:`dynamic_update_slice` to update a slice
1841     in a single ``axis``.
1842  """
1843  axis = int(axis)
1844  start_indices = [_zero(start_index)] * _ndim(operand)
1845  start_indices[axis] = start_index
1846  return dynamic_update_slice(operand, update, start_indices)
1847
1848
1849def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
1850                                axis: int) -> Array:
1851  """Convenience wrapper around :func:`dynamic_update_slice` to update a slice
1852     of size 1 in a single ``axis``.
1853  """
1854  axis = int(axis)
1855  if _ndim(update) != _ndim(operand):
1856    assert _ndim(update) + 1 == _ndim(operand)
1857    update = expand_dims(update, (axis,))
1858  return dynamic_update_slice_in_dim(operand, update, index, axis)
1859
1860
1861def batch_matmul(lhs: Array, rhs: Array,
1862                 precision: PrecisionLike = None) -> Array:
1863  """Batch matrix multiplication."""
1864  if _min(lhs.ndim, rhs.ndim) < 2:
1865    raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
1866                     .format(lhs.ndim, rhs.ndim))
1867  if lhs.ndim != rhs.ndim:
1868    raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
1869                     .format(lhs.ndim, rhs.ndim))
1870  lhs_contract = (lhs.ndim - 1,)
1871  rhs_contract = (rhs.ndim - 2,)
1872  batch = tuple(range(lhs.ndim - 2))
1873  return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)),
1874                     precision=precision)
1875
1876
1877# These functions also exist in the XLA client library, but we treat them
1878# as non-primitive to maintain a smaller set of autodiff primitives.
1879
1880def square(x: Array) -> Array:
1881  r"""Elementwise square: :math:`x^2`."""
1882  return integer_pow(x, 2)
1883
1884def reciprocal(x: Array) -> Array:
1885  r"""Elementwise reciprocal: :math:`1 \over x`."""
1886  return integer_pow(x, -1)
1887
1888def _upcast_fp16_for_computation(f):
1889  @functools.wraps(f)
1890  def f_wrapped(x):
1891    dtype = _dtype(x)
1892    if dtype == np.float16 or dtype == dtypes.bfloat16:
1893      return convert_element_type(
1894        f(convert_element_type(x, np.float32)), dtype)
1895    return f(x)
1896
1897  return f_wrapped
1898
1899def tan(x: Array) -> Array:
1900  r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
1901  return tan_p.bind(x)
1902
1903def asin(x: Array) -> Array:
1904  r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
1905  return asin_p.bind(x)
1906
1907def acos(x: Array) -> Array:
1908  r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
1909  return acos_p.bind(x)
1910
1911def atan(x: Array) -> Array:
1912  r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
1913  return atan_p.bind(x)
1914
1915def sinh(x: Array) -> Array:
1916  r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
1917  return sinh_p.bind(x)
1918
1919def cosh(x: Array) -> Array:
1920  r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
1921  return cosh_p.bind(x)
1922
1923def asinh(x: Array) -> Array:
1924  r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
1925  return asinh_p.bind(x)
1926
1927def acosh(x: Array) -> Array:
1928  r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
1929  return acosh_p.bind(x)
1930
1931def atanh(x: Array) -> Array:
1932  r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
1933  return atanh_p.bind(x)
1934
1935
1936# Add some methods to ShapedArray that rely on lax primitives
1937
1938ShapedArray.broadcast = core.aval_method(broadcast)
1939ShapedArray.transpose = core.aval_method(transpose)  # clobbered by lax_numpy
1940ShapedArray.reshape = core.aval_method(reshape)      # clobbered by lax_numpy
1941
1942def _iter(tracer):
1943  if tracer.ndim == 0:
1944    raise TypeError("iteration over a 0-d array")  # same as numpy error
1945  else:
1946    n = int(tracer.shape[0])
1947    # return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
1948    return iter([index_in_dim(tracer, i, keepdims=False) for i in range(n)])
1949ShapedArray._iter = staticmethod(_iter)
1950
1951# Add some ad handlers that use (or could use) lax primitives
1952
1953def zeros_like_array(x):
1954  return full_like(x, 0)
1955
1956for t in itertools.chain(
1957    dtypes.python_scalar_dtypes.keys(), array_types,
1958    [xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray]):
1959  ad_util.jaxval_adders[t] = add
1960ad_util.jaxval_zeros_likers[xla._DeviceArray] = zeros_like_array
1961ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array
1962ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
1963
1964
1965### primitives
1966
1967
1968_input_dtype = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype)
1969_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
1970_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype
1971
1972def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
1973                       multiple_results=False):
1974  prim = Primitive(name)
1975  prim.multiple_results = multiple_results
1976  prim.def_impl(partial(xla.apply_primitive, prim))
1977  prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
1978  xla.translations[prim] = translation_rule or partial(standard_translate, name)
1979  return prim
1980
1981
1982def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
1983  assert all(isinstance(arg, UnshapedArray) for arg in args), args
1984  least_specialized = _max(
1985      map(type, args), key=operator.attrgetter('array_abstraction_level'))
1986  if least_specialized is ConcreteArray:
1987    out_vals = prim.impl(*[x.val for x in args], **kwargs)
1988    if not prim.multiple_results:
1989      out_vals = [out_vals]
1990    out_avals = safe_map(ConcreteArray, out_vals)
1991  elif least_specialized is ShapedArray:
1992    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
1993    if not prim.multiple_results:
1994      shapes, dtypes = [shapes], [dtypes]
1995    out_avals = safe_map(ShapedArray, shapes, dtypes)
1996  elif least_specialized is UnshapedArray:
1997    dtypes = dtype_rule(*args, **kwargs)
1998    if not prim.multiple_results:
1999      dtypes = [dtypes]
2000    out_avals = safe_map(UnshapedArray, dtypes)
2001  else:
2002    raise TypeError(args, least_specialized)
2003  if not prim.multiple_results:
2004    return out_avals[0]
2005  return out_avals
2006
2007
2008def standard_translate(name, c, *args, **kwargs):
2009  xla_opname = ''.join(term.capitalize() for term in name.split('_'))
2010  return getattr(xops, xla_opname)(*args, **kwargs)
2011
2012
2013def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
2014  if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
2015    msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
2016    typename = str(np.dtype(aval.dtype).name)
2017    accepted_typenames = (t.__name__ for t in accepted_dtypes)
2018    raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
2019  return result_dtype(aval.dtype)
2020
2021
2022def unop(result_dtype, accepted_dtypes, name, translation_rule=None):
2023  dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
2024  prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
2025                            translation_rule=translation_rule)
2026  batching.defvectorized(prim)
2027  masking.defvectorized(prim)
2028  return prim
2029standard_unop = partial(unop, _identity)
2030_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
2031
2032
2033def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
2034  aval_dtypes = [aval.dtype for aval in avals]
2035  for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)):
2036    if not any(dtypes.issubdtype(aval_dtype, t) for t in types):
2037      if aval_dtype is dtypes.float0:
2038        raise TypeError(
2039            f"Called {name} with a float0 at position {i}. "
2040            "float0s do not support any operations by design, because they "
2041            "are not compatible with non-trivial vector spaces. No implicit dtype "
2042            "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
2043            "to cast a float0 array to a regular zeros array. \n"
2044            "If you didn't expect to get a float0 you might have accidentally "
2045            "taken a gradient with respect to an integer argument.")
2046      else:
2047        msg = ('{} does not accept dtype {} at position {}. '
2048               'Accepted dtypes at position {} are subtypes of {}.')
2049        typename = str(np.dtype(aval_dtype).name)
2050        typenames = ', '.join(t.__name__ for t in types)
2051        raise TypeError(msg.format(name, typename, i, i, typenames))
2052  _check_same_dtypes(name, False, *aval_dtypes)
2053  return result_dtype(*avals)
2054
2055
2056def _broadcasting_shape_rule(name, *avals):
2057  shapes = [aval.shape for aval in avals if aval.shape]
2058  if not shapes:
2059    return ()
2060  if len({len(shape) for shape in shapes}) != 1:
2061    msg = '{} got arrays of different rank: {}.'
2062    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
2063  result_shape = _try_broadcast_shapes(shapes)
2064  if result_shape is None:
2065    msg = '{} got incompatible shapes for broadcasting: {}.'
2066    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
2067  return result_shape
2068
2069def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
2070  dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
2071  shape_rule = partial(_broadcasting_shape_rule, name)
2072  prim = standard_primitive(shape_rule, dtype_rule, name,
2073                            translation_rule=translation_rule)
2074  batching.defbroadcasting(prim)
2075  masking.defnaryop(prim)
2076  return prim
2077standard_naryop = partial(naryop, _input_dtype)
2078
2079
2080def _broadcast_translate(translate: Callable):
2081  # Decorator for translation rules which adds explicit broadcasting of
2082  # positional arguments. This is necessary only for a handful of primitives
2083  # whose XLA implementations do not support broadcasting.
2084  def _broadcast_array(array, array_shape, result_shape):
2085    if array_shape == result_shape:
2086      return array
2087    bcast_dims = tuple(range(len(result_shape) - len(array_shape),
2088                             len(result_shape)))
2089    result = xops.BroadcastInDim(array, result_shape, bcast_dims)
2090    return result
2091
2092  def _broadcasted_translation_rule(c, *args, **kwargs):
2093    shapes = [c.get_shape(arg).dimensions() for arg in args]
2094    result_shape = broadcast_shapes(*shapes)
2095    args = [_broadcast_array(arg, arg_shape, result_shape)
2096            for arg, arg_shape in zip(args, shapes)]
2097    return translate(c, *args, **kwargs)
2098  return _broadcasted_translation_rule
2099
2100# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
2101# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
2102# a broadcast). but saving the shape info with the primitives isn't great either
2103# because then we can't trace these ops without shape data.
2104def _brcast(x, *others):
2105  # Used in jvprules to make naryop broadcasting explicit for transposability.
2106  # Requires shape info during jvp tracing, which isn't strictly necessary.
2107  # We don't need full numpy broadcasting, but otherwise the logic is the same
2108  # so we reuse the broadcast_shapes function after filtering out scalars.
2109  shapes = tuple(filter(None, map(np.shape, (x,) + others)))
2110  shape = shapes and broadcast_shapes(*shapes)
2111  if np.shape(x) != shape:
2112    return _brcast_to(x, shape)
2113  else:
2114    return x
2115
2116
2117def _brcast_to(x, shape):
2118  x_shape = np.shape(x)
2119  assert x_shape != shape
2120  if x_shape:
2121    assert len(x_shape) == len(shape)
2122    broadcast_dimensions, = np.where(np.equal(x_shape, shape))
2123    squeezed_dimensions, = np.where(np.not_equal(x_shape, shape))
2124    squeezed = squeeze(x, squeezed_dimensions)
2125    return broadcast_in_dim(squeezed, shape, broadcast_dimensions)
2126  else:
2127    return broadcast(x, shape)
2128
2129
2130_float = {np.floating}
2131_complex = {np.complexfloating}
2132_complex_elem_types = {np.float32, np.float64}
2133_int = {np.integer}
2134_bool = {np.bool_}
2135
2136_num = _int | _float | _complex
2137_any = _int | _float | _complex | _bool
2138_bool_or_int = _int | _bool
2139
2140neg_p = standard_unop(_num, 'neg')
2141ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
2142
2143def _sign_translation_rule(c, x):
2144  shape = c.get_shape(x)
2145  dtype = shape.numpy_dtype()
2146  if dtypes.issubdtype(dtype, np.unsignedinteger):
2147    zero = xb.constant(c, np.array(0, dtype=dtype))
2148    dims = c.get_shape(x).dimensions()
2149    return xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, dims),
2150                       xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)),
2151                                      dims))
2152  return xops.Sign(x)
2153
2154sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
2155ad.defjvp_zero(sign_p)
2156
2157nextafter_p = standard_naryop(
2158  [_float, _float], 'nextafter',
2159  translation_rule=_broadcast_translate(partial(standard_translate, 'next_after')))
2160
2161floor_p = standard_unop(_float, 'floor')
2162ad.defjvp_zero(floor_p)
2163
2164ceil_p = standard_unop(_float, 'ceil')
2165ad.defjvp_zero(ceil_p)
2166
2167def _round_to_nearest_even(x):
2168  half = _const(x, 0.5)
2169  one = _const(x, 1)
2170  round_val = floor(x)
2171  fraction = x - round_val
2172  nearest_even_int = sub(
2173    round_val, mul(_const(x, 2), floor(mul(half, x))))
2174  is_odd = eq(nearest_even_int, one)
2175  return select(
2176    bitwise_or(gt(fraction, half),
2177               bitwise_and(eq(fraction, half), is_odd)),
2178    add(round_val, one), round_val)
2179
2180def _round_translation_rule(c, x, *, rounding_method):
2181  if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
2182    return xops.Round(x)
2183  else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
2184    rounding_fun = xla.lower_fun(_round_to_nearest_even, multiple_results=False)
2185    return rounding_fun(c, x)
2186
2187round_p = standard_unop(_float, 'round')
2188xla.translations[round_p] = _round_translation_rule
2189ad.defjvp_zero(round_p)
2190
2191is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
2192ad.defjvp_zero(is_finite_p)
2193
2194exp_p = standard_unop(_float | _complex, 'exp')
2195ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
2196iad.definverse(exp_p, lambda r, x: log(r))
2197# For exp_p it is more efficient to use the reconstructed output for the vjp
2198# rule instead of computing it again from the input.
2199iad.primitive_ivjps[exp_p] = lambda x, y, ct: [[log(y[0])], [ct[0] * y[0]]]
2200
2201log_p = standard_unop(_float | _complex, 'log')
2202ad.defjvp(log_p, lambda g, x: div(g, x))
2203iad.definverse(log_p, lambda r, x: exp(r))
2204
2205expm1_p = standard_unop(_float | _complex, 'expm1')
2206ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
2207
2208log1p_p = standard_unop(_float | _complex, 'log1p')
2209ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
2210
2211tanh_p = standard_unop(_float | _complex, 'tanh')
2212ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
2213                                         sub(_one(x), ans)))
2214
2215sin_p = standard_unop(_float | _complex, 'sin')
2216ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
2217
2218cos_p = standard_unop(_float | _complex, 'cos')
2219ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
2220
2221@partial(xla.lower_fun, multiple_results=False)
2222@_upcast_fp16_for_computation
2223def tan_translation_rule(x):
2224  return div(sin(x), cos(x))
2225
2226tan_p = standard_unop(_float | _complex, 'tan',
2227                       translation_rule=tan_translation_rule)
2228ad.defjvp(tan_p, lambda g, x: mul(g, _const(x, 1) + square(tan(x))))
2229
2230
2231@partial(xla.lower_fun, multiple_results=False)
2232def asin_translation_rule(x):
2233  if dtypes.issubdtype(_dtype(x), np.complexfloating):
2234    return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x)))
2235  else:
2236    return mul(_const(x, 2),
2237               atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))
2238
2239asin_p = standard_unop(_float | _complex, 'asin',
2240                       translation_rule=asin_translation_rule)
2241ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
2242
2243
2244@partial(xla.lower_fun, multiple_results=False)
2245def acos_translation_rule(x):
2246  if dtypes.issubdtype(_dtype(x), np.complexfloating):
2247    result = mul(_const(x, 1j), acosh(x))
2248    # By convention, numpy chooses the branch with positive real part.
2249    rpart = real(result)
2250    return select(
2251      gt(rpart, _const(rpart, 0)),
2252      result,
2253      neg(result)
2254    )
2255  else:
2256    return select(
2257        ne(x, _const(x, -1.0)),
2258        mul(_const(x, 2),
2259            atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
2260        full_like(x, np.pi))
2261
2262acos_p = standard_unop(_float | _complex, 'acos',
2263                       translation_rule=acos_translation_rule)
2264ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
2265
2266@partial(xla.lower_fun, multiple_results=False)
2267def atan_translation_rule(x):
2268  if dtypes.issubdtype(_dtype(x), np.complexfloating):
2269    return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x)))
2270  else:
2271    return atan2(x, _const(x, 1))
2272
2273atan_p = standard_unop(_float | _complex, 'atan',
2274                       translation_rule=atan_translation_rule)
2275ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
2276
2277atan2_p = standard_naryop([_float, _float], 'atan2')
2278ad.defjvp(atan2_p,
2279  lambda g, x, y: _brcast(g, y) * (y / (square(x) + square(y))),
2280  lambda g, x, y: _brcast(g, x) * -x / (square(x) + square(y)))
2281
2282sinh_p = standard_unop(_float | _complex, 'sinh')
2283ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
2284
2285cosh_p = standard_unop(_float | _complex, 'cosh')
2286ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
2287
2288asinh_p = standard_unop(_float | _complex, 'asinh')
2289ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
2290
2291acosh_p = standard_unop(_float | _complex, 'acosh')
2292ad.defjvp(acosh_p,
2293          lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
2294
2295atanh_p = standard_unop(_float | _complex, 'atanh')
2296ad.defjvp(atanh_p,
2297          lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
2298
2299regularized_incomplete_beta_p = standard_naryop(
2300    [_float, _float, _float], 'regularized_incomplete_beta',
2301    translation_rule=_broadcast_translate(
2302      partial(standard_translate, 'regularized_incomplete_beta')))
2303
2304def betainc_gradx(g, a, b, x):
2305  lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
2306  partial_x = exp((b - 1) * log1p(-x) +
2307                  (a - 1) * log(x) - lbeta)
2308  return partial_x * g
2309
2310def betainc_grad_not_implemented(g, a, b, x):
2311  raise ValueError("Betainc gradient with respect to a and b not supported.")
2312
2313ad.defjvp(regularized_incomplete_beta_p,
2314  betainc_grad_not_implemented,
2315  betainc_grad_not_implemented,
2316  betainc_gradx)
2317
2318lgamma_p = standard_unop(_float, 'lgamma')
2319ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
2320
2321digamma_p = standard_unop(_float, 'digamma')
2322
2323igamma_p = standard_naryop(
2324  [_float, _float], 'igamma',
2325  translation_rule=_broadcast_translate(partial(standard_translate, 'igamma')))
2326igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a',
2327  translation_rule=_broadcast_translate(partial(standard_translate,
2328                                               'igamma_grad_a')))
2329
2330def igamma_gradx(g, a, x):
2331  return _brcast(g, a, x) * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
2332
2333def igamma_grada(g, a, x):
2334  return _brcast(g, a, x) * igamma_grad_a(a, x)
2335
2336ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
2337
2338igammac_p = standard_naryop(
2339  [_float, _float], 'igammac',
2340  translation_rule=_broadcast_translate(partial(standard_translate, 'igammac')))
2341
2342def igammac_gradx(g, a, x):
2343  return -igamma_gradx(g, a, x)
2344
2345def igammac_grada(g, a, x):
2346  return -igamma_grada(g, a, x)
2347
2348ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
2349
2350random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad',
2351  translation_rule=_broadcast_translate(partial(standard_translate,
2352                                               'random_gamma_grad')))
2353
2354bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
2355ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
2356
2357bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
2358def _bessel_i1e_jvp(g, y, x):
2359  eps = dtypes.finfo(_dtype(x)).eps
2360  x_is_not_tiny = abs(x) > eps
2361  safe_x = select(x_is_not_tiny, x, full_like(x, eps))
2362  dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x))
2363  dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5))
2364  return g * dy_dx
2365ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
2366
2367erf_p = standard_unop(_float, 'erf')
2368ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
2369                                  mul(g, exp(neg(square(x))))))
2370
2371erfc_p = standard_unop(_float, 'erfc')
2372ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
2373                                   mul(neg(g), exp(neg(square(x))))))
2374
2375erf_inv_p = standard_unop(_float, 'erf_inv')
2376ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
2377                                            mul(g, exp(square(ans)))))
2378
2379real_p = unop(_complex_basetype, _complex, 'real')
2380ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
2381
2382imag_p = unop(_complex_basetype, _complex, 'imag')
2383ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
2384
2385_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
2386complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
2387                  'complex')
2388ad.deflinear2(complex_p, lambda t, *args: [real(t), imag(neg(t))])
2389
2390conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
2391
2392def _conj_transpose_rule(t, x, *, input_dtype):
2393  assert ad.is_undefined_primal(x)
2394  if dtypes.issubdtype(input_dtype, np.complexfloating):
2395    return [conj(t)]
2396  else:
2397    return [real(t)]
2398
2399xla.translations[conj_p] = lambda c, x, **kwargs: xops.Conj(x)
2400ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
2401ad.primitive_transposes[conj_p] = _conj_transpose_rule
2402
2403abs_p = unop(_complex_basetype, _num, 'abs')
2404
2405def _abs_jvp_rule(g, ans, x):
2406  if _iscomplex(x):
2407    return _maybe_real(mul(g, div(_maybe_conj(x),
2408           _replace_zero(convert_element_type(ans, _dtype(x))))))
2409  else:
2410    return select(ge(x, _zero(x)), g, neg(g))
2411ad.defjvp2(abs_p, _abs_jvp_rule)
2412_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
2413_maybe_real = lambda x: real(x) if _iscomplex(x) else x
2414
2415sqrt_p = standard_unop(_float | _complex, 'sqrt')
2416ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
2417
2418rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
2419ad.defjvp2(rsqrt_p,
2420           lambda g, ans, x:
2421           mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5)))))
2422
2423pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
2424
2425def _pow_jvp_lhs(g, ans, x, y):
2426  jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
2427  return mul(_brcast(g, y), jac)
2428
2429def _pow_jvp_rhs(g, ans, x, y):
2430  return mul(_brcast(g, x), mul(log(_replace_zero(x)), ans))
2431
2432ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
2433
2434
2435def _integer_pow_dtype_rule(x, *, y):
2436  dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x)
2437  if y < 0 and dtypes.issubdtype(dtype, np.integer):
2438    raise TypeError("Integers cannot be raised to negative powers, got "
2439                    f"integer_pow({x}, {y})")
2440  return dtype
2441
2442def _integer_pow_translation_rule(c, x, *, y):
2443  if y == 0:
2444    shape = c.get_shape(x)
2445    one = xb.constant(c, np.array(1, dtype=shape.numpy_dtype()))
2446    return xops.Broadcast(one, shape.dimensions())
2447  is_reciprocal = y < 0
2448  if is_reciprocal:
2449    y = -y
2450  acc = None
2451  while y > 0:
2452    if y & 1:
2453      acc = x if acc is None else xops.Mul(acc, x)
2454    y >>= 1
2455    if y > 0:
2456      x = xops.Mul(x, x)
2457  return xops.Reciprocal(acc) if is_reciprocal else acc
2458
2459def _integer_pow_jvp(g, x, *, y):
2460  return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
2461
2462integer_pow_p = standard_primitive(
2463  _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
2464  translation_rule=_integer_pow_translation_rule)
2465batching.defvectorized(integer_pow_p)
2466masking.defvectorized(integer_pow_p)
2467ad.defjvp(integer_pow_p, _integer_pow_jvp)
2468
2469_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
2470
2471not_p = standard_unop(_bool_or_int, 'not')
2472ad.defjvp_zero(not_p)
2473
2474and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
2475ad.defjvp_zero(and_p)
2476
2477or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
2478ad.defjvp_zero(or_p)
2479
2480xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
2481ad.defjvp_zero(xor_p)
2482
2483population_count_p = standard_unop(_int, 'population_count')
2484
2485def _add_transpose(t, x, y):
2486  # The following linearity assertion is morally true, but because in some cases we
2487  # instantiate zeros for convenience, it doesn't always hold.
2488  # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
2489  return [t, t]
2490
2491add_p = standard_naryop([_num, _num], 'add')
2492ad.defjvp(add_p, lambda g, x, y: _brcast(g, y), lambda g, x, y: _brcast(g, x))
2493ad.primitive_transposes[add_p] = _add_transpose
2494def _add_inverse(r, x, y):
2495  xr = r - y
2496  yr = r - x
2497  return xr, yr
2498iad.definverse(add_p, _add_inverse)
2499
2500def _sub_transpose(t, x, y):
2501  # The following linearity assertion is morally true, but because in some cases
2502  # we instantiate zeros for convenience, it doesn't always hold.
2503  # TODO(mattjj): re-enable this assertion, don't return None below
2504  # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
2505  if type(t) is ad_util.Zero:
2506    x_bar = ad_util.Zero(x.aval) if ad.is_undefined_primal(x) else None
2507    y_bar = ad_util.Zero(y.aval) if ad.is_undefined_primal(y) else None
2508    return [x_bar, y_bar]
2509  else:
2510    return [t, neg(t)]
2511
2512sub_p = standard_naryop([_num, _num], 'sub')
2513ad.defjvp(sub_p,
2514          lambda g, x, y: _brcast(g, y),
2515          lambda g, x, y: _brcast(neg(g), x))
2516ad.primitive_transposes[sub_p] = _sub_transpose
2517
2518mul_p = standard_naryop([_num, _num], 'mul')
2519ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul)
2520def _mul_inverse(r, x, y):
2521  xr = r / y
2522  yr = r / x
2523  return xr, yr
2524iad.definverse(mul_p, _mul_inverse)
2525
2526def _div_transpose_rule(cotangent, x, y):
2527  assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
2528  res = ad_util.Zero(x.aval) if type(cotangent) is ad_util.Zero else div(cotangent, y)
2529  return res, None
2530div_p = standard_naryop([_num, _num], 'div')
2531ad.defjvp(div_p,
2532          lambda g, x, y: div(_brcast(g, y), y),
2533          lambda g, x, y: mul(mul(neg(_brcast(g, x)), x), integer_pow(y, -2)))
2534ad.primitive_transposes[div_p] = _div_transpose_rule
2535
2536rem_p = standard_naryop([_num, _num], 'rem')
2537ad.defjvp(rem_p,
2538          lambda g, x, y: _brcast(g, y),
2539          lambda g, x, y: mul(_brcast(neg(g), x), floor(div(x, y))))
2540
2541
2542def _broadcasting_select(c, which, x, y):
2543  """Wrapper around XLA `Select` that broadcasts its arguments."""
2544  which_shape, x_shape, y_shape = (
2545    c.get_shape(t).dimensions() for t in (which, x, y))
2546  out_shape = broadcast_shapes(which_shape, x_shape, y_shape)
2547  bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape),
2548                                         len(out_shape)))
2549  which = xops.BroadcastInDim(which, out_shape, bcast_dims(which_shape))
2550  x = xops.BroadcastInDim(x, out_shape, bcast_dims(x_shape))
2551  y = xops.BroadcastInDim(y, out_shape, bcast_dims(y_shape))
2552  return xops.Select(which, x, y)
2553
2554
2555def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
2556  dtype = c.get_shape(x).numpy_dtype()
2557  if dtypes.issubdtype(dtype, np.complexfloating):
2558    rx = xops.Real(x)
2559    ry = xops.Real(y)
2560    return _broadcasting_select(
2561        c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)),
2562                       cmp(rx, ry)),
2563        x, y)
2564  return minmax(x, y)
2565
2566max_p: core.Primitive = standard_naryop(
2567  [_any, _any], 'max', translation_rule=partial(
2568    _minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt))
2569ad.defjvp2(max_p,
2570           lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
2571           lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
2572
2573min_p: core.Primitive = standard_naryop(
2574  [_any, _any], 'min', translation_rule=partial(
2575    _minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt))
2576ad.defjvp2(min_p,
2577           lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
2578           lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
2579
2580shift_left_p = standard_naryop([_int, _int], 'shift_left')
2581ad.defjvp_zero(shift_left_p)
2582
2583shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
2584ad.defjvp_zero(shift_right_arithmetic_p)
2585
2586shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
2587ad.defjvp_zero(shift_right_logical_p)
2588
2589eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
2590ad.defjvp_zero(eq_p)
2591
2592ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne')
2593ad.defjvp_zero(ne_p)
2594
2595ge_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ge')
2596ad.defjvp_zero(ge_p)
2597
2598gt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'gt')
2599ad.defjvp_zero(gt_p)
2600
2601le_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'le')
2602ad.defjvp_zero(le_p)
2603
2604lt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'lt')
2605ad.defjvp_zero(lt_p)
2606
2607
2608def _convert_element_type_shape_rule(operand, *, new_dtype):
2609  return operand.shape
2610
2611def _convert_element_type_dtype_rule(operand, *, new_dtype):
2612  return new_dtype
2613
2614def _convert_element_type_translation_rule(c, operand, *, new_dtype):
2615  old_dtype = c.get_shape(operand).numpy_dtype()
2616  if (dtypes.issubdtype(old_dtype, np.complexfloating) and
2617      not dtypes.issubdtype(new_dtype, np.complexfloating)):
2618    operand = xops.Real(operand)
2619  new_etype = xla_client.dtype_to_etype(new_dtype)
2620  return xops.ConvertElementType(operand, new_element_type=new_etype)
2621
2622def _convert_element_type_transpose_rule(ct, operand, *, new_dtype):
2623  assert ad.is_undefined_primal(operand)
2624  old_dtype = operand.aval.dtype
2625  if type(ct) is ad_util.Zero:
2626    return [ad_util.Zero(operand.aval)]
2627  elif core.primal_dtype_to_tangent_dtype(old_dtype) is dtypes.float0:
2628    return [ad_util.Zero(ShapedArray(operand.aval.shape, dtype=dtypes.float0))]
2629  else:
2630    return [convert_element_type_p.bind(ct, new_dtype=old_dtype)]
2631
2632def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype):
2633  if core.primal_dtype_to_tangent_dtype(new_dtype) is dtypes.float0:
2634    return ad_util.Zero(ShapedArray(tangent.shape, dtype=dtypes.float0))
2635  else:
2636    return convert_element_type_p.bind(tangent, new_dtype=new_dtype)
2637
2638convert_element_type_p = standard_primitive(
2639    _convert_element_type_shape_rule, _convert_element_type_dtype_rule,
2640    'convert_element_type', _convert_element_type_translation_rule)
2641ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
2642ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
2643batching.defvectorized(convert_element_type_p)
2644masking.defvectorized(convert_element_type_p)
2645
2646
2647def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
2648  return operand.shape
2649
2650def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
2651  return new_dtype
2652
2653def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):
2654  new_etype = xla_bridge.dtype_to_etype(new_dtype)
2655  return xops.BitcastConvertType(operand, new_element_type=new_etype)
2656
2657bitcast_convert_type_p = standard_primitive(
2658    _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule,
2659    'bitcast_convert_type', _bitcast_convert_type_translation_rule)
2660ad.defjvp_zero(bitcast_convert_type_p)
2661batching.defvectorized(bitcast_convert_type_p)
2662masking.defvectorized(bitcast_convert_type_p)
2663
2664
2665def _conv_general_dilated_shape_rule(
2666    lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding,
2667    lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
2668    batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
2669  assert type(dimension_numbers) is ConvDimensionNumbers
2670  if len(lhs.shape) != len(rhs.shape):
2671    msg = ("conv_general_dilated lhs and rhs must have the same number of "
2672           "dimensions, but got {} and {}.")
2673    raise ValueError(msg.format(lhs.shape, rhs.shape))
2674  if not feature_group_count > 0:
2675    msg = ("conv_general_dilated feature_group_count "
2676           "must be a positive integer, got {}.")
2677    raise ValueError(msg.format(feature_group_count))
2678  lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]]
2679  quot, rem = divmod(lhs_feature_count, feature_group_count)
2680  if rem:
2681    msg = ("conv_general_dilated feature_group_count must divide lhs feature "
2682           "dimension size, but {} does not divide {}.")
2683    raise ValueError(msg.format(feature_group_count, lhs_feature_count))
2684  if quot != rhs.shape[dimension_numbers.rhs_spec[1]]:
2685    msg = ("conv_general_dilated lhs feature dimension size divided by "
2686           "feature_group_count must equal the rhs input feature dimension "
2687           "size, but {} // {} != {}.")
2688    raise ValueError(msg.format(lhs_feature_count, feature_group_count,
2689                                rhs.shape[dimension_numbers.rhs_spec[1]]))
2690  if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
2691    msg = ("conv_general_dilated rhs output feature dimension size must be a "
2692           "multiple of feature_group_count, but {} is not a multiple of {}.")
2693    raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
2694                                feature_group_count))
2695
2696  if not batch_group_count > 0:
2697    msg = ("conv_general_dilated batch_group_count "
2698           "must be a positive integer, got {}.")
2699    raise ValueError(msg.format(batch_group_count))
2700  lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
2701  if lhs_batch_count % batch_group_count != 0:
2702    msg = ("conv_general_dilated batch_group_count must divide lhs batch "
2703           "dimension size, but {} does not divide {}.")
2704    raise ValueError(msg.format(batch_group_count, lhs_batch_count))
2705
2706  if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
2707    msg = ("conv_general_dilated rhs output feature dimension size must be a "
2708           "multiple of batch_group_count, but {} is not a multiple of {}.")
2709    raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
2710                                batch_group_count))
2711
2712  if batch_group_count > 1 and feature_group_count > 1:
2713    msg = ("At most one of batch_group_count and feature_group_count may be > "
2714           "1, got batch_group_count={} and feature_group_count={}")
2715    raise ValueError(msg.format(batch_group_count, feature_group_count))
2716
2717  if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides):
2718    msg = ("conv_general_dilated window and window_strides must have "
2719           "the same number of dimensions, but got {} and {}")
2720    raise ValueError(
2721        msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)), len(window_strides)))
2722
2723  lhs_perm, rhs_perm, out_perm = dimension_numbers
2724  lhs_trans = _dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation)
2725  rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
2726  out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
2727                               batch_group_count)
2728  return tuple(np.take(out_trans, np.argsort(out_perm)))
2729
2730def _conv_general_dilated_dtype_rule(
2731    lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
2732    dimension_numbers, **unused_kwargs):
2733  return naryop_dtype_rule(_input_dtype, [_float | _complex, _float | _complex],
2734                          'conv_general_dilated', lhs, rhs)
2735
2736_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
2737_conv_sdims = lambda spec: spec[2:]
2738
2739# Understanding the convolution transpose rules:
2740# Ignoring the spatial dimensions, let m = batch, j = input feature,
2741# k = output feature.
2742#
2743# Convolution computes the following contraction:
2744# Forward: [m, j] [j, k] -> [m, k]
2745#
2746# The transposes are similar to the rules for transposing a matmul:
2747# LHS transpose: [m, k] [k, j] -> [m, j]
2748# RHS transpose: [j, m] [m, k] -> [j, k]
2749#
2750# With feature grouping, we have the following signatures:
2751# Forward: [m, gj] [j, gk] -> [m, gk]
2752# LHS transpose: [m, gk] [k, gj] -> [m, gj]
2753# --> implemented as feature grouping after transposing the group from the
2754#     kernel input features to the kernel output features.
2755# RHS transpose: [gj, m] [m, gk] -> [j, gk]
2756# --> which is batch grouping.
2757#
2758# With batch grouping, we have the following signatures:
2759# Forward: [gm,j] [j,gk]->[m,gk]
2760# LHS transpose: [m, gk][gk, j] -> [gm, j]
2761# --> implemented as feature grouping with transposing the group on the kernel
2762#     and the output.
2763# RHS transpose: [j, gm][m, gk] -> [j, gk]
2764# --> which is feature grouping.
2765
2766def _conv_general_dilated_transpose_lhs(
2767    g, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
2768    dimension_numbers, feature_group_count, batch_group_count,
2769    lhs_shape, rhs_shape, precision):
2770  assert type(dimension_numbers) is ConvDimensionNumbers
2771  assert batch_group_count == 1 or feature_group_count == 1
2772  lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
2773  lhs_spec, rhs_spec, out_spec = dimension_numbers
2774  t_rhs_spec = _conv_spec_transpose(rhs_spec)
2775  if feature_group_count > 1:
2776    # in addition to switching the dims in the spec, need to move the feature
2777    # group axis into the transposed rhs's output feature dim
2778    rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs)
2779    rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
2780  elif batch_group_count > 1:
2781    rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs)
2782    rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
2783    feature_group_count = batch_group_count
2784  trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec)
2785  padding = _conv_general_vjp_lhs_padding(
2786      np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
2787      window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
2788      rhs_dilation)
2789  revd_weights = rev(rhs, rhs_sdims)
2790  out = conv_general_dilated(
2791      g, revd_weights, window_strides=lhs_dilation, padding=padding,
2792      lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
2793      dimension_numbers=trans_dimension_numbers,
2794      feature_group_count=feature_group_count,
2795      batch_group_count=1, precision=precision)
2796  if batch_group_count > 1:
2797    out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out)
2798    out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out)
2799  return out
2800
2801def _conv_general_dilated_transpose_rhs(
2802    g, lhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
2803    dimension_numbers: ConvDimensionNumbers, feature_group_count: int,
2804    batch_group_count: int, lhs_shape, rhs_shape, precision):
2805  assert type(dimension_numbers) is ConvDimensionNumbers
2806  if np.size(g) == 0:
2807    # Avoids forming degenerate convolutions where the RHS has spatial size 0.
2808    # Awkwardly, we don't have an aval for the rhs readily available, so instead
2809    # of returning an ad_util.Zero instance here, representing a symbolic zero
2810    # value, we instead return a None, which is meant to represent having no
2811    # cotangent at all (and is thus incorrect for this situation), since the two
2812    # are treated the same operationally.
2813    # TODO(mattjj): adjust defbilinear so that the rhs aval is available here
2814    return None
2815  lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
2816  lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
2817  assert batch_group_count == 1 or feature_group_count == 1
2818  if batch_group_count > 1:
2819    feature_group_count = batch_group_count
2820    batch_group_count = 1
2821  elif feature_group_count > 1:
2822    batch_group_count = feature_group_count
2823    feature_group_count = 1
2824  trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
2825  padding = _conv_general_vjp_rhs_padding(
2826      np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
2827      window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
2828      rhs_dilation)
2829  return conv_general_dilated(
2830      lhs, g, window_strides=rhs_dilation, padding=padding,
2831      lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
2832      dimension_numbers=trans_dimension_numbers,
2833      feature_group_count=feature_group_count,
2834      batch_group_count=batch_group_count, precision=precision)
2835
2836
2837def _conv_general_dilated_translation_rule(
2838    c, lhs, rhs, *, window_strides, padding,
2839    lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
2840    batch_group_count, precision, expand_complex_convolutions, **unused_kwargs):
2841  assert type(dimension_numbers) is ConvDimensionNumbers
2842  dimension_numbers = _conv_general_proto(dimension_numbers)
2843  precision_config = _precision_config(precision)
2844  dtype = c.get_shape(lhs).numpy_dtype()
2845  conv = lambda x, y: xops.ConvGeneralDilated(
2846      x, y, window_strides, padding, lhs_dilation, rhs_dilation,
2847      dimension_numbers, feature_group_count, batch_group_count,
2848      precision_config=precision_config)
2849  if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
2850    # We use a trick for complex multiplication due to Gauss which uses three
2851    # multiplications and five additions; instead of the naive method of four
2852    # multiplications and two additions.
2853    # https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm
2854    #
2855    # This performance win comes with a trade-off in accuracy; especially in
2856    # cases when the real and imaginary differ hugely in magnitude. The relative
2857    # error bound (e.g. 1p-24 in case of float32) would be relative to the
2858    # maximum of real and imaginary parts of the result instead of being
2859    # satisfied by the real and imaginary parts independently of each other.
2860    lhs_real, lhs_imag = xops.Real(lhs), xops.Imag(lhs)
2861    rhs_real, rhs_imag = xops.Real(rhs), xops.Imag(rhs)
2862    k1 = conv(xops.Add(lhs_real, lhs_imag), rhs_real)
2863    k2 = conv(lhs_real, xops.Sub(rhs_imag, rhs_real))
2864    k3 = conv(lhs_imag, xops.Add(rhs_real, rhs_imag))
2865    return xops.Complex(xops.Sub(k1, k3), xops.Add(k1, k2))
2866  return conv(lhs, rhs)
2867
2868def _conv_general_dilated_batch_rule(
2869    batched_args, batch_dims, *, window_strides, padding,
2870    lhs_dilation, rhs_dilation, dimension_numbers,
2871    feature_group_count, batch_group_count, precision, **unused_kwargs):
2872  assert batch_group_count == 1 or feature_group_count == 1
2873  lhs, rhs = batched_args
2874  lhs_bdim, rhs_bdim = batch_dims
2875  lhs_spec, rhs_spec, out_spec = dimension_numbers
2876
2877  if lhs_bdim is not None and rhs_bdim is not None:
2878    assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
2879    if batch_group_count > 1:
2880      new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
2881      batch_group_count *= lhs.shape[lhs_bdim]
2882    else:
2883      new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
2884      feature_group_count *= lhs.shape[lhs_bdim]
2885    new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
2886    out = conv_general_dilated(
2887      new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation,
2888      dimension_numbers, feature_group_count=feature_group_count,
2889      batch_group_count=batch_group_count,
2890      precision=precision)
2891    out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
2892    return out, out_spec[1]
2893
2894  elif lhs_bdim is not None:
2895    if batch_group_count == 1:
2896      new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
2897      out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
2898                                 lhs_dilation, rhs_dilation, dimension_numbers,
2899                                 feature_group_count, precision=precision)
2900      out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
2901      return out, out_spec[0]
2902    else:
2903      new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]),
2904                                     batch_group_count, lhs)
2905      new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim),
2906                                   lhs_spec[0] + 1,
2907                                   new_lhs)
2908      new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs)
2909      out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
2910                                 lhs_dilation, rhs_dilation, dimension_numbers,
2911                                 feature_group_count, batch_group_count,
2912                                 precision=precision)
2913      out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
2914      return out, out_spec[0]
2915
2916  elif rhs_bdim is not None:
2917    if feature_group_count == 1 and batch_group_count == 1:
2918      new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
2919      out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
2920                                 lhs_dilation, rhs_dilation, dimension_numbers,
2921                                 feature_group_count, batch_group_count,
2922                                 precision=precision)
2923      out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
2924      return out, out_spec[1]
2925    else:
2926      # groups need to be outermost, so we need to factor them out of the
2927      # rhs output feature dim, then factor the batch dim into the remaining rhs
2928      # output feature dim, then put groups back in. We do something
2929      # similar on the output. An alternative which would require more FLOPs but
2930      # fewer reshapes would be to broadcast lhs.
2931      group_count = (feature_group_count if feature_group_count > 1
2932                     else batch_group_count)
2933      new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
2934                                     group_count, rhs)
2935      new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
2936                                   rhs_spec[0] + 1,
2937                                   new_rhs)
2938      new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
2939      out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
2940                                 lhs_dilation, rhs_dilation, dimension_numbers,
2941                                 feature_group_count, batch_group_count,
2942                                 precision=precision)
2943      out = _reshape_axis_out_of(out_spec[1], group_count, out)
2944      out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out)
2945      out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
2946      return out, out_spec[1]
2947
2948def _masked(padded_value, logical_shape, dimensions, value=0):
2949  """
2950  Sets all padding to the given value (default is 0) in the given dimensions.
2951  All values outside the logical shape are considered padding.
2952  """
2953  if len(dimensions) == 0:
2954    return padded_value
2955
2956  masks = [broadcasted_iota(np.int32, padded_value.shape, d) < logical_shape[d]
2957           for d in dimensions]
2958  mask_intersection = masks[0]
2959  for mask in masks[1:]:
2960    mask_intersection &= mask
2961  return select(mask_intersection, padded_value, full_like(padded_value, value))
2962
2963def _conv_general_dilated_masking_rule(
2964        padded_vals, logical_shapes, window_strides, padding, lhs_dilation,
2965        rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
2966        lhs_shape, rhs_shape, precision):
2967  lhs, rhs = padded_vals
2968  logical_lhs_shape, logical_rhs_shape = logical_shapes
2969
2970  o, i, *window_dimensions = dimension_numbers.rhs_spec
2971  assert (np.all(np.take(rhs.shape, window_dimensions)
2972                  == np.take(logical_rhs_shape, window_dimensions))), \
2973              "Conv filter masking not yet implemented."
2974
2975  n, c, *padded_dimensions = dimension_numbers.lhs_spec
2976
2977  return conv_general_dilated(
2978    _masked(lhs, logical_lhs_shape, padded_dimensions),
2979    _masked(rhs, logical_rhs_shape, (i,)),
2980    window_strides=window_strides, padding=padding,
2981    lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation,
2982    dimension_numbers=dimension_numbers,
2983    feature_group_count=feature_group_count,
2984    batch_group_count=batch_group_count,
2985    precision=precision)
2986
2987conv_general_dilated_p = standard_primitive(
2988    _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
2989    'conv_general_dilated', partial(_conv_general_dilated_translation_rule,
2990                                    expand_complex_convolutions=False))
2991
2992# TODO(b/161124619, b/161126248): XLA does not support complex convolution on
2993# CPU or GPU; on these backends, lower complex convolutions away.
2994xla.backend_specific_translations['cpu'][conv_general_dilated_p] = partial(
2995    _conv_general_dilated_translation_rule, expand_complex_convolutions=True)
2996xla.backend_specific_translations['gpu'][conv_general_dilated_p] = partial(
2997    _conv_general_dilated_translation_rule, expand_complex_convolutions=True)
2998
2999ad.defbilinear(conv_general_dilated_p,
3000               _conv_general_dilated_transpose_lhs,
3001               _conv_general_dilated_transpose_rhs)
3002batching.primitive_batchers[conv_general_dilated_p] = \
3003    _conv_general_dilated_batch_rule
3004masking.masking_rules[conv_general_dilated_p] = \
3005  _conv_general_dilated_masking_rule
3006
3007def _reshape_axis_into(src, dst, x):
3008  perm = [i for i in range(x.ndim) if i != src]
3009  perm.insert(dst, src)
3010  new_shape = list(np.delete(x.shape, src))
3011  new_shape[dst] *= x.shape[src]
3012  return reshape(x, new_shape, perm)
3013
3014def _reshape_axis_out_of(src, size1, x):
3015  shape = list(x.shape)
3016  size2, ragged = divmod(shape[src], size1)
3017  assert not ragged
3018  shape[src:src+1] = [size1, size2]
3019  return reshape(x, shape)
3020
3021def _precision_config(precision):
3022  if precision is not None:
3023    config = xla_client.PrecisionConfig()
3024    if isinstance(precision, tuple):
3025      config.operand_precision.extend(precision)
3026    else:
3027      config.operand_precision.extend((precision, precision))
3028    return config
3029  return None
3030
3031
3032def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
3033                            preferred_element_type: Optional[DType]):
3034  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
3035  if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
3036             for d in (lhs_contracting, lhs_batch)):
3037    msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
3038           "less than the number of axes of the lhs value, got "
3039           f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} "
3040           f"for lhs of rank {lhs.ndim}")
3041    raise TypeError(msg)
3042  if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim))
3043             for d in (rhs_contracting, rhs_batch)):
3044    msg = ("dot_general requires rhs dimension numbers to be nonnegative and "
3045           "less than the number of axes of the rhs value, got "
3046           f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} "
3047           f"for rhs of rank {rhs.ndim}")
3048    raise TypeError(msg)
3049  if len(lhs_batch) != len(rhs_batch):
3050    msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
3051           "dimensions, got lhs_batch {} and rhs_batch {}.")
3052    raise TypeError(msg.format(lhs_batch, rhs_batch))
3053  lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
3054  rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
3055  if len(lhs_batch_set) != len(lhs_batch):
3056    msg = ("dot_general requires lhs batch dimensions to be distinct, got "
3057           f"lhs_batch {lhs_batch}.")
3058    raise TypeError(msg)
3059  if len(rhs_batch_set) != len(rhs_batch):
3060    msg = ("dot_general requires rhs batch dimensions to be distinct, got "
3061           f"rhs_batch {rhs_batch}.")
3062    raise TypeError(msg)
3063  if len(lhs_contracting_set) != len(lhs_contracting):
3064    msg = ("dot_general requires lhs contracting dimensions to be distinct, "
3065           f"got lhs_contracting {lhs_contracting}.")
3066    raise TypeError(msg)
3067  if len(rhs_contracting_set) != len(rhs_contracting):
3068    msg = ("dot_general requires rhs contracting dimensions to be distinct, "
3069           f"got rhs_contracting {rhs_contracting}.")
3070    raise TypeError(msg)
3071  if lhs_contracting_set & lhs_batch_set:
3072    msg = ("dot_general requires lhs batch dimensions to be disjoint from "
3073           "contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
3074    raise TypeError(msg.format(lhs_batch, lhs_contracting))
3075  if rhs_contracting_set & rhs_batch_set:
3076    msg = ("dot_general requires rhs batch dimensions to be disjoint from "
3077           "contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
3078    raise TypeError(msg.format(rhs_batch, rhs_contracting))
3079  lhs_batch_shape = np.take(lhs.shape, lhs_batch)
3080  rhs_batch_shape = np.take(rhs.shape, rhs_batch)
3081  if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)):
3082    msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3083           "to have the same shape, got {} and {}.")
3084    raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
3085  lhs_contracting_shape = np.take(lhs.shape, lhs_contracting)
3086  rhs_contracting_shape = np.take(rhs.shape, rhs_contracting)
3087  if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)):
3088    msg = ("dot_general requires contracting dimensions to have the same "
3089           "shape, got {} and {}.")
3090    raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
3091
3092  batch_shape = tuple(lhs_batch_shape)
3093  lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
3094  lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch))
3095  rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
3096  rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch))
3097  return batch_shape + lhs_tensored_shape + rhs_tensored_shape
3098
3099def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
3100                            preferred_element_type: Optional[DType]):
3101  input_dtype = naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
3102  if preferred_element_type is None:
3103    return input_dtype
3104  if dtypes.issubdtype(input_dtype, np.integer) and not dtypes.issubdtype(preferred_element_type, np.integer):
3105    raise TypeError("`preferred_element_type` and the original type must both be integral or both be floating point.")
3106  if dtypes.issubdtype(input_dtype, np.signedinteger) and not dtypes.issubdtype(preferred_element_type, np.signedinteger):
3107    raise TypeError("`preferred_element_type` must have the same signedness as the original type.")
3108  input_bitwidth = np.dtype(input_dtype).itemsize
3109  preferred_bitwidth = np.dtype(preferred_element_type).itemsize
3110  if preferred_bitwidth < input_bitwidth:
3111     raise TypeError("`preferred_element_type` must not be narrower than the original type.")
3112  return preferred_element_type
3113
3114def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
3115                               preferred_element_type: Optional[DType],
3116                               swap_ans=False):
3117  (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
3118  x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
3119  x_kept = remaining(range(x_ndim), x_contract, x_batch)
3120  y_kept = remaining(range(y.ndim), y_contract, y_batch)
3121  if swap_ans:
3122    ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
3123  else:
3124    ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
3125  dims = ((ans_y, y_kept), (ans_batch, y_batch))
3126  x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
3127  out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
3128  return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
3129                   tuple(out_axes))
3130
3131def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
3132                               preferred_element_type: Optional[DType]):
3133  (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
3134  swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
3135  return _dot_general_transpose_lhs(
3136    g, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
3137    preferred_element_type=preferred_element_type,
3138    swap_ans=True)
3139
3140
3141def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
3142                            precision,
3143                            preferred_element_type: Optional[DType]):
3144  lhs, rhs = batched_args
3145  new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
3146      (lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
3147  batched_out = dot_general(lhs, rhs, new_dimension_numbers,
3148                            precision=precision,
3149                            preferred_element_type=preferred_element_type)
3150  return batched_out, result_batch_dim
3151
3152def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
3153  # there are three kinds of dimensions in a dot_general:
3154  # - contraction dimensions appear in lhs and rhs but not the result
3155  # - batch dimensions appear in lhs, rhs, and result
3156  # - tensor product dimensions appear in the result and one of lhs or rhs
3157  lhs_ndim, rhs_ndim = ndims
3158  lbd, rbd = batch_dims
3159  assert lbd is not None or rbd is not None
3160  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
3161
3162  def bump_dims(dims, b):
3163    return tuple(np.add(dims, np.greater_equal(dims, b)))
3164
3165  if lbd is not None and rbd is not None:
3166    # adding a batch dimension
3167    lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
3168    rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
3169    lhs_contract = bump_dims(lhs_contract, lbd)
3170    rhs_contract = bump_dims(rhs_contract, rbd)
3171    result_batch_dim = 0
3172  else:
3173    # adding a tensor product dimension
3174    if lbd is not None:
3175      other = tuple(d for d in range(lhs_ndim)
3176                    if d not in lhs_batch and d not in lhs_contract)
3177      result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
3178      lhs_batch = bump_dims(lhs_batch, lbd)
3179      lhs_contract = bump_dims(lhs_contract, lbd)
3180    else:
3181      other = tuple(d for d in range(rhs_ndim)
3182                    if d not in rhs_batch and d not in rhs_contract)
3183      result_batch_dim = (lhs_ndim - len(lhs_contract) +
3184                          sum(np.less(other, rbd)))
3185      rhs_batch = bump_dims(rhs_batch, rbd)
3186      rhs_contract = bump_dims(rhs_contract, rbd)
3187
3188  new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
3189  return new_dimension_numbers, int(result_batch_dim)
3190
3191def _dot_using_sum_of_products(lhs, rhs, *, dimension_numbers):
3192  contract_dims, batch_dims = dimension_numbers
3193  lhs_contract_dims, rhs_contract_dims = contract_dims
3194  lhs_batch_dims, rhs_batch_dims = batch_dims
3195  lhs_noncontract_dims = tuple(sorted(
3196    set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
3197  rhs_noncontract_dims = tuple(sorted(
3198    set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
3199  lhs = transpose(lhs,
3200                  lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
3201  rhs = transpose(rhs,
3202                  rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
3203
3204  lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
3205  lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
3206  lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
3207
3208  rhs_start_expand = len(lhs_batch_dims)
3209  rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
3210  rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
3211
3212  out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
3213              len(rhs_noncontract_dims))
3214  op_product = bitwise_and if lhs.dtype == np.bool_ else mul
3215  op_sum = bitwise_or if lhs.dtype == np.bool_ else add
3216  return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
3217                tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
3218
3219def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision,
3220                                  preferred_element_type: Optional[DType]):
3221  if preferred_element_type is not None:
3222    preferred_element_type = xla_client.dtype_to_etype(preferred_element_type)
3223  return xops.DotGeneral(lhs, rhs,
3224                         xc.make_dot_dimension_numbers(dimension_numbers),
3225                         precision_config=_precision_config(precision),
3226                         preferred_element_type=preferred_element_type)
3227
3228def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
3229                              precision,
3230                              preferred_element_type: Optional[DType]):
3231  lhs, rhs = padded_vals
3232  # Only need to mask off contraction dims of one side - we mask the lhs here
3233  # but this is arbitrary. Could check the sizes of lhs and rhs and mask
3234  # whichever is smallest.
3235  lhs_shape, _ = logical_shapes
3236  (lhs_contract, _), _ = dimension_numbers
3237  return dot_general(_masked(lhs, lhs_shape, lhs_contract),
3238                     rhs, dimension_numbers, precision=precision,
3239                     preferred_element_type=preferred_element_type)
3240
3241dot_general_p = standard_primitive(_dot_general_shape_rule,
3242                                   _dot_general_dtype_rule, 'dot_general',
3243                                   _dot_general_translation_rule)
3244ad.defbilinear(dot_general_p,
3245               _dot_general_transpose_lhs, _dot_general_transpose_rhs)
3246batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
3247masking.masking_rules[dot_general_p] = _dot_general_masking_rule
3248
3249def _broadcast_shape_rule(operand, sizes):
3250  _check_shapelike('broadcast', 'sizes', sizes)
3251  return tuple(sizes) + operand.shape
3252
3253def _broadcast_batch_rule(batched_args, batch_dims, *, sizes):
3254  operand, = batched_args
3255  bdim, = batch_dims
3256  new_bdim = None if bdim is None else bdim + len(sizes)
3257  return broadcast(operand, sizes), new_bdim
3258
3259broadcast_p = standard_primitive(
3260    _broadcast_shape_rule, _input_dtype, 'broadcast')
3261ad.deflinear2(broadcast_p, lambda t, _, sizes: [_reduce_sum(t, range(len(sizes)))])
3262batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule
3263
3264def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions):
3265  if type(operand) is np.ndarray:
3266    operand = _device_put_raw(operand)
3267  if xla.type_is_device_array(operand) and np.all(
3268      np.equal(operand.shape, np.take(shape, broadcast_dimensions))):
3269    shape = _broadcast_in_dim_shape_rule(
3270      operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
3271    aval = ShapedArray(shape, _dtype(operand))
3272    if operand._lazy_expr is None:
3273      lazy_expr = lazy.broadcast(lazy.array(operand.shape), shape, broadcast_dimensions)
3274    else:
3275      lazy_expr = lazy.broadcast(operand._lazy_expr, shape, broadcast_dimensions)
3276    return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
3277  else:
3278    return xla.apply_primitive(broadcast_in_dim_p, operand, shape=shape,
3279                               broadcast_dimensions=broadcast_dimensions)
3280
3281def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
3282  _check_shapelike('broadcast_in_dim', 'shape', shape)
3283  _check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
3284                   broadcast_dimensions)
3285  operand_ndim = np.ndim(operand)
3286  if operand_ndim != len(broadcast_dimensions):
3287    msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
3288           'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
3289    raise TypeError(msg.format(broadcast_dimensions, operand_ndim))
3290  if len(shape) < operand_ndim:
3291    msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
3292           'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
3293    raise TypeError(msg.format(operand_ndim, len(shape)))
3294  if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
3295    msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
3296           'dimensions, got {} for operand ndim {} and shape {}.')
3297    raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
3298  if any(operand.shape[i] != shape[broadcast_dimensions[i]] and
3299         operand.shape[i] != 1 for i in range(operand_ndim)):
3300    msg = (
3301        "broadcast_in_dim operand dimension sizes must either be 1, or be "
3302        "equal to their corresponding dimensions in the target broadcast "
3303        "shape; got operand of shape {}, target broadcast shape {}, "
3304        "broadcast_dimensions {} ")
3305    raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
3306  if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
3307      tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
3308    msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
3309           "got broadcast_dimensions {}")
3310    raise TypeError(msg.format(broadcast_dimensions))
3311
3312  return shape
3313
3314def _broadcast_in_dim_transpose_rule(ct, operand, *, shape, broadcast_dimensions):
3315  shape_in = operand.aval.shape
3316  unit_dimensions = tuple(i for i, s in enumerate(shape_in) if s == 1)
3317  bdims = tuple(np.delete(broadcast_dimensions, unit_dimensions))
3318  axes = tuple(np.delete(range(len(shape)), bdims))
3319  return [expand_dims(_reduce_sum(ct, axes), unit_dimensions)]
3320
3321def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
3322                                 broadcast_dimensions):
3323  operand, = batched_args
3324  bdim, = batch_dims
3325  new_operand = batching.moveaxis(operand, bdim, 0)
3326  new_shape = (operand.shape[bdim],) + shape
3327  new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
3328  return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
3329
3330
3331broadcast_in_dim_p = standard_primitive(
3332    _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
3333broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl)
3334ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
3335batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
3336
3337
3338def _clamp_shape_rule(min, operand, max):
3339  if min.shape and min.shape != operand.shape:
3340    m = "clamp requires min.shape == operand.shape or min.shape == (), got {}."
3341    raise TypeError(m.format(min.shape))
3342  if max.shape and max.shape != operand.shape:
3343    m = "clamp requires max.shape == operand.shape or max.shape == (), got {}."
3344    raise TypeError(m.format(max.shape))
3345  return operand.shape
3346
3347_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
3348                            'clamp')
3349
3350clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
3351ad.defjvp(clamp_p,
3352          lambda g, min, operand, max:
3353          select(bitwise_and(gt(min, operand), lt(min, max)),
3354                 _brcast(g, operand), _zeros(operand)),
3355          lambda g, min, operand, max:
3356          select(bitwise_and(gt(operand, min), lt(operand, max)),
3357                 g, _zeros(operand)),
3358          lambda g, min, operand, max:
3359          select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
3360batching.defbroadcasting(clamp_p)
3361
3362
3363def _concatenate_shape_rule(*operands, **kwargs):
3364  dimension = kwargs.pop('dimension')
3365  if not operands:
3366    msg = "concatenate expects at least one operand, got 0."
3367    raise TypeError(msg)
3368  if not all(isinstance(operand, UnshapedArray) for operand in operands):
3369    msg = "All objects to concatenate must be arrays, got {}."
3370    op = next(op for op in operands if not isinstance(op, UnshapedArray))
3371    raise TypeError(msg.format(type(op)))
3372  if len({operand.ndim for operand in operands}) != 1:
3373    msg = "Cannot concatenate arrays with different ranks, got {}."
3374    raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
3375  if not 0 <= dimension < operands[0].ndim:
3376    msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
3377    raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
3378  shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
3379            for operand in operands]
3380  if not shapes[:-1] == shapes[1:]:
3381    msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
3382           "other than the one being concatenated: concatenating along "
3383           "dimension {} for shapes {}.")
3384    shapes = [operand.shape for operand in operands]
3385    raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
3386
3387  concat_size = sum(o.shape[dimension] for o in operands)
3388  ex_shape = operands[0].shape
3389  return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
3390
3391def _concatenate_dtype_rule(*operands, **kwargs):
3392  _check_same_dtypes('concatenate', False, *(o.dtype for o in operands))
3393  return operands[0].dtype
3394
3395def _concatenate_translation_rule(c, *operands, **kwargs):
3396  dimension = kwargs.pop('dimension')
3397  return xops.ConcatInDim(c, operands, dimension)
3398
3399def _concatenate_transpose_rule(t, *operands, dimension):
3400  operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
3401                    for o in operands]
3402  if type(t) is ad_util.Zero:
3403    return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
3404            for o in operands]
3405  else:
3406    limit_points = np.cumsum([shape[dimension] for shape in operand_shapes])
3407    starts = np.zeros((len(operands), t.ndim), dtype=int)
3408    starts[1:, dimension] = limit_points[:-1]
3409    limits = np.tile(t.shape, (len(operands), 1))
3410    limits[:, dimension] = limit_points
3411
3412    return [slice(t, start, limit) if ad.is_undefined_primal(o) else None
3413            for o, start, limit in zip(operands, starts, limits)]
3414
3415def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
3416  size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
3417              if bdim is not None)
3418  operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
3419              else broadcast(op, (size,))
3420              for op, bdim in zip(batched_args, batch_dims)]
3421  return concatenate(operands, dimension + 1), 0
3422
3423# The concatenate_p masking rule requires use of a while-loop construct and so
3424# is defined in lax_control_flow.py
3425
3426concatenate_p = standard_primitive(
3427    _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
3428    _concatenate_translation_rule)
3429ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
3430ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
3431batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
3432
3433
3434def _pad_dtype_rule(operand, padding_value, *, padding_config):
3435  if operand.dtype != padding_value.dtype:
3436    msg = "pad operand and padding_value must be same dtype: got {} and {}."
3437    raise TypeError(msg.format(operand.dtype, padding_value.dtype))
3438
3439  return _input_dtype(operand, padding_value)
3440
3441def _pad_shape_rule(operand, padding_value, *, padding_config):
3442  del padding_value
3443  if not len(padding_config) == np.ndim(operand):
3444    raise ValueError("length of padding_config must equal the number of axes "
3445                     f"of operand, got padding_config {padding_config} "
3446                     f"for operand shape {np.shape(operand)}")
3447  if not all(i >= 0 for _, _, i in padding_config):
3448    raise ValueError("interior padding in padding_config must be nonnegative, "
3449                     f"got padding_config {padding_config}")
3450  return tuple(l + h + d + (_max(0, d - 1) * i if i > 0 else 0)
3451               for (l, h, i), d in zip(padding_config, np.shape(operand)))
3452
3453def _pad_transpose(t, operand, padding_value, *, padding_config):
3454  if type(t) is ad_util.Zero:
3455    t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
3456    t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
3457  else:
3458    lo, hi, interior = zip(*padding_config)
3459    total = lambda x: _reduce_sum(x, list(range(t.ndim)))
3460
3461    def t_op():
3462      unpad_config = safe_zip(np.negative(lo), np.negative(hi),
3463                              np.zeros_like(interior))
3464      unpadded = pad(t, np.array(0., t.dtype), unpad_config)
3465      return slice(unpadded, np.zeros_like(lo), unpadded.shape, np.add(interior, 1))
3466
3467    t_operand = t_op() if ad.is_undefined_primal(operand) else None
3468    t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None
3469  return [t_operand, t_padv]
3470
3471def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
3472  operand, padding_value = batched_args
3473  operand_bdim, padding_value_bdim = batch_dims
3474  if padding_value_bdim is None:
3475    assert operand_bdim is not None
3476    padding_config = list(padding_config)
3477    padding_config.insert(operand_bdim, (0, 0, 0))
3478    return pad(operand, padding_value, padding_config), operand_bdim
3479  else:
3480    raise NotImplementedError  # loop and stack
3481
3482def _pad_translation_rule(c, operand, padding_value, *, padding_config):
3483  return xops.Pad(operand, padding_value,
3484                  xc.make_padding_config(padding_config))
3485
3486def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
3487  operand, padding_value = padded_vals
3488  shape, _ = logical_shapes
3489
3490  out = pad(operand, padding_value, padding_config)
3491  out_shape = [lo + shape[i] * (interior + 1)
3492               for i, (lo, hi, interior) in enumerate(padding_config)]
3493  padded_dims = [i for i, config in enumerate(padding_config)
3494                 if config != (0, 0, 0)]
3495  return _masked(out, out_shape, padded_dims, padding_value)
3496
3497pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
3498                           translation_rule=_pad_translation_rule)
3499ad.deflinear2(pad_p, _pad_transpose)
3500batching.primitive_batchers[pad_p] = _pad_batch_rule
3501masking.masking_rules[pad_p] = _pad_masking_rule
3502
3503
3504# The squeeze primitive exists for the benefit of masking and other
3505# transformations that need to keep track of axis identity.
3506# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
3507# with shape (N,). This results in the following JAXpr:
3508#   reshape[ dimension=None new_sizes=(N,) ]
3509# For N > 1, we can match up the output array axis with the second axis of the
3510# input. But for N = 1, it is not clear how axes match up: all we know from the
3511# JAXpr is that we are reshaping from (1, 1) to (1,).
3512# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
3513
3514def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array:
3515  """Squeeze any number of size 1 dimensions from an array."""
3516  ndim = np.ndim(array)
3517  dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
3518  if not dimensions:
3519    return array
3520  return squeeze_p.bind(array, dimensions=dimensions)
3521
3522def _squeeze_dtype_rule(operand, *, dimensions):
3523  return operand.dtype
3524
3525def _squeeze_shape_rule(operand, *, dimensions):
3526  return _compute_squeeze_shape(np.shape(operand), dimensions)
3527
3528def _compute_squeeze_shape(shape, dimensions):
3529  dims_set = set(dimensions)
3530  if len(dims_set) != len(dimensions):
3531    raise ValueError(f"dimensions are not unique: {dimensions}")
3532  if not all(0 <= d < len(shape) for d in dims_set):
3533    raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
3534  if any(shape[d] != 1 for d in dimensions):
3535    raise ValueError(
3536        "cannot select an axis to squeeze out which has size not equal to "
3537        f"one, got shape={shape} and dimensions={dimensions}")
3538  return tuple(s for i, s in enumerate(shape) if i not in dims_set)
3539
3540def _squeeze_translation_rule(c, arg, *, dimensions):
3541  new_shape = _compute_squeeze_shape(c.get_shape(arg).dimensions(), dimensions)
3542  return xops.Reshape(arg, new_shape)
3543
3544def _squeeze_transpose_rule(t, operand, *, dimensions):
3545  assert ad.is_undefined_primal(operand)
3546  return [expand_dims(t, dimensions)]
3547
3548def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
3549  operand, = batched_args
3550  bdim, = batch_dims
3551  operand = batching.moveaxis(operand, bdim, 0)
3552  dimensions = tuple(np.add(1, dimensions))
3553  return squeeze(operand, dimensions=dimensions), 0
3554
3555squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
3556                               'squeeze', _squeeze_translation_rule)
3557ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
3558batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
3559
3560
3561def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array:
3562  """Insert any number of size 1 dimensions into an array."""
3563  ndim_out = np.ndim(array) + len(dimensions)
3564  dims_set = frozenset(canonicalize_axis(i, ndim_out) for i in dimensions)
3565  result_shape = list(np.shape(array))
3566  for i in sorted(dims_set):
3567    result_shape.insert(i, 1)
3568  broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
3569  return broadcast_in_dim(array, result_shape, broadcast_dims)
3570
3571
3572# We have a nonstandard reshape impl so that we can be lazy about data movement.
3573def _reshape_impl(operand, *, new_sizes, dimensions):
3574  old_sizes = np.shape(operand)
3575  if xla.type_is_device_array(operand) and dimensions is None:
3576    bcast_dims = _is_singleton_reshape(old_sizes, new_sizes)
3577    if bcast_dims is not None:
3578      aval = ShapedArray(new_sizes, operand.dtype)
3579      if operand._lazy_expr is None:
3580        lazy_expr = lazy.broadcast(lazy.array(operand.shape), new_sizes, bcast_dims)
3581      else:
3582        lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)
3583      return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
3584  return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
3585                             dimensions=dimensions)
3586
3587def _is_singleton_reshape(old, new):
3588  # A singleton reshape is one where only singleton dimensions are added. We
3589  # want to detect them because they can be expressed as (lazy) broadcasts.
3590  old, new = iter(old), iter(new)
3591  d1, d2 = next(old, None), next(new, None)
3592  bcast_dims = []
3593  i = 0
3594  while True:
3595    if d1 is d2 is None:
3596      return bcast_dims
3597    elif d1 == d2:
3598      bcast_dims.append(i)
3599      i += 1
3600      d1, d2 = next(old, None), next(new, None)
3601    elif d2 == 1:
3602      i += 1
3603      d2 = next(new, None)
3604    else:
3605      return None
3606
3607def _reshape_shape_rule(operand, *, new_sizes, dimensions):
3608  if not np.all(np.greater_equal(new_sizes, 0)):
3609    msg = 'reshape new_sizes must all be positive, got {}.'
3610    raise TypeError(msg.format(new_sizes))
3611  if prod(np.shape(operand)) != prod(new_sizes):
3612    msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
3613    raise TypeError(msg.format(new_sizes, np.shape(operand)))
3614  if dimensions is not None:
3615    if set(dimensions) != set(range(np.ndim(operand))):
3616      msg = ('reshape dimensions must be a permutation of operand dimensions, '
3617             'got dimensions {} for shape {}.')
3618      raise TypeError(msg.format(dimensions, np.shape(operand)))
3619  return tuple(new_sizes)
3620
3621def _reshape_dtype_rule(operand, *, new_sizes, dimensions):
3622  return operand.dtype
3623
3624def _reshape_translation_rule(c, operand, *, new_sizes, dimensions):
3625  if dimensions is None:
3626    return xops.Reshape(operand, new_sizes)
3627  else:
3628    return xops.Reshape(operand, dimensions, new_sizes)
3629
3630def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions):
3631  assert ad.is_undefined_primal(operand)
3632  if dimensions is None:
3633    return [reshape(t, operand.aval.shape)]
3634  else:
3635    return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)),
3636                      np.argsort(dimensions))]
3637
3638def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
3639  operand, = batched_args
3640  bdim, = batch_dims
3641  operand = batching.moveaxis(operand, bdim, 0)
3642  if dimensions is not None:
3643    dimensions = (0,) + tuple(np.add(1, dimensions))
3644  return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
3645
3646def _reshape_masking_rule(padded_args, logical_shapes, polymorphic_shapes,
3647                          new_sizes, dimensions):
3648  operand, = padded_args
3649  old_shape, = polymorphic_shapes
3650  def is_poly(size): return type(size) is masking.Poly and not size.is_constant
3651  def merge_const_sizes(shape):
3652    """Merges all nonpolymorphic sizes into the previous polymorphic size."""
3653    poly_dims = [i for i, size in enumerate(shape) if is_poly(size)]
3654    return [prod(shape[start:stop])
3655            for start, stop in zip([0] + poly_dims, poly_dims + [len(shape)])]
3656  if merge_const_sizes(old_shape) != merge_const_sizes(new_sizes):
3657    raise NotImplementedError(
3658      "Reshape on padded dimensions causing fragmentation is not supported.")
3659
3660  return reshape(operand,
3661                 new_sizes=masking.padded_shape_as_value(new_sizes),
3662                 dimensions=dimensions)
3663
3664reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
3665                               'reshape', _reshape_translation_rule)
3666reshape_p.def_impl(_reshape_impl)
3667ad.deflinear2(reshape_p, _reshape_transpose_rule)
3668batching.primitive_batchers[reshape_p] = _reshape_batch_rule
3669masking.masking_rules[reshape_p] = _reshape_masking_rule
3670
3671def _rev_shape_rule(operand, *, dimensions):
3672  _check_shapelike('rev', 'dimensions', dimensions)
3673  if len(set(dimensions)) != len(dimensions):
3674    msg = 'rev dimensions must be unique, got {}.'
3675    raise TypeError(msg.format(dimensions))
3676  if dimensions and not _max(dimensions) < operand.ndim:
3677    msg = ('rev dimensions must all be less than operand ndim, got dimensions '
3678           '{} for operand ndim {}.')
3679    raise TypeError(msg.format(dimensions, operand.ndim))
3680  return operand.shape
3681
3682def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
3683  operand, = batched_args
3684  bdim, = batch_dims
3685  new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
3686  return rev(operand, new_dimensions), bdim
3687
3688rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
3689ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
3690batching.primitive_batchers[rev_p] = _rev_batch_rule
3691
3692
3693def _transpose_impl(operand, *, permutation):
3694  if xla.type_is_device_array(operand):
3695    if operand._lazy_expr is None:
3696      lazy_expr = lazy.transpose(lazy.array(operand.shape), permutation)
3697    else:
3698      lazy_expr = lazy.transpose(operand._lazy_expr, permutation)
3699    aval = ShapedArray(lazy_expr.shape, operand.dtype)
3700    return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
3701  else:
3702    return xla.apply_primitive(transpose_p, operand, permutation=permutation)
3703
3704def _transpose_shape_rule(operand, *, permutation):
3705  if not isinstance(permutation, (tuple, list, np.ndarray)):
3706    msg = "transpose permutation must be a tuple/list/ndarray, got {}."
3707    raise TypeError(msg.format(type(permutation)))
3708  if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
3709    msg = ("transpose permutation isn't a permutation of operand dimensions, "
3710           "got permutation {} for operand shape {}.")
3711    raise TypeError(msg.format(permutation, operand.shape))
3712  return tuple(np.take(operand.shape, permutation))
3713
3714def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
3715  operand, = batched_args
3716  bdim, = batch_dims
3717  perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
3718  return transpose(operand, perm), 0
3719
3720def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
3721  return transpose(*padded_vals, permutation=permutation)
3722
3723transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
3724                                 'transpose')
3725transpose_p.def_impl(_transpose_impl)
3726ad.deflinear2(transpose_p,
3727              lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
3728batching.primitive_batchers[transpose_p] = _transpose_batch_rule
3729masking.masking_rules[transpose_p] = _transpose_masking_rule
3730
3731
3732def _select_shape_rule(pred, on_true, on_false):
3733  if on_true.shape != on_false.shape:
3734    msg = "select on_true and on_false must have the same shape, got {} and {}."
3735    raise TypeError(msg.format(on_true.shape, on_false.shape))
3736  if pred.shape and pred.shape != on_true.shape:
3737    msg = ("select pred must be scalar or have the same shape as on_true and "
3738           "on_false, got pred shape {} for on_true and on_false of shape {}.")
3739    raise TypeError(msg.format(pred.shape, on_true.shape))
3740  return on_true.shape
3741
3742def _select_dtype_rule(pred, on_true, on_false):
3743  _check_same_dtypes("select", False, on_true.dtype, on_false.dtype)
3744  if not dtypes.issubdtype(pred.dtype, np.bool_):
3745    msg = "select pred must be boolean type, got {}."
3746    raise TypeError(msg.format(pred.dtype))
3747  return on_true.dtype
3748
3749def _select_transpose_rule(t, pred, on_true, on_false):
3750  assert not ad.is_undefined_primal(pred)
3751  if type(t) is ad_util.Zero:
3752    return [None,
3753            ad_util.Zero(on_true.aval) if ad.is_undefined_primal(on_true) else None,
3754            ad_util.Zero(on_false.aval) if ad.is_undefined_primal(on_false) else None]
3755  else:
3756    zeros = full_like(t, 0)
3757    return [None,
3758            select(pred, t, zeros) if ad.is_undefined_primal(on_true) else None,
3759            select(pred, zeros, t) if ad.is_undefined_primal(on_false) else None]
3760
3761def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
3762  pred, on_true, on_false, = batched_args
3763  pred_bdim, ot_bdim, of_bdim = batch_dims
3764  size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
3765              if i is not None)
3766
3767  # avoid transposes and some broadcasts in special cases
3768  if pred_bdim == ot_bdim == of_bdim:
3769    if np.shape(pred) == np.shape(on_true):
3770      return select(pred, on_true, on_false), pred_bdim
3771    else:
3772      # vmapped function had a scalar pred with nonscalar args
3773      assert np.ndim(pred) == 1
3774      pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim])
3775      return select(pred, on_true, on_false), pred_bdim
3776  elif np.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
3777    if ot_bdim == of_bdim:
3778      return select(pred, on_true, on_false), ot_bdim
3779    elif np.shape(on_true) == np.shape(on_false):
3780      on_false = batching.moveaxis(on_false, of_bdim, ot_bdim)
3781      return select(pred, on_true, on_false), ot_bdim
3782
3783  pred = batching.bdim_at_front(pred, pred_bdim, size) if np.shape(pred) else pred
3784  if not np.shape(on_true) == np.shape(on_false) == ():
3785    on_true = batching.bdim_at_front(on_true, ot_bdim, size)
3786    on_false = batching.bdim_at_front(on_false, of_bdim, size)
3787  assert np.shape(on_true) == np.shape(on_false)
3788  if 0 < np.ndim(pred) < np.ndim(on_true):
3789    # vmapped function had a scalar pred with nonscalar args
3790    assert np.ndim(pred) == 1
3791    pred = broadcast_in_dim(pred, on_true.shape, [0])
3792  if np.ndim(pred) > np.ndim(on_true):
3793    assert np.ndim(on_true) == 0
3794    on_true = broadcast(on_true, pred.shape)
3795    on_false = broadcast(on_false, pred.shape)
3796  return select(pred, on_true, on_false), 0
3797
3798def _select_masking_rule(padded_vals, logical_shapes):
3799  pred_shape, true_shape, false_shape = [
3800      masking.padded_shape_as_value(val.shape) for val in padded_vals]
3801  assert np.array_equal(pred_shape, true_shape)
3802  assert np.array_equal(pred_shape, false_shape)
3803  return select(*padded_vals)
3804
3805def _select_jvp(primals, tangents):
3806  pred, on_true, on_false = primals
3807  _, on_true_dot, on_false_dot = tangents
3808  out = select(pred, on_true, on_false)
3809  if type(on_true_dot) is ad_util.Zero:
3810    out_dot = select(pred, _zeros(on_false_dot), on_false_dot)
3811  elif type(on_false_dot) is ad_util.Zero:
3812    out_dot = select(pred, on_true_dot, _zeros(on_true_dot))
3813  else:
3814    out_dot = select(pred, on_true_dot, on_false_dot)
3815  return out, out_dot
3816
3817select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
3818ad.primitive_jvps[select_p] = _select_jvp
3819ad.primitive_transposes[select_p] = _select_transpose_rule
3820batching.primitive_batchers[select_p] = _select_batch_rule
3821masking.masking_rules[select_p] = _select_masking_rule
3822
3823
3824def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
3825  _check_shapelike("slice", "start_indices", start_indices)
3826  _check_shapelike("slice", "limit_indices", limit_indices)
3827  if operand.ndim != len(start_indices):
3828    msg = ("slice start_indices must have length equal to the number of "
3829           "dimensions of the operand, got indices {} for operand shape {}.")
3830    raise TypeError(msg.format(start_indices, operand.shape))
3831  if len(start_indices) != len(limit_indices):
3832    msg = ("slice limit_indices must have the same length as start_indices, "
3833           "got start_inidices {} and limit_indices {}.")
3834    raise TypeError(msg.format(start_indices, limit_indices))
3835  if (not masking.is_polymorphic(limit_indices) and
3836      not masking.is_polymorphic(operand.shape) and
3837      not np.all(np.less_equal(limit_indices, operand.shape))):
3838    msg = ("slice limit_indices must be less than or equal to operand shape, "
3839           "got limit_indices {} for operand shape {}.")
3840    raise TypeError(msg.format(limit_indices, operand.shape))
3841  if not np.all(np.greater_equal(start_indices, 0)):
3842    msg = ("slice start_indices must be greater than or equal to zero, "
3843           "got start_indices of {}.")
3844    raise TypeError(msg.format(start_indices))
3845  if (not masking.is_polymorphic(limit_indices) and
3846      not np.all(np.greater_equal(limit_indices, start_indices))):
3847    msg = ("slice limit_indices must be greater than or equal to start_indices,"
3848           " got start_indices {} and limit_indices {}.")
3849    raise TypeError(msg.format(start_indices, limit_indices))
3850  if strides is None:
3851    strides = np.ones(operand.ndim, np.int32)
3852  else:
3853    _check_shapelike("slice", "strides", strides)
3854    if len(strides) != operand.ndim:
3855      msg = ("slice strides must have length equal to the number of dimensions "
3856             "of the operand, got strides {} for operand shape {}.")
3857      raise TypeError(msg.format(strides, operand.shape))
3858    if not np.all(np.greater(strides, 0)):
3859      msg = "slice strides must be positive, got {}"
3860      raise TypeError(msg.format(strides))
3861
3862  diff = np.subtract(limit_indices, start_indices)
3863  # Not np.divmod since Poly.__rdivmod__ is ignored by NumPy, breaks poly stride
3864  return tuple(q + (r > 0) for q, r in map(divmod, diff, strides))
3865
3866def _slice_translation_rule(c, operand, *, start_indices, limit_indices,
3867                            strides):
3868  return xops.Slice(operand, start_indices, limit_indices,
3869                    strides or [1] * len(start_indices))
3870
3871def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
3872  assert ad.is_undefined_primal(operand)
3873  operand_shape = operand.aval.shape
3874  if strides is None or np.all(np.equal(strides, 1)):
3875    pads = zip(start_indices, np.subtract(operand_shape, limit_indices),
3876               (0,) * len(start_indices))
3877  else:
3878    real_limits = np.add(
3879      start_indices,
3880      np.where(np.array(t.shape) == 0, 0,
3881               np.add(1, np.multiply(np.subtract(t.shape, 1), strides))))
3882    pads = safe_zip(start_indices, np.subtract(operand_shape, real_limits),
3883                    np.subtract(strides, 1))
3884  result = pad(t, _const(t, 0), pads)
3885  assert result.shape == operand_shape, (
3886    f"result.shape={result.shape} operand_shape={operand_shape}")
3887  return [result]
3888
3889
3890def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
3891                         limit_indices, strides):
3892  operand, = batched_args
3893  bdim, = batch_dims
3894
3895  new_start_indices = list(start_indices)
3896  new_start_indices.insert(bdim, 0)
3897
3898  new_limit_indices = list(limit_indices)
3899  new_limit_indices.insert(bdim, operand.shape[bdim])
3900
3901  if strides is None:
3902    new_strides = None
3903  else:
3904    new_strides = list(strides)
3905    new_strides.insert(bdim, 1)
3906
3907  out = slice(operand, new_start_indices, new_limit_indices, new_strides)
3908  return out, bdim
3909
3910def _slice_masking_rule(
3911    padded_vals, logical_shapes, start_indices, limit_indices, strides):
3912  operand, = padded_vals
3913  strides = masking.padded_shape_as_value(strides) if strides else None
3914  return slice(operand,
3915               start_indices=masking.padded_shape_as_value(start_indices),
3916               limit_indices=masking.padded_shape_as_value(limit_indices),
3917               strides=strides)
3918
3919slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
3920                             _slice_translation_rule)
3921ad.deflinear2(slice_p, _slice_transpose_rule)
3922batching.primitive_batchers[slice_p] = _slice_batching_rule
3923masking.masking_rules[slice_p] = _slice_masking_rule
3924
3925
3926def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
3927  if operand.ndim != len(start_indices):
3928    msg = ("dynamic_slice start_indices must have length equal to the number "
3929           "of dimensions of the operand, got indices {} for operand shape {}.")
3930    raise TypeError(msg.format(start_indices, operand.shape))
3931  if len(start_indices) != len(slice_sizes):
3932    msg = ("dynamic_slice slice_sizes must have the same length as "
3933           "start_indices, got start_inidices length {} and slice_sizes {}.")
3934    raise TypeError(msg.format(len(start_indices), slice_sizes))
3935  if not np.all(np.less_equal(slice_sizes, operand.shape)):
3936    msg = ("slice slice_sizes must be less than or equal to operand shape, "
3937           "got slice_sizes {} for operand shape {}.")
3938    raise TypeError(msg.format(slice_sizes, operand.shape))
3939  if not np.all(np.greater_equal(slice_sizes, 0)):
3940    msg = ("slice slice_sizes must be greater than or equal to zero, "
3941           "got slice_sizes of {}.")
3942    raise TypeError(msg.format(slice_sizes))
3943  return tuple(slice_sizes)
3944
3945def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
3946  if any(i.dtype != start_indices[0].dtype or
3947         not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
3948    msg = ("index arguments to dynamic_slice must be integers of the same "
3949           "type, got: {}")
3950    raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices)))
3951  return operand.dtype
3952
3953def _dynamic_slice_translation_rule(c, operand, *start_indices, slice_sizes):
3954  return xops.DynamicSlice(operand, start_indices, slice_sizes)
3955
3956def _dynamic_slice_jvp(primals, tangents, *, slice_sizes):
3957  tangent_out = tangents[0]
3958  if type(tangent_out) is not ad_util.Zero:
3959    tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes)
3960  return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out
3961
3962def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
3963  assert ad.is_undefined_primal(operand)
3964  assert all(not ad.is_undefined_primal(s) for s in start_indices)
3965  operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype
3966  if type(t) is ad_util.Zero:
3967    return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
3968  else:
3969    if config.omnistaging_enabled:
3970      zeros = full(operand_shape, 0, operand_dtype)
3971    else:
3972      zeros = full(operand_shape, tie_in(t, _zero(t)))
3973    return ([dynamic_update_slice(zeros, t, start_indices)] +
3974            [None] * len(start_indices))
3975
3976def _batch_dynamic_slice_indices(indices, bdims):
3977  if len(indices) == 0:
3978    return np.array([], 'int32'), None
3979  size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
3980  if size < 0:
3981    return concatenate([broadcast(i, (1,)) for i in indices], 0), None
3982  indices = concatenate(
3983    [broadcast_in_dim(x, (size, 1),
3984                      broadcast_dimensions=((0,) if i is not None else ()))
3985     for x, i in zip(indices, bdims)],
3986    dimension=1)
3987  return indices, 0
3988
3989def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
3990  # A dynamic slice is a special case of gather; we can delegate to the gather
3991  # batching rule.
3992  # TODO(phawkins): consider removing dynamic_slice entirely and using gather
3993  # always.
3994  operand, *start_indices = batched_args
3995  operand_bd, *start_idx_bds = batch_dims
3996  operand_shape = (operand.shape if operand_bd is batching.not_mapped
3997                   else tuple(np.delete(operand.shape, operand_bd)))
3998  dims = tuple(range(len(operand_shape)))
3999  dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
4000                                 start_index_map=dims)
4001  index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
4002  return _gather_batching_rule(
4003    [operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
4004    slice_sizes=slice_sizes)
4005
4006
4007dynamic_slice_p = standard_primitive(
4008    _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
4009    _dynamic_slice_translation_rule)
4010ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp  # TODO
4011ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
4012batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
4013
4014
4015def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
4016  if operand.ndim != update.ndim:
4017    msg = ("dynamic_update_slice update must have the same rank as operand, "
4018           "got update shape {} for operand shape {}.")
4019    raise TypeError(msg.format(update.shape, operand.shape))
4020  if operand.ndim != len(start_indices):
4021    msg = ("dynamic_update_slice start_indices must have length equal to the "
4022           "rank of operand, got indices {} for operand shape {}.")
4023    raise TypeError(msg.format(start_indices, operand.shape))
4024  if not np.all(np.less_equal(update.shape, operand.shape)):
4025    msg = ("dynamic_update_slice update shape must be smaller than operand "
4026           "shape, got update shape {} for operand shape {}.")
4027    raise TypeError(msg.format(update.shape, operand.shape))
4028  return operand.shape
4029
4030def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
4031  _check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype)
4032  if any(i.dtype != start_indices[0].dtype or
4033         not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
4034    msg = ("index arguments to dynamic_update_slice must be integers of the "
4035           "same type, got {}")
4036    raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices)))
4037  return operand.dtype
4038
4039def _dynamic_update_slice_jvp(primals, tangents):
4040  operand, update = primals[:2]
4041  start_indices = primals[2:]
4042  g_operand, g_update = tangents[:2]
4043  val_out = dynamic_update_slice(operand, update, start_indices)
4044  if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
4045    tangent_out = ad_util.Zero.from_value(val_out)
4046  else:
4047    g_operand = ad.instantiate_zeros(g_operand)
4048    g_update = ad.instantiate_zeros(g_update)
4049    tangent_out = dynamic_update_slice(g_operand, g_update, start_indices)
4050  return val_out, tangent_out
4051
4052def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
4053  assert all(not ad.is_undefined_primal(x) for x in start_indices)
4054  if ad.is_undefined_primal(update):
4055    update_shape = update.aval.shape
4056  else:
4057    update_shape = update.shape
4058  if type(t) is ad_util.Zero:
4059    operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
4060    update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None
4061  else:
4062    dus = dynamic_update_slice
4063    ds = dynamic_slice
4064    zeros = _zeros(t, shape=update_shape)
4065    operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None
4066    update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None
4067  return [operand_t, update_t] + [None] * len(start_indices)
4068
4069def _dynamic_update_slice_translation_rule(c, operand, update, *start_indices):
4070  return xops.DynamicUpdateSlice(operand, update, start_indices)
4071
4072def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
4073  # A dynamic update slice is a special case of scatter; we can delegate to the
4074  # scatter batching rule.
4075  # TODO(phawkins): consider removing dynamic_update_slice entirely and using
4076  # scatter always.
4077  operand, update, *start_idx = batched_args
4078  operand_bd, update_bd, *start_idx_bd = batch_dims
4079  update_shape = (np.shape(update) if update_bd is batching.not_mapped
4080                  else tuple(np.delete(np.shape(update), update_bd)))
4081  dims = tuple(range(len(update_shape)))
4082  dnums = ScatterDimensionNumbers(update_window_dims=dims,
4083                                  inserted_window_dims=(),
4084                                  scatter_dims_to_operand_dims=dims)
4085  index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd)
4086  return _scatter_batching_rule(
4087    scatter, (operand, index, update), (operand_bd, index_bdim, update_bd),
4088    update_jaxpr=None, update_consts=None, dimension_numbers=dnums,
4089    indices_are_sorted=True, unique_indices=True)
4090
4091
4092dynamic_update_slice_p = standard_primitive(
4093    _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
4094    'dynamic_update_slice', _dynamic_update_slice_translation_rule)
4095ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
4096ad.primitive_transposes[dynamic_update_slice_p] = \
4097    _dynamic_update_slice_transpose_rule
4098batching.primitive_batchers[dynamic_update_slice_p] = \
4099    _dynamic_update_slice_batching_rule
4100
4101
4102def _gather_dimensions_proto(indices_shape, dimension_numbers):
4103  assert type(dimension_numbers) is GatherDimensionNumbers
4104  proto = xla_client.GatherDimensionNumbers()
4105  proto.offset_dims.extend(dimension_numbers.offset_dims)
4106  proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
4107  proto.start_index_map.extend(dimension_numbers.start_index_map)
4108  assert indices_shape.rank() > 0
4109  proto.index_vector_dim = indices_shape.rank() - 1
4110  return proto
4111
4112def _gather_dtype_rule(operand, start_indices, **kwargs):
4113  if not dtypes.issubdtype(start_indices.dtype, np.integer):
4114    raise ValueError("start_indices must have an integer type")
4115  return dtypes.canonicalize_dtype(operand.dtype)
4116
4117_rank = lambda arr: len(arr.shape)
4118
4119def _is_sorted(dims, op_name, name):
4120  for i in range(1, len(dims)):
4121    if dims[i] < dims[i - 1]:
4122      raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}")
4123
4124def _sorted_dims_in_range(dims, rank, op_name, name):
4125  if len(dims) == 0:
4126    return
4127  invalid_dim = None
4128  if dims[0] < 0:
4129    invalid_dim = dims[0]
4130  elif dims[-1] >= rank:
4131    invalid_dim = dims[-1]
4132  if invalid_dim:
4133    raise TypeError(f"Invalid {name} set in {op_name} op; valid range is "
4134                    f"[0, {rank}); got: {invalid_dim}.")
4135
4136def _no_duplicate_dims(dims, op_name, name):
4137  if len(set(dims)) != len(dims):
4138    raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.")
4139
4140def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
4141                       slice_sizes):
4142  """Validates the well-formedness of the arguments to Gather.
4143
4144  The code implements the checks based on the detailed operation semantics of
4145  XLA's `Gather <https://www.tensorflow.org/xla/operation_semantics#gather>`_
4146  operator and following the outline of the implementation of
4147  ShapeInference::InferGatherShape in TensorFlow.
4148  """
4149
4150  offset_dims = dimension_numbers.offset_dims
4151  collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
4152  start_index_map = dimension_numbers.start_index_map
4153
4154  # Note: in JAX, index_vector_dim is always computed as below, cf. the
4155  # documentation of the GatherDimensionNumbers class.
4156  index_vector_dim = _rank(start_indices) - 1
4157
4158  # This case should never happen in JAX, due to the implicit construction of
4159  # index_vector_dim, but is included for completeness.
4160  if _rank(start_indices) < index_vector_dim or index_vector_dim < 0:
4161    raise TypeError(f"Gather index leaf dimension must be within [0, rank("
4162                    f"start_indices) + 1). rank(start_indices) is "
4163                    f"{_rank(start_indices)} and gather index leaf dimension "
4164                    f"is {index_vector_dim}.")
4165
4166  expanded_start_indices_shape = list(start_indices.shape)
4167
4168  # This case should never happen in JAX, due to the implicit construction of
4169  # index_vector_dim, but is included for completeness.
4170  if len(expanded_start_indices_shape) == index_vector_dim:
4171    expanded_start_indices_shape.append(1)
4172
4173  # Start ValidateGatherDimensions
4174  # In the error messages output by XLA, "offset_dims" is called "Output window
4175  # dimensions" in error messages. For consistency's sake, our error messages
4176  # stick to "offset_dims".
4177  _is_sorted(offset_dims, "gather", "offset_dims")
4178  _no_duplicate_dims(offset_dims, "gather", "offset_dims")
4179
4180  output_offset_dim_count = len(offset_dims)
4181  output_shape_rank = len(offset_dims) + _rank(start_indices) - 1
4182
4183  for i in range(output_offset_dim_count):
4184    offset_dim = offset_dims[i]
4185    if offset_dim < 0 or offset_dim >= output_shape_rank:
4186      raise TypeError(f"Offset dimension {i} in gather op is out of bounds; "
4187                      f"got {offset_dim}, but should have been in "
4188                      f"[0, {output_shape_rank})")
4189
4190  if len(start_index_map) != start_indices.shape[index_vector_dim]:
4191    raise TypeError(f"Gather op has {len(start_index_map)} elements in "
4192                    f"start_index_map and the bound of dimension "
4193                    f"index_vector_dim={index_vector_dim} of start_indices is "
4194                    f"{start_indices.shape[index_vector_dim]}. These two "
4195                    f"numbers must be equal.")
4196
4197  for i in range(len(start_index_map)):
4198    operand_dim_for_start_index_i = start_index_map[i]
4199    if (operand_dim_for_start_index_i < 0 or
4200        operand_dim_for_start_index_i >= _rank(operand)):
4201      raise TypeError(f"Invalid start_index_map; domain is "
4202                      f"[0, {_rank(operand)}), got: "
4203                      f"{i}->{operand_dim_for_start_index_i}.")
4204
4205  _no_duplicate_dims(start_index_map, "gather", "start_index_map")
4206
4207  # _is_sorted and _sorted_dims_in_range are checked in the opposite order
4208  # compared to the XLA implementation. In cases when the input is not sorted
4209  # AND there are problematic collapsed_slice_dims, the error message will thus
4210  # be different.
4211  _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims")
4212  _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather",
4213                        "collapsed_slice_dims")
4214  _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims")
4215  # End ValidateGatherDimensions
4216
4217  if _rank(operand) != len(slice_sizes):
4218    raise TypeError(f"Gather op must have one slice size for every input "
4219                    f"dimension; got: len(slice_sizes)={len(slice_sizes)}, "
4220                    f"input_shape.rank={_rank(operand)}")
4221
4222  if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims):
4223    raise TypeError(f"All components of the offset index in a gather op must "
4224                    f"either be a offset dimension or explicitly collapsed; "
4225                    f"got len(slice_sizes)={len(slice_sizes)}, "
4226                    f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
4227                    f"{collapsed_slice_dims}.")
4228
4229  for i in range(len(slice_sizes)):
4230    slice_size = slice_sizes[i]
4231    corresponding_input_size = operand.shape[i]
4232
4233    if slice_size < 0 or slice_size > corresponding_input_size:
4234      raise TypeError(f"Slice size at index {i} in gather op is out of range, "
4235                      f"must be within [0, {corresponding_input_size + 1}), "
4236                      f"got {slice_size}.")
4237
4238  for i in range(len(collapsed_slice_dims)):
4239    bound = slice_sizes[collapsed_slice_dims[i]]
4240    if bound > 1:
4241      raise TypeError(f"Gather op can only collapse slice dims with bound 1 "
4242                      f"or 0, but bound is {bound} for index "
4243                      f"{collapsed_slice_dims[i]} at position {i}.")
4244
4245  expanded_start_indices_shape.pop(index_vector_dim)
4246  start_indices_shape = iter(expanded_start_indices_shape)
4247
4248  slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims))
4249  return tuple(next(slice_sizes) if i in offset_dims
4250               else next(start_indices_shape) for i in range(output_shape_rank))
4251
4252def _gather_translation_rule(c, operand, start_indices, *, dimension_numbers,
4253                             slice_sizes):
4254  indices_shape = c.get_shape(start_indices)
4255  return xops.Gather(
4256    operand, start_indices,
4257    _gather_dimensions_proto(indices_shape, dimension_numbers), slice_sizes,
4258    indices_are_sorted=False)
4259
4260def _gather_jvp_rule(g, operand, start_indices, *, dimension_numbers,
4261                     slice_sizes):
4262  return gather(g, start_indices, dimension_numbers, slice_sizes)
4263
4264def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
4265                          slice_sizes):
4266  assert ad.is_undefined_primal(operand)
4267  operand_shape = operand.aval.shape
4268  if type(t) is ad_util.Zero:
4269    out = ad_util.Zero(operand.aval)
4270  else:
4271    if config.omnistaging_enabled:
4272      zeros = full(operand_shape, _zero(t))
4273    else:
4274      zeros = full(operand_shape, tie_in(t, _zero(t)))
4275    scatter_dnums = ScatterDimensionNumbers(
4276      update_window_dims=dimension_numbers.offset_dims,
4277      inserted_window_dims=dimension_numbers.collapsed_slice_dims,
4278      scatter_dims_to_operand_dims=dimension_numbers.start_index_map)
4279    out = scatter_add(zeros, start_indices, t, scatter_dnums,
4280                      indices_are_sorted=False,
4281                      unique_indices=False)
4282  return [out, None]
4283
4284def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
4285                          slice_sizes):
4286  operand, start_indices = batched_args
4287  operand_bdim, start_indices_bdim = batch_dims
4288
4289  if operand_bdim is not None and start_indices_bdim is None:
4290    operand = batching.moveaxis(operand, operand_bdim, 0)
4291    slice_sizes = (operand.shape[0],) + slice_sizes
4292    offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims))
4293    collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
4294    start_index_map = tuple(np.add(1, dimension_numbers.start_index_map))
4295    dnums = GatherDimensionNumbers(
4296        offset_dims=offset_dims,
4297        collapsed_slice_dims=collapsed_slice_dims,
4298        start_index_map=start_index_map)
4299    return gather(operand, start_indices, dimension_numbers=dnums,
4300                  slice_sizes=slice_sizes), 0
4301
4302  elif operand_bdim is None and start_indices_bdim is not None:
4303    start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
4304    offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
4305    dnums = GatherDimensionNumbers(
4306        offset_dims=offset_dims,
4307        collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
4308        start_index_map=dimension_numbers.start_index_map)
4309    return gather(operand, start_indices, dimension_numbers=dnums,
4310                  slice_sizes=slice_sizes), 0
4311
4312  else:
4313    # move batch dimensions to the front to simplify logic
4314    operand = batching.moveaxis(operand, operand_bdim, 0)
4315    start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
4316
4317    # Example: user code had start_indices shape (3, 4, 5), and we have to deal
4318    # with start_indices shape (7, 3, 4, 5). We transform that to a
4319    # start_indices of shape (7, 3, 4, 6) where we concatenated an iota that
4320    # counts along our batch dimension to the front of the ndindex.
4321    count_shape = list(start_indices.shape)
4322    count_shape[-1] = 1
4323    counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
4324    start_indices = concatenate([counts, start_indices], len(count_shape) - 1)
4325
4326    slice_sizes = (_min(operand.shape[0], 1),) + slice_sizes
4327    collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
4328    offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
4329    start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))
4330
4331    dnums = GatherDimensionNumbers(
4332        offset_dims=offset_dims,
4333        collapsed_slice_dims=collapsed_slice_dims,
4334        start_index_map=start_index_map)
4335    return gather(operand, start_indices, dimension_numbers=dnums,
4336                  slice_sizes=slice_sizes), 0
4337
4338gather_p = standard_primitive(
4339    _gather_shape_rule, _gather_dtype_rule, 'gather',
4340    _gather_translation_rule)
4341ad.defjvp(gather_p, _gather_jvp_rule, None)
4342
4343ad.primitive_transposes[gather_p] = _gather_transpose_rule
4344batching.primitive_batchers[gather_p] = _gather_batching_rule
4345
4346
4347def _scatter_dimensions_proto(indices_shape, dimension_numbers):
4348  assert type(dimension_numbers) is ScatterDimensionNumbers
4349  proto = xla_client.ScatterDimensionNumbers()
4350  proto.update_window_dims.extend(dimension_numbers.update_window_dims)
4351  proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
4352  proto.scatter_dims_to_operand_dims.extend(
4353      dimension_numbers.scatter_dims_to_operand_dims)
4354  assert indices_shape.rank() > 0
4355  proto.index_vector_dim = indices_shape.rank() - 1
4356  return proto
4357
4358def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs):
4359  if not dtypes.issubdtype(scatter_indices.dtype, np.integer):
4360    raise ValueError("scatter_indices must have an integer type")
4361  _check_same_dtypes("scatter", False, operand.dtype, updates.dtype)
4362  return dtypes.canonicalize_dtype(operand.dtype)
4363
4364def _scatter_shape_rule(operand, scatter_indices, updates, *, update_jaxpr,
4365                        update_consts, dimension_numbers, indices_are_sorted,
4366                        unique_indices):
4367  """Validates the well-formedness of the ``dimension_numbers`` argument to
4368  Scatter.
4369
4370  The code implements the checks based on the detailed operation semantics of
4371  XLA's `Scatter <https://www.tensorflow.org/xla/operation_semantics#scatter>`_
4372  operator and following the outline of the implementation of
4373  ShapeInference::InferScatterShape in TensorFlow.
4374  """
4375
4376  update_window_dims = dimension_numbers.update_window_dims
4377  inserted_window_dims = dimension_numbers.inserted_window_dims
4378  scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims
4379  # Note: in JAX, index_vector_dim is always computed as below, cf. the
4380  # documentation of the ScatterDimensionNumbers class.
4381  index_vector_dim = _rank(scatter_indices) - 1
4382
4383  # This case should never happen in JAX, due to the implicit construction of
4384  # index_vector_dim, but is included for completeness.
4385  if _rank(scatter_indices) < index_vector_dim or index_vector_dim < 0:
4386    raise TypeError(f"Scatter index leaf dimension must be within [0, "
4387                    f"rank(scatter_indices) + 1). rank(scatter_indices) is "
4388                    f"{_rank(scatter_indices)} and scatter index leaf "
4389                    f"dimension is {index_vector_dim}.")
4390
4391  expanded_scatter_indices_shape = list(scatter_indices.shape)
4392  # This case should never happen in JAX, due to the implicit construction of
4393  # index_vector_dim, but is included for completeness.
4394  if len(expanded_scatter_indices_shape) == index_vector_dim:
4395    expanded_scatter_indices_shape.append(1)
4396
4397  expected_updates_rank = (len(expanded_scatter_indices_shape) - 1 +
4398                           len(update_window_dims))
4399
4400  if _rank(updates) != expected_updates_rank:
4401    raise TypeError(f"Updates tensor must be of rank {expected_updates_rank}; "
4402                    f"got {_rank(updates)}.")
4403
4404  # Validate update_window_dims
4405  _is_sorted(update_window_dims, "scatter", "update_window_dims")
4406  _no_duplicate_dims(update_window_dims, "scatter", "update_window_dims")
4407  _sorted_dims_in_range(update_window_dims, _rank(updates), "scatter",
4408                        "update_window_dims")
4409
4410  # Validate inserted_window_dims
4411  _is_sorted(inserted_window_dims, "scatter", "inserted_window_dims")
4412  _no_duplicate_dims(inserted_window_dims, "scatter", "inserted_window_dims")
4413  _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter",
4414                        "inserted_window_dims")
4415
4416  # Validate window_size
4417  window_size = len(update_window_dims) + len(inserted_window_dims)
4418  if _rank(operand) != window_size:
4419    raise TypeError(f"Scatter op has window of size {window_size}; doesn't "
4420                    f"match operand of rank {_rank(operand)}.")
4421
4422  # Validate scatter_dims_to_operand_dims
4423  if (len(scatter_dims_to_operand_dims) !=
4424      scatter_indices.shape[index_vector_dim]):
4425    raise TypeError(f"Scatter op has {len(scatter_dims_to_operand_dims)} "
4426                    f"elements in scatter_dims_to_operand_dims and the bound "
4427                    f"of dimension index_vector_dim={index_vector_dim} of "
4428                    f"scatter_indices is "
4429                    f"{scatter_indices.shape[index_vector_dim]}. These two "
4430                    f"numbers must be equal")
4431
4432  for i in range(len(scatter_dims_to_operand_dims)):
4433    dim = scatter_dims_to_operand_dims[i]
4434    if dim < 0 or dim >= _rank(operand):
4435      raise TypeError(f"Invalid scatter_dims_to_operand_dims mapping; domain "
4436                      f"is [0, {_rank(operand)}), got: {i}->{dim}.")
4437
4438  _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter",
4439                     "scatter_dims_to_operand_dims")
4440
4441  max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape))
4442                            if not i in set(inserted_window_dims)]
4443
4444  for i in range(len(update_window_dims)):
4445    update_window_dim = update_window_dims[i]
4446    if updates.shape[update_window_dim] > max_update_slice_sizes[i]:
4447      raise TypeError(f"Bounds of the window dimensions of updates must not "
4448                      f"exceed the bounds of the corresponding dimensions of "
4449                      f"operand. For dimension {update_window_dim}, updates "
4450                      f"bound is {updates.shape[update_window_dim]}, operand "
4451                      f"bound is {max_update_slice_sizes[i]}.")
4452
4453  update_scatter_dims = [dim for dim in range(_rank(updates)) if dim not in
4454                         set(update_window_dims)]
4455
4456  scatter_dims_seen = 0
4457  for i in update_scatter_dims:
4458    if scatter_dims_seen == index_vector_dim:
4459      scatter_dims_seen += 1
4460    if updates.shape[i] != expanded_scatter_indices_shape[scatter_dims_seen]:
4461      raise TypeError(f"Bounds of the scatter dimensions of updates must be "
4462                      f"the same as the bounds of the corresponding dimensions "
4463                      f"of scatter indices. For scatter dimension {i}, updates "
4464                      f"bound is {updates.shape[i]}, scatter_indices bound is "
4465                      f"{expanded_scatter_indices_shape[scatter_dims_seen]}.")
4466    scatter_dims_seen += 1
4467
4468  return operand.shape
4469
4470def _scatter_translation_rule(c, operand, scatter_indices, updates, *,
4471                              update_jaxpr, update_consts, dimension_numbers,
4472                              indices_are_sorted, unique_indices):
4473  dtype = c.get_shape(operand).numpy_dtype()
4474  init_value = xb.constant(c, np.array(0, dtype))
4475  update_computation = _reduction_computation(
4476      c, update_jaxpr, update_consts, init_value)
4477  indices_shape = c.get_shape(scatter_indices)
4478  return xops.Scatter(operand, scatter_indices, updates, update_computation,
4479                      _scatter_dimensions_proto(indices_shape, dimension_numbers),
4480                      indices_are_sorted, unique_indices)
4481
4482def _scatter_add_translation_rule(
4483    c, operand, scatter_indices, updates, *, update_jaxpr, update_consts,
4484    dimension_numbers, indices_are_sorted, unique_indices,
4485    expand_complex128=False):
4486  dtype = c.get_shape(operand).numpy_dtype()
4487  scatter_dims = _scatter_dimensions_proto(c.get_shape(scatter_indices),
4488                                           dimension_numbers)
4489
4490  def _make_reducer(dtype):
4491    subc = xla_bridge.make_computation_builder("scatter_add_reducer")
4492    shape = xc.Shape.array_shape(np.dtype(dtype), ())
4493    args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
4494    out = xops.Add(args[0], args[1])
4495    return subc.build(out)
4496
4497  if expand_complex128 and dtype == np.complex128:
4498    update_computation = _make_reducer(np.float64)
4499    re = xops.Scatter(xops.Real(operand), scatter_indices, xops.Real(updates),
4500                      update_computation, scatter_dims, indices_are_sorted,
4501                      unique_indices)
4502    im = xops.Scatter(xops.Imag(operand), scatter_indices, xops.Imag(updates),
4503                      update_computation, scatter_dims, indices_are_sorted,
4504                      unique_indices)
4505    return xops.Complex(re, im)
4506  else:
4507    update_computation = _make_reducer(dtype)
4508    return xops.Scatter(operand, scatter_indices, updates, update_computation,
4509                        scatter_dims, indices_are_sorted, unique_indices)
4510
4511def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
4512                     dimension_numbers, indices_are_sorted, unique_indices):
4513  operand, scatter_indices, updates = primals
4514  g_operand, g_scatter_indices, g_updates = tangents
4515  val_out = scatter_add_p.bind(
4516      operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
4517      update_consts=update_consts, dimension_numbers=dimension_numbers,
4518      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
4519  if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
4520    tangent_out = ad_util.Zero.from_value(val_out)
4521  else:
4522    g_operand = ad.instantiate_zeros(g_operand)
4523    g_updates = ad.instantiate_zeros(g_updates)
4524    tangent_out = scatter_add_p.bind(
4525        g_operand, scatter_indices, g_updates, update_jaxpr=update_jaxpr,
4526        update_consts=update_consts, dimension_numbers=dimension_numbers,
4527        indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
4528  return val_out, tangent_out
4529
4530def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *,
4531                                update_jaxpr, update_consts, dimension_numbers,
4532                                indices_are_sorted, unique_indices):
4533  assert not ad.is_undefined_primal(scatter_indices)
4534  if ad.is_undefined_primal(updates):
4535    updates_shape = updates.aval.shape
4536  else:
4537    updates_shape = updates.shape
4538  if type(t) is ad_util.Zero:
4539    operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
4540    update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None
4541  else:
4542    operand_t = update_t = None
4543    if ad.is_undefined_primal(operand):
4544      operand_t = t
4545
4546    if ad.is_undefined_primal(updates):
4547      gather_dnums = GatherDimensionNumbers(
4548        offset_dims=dimension_numbers.update_window_dims,
4549        collapsed_slice_dims=dimension_numbers.inserted_window_dims,
4550        start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
4551      slice_sizes = []
4552      pos = 0
4553      for i in range(len(t.shape)):
4554        if i in dimension_numbers.inserted_window_dims:
4555          slice_sizes.append(1)
4556        else:
4557          slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
4558          pos += 1
4559      update_t = gather(t, scatter_indices, dimension_numbers=gather_dnums,
4560                        slice_sizes=slice_sizes)
4561  return [operand_t, None, update_t]
4562
4563def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *,
4564                                update_jaxpr, update_consts, dimension_numbers,
4565                                indices_are_sorted, unique_indices):
4566  assert not ad.is_undefined_primal(scatter_indices)
4567  if ad.is_undefined_primal(updates):
4568    updates_shape = updates.aval.shape
4569  else:
4570    updates_shape = updates.shape
4571  if type(t) is ad_util.Zero:
4572    operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
4573    update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None
4574  else:
4575    operand_t = update_t = None
4576    if ad.is_undefined_primal(operand):
4577      operand_t = scatter_mul(
4578          t, scatter_indices, updates, dimension_numbers=dimension_numbers,
4579          indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
4580    if ad.is_undefined_primal(updates):
4581      gather_dnums = GatherDimensionNumbers(
4582        offset_dims=dimension_numbers.update_window_dims,
4583        collapsed_slice_dims=dimension_numbers.inserted_window_dims,
4584        start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
4585      slice_sizes = []
4586      pos = 0
4587      for i in range(len(t.shape)):
4588        if i in dimension_numbers.inserted_window_dims:
4589          slice_sizes.append(1)
4590        else:
4591          slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
4592          pos += 1
4593      update_t = gather(mul(t, operand), scatter_indices,
4594                        dimension_numbers=gather_dnums, slice_sizes=slice_sizes)
4595  return [operand_t, None, update_t]
4596
4597
4598def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
4599                           update_jaxpr, update_consts, dimension_numbers,
4600                           indices_are_sorted, unique_indices):
4601  operand, scatter_indices, updates = batched_args
4602  operand_bdim, scatter_indices_bdim, updates_bdim = batch_dims
4603  del update_jaxpr, update_consts  # Unused.
4604
4605  # move the operand batch dim to the front if it is not None, otherwise create
4606  # it at the front (so that we can scatter into it)
4607  size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
4608              if ax is not None)
4609  operand = batching.bdim_at_front(operand, operand_bdim, size)
4610  operand_bdim = 0
4611
4612  updates = batching.bdim_at_front(updates, updates_bdim, size)
4613
4614  if scatter_indices_bdim is None:
4615    inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims))
4616    update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims))
4617    scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
4618    dnums = ScatterDimensionNumbers(
4619        update_window_dims=update_window_dims,
4620        inserted_window_dims=inserted_window_dims,
4621        scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
4622    return scatter_op(
4623      operand, scatter_indices, updates, dnums,
4624      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0
4625
4626
4627  # see the third case in _gather_batching_rule for comparison and comments
4628  scatter_indices = batching.bdim_at_front(
4629    scatter_indices, scatter_indices_bdim, size)
4630
4631  count_shape = list(scatter_indices.shape)
4632  count_shape[-1] = 1
4633  counts = broadcasted_iota(scatter_indices.dtype, tuple(count_shape), 0)
4634  scatter_indices = concatenate([counts, scatter_indices],
4635                                len(count_shape) - 1)
4636
4637  update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims))
4638  inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims))
4639  scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
4640
4641  dnums = ScatterDimensionNumbers(
4642      update_window_dims=update_window_dims,
4643      inserted_window_dims=inserted_window_dims,
4644      scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
4645  return scatter_op(
4646      operand, scatter_indices, updates, dnums,
4647      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0
4648
4649scatter_add_p = standard_primitive(
4650    _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
4651    _scatter_add_translation_rule)
4652ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
4653ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
4654batching.primitive_batchers[scatter_add_p] = (
4655  partial(_scatter_batching_rule, scatter_add))
4656
4657xla.backend_specific_translations['gpu'][scatter_add_p] = partial(
4658    _scatter_add_translation_rule, expand_complex128=True)
4659
4660scatter_mul_p = standard_primitive(
4661    _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
4662    _scatter_translation_rule)
4663
4664def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
4665                         indices_are_sorted, unique_indices, **kw):
4666  return mul(x, scatter_add(
4667      zeros_like_array(x), i, g, dimension_numbers=dimension_numbers,
4668      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices))
4669
4670ad.defjvp(scatter_mul_p,
4671          lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
4672          None,
4673          _scatter_mul_jvp_rhs)
4674ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
4675batching.primitive_batchers[scatter_mul_p] = (
4676  partial(_scatter_batching_rule, scatter_mul))
4677
4678def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
4679                          update_consts, dimension_numbers,
4680                          indices_are_sorted, unique_indices):
4681  operand, scatter_indices, updates = primals
4682  g_operand, g_scatter_indices, g_updates = tangents
4683
4684  scatter_dnums = dimension_numbers
4685  updates_shape = updates.shape
4686
4687  val_out = scatter_op.bind(
4688      operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
4689      update_consts=update_consts, dimension_numbers=scatter_dnums,
4690      indices_are_sorted=indices_are_sorted,
4691      unique_indices=unique_indices)
4692
4693  if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
4694    tangent_out = ad_util.Zero.from_value(val_out)
4695  else:
4696    g_operand = ad.instantiate_zeros(g_operand)
4697    g_updates = ad.instantiate_zeros(g_updates)
4698
4699    # gather_dnums and slice_sizes define the gather op that is the inverse of
4700    # the scatter op specified by scatter_dnums
4701    gather_dnums = GatherDimensionNumbers(
4702        offset_dims=scatter_dnums.update_window_dims,
4703        collapsed_slice_dims=scatter_dnums.inserted_window_dims,
4704        start_index_map=scatter_dnums.scatter_dims_to_operand_dims)
4705
4706    slice_sizes = []
4707    pos = 0
4708    for i in range(len(operand.shape)):
4709      if i in scatter_dnums.inserted_window_dims:
4710        slice_sizes.append(1)
4711      else:
4712        slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]])
4713        pos += 1
4714
4715    # For consistency with other max operations, if there are two or more values
4716    # in updates that are contending to replace the same index location, the
4717    # resulting tangent at that location will be the average of the associated
4718    # tangents for the values in updates.
4719
4720    initial_vals = gather(
4721        operand, scatter_indices, gather_dnums, np.array(slice_sizes))
4722
4723    target_vals = gather(
4724        val_out, scatter_indices, gather_dnums, np.array(slice_sizes))
4725
4726    successful_updates = (updates == target_vals)
4727    retained_values = (initial_vals == target_vals)
4728
4729    num_updates = gather(
4730        scatter_add(_zeros(operand),
4731                    scatter_indices,
4732                    select(successful_updates, _ones(updates), _zeros(updates)),
4733                    scatter_dnums),
4734        scatter_indices,
4735        gather_dnums,
4736        np.array(slice_sizes))
4737
4738    num_refs = gather(
4739        scatter_add(_zeros(operand),
4740                    scatter_indices,
4741                    _ones(updates),
4742                    scatter_dnums),
4743        scatter_indices,
4744        gather_dnums,
4745        np.array(slice_sizes))
4746
4747    updates_normalizer = select(retained_values,
4748                                1.0 / (num_updates + 1),
4749                                1.0 / num_updates)
4750
4751    updates_coef = select(successful_updates,
4752                          updates_normalizer,
4753                          _zeros(updates))
4754
4755    operand_normalizer = select(retained_values,
4756                                1.0 / (num_updates + 1),
4757                                _zeros(num_updates))
4758
4759    operand_coef = (-1.0 + operand_normalizer) / num_refs
4760
4761    # This can be simplified once scatter has transpose implemented
4762    target_tangents = gather(
4763        g_operand, scatter_indices, gather_dnums, np.array(slice_sizes))
4764
4765    tangent_updates = (target_tangents * operand_coef +
4766                       g_updates * updates_coef)
4767
4768    tangent_out = scatter_add(g_operand,
4769                              scatter_indices,
4770                              tangent_updates,
4771                              scatter_dnums,
4772                              indices_are_sorted=indices_are_sorted,
4773                              unique_indices=unique_indices)
4774
4775  return val_out, tangent_out
4776
4777scatter_min_p = standard_primitive(
4778    _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
4779    _scatter_translation_rule)
4780batching.primitive_batchers[scatter_min_p] = (
4781  partial(_scatter_batching_rule, scatter_min))
4782ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
4783
4784scatter_max_p = standard_primitive(
4785    _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
4786    _scatter_translation_rule)
4787batching.primitive_batchers[scatter_max_p] = (
4788  partial(_scatter_batching_rule, scatter_max))
4789ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
4790
4791def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
4792                 dimension_numbers, indices_are_sorted, unique_indices):
4793  operand, scatter_indices, updates = primals
4794  g_operand, g_scatter_indices, g_updates = tangents
4795  dnums = dimension_numbers
4796
4797  if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
4798    val_out = scatter_p.bind(
4799      operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
4800      update_consts=update_consts, dimension_numbers=dnums,
4801      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
4802    return val_out, ad_util.Zero.from_value(val_out)
4803
4804  g_operand = ad.instantiate_zeros(g_operand)
4805  g_updates = ad.instantiate_zeros(g_updates)
4806
4807  # If there are overlapping indices in the scatter, it is unspecified which
4808  # update "wins". So we use the following perhaps surprising scheme:
4809  # a) attach a positive ID to each update in updates, and perform the scatter
4810  #    on the IDs
4811  # b) perform the inverse gather on the scattered IDs (similar to
4812  #    _scatter_add_transpose).
4813  # c) use the gathered IDs to mask the primal and tangent values.
4814  # d) perform a scatter-add on the masked primal and tangent values. A benefit
4815  #    of using scatter-add here is that we don't need a `scatter` transpose
4816  #    rule.
4817
4818
4819  # a) attach a positive ID to each update in `updates`, and perform a scatter
4820  #    on the IDs.
4821  ids_shape = np.array(updates.shape, dtype=np.int64)
4822  ids_shape[dnums.update_window_dims,] = 1
4823  num_ids = np.prod(ids_shape)
4824  id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
4825  update_ids = add(reshape(iota(id_dtype, num_ids), ids_shape),
4826                   _ones(updates, dtype=id_dtype))
4827
4828  scattered_ids = scatter(full(operand.shape, 0, id_dtype),
4829                          scatter_indices, update_ids, dnums,
4830                          indices_are_sorted=indices_are_sorted,
4831                          unique_indices=unique_indices)
4832
4833  # b) compute the inverse gather that "undoes" the scatter on the id values.
4834  gather_dnums = GatherDimensionNumbers(
4835    offset_dims=dnums.update_window_dims,
4836    collapsed_slice_dims=dnums.inserted_window_dims,
4837    start_index_map=dnums.scatter_dims_to_operand_dims)
4838  slice_sizes = []
4839  pos = 0
4840  for i in range(len(scattered_ids.shape)):
4841    if i in dnums.inserted_window_dims:
4842      slice_sizes.append(1)
4843    else:
4844      slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
4845      pos += 1
4846  gathered_update_ids = gather(scattered_ids, scatter_indices,
4847                               dimension_numbers=gather_dnums,
4848                               slice_sizes=slice_sizes)
4849
4850  # c) mask off input elements that do not correspond to a primal output.
4851  masked_operand = select(eq(scattered_ids, _zeros(scattered_ids)),
4852                          operand, _zeros(operand))
4853  masked_updates = select(eq(update_ids,  gathered_update_ids),
4854                          updates, _zeros(updates))
4855  masked_g_operand = select(eq(scattered_ids, _zeros(scattered_ids)),
4856                            g_operand, _zeros(g_operand))
4857  masked_g_updates = select(eq(update_ids, gathered_update_ids),
4858                            g_updates, _zeros(g_updates))
4859
4860  # d) perform scatter-adds to compute the primal and tangent outputs.
4861  val_out = scatter_add(masked_operand, scatter_indices, masked_updates,
4862                        dimension_numbers=dnums,
4863                        indices_are_sorted=indices_are_sorted,
4864                        unique_indices=unique_indices)
4865  tangent_out = scatter_add(masked_g_operand, scatter_indices, masked_g_updates,
4866                            dimension_numbers=dnums,
4867                            indices_are_sorted=indices_are_sorted,
4868                            unique_indices=unique_indices)
4869  return val_out, tangent_out
4870
4871
4872scatter_p = standard_primitive(
4873    _scatter_shape_rule, _scatter_dtype_rule, 'scatter',
4874    _scatter_translation_rule)
4875ad.primitive_jvps[scatter_p] = _scatter_jvp
4876batching.primitive_batchers[scatter_p] = (
4877  partial(_scatter_batching_rule, scatter))
4878
4879
4880def _reduce_shape_rule(*args, computation, jaxpr, consts, dimensions):
4881  operand_args, init_value_args = split_list(args, [len(args) // 2])
4882  if any(arg.shape != () for arg in init_value_args):
4883    init_value_shapes = [a.shape for a in init_value_args]
4884    raise ValueError(f'Found non-scalar init_value: {init_value_shapes}')
4885  return [
4886      tuple(np.delete(op_arg.shape, dimensions))
4887      for op_arg in operand_args
4888  ]
4889
4890
4891def _reduce_dtype_rule(*args, computation, jaxpr, consts, dimensions):
4892  operand_args, init_value_args = split_list(args, [len(args) // 2])
4893  operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_args]
4894  init_value_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_value_args]
4895  if operand_dtypes != init_value_dtypes:
4896    raise TypeError(f"operand dtypes should match corresponding initial value dtypes; got "
4897                    f"operands={operand_args} and initial_values={init_value_args}")
4898  return operand_dtypes
4899
4900
4901def _reduce_translation_rule(c, *values, computation, jaxpr,
4902                             consts, dimensions):
4903  operands, init_values = split_list(values, [len(values) // 2])
4904  if len(operands) == 1:
4905    init_value = init_values[0]
4906    xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
4907    out = xops.Reduce(c, operands, init_values, xla_computation, dimensions)
4908    return xops.Tuple(c, (out,))
4909  xla_computation = _reduction_computation(c, jaxpr, consts, init_values, singleton=False)
4910  return xops.Reduce(c, operands, init_values, xla_computation, dimensions)
4911
4912
4913def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
4914                       consts, dimensions):
4915  num_operands = len(batched_args) // 2
4916  operands, init_values = split_list(batched_args, [num_operands])
4917  operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
4918  if all(init_value_bdim is None for init_value_bdim in init_value_bdims):
4919    # Assume all batch dims are the same for each of the operands
4920    assert all(operand_bdim is not None for operand_bdim in operand_bdims)
4921    assert all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims)
4922    # TODO(sharadmv): handle the case when batch dims are different across
4923    # operands or when some are unbatched
4924    operand_bdim = operand_bdims[0]
4925    new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
4926    new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
4927    new_operand_bdims = [new_operand_bdim] * num_operands
4928    return reduce_p.bind(*(operands + init_values),
4929                         computation=computation, dimensions=tuple(new_dimensions),
4930                         consts=consts,
4931                         jaxpr=jaxpr), new_operand_bdims
4932  else:
4933    raise NotImplementedError  # loop and stack
4934
4935
4936def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
4937  if singleton:
4938    init_values = [init_values]
4939  shapes = safe_map(c.get_shape, init_values + init_values)
4940  axis_env = xla.AxisEnv(1, (), ())  # no parallel primitives inside reductions
4941  subc = xla_bridge.make_computation_builder("reduction_computation")
4942  assert len(consts) == 0, "Reduction computations cannot have constants"
4943  args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
4944  out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
4945  if singleton:
4946    return subc.build(out_nodes[0])
4947  out_nodes = xops.Tuple(subc, out_nodes)
4948  return subc.build(out_nodes)
4949
4950def _masking_defreducer(prim, identity):
4951  masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
4952
4953def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
4954                          axes, input_shape=None, **reduce_kwargs):
4955  (padded_val,), (logical_shape,) = padded_vals, logical_shapes
4956  padded_shape = masking.padded_shape_as_value(padded_val.shape)
4957  masks = [broadcasted_iota(np.int32, padded_shape, i) < d
4958           for i, d in enumerate(logical_shape) if i in axes]
4959  mask = _reduce(operator.and_, masks)
4960  masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
4961  prim_bind = partial(prim.bind, **reduce_kwargs)
4962  bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape)
4963  return bind(masked_val, axes=axes)
4964
4965reduce_p = standard_primitive(_reduce_shape_rule, _reduce_dtype_rule,
4966                              'reduce', translation_rule=_reduce_translation_rule,
4967                              multiple_results=True)
4968batching.primitive_batchers[reduce_p] = _reduce_batch_rule
4969
4970
4971def _reduce_number_dtype_rule(name, operand, *args, **kw):
4972  if not dtypes.issubdtype(operand.dtype, np.number):
4973    raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
4974                    "of number.".format(name, np.dtype(operand.dtype).name))
4975  return dtypes.canonicalize_dtype(operand.dtype)
4976
4977def _reduce_sum_shape_rule(operand, *, axes):
4978  return _reduce_op_shape_rule(operand, axes=axes)
4979
4980def _reduce_sum_translation_rule(c, operand, *, axes):
4981  shape = c.get_shape(operand)
4982  dtype = shape.numpy_dtype()
4983  scalar = ShapedArray((), dtype)
4984  return xops.Reduce(c, [operand], [xb.constant(c, np.array(0, dtype))],
4985                     xla.primitive_subcomputation(add_p, scalar, scalar),
4986                     axes)
4987
4988def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
4989  assert ad.is_undefined_primal(operand)
4990  input_shape = operand.aval.shape
4991  broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
4992  result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
4993  assert result.shape == input_shape
4994  return [result]
4995
4996reduce_sum_p = standard_primitive(
4997  _reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
4998  'reduce_sum', _reduce_sum_translation_rule)
4999ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
5000batching.defreducer(reduce_sum_p)
5001_masking_defreducer(reduce_sum_p,
5002                    lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape))
5003
5004
5005def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
5006  del input_shape  # Unused.
5007  if len(axes) != len(set(axes)):
5008    raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
5009  if not all(0 <= a < operand.ndim for a in axes):
5010    raise ValueError(f"reduction axes {axes} contains out-of-bounds indices for {operand}.")
5011  return tuple(np.delete(operand.shape, axes))
5012
5013def _reduce_prod_translation_rule(c, operand, *, axes):
5014  dtype = c.get_shape(operand).numpy_dtype()
5015  scalar = ShapedArray((), dtype)
5016  return xops.Reduce(c, [operand], [xb.constant(c, np.array(1, dtype))],
5017                     xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
5018
5019def _reduce_prod_jvp_rule(primals, tangents, *, axes):
5020  operand, = primals
5021  tangent, = tangents
5022  input_shape = np.array(operand.shape)
5023
5024  n = np.prod(input_shape[list(axes)])
5025  non_axes = np.delete(np.arange(len(input_shape)), axes)
5026
5027  # Move the reduced axes to the front, and flatten them to 1D.
5028  permutation = axes + tuple(non_axes)
5029  new_shape = (n,) + tuple(input_shape[non_axes])
5030  operand = reshape(operand, new_shape, permutation)
5031  tangent = reshape(tangent, new_shape, permutation)
5032
5033  def _reduce_prod_tree(x, axis=0):
5034    """Reduce by repeatedly splitting the array and multiplying."""
5035    while x.shape[axis] > 1:
5036      n = x.shape[axis]
5037      n1 = (n + 1) // 2
5038      n2 = n - n1
5039      x1 = slice_in_dim(x, 0, n1)
5040      x2 = slice_in_dim(x, n1, None)
5041      if n2 != n1:
5042        paddings = [(0, 0, 0)] * len(x.shape)
5043        paddings[axis] = (0, 1, 0)
5044        x2 = pad(x2, _const(x, 1), paddings)
5045      x = x1 * x2
5046    if x.shape[axis] == 0:
5047      return full(input_shape[non_axes], _one(x))
5048    return squeeze(x, (axis,))
5049
5050  return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
5051
5052
5053reduce_prod_p = standard_primitive(
5054  _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
5055  'reduce_prod', _reduce_prod_translation_rule)
5056ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
5057batching.defreducer(reduce_prod_p)
5058_masking_defreducer(reduce_prod_p,
5059                    lambda shape, dtype: np.broadcast_to(np.array(1, dtype), shape))
5060
5061
5062def _reduce_chooser_shape_rule(operand, *, axes):
5063  return tuple(np.delete(operand.shape, axes))
5064
5065def _reduce_chooser_translation_rule(prim, identity, c, operand, *, axes):
5066  dtype = c.get_shape(operand).numpy_dtype()
5067  scalar = ShapedArray((), dtype)
5068  return xops.Reduce(c, [operand], [xb.constant(c, identity(dtype))],
5069                     xla.primitive_subcomputation(prim, scalar, scalar), axes)
5070
5071def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
5072  # TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
5073  # locations in a single pass (rather than comparing equality) and use a
5074  # gather, and/or even push along the chosen elements of g (b/112040122)
5075  shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
5076  location_indicators = convert_element_type(
5077      _eq_meet(operand, reshape(ans, shape)), g.dtype)
5078  counts = _reduce_sum(location_indicators, axes)
5079  return div(_reduce_sum(mul(g, location_indicators), axes), counts)
5080
5081_reduce_max_translation_rule = partial(_reduce_chooser_translation_rule, max_p,
5082                                       _get_max_identity)
5083reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
5084                                  'reduce_max', _reduce_max_translation_rule)
5085ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
5086batching.defreducer(reduce_max_p)
5087_masking_defreducer(reduce_max_p,
5088                    lambda shape, dtype: np.broadcast_to(np.array(-np.inf, dtype), shape))
5089
5090
5091_reduce_min_translation_rule = partial(
5092    _reduce_chooser_translation_rule, min_p, _get_min_identity)
5093reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
5094                                  'reduce_min', _reduce_min_translation_rule)
5095ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
5096batching.defreducer(reduce_min_p)
5097_masking_defreducer(reduce_min_p,
5098                    lambda shape, dtype: np.broadcast_to(np.array(np.inf, dtype), shape))
5099
5100
5101
5102def _argminmax_shape_rule(operand, *, axes, index_dtype):
5103  axis, = axes
5104  return tuple(np.delete(operand.shape, axis))
5105
5106def _argminmax_dtype_rule(operand, *, axes, index_dtype):
5107  if not dtypes.issubdtype(index_dtype, np.integer):
5108    raise TypeError("index_dtype must be an integer type, but got {}"
5109                    .format(np.dtype(index_dtype).name))
5110  return index_dtype
5111
5112def _argminmax_translation_rule(value_comparator, identity,
5113                                c, operand, *, axes, index_dtype):
5114  axis, = axes
5115  shape = c.get_shape(operand)
5116  dtype = shape.numpy_dtype()
5117
5118  subc = xb.make_computation_builder("argminmax_comparator")
5119  value_shape = xc.Shape.array_shape(shape.xla_element_type(), ())
5120  index_shape = xc.Shape.array_shape(index_dtype, ())
5121  x_value = xb.parameter(subc, 0, value_shape)
5122  x_index = xb.parameter(subc, 1, index_shape)
5123  y_value = xb.parameter(subc, 2, value_shape)
5124  y_index = xb.parameter(subc, 3, index_shape)
5125  which_value = value_comparator(x_value, y_value)
5126  which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value),
5127                                              xops.Lt(x_index, y_index)))
5128  xops.Tuple(subc, [xops.Select(which_value, x_value, y_value),
5129                    xops.Select(which_index, x_index, y_index)])
5130  comparator = subc.build()
5131
5132  iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions())
5133  iota = xc.ops.Iota(c, iota_shape, axis)
5134  out = xops.Reduce(
5135    c, [operand, iota],
5136    [xb.constant(c, identity(dtype)),
5137     xb.constant(c, np.array(0, index_dtype))], comparator, [axis])
5138  return xops.GetTupleElement(out, 1)
5139
5140def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
5141  axis, = axes
5142  idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
5143  maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
5144  maxval = broadcast(tie_in(a, maxval), a.shape)
5145  mask_idxs = select(eq(a, expand_dims(op(a, (axis,)), (axis,))), idxs,
5146                     maxval)
5147  return _reduce_min(mask_idxs, (axis,))
5148
5149_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt,
5150                                   _get_min_identity)
5151_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt,
5152                                   _get_max_identity)
5153
5154argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
5155                              'argmin', _argmin_translation_rule)
5156batching.defreducer(argmin_p)
5157ad.defjvp_zero(argmin_p)
5158xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun(
5159  partial(_argminmax_gpu_translation_rule, _reduce_min),
5160  multiple_results=False)
5161
5162argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
5163                              'argmax', _argmax_translation_rule)
5164batching.defreducer(argmax_p)
5165ad.defjvp_zero(argmax_p)
5166xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun(
5167  partial(_argminmax_gpu_translation_rule, _reduce_max),
5168  multiple_results=False)
5169
5170
5171def _reduce_logical_shape_rule(operand, *, axes):
5172  if operand.dtype != np.bool_:
5173    msg = "logical reduction requires operand dtype bool, got {}."
5174    raise TypeError(msg.format(operand.dtype))
5175  return tuple(np.delete(operand.shape, axes))
5176
5177def _reduce_logical_translation_rule(prim, identity, c, operand, *, axes):
5178  scalar = ShapedArray((), np.bool_)
5179  return xops.Reduce(c, [operand], [xb.constant(c, identity(np.bool_))],
5180                     xla.primitive_subcomputation(prim, scalar, scalar), axes)
5181
5182_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
5183                                      or_p, _get_max_identity)
5184reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
5185                                 'reduce_or', _reduce_or_translation_rule)
5186batching.defreducer(reduce_or_p)
5187
5188
5189_reduce_and_translation_rule = partial(_reduce_logical_translation_rule,
5190                                       and_p, _get_min_identity)
5191reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
5192                                  'reduce_and', _reduce_and_translation_rule)
5193batching.defreducer(reduce_and_p)
5194
5195def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
5196                              window_dimensions, window_strides, padding,
5197                              base_dilation, window_dilation):
5198  if operand.dtype != init_value.dtype:
5199    msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
5200           " got operand dtype {} and init_value dtype {}.")
5201    raise TypeError(msg.format(operand.dtype, init_value.dtype))
5202  if init_value.shape != ():
5203    msg = ("reduce_window expected init_value to be a scalar but init_value "
5204           "has shape {}.")
5205    raise TypeError(msg.format(init_value.shape))
5206  return _common_reduce_window_shape_rule(
5207    operand, window_dimensions, window_strides, padding, base_dilation,
5208    window_dilation)
5209
5210def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts,
5211                                    window_dimensions, window_strides, padding,
5212                                    base_dilation, window_dilation):
5213  xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
5214  return xops.ReduceWindowWithGeneralPadding(
5215    operand, init_value, xla_computation, window_dimensions,
5216    window_strides, base_dilation, window_dilation, padding)
5217
5218def _generic_reduce_window_batch_rule(
5219    batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
5220    window_strides, padding, base_dilation, window_dilation):
5221  operand, init = batched_args
5222  bdim, init_bdim = batch_dims
5223  if init_bdim is not None:
5224    raise NotImplementedError("reduce_window batching is not implemented for "
5225                              "initial values")
5226
5227  def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
5228                    window_dilation):
5229    return reduce_window_p.bind(
5230      x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
5231      window_strides=window_strides, padding=padding, base_dilation=base_dilation,
5232      window_dilation=window_dilation)
5233  return _reduce_window_batch_rule(
5234    reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
5235    window_strides=window_strides, padding=padding, base_dilation=base_dilation,
5236    window_dilation=window_dilation)
5237
5238
5239reduce_window_p = standard_primitive(
5240    _reduce_window_shape_rule, _input_dtype, 'reduce_window',
5241    _reduce_window_translation_rule)
5242batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
5243
5244
5245def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
5246                                  padding, base_dilation, window_dilation):
5247  if not dtypes.issubdtype(operand.dtype, np.number):
5248    msg = "operand to reduce_window_sum must have a number dtype, got {}"
5249    raise TypeError(msg.format(np.dtype(operand.dtype).name))
5250  return _common_reduce_window_shape_rule(operand, window_dimensions,
5251                                          window_strides, padding, base_dilation,
5252                                          window_dilation)
5253
5254def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions,
5255                                        window_strides, padding, base_dilation,
5256                                        window_dilation):
5257  dtype = c.get_shape(operand).numpy_dtype()
5258  scalar = ShapedArray((), dtype)
5259  return xops.ReduceWindowWithGeneralPadding(
5260    operand, xb.constant(c, np.array(0, dtype)),
5261    xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
5262    window_strides, base_dilation, window_dilation, padding)
5263
5264def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
5265                                      window_strides, padding, base_dilation,
5266                                      window_dilation):
5267  assert ad.is_undefined_primal(operand)
5268  input_shape = operand.aval.shape
5269  pads = _conv_general_vjp_lhs_padding(
5270      input_shape, window_dimensions, window_strides, cotangent.shape, padding,
5271      base_dilation, window_dilation)
5272  ones = [1] * len(input_shape)
5273  padding_config = [(lo, hi, stride - 1)
5274                    for (lo, hi), stride in zip(pads, window_strides)]
5275  pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
5276  result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
5277                              [(0, 0)] * len(input_shape),
5278                              base_dilation=ones,
5279                              window_dilation=window_dilation)
5280  assert result.shape == input_shape, (result.shape, input_shape)
5281  return [result]
5282
5283def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
5284                              window_dimensions, window_strides, padding,
5285                              base_dilation, window_dilation):
5286  operand, = batched_args
5287  bdim, = bdims
5288
5289  if bdim is not None:
5290    window_dimensions = \
5291        window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
5292    window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
5293    padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
5294    base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
5295    window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
5296
5297  operand = reduce_window(operand, window_dimensions, window_strides, padding,
5298                          base_dilation, window_dilation)
5299  return operand, bdim
5300
5301reduce_window_sum_p = standard_primitive(
5302    _reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
5303    _reduce_window_sum_translation_rule)
5304ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
5305batching.primitive_batchers[reduce_window_sum_p] = partial(
5306  _reduce_window_batch_rule, _reduce_window_sum)
5307
5308def _reduce_window_chooser_translation_rule(
5309    prim, identity, c, operand, *, window_dimensions, window_strides, padding,
5310    base_dilation, window_dilation):
5311  dtype = c.get_shape(operand).numpy_dtype()
5312  scalar = ShapedArray((), dtype)
5313  return xops.ReduceWindowWithGeneralPadding(
5314    operand, xb.constant(c, identity(dtype)),
5315    xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
5316    window_strides, base_dilation, window_dilation, padding)
5317
5318def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
5319                                    window_strides, padding, base_dilation,
5320                                    window_dilation):
5321  assert prim is max_p or prim is min_p
5322  select_prim = ge_p if prim is max_p else le_p
5323  return _select_and_gather_add(g, operand, select_prim, window_dimensions,
5324                                window_strides, padding, base_dilation,
5325                                window_dilation)
5326
5327
5328def _common_reduce_window_shape_rule(operand, window_dimensions,
5329                                     window_strides, padding, base_dilation,
5330                                     window_dilation):
5331  _check_shapelike("reduce_window", "window_dimensions", window_dimensions,
5332                   non_zero_shape=True)
5333  _check_shapelike("reduce_window", "window_strides", window_strides,
5334                   non_zero_shape=True)
5335  _check_shapelike("reduce_window", "base_dilation", base_dilation)
5336  _check_shapelike("reduce_window", "window_dilation", window_dilation)
5337  if operand.ndim != len(window_dimensions):
5338    msg = ("reduce_window got the wrong number of window_dimensions for "
5339           "operand: got operand shape {} with window_dimensions {}.")
5340    raise TypeError(msg.format(operand.shape, window_dimensions))
5341  if len(window_strides) != len(window_dimensions):
5342    msg = ("reduce_window got inconsistent window_strides and "
5343           "window_dimensions: got window_strides {} and window_dimensions {}.")
5344    raise TypeError(msg.format(window_strides, window_dimensions))
5345  if len(base_dilation) != len(window_dimensions):
5346    msg = ("reduce_window got inconsistent base_dilation and "
5347           "window_dimensions: got base_dilation {} and window_dimensions {}.")
5348    raise TypeError(msg.format(base_dilation, window_dimensions))
5349  if len(window_dilation) != len(window_dimensions):
5350    msg = ("reduce_window got inconsistent window_dilation and "
5351           "window_dimensions: got window_dilation {} and window_dimensions "
5352           "{}.")
5353    raise TypeError(msg.format(window_dilation, window_dimensions))
5354
5355  return reduce_window_shape_tuple(operand.shape, window_dimensions,
5356                                   window_strides, padding, base_dilation,
5357                                   window_dilation)
5358
5359def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
5360                              padding, base_dilation=None,
5361                              window_dilation=None):
5362  if base_dilation is not None:
5363    operand_shape = _dilate_shape(operand_shape, base_dilation)
5364  if window_dilation is not None:
5365    window_dimensions = _dilate_shape(window_dimensions, window_dilation)
5366  operand_padded = np.add(operand_shape, np.add(*zip(*padding)))
5367  t = np.floor_divide(
5368      np.subtract(operand_padded, window_dimensions), window_strides) + 1
5369  return tuple(t)
5370
5371_reduce_window_max_translation_rule = partial(
5372    _reduce_window_chooser_translation_rule, max_p, _get_max_identity)
5373reduce_window_max_p = standard_primitive(
5374    _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
5375    _reduce_window_max_translation_rule)
5376ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p))
5377batching.primitive_batchers[reduce_window_max_p] = partial(
5378  _reduce_window_batch_rule, _reduce_window_max)
5379
5380_reduce_window_min_translation_rule = partial(
5381    _reduce_window_chooser_translation_rule, min_p, _get_min_identity)
5382reduce_window_min_p = standard_primitive(
5383    _common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min',
5384    _reduce_window_min_translation_rule)
5385ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p))
5386
5387_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
5388                                        _reduce_window_min)
5389batching.primitive_batchers[reduce_window_min_p] = partial(
5390  _reduce_window_batch_rule, _reduce_window_min)
5391
5392
5393def _select_and_scatter_shape_rule(
5394    operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
5395    scatter_consts, window_dimensions, window_strides, padding):
5396  _check_shapelike("select_and_scatter", "window_dimensions", window_dimensions)
5397  _check_shapelike("select_and_scatter", "window_strides", window_strides)
5398  if len(window_dimensions) != len(window_strides):
5399    msg = ("select_and_scatter got inconsistent window_strides and "
5400           "window_dimensions: got window_strides {} and window_dimensions {}.")
5401    raise TypeError(msg.format(window_strides, window_dimensions))
5402  return operand.shape
5403
5404def _select_and_scatter_translation(
5405  c, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
5406  scatter_consts, window_dimensions, window_strides, padding):
5407  select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
5408  scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
5409  return xops.SelectAndScatterWithGeneralPadding(
5410    operand, select, window_dimensions, window_strides, padding, source,
5411    init_value, scatter)
5412
5413select_and_scatter_p = standard_primitive(
5414    _select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
5415    _select_and_scatter_translation)
5416
5417
5418def _select_and_scatter_add_shape_rule(
5419    source, operand, *, select_prim, window_dimensions, window_strides,
5420    padding):
5421  return operand.shape
5422
5423def _select_and_scatter_add_translation(
5424    c, source, operand, *, select_prim, window_dimensions, window_strides,
5425    padding):
5426  dtype = c.get_shape(operand).numpy_dtype()
5427  scalar = ShapedArray((), dtype)
5428  select = xla.primitive_subcomputation(select_prim, scalar, scalar)
5429  scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
5430  zero = xb.constant(c, np.array(0, dtype))
5431  return xops.SelectAndScatterWithGeneralPadding(
5432    operand, select, window_dimensions, window_strides, padding, source, zero,
5433    scatter)
5434
5435def _select_and_scatter_add_jvp(
5436    primals, tangents, *, select_prim, window_dimensions, window_strides,
5437    padding):
5438  source, operand = primals
5439  g_source, g_operand = tangents
5440  val_out = _select_and_scatter_add(
5441      source, operand, select_prim, window_dimensions, window_strides,
5442      padding)
5443  del g_operand
5444  if type(g_source) is ad_util.Zero:
5445    tangent_out = ad_util.Zero.from_value(val_out)
5446  else:
5447    tangent_out = _select_and_scatter_add(
5448        g_source, operand, select_prim, window_dimensions,
5449        window_strides, padding)
5450  return val_out, tangent_out
5451
5452def _select_and_scatter_add_transpose(
5453    t, source, operand, *, select_prim, window_dimensions, window_strides,
5454    padding):
5455  assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
5456  ones = (1,) * len(window_dimensions)
5457  source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
5458                                    window_strides, padding, ones, ones)
5459  return [source_t, None]
5460
5461def _select_and_scatter_add_batch_rule(
5462    batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
5463    padding):
5464  source, operand = batched_args
5465  s_bdim, o_bdim = batch_dims
5466  size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
5467              if bdim is not None)
5468  source = batching.bdim_at_front(source, s_bdim, size)
5469  operand = batching.bdim_at_front(operand, o_bdim, size)
5470
5471  window_dimensions = (1,) + window_dimensions
5472  window_strides = (1,) + window_strides
5473  padding = ((0, 0),) + padding
5474  out = _select_and_scatter_add(source, operand, select_prim, window_dimensions,
5475                                window_strides, padding)
5476  return out, 0
5477
5478select_and_scatter_add_p = standard_primitive(
5479    _select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
5480    _select_and_scatter_add_translation)
5481ad.primitive_transposes[select_and_scatter_add_p] = \
5482    _select_and_scatter_add_transpose
5483ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
5484batching.primitive_batchers[select_and_scatter_add_p] = \
5485    _select_and_scatter_add_batch_rule
5486
5487def _select_and_gather_add_shape_rule(
5488    tangents, operand, *, select_prim, window_dimensions, window_strides,
5489    padding, base_dilation, window_dilation):
5490  if tangents.shape != operand.shape:
5491    msg = ("select_and_gather_add tangents and operand shapes must match, "
5492           "got {} and {}.")
5493    raise TypeError(msg.format(tangents.shape, operand.shape))
5494  return _common_reduce_window_shape_rule(
5495    operand, window_dimensions, window_strides, padding, base_dilation,
5496    window_dilation)
5497
5498
5499_UINT_DTYPES = {
5500  16: np.uint16,
5501  32: np.uint32,
5502  64: np.uint64,
5503}
5504
5505_INT_DTYPES = {
5506  16: np.int16,
5507  32: np.int32,
5508  64: np.int64,
5509}
5510
5511def _select_and_gather_add_translation(
5512    c, tangents, operand, *, select_prim, window_dimensions, window_strides,
5513    padding, base_dilation, window_dilation, max_bits=64):
5514  shape = c.get_shape(operand)
5515  dtype = shape.numpy_dtype()
5516  etype = shape.xla_element_type()
5517  nbits = dtypes.finfo(dtype).bits
5518
5519  assert nbits <= max_bits
5520  double_word_reduction = nbits * 2 <= max_bits
5521
5522  const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
5523                                          canonicalize_types=False)
5524
5525  if double_word_reduction:
5526    # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
5527    # we implement a pair-wise ReduceWindow by packing two k-bit values into
5528    # 2k-bit unsigned integer using bit tricks.
5529    word_dtype = _UINT_DTYPES[nbits]
5530    double_word_dtype = _UINT_DTYPES[nbits * 2]
5531    word_type = xla_client.dtype_to_etype(word_dtype)
5532    double_word_type = xla_client.dtype_to_etype(double_word_dtype)
5533
5534    # Packs two values into a tuple.
5535    def pack(a, b):
5536      a = xops.BitcastConvertType(a, word_type)
5537      b = xops.BitcastConvertType(b, word_type)
5538      a = xops.ConvertElementType(a, double_word_type)
5539      b = xops.ConvertElementType(b, double_word_type)
5540      a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits))
5541      return xops.Or(a, b)
5542
5543    # Unpacks the first element of a tuple.
5544    def fst(c, t):
5545      st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
5546      return xops.BitcastConvertType(xops.ConvertElementType(st, word_type), etype)
5547
5548    # Unpacks the second element of a tuple.
5549    def snd(t):
5550      return xops.BitcastConvertType(xops.ConvertElementType(t, word_type), etype)
5551
5552  else:
5553    # The double-word trick above only works if we have a sufficiently large
5554    # type. As an alternative, we can pack two half words into a single word,
5555    # at the cost of precision.
5556    # TODO(b/73062247): add support for tuple reductions and remove this case.
5557    warnings.warn("Using reduced precision for gradient of reduce-window "
5558                  "min/max operator to work around missing XLA support for "
5559                  "pair-reductions. This is likely from a second or "
5560                  "higher derivative of a max-pooling operation.")
5561    r_nbits = nbits // 2
5562    # Drop/round the bottom mantissa bits.
5563    nexp = dtypes.finfo(dtype).nexp
5564    nmant = r_nbits - nexp - 1
5565
5566    double_word_dtype = word_dtype = _UINT_DTYPES[nbits]
5567    word_type = xla_client.dtype_to_etype(word_dtype)
5568
5569    # Packs two values into a tuple.
5570    def pack(a, b):
5571      a = xops.ReducePrecision(a, exponent_bits=nexp, mantissa_bits=nmant)
5572      b = xops.ReducePrecision(b, exponent_bits=nexp, mantissa_bits=nmant)
5573      a = xops.BitcastConvertType(a, word_type)
5574      b = xops.BitcastConvertType(b, word_type)
5575      b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits))
5576      return xops.Or(a, b)
5577
5578    # Unpacks the first element of a tuple.
5579    def fst(c, t):
5580      st = xops.And(t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits))
5581      return xops.BitcastConvertType(st, etype)
5582
5583    # Unpacks the second element of a tuple.
5584    def snd(t):
5585      return xops.BitcastConvertType(xops.ShiftLeft(t, const(c, word_dtype, r_nbits)),
5586                                  etype)
5587
5588  def reducer():
5589    c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
5590    x = xb.parameter(c, 0,
5591      xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
5592    y = xb.parameter(c, 1,
5593      xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
5594    assert select_prim is ge_p or select_prim is le_p
5595    which = xops.Ge if select_prim is ge_p else xops.Le
5596    xops.Select(which(fst(c, x), fst(c, y)), x, y)
5597    return c.build()
5598
5599
5600  assert select_prim is ge_p or select_prim is le_p, select_prim
5601  init = -np.inf if select_prim is ge_p else np.inf
5602  out = xops.ReduceWindowWithGeneralPadding(
5603    pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
5604    reducer(), window_dimensions, window_strides, base_dilation,
5605    window_dilation, padding)
5606  return snd(out)
5607
5608def _select_and_gather_add_jvp(
5609    primals, tangents, *, select_prim, window_dimensions, window_strides,
5610    padding, base_dilation, window_dilation):
5611  source, operand = primals
5612  g_source, g_operand = tangents
5613  val_out = _select_and_gather_add(
5614      source, operand, select_prim, window_dimensions, window_strides,
5615      padding, base_dilation, window_dilation)
5616  del g_operand
5617  if type(g_source) is ad_util.Zero:
5618    tangent_out = ad_util.Zero.from_value(val_out)
5619  else:
5620    tangent_out = _select_and_gather_add(
5621        g_source, operand, select_prim, window_dimensions,
5622        window_strides, padding, base_dilation, window_dilation)
5623  return val_out, tangent_out
5624
5625def _select_and_gather_add_transpose(
5626    t, tangents, operand, *, select_prim, window_dimensions, window_strides,
5627    padding, base_dilation, window_dilation):
5628  assert select_prim in (le_p, ge_p)
5629  assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)
5630  if any(d != 1 for d in window_dilation):
5631    msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
5632           "dilation, got window_dilation={}.")
5633    raise NotImplementedError(msg.format(window_dilation))
5634  if type(t) is ad_util.Zero:
5635    return [ad_util.Zero(tangents.aval), None]
5636  has_base_dilation = any(d != 1 for d in base_dilation)
5637  if has_base_dilation:
5638    select_identity = (_get_max_identity if select_prim is ge_p
5639                       else _get_min_identity)
5640    operand = pad(operand, select_identity(operand.dtype),
5641                  tuple((0, 0, d - 1) for d in base_dilation))
5642  result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
5643                                   window_strides, padding)
5644  if has_base_dilation:
5645    result = slice(operand, (0,) * len(operand.shape), operand.shape,
5646                   base_dilation)
5647  return [result, None]
5648
5649def _select_and_gather_add_batching_rule(
5650    batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
5651    padding, base_dilation, window_dilation):
5652  t, x = batched_args
5653  t_bdim, x_bdim = batch_dims
5654  size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
5655              if bdim is not None)
5656  t = batching.bdim_at_front(t, t_bdim, size)
5657  x = batching.bdim_at_front(x, x_bdim, size)
5658  window_dimensions = (1,) + window_dimensions
5659  window_strides = (1,) + window_strides
5660  padding = ((0, 0),) + padding
5661  base_dilation = (1,) + base_dilation
5662  window_dilation = (1,) + window_dilation
5663  out = _select_and_gather_add(t, x, select_prim, window_dimensions,
5664                               window_strides, padding, base_dilation,
5665                               window_dilation)
5666  return (out, 0)
5667
5668
5669select_and_gather_add_p = standard_primitive(
5670    _select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
5671    _select_and_gather_add_translation)
5672ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
5673ad.primitive_transposes[select_and_gather_add_p] = \
5674  _select_and_gather_add_transpose
5675batching.primitive_batchers[select_and_gather_add_p] = \
5676  _select_and_gather_add_batching_rule
5677xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
5678  _select_and_gather_add_translation,
5679  max_bits=32)
5680
5681def _sort_abstract_eval(*args, **kwargs):
5682  args = tuple(raise_to_shaped(arg) for arg in args)
5683  if any(arg.shape != args[0].shape for arg in args[1:]):
5684    shapes = " ".join(str(a.shape) for a in args)
5685    raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
5686  return args
5687
5688
5689def _float_to_int_for_sort(x):
5690  # Switch from a floating point value to a integer value in such a way that
5691  # when using the integer value to compare, we get the same result for normal
5692  # values, and -nan is treated as the smallest value, and nan is treated as
5693  # the largest value.
5694  # If f is a float, and
5695  # x = bit_cast<int32>(f);
5696  # y = x < 0 ? int32_max - x : x;
5697  # then y is ordered as an int32 such that finite values have the obvious
5698  # order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
5699  # and end of the ordering.
5700  # Note that in order to avoid -x to overflow, we calculate
5701  # int32_max - x as unsigned, and then convert back to signed.
5702  if x.dtype == dtypes.bfloat16:
5703    x = convert_element_type(x, np.float32)
5704  nbits = np.finfo(x).bits
5705  signed_dtype = _INT_DTYPES[nbits]
5706  unsigned_dtype = _UINT_DTYPES[nbits]
5707
5708  signed = bitcast_convert_type(x, signed_dtype)
5709  unsigned = bitcast_convert_type(x, unsigned_dtype)
5710  flipped = bitcast_convert_type(
5711    sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
5712  return select(lt(signed, _zero(signed)), flipped, signed)
5713
5714# Default comparator that sorts the operands lexicographically on the
5715# first `num_keys` arguments.
5716# For floating point types, a total order is created where
5717# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
5718# For complex types, the (real, imag) pairs are sorted lexicographically
5719# (following NumPy's semantics).
5720# This code adds complex-number support and lexicographic ordering to the algorithm from:
5721# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
5722def _sort_lt_comparator(*operands, num_keys=1):
5723  assert len(operands) >= 2 and len(operands) % 2 == 0, operands
5724  assert len(operands) // 2 >= num_keys, (operands, num_keys)
5725  x_keys, y_keys = [], []
5726  for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
5727    assert x.dtype == y.dtype, (x.dtype, y.dtype)
5728    if np.issubdtype(x.dtype, np.complexfloating):
5729      x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
5730      y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
5731    elif np.issubdtype(x.dtype, np.floating):
5732      x_keys.append(_float_to_int_for_sort(x))
5733      y_keys.append(_float_to_int_for_sort(y))
5734    else:
5735      x_keys.append(x)
5736      y_keys.append(y)
5737
5738  p = None
5739  for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
5740    p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
5741         else lt(xk, yk))
5742  return p
5743
5744
5745def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
5746  types = [c.get_shape(x).xla_element_type() for x in operands]
5747  subc = xla_bridge.make_computation_builder("sort_lt_comparator")
5748  params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
5749            for i, typ in enumerate(types) for j in range(2)]
5750  result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
5751                         multiple_results=False)(subc, *params)
5752  comparator = subc.build(result)
5753  out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
5754                  comparator=comparator)
5755  return out if len(operands) != 1 else xops.Tuple(c, [out])
5756
5757def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
5758  shape = primals[0].shape
5759  iotas = []
5760  for dim, size in enumerate(shape):
5761    dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64
5762    iotas.append(broadcasted_iota(dtype, shape, dim))
5763  primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
5764                        is_stable=is_stable, num_keys=num_keys)
5765  idx = tuple(primals[-1] if i == dimension else iotas[i]
5766              for i in range(len(shape)))
5767  tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
5768  return tuple(primals[:-1]), tangents_out
5769
5770def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
5771  prototype_arg, new_bdim = next(
5772    (a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
5773  new_args = []
5774  for arg, bdim in zip(batched_args, batch_dims):
5775    if bdim is None:
5776      dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
5777      new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
5778    else:
5779      new_args.append(batching.moveaxis(arg, bdim, new_bdim))
5780  new_dimension = dimension + (new_bdim <= dimension)
5781  bdims = (new_bdim,) * len(new_args)
5782  return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
5783          bdims)
5784
5785
5786sort_p = Primitive('sort')
5787sort_p.multiple_results = True
5788sort_p.def_impl(partial(xla.apply_primitive, sort_p))
5789sort_p.def_abstract_eval(_sort_abstract_eval)
5790xla.translations[sort_p] = _sort_translation_rule
5791ad.primitive_jvps[sort_p] = _sort_jvp
5792batching.primitive_batchers[sort_p] = _sort_batch_rule
5793
5794
5795def _top_k_abstract_eval(operand, *, k):
5796  if k < 0:
5797    raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
5798  if len(operand.shape) == 0:
5799    raise TypeError("top_k operand must have >= 1 dimension, got {}"
5800                    .format(operand.shape))
5801  shape = list(operand.shape)
5802  if shape[-1] < k:
5803    msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
5804    raise ValueError(msg.format(k, shape))
5805  shape[-1] = k
5806  return (ShapedArray(shape, operand.dtype),
5807          ShapedArray(shape, np.dtype(np.int32)))
5808
5809def _top_k_jvp(primals, tangents, *, k):
5810  operand, = primals
5811  tangent, = tangents
5812  primals_out = top_k(operand, k)
5813  if type(tangent) is ad_util.Zero:
5814    tangent_out = ad_util.Zero.from_value(primals_out[0])
5815  else:
5816    _, k_idxs = primals_out
5817    idx_shape = k_idxs.shape
5818    rank = len(idx_shape)
5819    gather_index_shape = idx_shape + (1,)
5820    gather_indices = []
5821    for i in range(rank-1):
5822      _iota = iota(k_idxs.dtype, idx_shape[i])
5823      if not config.omnistaging_enabled:
5824        _iota = tie_in(operand, _iota)
5825      _iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
5826      gather_indices.append(_iota)
5827    gather_indices.append(reshape(k_idxs, gather_index_shape))
5828    gather_indices = concatenate(gather_indices, dimension=rank)
5829    slice_sizes = (1,) * rank
5830    dnums = GatherDimensionNumbers(
5831      offset_dims=(),
5832      collapsed_slice_dims=tuple(range(rank)),
5833      start_index_map=tuple(range(rank)))
5834    tangent_out = gather(tangent, gather_indices, dnums, slice_sizes)
5835  return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1]))
5836
5837def _top_k_batch_rule(batched_args, batch_dims, *, k):
5838  operand, = batched_args
5839  bdim, = batch_dims
5840  if bdim == operand.ndim-1:
5841    perm = np.arange(operand.ndim)
5842    perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
5843    top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
5844    return (transpose(top_k_v, perm),
5845            transpose(top_k_i, perm)), (bdim, bdim)
5846  else:
5847    return top_k(operand, k=k), (bdim, bdim)
5848
5849top_k_p = Primitive('top_k')
5850top_k_p.multiple_results = True
5851top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
5852top_k_p.def_abstract_eval(_top_k_abstract_eval)
5853xla.translations[top_k_p] = partial(standard_translate, 'top_k')
5854ad.primitive_jvps[top_k_p] = _top_k_jvp
5855batching.primitive_batchers[top_k_p] = _top_k_batch_rule
5856
5857def _stop_gradient_jvp_rule(primals, tangents):
5858  # if we don't call stop_gradient here, we'd only peel off one autodiff tracer
5859  x, = primals
5860  return stop_gradient(x), ad_util.Zero.from_value(x)
5861
5862def _stop_gradient_batch_rule(batched_args, batch_dims):
5863  x, = batched_args
5864  dim, = batch_dims
5865  return stop_gradient(x), dim
5866
5867ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
5868batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
5869
5870
5871def create_token(_=None):
5872  """Creates an XLA token value with no preconditions for sequencing effects.
5873
5874  Experimental.
5875
5876  The argument is ignored. It exists for backward compatibility.
5877  """
5878  if config.omnistaging_enabled:
5879    return create_token_p.bind()
5880  else:
5881    x = _
5882    if x is None:
5883      raise ValueError(
5884          'create_token needs a tie-in operand unless omnistaging is enabled.')
5885    return create_token_p.bind(stop_gradient(x))
5886
5887create_token_p = Primitive("create_token")
5888create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
5889create_token_p.def_abstract_eval(lambda *_: abstract_token)
5890xla.translations[create_token_p] = lambda c, *_: xops.CreateToken(c)
5891
5892def after_all(*operands):
5893  """Merges one or more XLA token values. Experimental.
5894
5895  Wraps the XLA AfterAll operator."""
5896  return after_all_p.bind(*operands)
5897
5898def _after_all_abstract_eval(*operands):
5899  if any(x is not abstract_token for x in operands):
5900    raise TypeError("Arguments to after_all must be tokens")
5901  return abstract_token
5902
5903
5904def _after_all_translation_rule(c, *operands):
5905  return xops.AfterAll(c, operands)
5906
5907after_all_p = Primitive("after_all")
5908after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
5909after_all_p.def_abstract_eval(_after_all_abstract_eval)
5910xla.translations[after_all_p] = _after_all_translation_rule
5911
5912
5913def infeed(token, shape=None, partitions=None):
5914  """Consumes an infeed value of `shape` from the host. Experimental.
5915
5916  `token` is used to sequence infeed and outfeed effects.
5917  `partitions` may be specified inside a `sharded_jit` function.
5918  """
5919  flat_shapes, treedef = pytree.flatten(shape)
5920  for shape in flat_shapes:
5921    if not isinstance(shape, ShapedArray):
5922      raise TypeError("shape argument to infeed must be a pytree of "
5923                      "ShapedArray values, got {}".format(shape))
5924  if partitions is not None:
5925    # Always replicate token.
5926    # We specifically use type() to raise an error for PartitionSpecs.
5927    if type(partitions) != tuple:  # pylint: disable=unidiomatic-typecheck
5928      raise ValueError(f"'partitions' argument to infeed should be a tuple, "
5929                       f"got {partitions}")
5930    partitions = partitions + (None,)
5931  xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
5932                               partitions=partitions)
5933  return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
5934
5935def _infeed_abstract_eval(token, *, shapes, partitions):
5936  if token is not abstract_token:
5937    raise TypeError("First argument to infeed must be a token")
5938  return shapes + (abstract_token,)
5939
5940
5941def _infeed_translation_rule(c, token, *, shapes, partitions):
5942  shape = tuple(shape.with_major_to_minor_layout_if_absent()
5943                for x in shapes for shape in xla.aval_to_xla_shapes(x))
5944  build_infeed = partial(xops.InfeedWithToken, token,
5945                         xla_client.Shape.tuple_shape(shape))
5946  if partitions:
5947    xs_and_token = xb.with_sharding(c, partitions, build_infeed)
5948  else:
5949    # Note that infeed will default to replication if inside a sharded
5950    # computation and no sharding is specified.
5951    xs_and_token = build_infeed()
5952  xs = xops.GetTupleElement(xs_and_token, 0)
5953  token = xops.GetTupleElement(xs_and_token, 1)
5954  outs = [xops.GetTupleElement(xs, i) for i in range(len(shapes))] + [token]
5955  return xops.Tuple(c, outs)
5956
5957infeed_p = Primitive("infeed")
5958infeed_p.multiple_results = True
5959infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
5960infeed_p.def_abstract_eval(_infeed_abstract_eval)
5961xla.translations[infeed_p] = _infeed_translation_rule
5962
5963def outfeed(token, xs):
5964  """Outfeeds value `xs` to the host. Experimental.
5965
5966  `token` is used to sequence infeed and outfeed effects.
5967  """
5968  flat_xs, _ = pytree.flatten(xs)
5969  return outfeed_p.bind(token, *flat_xs)
5970
5971def _outfeed_abstract_eval(token, *xs):
5972  if token is not abstract_token:
5973    raise TypeError("First argument to outfeed must be a token")
5974  return abstract_token
5975
5976
5977def _outfeed_translation_rule(c, token, *xs):
5978  t = xops.Tuple(c, xs)
5979  return xops.OutfeedWithToken(t, token, c.get_shape(t))
5980
5981outfeed_p = Primitive("outfeed")
5982outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
5983outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
5984xla.translations[outfeed_p] = _outfeed_translation_rule
5985
5986def rng_uniform(a, b, shape):
5987  """Stateful PRNG generator. Experimental and its use is discouraged.
5988
5989  Returns uniformly distributed random numbers in the range [a, b)
5990
5991  You should use jax.random for most purposes; this function exists only for
5992  niche use cases with special performance requirements.
5993
5994  This API may be removed at any time.
5995  """
5996  return rng_uniform_p.bind(a, b, shape=tuple(shape))
5997
5998def _rng_uniform_abstract_eval(a, b, *, shape):
5999  if a.dtype != b.dtype:
6000    raise ValueError(
6001      "Arguments to rng_uniform must have identical dtypes, got {} "
6002      "and {}.".format(a.dtype, b.dtype))
6003  if a.shape != () or b.shape != ():
6004    raise ValueError(
6005      "Arguments to rng_uniform must be scalars; got shapes {} and {}."
6006      .format(a.shape, b.shape))
6007  return ShapedArray(shape, a.dtype)
6008
6009def _rng_uniform_translation_rule(c, a, b, *, shape):
6010  xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape)
6011  return xops.RngUniform(a, b, xla_shape)
6012
6013rng_uniform_p = Primitive("rng_uniform")
6014rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
6015rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
6016xla.translations[rng_uniform_p] = _rng_uniform_translation_rule
6017
6018
6019def _iota_abstract_eval(*, dtype, shape, dimension):
6020  _check_shapelike("iota", "shape", shape)
6021  if not any(dtypes.issubdtype(dtype, t) for t in _num):
6022    msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
6023    typename = str(np.dtype(dtype).name)
6024    accepted_typenames = (t.__name__ for t in _num)
6025    raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
6026  if not 0 <= dimension < len(shape):
6027    raise ValueError("iota dimension must be between 0 and len(shape), got "
6028                     f"dimension={dimension} for shape {shape}")
6029  return ShapedArray(shape, dtype)
6030
6031def _iota_translation_rule(c, dtype, shape, dimension):
6032  etype = xla_client.dtype_to_etype(dtype)
6033  xla_shape = xc.Shape.array_shape(etype, shape)
6034  return xops.Iota(c, xla_shape, dimension)
6035
6036iota_p = Primitive('iota')
6037iota_p.def_impl(partial(xla.apply_primitive, iota_p))
6038iota_p.def_abstract_eval(_iota_abstract_eval)
6039xla.translations[iota_p] = _iota_translation_rule
6040
6041
6042### util
6043
6044_ndim = np.ndim
6045
6046
6047def _dilate_shape(shape, dilation):
6048  """Utility function for computing the shape resulting from a dilation."""
6049  if not np.all(np.greater(dilation, 0)):
6050    msg = "All dilations must be positive, got {}."
6051    raise TypeError(msg.format(dilation))
6052  dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
6053  return np.where(shape == 0, 0,
6054                   np.multiply(dilation, np.subtract(shape, 1)) + 1)
6055
6056def _ceil_divide(x1, x2):
6057  return -np.floor_divide(np.negative(x1), x2)
6058
6059def padtype_to_pads(in_shape, window_shape, window_strides, padding):
6060  """Convert padding string to list of pairs of pad values."""
6061  PaddingType = xla_client.PaddingType
6062
6063  if isinstance(padding, str):
6064    mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}
6065    try:
6066      padding = mapping[padding.upper()]
6067    except KeyError as err:
6068      msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
6069      raise RuntimeError(msg.format(padding)) from err
6070
6071  if padding == PaddingType.SAME:
6072    out_shape = _ceil_divide(in_shape, window_strides)
6073    pad_sizes = np.maximum(0, (out_shape - 1) * window_strides +
6074                                window_shape - in_shape)
6075    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
6076  elif padding == PaddingType.VALID:
6077    return [(0, 0)] * len(in_shape)
6078  else:
6079    msg = "Unknown padding type: {}."
6080    raise TypeError(msg.format(padding))
6081
6082
6083def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
6084  """Check that dtypes agree, possibly ignoring float precision."""
6085  # the `ignore_fp_precision` flag exists because the XLA shape inference logic
6086  # allows mixed floating point precision, but the HLO verifier often rejects it
6087  types = list(map(np.dtype, ttypes))  # canonicalize
6088  if ignore_fp_precision:
6089    types = [
6090        np.floating if dtypes.issubdtype(dtype, np.floating)
6091        else np.complexfloating if dtypes.issubdtype(dtype, np.complexfloating)
6092        else dtype for dtype in types]
6093  if len({dtypes.canonicalize_dtype(t) for t in types}) != 1:
6094    if ignore_fp_precision:
6095      msg = ("{} requires arguments to have same dtypes up to floating point "
6096             "precision, got {}.")
6097    else:
6098      msg = "{} requires arguments to have the same dtypes, got {}."
6099    raise TypeError(msg.format(name, ", ".join(map(str, types))))
6100
6101
6102def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
6103  """Check that conv shapes are valid and are consistent with window_strides."""
6104  if len(lhs_shape) != len(rhs_shape):
6105    msg = "Arguments to {} must have same rank, got {} and {}."
6106    raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
6107  if len(lhs_shape) < 2:
6108    msg = "Arguments to {} must have rank at least 2, got {} and {}."
6109    raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
6110  if lhs_shape[1] != rhs_shape[1]:
6111    msg = "Arguments to {} must agree on input feature size, got {} and {}."
6112    raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1]))
6113  _check_shapelike(name, "window_strides", window_strides)
6114  if not np.all(np.greater(window_strides, 0)):
6115    msg = "All elements of window_strides must be positive, got {}."
6116    raise TypeError(msg.format(window_strides))
6117  if len(window_strides) != len(lhs_shape) - 2:
6118    msg = "{} window_strides has wrong length: expected {}, got {}."
6119    expected_length = len(lhs_shape) - 2
6120    raise TypeError(msg.format(name, expected_length, len(window_strides)))
6121
6122
6123def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
6124  """Compute the shape tuple of a conv given input shapes in canonical order."""
6125  if isinstance(pads, str):
6126    pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
6127  if len(pads) != len(lhs_shape) - 2:
6128    msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
6129    raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
6130
6131  lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2),
6132                                              axis=1))
6133  out_space = np.floor_divide(
6134    np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
6135  out_space = np.maximum(0, out_space)
6136  assert lhs_shape[0] % batch_group_count == 0
6137  out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0])
6138  return tuple(out_shape + tuple(out_space))
6139
6140
6141def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
6142                             dimension_numbers):
6143  lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
6144  lhs_trans = np.take(lhs_shape, lhs_perm)
6145  rhs_trans = np.take(rhs_shape, rhs_perm)
6146  out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
6147  return tuple(np.take(out_trans, np.argsort(out_perm)))
6148
6149
6150def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
6151                               dimension_numbers):
6152  lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
6153  lhs_trans = np.take(lhs_shape, lhs_perm)
6154  rhs_trans = np.take(rhs_shape, rhs_perm)
6155  if isinstance(padding, str):
6156    padding = [_conv_transpose_padding(k, s, padding)
6157               for k,s in zip(rhs_trans[2:], window_strides)]
6158  padding = list(map(np.sum, padding))
6159  unpad_out_space = [(i-1) * s - k + 2
6160                     for i, k, s in zip(lhs_trans[2:],
6161                                        rhs_trans[2:],
6162                                        window_strides)]
6163  out_space = np.sum([unpad_out_space, padding], axis=0).tolist()
6164  out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space))
6165  return tuple(np.take(out_trans, np.argsort(out_perm)))
6166
6167
6168def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
6169  """Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
6170  if not isinstance(obj, (tuple, list, np.ndarray)):
6171    msg = "{} {} must be of type tuple/list/ndarray, got {}."
6172    raise TypeError(msg.format(fun_name, arg_name, type(obj)))
6173  # bool(obj) for an ndarray raises an error, so we check len
6174  if not len(obj):  # pylint: disable=g-explicit-length-test
6175    return
6176  obj_arr = np.array(obj)
6177  if obj_arr.ndim != 1:
6178    msg = "{} {} must be rank 1, got {}."
6179    raise TypeError(msg.format(obj_arr.ndim))
6180  try:
6181    canonicalize_shape(obj_arr)
6182  except TypeError as err:
6183    msg = "{} {} must have every element be an integer type, got {}."
6184    raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err
6185  lower_bound, bound_error = (
6186      (1, "strictly positive") if non_zero_shape else (0, "nonnegative"))
6187  if not (obj_arr >= lower_bound).all():
6188    msg = "{} {} must have every element be {}, got {}."
6189    raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
6190
6191
6192def _dynamic_slice_indices(operand, start_indices):
6193  if len(start_indices) != operand.ndim:
6194    msg = ("Length of slice indices must match number of operand dimensions ({} "
6195          "vs {})")
6196    raise ValueError(msg.format(len(start_indices), operand.shape))
6197  # map int over operand.shape to raise any dynamic-shape errors
6198  safe_map(int, operand.shape)
6199  if not isinstance(start_indices, (tuple, list)):
6200    if start_indices.ndim != 1:
6201      raise ValueError("Slice indices must be a 1D sequence, got {}"
6202                       .format(start_indices.shape))
6203    return select(lt(start_indices, _zeros(start_indices)),
6204                  add(start_indices, _const(start_indices, operand.shape)),
6205                  start_indices)
6206  else:
6207    return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
6208            if isinstance(i, (int, np.integer))
6209            else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
6210            for i, d in zip(start_indices, operand.shape)]
6211
6212
6213def _const(example, val):
6214  if dtypes.is_python_scalar(example):
6215    return dtypes.scalar_type_of(example)(val)
6216  return np.array(val, _dtype(example))
6217
6218_zeros: Callable = partial(full_like, fill_value=0)
6219_zero: Callable = partial(full_like, shape=(), fill_value=0)
6220_ones: Callable = partial(full_like, fill_value=1)
6221_one: Callable = partial(full_like, shape=(), fill_value=1)
6222_twos: Callable = partial(full_like, fill_value=2)
6223_two: Callable = partial(full_like, shape=(), fill_value=2)
6224
6225dtype: Callable = dtypes.result_type
6226_dtype: Callable = dtypes.result_type
6227
6228def _iscomplex(x) -> bool:
6229  return dtypes.issubdtype(_dtype(x), np.complexfloating)
6230
6231
6232def ranges_like(*xs):
6233  start = 0
6234  for x in xs:
6235    x_len = len(x)
6236    yield range(start, start + x_len)
6237    start += x_len
6238
6239
6240def remaining(original, *removed_lists):
6241  removed = set(itertools.chain(*removed_lists))
6242  return [i for i in original if i not in removed]
6243
6244
6245def _canonicalize_precision(precision):
6246  if precision is None:
6247    return None
6248  if isinstance(precision, Precision) or (
6249      isinstance(precision, tuple)
6250      and len(precision) == 2
6251      and all(isinstance(p, Precision) for p in precision)
6252  ):
6253    return precision
6254  else:
6255    raise ValueError("Precision argument must be None, a lax.Precision value "
6256                     f"or a tuple of two lax.Precision values; got {precision}")
6257
6258
6259def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers
6260                           ) -> ConvDimensionNumbers:
6261  """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.
6262
6263  Args:
6264    lhs_shape: tuple of nonnegative integers, shape of the convolution input.
6265    rhs_shape: tuple of nonnegative integers, shape of the convolution kernel.
6266    dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers
6267      object following the convolution dimension number specification format in
6268      xla_client.py.
6269
6270  Returns:
6271    A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
6272    canonical form used by lax functions.
6273  """
6274  if isinstance(dimension_numbers, ConvDimensionNumbers):
6275    return dimension_numbers
6276  if len(lhs_shape) != len(rhs_shape):
6277    msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
6278    raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
6279
6280  if dimension_numbers is None:
6281    iota = tuple(range(len(lhs_shape)))
6282    return ConvDimensionNumbers(iota, iota, iota)
6283  elif isinstance(dimension_numbers, (list, tuple)):
6284    if len(dimension_numbers) != 3:
6285      msg = "convolution dimension_numbers list/tuple must be length 3, got {}."
6286      raise TypeError(msg.format(len(dimension_numbers)))
6287    if not all(isinstance(elt, str) for elt in dimension_numbers):
6288      msg = "convolution dimension_numbers elements must be strings, got {}."
6289      raise TypeError(msg.format(tuple(map(type, dimension_numbers))))
6290    msg = ("convolution dimension_numbers[{}] must have len equal to the ndim "
6291           "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.")
6292    for i, elt in enumerate(dimension_numbers):
6293      if len(elt) != len(lhs_shape):
6294        raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape))
6295
6296    lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers)
6297    return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
6298  else:
6299    msg = "convolution dimension_numbers must be tuple/list or None, got {}."
6300    raise TypeError(msg.format(type(dimension_numbers)))
6301
6302
6303def conv_general_permutations(dimension_numbers):
6304  """Utility for convolution dimension permutations relative to Conv HLO."""
6305  lhs_spec, rhs_spec, out_spec = dimension_numbers
6306  lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C")
6307  for i, (a, b) in enumerate(charpairs):
6308    if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1:
6309      msg = ("convolution dimension_numbers[{}] must contain the characters "
6310             "'{}' and '{}' exactly once, got {}.")
6311      raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
6312    if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
6313      msg = ("convolution dimension_numbers[{}] cannot have duplicate "
6314             "characters, got {}.")
6315      raise TypeError(msg.format(i, dimension_numbers[i]))
6316  if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) ==
6317          set(out_spec) - set(out_char)):
6318    msg = ("convolution dimension_numbers elements must each have the same "
6319           "set of spatial characters, got {}.")
6320    raise TypeError(msg.format(dimension_numbers))
6321
6322  def getperm(spec, charpair):
6323    spatial = (i for i, c in enumerate(spec) if c not in charpair)
6324    if spec is not rhs_spec:
6325      spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i]))
6326    return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial)
6327
6328  lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs)
6329  return lhs_perm, rhs_perm, out_perm
6330
6331
6332def _conv_general_proto(dimension_numbers):
6333  assert type(dimension_numbers) is ConvDimensionNumbers
6334  lhs_spec, rhs_spec, out_spec = dimension_numbers
6335  proto = xla_client.ConvolutionDimensionNumbers()
6336  proto.input_batch_dimension = lhs_spec[0]
6337  proto.input_feature_dimension = lhs_spec[1]
6338  proto.output_batch_dimension = out_spec[0]
6339  proto.output_feature_dimension = out_spec[1]
6340  proto.kernel_output_feature_dimension = rhs_spec[0]
6341  proto.kernel_input_feature_dimension = rhs_spec[1]
6342  proto.input_spatial_dimensions.extend(lhs_spec[2:])
6343  proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
6344  proto.output_spatial_dimensions.extend(out_spec[2:])
6345  return proto
6346
6347
6348def _conv_general_vjp_lhs_padding(
6349    in_shape, window_dimensions, window_strides, out_shape, padding,
6350    lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
6351  lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
6352  rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
6353  out_dilated_shape = _dilate_shape(out_shape, window_strides)
6354  pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
6355  pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1
6356               - out_dilated_shape - pad_before)
6357  return safe_zip(pad_before, pad_after)
6358
6359
6360def _conv_general_vjp_rhs_padding(
6361    in_shape, window_dimensions, window_strides, out_shape, padding,
6362    lhs_dilation, rhs_dilation):
6363  lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
6364  rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
6365  out_dilated_shape = _dilate_shape(out_shape, window_strides)
6366  total_in_pad = out_dilated_shape + rhs_dilated_shape - lhs_dilated_shape - 1
6367  return [(pad[0], tot - pad[0]) for pad, tot in zip(padding, total_in_pad)]
6368
6369
6370def _balanced_eq(x, z, y):
6371  return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
6372             select(_eq_meet(y, z), _twos(z), _ones(z)))
6373
6374
6375def _eq_meet(a, b):
6376  a_dtype, b_dtype = _dtype(a), _dtype(b)
6377  if a_dtype != b_dtype:
6378    higher_dtype = dtypes.promote_types(a_dtype, b_dtype)
6379    if higher_dtype == a_dtype:
6380      a = convert_element_type(a, b_dtype)
6381    else:
6382      b = convert_element_type(b, a_dtype)
6383  return eq(a, b)
6384
6385
6386def _abstractify(x):
6387  return raise_to_shaped(core.get_aval(x))
6388
6389
6390def _check_user_dtype_supported(dtype, fun_name=None):
6391  # Avoid using `dtype in [...]` becuase of numpy dtype equality overloading.
6392  if isinstance(dtype, type) and dtype in {bool, int, float, complex}:
6393    return
6394  np_dtype = np.dtype(dtype)
6395  if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16:
6396    msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
6397    msg += f" in {fun_name}" if fun_name else ""
6398    raise TypeError(msg)
6399  if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
6400    msg = ("Explicitly requested dtype {} {} is not available, "
6401           "and will be truncated to dtype {}. To enable more dtypes, set the "
6402           "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
6403           "environment variable. "
6404           "See https://github.com/google/jax#current-gotchas for more.")
6405    fun_name = f"requested in {fun_name}" if fun_name else ""
6406    truncated_dtype = dtypes.canonicalize_dtype(dtype).name
6407    warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
6408
6409
6410def _canonicalize_axis(axis, num_dims):
6411  """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
6412  axis = operator.index(axis)
6413  if not -num_dims <= axis < num_dims:
6414    raise ValueError(
6415        "axis {} is out of bounds for array of dimension {}".format(
6416            axis, num_dims))
6417  if axis < 0:
6418    axis = axis + num_dims
6419  return axis
6420
6421
6422tie_in_p = Primitive('tie_in')
6423
6424@config.register_omnistaging_disabler
6425def omnistaging_disabler() -> None:
6426  global tie_in
6427
6428  def tie_in(x: Array, y: Array) -> Array:
6429    """Returns the value of ``y`` but with a fake data dependence on ``x``.
6430
6431    When staging to XLA (e.g. running under jit or pmap), values that don't depend
6432    on computation inputs are computed op-by-op, and folded into the XLA
6433    computation as constants.
6434
6435    ``tie_in`` provides a way to explicitly stage values into the computation.
6436    When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
6437    is ``y``, but staged to XLA. Downstream use of the result will also be staged
6438    to XLA.
6439
6440    For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
6441    a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
6442    XLA as long as ``x`` is staged to XLA.
6443    """
6444    if config.omnistaging_enabled:
6445      return y
6446    else:
6447      return tie_in_p.bind(x, y)
6448
6449  # If lax has already been imported, we need to monkey-patch the
6450  # lax/__init__.py import of tie_in. If not (i.e. if this is running at lax
6451  # module creation time) then we'll get an import error.
6452  try:
6453    jax.lax.tie_in = tie_in
6454  except AttributeError:
6455    pass
6456
6457  def _tie_in_transpose_rule(t, x, y):
6458    if ad.is_undefined_primal(x):
6459      return [ad_util.Zero(x.aval), t]
6460    else:
6461      return [ad_util.Zero.from_value(x), t]
6462
6463  def _tie_in_batch_rule(batched_args, batch_dims):
6464    y = tie_in(*batched_args)
6465    _, bdim_y = batch_dims
6466    return y, bdim_y
6467
6468  def _tie_in_impl(x, y):
6469    core.check_valid_jaxtype(x)
6470    core.check_valid_jaxtype(y)
6471    return y
6472
6473  def _tie_in_jvp(primals, tangents):
6474    x, y = primals
6475    x_dot, y_dot = tangents
6476    if type(y_dot) is ad_util.Zero or core.get_aval(y_dot).dtype is dtypes.float0:
6477      return y, y_dot  # skip tying in in this case
6478    else:
6479      return ad.linear_jvp(tie_in_p, primals, tangents)
6480
6481  tie_in_p.def_impl(_tie_in_impl)
6482  tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
6483  xla.translations[tie_in_p] = lambda c, x, y: y
6484  ad.primitive_jvps[tie_in_p] = _tie_in_jvp
6485  ad.primitive_transposes[tie_in_p] = partial(ad.linear_transpose2, _tie_in_transpose_rule)
6486  batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
6487  masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
6488