1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines test inputs and invocations for JAX primitives.
15
16The idea is that we want to list all the JAX numeric primitives and for
17each a set of inputs that should cover their use cases. We want these separate
18from any particular test suite so we can reuse it to build multiple kinds of
19tests. For example, we can use the harnesses to check that each primitive is
20compiled correctly, or that we can apply a certain transformation, e.g., `vmap`.
21
22See the `Harness` class below for how to define a harness, describing one
23use case of one primitive.
24
25Some use cases are known to be partially implemented
26in JAX, e.g., because of an implementation limitation. We do have harnesses
27for those cases too, but we filter them out.
28Instead of writing this information as conditions inside one
29particular test, we write them as `Limitation` objects that can be reused in
30multiple tests and can also be used to generate documentation, e.g.,
31the report of [unsupported and
32partially-implemented JAX
33primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
34
35The limitations are used to filter out from tests the harnesses that are known
36to fail. A Limitation is specific to a harness.
37
38"""
39
40import operator
41import os
42from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple, Union
43
44from functools import partial
45
46from absl import testing
47import jax
48from jax import config
49from jax import dtypes
50from jax import ad_util
51from jax import test_util as jtu
52from jax import lax
53from jax import numpy as jnp
54from jax._src.lax import control_flow as lax_control_flow
55from jax.interpreters import xla
56
57from jax.lib import xla_client
58
59import numpy as np
60
61FLAGS = config.FLAGS
62
63Rng = Any  # A random number generator
64DType = Any
65
66class RandArg(NamedTuple):
67  """Descriptor for a randomly generated argument.
68
69  See description of `Harness`.
70  """
71  shape: Tuple[int, ...]
72  dtype: DType
73
74
75class StaticArg(NamedTuple):
76  """Descriptor for a static argument.
77
78  See description of `Harness`.
79  """
80  value: Any
81
82
83class CustomArg(NamedTuple):
84  """Descriptor for a dynamic argument generated from a PRNG factory.
85
86  See description of `Harness`.
87  """
88  make: Callable[[Rng], Any]  # Called with a Rng to make a tensor
89
90
91class Harness:
92  """Specifies inputs and callable for a primitive.
93
94  See the module docstring for an introduction to harnesses.
95
96  A harness is conceptually a callable and a list of arguments, that together
97  exercise a use case. The harness can optionally have additional parameters
98  that can be used by the test.
99
100  The arguments are specified through argument descriptors. An argument
101  descriptor can be:
102    * a numeric value or ndarray, or
103    * an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to
104    generate
105      random tensor of the given shape and type, or
106    * an instance of ``CustomArg(fun)`` to be used with a PRNG, or
107    * an instance of ``StaticArg(value)``. These are values that specialize the
108      callable, but are not exposed as external arguments.
109
110  For example, a harness for ``lax.take(arr, indices, axis=None)`` may want
111  to expose as external (dynamic) argument the array and the indices, and
112  keep the axis as a static argument (technically specializing the `take` to
113  a axis):
114
115    Harness(lax.slice_p,
116            f"take_axis={axis}",
117            lax.take,
118            [RandArg((2, 4), np.float32), np.array([-1, 0, 1]),
119            StaticArg(axis)],
120            axis=axis)
121
122  Each harness can have a list of Limitations that describe the cases when
123  the harness may not be fully implemented.
124  """
125  # The group name most often is the primitive name.
126  group_name: str
127  # Descriptive name of the harness, used as a testcase_name. Unique in a group.
128  name: str
129  # The function taking all arguments (static and dynamic).
130  fun: Callable
131  # Describes how to construct arguments, see the class docstring.
132  arg_descriptors: Sequence[Union[RandArg, StaticArg, CustomArg, Any]]
133  dtype: DType
134  # A set of limitations describing the cases that are not supported or
135  # partially implemented in JAX for this harness.
136  jax_unimplemented: Sequence["Limitation"]
137  rng_factory: Callable
138  # Carry some arbitrary parameters that the test can access.
139  params: Dict[str, Any]
140
141  def __init__(self,
142               group_name,
143               name,
144               fun,
145               arg_descriptors,
146               *,
147               dtype,
148               rng_factory=jtu.rand_default,
149               jax_unimplemented: Sequence["Limitation"] = (),
150               **params):
151    """See class docstring."""
152    self.group_name = group_name
153    self.name = name
154    self.fun = fun  # type: ignore[assignment]
155    self.arg_descriptors = arg_descriptors
156    self.rng_factory = rng_factory  # type: ignore[assignment]
157    self.jax_unimplemented = jax_unimplemented
158    self.dtype = dtype
159    self.params = params
160
161  def __str__(self):
162    return self.fullname
163
164  @property
165  def fullname(self):
166    return self.name if self.group_name is None else f"{self.group_name}_{self.name}"
167
168  def _arg_maker(self, arg_descriptor, rng: Rng):
169    if isinstance(arg_descriptor, StaticArg):
170      return arg_descriptor.value
171    if isinstance(arg_descriptor, RandArg):
172      return self.rng_factory(rng)(arg_descriptor.shape, arg_descriptor.dtype)
173    if isinstance(arg_descriptor, CustomArg):
174      return arg_descriptor.make(rng)
175
176    return arg_descriptor
177
178  def args_maker(self, rng: Rng) -> Sequence:
179    """All-argument maker, including the static ones."""
180    return [self._arg_maker(ad, rng) for ad in self.arg_descriptors]
181
182  def dyn_args_maker(self, rng: Rng) -> Sequence:
183    """A dynamic-argument maker, for use with `dyn_fun`."""
184    return [
185      self._arg_maker(ad, rng)
186      for ad in self.arg_descriptors
187      if not isinstance(ad, StaticArg)
188    ]
189
190  def dyn_fun(self, *dyn_args):
191    """Invokes `fun` given just the dynamic arguments."""
192    all_args = self._args_from_dynargs(dyn_args)
193    return self.fun(*all_args)
194
195  def _args_from_dynargs(self, dyn_args: Sequence) -> Sequence:
196    """All arguments, including the static ones."""
197    next_dynamic_argnum = 0
198    all_args = []
199    for ad in self.arg_descriptors:
200      if isinstance(ad, StaticArg):
201        all_args.append(ad.value)
202      else:
203        all_args.append(dyn_args[next_dynamic_argnum])
204        next_dynamic_argnum += 1
205    return all_args
206
207  def filter(self,
208             device_under_test: str,
209             *,
210             include_jax_unimpl: bool = False,
211             one_containing: Optional[str] = None) -> bool:
212    if not include_jax_unimpl:
213      if any([
214        device_under_test in l.devices
215        for l in self.jax_unimplemented
216        if l.filter(device=device_under_test, dtype=self.dtype)
217      ]):
218        return False
219
220    if one_containing is not None and one_containing not in self.fullname:
221      return False
222    return True
223
224def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str:
225  """User-friendly description of a set of dtypes"""
226  if not dtype_list and empty_means_all:
227    return "all"
228
229  names = set([np.dtype(dt).name for dt in dtype_list])
230  signed = {"int8", "int16", "int32", "int64"}
231  if all([t in names for t in signed]):
232    names = (names - signed) | {"signed"}
233  integers = {"uint8", "uint16", "uint32", "uint64"}
234  if all([t in names for t in integers]):
235    names = (names - integers) | {"unsigned"}
236  integer = {"signed", "unsigned"}
237  if all([t in names for t in integer]):
238    names = (names - integer) | {"integer"}
239
240  floating = {"bfloat16", "float16", "float32", "float64"}
241  if all([t in names for t in floating]):
242    names = (names - floating) | {"floating"}
243
244  complex = {"complex64", "complex128"}
245  if all([t in names for t in complex]):
246    names = (names - complex) | {"complex"}
247
248  inexact = {"floating", "complex"}
249  if all([t in names for t in inexact]):
250    names = (names - inexact) | {"inexact"}
251
252  all_types = {"integer", "inexact", "bool"}
253  if all([t in names for t in all_types]):
254    names = (names - all_types) | {"all"}
255
256  return ", ".join(sorted(list(names)))
257
258
259##### All harnesses in this file.
260all_harnesses: List[Harness] = []
261
262
263def define(
264    group_name,  # Should be the primitive name, as much as possible
265    name,
266    fun,
267    arg_descriptors,
268    *,
269    dtype,
270    rng_factory=jtu.rand_default,
271    jax_unimplemented: Sequence["Limitation"] = (),
272    **params):
273  """Defines a harness and stores it in `all_harnesses`. See Harness."""
274  group_name = str(group_name)
275  h = Harness(group_name, name,
276    fun,
277    arg_descriptors,
278    rng_factory=rng_factory,
279    jax_unimplemented=jax_unimplemented,
280    dtype=dtype,
281    **params)
282  all_harnesses.append(h)
283
284
285class Limitation:
286  """Encodes conditions under which a harness is limited, e.g., not runnable in JAX.
287
288  See the module docstring for an introduction to harnesses and limitations.
289  """
290
291  def __init__(
292      self,
293      description: str,
294      *,
295      enabled: bool = True,
296      devices: Sequence[str] = ("cpu", "gpu", "tpu"),
297      dtypes: Sequence[DType] = (),
298      skip_run: bool = False,
299  ):
300    """Args:
301      description: text to augment the harness group name with the description
302      of the limitation. Used for reports.
303      enabled: whether this limitation is enabled for the harness in which
304        it appears. This is only used during testing to know whether to ignore
305        harness errors. Use this sparingly, prefer `devices` and
306        `dtypes` for enabled conditions that are included in reports.
307      devices: the list of device types for which this applies. Used for
308        filtering during harness execution, and for reports.
309      dtypes: the list of dtypes for which this applies. Used for filtering
310        during harness execution, and for reports.
311    """
312    assert isinstance(description, str), f"{description}"
313    self.description = description
314    self.skip_run = skip_run
315    if isinstance(devices, str):
316      devices = (devices,)
317    else:
318      devices = tuple(devices)
319    self.devices = devices
320    if not isinstance(dtypes, Iterable):
321      dtypes = (dtypes,)
322    else:
323      dtypes = tuple(dtypes)
324    self.dtypes = dtypes
325    self.enabled = enabled  # Does it apply to the current harness?
326
327  def __str__(self):
328    return (f"\"{self.description}\" devices={self.devices} "
329            f"dtypes={[np.dtype(dt).name for dt in self.dtypes]}" +
330            (" (skip_run) " if self.skip_run else ""))
331  __repr__ = __str__
332
333  def filter(self,
334             device: Optional[str] = None,
335             dtype: Optional[DType] = None) -> bool:
336    """Check that a limitation is enabled for the given dtype and device."""
337    return (self.enabled and
338            (not self.dtypes or dtype is None or dtype in self.dtypes) and
339            (device is None or device in self.devices))
340
341
342def parameterized(harnesses: Iterable[Harness],
343                  *,
344                  one_containing: Optional[str] = None,
345                  include_jax_unimpl: bool = False):
346  """Decorator for tests.
347
348  The tests receive a `harness` argument.
349
350  The `JAX_TEST_HARNESS_ONE_CONTAINING` environment variable is useful for
351  debugging. If given, then picks only one harness whose name contains the
352  string. The whole set of parameterized tests is reduced to one test,
353  whose name is not decorated to make it easier to pick with an IDE for
354  running.
355  """
356  one_containing = one_containing or os.environ.get(
357    "JAX_TEST_HARNESS_ONE_CONTAINING")
358  cases = tuple(
359    # Change the testcase name to include the harness name.
360    dict(
361      testcase_name=harness.fullname if one_containing is None else "",
362      harness=harness) for harness in harnesses if harness.filter(
363      jtu.device_under_test(),
364      one_containing=one_containing,
365      include_jax_unimpl=include_jax_unimpl))
366  if one_containing is not None:
367    if not cases:
368      raise ValueError(
369        f"Cannot find test case with name containing {one_containing}."
370        "Names are:"
371        "\n".join([harness.fullname for harness in harnesses]))
372    cases = cases[0:1]
373  if not cases:
374    # We filtered out all the harnesses.
375    return jtu.skip_on_devices(jtu.device_under_test())
376  return testing.parameterized.named_parameters(*cases)
377
378
379###############################################################################
380###                    Harness definitions                                  ###
381###############################################################################
382
383def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
384  define(
385    str(prim),
386    f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
387    prim.bind, [RandArg(shape, dtype)],
388    prim=prim,
389    dtype=dtype,
390    shape=shape)
391
392
393for dtype in (set(jtu.dtypes.all) -
394              set(jtu.dtypes.all_unsigned + jtu.dtypes.boolean)):
395  _make_unary_elementwise_harness(prim=lax.abs_p, dtype=dtype)
396
397for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
398  _make_unary_elementwise_harness(prim=lax.acosh_p, dtype=dtype)
399  _make_unary_elementwise_harness(prim=lax.asinh_p, dtype=dtype)
400  _make_unary_elementwise_harness(prim=lax.atanh_p, dtype=dtype)
401  _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype)
402  _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype)
403  _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype)
404  _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype)
405  _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype)
406  _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype)
407  _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype)
408  _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype)
409  _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype)
410  _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype)
411  _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype)
412  _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype)
413  _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
414  _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
415  _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
416
417for dtype in jtu.dtypes.all_floating:
418  _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)
419  _make_unary_elementwise_harness(prim=lax.bessel_i1e_p, dtype=dtype)
420  _make_unary_elementwise_harness(prim=lax.ceil_p, dtype=dtype)
421  _make_unary_elementwise_harness(prim=lax.erf_p, dtype=dtype)
422  _make_unary_elementwise_harness(prim=lax.erf_inv_p, dtype=dtype)
423  _make_unary_elementwise_harness(prim=lax.erfc_p, dtype=dtype)
424  _make_unary_elementwise_harness(prim=lax.floor_p, dtype=dtype)
425  _make_unary_elementwise_harness(prim=lax.is_finite_p, dtype=dtype)
426  _make_unary_elementwise_harness(prim=lax.lgamma_p, dtype=dtype)
427  _make_unary_elementwise_harness(prim=lax.digamma_p, dtype=dtype)
428
429for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
430  _make_unary_elementwise_harness(prim=lax.neg_p, dtype=dtype)
431  _make_unary_elementwise_harness(prim=lax.sign_p, dtype=dtype)
432
433
434def _make_round_harness(name,
435                        *,
436                        shape=(100, 100),
437                        dtype=np.float32,
438                        rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO,
439                        operand=None):
440  operand = operand if operand is not None else RandArg(shape, dtype)
441  define(
442    "round",
443    f"{name}_shape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_roundingmethod={rounding_method}",
444    lax.round, [operand, StaticArg(rounding_method)],
445    dtype=dtype,
446    operand=operand,
447    rounding_method=rounding_method)
448
449
450# Validate dtypes
451for dtype in jtu.dtypes.all_floating:
452  _make_round_harness("dtypes", dtype=dtype)
453
454for rounding_method in [
455  lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN
456]:
457  operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)
458  _make_round_harness(
459    "rounding_methods", operand=operand, rounding_method=rounding_method)
460
461# Validate edge cases
462for name, operand in [
463  # Checks that https://github.com/google/jax/issues/4952 is resolved
464  ("round_away_from_0",
465   np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)),
466]:
467  _make_round_harness(f"edge_case_{name}", operand=operand)
468
469
470def _make_convert_element_type_harness(name,
471                                       *,
472                                       shape=(100, 100),
473                                       dtype=np.float32,
474                                       new_dtype=np.float32):
475  define(
476    "convert_element_type",
477    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}",
478    lambda arg: (lax.convert_element_type_p.bind(arg, new_dtype=new_dtype)),
479    [RandArg(shape, dtype)],
480    shape=shape,
481    dtype=dtype,
482    new_dtype=new_dtype)
483
484
485for old_dtype in jtu.dtypes.all:
486  # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating
487  # point numbers and new_dtype is an unsigned integer. See issue
488  # https://github.com/google/jax/issues/5082 for details.
489  for new_dtype in (jtu.dtypes.all
490  if not (dtypes.issubdtype(old_dtype, np.floating) or
491          dtypes.issubdtype(old_dtype, np.complexfloating))
492  else set(jtu.dtypes.all) - set(jtu.dtypes.all_unsigned)):
493    _make_convert_element_type_harness(
494      "dtypes_to_dtypes", dtype=old_dtype, new_dtype=new_dtype)
495
496
497def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3):
498  define(
499    "integer_pow",
500    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_y={y}",
501    lax.integer_pow,
502    [RandArg(shape, dtype), StaticArg(y)],
503    shape=shape,
504    dtype=dtype,
505    y=y)
506
507
508for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
509  # Validate dtypes and y values
510  _make_integer_pow_harness("dtypes", dtype=dtype)
511  # Validate overflow behavior by dtype
512  _make_integer_pow_harness("overflow", y=1000, dtype=dtype)
513
514for dtype in jtu.dtypes.all_inexact:
515  # Validate negative y by dtype
516  _make_integer_pow_harness("negative_exp", y=-1000, dtype=dtype)
517
518
519def _make_pow_harness(name,
520                      *,
521                      shapes=((20, 30), (20, 30)),
522                      dtype=np.float32,
523                      lhs=None,
524                      rhs=None):
525  lhs = RandArg(shapes[0], dtype) if lhs is None else lhs
526  rhs = RandArg(shapes[1], dtype) if rhs is None else rhs
527  define(
528    "pow",
529    f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs.shape, dtype)}",
530    lax.pow, [lhs, rhs],
531    lhs=lhs,
532    rhs=rhs,
533    dtype=dtype)
534
535
536for dtype in jtu.dtypes.all_inexact:
537  # Validate dtypes
538  _make_pow_harness("dtypes", dtype=dtype)
539
540# Validate broadcasting behavior
541for shapes in [
542  ((), (4, 5, 6)),  # broadcasting lhs
543  ((4, 5, 6), ()),  # broadcasting rhs
544  ((4, 1, 6), (4, 5, 6)),  # broadcasting lhs on a specific axis
545  ((4, 5, 6), (4, 1, 6)),  # broadcasting rhs on a specific axis
546]:
547  _make_pow_harness("broadcast", shapes=shapes)
548
549
550def _make_reshape_harness(name,
551                          *,
552                          shape=(2, 3),
553                          new_sizes=(3, 2),
554                          dimensions=(0, 1),
555                          dtype=np.float32):
556  define(
557    "reshape",
558    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newsizes={new_sizes}_dimensions={dimensions}",
559    lax.reshape,
560    [RandArg(shape, dtype),
561     StaticArg(new_sizes),
562     StaticArg(dimensions)],
563    shape=shape,
564    dtype=dtype,
565    new_sizes=new_sizes,
566    dimensions=dimensions)
567
568
569# Validate dtypes
570for dtype in jtu.dtypes.all:
571  _make_reshape_harness("dtypes", dtype=dtype)
572
573# Validate new_sizes
574for shape, new_sizes, dimensions in [
575  ((3, 4, 5), (3, 20), (0, 1, 2)),  # merging two dimensions
576  ((3, 4, 5), (4, 15), (0, 1, 2)),  # changing leading dimension
577]:
578  _make_reshape_harness(
579    "new_sizes", shape=shape, new_sizes=new_sizes, dimensions=dimensions)
580# Validate dimensions collapsing order
581for shape, new_sizes, dimensions in [
582  ((3, 4, 5), (3, 20), (2, 1, 0)),  # transpose shape (0, 1, 2) into (2, 1, 0)
583  ((3, 4, 5), (3, 20), (2, 0, 1)),  # transpose shape (0, 1, 2) into (2, 0, 1)
584]:
585  _make_reshape_harness(
586    "dimensions", shape=shape, new_sizes=new_sizes, dimensions=dimensions)
587
588
589def _make_rev_harness(name, *, shape=(4, 5), dtype=np.float32, dimensions=(0,)):
590  define(
591    "rev",
592    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}",
593    lax.rev,
594    [RandArg(shape, dtype), StaticArg(dimensions)],
595    shape=shape,
596    dtype=dtype,
597    dimensions=dimensions)
598
599
600# Validate dtypes
601for dtype in jtu.dtypes.all:
602  _make_rev_harness("dtypes", dtype=dtype)
603# Validate dimensions
604for shape, dimensions in [
605  ((3, 4, 5), ()),  # empty dimensions
606  ((3, 4, 5), (0, 2)),  # some dimensions
607  ((3, 4, 5), (0, 1, 2)),  # all dimensions (ordered)
608  ((3, 4, 5), (2, 0, 1)),  # all dimensions (unordered)
609]:
610  _make_rev_harness("dimensions", shape=shape, dimensions=dimensions)
611
612
613def _make_device_put_harness(name,
614                             *,
615                             shape=(3, 4),
616                             dtype=np.float32,
617                             device=None):
618  _device_fn = lambda: jax.devices(device)[0] if device is not None else None
619  define(
620    "device_put",
621    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_device={device}",
622    lambda x: xla.device_put_p.bind(x, device=_device_fn()),
623    [RandArg(shape, dtype)],
624    shape=shape,
625    dtype=dtype,
626    device=device)
627
628
629# Validate dtypes
630for dtype in jtu.dtypes.all:
631  _make_device_put_harness("dtypes", dtype=dtype)
632# Validate devices
633_make_device_put_harness("devices", device="cpu")
634
635
636def _make_bitcast_convert_type_harness(name,
637                                       *,
638                                       shape=(2, 3),
639                                       dtype=np.float32,
640                                       new_dtype=np.float32):
641  define(
642    "bitcast_convert_type",
643    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}",
644    lambda x: (lax.bitcast_convert_type_p.bind(x, new_dtype=new_dtype)),
645    [RandArg(shape, dtype)],
646    shape=shape,
647    dtype=dtype,
648    new_dtype=new_dtype)
649
650
651def _can_bitcast(dtype, target_dtype):
652  def _to_equivalence_class(dtype):
653    if dtypes.issubdtype(dtype, np.integer):
654      return dtypes.iinfo(dtype).bits
655    elif dtypes.issubdtype(dtype, np.floating):
656      return dtypes.finfo(dtype).bits
657    else:
658      assert dtype == np.bool_ or dtypes.issubdtype(dtype, np.complexfloating)
659      # Complex and boolean types can only be cast to themselves
660      return np.dtype(dtype).name
661
662  return _to_equivalence_class(dtype) == _to_equivalence_class(target_dtype)
663
664
665# Validate dtypes combinations
666for dtype in jtu.dtypes.all:
667  for new_dtype in filter(partial(_can_bitcast, dtype), jtu.dtypes.all):
668    _make_bitcast_convert_type_harness(
669      "dtypes_to_new_dtypes", dtype=dtype, new_dtype=new_dtype)
670
671
672def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32):
673  define(
674    ad_util.add_any_p,
675    f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}",
676    ad_util.add_jaxvals_p.bind,
677    list(map(lambda s: RandArg(s, dtype), shapes)),
678    dtype=dtype,
679    shapes=shapes)
680
681
682for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
683  _make_add_any_harness("dtypes", dtype=dtype)
684
685for rhs_dtype in jtu.dtypes.all:
686  lhs_dtype = np.float32
687  lhs_shape = (2, 3)
688  rhs_shape = (4, 5)
689  define(
690    lax.tie_in_p,
691    f"lhs={jtu.format_shape_dtype_string(lhs_shape, lhs_dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}",
692    lax.tie_in_p.bind,
693    [RandArg(lhs_shape, lhs_dtype),
694     RandArg(rhs_shape, rhs_dtype)],
695    jax_unimplemented=[
696      Limitation(
697        "requires omnistaging to be disabled",
698        enabled=config.omnistaging_enabled)
699    ],
700    dtype=rhs_dtype,
701    lhs_shape=lhs_shape,
702    lhs_dtype=lhs_dtype,
703    rhs_shape=rhs_shape,
704    rhs_dtype=rhs_dtype,
705    primitive=lax.tie_in_p)
706
707for dtype in jtu.dtypes.all:
708  shape: Tuple[int, ...] = (20, 20)
709  define(
710    ad_util.stop_gradient_p,
711    f"{jtu.format_shape_dtype_string(shape, dtype)}",
712    ad_util.stop_gradient_p.bind, [RandArg(shape, dtype)],
713    shape=shape,
714    dtype=dtype)
715
716_LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p)
717
718
719def _make_comparator_harness(name,
720                             *,
721                             dtype=np.float32,
722                             op=lax.eq_p,
723                             lhs_shape=(),
724                             rhs_shape=()):
725  define(
726    op.name,
727    f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
728    lambda *args: op.bind(*args),
729    [RandArg(lhs_shape, dtype),
730     RandArg(rhs_shape, dtype)],
731    op=op,
732    lhs_shape=lhs_shape,
733    rhs_shape=rhs_shape,
734    dtype=dtype)
735
736
737for op in _LAX_COMPARATORS:
738  for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else
739                set(jtu.dtypes.all) - set(jtu.dtypes.complex)):
740    # Validate dtypes
741    _make_comparator_harness("dtypes", dtype=dtype, op=op)
742
743  # Validate broadcasting behavior
744  for lhs_shape, rhs_shape in [
745    ((), (2, 3)),  # broadcast scalar
746    ((1, 2), (3, 2)),  # broadcast along specific axis
747  ]:
748    _make_comparator_harness(
749      "broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op)
750
751for dtype in jtu.dtypes.all:
752  shape = (3, 4, 5)
753  define(
754    "zeros_like",
755    f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
756    ad_util.zeros_like_p.bind, [RandArg(shape, dtype)],
757    shape=shape,
758    dtype=dtype)
759
760for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
761  arg = np.array([-1, -2, 0, 1], dtype=dtype)
762  define(
763    "population_count",
764    f"{jtu.dtype_str(dtype)}",
765    lax.population_count, [arg],
766    dtype=dtype)
767
768
769def _get_max_identity(dtype):
770  if dtypes.issubdtype(dtype, np.inexact):
771    return np.array(-np.inf, dtype)
772  elif dtypes.issubdtype(dtype, np.integer):
773    return np.array(dtypes.iinfo(dtype).min, dtype)
774  elif dtypes.issubdtype(dtype, np.bool_):
775    return np.array(False, np.bool_)
776
777
778def _get_min_identity(dtype):
779  if dtypes.issubdtype(dtype, np.inexact):
780    return np.array(np.inf, dtype)
781  elif dtypes.issubdtype(dtype, np.integer):
782    return np.array(dtypes.iinfo(dtype).max, dtype)
783  elif dtypes.issubdtype(dtype, np.bool_):
784    return np.array(True, np.bool_)
785
786
787def _make_argminmax_harness(prim,
788                            name,
789                            *,
790                            shape=(15,),
791                            dtype=jnp.float32,
792                            axes=(0,),
793                            index_dtype=np.int32,
794                            arr=None):
795  arr = arr if arr is not None else RandArg(shape, dtype)
796  dtype, shape = arr.dtype, arr.shape
797  index_dtype = dtypes.canonicalize_dtype(index_dtype)
798  define(
799    prim,
800    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}",
801    lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
802    shape=shape,
803    dtype=dtype,
804    axes=axes,
805    index_dtype=index_dtype,
806    prim=prim)
807
808
809for prim in [lax.argmin_p, lax.argmax_p]:
810  for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex):
811    # Validate dtypes for each primitive
812    _make_argminmax_harness(prim, "dtypes", dtype=dtype)
813
814  # Validate axes for each primitive; non major axis
815  _make_argminmax_harness(prim, "axes", shape=(18, 12), axes=(1,))
816
817  # Validate index dtype for each primitive
818  for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
819    _make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
820
821
822# TODO(bchetioui): the below documents a limitation of argmin and argmax when a
823# dimension of the input is too large. However, it is not categorizable as it
824# seems that the converter fails before reaching the actual primitive call. This
825# suggests that we may need to harden the converter to handle inputs this big.
826# + tuple( # Document limitation in case of too large axis
827#  _make_argminmax_harness("overflow_axis", prim=prim,
828#                          arr=np.ones((2**31,), dtype=np.uint8))
829#  for prim in [lax.argmin_p, lax.argmax_p]
830# )
831
832
833def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0):
834  define(
835    lax.iota_p,
836    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimension={dimension}",
837    lambda dtype, shape, dim:
838    (lax.iota_p.bind(dtype=dtype, shape=shape, dimension=dim)),
839    [StaticArg(dtype),
840     StaticArg(shape),
841     StaticArg(dimension)],
842    shape=shape,
843    dtype=dtype,
844    dimension=dimension)
845
846
847for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
848  _make_iota_harness("dtypes", dtype=dtype)
849
850# Validate broadcasting
851for shape, dimension in [
852  ((4, 8, 1, 1), 1),  # broadcasting along non-major dimension
853  ((4, 8, 1, 1), 2),  # broadcasting along dimension == 1
854]:
855  _make_iota_harness("broadcasting", shape=shape, dimension=dimension)
856
857
858def _make_div_rem_harness(prim,
859                          name,
860                          *,
861                          shapes=((2,), (2,)),
862                          dtype=np.float32,
863                          arrs=(None, None)):
864  lhs, rhs = arrs
865
866  def _make_non_zero(rng):
867    return jtu.rand_nonzero(rng)(shapes[1], dtype)
868
869  lhs = RandArg(shapes[0], dtype) if lhs is None else lhs
870  rhs_shape = rhs.shape if rhs is not None else shapes[1]
871  rhs_dtype = rhs.dtype if rhs is not None else dtype
872  rhs = CustomArg(_make_non_zero) if rhs is None else rhs
873
874  define(
875    prim,
876    f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, lhs.dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}",
877    prim.bind, [lhs, rhs],
878    dtype=dtype,
879    lhs=lhs,
880    rhs=rhs,
881    prim=prim)
882
883
884for prim in [lax.div_p, lax.rem_p]:
885  for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean) - (
886      set() if prim is lax.div_p else set(jtu.dtypes.complex)):
887    _make_div_rem_harness(prim, "dtypes", dtype=dtype)
888
889  # Validate broadcasting
890  for shapes in [
891    ((2, 1, 3), (2, 4, 3)),  # broadcast dividend
892    ((2, 4, 3), (2, 1, 3)),  # broadcast divisor
893  ]:
894    _make_div_rem_harness(prim, "broadcast", shapes=shapes)
895
896  # Validate singularity points
897  for name, arrs in [
898    ("positive_by_0", (np.ones(
899      (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
900    # TODO: this fails on CPU, different result
901    # ("positive_by_0_int32", (np.ones((2,), dtype=np.int32),
902    #                         np.zeros((2,), dtype=np.int32))),
903    ("negative_by_0", (-np.ones(
904      (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
905    ("0_by_0", (np.zeros(
906      (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
907    ("inf_by_inf", (np.array([np.inf], dtype=np.float32),
908                    np.array([np.inf], dtype=np.float32))),
909  ]:
910    _make_div_rem_harness(prim, f"singularity_{name}", arrs=arrs)
911
912
913def _make_binary_elementwise_harnesses(prim,
914                                       dtypes,
915                                       default_dtype=np.float32,
916                                       jax_unimplemented=lambda **kwargs: []):
917  def _make(name, *, shapes=((20, 20), (20, 20)), dtype):
918    lhs_shape, rhs_shape = shapes
919    define(
920      prim,
921      f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
922      prim.bind, [RandArg(lhs_shape, dtype),
923                  RandArg(rhs_shape, dtype)],
924      jax_unimplemented=jax_unimplemented(
925        dtype=dtype, prim=prim, shapes=shapes),
926      prim=prim,
927      dtype=dtype,
928      shapes=shapes)
929
930  return (tuple(  # Validate dtypes
931    _make("dtypes", dtype=dtype)
932    for dtype in dtypes) + tuple(  # Validate broadcasting
933    _make("broadcasting", dtype=default_dtype, shapes=shapes)
934    for shapes in [
935      ((20, 20), (1, 20)),  # broadcasting rhs
936      ((1, 20), (20, 20)),  # broadcasting lhs
937    ]))
938
939
940_make_binary_elementwise_harnesses(
941  prim=lax.add_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
942
943_make_binary_elementwise_harnesses(
944  prim=lax.mul_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
945
946_make_binary_elementwise_harnesses(
947  prim=lax.atan2_p, dtypes=jtu.dtypes.all_floating)
948
949_make_binary_elementwise_harnesses(
950  prim=lax.igamma_p,
951  dtypes=jtu.dtypes.all_floating,
952  jax_unimplemented=lambda *_, dtype, **kwargs: [
953    Limitation(
954      "XLA internal error", dtypes=[np.float16, dtypes.bfloat16]),
955  ])
956
957_make_binary_elementwise_harnesses(
958  prim=lax.igammac_p,
959  dtypes=jtu.dtypes.all_floating,
960  jax_unimplemented=lambda *_, dtype, **kwargs: [
961    Limitation(
962      "XLA internal error", dtypes=[np.float16, dtypes.bfloat16]),
963  ])
964
965_make_binary_elementwise_harnesses(
966  prim=lax.nextafter_p,
967  dtypes=jtu.dtypes.all_floating)
968
969_make_binary_elementwise_harnesses(
970  prim=lax.and_p,
971  default_dtype=np.int32,
972  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
973         jtu.dtypes.boolean)
974
975_make_binary_elementwise_harnesses(
976  prim=lax.or_p,
977  default_dtype=np.int32,
978  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
979         jtu.dtypes.boolean)
980
981_make_binary_elementwise_harnesses(
982  prim=lax.xor_p,
983  default_dtype=np.int32,
984  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
985         jtu.dtypes.boolean)
986
987_make_binary_elementwise_harnesses(
988  prim=lax.shift_left_p,
989  default_dtype=np.int32,
990  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
991
992_make_binary_elementwise_harnesses(
993  prim=lax.shift_right_logical_p,
994  default_dtype=np.int32,
995  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
996
997_make_binary_elementwise_harnesses(
998  prim=lax.shift_right_arithmetic_p,
999  default_dtype=np.int32,
1000  dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
1001
1002_make_binary_elementwise_harnesses(
1003  prim=lax.sub_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
1004
1005_min_max_special_cases = tuple(
1006  (lhs, rhs)
1007  for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex
1008  for lhs, rhs in [(np.array([np.inf, np.inf], dtype=dtype),
1009                    np.array([np.nan, np.nan], dtype=dtype)),
1010                   (np.array([-np.inf, -np.inf], dtype=dtype),
1011                    np.array([np.nan, np.nan], dtype=dtype))])
1012
1013_make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all)
1014# Validate special cases
1015for lhs, rhs in _min_max_special_cases:
1016  define(
1017    lax.min_p,
1018    f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}",
1019    lax.min_p.bind, [lhs, rhs],
1020    prim=lax.min_p,
1021    dtype=lhs.dtype)
1022
1023_make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all)
1024# Validate special cases
1025for lhs, rhs in _min_max_special_cases:
1026  define(
1027    lax.max_p,
1028    f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}",
1029    lax.max_p.bind, [lhs, rhs],
1030    prim=lax.max_p,
1031    dtype=lhs.dtype)
1032
1033
1034def _make_broadcast_in_dim_harness(name,
1035                                   *,
1036                                   dtype=np.float32,
1037                                   shape=(2,),
1038                                   outshape=(2,),
1039                                   broadcast_dimensions=(0,)):
1040  define(
1041    lax.broadcast_in_dim_p,
1042    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_outshape={outshape}_broadcastdimensions={broadcast_dimensions}",
1043    lambda operand: lax.broadcast_in_dim_p.bind(
1044      operand, shape=outshape, broadcast_dimensions=broadcast_dimensions),
1045    [RandArg(shape, dtype)],
1046    shape=shape,
1047    dtype=dtype,
1048    outshape=outshape,
1049    broadcast_dimensions=broadcast_dimensions)
1050
1051
1052for dtype in jtu.dtypes.all:
1053  _make_broadcast_in_dim_harness("dtypes", dtype=dtype)
1054
1055# Validate parameter combinations
1056for shape, outshape, broadcast_dimensions in [
1057  [(2,), (3, 2), (1,)],  # add major dimension
1058  [(2,), (2, 3), (0,)],  # add inner dimension
1059  [(), (2, 3), ()],  # use scalar shape
1060  [(1, 2), (4, 3, 2), (0, 2)],  # map size 1 dim to different output dim value
1061]:
1062  _make_broadcast_in_dim_harness(
1063    "parameter_combinations",
1064    shape=shape,
1065    outshape=outshape,
1066    broadcast_dimensions=broadcast_dimensions)
1067
1068
1069def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()):
1070  define(
1071    lax.broadcast_p,
1072    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_sizes={sizes}",
1073    lambda operand: lax.broadcast_p.bind(operand, sizes=sizes),
1074    [RandArg(shape, dtype)],
1075    shape=shape,
1076    dtype=dtype,
1077    sizes=sizes)
1078
1079
1080for dtype in jtu.dtypes.all:
1081  _make_broadcast_harness("dtypes", dtype=dtype)
1082
1083# Validate sizes
1084for sizes in [
1085  (2,),  # broadcast 1 dim
1086  (1, 2, 3),  # broadcast n > 1 dims
1087]:
1088  _make_broadcast_harness("sizes", sizes=sizes)
1089
1090for dtype in jtu.dtypes.all_floating:
1091  for arg1, arg2, arg3 in [
1092    (np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype),
1093     np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype),
1094     np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6], dtype=dtype))
1095  ]:
1096    define(
1097      lax.regularized_incomplete_beta_p,
1098      f"_{jtu.dtype_str(dtype)}",
1099      lax.betainc, [arg1, arg2, arg3],
1100      dtype=dtype)
1101
1102## GATHER
1103# Validate dtypes
1104for dtype in set(jtu.dtypes.all):
1105  indices = np.array(2, dtype=np.int32)
1106  shape = (10,)
1107  axis = 0
1108  define(
1109    lax.gather_p,
1110    f"dtypes_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}",
1111    lambda a, i, axis: jnp.take(a, i, axis=axis),
1112    [
1113      RandArg(shape, dtype),
1114      indices,
1115      StaticArg(axis)
1116    ],
1117    dtype=dtype)
1118
1119# Construct gather harnesses using take
1120_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10))
1121for indices in [
1122  # Ensure each set of indices has a distinct shape
1123  np.array(2, dtype=np.int32),
1124  np.array([2], dtype=np.int32),
1125  np.array([2, 4], dtype=np.int32),
1126  np.array([[2, 4], [5, 6]], dtype=np.int32),
1127  np.array([0, 1, 10], dtype=np.int32),  # Index out of bounds
1128  np.array([0, 1, 2, -1], dtype=np.int32),  # Index out of bounds
1129]:
1130  for axis in [0, 1, 2]:
1131    define(
1132      lax.gather_p,
1133      f"from_take_indices_shape={indices.shape}_axis={axis}",
1134      lambda a, i, axis: jnp.take(a, i, axis=axis),
1135      [_gather_input, indices, StaticArg(axis)],
1136      dtype=_gather_input.dtype)
1137
1138# Directly from lax.gather in lax_test.py.
1139for shape, idxs, dnums, slice_sizes in [
1140  ((5,), np.array([[0], [2]]),
1141   lax.GatherDimensionNumbers(
1142     offset_dims=(), collapsed_slice_dims=(0,),
1143     start_index_map=(0,)), (1,)),
1144  ((10,), np.array([[0], [0], [0]]),
1145   lax.GatherDimensionNumbers(
1146     offset_dims=(1,), collapsed_slice_dims=(),
1147     start_index_map=(0,)), (2,)),
1148  ((
1149       10,
1150       5,
1151   ), np.array([[0], [2], [1]]),
1152   lax.GatherDimensionNumbers(
1153     offset_dims=(1,), collapsed_slice_dims=(0,),
1154     start_index_map=(0,)), (1, 3)),
1155  ((10, 5), np.array([[0, 2], [1, 0]]),
1156   lax.GatherDimensionNumbers(
1157     offset_dims=(1,), collapsed_slice_dims=(0,),
1158     start_index_map=(0, 1)), (1, 3)),
1159]:
1160  dtype = np.float32
1161  define(
1162    lax.gather_p,
1163    f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}",
1164    lambda op, idxs, dnums, slice_sizes: lax.gather(
1165      op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes),
1166    [RandArg(shape, dtype), idxs,
1167     StaticArg(dnums),
1168     StaticArg(slice_sizes)],
1169    dtype=dtype)
1170
1171
1172def _make_scatter_harness(name,
1173                          *,
1174                          shape=(5,),
1175                          f_lax=lax.scatter_min,
1176                          indices_are_sorted=False,
1177                          unique_indices=False,
1178                          scatter_indices=np.array([[0], [2]]),
1179                          update_shape=(2,),
1180                          dtype=np.float32,
1181                          dimension_numbers=((), (0,), (0,))):
1182  dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers)
1183  define(
1184    f_lax.__name__,
1185    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}"
1186      .replace(" ", ""),
1187    partial(
1188      f_lax,
1189      indices_are_sorted=indices_are_sorted,
1190      unique_indices=unique_indices), [
1191      RandArg(shape, dtype),
1192      StaticArg(scatter_indices),
1193      RandArg(update_shape, dtype),
1194      StaticArg(dimension_numbers)
1195    ],
1196    jax_unimplemented=[
1197      Limitation(
1198        "unimplemented", devices="tpu", dtypes=np.complex64,
1199        enabled=(f_lax in [lax.scatter_max, lax.scatter_min]))
1200    ],
1201    f_lax=f_lax,
1202    shape=shape,
1203    dtype=dtype,
1204    scatter_indices=scatter_indices,
1205    update_shape=update_shape,
1206    dimension_numbers=dimension_numbers,
1207    indices_are_sorted=indices_are_sorted,
1208    unique_indices=unique_indices)
1209
1210
1211# Validate dtypes
1212for dtype in jtu.dtypes.all:
1213  for f_lax in [lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min]:
1214    if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
1215      continue
1216    _make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
1217
1218# Validate f_lax/update_jaxpr
1219# We explicitly decide against testing lax.scatter, as its reduction function
1220# is lambda x, y: y, which is not commutative and thus makes results
1221# non-deterministic when an index into the operand is updated several times.
1222
1223# Validate shapes, dimension numbers and scatter indices
1224for shape, scatter_indices, update_shape, dimension_numbers in [
1225  ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))),
1226  ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,)))
1227]:
1228  _make_scatter_harness(
1229    "shapes_and_dimension_numbers",
1230    shape=shape,
1231    update_shape=update_shape,
1232    scatter_indices=np.array(scatter_indices),
1233    dimension_numbers=dimension_numbers)
1234
1235# Validate sorted indices
1236_make_scatter_harness("indices_are_sorted", indices_are_sorted=True)
1237# Validate unique_indices
1238# `unique_indices` does not affect correctness, only performance, and thus
1239# does not need to be tested here. If/when it will make sense to add a test
1240# with `unique_indices` = True, particular care will have to be taken with
1241# regards to the choice of parameters, as the results are only predictable
1242# when all the indices to be updated are pairwise non-overlapping. Identifying
1243# such cases is non-trivial.
1244_make_scatter_harness("unique_indices", unique_indices=False)
1245
1246for dtype in jtu.dtypes.all:
1247  arg_shape = (2, 3)
1248  for pads in [
1249    [(0, 0, 0), (0, 0, 0)],  # no padding
1250    [(1, 1, 0), (2, 2, 0)],  # only positive edge padding
1251    [(1, 2, 1), (0, 1, 0)],  # edge padding and interior padding
1252    [(0, 0, 0), (-1, -1, 0)],  # negative padding
1253    [(0, 0, 0), (-2, -2, 4)],  # add big dilation then remove from edges
1254    [(0, 0, 0), (-2, -3, 1)],  # remove everything in one dimension
1255  ]:
1256    define(
1257      lax.pad_p,
1258      f"inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}",
1259      lax.pad,
1260      [RandArg(arg_shape, dtype),
1261       np.array(0, dtype),
1262       StaticArg(pads)],
1263      rng_factory=jtu.rand_small,
1264      arg_shape=arg_shape,
1265      dtype=dtype,
1266      pads=pads)
1267
1268
1269def _make_select_harness(name,
1270                         *,
1271                         shape_pred=(2, 3),
1272                         shape_args=(2, 3),
1273                         dtype=np.float32):
1274  define(
1275    lax.select_p,
1276    f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}",
1277    lax.select, [
1278      RandArg(shape_pred, np.bool_),
1279      RandArg(shape_args, dtype),
1280      RandArg(shape_args, dtype)
1281    ],
1282    shape_pred=shape_pred,
1283    shape_args=shape_args,
1284    dtype=dtype)
1285
1286
1287for dtype in jtu.dtypes.all:
1288  _make_select_harness("dtypes", dtype=dtype)
1289
1290# Validate shapes
1291_make_select_harness("shapes", shape_pred=(), shape_args=(18,))
1292
1293
1294def _make_transpose_harness(name,
1295                            *,
1296                            shape=(2, 3),
1297                            permutation=(1, 0),
1298                            dtype=np.float32):
1299  define(
1300    lax.transpose_p,
1301    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_permutation={permutation}"
1302      .replace(" ", ""),
1303    lambda x: lax.transpose_p.bind(x, permutation=permutation),
1304    [RandArg(shape, dtype)],
1305    shape=shape,
1306    dtype=dtype,
1307    permutation=permutation)
1308
1309
1310for dtype in jtu.dtypes.all:
1311  _make_transpose_harness("dtypes", dtype=dtype)
1312
1313# Validate permutations
1314for shape, permutation in [
1315  ((2, 3, 4), (0, 1, 2)),  # identity
1316  ((2, 3, 4), (1, 2, 0)),  # transposition
1317]:
1318  _make_transpose_harness("permutations", shape=shape, permutation=permutation)
1319
1320
1321## CUMREDUCE
1322def _make_cumreduce_harness(name,
1323                            *,
1324                            f_jax=lax_control_flow.cummin,
1325                            shape=(8, 9),
1326                            dtype=np.float32,
1327                            axis=0,
1328                            reverse=False):
1329  limitations = []
1330  if f_jax.__name__ != "cumsum":
1331    limitations.append(
1332      Limitation(
1333        "unimplemented",
1334        devices="tpu",
1335        dtypes=np.complex64,
1336      ))
1337  define(
1338    f_jax.__name__,
1339    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}_reverse={reverse}",
1340    f_jax, [RandArg(shape, dtype),
1341            StaticArg(axis),
1342            StaticArg(reverse)],
1343    jax_unimplemented=limitations,
1344    f_jax=f_jax,
1345    shape=shape,
1346    dtype=dtype,
1347    axis=axis,
1348    reverse=reverse)
1349
1350
1351# Validate dtypes for each function
1352for f_jax in [
1353  lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumsum,
1354  lax_control_flow.cumprod
1355]:
1356  for dtype in jtu.dtypes.all:
1357    if dtype == np.bool_:
1358      continue
1359    _make_cumreduce_harness("dtype_by_fun", dtype=dtype, f_jax=f_jax)
1360
1361  # Validate axis for each function
1362  shape = (8, 9)
1363  for axis in range(len(shape)):
1364    _make_cumreduce_harness("axis_by_fun", axis=axis, f_jax=f_jax, shape=shape)
1365
1366  # Validate reverse for each function
1367  _make_cumreduce_harness("reverse", reverse=True, f_jax=f_jax)
1368
1369### TOP_K
1370def _make_top_k_harness(name,
1371                        *,
1372                        operand=None,
1373                        shape=(5, 3),
1374                        dtype=np.float32,
1375                        k=2):
1376  if operand is None:
1377    operand = RandArg(shape, dtype)
1378  define(
1379    lax.top_k_p,
1380    f"{name}_inshape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_k={k}",
1381    lax.top_k, [operand, StaticArg(k)],
1382    shape=operand.shape,
1383    dtype=operand.dtype,
1384    k=k)
1385
1386
1387for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex):
1388  # Validate dtypes
1389  _make_top_k_harness("dtypes", dtype=dtype)
1390
1391# Validate implicit properties of the sort
1392for name, operand, k in [("stability",
1393                          np.array([5, 7, 5, 8, 8, 5], dtype=np.int32), 3),
1394                         ("sort_inf_nan",
1395                          np.array([+np.inf, np.nan, -np.nan, -np.inf, 3],
1396                                   dtype=np.float32), 5)]:
1397  _make_top_k_harness(name, operand=operand, k=k)
1398
1399### SORT
1400def _make_sort_harness(name,
1401                       *,
1402                       operands=None,
1403                       shape=(5, 7),
1404                       dtype=np.float32,
1405                       dimension=0,
1406                       is_stable=False,
1407                       num_keys=1):
1408  if operands is None:
1409    operands = [RandArg(shape, dtype)]
1410  define(
1411      lax.sort_p,
1412      f"{name}_num_arrays={len(operands)}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}_num_keys={num_keys}",
1413      lambda *args: lax.sort_p.bind(
1414          *args[:-3], dimension=args[-3], is_stable=args[-2], num_keys=args[-1]
1415      ), [
1416          *operands,
1417          StaticArg(dimension),
1418          StaticArg(is_stable),
1419          StaticArg(num_keys)
1420      ],
1421      shape=operands[0].shape,
1422      dimension=dimension,
1423      dtype=operands[0].dtype,
1424      is_stable=is_stable,
1425      num_keys=num_keys,
1426      num_arrays=len(operands))
1427
1428
1429_lax_sort_multiple_array_shape = (100,)
1430# In order to test lexicographic ordering and sorting stability, the first
1431# array contains only integers 0 and 1
1432_lax_sort_multiple_array_first_arg = (
1433    np.random.uniform(0, 2, _lax_sort_multiple_array_shape).astype(np.int32))
1434
1435# Validate dtypes
1436for dtype in jtu.dtypes.all:
1437  _make_sort_harness("dtypes", dtype=dtype)
1438
1439# Validate dimensions
1440for dimension in [0, 1]:
1441  _make_sort_harness("dimensions", dimension=dimension)
1442# Validate stable sort
1443_make_sort_harness("is_stable", is_stable=True)
1444# Potential edge cases
1445for operands, dimension in [
1446  ([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], 0)
1447]:
1448  _make_sort_harness("edge_cases", operands=operands, dimension=dimension)
1449
1450# Validate multiple arrays, num_keys, and is_stable
1451for is_stable in [False, True]:
1452  for operands in (
1453      [
1454          _lax_sort_multiple_array_first_arg,
1455          RandArg(_lax_sort_multiple_array_shape, np.int32)
1456      ],
1457      [
1458          _lax_sort_multiple_array_first_arg,
1459          RandArg(_lax_sort_multiple_array_shape, np.int32),
1460          RandArg(_lax_sort_multiple_array_shape, np.float32)
1461      ],
1462  ):
1463    for num_keys in range(1, len(operands) + 1):
1464      _make_sort_harness(
1465          "multiple_arrays",
1466          operands=operands,
1467          num_keys=num_keys,
1468          is_stable=is_stable,
1469          shape=_lax_sort_multiple_array_first_arg.shape,
1470          dtype=_lax_sort_multiple_array_first_arg.dtype)
1471
1472
1473def _make_cholesky_arg(shape, dtype, rng):
1474  a = jtu.rand_default(rng)(shape, dtype)
1475  return np.matmul(a, jnp.conj(np.swapaxes(a, -1, -2)))
1476
1477
1478for dtype in jtu.dtypes.all_inexact:
1479  for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]:
1480    define(
1481      lax.linalg.cholesky_p,
1482      f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
1483      lambda *args: lax.linalg.cholesky_p.bind(*args),
1484      [CustomArg(partial(_make_cholesky_arg, shape, dtype))],
1485      jax_unimplemented=[
1486        Limitation(
1487          "unimplemented",
1488          dtypes=[np.float16],
1489          devices=("cpu", "gpu"))
1490      ],
1491      shape=shape,
1492      dtype=dtype)
1493
1494for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
1495  for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]:
1496    for full_matrices in [False, True]:
1497      define(
1498        lax.linalg.qr_p,
1499        f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}",
1500        lax.linalg.qr,
1501        [RandArg(shape, dtype),
1502         StaticArg(full_matrices)],
1503        # See jax.lib.lapack.geqrf for the list of compatible types
1504        jax_unimplemented=[
1505          Limitation(
1506            "unimplemented",
1507            devices=("cpu", "gpu"),
1508            dtypes=[np.float16, dtypes.bfloat16]),
1509        ],
1510        shape=shape,
1511        dtype=dtype,
1512        full_matrices=full_matrices)
1513
1514
1515def _make_fft_harness(name,
1516                      *,
1517                      shape=(14, 15, 16, 17),
1518                      dtype=np.float32,
1519                      fft_type=xla_client.FftType.FFT,
1520                      fft_lengths=(17,)):
1521  def _fft_rng_factory(dtype):
1522    _all_integers = (
1523        jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean)
1524    # For integer types, use small values to keep the errors small
1525    if dtype in _all_integers:
1526      return jtu.rand_small
1527    else:
1528      return jtu.rand_default
1529
1530  define(
1531    lax.fft_p,
1532    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}",
1533    lambda *args: lax.fft_p.bind(
1534      args[0], fft_type=args[1], fft_lengths=args[2]),
1535    [RandArg(shape, dtype),
1536     StaticArg(fft_type),
1537     StaticArg(fft_lengths)],
1538    jax_unimplemented=[
1539      Limitation(
1540        "only 1D FFT is currently supported b/140351181.",
1541        devices="tpu",
1542        enabled=len(fft_lengths) > 1),
1543    ],
1544    rng_factory=_fft_rng_factory(dtype),
1545    shape=shape,
1546    dtype=dtype,
1547    fft_type=fft_type,
1548    fft_lengths=fft_lengths)
1549
1550
1551# FFT, IFFT, RFFT, IRFFT
1552for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])):
1553  # Validate dtypes per FFT type
1554  for dtype in (jtu.dtypes.floating
1555  if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex):
1556    shape = (14, 15, 16, 17)
1557    for fft_lengths in [
1558      (shape[-1],) if fft_type != xla_client.FftType.IRFFT else
1559      ((shape[-1] - 1) * 2,)
1560    ]:
1561      _make_fft_harness(
1562        "dtypes",
1563        shape=shape,
1564        dtype=dtype,
1565        fft_type=fft_type,
1566        fft_lengths=fft_lengths)
1567
1568  # Validate dimensions per FFT type
1569  for dtype in [
1570    np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64
1571  ]:
1572    for dims in [1, 2, 3]:
1573      for fft_lengths in [
1574        shape[-dims:] if fft_type != xla_client.FftType.IRFFT else
1575        shape[-dims:-1] + ((shape[-1] - 1) * 2,)
1576      ]:
1577        _make_fft_harness(
1578          "dims",
1579          shape=shape,
1580          fft_type=fft_type,
1581          fft_lengths=fft_lengths,
1582          dtype=dtype)
1583
1584for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
1585  for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]:
1586    for full_matrices in [False, True]:
1587      for compute_uv in [False, True]:
1588        define(
1589          lax.linalg.svd_p,
1590          f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}",
1591          lambda *args: lax.linalg.svd_p.bind(
1592            args[0], full_matrices=args[1], compute_uv=args[2]), [
1593            RandArg(shape, dtype),
1594            StaticArg(full_matrices),
1595            StaticArg(compute_uv)
1596          ],
1597          jax_unimplemented=[
1598            Limitation(
1599              "unimplemented",
1600              devices=("cpu", "gpu"),
1601              dtypes=[np.float16, dtypes.bfloat16]),
1602            Limitation(
1603              "complex not implemented. Works in JAX for CPU and GPU with custom kernels",
1604              devices="tpu",
1605              dtypes=[np.complex64, np.complex128])
1606          ],
1607          shape=shape,
1608          dtype=dtype,
1609          full_matrices=full_matrices,
1610          compute_uv=compute_uv)
1611
1612for dtype in jtu.dtypes.all_inexact:
1613  for shape in [(0, 0), (5, 5), (2, 6, 6)]:
1614    for compute_left_eigenvectors in [False, True]:
1615      for compute_right_eigenvectors in [False, True]:
1616        define(
1617          lax.linalg.eig_p,
1618          f"shape={jtu.format_shape_dtype_string(shape, dtype)}_computelefteigenvectors={compute_left_eigenvectors}_computerighteigenvectors={compute_right_eigenvectors}",
1619          lax.linalg.eig, [
1620            RandArg(shape, dtype),
1621            StaticArg(compute_left_eigenvectors),
1622            StaticArg(compute_right_eigenvectors)
1623          ],
1624          jax_unimplemented=[
1625            Limitation(
1626              "only supported on CPU in JAX", devices=("tpu", "gpu")),
1627            Limitation(
1628              "unimplemented",
1629              devices="cpu",
1630              dtypes=[np.float16, dtypes.bfloat16])
1631          ],
1632          shape=shape,
1633          dtype=dtype,
1634          compute_left_eigenvectors=compute_left_eigenvectors,
1635          compute_right_eigenvectors=compute_right_eigenvectors)
1636
1637
1638def _make_triangular_eigh_operand(shape, dtype, lower: bool, rng: Rng):
1639  # For testing eigh we use triangular matrices
1640  operand = jtu.rand_default(rng)(shape, dtype)
1641  # Make operand self-adjoint
1642  operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2
1643  # Make operand lower/upper triangular
1644  return operand  # np.tril(operand) if lower else np.triu(operand)
1645
1646
1647for dtype in jtu.dtypes.all_inexact:
1648  for shape in [(0, 0), (50, 50), (2, 20, 20)]:
1649    for lower in [False, True]:
1650      define(
1651        lax.linalg.eigh_p,
1652        f"shape={jtu.format_shape_dtype_string(shape, dtype)}_lower={lower}",
1653        # Make operand lower/upper triangular
1654        lambda operand, lower, symmetrize_input: (lax.linalg.eigh(
1655          jnp.tril(operand)
1656          if lower else jnp.triu(operand), lower, symmetrize_input)),
1657        # lax.linalg.eigh,
1658        [
1659          CustomArg(
1660            partial(_make_triangular_eigh_operand, shape, dtype, lower)),
1661          StaticArg(lower),
1662          StaticArg(False)
1663        ],
1664        jax_unimplemented=[
1665          Limitation(
1666            "complex eigh not supported ",
1667            devices="tpu",
1668            dtypes=[np.complex64, np.complex128]),
1669          Limitation(
1670            "unimplemented", devices="cpu", dtypes=[np.float16]),
1671          Limitation(
1672            "unimplemented", devices="gpu", dtypes=[np.float16]),
1673        ],
1674        shape=shape,
1675        dtype=dtype,
1676        lower=lower)
1677
1678for dtype in jtu.dtypes.all_inexact:
1679  for shape in [
1680    (5, 5),  # square
1681    (3, 5, 5),  # batched
1682    (3, 5),  # non-square
1683  ]:
1684    define(
1685      lax.linalg.lu_p,
1686      f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
1687      lax.linalg.lu, [RandArg(shape, dtype)],
1688      jax_unimplemented=[
1689        Limitation(
1690          "unimplemented", dtypes=[np.float16, dtypes.bfloat16])
1691      ],
1692      shape=shape,
1693      dtype=dtype)
1694
1695
1696def _make_triangular_solve_harness(name,
1697                                   *,
1698                                   left_side=True,
1699                                   lower=False,
1700                                   ab_shapes=((4, 4), (4, 1)),
1701                                   dtype=np.float32,
1702                                   transpose_a=False,
1703                                   conjugate_a=False,
1704                                   unit_diagonal=False):
1705  a_shape, b_shape = ab_shapes
1706  f_lax = lambda a, b: (
1707    lax.linalg.triangular_solve_p.bind(
1708      a,
1709      b,
1710      left_side=left_side,
1711      lower=lower,
1712      transpose_a=transpose_a,
1713      conjugate_a=conjugate_a,
1714      unit_diagonal=unit_diagonal))
1715
1716  define(
1717    lax.linalg.triangular_solve_p,
1718    f"{name}_a={jtu.format_shape_dtype_string(a_shape, dtype)}_b={jtu.format_shape_dtype_string(b_shape, dtype)}_leftside={left_side}_lower={lower}_transposea={transpose_a}_conjugatea={conjugate_a}_unitdiagonal={unit_diagonal}",
1719    f_lax, [RandArg(a_shape, dtype),
1720            RandArg(b_shape, dtype)],
1721    jax_unimplemented=[
1722      Limitation(
1723        "unimplemented", devices="gpu", dtypes=[np.float16]),
1724    ],
1725    dtype=dtype,
1726    a_shape=a_shape,
1727    b_shape=b_shape,
1728    left_side=left_side,
1729    lower=lower,
1730    tranpose_a=transpose_a,
1731    conjugate_a=conjugate_a,
1732    unit_diagonal=unit_diagonal)
1733
1734
1735# Validate dtypes
1736# This first harness runs the tests for all dtypes using default values for
1737# all the other parameters, except unit_diagonal (to ensure that
1738# tf.linalg.set_diag works reliably for all dtypes). Variations of other
1739# parameters can thus safely skip testing their corresponding default value.
1740# Note that this validates solving on the left.
1741for dtype in jtu.dtypes.all_inexact:
1742  for unit_diagonal in [False, True]:
1743    _make_triangular_solve_harness(
1744      "dtypes", dtype=dtype, unit_diagonal=unit_diagonal)
1745
1746# Validate shapes when solving on the right
1747for ab_shapes in [
1748  ((4, 4), (1, 4)),  # standard
1749  ((2, 8, 8), (2, 10, 8)),  # batched
1750]:
1751  _make_triangular_solve_harness(
1752    "shapes_right", ab_shapes=ab_shapes, left_side=False)
1753# Validate transformations of a complex matrix
1754for lower in [False, True]:
1755  for transpose_a in [False, True]:
1756    for conjugate_a in [False, True]:
1757      _make_triangular_solve_harness(
1758        "complex_transformations",
1759        dtype=np.complex64,
1760        lower=lower,
1761        transpose_a=transpose_a,
1762        conjugate_a=conjugate_a)
1763
1764# Validate transformations of a real matrix
1765for lower in [False, True]:
1766  for transpose_a in [False, True]:
1767    # conjugate_a is irrelevant for real dtypes, and is thus omitted
1768    _make_triangular_solve_harness(
1769      "real_transformations",
1770      dtype=np.float32,
1771      lower=lower,
1772      transpose_a=transpose_a)
1773
1774
1775def _make_linear_solve_harnesses():
1776  def linear_solve(a, b, solve, transpose_solve=None, symmetric=False):
1777    matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST)
1778    return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric)
1779
1780  def explicit_jacobian_solve(matvec, b):
1781    return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b))
1782
1783  def _make_harness(name,
1784                    *,
1785                    shape=(4, 4),
1786                    dtype=np.float32,
1787                    symmetric=False,
1788                    solvers=(explicit_jacobian_solve, explicit_jacobian_solve)):
1789    solve, transpose_solve = solvers
1790    transpose_solve_name = transpose_solve.__name__ if transpose_solve else None
1791
1792    def _make_first_argument(rng):
1793      a = jtu.rand_default(rng)(shape, dtype)
1794      if symmetric:
1795        a = a + a.T
1796      return a
1797
1798    define(
1799      lax.linear_solve_p,
1800      f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_symmetric={symmetric}",
1801      linear_solve, [
1802        CustomArg(_make_first_argument),
1803        RandArg(shape[:-1], dtype),
1804        StaticArg(solve),
1805        StaticArg(transpose_solve),
1806        StaticArg(symmetric)
1807      ],
1808      shape=shape,
1809      dtype=dtype,
1810      solve=solve,
1811      transpose_solve=transpose_solve,
1812      symmetric=symmetric)
1813
1814  for dtype in jtu.dtypes.all_floating:
1815    if not dtype in [np.float16, dtypes.bfloat16]:
1816      _make_harness("dtypes", dtype=dtype)
1817  # Validate symmetricity
1818  _make_harness("symmetric", symmetric=True)
1819  # Validate removing transpose_solve
1820  _make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None))
1821
1822
1823_make_linear_solve_harnesses()
1824
1825
1826def _make_slice_harness(name,
1827                        shape=(3,),
1828                        start_indices=(1,),
1829                        limit_indices=(2,),
1830                        strides=None,
1831                        dtype=np.float32):
1832  define(
1833    lax.slice_p,
1834    f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}",
1835    # type: ignore
1836    lax.slice,
1837    [
1838      RandArg(shape, dtype),  # type: ignore
1839      StaticArg(start_indices),  # type: ignore
1840      StaticArg(limit_indices),  # type: ignore
1841      StaticArg(strides)
1842    ],  # type: ignore
1843    dtype=dtype,
1844    shape=shape,  # type: ignore
1845    start_indices=start_indices,  # type: ignore
1846    limit_indices=limit_indices)  # type: ignore
1847
1848
1849# Test first all dtypes
1850for dtype in jtu.dtypes.all:
1851  _make_slice_harness("dtypes", dtype=dtype)
1852# Now test many shapes
1853for shape, start_indices, limit_indices, strides in [
1854  ((3,), (1,), (2,), None),
1855  ((7,), (4,), (7,), None),
1856  ((5,), (1,), (5,), (2,)),
1857  ((8,), (1,), (6,), (2,)),
1858  ((5, 3), (1, 1), (3, 2), None),
1859  ((5, 3), (1, 1), (3, 1), None),
1860  ((7, 5, 3), (4, 0, 1), (7, 1, 3), None),
1861  ((5, 3), (1, 1), (2, 1), (1, 1)),
1862  ((5, 3), (1, 1), (5, 3), (2, 1)),
1863]:
1864  _make_slice_harness(
1865    "shapes",
1866    shape=shape,
1867    start_indices=start_indices,
1868    limit_indices=limit_indices,
1869    strides=strides)
1870
1871
1872def _make_complex_harness(name, *, shapes=((3, 4), (3, 4)), dtype=np.float32):
1873  define(
1874    lax.complex_p,
1875    f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}",
1876    lax.complex_p.bind,
1877    [RandArg(shapes[0], dtype),
1878     RandArg(shapes[1], dtype)],
1879    shapes=shapes,
1880    dtype=dtype)
1881
1882
1883for dtype in jtu.dtypes.floating:
1884  _make_complex_harness("dtypes", dtype=dtype)
1885
1886# Validate broadcasting
1887for shapes in [
1888  ((3, 2), (3, 1)),  # broadcast imaginary part
1889  ((3, 1), (3, 2)),  # broadcast real part
1890]:
1891  _make_complex_harness("broadcast", shapes=shapes)
1892
1893
1894def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs):
1895  define(
1896    lax.conj_p,
1897    f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_kwargs={kwargs}"
1898      .replace(" ", ""),
1899    lambda x: lax.conj_p.bind(x, **kwargs), [RandArg(shape, dtype)],
1900    shape=shape,
1901    dtype=dtype,
1902    **kwargs)
1903
1904
1905for dtype in jtu.dtypes.floating + jtu.dtypes.complex:
1906  _make_conj_harness("dtypes", dtype=dtype)
1907
1908# Validate kwargs
1909_make_conj_harness("kwargs", _input_dtype=np.float32)
1910
1911
1912def _make_real_imag_harness(prim, name, *, shape=(2, 3), dtype=np.float32):
1913  define(
1914    prim,
1915    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}",
1916    prim.bind, [RandArg(shape, dtype)],
1917    shape=shape,
1918    dtype=dtype,
1919    prim=prim)
1920
1921
1922for dtype in jtu.dtypes.complex:
1923  for prim in [lax.real_p, lax.imag_p]:
1924    _make_real_imag_harness(prim, "dtypes", dtype=dtype)
1925
1926
1927def _make_dynamic_slice_harness(name,
1928                                shape=(3,),
1929                                start_indices=(1,),
1930                                limit_indices=(2,),
1931                                dtype=np.float32):
1932  define(
1933    lax.dynamic_slice_p,
1934    f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}",
1935    # type: ignore
1936    lax.dynamic_slice,
1937    [
1938      RandArg(shape, dtype),  # type: ignore
1939      np.array(list(start_indices)),
1940      StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
1941    ],  # type: ignore
1942    dtype=dtype,
1943    shape=shape,  # type: ignore
1944    start_indices=start_indices,  # type: ignore
1945    limit_indices=limit_indices)  # type: ignore
1946
1947
1948# Test first all dtypes
1949for dtype in jtu.dtypes.all:
1950  _make_dynamic_slice_harness("dtypes", dtype=dtype)
1951# Now test many shapes
1952for shape, start_indices, limit_indices in [
1953  ((3,), (1,), (2,)),
1954  ((7,), (4,), (7,)),
1955  ((5,), (1,), (5,)),
1956  ((8,), (1,), (6,)),
1957  ((5, 3), (1, 1), (3, 2)),
1958  ((7, 5, 3), (4, 0, 1), (7, 1, 3)),
1959  ((5, 3), (1, 1), (2, 1)),
1960  # out-of-bounds cases, allowed for dynamic_slice
1961  ((5,), (-1,), (0,)),
1962  ((5,), (-1,), (1,)),
1963  ((5,), (-4,), (-2,)),
1964  ((5,), (-5,), (-2,)),
1965  ((5,), (-6,), (-5,)),
1966  ((5,), (-10,), (-9,)),
1967  ((5,), (-100,), (-99,)),
1968  ((5,), (5,), (6,)),
1969  ((5,), (10,), (11,)),
1970  ((5,), (3,), (6,))
1971]:
1972  _make_dynamic_slice_harness(
1973    "shapes",
1974    shape=shape,
1975    start_indices=start_indices,
1976    limit_indices=limit_indices)
1977
1978
1979def _make_dynamic_update_slice_harness(name,
1980                                       shape=(3,),
1981                                       start_indices=(1,),
1982                                       dtype=np.float32,
1983                                       update_shape=(1,)):
1984  define(
1985    lax.dynamic_update_slice_p,
1986    (
1987      f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}"  # type: ignore
1988      f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
1989      f"_start_indices={start_indices}"),
1990    lax.dynamic_update_slice,
1991    [
1992      RandArg(shape, dtype),  # type: ignore
1993      RandArg(update_shape, dtype),  # type: ignore
1994      np.array(start_indices)
1995    ],  # type: ignore
1996    dtype=dtype,
1997    shape=shape,  # type: ignore
1998    start_indices=start_indices,  # type: ignore
1999    update_shape=update_shape)  # type: ignore
2000
2001
2002# Test first all dtypes
2003for dtype in jtu.dtypes.all:
2004  _make_dynamic_update_slice_harness("dtypes", dtype=dtype)
2005# Now test many shapes
2006for shape, start_indices, update_shape in [
2007  ((3,), (1,), (1,)),
2008  ((5, 3), (1, 1), (3, 1)),
2009  ((7, 5, 3), (4, 1, 0), (2, 0, 1)),
2010  ((3,), (-1,), (1,)),  # out-of-bounds
2011  ((3,), (10,), (1,)),  # out-of-bounds
2012  ((3,), (10,), (2,)),  # out-of-bounds
2013]:
2014  _make_dynamic_update_slice_harness(
2015    "shapes",
2016    shape=shape,
2017    start_indices=start_indices,
2018    update_shape=update_shape)
2019
2020
2021def _make_squeeze_harness(name, shape=(1, 2), dimensions=(0,), dtype=np.float32):
2022  define(
2023    lax.squeeze_p,
2024    f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}",  # type: ignore
2025    lax.squeeze,
2026    [RandArg(shape, dtype), StaticArg(dimensions)],  # type: ignore[has-type]
2027    dtype=dtype,
2028    arg_shape=shape,
2029    dimensions=dimensions)  # type: ignore[has-type]
2030
2031
2032# Test first all dtypes
2033for dtype in set(jtu.dtypes.all):
2034  _make_squeeze_harness("dtypes", dtype=dtype)
2035# Now test many shapes
2036for shape, dimensions in [
2037  ((1,), (0,)),
2038  ((1,), (-1,)),
2039  ((2, 1, 4), (1,)),
2040  ((2, 1, 4), (-2,)),
2041  ((2, 1, 3, 1), (1,)),
2042  ((2, 1, 3, 1), (1, 3)),
2043  ((2, 1, 3, 1), (3,)),
2044  ((2, 1, 3, 1), (1, -1)),
2045]:
2046  _make_squeeze_harness("shapes", shape=shape, dimensions=dimensions)
2047
2048
2049def _make_select_and_scatter_add_harness(name,
2050                                         *,
2051                                         shape=(2, 4, 6),
2052                                         dtype=np.float32,
2053                                         select_prim=lax.ge_p,
2054                                         window_dimensions=(2, 2, 2),
2055                                         window_strides=(1, 1, 1),
2056                                         padding=((0, 0), (0, 0), (0, 0)),
2057                                         nb_inactive_dims=0):
2058  ones = (1,) * len(shape)
2059  cotangent_shape = jax.api.eval_shape(
2060    lambda x: lax._select_and_gather_add(x, x, lax.ge_p, window_dimensions,
2061                                         window_strides, padding, ones, ones),
2062    np.ones(shape, dtype)).shape
2063  define(
2064    lax.select_and_scatter_add_p,
2065    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}",
2066    lax._select_and_scatter_add, [
2067      RandArg(cotangent_shape, dtype),
2068      RandArg(shape, dtype),
2069      StaticArg(select_prim),
2070      StaticArg(window_dimensions),
2071      StaticArg(window_strides),
2072      StaticArg(padding)
2073    ],
2074    jax_unimplemented=[
2075      Limitation(
2076        "works only for 2 or more inactive dimensions",
2077        devices="tpu",
2078        enabled=(nb_inactive_dims < 2))
2079    ],
2080    shape=shape,
2081    dtype=dtype,
2082    select_prim=select_prim,
2083    window_dimensions=window_dimensions,
2084    window_strides=window_strides,
2085    padding=padding)
2086
2087
2088for dtype in set(jtu.dtypes.all) - set([np.complex64, np.complex128]):
2089  _make_select_and_scatter_add_harness("dtypes", dtype=dtype)
2090
2091# Validate different reduction primitives
2092_make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p)
2093
2094# Validate padding
2095for padding in [
2096  # TODO(bchetioui): commented out the test based on
2097  # https://github.com/google/jax/issues/4690
2098  # ((1, 2), (2, 3), (3, 4)) # non-zero padding
2099  ((1, 1), (1, 1), (1, 1))  # non-zero padding
2100]:
2101  _make_select_and_scatter_add_harness("padding", padding=padding)
2102
2103# Validate window_dimensions; uneven dimensions
2104_make_select_and_scatter_add_harness(
2105  "window_dimensions", window_dimensions=(1, 2, 3))
2106
2107# Validate window_strides
2108# smaller than/same as/bigger than corresponding window dimension
2109_make_select_and_scatter_add_harness("window_strides", window_strides=(1, 2, 3))
2110
2111# Validate dtypes on TPU
2112for dtype in set(jtu.dtypes.all) - set(
2113    [np.bool_, np.complex64, np.complex128, np.int8, np.uint8]):
2114  for window_strides, window_dimensions, nb_inactive_dims in [((1, 2, 1),
2115                                                               (1, 3, 1), 2)]:
2116    _make_select_and_scatter_add_harness(
2117      "tpu_dtypes",
2118      dtype=dtype,
2119      nb_inactive_dims=nb_inactive_dims,
2120      window_strides=window_strides,
2121      window_dimensions=window_dimensions)
2122
2123
2124def _make_select_and_gather_add_harness(name,
2125                                        *,
2126                                        shape=(4, 6),
2127                                        dtype=np.float32,
2128                                        select_prim=lax.le_p,
2129                                        padding="VALID",
2130                                        window_dimensions=(2, 2),
2131                                        window_strides=(1, 1),
2132                                        base_dilation=(1, 1),
2133                                        window_dilation=(1, 1)):
2134  if isinstance(padding, str):
2135    padding = tuple(
2136      lax.padtype_to_pads(shape, window_dimensions, window_strides, padding))
2137  define(
2138    lax.select_and_gather_add_p,
2139    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
2140    lax._select_and_gather_add, [
2141      RandArg(shape, dtype),
2142      RandArg(shape, dtype),
2143      StaticArg(select_prim),
2144      StaticArg(window_dimensions),
2145      StaticArg(window_strides),
2146      StaticArg(padding),
2147      StaticArg(base_dilation),
2148      StaticArg(window_dilation)
2149    ],
2150    shape=shape,
2151    dtype=dtype,
2152    window_dimensions=window_dimensions,
2153    window_strides=window_strides,
2154    padding=padding,
2155    base_dilation=base_dilation,
2156    window_dilation=window_dilation)
2157
2158
2159for dtype in jtu.dtypes.all_floating:
2160  for select_prim in [lax.ge_p, lax.le_p]:
2161    _make_select_and_gather_add_harness("dtypes", dtype=dtype, select_prim=select_prim)
2162
2163# Validate selection primitives
2164_make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p)
2165# Validate window dimensions
2166_make_select_and_gather_add_harness(
2167  "window_dimensions", window_dimensions=(2, 3))
2168
2169# Validate window strides
2170_make_select_and_gather_add_harness("window_strides", window_strides=(2, 3))
2171# Validate padding
2172_make_select_and_gather_add_harness("padding", padding="SAME")
2173
2174# Validate dilations
2175for base_dilation, window_dilation in [
2176  ((2, 3), (1, 1)),  # base dilation, no window dilation
2177  ((1, 1), (2, 3)),  # no base dilation, window dilation
2178  ((2, 3), (3, 2))  # base dilation, window dilation
2179]:
2180  _make_select_and_gather_add_harness(
2181    "dilations", base_dilation=base_dilation, window_dilation=window_dilation)
2182
2183
2184def _make_reduce_window_harness(name,
2185                                *,
2186                                shape=(4, 6),
2187                                base_dilation=(1, 1),
2188                                computation=lax.add,
2189                                window_dimensions=(2, 2),
2190                                window_dilation=(1, 1),
2191                                init_value=0,
2192                                window_strides=(1, 1),
2193                                dtype=np.float32,
2194                                padding=((0, 0), (0, 0))):
2195  prim_name = f"reduce_window_{computation.__name__}"
2196  limitations = []
2197  if computation.__name__ in ("max", "mul", "min"):
2198    limitations.append(
2199      Limitation("unimplemented in XLA", devices="tpu", dtypes=np.complex64))
2200  define(
2201    prim_name,
2202    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}"
2203      .replace(" ", ""),
2204    lax.reduce_window,
2205    [
2206      RandArg(shape, dtype),
2207      # Must be static to trigger the picking of the reducers
2208      StaticArg(np.array(init_value, dtype=dtype)),
2209      StaticArg(computation),
2210      StaticArg(window_dimensions),
2211      StaticArg(window_strides),
2212      StaticArg(padding),
2213      StaticArg(base_dilation),
2214      StaticArg(window_dilation)
2215    ],
2216    jax_unimplemented=limitations,
2217    shape=shape,
2218    dtype=dtype,
2219    init_value=np.array(init_value, dtype=dtype),
2220    computation=computation,
2221    window_dimensions=window_dimensions,
2222    window_strides=window_strides,
2223    padding=padding,
2224    base_dilation=base_dilation,
2225    window_dilation=window_dilation)
2226
2227
2228# Validate dtypes across all execution paths
2229# This first harness runs the tests for all dtypes using default values for
2230# the other parameters (outside of computation and its init_value), through
2231# several execution paths. Variations of other parameters can thus safely
2232# skip testing their corresponding default value.
2233for dtype in jtu.dtypes.all:
2234  for computation, init_value in [
2235    (lax.min, _get_min_identity(dtype)),  # path through reduce_window_min
2236    (lax.max, _get_max_identity(dtype)),  # path through TF reduce_window_max
2237    (lax.max, 1),  # path through reduce_window
2238    (lax.add, 0),  # path_through reduce_window_sum
2239    (lax.add, 1),  # path through reduce_window
2240    (lax.mul, 0),  # path through reduce_window
2241    (lax.mul, 1),  # path through reduce_window
2242    (lax.mul, 2),  # path through reduce_window
2243  ]:
2244    if dtype == np.bool_ and (computation in [lax.add, lax.mul]):
2245      continue
2246    _make_reduce_window_harness(
2247      "dtypes", dtype=dtype, computation=computation, init_value=init_value)
2248# Validate window_dimensions
2249_make_reduce_window_harness("window_dimensions", window_dimensions=(1, 1))
2250# Validate window_strides
2251_make_reduce_window_harness("window_strides", window_strides=(1, 2))
2252# Validate padding
2253_make_reduce_window_harness("padding", padding=((1, 2), (0, 3)))
2254# Validate base_dilation
2255_make_reduce_window_harness("base_dilation", base_dilation=(1, 2))
2256# Validate window_dilation
2257_make_reduce_window_harness("window_dilation", window_dilation=(1, 2))
2258# Validate squeezing behavior and dimensions in tf.nn.max_pool
2259for shape, window_dimensions in [
2260  ((2,), (2,)),  # 1 spatial dimension, left and right squeeze
2261  ((2, 1), (2, 1)),  # 1 spatial dimension, left squeeze
2262  ((1, 2), (1, 2)),  # 1 spatial dimension, right squeeze
2263  ((1, 2, 1), (1, 2, 1)),  # 1 spatial dimension no squeeze
2264  ((2, 4), (2, 2)),  # 2 spatial dimensions, left and right squeeze
2265  ((2, 4, 3), (2, 2, 2)),  # 3 spatial dimensions, left and right squeeze
2266  ((1, 4, 3, 2, 1), (1, 2, 2, 2, 1))  # 3 spatial dimensions, no squeeze
2267]:
2268  _make_reduce_window_harness(
2269    "squeeze_dim",
2270    computation=lax.max,
2271    shape=shape,
2272    dtype=np.float32,
2273    init_value=-np.inf,
2274    base_dilation=tuple([1] * len(shape)),
2275    window_dilation=tuple([1] * len(shape)),
2276    padding=tuple([(0, 0)] * len(shape)),
2277    window_strides=tuple([1] * len(shape)),
2278    window_dimensions=window_dimensions)
2279
2280
2281def _make_reducer_harness(prim,
2282                          name,
2283                          *,
2284                          shape=(2, 3),
2285                          axes=(0,),
2286                          dtype=np.int32):
2287  define(
2288    prim,
2289    f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}",
2290    lambda arg: prim.bind(arg, axes=axes), [RandArg(shape, dtype)],
2291    prim=prim,
2292    shape=shape,
2293    dtype=dtype,
2294    axes=axes)
2295
2296
2297for prim in [
2298  lax.reduce_sum_p, lax.reduce_prod_p, lax.reduce_max_p, lax.reduce_min_p,
2299  lax.reduce_or_p, lax.reduce_and_p
2300]:
2301  for dtype in {
2302    lax.reduce_sum_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean),
2303    lax.reduce_prod_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean),
2304    lax.reduce_max_p: jtu.dtypes.all,
2305    lax.reduce_min_p: jtu.dtypes.all,
2306    lax.reduce_or_p: jtu.dtypes.boolean,
2307    lax.reduce_and_p: jtu.dtypes.boolean
2308  }[prim]:
2309    _make_reducer_harness(prim, "dtypes", dtype=dtype)
2310
2311for dtype in (np.float32, np.float64):
2312  for shape in ((), (3,)):
2313    define(
2314      "random_gamma",
2315      f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
2316      jax.jit(jax.random.gamma),
2317      [np.array([42, 43], dtype=np.uint32),
2318       RandArg(shape, dtype)],
2319      dtype=dtype)
2320
2321for key_i, key in enumerate([
2322  np.array([0, 0], dtype=np.uint32),
2323  np.array([42, 43], dtype=np.uint32),
2324  np.array([0xFFFFFFFF, 0], dtype=np.uint32),
2325  np.array([0, 0xFFFFFFFF], dtype=np.uint32),
2326  np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
2327]):
2328  define(
2329    "random_split",
2330    f"_i={key_i}",
2331    jax.jit(lambda key: jax.random.split(key, 2)), [key],
2332    dtype=key.dtype)
2333
2334
2335def _make_clamp_harness(name,
2336                        *,
2337                        min_shape=(),
2338                        operand_shape=(2, 3),
2339                        max_shape=(),
2340                        dtype=np.float32,
2341                        min_max=None):
2342  min_arr, max_arr = (
2343    min_max if min_max is not None else
2344    [RandArg(min_shape, dtype),
2345     RandArg(max_shape, dtype)])
2346  define(
2347    lax.clamp_p,
2348    f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}",
2349    lax.clamp, [min_arr, RandArg(operand_shape, dtype), max_arr],
2350    min_shape=min_arr.shape,
2351    operand_shape=operand_shape,
2352    max_shape=max_arr.shape,
2353    dtype=dtype)
2354
2355
2356for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]):
2357  _make_clamp_harness("dtypes", dtype=dtype)
2358
2359# Validate broadcasting of min/max arrays
2360for min_shape, operand_shape, max_shape in [
2361  ((), (2, 3), (2, 3)),  # no broadcasting for max
2362  ((2, 3), (2, 3), ()),  # no broadcasting for min
2363  ((2, 3), (2, 3), (2, 3)),  # no broadcasting
2364]:
2365  _make_clamp_harness(
2366    "broadcasting",
2367    min_shape=min_shape,
2368    max_shape=max_shape,
2369    operand_shape=operand_shape)
2370
2371# Validate clamping when minval > maxval, and when minval < maxval
2372for is_ordered, min_arr, max_arr in [
2373  (False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)),
2374  (True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32))
2375]:
2376  _make_clamp_harness(
2377    f"order={is_ordered}", min_max=(min_arr, max_arr), dtype=np.float32)
2378
2379
2380def _make_dot_general_harness(name,
2381                              *,
2382                              lhs_shape=(3, 4),
2383                              rhs_shape=(4, 2),
2384                              dtype=np.float32,
2385                              precision=None,
2386                              dimension_numbers=(((1,), (0,)), ((), ()))):
2387  define(
2388    lax.dot_general_p,
2389    f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}"
2390      .replace(" ", ""),
2391    lax.dot_general, [
2392      RandArg(lhs_shape, dtype),
2393      RandArg(rhs_shape, dtype),
2394      StaticArg(dimension_numbers),
2395      StaticArg(precision)
2396    ],
2397    dtype=dtype,
2398    lhs_shape=lhs_shape,
2399    rhs_shape=rhs_shape,
2400    dimension_numbers=dimension_numbers,
2401    precision=precision)
2402
2403
2404# There are two execution paths in the conversion of dot_general. The main path
2405# uses tf.einsum, while special cases use tf.linalg.matmul. For that reason,
2406# the below tests are designed to perform the same checks on both execution
2407# paths.
2408# Validate dtypes and precision
2409# This first harness runs the tests for all dtypes and precisions using
2410# default values for all the other parameters. Variations of other parameters
2411# can thus safely skip testing their corresponding default value.
2412
2413for dtype in jtu.dtypes.all:
2414  for precision in [
2415    None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST
2416  ]:
2417    for lhs_shape, rhs_shape, dimension_numbers in [
2418      ((3, 4), (4, 2), (((1,), (0,)), ((), ()))),
2419      ((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,))))
2420    ]:
2421      _make_dot_general_harness(
2422        "dtypes_and_precision",
2423        precision=precision,
2424        lhs_shape=lhs_shape,
2425        rhs_shape=rhs_shape,
2426        dimension_numbers=dimension_numbers,
2427        dtype=dtype)
2428
2429# Validate batch dimensions
2430for lhs_shape, rhs_shape, dimension_numbers in [
2431  # Unique pattern that can go through tf.linalg.matmul
2432  ((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))),
2433  # Main path with out of order batch dimensions
2434  ((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0))))
2435]:
2436  _make_dot_general_harness(
2437    "batch_dimensions",
2438    lhs_shape=lhs_shape,
2439    rhs_shape=rhs_shape,
2440    dimension_numbers=dimension_numbers)
2441
2442# Validate squeezing behavior for matmul path
2443for lhs_shape, rhs_shape, dimension_numbers in [
2444  ((4,), (4, 4), (((0,), (0,)), ((), ()))),  # (1, 4) -> (4,)
2445  ((4, 4), (4,), (((1,), (0,)), ((), ()))),  # (4, 1) -> (4,)
2446  ((4,), (4,), (((0,), (0,)), ((), ()))),  # (1, 1) -> ()
2447]:
2448  _make_dot_general_harness(
2449    "squeeze",
2450    lhs_shape=lhs_shape,
2451    rhs_shape=rhs_shape,
2452    dimension_numbers=dimension_numbers)
2453
2454
2455def _make_concatenate_harness(name,
2456                              *,
2457                              shapes=[(2, 3), (2, 3)],
2458                              dimension=0,
2459                              dtype=np.float32):
2460  shapes_str = "_".join(jtu.format_shape_dtype_string(s, dtype) for s in shapes)
2461  define(
2462    lax.concatenate_p,
2463    f"{name}_shapes={shapes_str}_dimension={dimension}",
2464    lambda *args: lax.concatenate_p.bind(*args, dimension=dimension),
2465    [RandArg(shape, dtype) for shape in shapes],
2466    shapes=shapes,
2467    dtype=dtype,
2468    dimension=dimension)
2469
2470
2471for dtype in jtu.dtypes.all:
2472  _make_concatenate_harness("dtypes", dtype=dtype)
2473
2474# Validate dimension; non-major axis
2475_make_concatenate_harness("dimension", dimension=1)
2476
2477# Validate > 2 operands
2478for shapes in [
2479  [(2, 3, 4), (3, 3, 4), (4, 3, 4)],  # 3 operands
2480]:
2481  _make_concatenate_harness("nb_operands", shapes=shapes)
2482
2483
2484def _make_conv_harness(name,
2485                       *,
2486                       lhs_shape=(2, 3, 9, 10),
2487                       rhs_shape=(3, 3, 4, 5),
2488                       dtype=np.float32,
2489                       window_strides=(1, 1),
2490                       precision=None,
2491                       padding=((0, 0), (0, 0)),
2492                       lhs_dilation=(1, 1),
2493                       rhs_dilation=(1, 1),
2494                       feature_group_count=1,
2495                       dimension_numbers=("NCHW", "OIHW", "NCHW"),
2496                       batch_group_count=1,
2497                       enable_xla=True):
2498  define(
2499    lax.conv_general_dilated_p,
2500    f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_enablexla={enable_xla}"
2501      .replace(" ", ""),
2502    lax.conv_general_dilated, [
2503      RandArg(lhs_shape, dtype),
2504      RandArg(rhs_shape, dtype),
2505      StaticArg(window_strides),
2506      StaticArg(padding),
2507      StaticArg(lhs_dilation),
2508      StaticArg(rhs_dilation),
2509      StaticArg(dimension_numbers),
2510      StaticArg(feature_group_count),
2511      StaticArg(batch_group_count),
2512      StaticArg(precision)
2513    ],
2514    lhs_shape=lhs_shape,
2515    rhs_shape=rhs_shape,
2516    dtype=dtype,
2517    window_strides=window_strides,
2518    padding=padding,
2519    lhs_dilation=lhs_dilation,
2520    rhs_dilation=rhs_dilation,
2521    dimension_numbers=dimension_numbers,
2522    feature_group_count=feature_group_count,
2523    batch_group_count=batch_group_count,
2524    precision=precision,
2525    enable_xla=enable_xla)
2526
2527
2528# Validate dtypes and precision
2529for dtype in jtu.dtypes.all_inexact:
2530  for precision in [
2531    None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST
2532  ]:
2533    # This first harness runs the tests for all dtypes and precisions using
2534    # default values for all the other parameters. Variations of other parameters
2535    # can thus safely skip testing their corresponding default value.
2536    _make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
2537# Validate variations of feature_group_count and batch_group_count
2538for batch_group_count, feature_group_count in [
2539  (1, 2),  # feature_group_count != 1
2540  (2, 1),  # batch_group_count != 1
2541]:
2542  for lhs_shape, rhs_shape in [
2543    ((2 * batch_group_count, 3 * feature_group_count, 9, 10),
2544     (3 * feature_group_count * batch_group_count, 3, 4, 5))
2545  ]:
2546    _make_conv_harness(
2547      "group_counts",
2548      lhs_shape=lhs_shape,
2549      rhs_shape=rhs_shape,
2550      feature_group_count=feature_group_count,
2551      batch_group_count=batch_group_count)
2552
2553### XXX
2554
2555# Validate variations of window_strides
2556for window_strides in [(2, 3)]:
2557  _make_conv_harness("window_strides", window_strides=window_strides)
2558
2559# Validate variations of padding
2560for padding in [
2561  ((1, 2), (0, 0)),  # padding only one spatial axis
2562  ((1, 2), (2, 1))  # padding on both spatial axes
2563]:
2564  _make_conv_harness("padding", padding=padding)
2565
2566# Validate variations of dilations
2567for lhs_dilation, rhs_dilation in [
2568  ((2, 2), (1, 1)),  # dilation only on LHS (transposed)
2569  ((1, 1), (2, 3)),  # dilation only on RHS (atrous)
2570  ((2, 3), (3, 2))  # dilation on both LHS and RHS (transposed & atrous)
2571]:
2572  _make_conv_harness(
2573    "dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
2574
2575# Dimension numbers and corresponding permutation
2576for dimension_numbers, lhs_shape, rhs_shape in [
2577  (("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)),  # TF default
2578  (("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)),  # custom
2579]:
2580  _make_conv_harness(
2581    "dimension_numbers",
2582    lhs_shape=lhs_shape,
2583    rhs_shape=rhs_shape,
2584    dimension_numbers=dimension_numbers)
2585
2586for padding, lhs_dilation, rhs_dilation in [
2587  ("VALID", (1,), (1,)),  # no dilation with "VALID" padding
2588  ("SAME", (1,), (1,)),  # no dilation with "SAME" padding
2589  ("VALID", (1,), (2,)),  # dilation only on RHS with "VALID" padding
2590  ("SAME", (1,), (2,)),  # dilation only on RHS with "SAME" padding
2591  # TODO(bchetioui): LHS dilation with string padding can never be done using
2592  # TF convolution functions for now.
2593]:
2594  for dimension_numbers, lhs_shape, rhs_shape in [
2595    (("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)),  # TF default
2596    # TODO(bchetioui): the NCW data format is not supported on CPU for TF
2597    # for now. That path is thus disabled to allow the code to use XLA instead.
2598  ]:
2599    for enable_xla in [False, True]:
2600      _make_conv_harness(
2601        "tf_conversion_path_1d",
2602        lhs_shape=lhs_shape,
2603        padding=padding,
2604        rhs_shape=rhs_shape,
2605        dimension_numbers=dimension_numbers,
2606        window_strides=(1,),
2607        lhs_dilation=lhs_dilation,
2608        rhs_dilation=rhs_dilation,
2609        enable_xla=enable_xla)
2610
2611for padding, lhs_dilation, rhs_dilation in [
2612  ("VALID", (1, 1), (1, 1)),  # no dilation with "VALID" padding
2613  ("SAME", (1, 1), (1, 1)),  # no dilation with "SAME" padding
2614  ("VALID", (1, 1), (1, 2)),  # dilation only on RHS with "VALID" padding
2615  ("SAME", (1, 1), (1, 2)),  # dilation only on RHS with "SAME" padding
2616  # TODO(bchetioui): LHS dilation with string padding can never be done using
2617  # TF convolution functions for now.
2618]:
2619  for dimension_numbers, lhs_shape, rhs_shape in [
2620    (("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)),  # TF default
2621    # TODO(bchetioui): the NCHW data format is not supported on CPU for TF
2622    # for now. That path is thus disabled to allow the code to use XLA instead.
2623  ]:
2624    for enable_xla in [False, True]:
2625      _make_conv_harness(
2626        "tf_conversion_path_2d",
2627        lhs_shape=lhs_shape,
2628        padding=padding,
2629        rhs_shape=rhs_shape,
2630        dimension_numbers=dimension_numbers,
2631        window_strides=(1, 1),
2632        lhs_dilation=lhs_dilation,
2633        rhs_dilation=rhs_dilation,
2634        enable_xla=enable_xla)
2635
2636for padding, lhs_dilation, rhs_dilation in [
2637  ("VALID", (1, 1, 1), (1, 1, 1)),  # no dilation with "VALID" padding
2638  ("SAME", (1, 1, 1), (1, 1, 1)),  # no dilation with "SAME" padding
2639  ("VALID", (1, 1, 1), (1, 1,
2640                        2)),  # dilation only on RHS with "VALID" padding
2641  ("SAME", (1, 1, 1), (1, 1, 2)),  # dilation only on RHS with "SAME" padding
2642  # TODO(bchetioui): LHS dilation with string padding can never be done using
2643  # TF convolution functions for now.
2644]:
2645  for dimension_numbers, lhs_shape, rhs_shape in [
2646    # TF default
2647    (("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
2648    # TODO(bchetioui): the NCDHW data format is not supported on CPU for TF
2649    # for now. That path is thus disabled to allow the code to use XLA instead.
2650  ]:
2651    for enable_xla in [False, True]:
2652      _make_conv_harness(
2653        "tf_conversion_path_3d",
2654        lhs_shape=lhs_shape,
2655        padding=padding,
2656        rhs_shape=rhs_shape,
2657        dimension_numbers=dimension_numbers,
2658        window_strides=(1, 1, 1),
2659        lhs_dilation=lhs_dilation,
2660        rhs_dilation=rhs_dilation,
2661        enable_xla=enable_xla)
2662