1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from contextlib import contextmanager 16import functools 17import re 18import os 19import textwrap 20from typing import Dict, Sequence, Union 21import unittest 22import warnings 23import zlib 24 25from absl.testing import absltest 26from absl.testing import parameterized 27 28import numpy as np 29import numpy.random as npr 30 31from . import api 32from . import core 33from . import dtypes as _dtypes 34from . import lax 35from .config import flags, bool_env 36from ._src.util import partial, prod 37from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce 38from .lib import xla_bridge 39from .interpreters import xla 40 41 42FLAGS = flags.FLAGS 43flags.DEFINE_enum( 44 'jax_test_dut', '', 45 enum_values=['', 'cpu', 'gpu', 'tpu'], 46 help= 47 'Describes the device under test in case special consideration is required.' 48) 49 50flags.DEFINE_integer( 51 'num_generated_cases', 52 int(os.getenv('JAX_NUM_GENERATED_CASES', 10)), 53 help='Number of generated cases to test') 54 55flags.DEFINE_bool( 56 'jax_skip_slow_tests', 57 bool_env('JAX_SKIP_SLOW_TESTS', False), 58 help='Skip tests marked as slow (> 5 sec).' 59) 60 61flags.DEFINE_string( 62 'test_targets', '', 63 'Regular expression specifying which tests to run, called via re.match on ' 64 'the test name. If empty or unspecified, run all tests.' 65) 66flags.DEFINE_string( 67 'exclude_test_targets', '', 68 'Regular expression specifying which tests NOT to run, called via re.match ' 69 'on the test name. If empty or unspecified, run all tests.' 70) 71 72EPS = 1e-4 73 74def _dtype(x): 75 return (getattr(x, 'dtype', None) or 76 np.dtype(_dtypes.python_scalar_dtypes.get(type(x), None)) or 77 np.asarray(x).dtype) 78 79 80def num_float_bits(dtype): 81 return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits 82 83 84def is_sequence(x): 85 try: 86 iter(x) 87 except TypeError: 88 return False 89 else: 90 return True 91 92_default_tolerance = { 93 np.dtype(np.bool_): 0, 94 np.dtype(np.int8): 0, 95 np.dtype(np.int16): 0, 96 np.dtype(np.int32): 0, 97 np.dtype(np.int64): 0, 98 np.dtype(np.uint8): 0, 99 np.dtype(np.uint16): 0, 100 np.dtype(np.uint32): 0, 101 np.dtype(np.uint64): 0, 102 np.dtype(_dtypes.bfloat16): 1e-2, 103 np.dtype(np.float16): 1e-3, 104 np.dtype(np.float32): 1e-6, 105 np.dtype(np.float64): 1e-15, 106 np.dtype(np.complex64): 1e-6, 107 np.dtype(np.complex128): 1e-15, 108} 109 110def default_tolerance(): 111 if device_under_test() != "tpu": 112 return _default_tolerance 113 tol = _default_tolerance.copy() 114 tol[np.dtype(np.float32)] = 1e-3 115 tol[np.dtype(np.complex64)] = 1e-3 116 return tol 117 118default_gradient_tolerance = { 119 np.dtype(_dtypes.bfloat16): 1e-1, 120 np.dtype(np.float16): 1e-2, 121 np.dtype(np.float32): 2e-3, 122 np.dtype(np.float64): 1e-5, 123 np.dtype(np.complex64): 1e-3, 124 np.dtype(np.complex128): 1e-5, 125} 126 127def _assert_numpy_allclose(a, b, atol=None, rtol=None): 128 a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a 129 b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b 130 kw = {} 131 if atol: kw["atol"] = atol 132 if rtol: kw["rtol"] = rtol 133 np.testing.assert_allclose(a, b, **kw) 134 135def tolerance(dtype, tol=None): 136 tol = {} if tol is None else tol 137 if not isinstance(tol, dict): 138 return tol 139 tol = {np.dtype(key): value for key, value in tol.items()} 140 dtype = _dtypes.canonicalize_dtype(np.dtype(dtype)) 141 return tol.get(dtype, default_tolerance()[dtype]) 142 143def _normalize_tolerance(tol): 144 tol = tol or 0 145 if isinstance(tol, dict): 146 return {np.dtype(k): v for k, v in tol.items()} 147 else: 148 return {k: tol for k in _default_tolerance} 149 150def join_tolerance(tol1, tol2): 151 tol1 = _normalize_tolerance(tol1) 152 tol2 = _normalize_tolerance(tol2) 153 out = tol1 154 for k, v in tol2.items(): 155 out[k] = max(v, tol1.get(k, 0)) 156 return out 157 158def _assert_numpy_close(a, b, atol=None, rtol=None): 159 assert a.shape == b.shape 160 atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) 161 rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) 162 _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) 163 164 165def check_eq(xs, ys): 166 tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) 167 168 169def check_close(xs, ys, atol=None, rtol=None): 170 assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol) 171 tree_all(tree_multimap(assert_close, xs, ys)) 172 173def _check_dtypes_match(xs, ys): 174 def _assert_dtypes_match(x, y): 175 if FLAGS.jax_enable_x64: 176 assert _dtype(x) == _dtype(y) 177 else: 178 assert (_dtypes.canonicalize_dtype(_dtype(x)) == 179 _dtypes.canonicalize_dtype(_dtype(y))) 180 tree_all(tree_multimap(_assert_dtypes_match, xs, ys)) 181 182 183def inner_prod(xs, ys): 184 def contract(x, y): 185 return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1))) 186 return tree_reduce(np.add, tree_multimap(contract, xs, ys)) 187 188 189def _safe_subtract(x, y, *, dtype): 190 """Subtraction that with `inf - inf == 0` semantics.""" 191 with np.errstate(invalid='ignore'): 192 return np.where(np.equal(x, y), np.array(0, dtype), 193 np.subtract(x, y, dtype=dtype)) 194 195add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x))) 196sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x))) 197safe_sub = partial(tree_multimap, 198 lambda x, y: _safe_subtract(x, y, dtype=_dtype(x))) 199conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x))) 200 201def scalar_mul(xs, a): 202 return tree_map(lambda x: np.multiply(x, a, dtype=_dtype(x)), xs) 203 204 205def rand_like(rng, x): 206 shape = np.shape(x) 207 dtype = _dtype(x) 208 randn = lambda: np.asarray(rng.randn(*shape), dtype=dtype) 209 if _dtypes.issubdtype(dtype, np.complexfloating): 210 return randn() + dtype.type(1.0j) * randn() 211 else: 212 return randn() 213 214 215def numerical_jvp(f, primals, tangents, eps=EPS): 216 delta = scalar_mul(tangents, eps) 217 f_pos = f(*add(primals, delta)) 218 f_neg = f(*sub(primals, delta)) 219 return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps) 220 221 222def _merge_tolerance(tol, default): 223 if tol is None: 224 return default 225 if not isinstance(tol, dict): 226 return tol 227 out = default.copy() 228 for k, v in tol.items(): 229 out[np.dtype(k)] = v 230 return out 231 232def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): 233 atol = _merge_tolerance(atol, default_gradient_tolerance) 234 rtol = _merge_tolerance(rtol, default_gradient_tolerance) 235 rng = np.random.RandomState(0) 236 tangent = tree_map(partial(rand_like, rng), args) 237 v_out, t_out = f_jvp(args, tangent) 238 _check_dtypes_match(v_out, t_out) 239 v_out_expected = f(*args) 240 _check_dtypes_match(v_out, v_out_expected) 241 t_out_expected = numerical_jvp(f, args, tangent, eps=eps) 242 # In principle we should expect exact equality of v_out and v_out_expected, 243 # but due to nondeterminism especially on GPU (e.g., due to convolution 244 # autotuning) we only require "close". 245 check_close(v_out, v_out_expected, atol=atol, rtol=rtol) 246 check_close(t_out, t_out_expected, atol=atol, rtol=rtol) 247 248 249def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS): 250 atol = _merge_tolerance(atol, default_gradient_tolerance) 251 rtol = _merge_tolerance(rtol, default_gradient_tolerance) 252 _rand_like = partial(rand_like, np.random.RandomState(0)) 253 v_out, vjpfun = f_vjp(*args) 254 v_out_expected = f(*args) 255 check_close(v_out, v_out_expected, atol=atol, rtol=rtol) 256 tangent = tree_map(_rand_like, args) 257 tangent_out = numerical_jvp(f, args, tangent, eps=eps) 258 cotangent = tree_map(_rand_like, v_out) 259 cotangent_out = conj(vjpfun(conj(cotangent))) 260 ip = inner_prod(tangent, cotangent_out) 261 ip_expected = inner_prod(tangent_out, cotangent) 262 check_close(ip, ip_expected, atol=atol, rtol=rtol) 263 264 265def check_grads(f, args, order, 266 modes=["fwd", "rev"], atol=None, rtol=None, eps=None): 267 """Check gradients from automatic differentiation against finite differences. 268 269 Gradients are only checked in a single randomly chosen direction, which 270 ensures that the finite difference calculation does not become prohibitively 271 expensive even for large input/output spaces. 272 273 Args: 274 f: function to check at ``f(*args)``. 275 args: tuple of argument values. 276 order: forward and backwards gradients up to this order are checked. 277 modes: lists of gradient modes to check ('fwd' and/or 'rev'). 278 atol: absolute tolerance for gradient equality. 279 rtol: relative tolerance for gradient equality. 280 eps: step size used for finite differences. 281 282 Raises: 283 AssertionError: if gradients do not match. 284 """ 285 args = tuple(args) 286 eps = eps or EPS 287 288 _check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps) 289 _check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps) 290 291 def _check_grads(f, args, order): 292 if "fwd" in modes: 293 _check_jvp(f, partial(api.jvp, f), args) 294 if order > 1: 295 _check_grads(partial(api.jvp, f), (args, args), order - 1) 296 297 if "rev" in modes: 298 _check_vjp(f, partial(api.vjp, f), args) 299 if order > 1: 300 def f_vjp(*args): 301 out_primal_py, vjp_py = api.vjp(f, *args) 302 return vjp_py(out_primal_py) 303 _check_grads(f_vjp, args, order - 1) 304 305 _check_grads(f, args, order) 306 307 308@contextmanager 309def count_primitive_compiles(): 310 xla.xla_primitive_callable.cache_clear() 311 312 # We count how many times we call primitive_computation (which is called 313 # inside xla_primitive_callable) instead of xla_primitive_callable so we don't 314 # count cache hits. 315 primitive_computation = xla.primitive_computation 316 count = [0] 317 318 def primitive_computation_and_count(*args, **kwargs): 319 count[0] += 1 320 return primitive_computation(*args, **kwargs) 321 322 xla.primitive_computation = primitive_computation_and_count 323 try: 324 yield count 325 finally: 326 xla.primitive_computation = primitive_computation 327 328 329@contextmanager 330def count_jit_and_pmap_compiles(): 331 # No need to clear any caches since we generally jit and pmap fresh callables 332 # in tests. 333 334 jaxpr_subcomp = xla.jaxpr_subcomp 335 count = [0] 336 337 def jaxpr_subcomp_and_count(*args, **kwargs): 338 count[0] += 1 339 return jaxpr_subcomp(*args, **kwargs) 340 341 xla.jaxpr_subcomp = jaxpr_subcomp_and_count 342 try: 343 yield count 344 finally: 345 xla.jaxpr_subcomp = jaxpr_subcomp 346 347@contextmanager 348def assert_num_jit_and_pmap_compilations(times): 349 with count_jit_and_pmap_compiles() as count: 350 yield 351 if count[0] != times: 352 raise AssertionError(f"Expected exactly {times} XLA compilations, " 353 f"but executed {count[0]}") 354 355def device_under_test(): 356 return FLAGS.jax_test_dut or xla_bridge.get_backend().platform 357 358def if_device_under_test(device_type: Union[str, Sequence[str]], 359 if_true, if_false): 360 """Chooses `if_true` of `if_false` based on device_under_test.""" 361 if device_under_test() in ([device_type] if isinstance(device_type, str) 362 else device_type): 363 return if_true 364 else: 365 return if_false 366 367def supported_dtypes(): 368 if device_under_test() == "tpu": 369 types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, 370 np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} 371 else: 372 types = {np.bool_, np.int8, np.int16, np.int32, np.int64, 373 np.uint8, np.uint16, np.uint32, np.uint64, 374 _dtypes.bfloat16, np.float16, np.float32, np.float64, 375 np.complex64, np.complex128} 376 if not FLAGS.jax_enable_x64: 377 types -= {np.uint64, np.int64, np.float64, np.complex128} 378 return types 379 380def skip_if_unsupported_type(dtype): 381 dtype = np.dtype(dtype) 382 if dtype.type not in supported_dtypes(): 383 raise unittest.SkipTest( 384 f"Type {dtype.name} not supported on {device_under_test()}") 385 386def skip_on_devices(*disabled_devices): 387 """A decorator for test methods to skip the test on certain devices.""" 388 def skip(test_method): 389 @functools.wraps(test_method) 390 def test_method_wrapper(self, *args, **kwargs): 391 device = device_under_test() 392 if device in disabled_devices: 393 test_name = getattr(test_method, '__name__', '[unknown test]') 394 raise unittest.SkipTest( 395 f"{test_name} not supported on {device.upper()}.") 396 return test_method(self, *args, **kwargs) 397 return test_method_wrapper 398 return skip 399 400def set_host_platform_device_count(nr_devices: int): 401 """Returns a closure that undoes the operation.""" 402 prev_xla_flags = os.getenv("XLA_FLAGS") 403 flags_str = prev_xla_flags or "" 404 # Don't override user-specified device count, or other XLA flags. 405 if "xla_force_host_platform_device_count" not in flags_str: 406 os.environ["XLA_FLAGS"] = (flags_str + 407 f" --xla_force_host_platform_device_count={nr_devices}") 408 # Clear any cached backends so new CPU backend will pick up the env var. 409 xla_bridge.get_backend.cache_clear() 410 def undo(): 411 if prev_xla_flags is None: 412 del os.environ["XLA_FLAGS"] 413 else: 414 os.environ["XLA_FLAGS"] = prev_xla_flags 415 xla_bridge.get_backend.cache_clear() 416 return undo 417 418def skip_on_flag(flag_name, skip_value): 419 """A decorator for test methods to skip the test when flags are set.""" 420 def skip(test_method): # pylint: disable=missing-docstring 421 @functools.wraps(test_method) 422 def test_method_wrapper(self, *args, **kwargs): 423 flag_value = getattr(FLAGS, flag_name) 424 if flag_value == skip_value: 425 test_name = getattr(test_method, '__name__', '[unknown test]') 426 raise unittest.SkipTest( 427 f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}") 428 return test_method(self, *args, **kwargs) 429 return test_method_wrapper 430 return skip 431 432 433def format_test_name_suffix(opname, shapes, dtypes): 434 arg_descriptions = (format_shape_dtype_string(shape, dtype) 435 for shape, dtype in zip(shapes, dtypes)) 436 return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions)) 437 438 439# We use special symbols, represented as singleton objects, to distinguish 440# between NumPy scalars, Python scalars, and 0-D arrays. 441class ScalarShape(object): 442 def __len__(self): return 0 443class _NumpyScalar(ScalarShape): pass 444class _PythonScalar(ScalarShape): pass 445NUMPY_SCALAR_SHAPE = _NumpyScalar() 446PYTHON_SCALAR_SHAPE = _PythonScalar() 447 448 449def _dims_of_shape(shape): 450 """Converts `shape` to a tuple of dimensions.""" 451 if type(shape) in (list, tuple): 452 return shape 453 elif isinstance(shape, ScalarShape): 454 return () 455 else: 456 raise TypeError(type(shape)) 457 458 459def _cast_to_shape(value, shape, dtype): 460 """Casts `value` to the correct Python type for `shape` and `dtype`.""" 461 if shape is NUMPY_SCALAR_SHAPE: 462 # explicitly cast to NumPy scalar in case `value` is a Python scalar. 463 return np.dtype(dtype).type(value) 464 elif shape is PYTHON_SCALAR_SHAPE: 465 # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998 466 return np.asarray(value).item() 467 elif type(shape) in (list, tuple): 468 assert np.shape(value) == tuple(shape) 469 return value 470 else: 471 raise TypeError(type(shape)) 472 473 474def dtype_str(dtype): 475 return np.dtype(dtype).name 476 477 478def format_shape_dtype_string(shape, dtype): 479 if isinstance(shape, np.ndarray): 480 return f'{dtype_str(dtype)}[{shape}]' 481 elif isinstance(shape, list): 482 shape = tuple(shape) 483 return _format_shape_dtype_string(shape, dtype) 484 485@functools.lru_cache(maxsize=64) 486def _format_shape_dtype_string(shape, dtype): 487 if shape is NUMPY_SCALAR_SHAPE: 488 return dtype_str(dtype) 489 elif shape is PYTHON_SCALAR_SHAPE: 490 return 'py' + dtype_str(dtype) 491 elif type(shape) is tuple: 492 shapestr = ','.join(str(dim) for dim in shape) 493 return '{}[{}]'.format(dtype_str(dtype), shapestr) 494 elif type(shape) is int: 495 return '{}[{},]'.format(dtype_str(dtype), shape) 496 else: 497 raise TypeError(type(shape)) 498 499 500def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): 501 """Produce random values given shape, dtype, scale, and post-processor. 502 503 Args: 504 rand: a function for producing random values of a given shape, e.g. a 505 bound version of either np.RandomState.randn or np.RandomState.rand. 506 shape: a shape value as a tuple of positive integers. 507 dtype: a numpy dtype. 508 scale: optional, a multiplicative scale for the random values (default 1). 509 post: optional, a callable for post-processing the random values (default 510 identity). 511 512 Returns: 513 An ndarray of the given shape and dtype using random values based on a call 514 to rand but scaled, converted to the appropriate dtype, and post-processed. 515 """ 516 r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype) 517 if _dtypes.issubdtype(dtype, np.complexfloating): 518 vals = r() + 1.0j * r() 519 else: 520 vals = r() 521 return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype) 522 523 524def rand_fullrange(rng, standardize_nans=False): 525 """Random numbers that span the full range of available bits.""" 526 def gen(shape, dtype, post=lambda x: x): 527 dtype = np.dtype(dtype) 528 size = dtype.itemsize * np.prod(_dims_of_shape(shape)) 529 vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8) 530 vals = post(vals).view(dtype).reshape(shape) 531 # Non-standard NaNs cause errors in numpy equality assertions. 532 if standardize_nans and np.issubdtype(dtype, np.floating): 533 vals[np.isnan(vals)] = np.nan 534 return _cast_to_shape(vals, shape, dtype) 535 return gen 536 537 538def rand_default(rng, scale=3): 539 return partial(_rand_dtype, rng.randn, scale=scale) 540 541 542def rand_nonzero(rng): 543 post = lambda x: np.where(x == 0, np.array(1, dtype=x.dtype), x) 544 return partial(_rand_dtype, rng.randn, scale=3, post=post) 545 546 547def rand_positive(rng): 548 post = lambda x: x + 1 549 return partial(_rand_dtype, rng.rand, scale=2, post=post) 550 551 552def rand_small(rng): 553 return partial(_rand_dtype, rng.randn, scale=1e-3) 554 555 556def rand_not_small(rng, offset=10.): 557 post = lambda x: x + np.where(x > 0, offset, -offset) 558 return partial(_rand_dtype, rng.randn, scale=3., post=post) 559 560 561def rand_small_positive(rng): 562 return partial(_rand_dtype, rng.rand, scale=2e-5) 563 564def rand_uniform(rng, low=0.0, high=1.0): 565 assert low < high 566 post = lambda x: x * (high - low) + low 567 return partial(_rand_dtype, rng.rand, post=post) 568 569 570def rand_some_equal(rng): 571 572 def post(x): 573 x_ravel = x.ravel() 574 if len(x_ravel) == 0: 575 return x 576 flips = rng.rand(*np.shape(x)) < 0.5 577 return np.where(flips, x_ravel[0], x) 578 579 return partial(_rand_dtype, rng.randn, scale=100., post=post) 580 581 582def rand_some_inf(rng): 583 """Return a random sampler that produces infinities in floating types.""" 584 base_rand = rand_default(rng) 585 586 """ 587 TODO: Complex numbers are not correctly tested 588 If blocks should be switched in order, and relevant tests should be fixed 589 """ 590 def rand(shape, dtype): 591 """The random sampler function.""" 592 if not _dtypes.issubdtype(dtype, np.floating): 593 # only float types have inf 594 return base_rand(shape, dtype) 595 596 if _dtypes.issubdtype(dtype, np.complexfloating): 597 base_dtype = np.real(np.array(0, dtype=dtype)).dtype 598 out = (rand(shape, base_dtype) + 599 np.array(1j, dtype) * rand(shape, base_dtype)) 600 return _cast_to_shape(out, shape, dtype) 601 602 dims = _dims_of_shape(shape) 603 posinf_flips = rng.rand(*dims) < 0.1 604 neginf_flips = rng.rand(*dims) < 0.1 605 606 vals = base_rand(shape, dtype) 607 vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals) 608 vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals) 609 610 return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype) 611 612 return rand 613 614def rand_some_nan(rng): 615 """Return a random sampler that produces nans in floating types.""" 616 base_rand = rand_default(rng) 617 618 def rand(shape, dtype): 619 """The random sampler function.""" 620 if _dtypes.issubdtype(dtype, np.complexfloating): 621 base_dtype = np.real(np.array(0, dtype=dtype)).dtype 622 out = (rand(shape, base_dtype) + 623 np.array(1j, dtype) * rand(shape, base_dtype)) 624 return _cast_to_shape(out, shape, dtype) 625 626 if not _dtypes.issubdtype(dtype, np.floating): 627 # only float types have inf 628 return base_rand(shape, dtype) 629 630 dims = _dims_of_shape(shape) 631 nan_flips = rng.rand(*dims) < 0.1 632 633 vals = base_rand(shape, dtype) 634 vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals) 635 636 return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype) 637 638 return rand 639 640def rand_some_inf_and_nan(rng): 641 """Return a random sampler that produces infinities in floating types.""" 642 base_rand = rand_default(rng) 643 644 """ 645 TODO: Complex numbers are not correctly tested 646 If blocks should be switched in order, and relevant tests should be fixed 647 """ 648 def rand(shape, dtype): 649 """The random sampler function.""" 650 if not _dtypes.issubdtype(dtype, np.floating): 651 # only float types have inf 652 return base_rand(shape, dtype) 653 654 if _dtypes.issubdtype(dtype, np.complexfloating): 655 base_dtype = np.real(np.array(0, dtype=dtype)).dtype 656 out = (rand(shape, base_dtype) + 657 np.array(1j, dtype) * rand(shape, base_dtype)) 658 return _cast_to_shape(out, shape, dtype) 659 660 dims = _dims_of_shape(shape) 661 posinf_flips = rng.rand(*dims) < 0.1 662 neginf_flips = rng.rand(*dims) < 0.1 663 nan_flips = rng.rand(*dims) < 0.1 664 665 vals = base_rand(shape, dtype) 666 vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals) 667 vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals) 668 vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals) 669 670 return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype) 671 672 return rand 673 674# TODO(mattjj): doesn't handle complex types 675def rand_some_zero(rng): 676 """Return a random sampler that produces some zeros.""" 677 base_rand = rand_default(rng) 678 679 def rand(shape, dtype): 680 """The random sampler function.""" 681 dims = _dims_of_shape(shape) 682 zeros = rng.rand(*dims) < 0.5 683 684 vals = base_rand(shape, dtype) 685 vals = np.where(zeros, np.array(0, dtype=dtype), vals) 686 687 return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype) 688 689 return rand 690 691 692def rand_int(rng, low=0, high=None): 693 def fn(shape, dtype): 694 nonlocal high 695 if low == 0 and high is None: 696 if np.issubdtype(dtype, np.integer): 697 high = np.iinfo(dtype).max 698 else: 699 raise ValueError("rand_int requires an explicit `high` value for " 700 "non-integer types.") 701 return rng.randint(low, high=high, size=shape, dtype=dtype) 702 return fn 703 704def rand_unique_int(rng, high=None): 705 def fn(shape, dtype): 706 return rng.choice(np.arange(high or prod(shape), dtype=dtype), 707 size=shape, replace=False) 708 return fn 709 710def rand_bool(rng): 711 def generator(shape, dtype): 712 return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype) 713 return generator 714 715def check_raises(thunk, err_type, msg): 716 try: 717 thunk() 718 assert False 719 except err_type as e: 720 assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg) 721 722def check_raises_regexp(thunk, err_type, pattern): 723 try: 724 thunk() 725 assert False 726 except err_type as e: 727 assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern) 728 729 730def iter_eqns(jaxpr): 731 # TODO(necula): why doesn't this search in params? 732 for eqn in jaxpr.eqns: 733 yield eqn 734 for subjaxpr in core.subjaxprs(jaxpr): 735 yield from iter_eqns(subjaxpr) 736 737def assert_dot_precision(expected_precision, fun, *args): 738 jaxpr = api.make_jaxpr(fun)(*args) 739 precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr) 740 if eqn.primitive == lax.dot_general_p] 741 for precision in precisions: 742 msg = "Unexpected precision: {} != {}".format(expected_precision, precision) 743 assert precision == expected_precision, msg 744 745 746_CACHED_INDICES: Dict[int, Sequence[int]] = {} 747 748def cases_from_list(xs): 749 xs = list(xs) 750 n = len(xs) 751 k = min(n, FLAGS.num_generated_cases) 752 # Random sampling for every parameterized test is expensive. Do it once and 753 # cache the result. 754 indices = _CACHED_INDICES.get(n) 755 if indices is None: 756 rng = npr.RandomState(42) 757 _CACHED_INDICES[n] = indices = rng.permutation(n) 758 return [xs[i] for i in indices[:k]] 759 760def cases_from_gens(*gens): 761 sizes = [1, 3, 10] 762 cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 763 for size in sizes: 764 for i in range(cases_per_size): 765 yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens) 766 767 768class JaxTestLoader(absltest.TestLoader): 769 def getTestCaseNames(self, testCaseClass): 770 names = super().getTestCaseNames(testCaseClass) 771 if FLAGS.test_targets: 772 pattern = re.compile(FLAGS.test_targets) 773 names = [name for name in names 774 if pattern.search(f"{testCaseClass.__name__}.{name}")] 775 if FLAGS.exclude_test_targets: 776 pattern = re.compile(FLAGS.exclude_test_targets) 777 names = [name for name in names 778 if not pattern.search(f"{testCaseClass.__name__}.{name}")] 779 return names 780 781 782class JaxTestCase(parameterized.TestCase): 783 """Base class for JAX tests including numerical checks and boilerplate.""" 784 785 # TODO(mattjj): this obscures the error messages from failures, figure out how 786 # to re-enable it 787 # def tearDown(self) -> None: 788 # assert core.reset_trace_state() 789 790 def setUp(self): 791 super(JaxTestCase, self).setUp() 792 core.skip_checks = False 793 # We use the adler32 hash for two reasons. 794 # a) it is deterministic run to run, unlike hash() which is randomized. 795 # b) it returns values in int32 range, which RandomState requires. 796 self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) 797 798 def rng(self): 799 return self._rng 800 801 def assertArraysEqual(self, x, y, *, check_dtypes=True): 802 """Assert that x and y arrays are exactly equal.""" 803 if check_dtypes: 804 self.assertDtypesMatch(x, y) 805 np.testing.assert_array_equal(x, y) 806 807 def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, 808 rtol=None): 809 """Assert that x and y are close (up to numerical tolerances).""" 810 self.assertEqual(x.shape, y.shape) 811 atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) 812 rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) 813 814 _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) 815 816 if check_dtypes: 817 self.assertDtypesMatch(x, y) 818 819 def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): 820 if not FLAGS.jax_enable_x64 and canonicalize_dtypes: 821 self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)), 822 _dtypes.canonicalize_dtype(_dtype(y))) 823 else: 824 self.assertEqual(_dtype(x), _dtype(y)) 825 826 def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, 827 canonicalize_dtypes=True): 828 """Assert that x and y, either arrays or nested tuples/lists, are close.""" 829 if isinstance(x, dict): 830 self.assertIsInstance(y, dict) 831 self.assertEqual(set(x.keys()), set(y.keys())) 832 for k in x.keys(): 833 self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, 834 rtol=rtol, canonicalize_dtypes=canonicalize_dtypes) 835 elif is_sequence(x) and not hasattr(x, '__array__'): 836 self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) 837 self.assertEqual(len(x), len(y)) 838 for x_elt, y_elt in zip(x, y): 839 self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, 840 rtol=rtol, canonicalize_dtypes=canonicalize_dtypes) 841 elif hasattr(x, '__array__') or np.isscalar(x): 842 self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) 843 if check_dtypes: 844 self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) 845 x = np.asarray(x) 846 y = np.asarray(y) 847 self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) 848 elif x == y: 849 return 850 else: 851 raise TypeError((type(x), type(y))) 852 853 def assertMultiLineStrippedEqual(self, expected, what): 854 """Asserts two strings are equal, after dedenting and stripping each line.""" 855 expected = textwrap.dedent(expected) 856 what = textwrap.dedent(what) 857 ignore_space_re = re.compile(r'\s*\n\s*') 858 expected_clean = re.sub(ignore_space_re, '\n', expected.strip()) 859 what_clean = re.sub(ignore_space_re, '\n', what.strip()) 860 self.assertMultiLineEqual(expected_clean, what_clean, 861 msg="Found\n{}\nExpecting\n{}".format(what, expected)) 862 863 def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, 864 rtol=None, atol=None): 865 """Helper method for running JAX compilation and allclose assertions.""" 866 args = args_maker() 867 868 def wrapped_fun(*args): 869 self.assertTrue(python_should_be_executing) 870 return fun(*args) 871 872 python_should_be_executing = True 873 python_ans = fun(*args) 874 875 python_shapes = tree_map(lambda x: np.shape(x), python_ans) 876 np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans) 877 self.assertEqual(python_shapes, np_shapes) 878 879 cache_misses = xla.xla_primitive_callable.cache_info().misses 880 python_ans = fun(*args) 881 self.assertEqual( 882 cache_misses, xla.xla_primitive_callable.cache_info().misses, 883 "Compilation detected during second call of {} in op-by-op " 884 "mode.".format(fun)) 885 886 cfun = api.jit(wrapped_fun) 887 python_should_be_executing = True 888 monitored_ans = cfun(*args) 889 890 python_should_be_executing = False 891 compiled_ans = cfun(*args) 892 893 self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, 894 atol=atol, rtol=rtol) 895 self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, 896 atol=atol, rtol=rtol) 897 898 args = args_maker() 899 900 python_should_be_executing = True 901 python_ans = fun(*args) 902 903 python_should_be_executing = False 904 compiled_ans = cfun(*args) 905 906 self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, 907 atol=atol, rtol=rtol) 908 909 def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, 910 check_dtypes=True, tol=None, 911 canonicalize_dtypes=True): 912 args = args_maker() 913 lax_ans = lax_op(*args) 914 numpy_ans = numpy_reference_op(*args) 915 self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, 916 atol=tol, rtol=tol, 917 canonicalize_dtypes=canonicalize_dtypes) 918 919 920@contextmanager 921def ignore_warning(**kw): 922 with warnings.catch_warnings(): 923 warnings.filterwarnings("ignore", **kw) 924 yield 925 926 927class _cached_property: 928 null = object() 929 930 def __init__(self, method): 931 self._method = method 932 self._value = self.null 933 934 def __get__(self, obj, cls): 935 if self._value is self.null: 936 self._value = self._method(obj) 937 return self._value 938 939 940class _LazyDtypes: 941 """A class that unifies lists of supported dtypes. 942 943 These could be module-level constants, but device_under_test() is not always 944 known at import time, so we need to define these lists lazily. 945 """ 946 def supported(self, dtypes): 947 supported = supported_dtypes() 948 return type(dtypes)(d for d in dtypes if d in supported) 949 950 @_cached_property 951 def floating(self): 952 return self.supported([np.float32, np.float64]) 953 954 @_cached_property 955 def all_floating(self): 956 return self.supported([_dtypes.bfloat16, np.float16, np.float32, np.float64]) 957 958 @_cached_property 959 def integer(self): 960 return self.supported([np.int32, np.int64]) 961 962 @_cached_property 963 def all_integer(self): 964 return self.supported([np.int8, np.int16, np.int32, np.int64]) 965 966 @_cached_property 967 def unsigned(self): 968 return self.supported([np.uint32, np.uint64]) 969 970 @_cached_property 971 def all_unsigned(self): 972 return self.supported([np.uint8, np.uint16, np.uint32, np.uint64]) 973 974 @_cached_property 975 def complex(self): 976 return self.supported([np.complex64, np.complex128]) 977 978 @_cached_property 979 def boolean(self): 980 return self.supported([np.bool_]) 981 982 @_cached_property 983 def inexact(self): 984 return self.floating + self.complex 985 986 @_cached_property 987 def all_inexact(self): 988 return self.all_floating + self.complex 989 990 @_cached_property 991 def numeric(self): 992 return self.floating + self.integer + self.unsigned + self.complex 993 994 @_cached_property 995 def all(self): 996 return (self.all_floating + self.all_integer + self.all_unsigned + 997 self.complex + self.boolean) 998 999 1000dtypes = _LazyDtypes() 1001