1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines test inputs and invocations for JAX primitives. 15 16The idea is that we want to list all the JAX numeric primitives and for 17each a set of inputs that should cover their use cases. We want these separate 18from any particular test suite so we can reuse it to build multiple kinds of 19tests. For example, we can use the harnesses to check that each primitive is 20compiled correctly, or that we can apply a certain transformation, e.g., `vmap`. 21 22See the `Harness` class below for how to define a harness, describing one 23use case of one primitive. 24 25Some use cases are known to be partially implemented 26in JAX, e.g., because of an implementation limitation. We do have harnesses 27for those cases too, but we filter them out. 28Instead of writing this information as conditions inside one 29particular test, we write them as `Limitation` objects that can be reused in 30multiple tests and can also be used to generate documentation, e.g., 31the report of [unsupported and 32partially-implemented JAX 33primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) 34 35The limitations are used to filter out from tests the harnesses that are known 36to fail. A Limitation is specific to a harness. 37 38""" 39 40import operator 41import os 42from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple, Union 43 44from functools import partial 45 46from absl import testing 47import jax 48from jax import config 49from jax import dtypes 50from jax import ad_util 51from jax import test_util as jtu 52from jax import lax 53from jax import numpy as jnp 54from jax._src.lax import control_flow as lax_control_flow 55from jax.interpreters import xla 56 57from jax.lib import xla_client 58 59import numpy as np 60 61FLAGS = config.FLAGS 62 63Rng = Any # A random number generator 64DType = Any 65 66class RandArg(NamedTuple): 67 """Descriptor for a randomly generated argument. 68 69 See description of `Harness`. 70 """ 71 shape: Tuple[int, ...] 72 dtype: DType 73 74 75class StaticArg(NamedTuple): 76 """Descriptor for a static argument. 77 78 See description of `Harness`. 79 """ 80 value: Any 81 82 83class CustomArg(NamedTuple): 84 """Descriptor for a dynamic argument generated from a PRNG factory. 85 86 See description of `Harness`. 87 """ 88 make: Callable[[Rng], Any] # Called with a Rng to make a tensor 89 90 91class Harness: 92 """Specifies inputs and callable for a primitive. 93 94 See the module docstring for an introduction to harnesses. 95 96 A harness is conceptually a callable and a list of arguments, that together 97 exercise a use case. The harness can optionally have additional parameters 98 that can be used by the test. 99 100 The arguments are specified through argument descriptors. An argument 101 descriptor can be: 102 * a numeric value or ndarray, or 103 * an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to 104 generate 105 random tensor of the given shape and type, or 106 * an instance of ``CustomArg(fun)`` to be used with a PRNG, or 107 * an instance of ``StaticArg(value)``. These are values that specialize the 108 callable, but are not exposed as external arguments. 109 110 For example, a harness for ``lax.take(arr, indices, axis=None)`` may want 111 to expose as external (dynamic) argument the array and the indices, and 112 keep the axis as a static argument (technically specializing the `take` to 113 a axis): 114 115 Harness(lax.slice_p, 116 f"take_axis={axis}", 117 lax.take, 118 [RandArg((2, 4), np.float32), np.array([-1, 0, 1]), 119 StaticArg(axis)], 120 axis=axis) 121 122 Each harness can have a list of Limitations that describe the cases when 123 the harness may not be fully implemented. 124 """ 125 # The group name most often is the primitive name. 126 group_name: str 127 # Descriptive name of the harness, used as a testcase_name. Unique in a group. 128 name: str 129 # The function taking all arguments (static and dynamic). 130 fun: Callable 131 # Describes how to construct arguments, see the class docstring. 132 arg_descriptors: Sequence[Union[RandArg, StaticArg, CustomArg, Any]] 133 dtype: DType 134 # A set of limitations describing the cases that are not supported or 135 # partially implemented in JAX for this harness. 136 jax_unimplemented: Sequence["Limitation"] 137 rng_factory: Callable 138 # Carry some arbitrary parameters that the test can access. 139 params: Dict[str, Any] 140 141 def __init__(self, 142 group_name, 143 name, 144 fun, 145 arg_descriptors, 146 *, 147 dtype, 148 rng_factory=jtu.rand_default, 149 jax_unimplemented: Sequence["Limitation"] = (), 150 **params): 151 """See class docstring.""" 152 self.group_name = group_name 153 self.name = name 154 self.fun = fun # type: ignore[assignment] 155 self.arg_descriptors = arg_descriptors 156 self.rng_factory = rng_factory # type: ignore[assignment] 157 self.jax_unimplemented = jax_unimplemented 158 self.dtype = dtype 159 self.params = params 160 161 def __str__(self): 162 return self.fullname 163 164 @property 165 def fullname(self): 166 return self.name if self.group_name is None else f"{self.group_name}_{self.name}" 167 168 def _arg_maker(self, arg_descriptor, rng: Rng): 169 if isinstance(arg_descriptor, StaticArg): 170 return arg_descriptor.value 171 if isinstance(arg_descriptor, RandArg): 172 return self.rng_factory(rng)(arg_descriptor.shape, arg_descriptor.dtype) 173 if isinstance(arg_descriptor, CustomArg): 174 return arg_descriptor.make(rng) 175 176 return arg_descriptor 177 178 def args_maker(self, rng: Rng) -> Sequence: 179 """All-argument maker, including the static ones.""" 180 return [self._arg_maker(ad, rng) for ad in self.arg_descriptors] 181 182 def dyn_args_maker(self, rng: Rng) -> Sequence: 183 """A dynamic-argument maker, for use with `dyn_fun`.""" 184 return [ 185 self._arg_maker(ad, rng) 186 for ad in self.arg_descriptors 187 if not isinstance(ad, StaticArg) 188 ] 189 190 def dyn_fun(self, *dyn_args): 191 """Invokes `fun` given just the dynamic arguments.""" 192 all_args = self._args_from_dynargs(dyn_args) 193 return self.fun(*all_args) 194 195 def _args_from_dynargs(self, dyn_args: Sequence) -> Sequence: 196 """All arguments, including the static ones.""" 197 next_dynamic_argnum = 0 198 all_args = [] 199 for ad in self.arg_descriptors: 200 if isinstance(ad, StaticArg): 201 all_args.append(ad.value) 202 else: 203 all_args.append(dyn_args[next_dynamic_argnum]) 204 next_dynamic_argnum += 1 205 return all_args 206 207 def filter(self, 208 device_under_test: str, 209 *, 210 include_jax_unimpl: bool = False, 211 one_containing: Optional[str] = None) -> bool: 212 if not include_jax_unimpl: 213 if any([ 214 device_under_test in l.devices 215 for l in self.jax_unimplemented 216 if l.filter(device=device_under_test, dtype=self.dtype) 217 ]): 218 return False 219 220 if one_containing is not None and one_containing not in self.fullname: 221 return False 222 return True 223 224def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str: 225 """User-friendly description of a set of dtypes""" 226 if not dtype_list and empty_means_all: 227 return "all" 228 229 names = set([np.dtype(dt).name for dt in dtype_list]) 230 signed = {"int8", "int16", "int32", "int64"} 231 if all([t in names for t in signed]): 232 names = (names - signed) | {"signed"} 233 integers = {"uint8", "uint16", "uint32", "uint64"} 234 if all([t in names for t in integers]): 235 names = (names - integers) | {"unsigned"} 236 integer = {"signed", "unsigned"} 237 if all([t in names for t in integer]): 238 names = (names - integer) | {"integer"} 239 240 floating = {"bfloat16", "float16", "float32", "float64"} 241 if all([t in names for t in floating]): 242 names = (names - floating) | {"floating"} 243 244 complex = {"complex64", "complex128"} 245 if all([t in names for t in complex]): 246 names = (names - complex) | {"complex"} 247 248 inexact = {"floating", "complex"} 249 if all([t in names for t in inexact]): 250 names = (names - inexact) | {"inexact"} 251 252 all_types = {"integer", "inexact", "bool"} 253 if all([t in names for t in all_types]): 254 names = (names - all_types) | {"all"} 255 256 return ", ".join(sorted(list(names))) 257 258 259##### All harnesses in this file. 260all_harnesses: List[Harness] = [] 261 262 263def define( 264 group_name, # Should be the primitive name, as much as possible 265 name, 266 fun, 267 arg_descriptors, 268 *, 269 dtype, 270 rng_factory=jtu.rand_default, 271 jax_unimplemented: Sequence["Limitation"] = (), 272 **params): 273 """Defines a harness and stores it in `all_harnesses`. See Harness.""" 274 group_name = str(group_name) 275 h = Harness(group_name, name, 276 fun, 277 arg_descriptors, 278 rng_factory=rng_factory, 279 jax_unimplemented=jax_unimplemented, 280 dtype=dtype, 281 **params) 282 all_harnesses.append(h) 283 284 285class Limitation: 286 """Encodes conditions under which a harness is limited, e.g., not runnable in JAX. 287 288 See the module docstring for an introduction to harnesses and limitations. 289 """ 290 291 def __init__( 292 self, 293 description: str, 294 *, 295 enabled: bool = True, 296 devices: Sequence[str] = ("cpu", "gpu", "tpu"), 297 dtypes: Sequence[DType] = (), 298 skip_run: bool = False, 299 ): 300 """Args: 301 description: text to augment the harness group name with the description 302 of the limitation. Used for reports. 303 enabled: whether this limitation is enabled for the harness in which 304 it appears. This is only used during testing to know whether to ignore 305 harness errors. Use this sparingly, prefer `devices` and 306 `dtypes` for enabled conditions that are included in reports. 307 devices: the list of device types for which this applies. Used for 308 filtering during harness execution, and for reports. 309 dtypes: the list of dtypes for which this applies. Used for filtering 310 during harness execution, and for reports. 311 """ 312 assert isinstance(description, str), f"{description}" 313 self.description = description 314 self.skip_run = skip_run 315 if isinstance(devices, str): 316 devices = (devices,) 317 else: 318 devices = tuple(devices) 319 self.devices = devices 320 if not isinstance(dtypes, Iterable): 321 dtypes = (dtypes,) 322 else: 323 dtypes = tuple(dtypes) 324 self.dtypes = dtypes 325 self.enabled = enabled # Does it apply to the current harness? 326 327 def __str__(self): 328 return (f"\"{self.description}\" devices={self.devices} " 329 f"dtypes={[np.dtype(dt).name for dt in self.dtypes]}" + 330 (" (skip_run) " if self.skip_run else "")) 331 __repr__ = __str__ 332 333 def filter(self, 334 device: Optional[str] = None, 335 dtype: Optional[DType] = None) -> bool: 336 """Check that a limitation is enabled for the given dtype and device.""" 337 return (self.enabled and 338 (not self.dtypes or dtype is None or dtype in self.dtypes) and 339 (device is None or device in self.devices)) 340 341 342def parameterized(harnesses: Iterable[Harness], 343 *, 344 one_containing: Optional[str] = None, 345 include_jax_unimpl: bool = False): 346 """Decorator for tests. 347 348 The tests receive a `harness` argument. 349 350 The `JAX_TEST_HARNESS_ONE_CONTAINING` environment variable is useful for 351 debugging. If given, then picks only one harness whose name contains the 352 string. The whole set of parameterized tests is reduced to one test, 353 whose name is not decorated to make it easier to pick with an IDE for 354 running. 355 """ 356 one_containing = one_containing or os.environ.get( 357 "JAX_TEST_HARNESS_ONE_CONTAINING") 358 cases = tuple( 359 # Change the testcase name to include the harness name. 360 dict( 361 testcase_name=harness.fullname if one_containing is None else "", 362 harness=harness) for harness in harnesses if harness.filter( 363 jtu.device_under_test(), 364 one_containing=one_containing, 365 include_jax_unimpl=include_jax_unimpl)) 366 if one_containing is not None: 367 if not cases: 368 raise ValueError( 369 f"Cannot find test case with name containing {one_containing}." 370 "Names are:" 371 "\n".join([harness.fullname for harness in harnesses])) 372 cases = cases[0:1] 373 if not cases: 374 # We filtered out all the harnesses. 375 return jtu.skip_on_devices(jtu.device_under_test()) 376 return testing.parameterized.named_parameters(*cases) 377 378 379############################################################################### 380### Harness definitions ### 381############################################################################### 382 383def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): 384 define( 385 str(prim), 386 f"shape={jtu.format_shape_dtype_string(shape, dtype)}", 387 prim.bind, [RandArg(shape, dtype)], 388 prim=prim, 389 dtype=dtype, 390 shape=shape) 391 392 393for dtype in (set(jtu.dtypes.all) - 394 set(jtu.dtypes.all_unsigned + jtu.dtypes.boolean)): 395 _make_unary_elementwise_harness(prim=lax.abs_p, dtype=dtype) 396 397for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: 398 _make_unary_elementwise_harness(prim=lax.acosh_p, dtype=dtype) 399 _make_unary_elementwise_harness(prim=lax.asinh_p, dtype=dtype) 400 _make_unary_elementwise_harness(prim=lax.atanh_p, dtype=dtype) 401 _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) 402 _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) 403 _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) 404 _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) 405 _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) 406 _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) 407 _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) 408 _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) 409 _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) 410 _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) 411 _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) 412 _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) 413 _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) 414 _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) 415 _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) 416 417for dtype in jtu.dtypes.all_floating: 418 _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) 419 _make_unary_elementwise_harness(prim=lax.bessel_i1e_p, dtype=dtype) 420 _make_unary_elementwise_harness(prim=lax.ceil_p, dtype=dtype) 421 _make_unary_elementwise_harness(prim=lax.erf_p, dtype=dtype) 422 _make_unary_elementwise_harness(prim=lax.erf_inv_p, dtype=dtype) 423 _make_unary_elementwise_harness(prim=lax.erfc_p, dtype=dtype) 424 _make_unary_elementwise_harness(prim=lax.floor_p, dtype=dtype) 425 _make_unary_elementwise_harness(prim=lax.is_finite_p, dtype=dtype) 426 _make_unary_elementwise_harness(prim=lax.lgamma_p, dtype=dtype) 427 _make_unary_elementwise_harness(prim=lax.digamma_p, dtype=dtype) 428 429for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): 430 _make_unary_elementwise_harness(prim=lax.neg_p, dtype=dtype) 431 _make_unary_elementwise_harness(prim=lax.sign_p, dtype=dtype) 432 433 434def _make_round_harness(name, 435 *, 436 shape=(100, 100), 437 dtype=np.float32, 438 rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO, 439 operand=None): 440 operand = operand if operand is not None else RandArg(shape, dtype) 441 define( 442 "round", 443 f"{name}_shape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_roundingmethod={rounding_method}", 444 lax.round, [operand, StaticArg(rounding_method)], 445 dtype=dtype, 446 operand=operand, 447 rounding_method=rounding_method) 448 449 450# Validate dtypes 451for dtype in jtu.dtypes.all_floating: 452 _make_round_harness("dtypes", dtype=dtype) 453 454for rounding_method in [ 455 lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN 456]: 457 operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32) 458 _make_round_harness( 459 "rounding_methods", operand=operand, rounding_method=rounding_method) 460 461# Validate edge cases 462for name, operand in [ 463 # Checks that https://github.com/google/jax/issues/4952 is resolved 464 ("round_away_from_0", 465 np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)), 466]: 467 _make_round_harness(f"edge_case_{name}", operand=operand) 468 469 470def _make_convert_element_type_harness(name, 471 *, 472 shape=(100, 100), 473 dtype=np.float32, 474 new_dtype=np.float32): 475 define( 476 "convert_element_type", 477 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}", 478 lambda arg: (lax.convert_element_type_p.bind(arg, new_dtype=new_dtype)), 479 [RandArg(shape, dtype)], 480 shape=shape, 481 dtype=dtype, 482 new_dtype=new_dtype) 483 484 485for old_dtype in jtu.dtypes.all: 486 # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating 487 # point numbers and new_dtype is an unsigned integer. See issue 488 # https://github.com/google/jax/issues/5082 for details. 489 for new_dtype in (jtu.dtypes.all 490 if not (dtypes.issubdtype(old_dtype, np.floating) or 491 dtypes.issubdtype(old_dtype, np.complexfloating)) 492 else set(jtu.dtypes.all) - set(jtu.dtypes.all_unsigned)): 493 _make_convert_element_type_harness( 494 "dtypes_to_dtypes", dtype=old_dtype, new_dtype=new_dtype) 495 496 497def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3): 498 define( 499 "integer_pow", 500 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_y={y}", 501 lax.integer_pow, 502 [RandArg(shape, dtype), StaticArg(y)], 503 shape=shape, 504 dtype=dtype, 505 y=y) 506 507 508for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): 509 # Validate dtypes and y values 510 _make_integer_pow_harness("dtypes", dtype=dtype) 511 # Validate overflow behavior by dtype 512 _make_integer_pow_harness("overflow", y=1000, dtype=dtype) 513 514for dtype in jtu.dtypes.all_inexact: 515 # Validate negative y by dtype 516 _make_integer_pow_harness("negative_exp", y=-1000, dtype=dtype) 517 518 519def _make_pow_harness(name, 520 *, 521 shapes=((20, 30), (20, 30)), 522 dtype=np.float32, 523 lhs=None, 524 rhs=None): 525 lhs = RandArg(shapes[0], dtype) if lhs is None else lhs 526 rhs = RandArg(shapes[1], dtype) if rhs is None else rhs 527 define( 528 "pow", 529 f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs.shape, dtype)}", 530 lax.pow, [lhs, rhs], 531 lhs=lhs, 532 rhs=rhs, 533 dtype=dtype) 534 535 536for dtype in jtu.dtypes.all_inexact: 537 # Validate dtypes 538 _make_pow_harness("dtypes", dtype=dtype) 539 540# Validate broadcasting behavior 541for shapes in [ 542 ((), (4, 5, 6)), # broadcasting lhs 543 ((4, 5, 6), ()), # broadcasting rhs 544 ((4, 1, 6), (4, 5, 6)), # broadcasting lhs on a specific axis 545 ((4, 5, 6), (4, 1, 6)), # broadcasting rhs on a specific axis 546]: 547 _make_pow_harness("broadcast", shapes=shapes) 548 549 550def _make_reshape_harness(name, 551 *, 552 shape=(2, 3), 553 new_sizes=(3, 2), 554 dimensions=(0, 1), 555 dtype=np.float32): 556 define( 557 "reshape", 558 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newsizes={new_sizes}_dimensions={dimensions}", 559 lax.reshape, 560 [RandArg(shape, dtype), 561 StaticArg(new_sizes), 562 StaticArg(dimensions)], 563 shape=shape, 564 dtype=dtype, 565 new_sizes=new_sizes, 566 dimensions=dimensions) 567 568 569# Validate dtypes 570for dtype in jtu.dtypes.all: 571 _make_reshape_harness("dtypes", dtype=dtype) 572 573# Validate new_sizes 574for shape, new_sizes, dimensions in [ 575 ((3, 4, 5), (3, 20), (0, 1, 2)), # merging two dimensions 576 ((3, 4, 5), (4, 15), (0, 1, 2)), # changing leading dimension 577]: 578 _make_reshape_harness( 579 "new_sizes", shape=shape, new_sizes=new_sizes, dimensions=dimensions) 580# Validate dimensions collapsing order 581for shape, new_sizes, dimensions in [ 582 ((3, 4, 5), (3, 20), (2, 1, 0)), # transpose shape (0, 1, 2) into (2, 1, 0) 583 ((3, 4, 5), (3, 20), (2, 0, 1)), # transpose shape (0, 1, 2) into (2, 0, 1) 584]: 585 _make_reshape_harness( 586 "dimensions", shape=shape, new_sizes=new_sizes, dimensions=dimensions) 587 588 589def _make_rev_harness(name, *, shape=(4, 5), dtype=np.float32, dimensions=(0,)): 590 define( 591 "rev", 592 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}", 593 lax.rev, 594 [RandArg(shape, dtype), StaticArg(dimensions)], 595 shape=shape, 596 dtype=dtype, 597 dimensions=dimensions) 598 599 600# Validate dtypes 601for dtype in jtu.dtypes.all: 602 _make_rev_harness("dtypes", dtype=dtype) 603# Validate dimensions 604for shape, dimensions in [ 605 ((3, 4, 5), ()), # empty dimensions 606 ((3, 4, 5), (0, 2)), # some dimensions 607 ((3, 4, 5), (0, 1, 2)), # all dimensions (ordered) 608 ((3, 4, 5), (2, 0, 1)), # all dimensions (unordered) 609]: 610 _make_rev_harness("dimensions", shape=shape, dimensions=dimensions) 611 612 613def _make_device_put_harness(name, 614 *, 615 shape=(3, 4), 616 dtype=np.float32, 617 device=None): 618 _device_fn = lambda: jax.devices(device)[0] if device is not None else None 619 define( 620 "device_put", 621 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_device={device}", 622 lambda x: xla.device_put_p.bind(x, device=_device_fn()), 623 [RandArg(shape, dtype)], 624 shape=shape, 625 dtype=dtype, 626 device=device) 627 628 629# Validate dtypes 630for dtype in jtu.dtypes.all: 631 _make_device_put_harness("dtypes", dtype=dtype) 632# Validate devices 633_make_device_put_harness("devices", device="cpu") 634 635 636def _make_bitcast_convert_type_harness(name, 637 *, 638 shape=(2, 3), 639 dtype=np.float32, 640 new_dtype=np.float32): 641 define( 642 "bitcast_convert_type", 643 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}", 644 lambda x: (lax.bitcast_convert_type_p.bind(x, new_dtype=new_dtype)), 645 [RandArg(shape, dtype)], 646 shape=shape, 647 dtype=dtype, 648 new_dtype=new_dtype) 649 650 651def _can_bitcast(dtype, target_dtype): 652 def _to_equivalence_class(dtype): 653 if dtypes.issubdtype(dtype, np.integer): 654 return dtypes.iinfo(dtype).bits 655 elif dtypes.issubdtype(dtype, np.floating): 656 return dtypes.finfo(dtype).bits 657 else: 658 assert dtype == np.bool_ or dtypes.issubdtype(dtype, np.complexfloating) 659 # Complex and boolean types can only be cast to themselves 660 return np.dtype(dtype).name 661 662 return _to_equivalence_class(dtype) == _to_equivalence_class(target_dtype) 663 664 665# Validate dtypes combinations 666for dtype in jtu.dtypes.all: 667 for new_dtype in filter(partial(_can_bitcast, dtype), jtu.dtypes.all): 668 _make_bitcast_convert_type_harness( 669 "dtypes_to_new_dtypes", dtype=dtype, new_dtype=new_dtype) 670 671 672def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32): 673 define( 674 ad_util.add_any_p, 675 f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}", 676 ad_util.add_jaxvals_p.bind, 677 list(map(lambda s: RandArg(s, dtype), shapes)), 678 dtype=dtype, 679 shapes=shapes) 680 681 682for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): 683 _make_add_any_harness("dtypes", dtype=dtype) 684 685for rhs_dtype in jtu.dtypes.all: 686 lhs_dtype = np.float32 687 lhs_shape = (2, 3) 688 rhs_shape = (4, 5) 689 define( 690 lax.tie_in_p, 691 f"lhs={jtu.format_shape_dtype_string(lhs_shape, lhs_dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}", 692 lax.tie_in_p.bind, 693 [RandArg(lhs_shape, lhs_dtype), 694 RandArg(rhs_shape, rhs_dtype)], 695 jax_unimplemented=[ 696 Limitation( 697 "requires omnistaging to be disabled", 698 enabled=config.omnistaging_enabled) 699 ], 700 dtype=rhs_dtype, 701 lhs_shape=lhs_shape, 702 lhs_dtype=lhs_dtype, 703 rhs_shape=rhs_shape, 704 rhs_dtype=rhs_dtype, 705 primitive=lax.tie_in_p) 706 707for dtype in jtu.dtypes.all: 708 shape: Tuple[int, ...] = (20, 20) 709 define( 710 ad_util.stop_gradient_p, 711 f"{jtu.format_shape_dtype_string(shape, dtype)}", 712 ad_util.stop_gradient_p.bind, [RandArg(shape, dtype)], 713 shape=shape, 714 dtype=dtype) 715 716_LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p) 717 718 719def _make_comparator_harness(name, 720 *, 721 dtype=np.float32, 722 op=lax.eq_p, 723 lhs_shape=(), 724 rhs_shape=()): 725 define( 726 op.name, 727 f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}", 728 lambda *args: op.bind(*args), 729 [RandArg(lhs_shape, dtype), 730 RandArg(rhs_shape, dtype)], 731 op=op, 732 lhs_shape=lhs_shape, 733 rhs_shape=rhs_shape, 734 dtype=dtype) 735 736 737for op in _LAX_COMPARATORS: 738 for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else 739 set(jtu.dtypes.all) - set(jtu.dtypes.complex)): 740 # Validate dtypes 741 _make_comparator_harness("dtypes", dtype=dtype, op=op) 742 743 # Validate broadcasting behavior 744 for lhs_shape, rhs_shape in [ 745 ((), (2, 3)), # broadcast scalar 746 ((1, 2), (3, 2)), # broadcast along specific axis 747 ]: 748 _make_comparator_harness( 749 "broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op) 750 751for dtype in jtu.dtypes.all: 752 shape = (3, 4, 5) 753 define( 754 "zeros_like", 755 f"shape={jtu.format_shape_dtype_string(shape, dtype)}", 756 ad_util.zeros_like_p.bind, [RandArg(shape, dtype)], 757 shape=shape, 758 dtype=dtype) 759 760for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: 761 arg = np.array([-1, -2, 0, 1], dtype=dtype) 762 define( 763 "population_count", 764 f"{jtu.dtype_str(dtype)}", 765 lax.population_count, [arg], 766 dtype=dtype) 767 768 769def _get_max_identity(dtype): 770 if dtypes.issubdtype(dtype, np.inexact): 771 return np.array(-np.inf, dtype) 772 elif dtypes.issubdtype(dtype, np.integer): 773 return np.array(dtypes.iinfo(dtype).min, dtype) 774 elif dtypes.issubdtype(dtype, np.bool_): 775 return np.array(False, np.bool_) 776 777 778def _get_min_identity(dtype): 779 if dtypes.issubdtype(dtype, np.inexact): 780 return np.array(np.inf, dtype) 781 elif dtypes.issubdtype(dtype, np.integer): 782 return np.array(dtypes.iinfo(dtype).max, dtype) 783 elif dtypes.issubdtype(dtype, np.bool_): 784 return np.array(True, np.bool_) 785 786 787def _make_argminmax_harness(prim, 788 name, 789 *, 790 shape=(15,), 791 dtype=jnp.float32, 792 axes=(0,), 793 index_dtype=np.int32, 794 arr=None): 795 arr = arr if arr is not None else RandArg(shape, dtype) 796 dtype, shape = arr.dtype, arr.shape 797 index_dtype = dtypes.canonicalize_dtype(index_dtype) 798 define( 799 prim, 800 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}", 801 lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr], 802 shape=shape, 803 dtype=dtype, 804 axes=axes, 805 index_dtype=index_dtype, 806 prim=prim) 807 808 809for prim in [lax.argmin_p, lax.argmax_p]: 810 for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex): 811 # Validate dtypes for each primitive 812 _make_argminmax_harness(prim, "dtypes", dtype=dtype) 813 814 # Validate axes for each primitive; non major axis 815 _make_argminmax_harness(prim, "axes", shape=(18, 12), axes=(1,)) 816 817 # Validate index dtype for each primitive 818 for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: 819 _make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype) 820 821 822# TODO(bchetioui): the below documents a limitation of argmin and argmax when a 823# dimension of the input is too large. However, it is not categorizable as it 824# seems that the converter fails before reaching the actual primitive call. This 825# suggests that we may need to harden the converter to handle inputs this big. 826# + tuple( # Document limitation in case of too large axis 827# _make_argminmax_harness("overflow_axis", prim=prim, 828# arr=np.ones((2**31,), dtype=np.uint8)) 829# for prim in [lax.argmin_p, lax.argmax_p] 830# ) 831 832 833def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0): 834 define( 835 lax.iota_p, 836 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimension={dimension}", 837 lambda dtype, shape, dim: 838 (lax.iota_p.bind(dtype=dtype, shape=shape, dimension=dim)), 839 [StaticArg(dtype), 840 StaticArg(shape), 841 StaticArg(dimension)], 842 shape=shape, 843 dtype=dtype, 844 dimension=dimension) 845 846 847for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean): 848 _make_iota_harness("dtypes", dtype=dtype) 849 850# Validate broadcasting 851for shape, dimension in [ 852 ((4, 8, 1, 1), 1), # broadcasting along non-major dimension 853 ((4, 8, 1, 1), 2), # broadcasting along dimension == 1 854]: 855 _make_iota_harness("broadcasting", shape=shape, dimension=dimension) 856 857 858def _make_div_rem_harness(prim, 859 name, 860 *, 861 shapes=((2,), (2,)), 862 dtype=np.float32, 863 arrs=(None, None)): 864 lhs, rhs = arrs 865 866 def _make_non_zero(rng): 867 return jtu.rand_nonzero(rng)(shapes[1], dtype) 868 869 lhs = RandArg(shapes[0], dtype) if lhs is None else lhs 870 rhs_shape = rhs.shape if rhs is not None else shapes[1] 871 rhs_dtype = rhs.dtype if rhs is not None else dtype 872 rhs = CustomArg(_make_non_zero) if rhs is None else rhs 873 874 define( 875 prim, 876 f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, lhs.dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}", 877 prim.bind, [lhs, rhs], 878 dtype=dtype, 879 lhs=lhs, 880 rhs=rhs, 881 prim=prim) 882 883 884for prim in [lax.div_p, lax.rem_p]: 885 for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean) - ( 886 set() if prim is lax.div_p else set(jtu.dtypes.complex)): 887 _make_div_rem_harness(prim, "dtypes", dtype=dtype) 888 889 # Validate broadcasting 890 for shapes in [ 891 ((2, 1, 3), (2, 4, 3)), # broadcast dividend 892 ((2, 4, 3), (2, 1, 3)), # broadcast divisor 893 ]: 894 _make_div_rem_harness(prim, "broadcast", shapes=shapes) 895 896 # Validate singularity points 897 for name, arrs in [ 898 ("positive_by_0", (np.ones( 899 (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), 900 # TODO: this fails on CPU, different result 901 # ("positive_by_0_int32", (np.ones((2,), dtype=np.int32), 902 # np.zeros((2,), dtype=np.int32))), 903 ("negative_by_0", (-np.ones( 904 (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), 905 ("0_by_0", (np.zeros( 906 (2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))), 907 ("inf_by_inf", (np.array([np.inf], dtype=np.float32), 908 np.array([np.inf], dtype=np.float32))), 909 ]: 910 _make_div_rem_harness(prim, f"singularity_{name}", arrs=arrs) 911 912 913def _make_binary_elementwise_harnesses(prim, 914 dtypes, 915 default_dtype=np.float32, 916 jax_unimplemented=lambda **kwargs: []): 917 def _make(name, *, shapes=((20, 20), (20, 20)), dtype): 918 lhs_shape, rhs_shape = shapes 919 define( 920 prim, 921 f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}", 922 prim.bind, [RandArg(lhs_shape, dtype), 923 RandArg(rhs_shape, dtype)], 924 jax_unimplemented=jax_unimplemented( 925 dtype=dtype, prim=prim, shapes=shapes), 926 prim=prim, 927 dtype=dtype, 928 shapes=shapes) 929 930 return (tuple( # Validate dtypes 931 _make("dtypes", dtype=dtype) 932 for dtype in dtypes) + tuple( # Validate broadcasting 933 _make("broadcasting", dtype=default_dtype, shapes=shapes) 934 for shapes in [ 935 ((20, 20), (1, 20)), # broadcasting rhs 936 ((1, 20), (20, 20)), # broadcasting lhs 937 ])) 938 939 940_make_binary_elementwise_harnesses( 941 prim=lax.add_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) 942 943_make_binary_elementwise_harnesses( 944 prim=lax.mul_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) 945 946_make_binary_elementwise_harnesses( 947 prim=lax.atan2_p, dtypes=jtu.dtypes.all_floating) 948 949_make_binary_elementwise_harnesses( 950 prim=lax.igamma_p, 951 dtypes=jtu.dtypes.all_floating, 952 jax_unimplemented=lambda *_, dtype, **kwargs: [ 953 Limitation( 954 "XLA internal error", dtypes=[np.float16, dtypes.bfloat16]), 955 ]) 956 957_make_binary_elementwise_harnesses( 958 prim=lax.igammac_p, 959 dtypes=jtu.dtypes.all_floating, 960 jax_unimplemented=lambda *_, dtype, **kwargs: [ 961 Limitation( 962 "XLA internal error", dtypes=[np.float16, dtypes.bfloat16]), 963 ]) 964 965_make_binary_elementwise_harnesses( 966 prim=lax.nextafter_p, 967 dtypes=jtu.dtypes.all_floating) 968 969_make_binary_elementwise_harnesses( 970 prim=lax.and_p, 971 default_dtype=np.int32, 972 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + 973 jtu.dtypes.boolean) 974 975_make_binary_elementwise_harnesses( 976 prim=lax.or_p, 977 default_dtype=np.int32, 978 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + 979 jtu.dtypes.boolean) 980 981_make_binary_elementwise_harnesses( 982 prim=lax.xor_p, 983 default_dtype=np.int32, 984 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + 985 jtu.dtypes.boolean) 986 987_make_binary_elementwise_harnesses( 988 prim=lax.shift_left_p, 989 default_dtype=np.int32, 990 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) 991 992_make_binary_elementwise_harnesses( 993 prim=lax.shift_right_logical_p, 994 default_dtype=np.int32, 995 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) 996 997_make_binary_elementwise_harnesses( 998 prim=lax.shift_right_arithmetic_p, 999 default_dtype=np.int32, 1000 dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned) 1001 1002_make_binary_elementwise_harnesses( 1003 prim=lax.sub_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean)) 1004 1005_min_max_special_cases = tuple( 1006 (lhs, rhs) 1007 for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex 1008 for lhs, rhs in [(np.array([np.inf, np.inf], dtype=dtype), 1009 np.array([np.nan, np.nan], dtype=dtype)), 1010 (np.array([-np.inf, -np.inf], dtype=dtype), 1011 np.array([np.nan, np.nan], dtype=dtype))]) 1012 1013_make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all) 1014# Validate special cases 1015for lhs, rhs in _min_max_special_cases: 1016 define( 1017 lax.min_p, 1018 f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}", 1019 lax.min_p.bind, [lhs, rhs], 1020 prim=lax.min_p, 1021 dtype=lhs.dtype) 1022 1023_make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all) 1024# Validate special cases 1025for lhs, rhs in _min_max_special_cases: 1026 define( 1027 lax.max_p, 1028 f"inf_nan_{jtu.dtype_str(lhs.dtype)}_{lhs[0]}_{rhs[0]}", 1029 lax.max_p.bind, [lhs, rhs], 1030 prim=lax.max_p, 1031 dtype=lhs.dtype) 1032 1033 1034def _make_broadcast_in_dim_harness(name, 1035 *, 1036 dtype=np.float32, 1037 shape=(2,), 1038 outshape=(2,), 1039 broadcast_dimensions=(0,)): 1040 define( 1041 lax.broadcast_in_dim_p, 1042 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_outshape={outshape}_broadcastdimensions={broadcast_dimensions}", 1043 lambda operand: lax.broadcast_in_dim_p.bind( 1044 operand, shape=outshape, broadcast_dimensions=broadcast_dimensions), 1045 [RandArg(shape, dtype)], 1046 shape=shape, 1047 dtype=dtype, 1048 outshape=outshape, 1049 broadcast_dimensions=broadcast_dimensions) 1050 1051 1052for dtype in jtu.dtypes.all: 1053 _make_broadcast_in_dim_harness("dtypes", dtype=dtype) 1054 1055# Validate parameter combinations 1056for shape, outshape, broadcast_dimensions in [ 1057 [(2,), (3, 2), (1,)], # add major dimension 1058 [(2,), (2, 3), (0,)], # add inner dimension 1059 [(), (2, 3), ()], # use scalar shape 1060 [(1, 2), (4, 3, 2), (0, 2)], # map size 1 dim to different output dim value 1061]: 1062 _make_broadcast_in_dim_harness( 1063 "parameter_combinations", 1064 shape=shape, 1065 outshape=outshape, 1066 broadcast_dimensions=broadcast_dimensions) 1067 1068 1069def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()): 1070 define( 1071 lax.broadcast_p, 1072 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_sizes={sizes}", 1073 lambda operand: lax.broadcast_p.bind(operand, sizes=sizes), 1074 [RandArg(shape, dtype)], 1075 shape=shape, 1076 dtype=dtype, 1077 sizes=sizes) 1078 1079 1080for dtype in jtu.dtypes.all: 1081 _make_broadcast_harness("dtypes", dtype=dtype) 1082 1083# Validate sizes 1084for sizes in [ 1085 (2,), # broadcast 1 dim 1086 (1, 2, 3), # broadcast n > 1 dims 1087]: 1088 _make_broadcast_harness("sizes", sizes=sizes) 1089 1090for dtype in jtu.dtypes.all_floating: 1091 for arg1, arg2, arg3 in [ 1092 (np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype), 1093 np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype), 1094 np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6], dtype=dtype)) 1095 ]: 1096 define( 1097 lax.regularized_incomplete_beta_p, 1098 f"_{jtu.dtype_str(dtype)}", 1099 lax.betainc, [arg1, arg2, arg3], 1100 dtype=dtype) 1101 1102## GATHER 1103# Validate dtypes 1104for dtype in set(jtu.dtypes.all): 1105 indices = np.array(2, dtype=np.int32) 1106 shape = (10,) 1107 axis = 0 1108 define( 1109 lax.gather_p, 1110 f"dtypes_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}", 1111 lambda a, i, axis: jnp.take(a, i, axis=axis), 1112 [ 1113 RandArg(shape, dtype), 1114 indices, 1115 StaticArg(axis) 1116 ], 1117 dtype=dtype) 1118 1119# Construct gather harnesses using take 1120_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10)) 1121for indices in [ 1122 # Ensure each set of indices has a distinct shape 1123 np.array(2, dtype=np.int32), 1124 np.array([2], dtype=np.int32), 1125 np.array([2, 4], dtype=np.int32), 1126 np.array([[2, 4], [5, 6]], dtype=np.int32), 1127 np.array([0, 1, 10], dtype=np.int32), # Index out of bounds 1128 np.array([0, 1, 2, -1], dtype=np.int32), # Index out of bounds 1129]: 1130 for axis in [0, 1, 2]: 1131 define( 1132 lax.gather_p, 1133 f"from_take_indices_shape={indices.shape}_axis={axis}", 1134 lambda a, i, axis: jnp.take(a, i, axis=axis), 1135 [_gather_input, indices, StaticArg(axis)], 1136 dtype=_gather_input.dtype) 1137 1138# Directly from lax.gather in lax_test.py. 1139for shape, idxs, dnums, slice_sizes in [ 1140 ((5,), np.array([[0], [2]]), 1141 lax.GatherDimensionNumbers( 1142 offset_dims=(), collapsed_slice_dims=(0,), 1143 start_index_map=(0,)), (1,)), 1144 ((10,), np.array([[0], [0], [0]]), 1145 lax.GatherDimensionNumbers( 1146 offset_dims=(1,), collapsed_slice_dims=(), 1147 start_index_map=(0,)), (2,)), 1148 (( 1149 10, 1150 5, 1151 ), np.array([[0], [2], [1]]), 1152 lax.GatherDimensionNumbers( 1153 offset_dims=(1,), collapsed_slice_dims=(0,), 1154 start_index_map=(0,)), (1, 3)), 1155 ((10, 5), np.array([[0, 2], [1, 0]]), 1156 lax.GatherDimensionNumbers( 1157 offset_dims=(1,), collapsed_slice_dims=(0,), 1158 start_index_map=(0, 1)), (1, 3)), 1159]: 1160 dtype = np.float32 1161 define( 1162 lax.gather_p, 1163 f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}", 1164 lambda op, idxs, dnums, slice_sizes: lax.gather( 1165 op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), 1166 [RandArg(shape, dtype), idxs, 1167 StaticArg(dnums), 1168 StaticArg(slice_sizes)], 1169 dtype=dtype) 1170 1171 1172def _make_scatter_harness(name, 1173 *, 1174 shape=(5,), 1175 f_lax=lax.scatter_min, 1176 indices_are_sorted=False, 1177 unique_indices=False, 1178 scatter_indices=np.array([[0], [2]]), 1179 update_shape=(2,), 1180 dtype=np.float32, 1181 dimension_numbers=((), (0,), (0,))): 1182 dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) 1183 define( 1184 f_lax.__name__, 1185 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}" 1186 .replace(" ", ""), 1187 partial( 1188 f_lax, 1189 indices_are_sorted=indices_are_sorted, 1190 unique_indices=unique_indices), [ 1191 RandArg(shape, dtype), 1192 StaticArg(scatter_indices), 1193 RandArg(update_shape, dtype), 1194 StaticArg(dimension_numbers) 1195 ], 1196 jax_unimplemented=[ 1197 Limitation( 1198 "unimplemented", devices="tpu", dtypes=np.complex64, 1199 enabled=(f_lax in [lax.scatter_max, lax.scatter_min])) 1200 ], 1201 f_lax=f_lax, 1202 shape=shape, 1203 dtype=dtype, 1204 scatter_indices=scatter_indices, 1205 update_shape=update_shape, 1206 dimension_numbers=dimension_numbers, 1207 indices_are_sorted=indices_are_sorted, 1208 unique_indices=unique_indices) 1209 1210 1211# Validate dtypes 1212for dtype in jtu.dtypes.all: 1213 for f_lax in [lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min]: 1214 if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_: 1215 continue 1216 _make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax) 1217 1218# Validate f_lax/update_jaxpr 1219# We explicitly decide against testing lax.scatter, as its reduction function 1220# is lambda x, y: y, which is not commutative and thus makes results 1221# non-deterministic when an index into the operand is updated several times. 1222 1223# Validate shapes, dimension numbers and scatter indices 1224for shape, scatter_indices, update_shape, dimension_numbers in [ 1225 ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), 1226 ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) 1227]: 1228 _make_scatter_harness( 1229 "shapes_and_dimension_numbers", 1230 shape=shape, 1231 update_shape=update_shape, 1232 scatter_indices=np.array(scatter_indices), 1233 dimension_numbers=dimension_numbers) 1234 1235# Validate sorted indices 1236_make_scatter_harness("indices_are_sorted", indices_are_sorted=True) 1237# Validate unique_indices 1238# `unique_indices` does not affect correctness, only performance, and thus 1239# does not need to be tested here. If/when it will make sense to add a test 1240# with `unique_indices` = True, particular care will have to be taken with 1241# regards to the choice of parameters, as the results are only predictable 1242# when all the indices to be updated are pairwise non-overlapping. Identifying 1243# such cases is non-trivial. 1244_make_scatter_harness("unique_indices", unique_indices=False) 1245 1246for dtype in jtu.dtypes.all: 1247 arg_shape = (2, 3) 1248 for pads in [ 1249 [(0, 0, 0), (0, 0, 0)], # no padding 1250 [(1, 1, 0), (2, 2, 0)], # only positive edge padding 1251 [(1, 2, 1), (0, 1, 0)], # edge padding and interior padding 1252 [(0, 0, 0), (-1, -1, 0)], # negative padding 1253 [(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges 1254 [(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension 1255 ]: 1256 define( 1257 lax.pad_p, 1258 f"inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}", 1259 lax.pad, 1260 [RandArg(arg_shape, dtype), 1261 np.array(0, dtype), 1262 StaticArg(pads)], 1263 rng_factory=jtu.rand_small, 1264 arg_shape=arg_shape, 1265 dtype=dtype, 1266 pads=pads) 1267 1268 1269def _make_select_harness(name, 1270 *, 1271 shape_pred=(2, 3), 1272 shape_args=(2, 3), 1273 dtype=np.float32): 1274 define( 1275 lax.select_p, 1276 f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}", 1277 lax.select, [ 1278 RandArg(shape_pred, np.bool_), 1279 RandArg(shape_args, dtype), 1280 RandArg(shape_args, dtype) 1281 ], 1282 shape_pred=shape_pred, 1283 shape_args=shape_args, 1284 dtype=dtype) 1285 1286 1287for dtype in jtu.dtypes.all: 1288 _make_select_harness("dtypes", dtype=dtype) 1289 1290# Validate shapes 1291_make_select_harness("shapes", shape_pred=(), shape_args=(18,)) 1292 1293 1294def _make_transpose_harness(name, 1295 *, 1296 shape=(2, 3), 1297 permutation=(1, 0), 1298 dtype=np.float32): 1299 define( 1300 lax.transpose_p, 1301 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_permutation={permutation}" 1302 .replace(" ", ""), 1303 lambda x: lax.transpose_p.bind(x, permutation=permutation), 1304 [RandArg(shape, dtype)], 1305 shape=shape, 1306 dtype=dtype, 1307 permutation=permutation) 1308 1309 1310for dtype in jtu.dtypes.all: 1311 _make_transpose_harness("dtypes", dtype=dtype) 1312 1313# Validate permutations 1314for shape, permutation in [ 1315 ((2, 3, 4), (0, 1, 2)), # identity 1316 ((2, 3, 4), (1, 2, 0)), # transposition 1317]: 1318 _make_transpose_harness("permutations", shape=shape, permutation=permutation) 1319 1320 1321## CUMREDUCE 1322def _make_cumreduce_harness(name, 1323 *, 1324 f_jax=lax_control_flow.cummin, 1325 shape=(8, 9), 1326 dtype=np.float32, 1327 axis=0, 1328 reverse=False): 1329 limitations = [] 1330 if f_jax.__name__ != "cumsum": 1331 limitations.append( 1332 Limitation( 1333 "unimplemented", 1334 devices="tpu", 1335 dtypes=np.complex64, 1336 )) 1337 define( 1338 f_jax.__name__, 1339 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={axis}_reverse={reverse}", 1340 f_jax, [RandArg(shape, dtype), 1341 StaticArg(axis), 1342 StaticArg(reverse)], 1343 jax_unimplemented=limitations, 1344 f_jax=f_jax, 1345 shape=shape, 1346 dtype=dtype, 1347 axis=axis, 1348 reverse=reverse) 1349 1350 1351# Validate dtypes for each function 1352for f_jax in [ 1353 lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumsum, 1354 lax_control_flow.cumprod 1355]: 1356 for dtype in jtu.dtypes.all: 1357 if dtype == np.bool_: 1358 continue 1359 _make_cumreduce_harness("dtype_by_fun", dtype=dtype, f_jax=f_jax) 1360 1361 # Validate axis for each function 1362 shape = (8, 9) 1363 for axis in range(len(shape)): 1364 _make_cumreduce_harness("axis_by_fun", axis=axis, f_jax=f_jax, shape=shape) 1365 1366 # Validate reverse for each function 1367 _make_cumreduce_harness("reverse", reverse=True, f_jax=f_jax) 1368 1369### TOP_K 1370def _make_top_k_harness(name, 1371 *, 1372 operand=None, 1373 shape=(5, 3), 1374 dtype=np.float32, 1375 k=2): 1376 if operand is None: 1377 operand = RandArg(shape, dtype) 1378 define( 1379 lax.top_k_p, 1380 f"{name}_inshape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_k={k}", 1381 lax.top_k, [operand, StaticArg(k)], 1382 shape=operand.shape, 1383 dtype=operand.dtype, 1384 k=k) 1385 1386 1387for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex): 1388 # Validate dtypes 1389 _make_top_k_harness("dtypes", dtype=dtype) 1390 1391# Validate implicit properties of the sort 1392for name, operand, k in [("stability", 1393 np.array([5, 7, 5, 8, 8, 5], dtype=np.int32), 3), 1394 ("sort_inf_nan", 1395 np.array([+np.inf, np.nan, -np.nan, -np.inf, 3], 1396 dtype=np.float32), 5)]: 1397 _make_top_k_harness(name, operand=operand, k=k) 1398 1399### SORT 1400def _make_sort_harness(name, 1401 *, 1402 operands=None, 1403 shape=(5, 7), 1404 dtype=np.float32, 1405 dimension=0, 1406 is_stable=False, 1407 num_keys=1): 1408 if operands is None: 1409 operands = [RandArg(shape, dtype)] 1410 define( 1411 lax.sort_p, 1412 f"{name}_num_arrays={len(operands)}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}_num_keys={num_keys}", 1413 lambda *args: lax.sort_p.bind( 1414 *args[:-3], dimension=args[-3], is_stable=args[-2], num_keys=args[-1] 1415 ), [ 1416 *operands, 1417 StaticArg(dimension), 1418 StaticArg(is_stable), 1419 StaticArg(num_keys) 1420 ], 1421 shape=operands[0].shape, 1422 dimension=dimension, 1423 dtype=operands[0].dtype, 1424 is_stable=is_stable, 1425 num_keys=num_keys, 1426 num_arrays=len(operands)) 1427 1428 1429_lax_sort_multiple_array_shape = (100,) 1430# In order to test lexicographic ordering and sorting stability, the first 1431# array contains only integers 0 and 1 1432_lax_sort_multiple_array_first_arg = ( 1433 np.random.uniform(0, 2, _lax_sort_multiple_array_shape).astype(np.int32)) 1434 1435# Validate dtypes 1436for dtype in jtu.dtypes.all: 1437 _make_sort_harness("dtypes", dtype=dtype) 1438 1439# Validate dimensions 1440for dimension in [0, 1]: 1441 _make_sort_harness("dimensions", dimension=dimension) 1442# Validate stable sort 1443_make_sort_harness("is_stable", is_stable=True) 1444# Potential edge cases 1445for operands, dimension in [ 1446 ([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], 0) 1447]: 1448 _make_sort_harness("edge_cases", operands=operands, dimension=dimension) 1449 1450# Validate multiple arrays, num_keys, and is_stable 1451for is_stable in [False, True]: 1452 for operands in ( 1453 [ 1454 _lax_sort_multiple_array_first_arg, 1455 RandArg(_lax_sort_multiple_array_shape, np.int32) 1456 ], 1457 [ 1458 _lax_sort_multiple_array_first_arg, 1459 RandArg(_lax_sort_multiple_array_shape, np.int32), 1460 RandArg(_lax_sort_multiple_array_shape, np.float32) 1461 ], 1462 ): 1463 for num_keys in range(1, len(operands) + 1): 1464 _make_sort_harness( 1465 "multiple_arrays", 1466 operands=operands, 1467 num_keys=num_keys, 1468 is_stable=is_stable, 1469 shape=_lax_sort_multiple_array_first_arg.shape, 1470 dtype=_lax_sort_multiple_array_first_arg.dtype) 1471 1472 1473def _make_cholesky_arg(shape, dtype, rng): 1474 a = jtu.rand_default(rng)(shape, dtype) 1475 return np.matmul(a, jnp.conj(np.swapaxes(a, -1, -2))) 1476 1477 1478for dtype in jtu.dtypes.all_inexact: 1479 for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]: 1480 define( 1481 lax.linalg.cholesky_p, 1482 f"shape={jtu.format_shape_dtype_string(shape, dtype)}", 1483 lambda *args: lax.linalg.cholesky_p.bind(*args), 1484 [CustomArg(partial(_make_cholesky_arg, shape, dtype))], 1485 jax_unimplemented=[ 1486 Limitation( 1487 "unimplemented", 1488 dtypes=[np.float16], 1489 devices=("cpu", "gpu")) 1490 ], 1491 shape=shape, 1492 dtype=dtype) 1493 1494for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: 1495 for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]: 1496 for full_matrices in [False, True]: 1497 define( 1498 lax.linalg.qr_p, 1499 f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}", 1500 lax.linalg.qr, 1501 [RandArg(shape, dtype), 1502 StaticArg(full_matrices)], 1503 # See jax.lib.lapack.geqrf for the list of compatible types 1504 jax_unimplemented=[ 1505 Limitation( 1506 "unimplemented", 1507 devices=("cpu", "gpu"), 1508 dtypes=[np.float16, dtypes.bfloat16]), 1509 ], 1510 shape=shape, 1511 dtype=dtype, 1512 full_matrices=full_matrices) 1513 1514 1515def _make_fft_harness(name, 1516 *, 1517 shape=(14, 15, 16, 17), 1518 dtype=np.float32, 1519 fft_type=xla_client.FftType.FFT, 1520 fft_lengths=(17,)): 1521 def _fft_rng_factory(dtype): 1522 _all_integers = ( 1523 jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean) 1524 # For integer types, use small values to keep the errors small 1525 if dtype in _all_integers: 1526 return jtu.rand_small 1527 else: 1528 return jtu.rand_default 1529 1530 define( 1531 lax.fft_p, 1532 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}", 1533 lambda *args: lax.fft_p.bind( 1534 args[0], fft_type=args[1], fft_lengths=args[2]), 1535 [RandArg(shape, dtype), 1536 StaticArg(fft_type), 1537 StaticArg(fft_lengths)], 1538 jax_unimplemented=[ 1539 Limitation( 1540 "only 1D FFT is currently supported b/140351181.", 1541 devices="tpu", 1542 enabled=len(fft_lengths) > 1), 1543 ], 1544 rng_factory=_fft_rng_factory(dtype), 1545 shape=shape, 1546 dtype=dtype, 1547 fft_type=fft_type, 1548 fft_lengths=fft_lengths) 1549 1550 1551# FFT, IFFT, RFFT, IRFFT 1552for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])): 1553 # Validate dtypes per FFT type 1554 for dtype in (jtu.dtypes.floating 1555 if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex): 1556 shape = (14, 15, 16, 17) 1557 for fft_lengths in [ 1558 (shape[-1],) if fft_type != xla_client.FftType.IRFFT else 1559 ((shape[-1] - 1) * 2,) 1560 ]: 1561 _make_fft_harness( 1562 "dtypes", 1563 shape=shape, 1564 dtype=dtype, 1565 fft_type=fft_type, 1566 fft_lengths=fft_lengths) 1567 1568 # Validate dimensions per FFT type 1569 for dtype in [ 1570 np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64 1571 ]: 1572 for dims in [1, 2, 3]: 1573 for fft_lengths in [ 1574 shape[-dims:] if fft_type != xla_client.FftType.IRFFT else 1575 shape[-dims:-1] + ((shape[-1] - 1) * 2,) 1576 ]: 1577 _make_fft_harness( 1578 "dims", 1579 shape=shape, 1580 fft_type=fft_type, 1581 fft_lengths=fft_lengths, 1582 dtype=dtype) 1583 1584for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: 1585 for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]: 1586 for full_matrices in [False, True]: 1587 for compute_uv in [False, True]: 1588 define( 1589 lax.linalg.svd_p, 1590 f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}", 1591 lambda *args: lax.linalg.svd_p.bind( 1592 args[0], full_matrices=args[1], compute_uv=args[2]), [ 1593 RandArg(shape, dtype), 1594 StaticArg(full_matrices), 1595 StaticArg(compute_uv) 1596 ], 1597 jax_unimplemented=[ 1598 Limitation( 1599 "unimplemented", 1600 devices=("cpu", "gpu"), 1601 dtypes=[np.float16, dtypes.bfloat16]), 1602 Limitation( 1603 "complex not implemented. Works in JAX for CPU and GPU with custom kernels", 1604 devices="tpu", 1605 dtypes=[np.complex64, np.complex128]) 1606 ], 1607 shape=shape, 1608 dtype=dtype, 1609 full_matrices=full_matrices, 1610 compute_uv=compute_uv) 1611 1612for dtype in jtu.dtypes.all_inexact: 1613 for shape in [(0, 0), (5, 5), (2, 6, 6)]: 1614 for compute_left_eigenvectors in [False, True]: 1615 for compute_right_eigenvectors in [False, True]: 1616 define( 1617 lax.linalg.eig_p, 1618 f"shape={jtu.format_shape_dtype_string(shape, dtype)}_computelefteigenvectors={compute_left_eigenvectors}_computerighteigenvectors={compute_right_eigenvectors}", 1619 lax.linalg.eig, [ 1620 RandArg(shape, dtype), 1621 StaticArg(compute_left_eigenvectors), 1622 StaticArg(compute_right_eigenvectors) 1623 ], 1624 jax_unimplemented=[ 1625 Limitation( 1626 "only supported on CPU in JAX", devices=("tpu", "gpu")), 1627 Limitation( 1628 "unimplemented", 1629 devices="cpu", 1630 dtypes=[np.float16, dtypes.bfloat16]) 1631 ], 1632 shape=shape, 1633 dtype=dtype, 1634 compute_left_eigenvectors=compute_left_eigenvectors, 1635 compute_right_eigenvectors=compute_right_eigenvectors) 1636 1637 1638def _make_triangular_eigh_operand(shape, dtype, lower: bool, rng: Rng): 1639 # For testing eigh we use triangular matrices 1640 operand = jtu.rand_default(rng)(shape, dtype) 1641 # Make operand self-adjoint 1642 operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2 1643 # Make operand lower/upper triangular 1644 return operand # np.tril(operand) if lower else np.triu(operand) 1645 1646 1647for dtype in jtu.dtypes.all_inexact: 1648 for shape in [(0, 0), (50, 50), (2, 20, 20)]: 1649 for lower in [False, True]: 1650 define( 1651 lax.linalg.eigh_p, 1652 f"shape={jtu.format_shape_dtype_string(shape, dtype)}_lower={lower}", 1653 # Make operand lower/upper triangular 1654 lambda operand, lower, symmetrize_input: (lax.linalg.eigh( 1655 jnp.tril(operand) 1656 if lower else jnp.triu(operand), lower, symmetrize_input)), 1657 # lax.linalg.eigh, 1658 [ 1659 CustomArg( 1660 partial(_make_triangular_eigh_operand, shape, dtype, lower)), 1661 StaticArg(lower), 1662 StaticArg(False) 1663 ], 1664 jax_unimplemented=[ 1665 Limitation( 1666 "complex eigh not supported ", 1667 devices="tpu", 1668 dtypes=[np.complex64, np.complex128]), 1669 Limitation( 1670 "unimplemented", devices="cpu", dtypes=[np.float16]), 1671 Limitation( 1672 "unimplemented", devices="gpu", dtypes=[np.float16]), 1673 ], 1674 shape=shape, 1675 dtype=dtype, 1676 lower=lower) 1677 1678for dtype in jtu.dtypes.all_inexact: 1679 for shape in [ 1680 (5, 5), # square 1681 (3, 5, 5), # batched 1682 (3, 5), # non-square 1683 ]: 1684 define( 1685 lax.linalg.lu_p, 1686 f"shape={jtu.format_shape_dtype_string(shape, dtype)}", 1687 lax.linalg.lu, [RandArg(shape, dtype)], 1688 jax_unimplemented=[ 1689 Limitation( 1690 "unimplemented", dtypes=[np.float16, dtypes.bfloat16]) 1691 ], 1692 shape=shape, 1693 dtype=dtype) 1694 1695 1696def _make_triangular_solve_harness(name, 1697 *, 1698 left_side=True, 1699 lower=False, 1700 ab_shapes=((4, 4), (4, 1)), 1701 dtype=np.float32, 1702 transpose_a=False, 1703 conjugate_a=False, 1704 unit_diagonal=False): 1705 a_shape, b_shape = ab_shapes 1706 f_lax = lambda a, b: ( 1707 lax.linalg.triangular_solve_p.bind( 1708 a, 1709 b, 1710 left_side=left_side, 1711 lower=lower, 1712 transpose_a=transpose_a, 1713 conjugate_a=conjugate_a, 1714 unit_diagonal=unit_diagonal)) 1715 1716 define( 1717 lax.linalg.triangular_solve_p, 1718 f"{name}_a={jtu.format_shape_dtype_string(a_shape, dtype)}_b={jtu.format_shape_dtype_string(b_shape, dtype)}_leftside={left_side}_lower={lower}_transposea={transpose_a}_conjugatea={conjugate_a}_unitdiagonal={unit_diagonal}", 1719 f_lax, [RandArg(a_shape, dtype), 1720 RandArg(b_shape, dtype)], 1721 jax_unimplemented=[ 1722 Limitation( 1723 "unimplemented", devices="gpu", dtypes=[np.float16]), 1724 ], 1725 dtype=dtype, 1726 a_shape=a_shape, 1727 b_shape=b_shape, 1728 left_side=left_side, 1729 lower=lower, 1730 tranpose_a=transpose_a, 1731 conjugate_a=conjugate_a, 1732 unit_diagonal=unit_diagonal) 1733 1734 1735# Validate dtypes 1736# This first harness runs the tests for all dtypes using default values for 1737# all the other parameters, except unit_diagonal (to ensure that 1738# tf.linalg.set_diag works reliably for all dtypes). Variations of other 1739# parameters can thus safely skip testing their corresponding default value. 1740# Note that this validates solving on the left. 1741for dtype in jtu.dtypes.all_inexact: 1742 for unit_diagonal in [False, True]: 1743 _make_triangular_solve_harness( 1744 "dtypes", dtype=dtype, unit_diagonal=unit_diagonal) 1745 1746# Validate shapes when solving on the right 1747for ab_shapes in [ 1748 ((4, 4), (1, 4)), # standard 1749 ((2, 8, 8), (2, 10, 8)), # batched 1750]: 1751 _make_triangular_solve_harness( 1752 "shapes_right", ab_shapes=ab_shapes, left_side=False) 1753# Validate transformations of a complex matrix 1754for lower in [False, True]: 1755 for transpose_a in [False, True]: 1756 for conjugate_a in [False, True]: 1757 _make_triangular_solve_harness( 1758 "complex_transformations", 1759 dtype=np.complex64, 1760 lower=lower, 1761 transpose_a=transpose_a, 1762 conjugate_a=conjugate_a) 1763 1764# Validate transformations of a real matrix 1765for lower in [False, True]: 1766 for transpose_a in [False, True]: 1767 # conjugate_a is irrelevant for real dtypes, and is thus omitted 1768 _make_triangular_solve_harness( 1769 "real_transformations", 1770 dtype=np.float32, 1771 lower=lower, 1772 transpose_a=transpose_a) 1773 1774 1775def _make_linear_solve_harnesses(): 1776 def linear_solve(a, b, solve, transpose_solve=None, symmetric=False): 1777 matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST) 1778 return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric) 1779 1780 def explicit_jacobian_solve(matvec, b): 1781 return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b)) 1782 1783 def _make_harness(name, 1784 *, 1785 shape=(4, 4), 1786 dtype=np.float32, 1787 symmetric=False, 1788 solvers=(explicit_jacobian_solve, explicit_jacobian_solve)): 1789 solve, transpose_solve = solvers 1790 transpose_solve_name = transpose_solve.__name__ if transpose_solve else None 1791 1792 def _make_first_argument(rng): 1793 a = jtu.rand_default(rng)(shape, dtype) 1794 if symmetric: 1795 a = a + a.T 1796 return a 1797 1798 define( 1799 lax.linear_solve_p, 1800 f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_symmetric={symmetric}", 1801 linear_solve, [ 1802 CustomArg(_make_first_argument), 1803 RandArg(shape[:-1], dtype), 1804 StaticArg(solve), 1805 StaticArg(transpose_solve), 1806 StaticArg(symmetric) 1807 ], 1808 shape=shape, 1809 dtype=dtype, 1810 solve=solve, 1811 transpose_solve=transpose_solve, 1812 symmetric=symmetric) 1813 1814 for dtype in jtu.dtypes.all_floating: 1815 if not dtype in [np.float16, dtypes.bfloat16]: 1816 _make_harness("dtypes", dtype=dtype) 1817 # Validate symmetricity 1818 _make_harness("symmetric", symmetric=True) 1819 # Validate removing transpose_solve 1820 _make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None)) 1821 1822 1823_make_linear_solve_harnesses() 1824 1825 1826def _make_slice_harness(name, 1827 shape=(3,), 1828 start_indices=(1,), 1829 limit_indices=(2,), 1830 strides=None, 1831 dtype=np.float32): 1832 define( 1833 lax.slice_p, 1834 f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", 1835 # type: ignore 1836 lax.slice, 1837 [ 1838 RandArg(shape, dtype), # type: ignore 1839 StaticArg(start_indices), # type: ignore 1840 StaticArg(limit_indices), # type: ignore 1841 StaticArg(strides) 1842 ], # type: ignore 1843 dtype=dtype, 1844 shape=shape, # type: ignore 1845 start_indices=start_indices, # type: ignore 1846 limit_indices=limit_indices) # type: ignore 1847 1848 1849# Test first all dtypes 1850for dtype in jtu.dtypes.all: 1851 _make_slice_harness("dtypes", dtype=dtype) 1852# Now test many shapes 1853for shape, start_indices, limit_indices, strides in [ 1854 ((3,), (1,), (2,), None), 1855 ((7,), (4,), (7,), None), 1856 ((5,), (1,), (5,), (2,)), 1857 ((8,), (1,), (6,), (2,)), 1858 ((5, 3), (1, 1), (3, 2), None), 1859 ((5, 3), (1, 1), (3, 1), None), 1860 ((7, 5, 3), (4, 0, 1), (7, 1, 3), None), 1861 ((5, 3), (1, 1), (2, 1), (1, 1)), 1862 ((5, 3), (1, 1), (5, 3), (2, 1)), 1863]: 1864 _make_slice_harness( 1865 "shapes", 1866 shape=shape, 1867 start_indices=start_indices, 1868 limit_indices=limit_indices, 1869 strides=strides) 1870 1871 1872def _make_complex_harness(name, *, shapes=((3, 4), (3, 4)), dtype=np.float32): 1873 define( 1874 lax.complex_p, 1875 f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}", 1876 lax.complex_p.bind, 1877 [RandArg(shapes[0], dtype), 1878 RandArg(shapes[1], dtype)], 1879 shapes=shapes, 1880 dtype=dtype) 1881 1882 1883for dtype in jtu.dtypes.floating: 1884 _make_complex_harness("dtypes", dtype=dtype) 1885 1886# Validate broadcasting 1887for shapes in [ 1888 ((3, 2), (3, 1)), # broadcast imaginary part 1889 ((3, 1), (3, 2)), # broadcast real part 1890]: 1891 _make_complex_harness("broadcast", shapes=shapes) 1892 1893 1894def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs): 1895 define( 1896 lax.conj_p, 1897 f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_kwargs={kwargs}" 1898 .replace(" ", ""), 1899 lambda x: lax.conj_p.bind(x, **kwargs), [RandArg(shape, dtype)], 1900 shape=shape, 1901 dtype=dtype, 1902 **kwargs) 1903 1904 1905for dtype in jtu.dtypes.floating + jtu.dtypes.complex: 1906 _make_conj_harness("dtypes", dtype=dtype) 1907 1908# Validate kwargs 1909_make_conj_harness("kwargs", _input_dtype=np.float32) 1910 1911 1912def _make_real_imag_harness(prim, name, *, shape=(2, 3), dtype=np.float32): 1913 define( 1914 prim, 1915 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}", 1916 prim.bind, [RandArg(shape, dtype)], 1917 shape=shape, 1918 dtype=dtype, 1919 prim=prim) 1920 1921 1922for dtype in jtu.dtypes.complex: 1923 for prim in [lax.real_p, lax.imag_p]: 1924 _make_real_imag_harness(prim, "dtypes", dtype=dtype) 1925 1926 1927def _make_dynamic_slice_harness(name, 1928 shape=(3,), 1929 start_indices=(1,), 1930 limit_indices=(2,), 1931 dtype=np.float32): 1932 define( 1933 lax.dynamic_slice_p, 1934 f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_start_indices={start_indices}_limit_indices={limit_indices}", 1935 # type: ignore 1936 lax.dynamic_slice, 1937 [ 1938 RandArg(shape, dtype), # type: ignore 1939 np.array(list(start_indices)), 1940 StaticArg(tuple(map(operator.sub, limit_indices, start_indices))) 1941 ], # type: ignore 1942 dtype=dtype, 1943 shape=shape, # type: ignore 1944 start_indices=start_indices, # type: ignore 1945 limit_indices=limit_indices) # type: ignore 1946 1947 1948# Test first all dtypes 1949for dtype in jtu.dtypes.all: 1950 _make_dynamic_slice_harness("dtypes", dtype=dtype) 1951# Now test many shapes 1952for shape, start_indices, limit_indices in [ 1953 ((3,), (1,), (2,)), 1954 ((7,), (4,), (7,)), 1955 ((5,), (1,), (5,)), 1956 ((8,), (1,), (6,)), 1957 ((5, 3), (1, 1), (3, 2)), 1958 ((7, 5, 3), (4, 0, 1), (7, 1, 3)), 1959 ((5, 3), (1, 1), (2, 1)), 1960 # out-of-bounds cases, allowed for dynamic_slice 1961 ((5,), (-1,), (0,)), 1962 ((5,), (-1,), (1,)), 1963 ((5,), (-4,), (-2,)), 1964 ((5,), (-5,), (-2,)), 1965 ((5,), (-6,), (-5,)), 1966 ((5,), (-10,), (-9,)), 1967 ((5,), (-100,), (-99,)), 1968 ((5,), (5,), (6,)), 1969 ((5,), (10,), (11,)), 1970 ((5,), (3,), (6,)) 1971]: 1972 _make_dynamic_slice_harness( 1973 "shapes", 1974 shape=shape, 1975 start_indices=start_indices, 1976 limit_indices=limit_indices) 1977 1978 1979def _make_dynamic_update_slice_harness(name, 1980 shape=(3,), 1981 start_indices=(1,), 1982 dtype=np.float32, 1983 update_shape=(1,)): 1984 define( 1985 lax.dynamic_update_slice_p, 1986 ( 1987 f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore 1988 f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}" 1989 f"_start_indices={start_indices}"), 1990 lax.dynamic_update_slice, 1991 [ 1992 RandArg(shape, dtype), # type: ignore 1993 RandArg(update_shape, dtype), # type: ignore 1994 np.array(start_indices) 1995 ], # type: ignore 1996 dtype=dtype, 1997 shape=shape, # type: ignore 1998 start_indices=start_indices, # type: ignore 1999 update_shape=update_shape) # type: ignore 2000 2001 2002# Test first all dtypes 2003for dtype in jtu.dtypes.all: 2004 _make_dynamic_update_slice_harness("dtypes", dtype=dtype) 2005# Now test many shapes 2006for shape, start_indices, update_shape in [ 2007 ((3,), (1,), (1,)), 2008 ((5, 3), (1, 1), (3, 1)), 2009 ((7, 5, 3), (4, 1, 0), (2, 0, 1)), 2010 ((3,), (-1,), (1,)), # out-of-bounds 2011 ((3,), (10,), (1,)), # out-of-bounds 2012 ((3,), (10,), (2,)), # out-of-bounds 2013]: 2014 _make_dynamic_update_slice_harness( 2015 "shapes", 2016 shape=shape, 2017 start_indices=start_indices, 2018 update_shape=update_shape) 2019 2020 2021def _make_squeeze_harness(name, shape=(1, 2), dimensions=(0,), dtype=np.float32): 2022 define( 2023 lax.squeeze_p, 2024 f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_dimensions={dimensions}", # type: ignore 2025 lax.squeeze, 2026 [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] 2027 dtype=dtype, 2028 arg_shape=shape, 2029 dimensions=dimensions) # type: ignore[has-type] 2030 2031 2032# Test first all dtypes 2033for dtype in set(jtu.dtypes.all): 2034 _make_squeeze_harness("dtypes", dtype=dtype) 2035# Now test many shapes 2036for shape, dimensions in [ 2037 ((1,), (0,)), 2038 ((1,), (-1,)), 2039 ((2, 1, 4), (1,)), 2040 ((2, 1, 4), (-2,)), 2041 ((2, 1, 3, 1), (1,)), 2042 ((2, 1, 3, 1), (1, 3)), 2043 ((2, 1, 3, 1), (3,)), 2044 ((2, 1, 3, 1), (1, -1)), 2045]: 2046 _make_squeeze_harness("shapes", shape=shape, dimensions=dimensions) 2047 2048 2049def _make_select_and_scatter_add_harness(name, 2050 *, 2051 shape=(2, 4, 6), 2052 dtype=np.float32, 2053 select_prim=lax.ge_p, 2054 window_dimensions=(2, 2, 2), 2055 window_strides=(1, 1, 1), 2056 padding=((0, 0), (0, 0), (0, 0)), 2057 nb_inactive_dims=0): 2058 ones = (1,) * len(shape) 2059 cotangent_shape = jax.api.eval_shape( 2060 lambda x: lax._select_and_gather_add(x, x, lax.ge_p, window_dimensions, 2061 window_strides, padding, ones, ones), 2062 np.ones(shape, dtype)).shape 2063 define( 2064 lax.select_and_scatter_add_p, 2065 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}", 2066 lax._select_and_scatter_add, [ 2067 RandArg(cotangent_shape, dtype), 2068 RandArg(shape, dtype), 2069 StaticArg(select_prim), 2070 StaticArg(window_dimensions), 2071 StaticArg(window_strides), 2072 StaticArg(padding) 2073 ], 2074 jax_unimplemented=[ 2075 Limitation( 2076 "works only for 2 or more inactive dimensions", 2077 devices="tpu", 2078 enabled=(nb_inactive_dims < 2)) 2079 ], 2080 shape=shape, 2081 dtype=dtype, 2082 select_prim=select_prim, 2083 window_dimensions=window_dimensions, 2084 window_strides=window_strides, 2085 padding=padding) 2086 2087 2088for dtype in set(jtu.dtypes.all) - set([np.complex64, np.complex128]): 2089 _make_select_and_scatter_add_harness("dtypes", dtype=dtype) 2090 2091# Validate different reduction primitives 2092_make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p) 2093 2094# Validate padding 2095for padding in [ 2096 # TODO(bchetioui): commented out the test based on 2097 # https://github.com/google/jax/issues/4690 2098 # ((1, 2), (2, 3), (3, 4)) # non-zero padding 2099 ((1, 1), (1, 1), (1, 1)) # non-zero padding 2100]: 2101 _make_select_and_scatter_add_harness("padding", padding=padding) 2102 2103# Validate window_dimensions; uneven dimensions 2104_make_select_and_scatter_add_harness( 2105 "window_dimensions", window_dimensions=(1, 2, 3)) 2106 2107# Validate window_strides 2108# smaller than/same as/bigger than corresponding window dimension 2109_make_select_and_scatter_add_harness("window_strides", window_strides=(1, 2, 3)) 2110 2111# Validate dtypes on TPU 2112for dtype in set(jtu.dtypes.all) - set( 2113 [np.bool_, np.complex64, np.complex128, np.int8, np.uint8]): 2114 for window_strides, window_dimensions, nb_inactive_dims in [((1, 2, 1), 2115 (1, 3, 1), 2)]: 2116 _make_select_and_scatter_add_harness( 2117 "tpu_dtypes", 2118 dtype=dtype, 2119 nb_inactive_dims=nb_inactive_dims, 2120 window_strides=window_strides, 2121 window_dimensions=window_dimensions) 2122 2123 2124def _make_select_and_gather_add_harness(name, 2125 *, 2126 shape=(4, 6), 2127 dtype=np.float32, 2128 select_prim=lax.le_p, 2129 padding="VALID", 2130 window_dimensions=(2, 2), 2131 window_strides=(1, 1), 2132 base_dilation=(1, 1), 2133 window_dilation=(1, 1)): 2134 if isinstance(padding, str): 2135 padding = tuple( 2136 lax.padtype_to_pads(shape, window_dimensions, window_strides, padding)) 2137 define( 2138 lax.select_and_gather_add_p, 2139 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}", 2140 lax._select_and_gather_add, [ 2141 RandArg(shape, dtype), 2142 RandArg(shape, dtype), 2143 StaticArg(select_prim), 2144 StaticArg(window_dimensions), 2145 StaticArg(window_strides), 2146 StaticArg(padding), 2147 StaticArg(base_dilation), 2148 StaticArg(window_dilation) 2149 ], 2150 shape=shape, 2151 dtype=dtype, 2152 window_dimensions=window_dimensions, 2153 window_strides=window_strides, 2154 padding=padding, 2155 base_dilation=base_dilation, 2156 window_dilation=window_dilation) 2157 2158 2159for dtype in jtu.dtypes.all_floating: 2160 for select_prim in [lax.ge_p, lax.le_p]: 2161 _make_select_and_gather_add_harness("dtypes", dtype=dtype, select_prim=select_prim) 2162 2163# Validate selection primitives 2164_make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p) 2165# Validate window dimensions 2166_make_select_and_gather_add_harness( 2167 "window_dimensions", window_dimensions=(2, 3)) 2168 2169# Validate window strides 2170_make_select_and_gather_add_harness("window_strides", window_strides=(2, 3)) 2171# Validate padding 2172_make_select_and_gather_add_harness("padding", padding="SAME") 2173 2174# Validate dilations 2175for base_dilation, window_dilation in [ 2176 ((2, 3), (1, 1)), # base dilation, no window dilation 2177 ((1, 1), (2, 3)), # no base dilation, window dilation 2178 ((2, 3), (3, 2)) # base dilation, window dilation 2179]: 2180 _make_select_and_gather_add_harness( 2181 "dilations", base_dilation=base_dilation, window_dilation=window_dilation) 2182 2183 2184def _make_reduce_window_harness(name, 2185 *, 2186 shape=(4, 6), 2187 base_dilation=(1, 1), 2188 computation=lax.add, 2189 window_dimensions=(2, 2), 2190 window_dilation=(1, 1), 2191 init_value=0, 2192 window_strides=(1, 1), 2193 dtype=np.float32, 2194 padding=((0, 0), (0, 0))): 2195 prim_name = f"reduce_window_{computation.__name__}" 2196 limitations = [] 2197 if computation.__name__ in ("max", "mul", "min"): 2198 limitations.append( 2199 Limitation("unimplemented in XLA", devices="tpu", dtypes=np.complex64)) 2200 define( 2201 prim_name, 2202 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}" 2203 .replace(" ", ""), 2204 lax.reduce_window, 2205 [ 2206 RandArg(shape, dtype), 2207 # Must be static to trigger the picking of the reducers 2208 StaticArg(np.array(init_value, dtype=dtype)), 2209 StaticArg(computation), 2210 StaticArg(window_dimensions), 2211 StaticArg(window_strides), 2212 StaticArg(padding), 2213 StaticArg(base_dilation), 2214 StaticArg(window_dilation) 2215 ], 2216 jax_unimplemented=limitations, 2217 shape=shape, 2218 dtype=dtype, 2219 init_value=np.array(init_value, dtype=dtype), 2220 computation=computation, 2221 window_dimensions=window_dimensions, 2222 window_strides=window_strides, 2223 padding=padding, 2224 base_dilation=base_dilation, 2225 window_dilation=window_dilation) 2226 2227 2228# Validate dtypes across all execution paths 2229# This first harness runs the tests for all dtypes using default values for 2230# the other parameters (outside of computation and its init_value), through 2231# several execution paths. Variations of other parameters can thus safely 2232# skip testing their corresponding default value. 2233for dtype in jtu.dtypes.all: 2234 for computation, init_value in [ 2235 (lax.min, _get_min_identity(dtype)), # path through reduce_window_min 2236 (lax.max, _get_max_identity(dtype)), # path through TF reduce_window_max 2237 (lax.max, 1), # path through reduce_window 2238 (lax.add, 0), # path_through reduce_window_sum 2239 (lax.add, 1), # path through reduce_window 2240 (lax.mul, 0), # path through reduce_window 2241 (lax.mul, 1), # path through reduce_window 2242 (lax.mul, 2), # path through reduce_window 2243 ]: 2244 if dtype == np.bool_ and (computation in [lax.add, lax.mul]): 2245 continue 2246 _make_reduce_window_harness( 2247 "dtypes", dtype=dtype, computation=computation, init_value=init_value) 2248# Validate window_dimensions 2249_make_reduce_window_harness("window_dimensions", window_dimensions=(1, 1)) 2250# Validate window_strides 2251_make_reduce_window_harness("window_strides", window_strides=(1, 2)) 2252# Validate padding 2253_make_reduce_window_harness("padding", padding=((1, 2), (0, 3))) 2254# Validate base_dilation 2255_make_reduce_window_harness("base_dilation", base_dilation=(1, 2)) 2256# Validate window_dilation 2257_make_reduce_window_harness("window_dilation", window_dilation=(1, 2)) 2258# Validate squeezing behavior and dimensions in tf.nn.max_pool 2259for shape, window_dimensions in [ 2260 ((2,), (2,)), # 1 spatial dimension, left and right squeeze 2261 ((2, 1), (2, 1)), # 1 spatial dimension, left squeeze 2262 ((1, 2), (1, 2)), # 1 spatial dimension, right squeeze 2263 ((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze 2264 ((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze 2265 ((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze 2266 ((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze 2267]: 2268 _make_reduce_window_harness( 2269 "squeeze_dim", 2270 computation=lax.max, 2271 shape=shape, 2272 dtype=np.float32, 2273 init_value=-np.inf, 2274 base_dilation=tuple([1] * len(shape)), 2275 window_dilation=tuple([1] * len(shape)), 2276 padding=tuple([(0, 0)] * len(shape)), 2277 window_strides=tuple([1] * len(shape)), 2278 window_dimensions=window_dimensions) 2279 2280 2281def _make_reducer_harness(prim, 2282 name, 2283 *, 2284 shape=(2, 3), 2285 axes=(0,), 2286 dtype=np.int32): 2287 define( 2288 prim, 2289 f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}", 2290 lambda arg: prim.bind(arg, axes=axes), [RandArg(shape, dtype)], 2291 prim=prim, 2292 shape=shape, 2293 dtype=dtype, 2294 axes=axes) 2295 2296 2297for prim in [ 2298 lax.reduce_sum_p, lax.reduce_prod_p, lax.reduce_max_p, lax.reduce_min_p, 2299 lax.reduce_or_p, lax.reduce_and_p 2300]: 2301 for dtype in { 2302 lax.reduce_sum_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean), 2303 lax.reduce_prod_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean), 2304 lax.reduce_max_p: jtu.dtypes.all, 2305 lax.reduce_min_p: jtu.dtypes.all, 2306 lax.reduce_or_p: jtu.dtypes.boolean, 2307 lax.reduce_and_p: jtu.dtypes.boolean 2308 }[prim]: 2309 _make_reducer_harness(prim, "dtypes", dtype=dtype) 2310 2311for dtype in (np.float32, np.float64): 2312 for shape in ((), (3,)): 2313 define( 2314 "random_gamma", 2315 f"_shape={jtu.format_shape_dtype_string(shape, dtype)}", 2316 jax.jit(jax.random.gamma), 2317 [np.array([42, 43], dtype=np.uint32), 2318 RandArg(shape, dtype)], 2319 dtype=dtype) 2320 2321for key_i, key in enumerate([ 2322 np.array([0, 0], dtype=np.uint32), 2323 np.array([42, 43], dtype=np.uint32), 2324 np.array([0xFFFFFFFF, 0], dtype=np.uint32), 2325 np.array([0, 0xFFFFFFFF], dtype=np.uint32), 2326 np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32) 2327]): 2328 define( 2329 "random_split", 2330 f"_i={key_i}", 2331 jax.jit(lambda key: jax.random.split(key, 2)), [key], 2332 dtype=key.dtype) 2333 2334 2335def _make_clamp_harness(name, 2336 *, 2337 min_shape=(), 2338 operand_shape=(2, 3), 2339 max_shape=(), 2340 dtype=np.float32, 2341 min_max=None): 2342 min_arr, max_arr = ( 2343 min_max if min_max is not None else 2344 [RandArg(min_shape, dtype), 2345 RandArg(max_shape, dtype)]) 2346 define( 2347 lax.clamp_p, 2348 f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}", 2349 lax.clamp, [min_arr, RandArg(operand_shape, dtype), max_arr], 2350 min_shape=min_arr.shape, 2351 operand_shape=operand_shape, 2352 max_shape=max_arr.shape, 2353 dtype=dtype) 2354 2355 2356for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]): 2357 _make_clamp_harness("dtypes", dtype=dtype) 2358 2359# Validate broadcasting of min/max arrays 2360for min_shape, operand_shape, max_shape in [ 2361 ((), (2, 3), (2, 3)), # no broadcasting for max 2362 ((2, 3), (2, 3), ()), # no broadcasting for min 2363 ((2, 3), (2, 3), (2, 3)), # no broadcasting 2364]: 2365 _make_clamp_harness( 2366 "broadcasting", 2367 min_shape=min_shape, 2368 max_shape=max_shape, 2369 operand_shape=operand_shape) 2370 2371# Validate clamping when minval > maxval, and when minval < maxval 2372for is_ordered, min_arr, max_arr in [ 2373 (False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)), 2374 (True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32)) 2375]: 2376 _make_clamp_harness( 2377 f"order={is_ordered}", min_max=(min_arr, max_arr), dtype=np.float32) 2378 2379 2380def _make_dot_general_harness(name, 2381 *, 2382 lhs_shape=(3, 4), 2383 rhs_shape=(4, 2), 2384 dtype=np.float32, 2385 precision=None, 2386 dimension_numbers=(((1,), (0,)), ((), ()))): 2387 define( 2388 lax.dot_general_p, 2389 f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}" 2390 .replace(" ", ""), 2391 lax.dot_general, [ 2392 RandArg(lhs_shape, dtype), 2393 RandArg(rhs_shape, dtype), 2394 StaticArg(dimension_numbers), 2395 StaticArg(precision) 2396 ], 2397 dtype=dtype, 2398 lhs_shape=lhs_shape, 2399 rhs_shape=rhs_shape, 2400 dimension_numbers=dimension_numbers, 2401 precision=precision) 2402 2403 2404# There are two execution paths in the conversion of dot_general. The main path 2405# uses tf.einsum, while special cases use tf.linalg.matmul. For that reason, 2406# the below tests are designed to perform the same checks on both execution 2407# paths. 2408# Validate dtypes and precision 2409# This first harness runs the tests for all dtypes and precisions using 2410# default values for all the other parameters. Variations of other parameters 2411# can thus safely skip testing their corresponding default value. 2412 2413for dtype in jtu.dtypes.all: 2414 for precision in [ 2415 None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST 2416 ]: 2417 for lhs_shape, rhs_shape, dimension_numbers in [ 2418 ((3, 4), (4, 2), (((1,), (0,)), ((), ()))), 2419 ((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,)))) 2420 ]: 2421 _make_dot_general_harness( 2422 "dtypes_and_precision", 2423 precision=precision, 2424 lhs_shape=lhs_shape, 2425 rhs_shape=rhs_shape, 2426 dimension_numbers=dimension_numbers, 2427 dtype=dtype) 2428 2429# Validate batch dimensions 2430for lhs_shape, rhs_shape, dimension_numbers in [ 2431 # Unique pattern that can go through tf.linalg.matmul 2432 ((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))), 2433 # Main path with out of order batch dimensions 2434 ((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0)))) 2435]: 2436 _make_dot_general_harness( 2437 "batch_dimensions", 2438 lhs_shape=lhs_shape, 2439 rhs_shape=rhs_shape, 2440 dimension_numbers=dimension_numbers) 2441 2442# Validate squeezing behavior for matmul path 2443for lhs_shape, rhs_shape, dimension_numbers in [ 2444 ((4,), (4, 4), (((0,), (0,)), ((), ()))), # (1, 4) -> (4,) 2445 ((4, 4), (4,), (((1,), (0,)), ((), ()))), # (4, 1) -> (4,) 2446 ((4,), (4,), (((0,), (0,)), ((), ()))), # (1, 1) -> () 2447]: 2448 _make_dot_general_harness( 2449 "squeeze", 2450 lhs_shape=lhs_shape, 2451 rhs_shape=rhs_shape, 2452 dimension_numbers=dimension_numbers) 2453 2454 2455def _make_concatenate_harness(name, 2456 *, 2457 shapes=[(2, 3), (2, 3)], 2458 dimension=0, 2459 dtype=np.float32): 2460 shapes_str = "_".join(jtu.format_shape_dtype_string(s, dtype) for s in shapes) 2461 define( 2462 lax.concatenate_p, 2463 f"{name}_shapes={shapes_str}_dimension={dimension}", 2464 lambda *args: lax.concatenate_p.bind(*args, dimension=dimension), 2465 [RandArg(shape, dtype) for shape in shapes], 2466 shapes=shapes, 2467 dtype=dtype, 2468 dimension=dimension) 2469 2470 2471for dtype in jtu.dtypes.all: 2472 _make_concatenate_harness("dtypes", dtype=dtype) 2473 2474# Validate dimension; non-major axis 2475_make_concatenate_harness("dimension", dimension=1) 2476 2477# Validate > 2 operands 2478for shapes in [ 2479 [(2, 3, 4), (3, 3, 4), (4, 3, 4)], # 3 operands 2480]: 2481 _make_concatenate_harness("nb_operands", shapes=shapes) 2482 2483 2484def _make_conv_harness(name, 2485 *, 2486 lhs_shape=(2, 3, 9, 10), 2487 rhs_shape=(3, 3, 4, 5), 2488 dtype=np.float32, 2489 window_strides=(1, 1), 2490 precision=None, 2491 padding=((0, 0), (0, 0)), 2492 lhs_dilation=(1, 1), 2493 rhs_dilation=(1, 1), 2494 feature_group_count=1, 2495 dimension_numbers=("NCHW", "OIHW", "NCHW"), 2496 batch_group_count=1, 2497 enable_xla=True): 2498 define( 2499 lax.conv_general_dilated_p, 2500 f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_enablexla={enable_xla}" 2501 .replace(" ", ""), 2502 lax.conv_general_dilated, [ 2503 RandArg(lhs_shape, dtype), 2504 RandArg(rhs_shape, dtype), 2505 StaticArg(window_strides), 2506 StaticArg(padding), 2507 StaticArg(lhs_dilation), 2508 StaticArg(rhs_dilation), 2509 StaticArg(dimension_numbers), 2510 StaticArg(feature_group_count), 2511 StaticArg(batch_group_count), 2512 StaticArg(precision) 2513 ], 2514 lhs_shape=lhs_shape, 2515 rhs_shape=rhs_shape, 2516 dtype=dtype, 2517 window_strides=window_strides, 2518 padding=padding, 2519 lhs_dilation=lhs_dilation, 2520 rhs_dilation=rhs_dilation, 2521 dimension_numbers=dimension_numbers, 2522 feature_group_count=feature_group_count, 2523 batch_group_count=batch_group_count, 2524 precision=precision, 2525 enable_xla=enable_xla) 2526 2527 2528# Validate dtypes and precision 2529for dtype in jtu.dtypes.all_inexact: 2530 for precision in [ 2531 None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST 2532 ]: 2533 # This first harness runs the tests for all dtypes and precisions using 2534 # default values for all the other parameters. Variations of other parameters 2535 # can thus safely skip testing their corresponding default value. 2536 _make_conv_harness("dtype_precision", dtype=dtype, precision=precision) 2537# Validate variations of feature_group_count and batch_group_count 2538for batch_group_count, feature_group_count in [ 2539 (1, 2), # feature_group_count != 1 2540 (2, 1), # batch_group_count != 1 2541]: 2542 for lhs_shape, rhs_shape in [ 2543 ((2 * batch_group_count, 3 * feature_group_count, 9, 10), 2544 (3 * feature_group_count * batch_group_count, 3, 4, 5)) 2545 ]: 2546 _make_conv_harness( 2547 "group_counts", 2548 lhs_shape=lhs_shape, 2549 rhs_shape=rhs_shape, 2550 feature_group_count=feature_group_count, 2551 batch_group_count=batch_group_count) 2552 2553### XXX 2554 2555# Validate variations of window_strides 2556for window_strides in [(2, 3)]: 2557 _make_conv_harness("window_strides", window_strides=window_strides) 2558 2559# Validate variations of padding 2560for padding in [ 2561 ((1, 2), (0, 0)), # padding only one spatial axis 2562 ((1, 2), (2, 1)) # padding on both spatial axes 2563]: 2564 _make_conv_harness("padding", padding=padding) 2565 2566# Validate variations of dilations 2567for lhs_dilation, rhs_dilation in [ 2568 ((2, 2), (1, 1)), # dilation only on LHS (transposed) 2569 ((1, 1), (2, 3)), # dilation only on RHS (atrous) 2570 ((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous) 2571]: 2572 _make_conv_harness( 2573 "dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) 2574 2575# Dimension numbers and corresponding permutation 2576for dimension_numbers, lhs_shape, rhs_shape in [ 2577 (("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default 2578 (("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom 2579]: 2580 _make_conv_harness( 2581 "dimension_numbers", 2582 lhs_shape=lhs_shape, 2583 rhs_shape=rhs_shape, 2584 dimension_numbers=dimension_numbers) 2585 2586for padding, lhs_dilation, rhs_dilation in [ 2587 ("VALID", (1,), (1,)), # no dilation with "VALID" padding 2588 ("SAME", (1,), (1,)), # no dilation with "SAME" padding 2589 ("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding 2590 ("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding 2591 # TODO(bchetioui): LHS dilation with string padding can never be done using 2592 # TF convolution functions for now. 2593]: 2594 for dimension_numbers, lhs_shape, rhs_shape in [ 2595 (("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default 2596 # TODO(bchetioui): the NCW data format is not supported on CPU for TF 2597 # for now. That path is thus disabled to allow the code to use XLA instead. 2598 ]: 2599 for enable_xla in [False, True]: 2600 _make_conv_harness( 2601 "tf_conversion_path_1d", 2602 lhs_shape=lhs_shape, 2603 padding=padding, 2604 rhs_shape=rhs_shape, 2605 dimension_numbers=dimension_numbers, 2606 window_strides=(1,), 2607 lhs_dilation=lhs_dilation, 2608 rhs_dilation=rhs_dilation, 2609 enable_xla=enable_xla) 2610 2611for padding, lhs_dilation, rhs_dilation in [ 2612 ("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding 2613 ("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding 2614 ("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding 2615 ("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding 2616 # TODO(bchetioui): LHS dilation with string padding can never be done using 2617 # TF convolution functions for now. 2618]: 2619 for dimension_numbers, lhs_shape, rhs_shape in [ 2620 (("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default 2621 # TODO(bchetioui): the NCHW data format is not supported on CPU for TF 2622 # for now. That path is thus disabled to allow the code to use XLA instead. 2623 ]: 2624 for enable_xla in [False, True]: 2625 _make_conv_harness( 2626 "tf_conversion_path_2d", 2627 lhs_shape=lhs_shape, 2628 padding=padding, 2629 rhs_shape=rhs_shape, 2630 dimension_numbers=dimension_numbers, 2631 window_strides=(1, 1), 2632 lhs_dilation=lhs_dilation, 2633 rhs_dilation=rhs_dilation, 2634 enable_xla=enable_xla) 2635 2636for padding, lhs_dilation, rhs_dilation in [ 2637 ("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding 2638 ("SAME", (1, 1, 1), (1, 1, 1)), # no dilation with "SAME" padding 2639 ("VALID", (1, 1, 1), (1, 1, 2640 2)), # dilation only on RHS with "VALID" padding 2641 ("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding 2642 # TODO(bchetioui): LHS dilation with string padding can never be done using 2643 # TF convolution functions for now. 2644]: 2645 for dimension_numbers, lhs_shape, rhs_shape in [ 2646 # TF default 2647 (("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)), 2648 # TODO(bchetioui): the NCDHW data format is not supported on CPU for TF 2649 # for now. That path is thus disabled to allow the code to use XLA instead. 2650 ]: 2651 for enable_xla in [False, True]: 2652 _make_conv_harness( 2653 "tf_conversion_path_3d", 2654 lhs_shape=lhs_shape, 2655 padding=padding, 2656 rhs_shape=rhs_shape, 2657 dimension_numbers=dimension_numbers, 2658 window_strides=(1, 1, 1), 2659 lhs_dilation=lhs_dilation, 2660 rhs_dilation=rhs_dilation, 2661 enable_xla=enable_xla) 2662