1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from contextlib import contextmanager
16import functools
17import re
18import os
19import textwrap
20from typing import Dict, Sequence, Union
21import unittest
22import warnings
23import zlib
24
25from absl.testing import absltest
26from absl.testing import parameterized
27
28import numpy as np
29import numpy.random as npr
30
31from . import api
32from . import core
33from . import dtypes as _dtypes
34from . import lax
35from .config import flags, bool_env
36from ._src.util import partial, prod
37from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
38from .lib import xla_bridge
39from .interpreters import xla
40
41
42FLAGS = flags.FLAGS
43flags.DEFINE_enum(
44    'jax_test_dut', '',
45    enum_values=['', 'cpu', 'gpu', 'tpu'],
46    help=
47    'Describes the device under test in case special consideration is required.'
48)
49
50flags.DEFINE_integer(
51  'num_generated_cases',
52  int(os.getenv('JAX_NUM_GENERATED_CASES', 10)),
53  help='Number of generated cases to test')
54
55flags.DEFINE_bool(
56    'jax_skip_slow_tests',
57    bool_env('JAX_SKIP_SLOW_TESTS', False),
58    help='Skip tests marked as slow (> 5 sec).'
59)
60
61flags.DEFINE_string(
62  'test_targets', '',
63  'Regular expression specifying which tests to run, called via re.match on '
64  'the test name. If empty or unspecified, run all tests.'
65)
66flags.DEFINE_string(
67  'exclude_test_targets', '',
68  'Regular expression specifying which tests NOT to run, called via re.match '
69  'on the test name. If empty or unspecified, run all tests.'
70)
71
72EPS = 1e-4
73
74def _dtype(x):
75  return (getattr(x, 'dtype', None) or
76          np.dtype(_dtypes.python_scalar_dtypes.get(type(x), None)) or
77          np.asarray(x).dtype)
78
79
80def num_float_bits(dtype):
81  return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
82
83
84def is_sequence(x):
85  try:
86    iter(x)
87  except TypeError:
88    return False
89  else:
90    return True
91
92_default_tolerance = {
93  np.dtype(np.bool_): 0,
94  np.dtype(np.int8): 0,
95  np.dtype(np.int16): 0,
96  np.dtype(np.int32): 0,
97  np.dtype(np.int64): 0,
98  np.dtype(np.uint8): 0,
99  np.dtype(np.uint16): 0,
100  np.dtype(np.uint32): 0,
101  np.dtype(np.uint64): 0,
102  np.dtype(_dtypes.bfloat16): 1e-2,
103  np.dtype(np.float16): 1e-3,
104  np.dtype(np.float32): 1e-6,
105  np.dtype(np.float64): 1e-15,
106  np.dtype(np.complex64): 1e-6,
107  np.dtype(np.complex128): 1e-15,
108}
109
110def default_tolerance():
111  if device_under_test() != "tpu":
112    return _default_tolerance
113  tol = _default_tolerance.copy()
114  tol[np.dtype(np.float32)] = 1e-3
115  tol[np.dtype(np.complex64)] = 1e-3
116  return tol
117
118default_gradient_tolerance = {
119  np.dtype(_dtypes.bfloat16): 1e-1,
120  np.dtype(np.float16): 1e-2,
121  np.dtype(np.float32): 2e-3,
122  np.dtype(np.float64): 1e-5,
123  np.dtype(np.complex64): 1e-3,
124  np.dtype(np.complex128): 1e-5,
125}
126
127def _assert_numpy_allclose(a, b, atol=None, rtol=None):
128  a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
129  b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
130  kw = {}
131  if atol: kw["atol"] = atol
132  if rtol: kw["rtol"] = rtol
133  np.testing.assert_allclose(a, b, **kw)
134
135def tolerance(dtype, tol=None):
136  tol = {} if tol is None else tol
137  if not isinstance(tol, dict):
138    return tol
139  tol = {np.dtype(key): value for key, value in tol.items()}
140  dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
141  return tol.get(dtype, default_tolerance()[dtype])
142
143def _normalize_tolerance(tol):
144  tol = tol or 0
145  if isinstance(tol, dict):
146    return {np.dtype(k): v for k, v in tol.items()}
147  else:
148    return {k: tol for k in _default_tolerance}
149
150def join_tolerance(tol1, tol2):
151  tol1 = _normalize_tolerance(tol1)
152  tol2 = _normalize_tolerance(tol2)
153  out = tol1
154  for k, v in tol2.items():
155    out[k] = max(v, tol1.get(k, 0))
156  return out
157
158def _assert_numpy_close(a, b, atol=None, rtol=None):
159  assert a.shape == b.shape
160  atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol))
161  rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol))
162  _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size)
163
164
165def check_eq(xs, ys):
166  tree_all(tree_multimap(_assert_numpy_allclose, xs, ys))
167
168
169def check_close(xs, ys, atol=None, rtol=None):
170  assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol)
171  tree_all(tree_multimap(assert_close, xs, ys))
172
173def _check_dtypes_match(xs, ys):
174  def _assert_dtypes_match(x, y):
175    if FLAGS.jax_enable_x64:
176      assert _dtype(x) == _dtype(y)
177    else:
178      assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
179              _dtypes.canonicalize_dtype(_dtype(y)))
180  tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
181
182
183def inner_prod(xs, ys):
184  def contract(x, y):
185    return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
186  return tree_reduce(np.add, tree_multimap(contract, xs, ys))
187
188
189def _safe_subtract(x, y, *, dtype):
190  """Subtraction that with `inf - inf == 0` semantics."""
191  with np.errstate(invalid='ignore'):
192    return np.where(np.equal(x, y), np.array(0, dtype),
193                    np.subtract(x, y, dtype=dtype))
194
195add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
196sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
197safe_sub = partial(tree_multimap,
198                   lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
199conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
200
201def scalar_mul(xs, a):
202  return tree_map(lambda x: np.multiply(x, a, dtype=_dtype(x)), xs)
203
204
205def rand_like(rng, x):
206  shape = np.shape(x)
207  dtype = _dtype(x)
208  randn = lambda: np.asarray(rng.randn(*shape), dtype=dtype)
209  if _dtypes.issubdtype(dtype, np.complexfloating):
210    return randn() + dtype.type(1.0j) * randn()
211  else:
212    return randn()
213
214
215def numerical_jvp(f, primals, tangents, eps=EPS):
216  delta = scalar_mul(tangents, eps)
217  f_pos = f(*add(primals, delta))
218  f_neg = f(*sub(primals, delta))
219  return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)
220
221
222def _merge_tolerance(tol, default):
223  if tol is None:
224    return default
225  if not isinstance(tol, dict):
226    return tol
227  out = default.copy()
228  for k, v in tol.items():
229    out[np.dtype(k)] = v
230  return out
231
232def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS):
233  atol = _merge_tolerance(atol, default_gradient_tolerance)
234  rtol = _merge_tolerance(rtol, default_gradient_tolerance)
235  rng = np.random.RandomState(0)
236  tangent = tree_map(partial(rand_like, rng), args)
237  v_out, t_out = f_jvp(args, tangent)
238  _check_dtypes_match(v_out, t_out)
239  v_out_expected = f(*args)
240  _check_dtypes_match(v_out, v_out_expected)
241  t_out_expected = numerical_jvp(f, args, tangent, eps=eps)
242  # In principle we should expect exact equality of v_out and v_out_expected,
243  # but due to nondeterminism especially on GPU (e.g., due to convolution
244  # autotuning) we only require "close".
245  check_close(v_out, v_out_expected, atol=atol, rtol=rtol)
246  check_close(t_out, t_out_expected, atol=atol, rtol=rtol)
247
248
249def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS):
250  atol = _merge_tolerance(atol, default_gradient_tolerance)
251  rtol = _merge_tolerance(rtol, default_gradient_tolerance)
252  _rand_like = partial(rand_like, np.random.RandomState(0))
253  v_out, vjpfun = f_vjp(*args)
254  v_out_expected = f(*args)
255  check_close(v_out, v_out_expected, atol=atol, rtol=rtol)
256  tangent = tree_map(_rand_like, args)
257  tangent_out = numerical_jvp(f, args, tangent, eps=eps)
258  cotangent = tree_map(_rand_like, v_out)
259  cotangent_out = conj(vjpfun(conj(cotangent)))
260  ip = inner_prod(tangent, cotangent_out)
261  ip_expected = inner_prod(tangent_out, cotangent)
262  check_close(ip, ip_expected, atol=atol, rtol=rtol)
263
264
265def check_grads(f, args, order,
266                modes=["fwd", "rev"], atol=None, rtol=None, eps=None):
267  """Check gradients from automatic differentiation against finite differences.
268
269  Gradients are only checked in a single randomly chosen direction, which
270  ensures that the finite difference calculation does not become prohibitively
271  expensive even for large input/output spaces.
272
273  Args:
274    f: function to check at ``f(*args)``.
275    args: tuple of argument values.
276    order: forward and backwards gradients up to this order are checked.
277    modes: lists of gradient modes to check ('fwd' and/or 'rev').
278    atol: absolute tolerance for gradient equality.
279    rtol: relative tolerance for gradient equality.
280    eps: step size used for finite differences.
281
282  Raises:
283    AssertionError: if gradients do not match.
284  """
285  args = tuple(args)
286  eps = eps or EPS
287
288  _check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps)
289  _check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps)
290
291  def _check_grads(f, args, order):
292    if "fwd" in modes:
293      _check_jvp(f, partial(api.jvp, f), args)
294      if order > 1:
295        _check_grads(partial(api.jvp, f), (args, args), order - 1)
296
297    if "rev" in modes:
298      _check_vjp(f, partial(api.vjp, f), args)
299      if order > 1:
300        def f_vjp(*args):
301          out_primal_py, vjp_py = api.vjp(f, *args)
302          return vjp_py(out_primal_py)
303        _check_grads(f_vjp, args, order - 1)
304
305  _check_grads(f, args, order)
306
307
308@contextmanager
309def count_primitive_compiles():
310  xla.xla_primitive_callable.cache_clear()
311
312  # We count how many times we call primitive_computation (which is called
313  # inside xla_primitive_callable) instead of xla_primitive_callable so we don't
314  # count cache hits.
315  primitive_computation = xla.primitive_computation
316  count = [0]
317
318  def primitive_computation_and_count(*args, **kwargs):
319    count[0] += 1
320    return primitive_computation(*args, **kwargs)
321
322  xla.primitive_computation = primitive_computation_and_count
323  try:
324    yield count
325  finally:
326    xla.primitive_computation = primitive_computation
327
328
329@contextmanager
330def count_jit_and_pmap_compiles():
331  # No need to clear any caches since we generally jit and pmap fresh callables
332  # in tests.
333
334  jaxpr_subcomp = xla.jaxpr_subcomp
335  count = [0]
336
337  def jaxpr_subcomp_and_count(*args, **kwargs):
338    count[0] += 1
339    return jaxpr_subcomp(*args, **kwargs)
340
341  xla.jaxpr_subcomp = jaxpr_subcomp_and_count
342  try:
343    yield count
344  finally:
345    xla.jaxpr_subcomp = jaxpr_subcomp
346
347@contextmanager
348def assert_num_jit_and_pmap_compilations(times):
349  with count_jit_and_pmap_compiles() as count:
350    yield
351  if count[0] != times:
352    raise AssertionError(f"Expected exactly {times} XLA compilations, "
353                         f"but executed {count[0]}")
354
355def device_under_test():
356  return FLAGS.jax_test_dut or xla_bridge.get_backend().platform
357
358def if_device_under_test(device_type: Union[str, Sequence[str]],
359                         if_true, if_false):
360  """Chooses `if_true` of `if_false` based on device_under_test."""
361  if device_under_test() in ([device_type] if isinstance(device_type, str)
362                             else device_type):
363    return if_true
364  else:
365    return if_false
366
367def supported_dtypes():
368  if device_under_test() == "tpu":
369    types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
370             np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
371  else:
372    types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
373             np.uint8, np.uint16, np.uint32, np.uint64,
374             _dtypes.bfloat16, np.float16, np.float32, np.float64,
375             np.complex64, np.complex128}
376  if not FLAGS.jax_enable_x64:
377    types -= {np.uint64, np.int64, np.float64, np.complex128}
378  return types
379
380def skip_if_unsupported_type(dtype):
381  dtype = np.dtype(dtype)
382  if dtype.type not in supported_dtypes():
383    raise unittest.SkipTest(
384      f"Type {dtype.name} not supported on {device_under_test()}")
385
386def skip_on_devices(*disabled_devices):
387  """A decorator for test methods to skip the test on certain devices."""
388  def skip(test_method):
389    @functools.wraps(test_method)
390    def test_method_wrapper(self, *args, **kwargs):
391      device = device_under_test()
392      if device in disabled_devices:
393        test_name = getattr(test_method, '__name__', '[unknown test]')
394        raise unittest.SkipTest(
395          f"{test_name} not supported on {device.upper()}.")
396      return test_method(self, *args, **kwargs)
397    return test_method_wrapper
398  return skip
399
400def set_host_platform_device_count(nr_devices: int):
401  """Returns a closure that undoes the operation."""
402  prev_xla_flags = os.getenv("XLA_FLAGS")
403  flags_str = prev_xla_flags or ""
404  # Don't override user-specified device count, or other XLA flags.
405  if "xla_force_host_platform_device_count" not in flags_str:
406    os.environ["XLA_FLAGS"] = (flags_str +
407                               f" --xla_force_host_platform_device_count={nr_devices}")
408  # Clear any cached backends so new CPU backend will pick up the env var.
409  xla_bridge.get_backend.cache_clear()
410  def undo():
411    if prev_xla_flags is None:
412      del os.environ["XLA_FLAGS"]
413    else:
414      os.environ["XLA_FLAGS"] = prev_xla_flags
415    xla_bridge.get_backend.cache_clear()
416  return undo
417
418def skip_on_flag(flag_name, skip_value):
419  """A decorator for test methods to skip the test when flags are set."""
420  def skip(test_method):        # pylint: disable=missing-docstring
421    @functools.wraps(test_method)
422    def test_method_wrapper(self, *args, **kwargs):
423      flag_value = getattr(FLAGS, flag_name)
424      if flag_value == skip_value:
425        test_name = getattr(test_method, '__name__', '[unknown test]')
426        raise unittest.SkipTest(
427          f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}")
428      return test_method(self, *args, **kwargs)
429    return test_method_wrapper
430  return skip
431
432
433def format_test_name_suffix(opname, shapes, dtypes):
434  arg_descriptions = (format_shape_dtype_string(shape, dtype)
435                      for shape, dtype in zip(shapes, dtypes))
436  return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions))
437
438
439# We use special symbols, represented as singleton objects, to distinguish
440# between NumPy scalars, Python scalars, and 0-D arrays.
441class ScalarShape(object):
442  def __len__(self): return 0
443class _NumpyScalar(ScalarShape): pass
444class _PythonScalar(ScalarShape): pass
445NUMPY_SCALAR_SHAPE = _NumpyScalar()
446PYTHON_SCALAR_SHAPE = _PythonScalar()
447
448
449def _dims_of_shape(shape):
450  """Converts `shape` to a tuple of dimensions."""
451  if type(shape) in (list, tuple):
452    return shape
453  elif isinstance(shape, ScalarShape):
454    return ()
455  else:
456    raise TypeError(type(shape))
457
458
459def _cast_to_shape(value, shape, dtype):
460  """Casts `value` to the correct Python type for `shape` and `dtype`."""
461  if shape is NUMPY_SCALAR_SHAPE:
462    # explicitly cast to NumPy scalar in case `value` is a Python scalar.
463    return np.dtype(dtype).type(value)
464  elif shape is PYTHON_SCALAR_SHAPE:
465    # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998
466    return np.asarray(value).item()
467  elif type(shape) in (list, tuple):
468    assert np.shape(value) == tuple(shape)
469    return value
470  else:
471    raise TypeError(type(shape))
472
473
474def dtype_str(dtype):
475  return np.dtype(dtype).name
476
477
478def format_shape_dtype_string(shape, dtype):
479  if isinstance(shape, np.ndarray):
480    return f'{dtype_str(dtype)}[{shape}]'
481  elif isinstance(shape, list):
482    shape = tuple(shape)
483  return _format_shape_dtype_string(shape, dtype)
484
485@functools.lru_cache(maxsize=64)
486def _format_shape_dtype_string(shape, dtype):
487  if shape is NUMPY_SCALAR_SHAPE:
488    return dtype_str(dtype)
489  elif shape is PYTHON_SCALAR_SHAPE:
490    return 'py' + dtype_str(dtype)
491  elif type(shape) is tuple:
492    shapestr = ','.join(str(dim) for dim in shape)
493    return '{}[{}]'.format(dtype_str(dtype), shapestr)
494  elif type(shape) is int:
495    return '{}[{},]'.format(dtype_str(dtype), shape)
496  else:
497    raise TypeError(type(shape))
498
499
500def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
501  """Produce random values given shape, dtype, scale, and post-processor.
502
503  Args:
504    rand: a function for producing random values of a given shape, e.g. a
505      bound version of either np.RandomState.randn or np.RandomState.rand.
506    shape: a shape value as a tuple of positive integers.
507    dtype: a numpy dtype.
508    scale: optional, a multiplicative scale for the random values (default 1).
509    post: optional, a callable for post-processing the random values (default
510      identity).
511
512  Returns:
513    An ndarray of the given shape and dtype using random values based on a call
514    to rand but scaled, converted to the appropriate dtype, and post-processed.
515  """
516  r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype)
517  if _dtypes.issubdtype(dtype, np.complexfloating):
518    vals = r() + 1.0j * r()
519  else:
520    vals = r()
521  return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
522
523
524def rand_fullrange(rng, standardize_nans=False):
525  """Random numbers that span the full range of available bits."""
526  def gen(shape, dtype, post=lambda x: x):
527    dtype = np.dtype(dtype)
528    size = dtype.itemsize * np.prod(_dims_of_shape(shape))
529    vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
530    vals = post(vals).view(dtype).reshape(shape)
531    # Non-standard NaNs cause errors in numpy equality assertions.
532    if standardize_nans and np.issubdtype(dtype, np.floating):
533      vals[np.isnan(vals)] = np.nan
534    return _cast_to_shape(vals, shape, dtype)
535  return gen
536
537
538def rand_default(rng, scale=3):
539  return partial(_rand_dtype, rng.randn, scale=scale)
540
541
542def rand_nonzero(rng):
543  post = lambda x: np.where(x == 0, np.array(1, dtype=x.dtype), x)
544  return partial(_rand_dtype, rng.randn, scale=3, post=post)
545
546
547def rand_positive(rng):
548  post = lambda x: x + 1
549  return partial(_rand_dtype, rng.rand, scale=2, post=post)
550
551
552def rand_small(rng):
553  return partial(_rand_dtype, rng.randn, scale=1e-3)
554
555
556def rand_not_small(rng, offset=10.):
557  post = lambda x: x + np.where(x > 0, offset, -offset)
558  return partial(_rand_dtype, rng.randn, scale=3., post=post)
559
560
561def rand_small_positive(rng):
562  return partial(_rand_dtype, rng.rand, scale=2e-5)
563
564def rand_uniform(rng, low=0.0, high=1.0):
565  assert low < high
566  post = lambda x: x * (high - low) + low
567  return partial(_rand_dtype, rng.rand, post=post)
568
569
570def rand_some_equal(rng):
571
572  def post(x):
573    x_ravel = x.ravel()
574    if len(x_ravel) == 0:
575      return x
576    flips = rng.rand(*np.shape(x)) < 0.5
577    return np.where(flips, x_ravel[0], x)
578
579  return partial(_rand_dtype, rng.randn, scale=100., post=post)
580
581
582def rand_some_inf(rng):
583  """Return a random sampler that produces infinities in floating types."""
584  base_rand = rand_default(rng)
585
586  """
587  TODO: Complex numbers are not correctly tested
588  If blocks should be switched in order, and relevant tests should be fixed
589  """
590  def rand(shape, dtype):
591    """The random sampler function."""
592    if not _dtypes.issubdtype(dtype, np.floating):
593      # only float types have inf
594      return base_rand(shape, dtype)
595
596    if _dtypes.issubdtype(dtype, np.complexfloating):
597      base_dtype = np.real(np.array(0, dtype=dtype)).dtype
598      out = (rand(shape, base_dtype) +
599             np.array(1j, dtype) * rand(shape, base_dtype))
600      return _cast_to_shape(out, shape, dtype)
601
602    dims = _dims_of_shape(shape)
603    posinf_flips = rng.rand(*dims) < 0.1
604    neginf_flips = rng.rand(*dims) < 0.1
605
606    vals = base_rand(shape, dtype)
607    vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
608    vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
609
610    return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
611
612  return rand
613
614def rand_some_nan(rng):
615  """Return a random sampler that produces nans in floating types."""
616  base_rand = rand_default(rng)
617
618  def rand(shape, dtype):
619    """The random sampler function."""
620    if _dtypes.issubdtype(dtype, np.complexfloating):
621      base_dtype = np.real(np.array(0, dtype=dtype)).dtype
622      out = (rand(shape, base_dtype) +
623             np.array(1j, dtype) * rand(shape, base_dtype))
624      return _cast_to_shape(out, shape, dtype)
625
626    if not _dtypes.issubdtype(dtype, np.floating):
627      # only float types have inf
628      return base_rand(shape, dtype)
629
630    dims = _dims_of_shape(shape)
631    nan_flips = rng.rand(*dims) < 0.1
632
633    vals = base_rand(shape, dtype)
634    vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
635
636    return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
637
638  return rand
639
640def rand_some_inf_and_nan(rng):
641  """Return a random sampler that produces infinities in floating types."""
642  base_rand = rand_default(rng)
643
644  """
645  TODO: Complex numbers are not correctly tested
646  If blocks should be switched in order, and relevant tests should be fixed
647  """
648  def rand(shape, dtype):
649    """The random sampler function."""
650    if not _dtypes.issubdtype(dtype, np.floating):
651      # only float types have inf
652      return base_rand(shape, dtype)
653
654    if _dtypes.issubdtype(dtype, np.complexfloating):
655      base_dtype = np.real(np.array(0, dtype=dtype)).dtype
656      out = (rand(shape, base_dtype) +
657             np.array(1j, dtype) * rand(shape, base_dtype))
658      return _cast_to_shape(out, shape, dtype)
659
660    dims = _dims_of_shape(shape)
661    posinf_flips = rng.rand(*dims) < 0.1
662    neginf_flips = rng.rand(*dims) < 0.1
663    nan_flips = rng.rand(*dims) < 0.1
664
665    vals = base_rand(shape, dtype)
666    vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
667    vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)
668    vals = np.where(nan_flips, np.array(np.nan, dtype=dtype), vals)
669
670    return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
671
672  return rand
673
674# TODO(mattjj): doesn't handle complex types
675def rand_some_zero(rng):
676  """Return a random sampler that produces some zeros."""
677  base_rand = rand_default(rng)
678
679  def rand(shape, dtype):
680    """The random sampler function."""
681    dims = _dims_of_shape(shape)
682    zeros = rng.rand(*dims) < 0.5
683
684    vals = base_rand(shape, dtype)
685    vals = np.where(zeros, np.array(0, dtype=dtype), vals)
686
687    return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
688
689  return rand
690
691
692def rand_int(rng, low=0, high=None):
693  def fn(shape, dtype):
694    nonlocal high
695    if low == 0 and high is None:
696      if np.issubdtype(dtype, np.integer):
697        high = np.iinfo(dtype).max
698      else:
699        raise ValueError("rand_int requires an explicit `high` value for "
700                         "non-integer types.")
701    return rng.randint(low, high=high, size=shape, dtype=dtype)
702  return fn
703
704def rand_unique_int(rng, high=None):
705  def fn(shape, dtype):
706    return rng.choice(np.arange(high or prod(shape), dtype=dtype),
707                      size=shape, replace=False)
708  return fn
709
710def rand_bool(rng):
711  def generator(shape, dtype):
712    return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype)
713  return generator
714
715def check_raises(thunk, err_type, msg):
716  try:
717    thunk()
718    assert False
719  except err_type as e:
720    assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg)
721
722def check_raises_regexp(thunk, err_type, pattern):
723  try:
724    thunk()
725    assert False
726  except err_type as e:
727    assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern)
728
729
730def iter_eqns(jaxpr):
731  # TODO(necula): why doesn't this search in params?
732  for eqn in jaxpr.eqns:
733    yield eqn
734  for subjaxpr in core.subjaxprs(jaxpr):
735    yield from iter_eqns(subjaxpr)
736
737def assert_dot_precision(expected_precision, fun, *args):
738  jaxpr = api.make_jaxpr(fun)(*args)
739  precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
740                if eqn.primitive == lax.dot_general_p]
741  for precision in precisions:
742    msg = "Unexpected precision: {} != {}".format(expected_precision, precision)
743    assert precision == expected_precision, msg
744
745
746_CACHED_INDICES: Dict[int, Sequence[int]] = {}
747
748def cases_from_list(xs):
749  xs = list(xs)
750  n = len(xs)
751  k = min(n, FLAGS.num_generated_cases)
752  # Random sampling for every parameterized test is expensive. Do it once and
753  # cache the result.
754  indices = _CACHED_INDICES.get(n)
755  if indices is None:
756    rng = npr.RandomState(42)
757    _CACHED_INDICES[n] = indices = rng.permutation(n)
758  return [xs[i] for i in indices[:k]]
759
760def cases_from_gens(*gens):
761  sizes = [1, 3, 10]
762  cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1
763  for size in sizes:
764    for i in range(cases_per_size):
765      yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens)
766
767
768class JaxTestLoader(absltest.TestLoader):
769  def getTestCaseNames(self, testCaseClass):
770    names = super().getTestCaseNames(testCaseClass)
771    if FLAGS.test_targets:
772      pattern = re.compile(FLAGS.test_targets)
773      names = [name for name in names
774               if pattern.search(f"{testCaseClass.__name__}.{name}")]
775    if FLAGS.exclude_test_targets:
776      pattern = re.compile(FLAGS.exclude_test_targets)
777      names = [name for name in names
778               if not pattern.search(f"{testCaseClass.__name__}.{name}")]
779    return names
780
781
782class JaxTestCase(parameterized.TestCase):
783  """Base class for JAX tests including numerical checks and boilerplate."""
784
785  # TODO(mattjj): this obscures the error messages from failures, figure out how
786  # to re-enable it
787  # def tearDown(self) -> None:
788  #   assert core.reset_trace_state()
789
790  def setUp(self):
791    super(JaxTestCase, self).setUp()
792    core.skip_checks = False
793    # We use the adler32 hash for two reasons.
794    # a) it is deterministic run to run, unlike hash() which is randomized.
795    # b) it returns values in int32 range, which RandomState requires.
796    self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
797
798  def rng(self):
799    return self._rng
800
801  def assertArraysEqual(self, x, y, *, check_dtypes=True):
802    """Assert that x and y arrays are exactly equal."""
803    if check_dtypes:
804      self.assertDtypesMatch(x, y)
805    np.testing.assert_array_equal(x, y)
806
807  def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
808                           rtol=None):
809    """Assert that x and y are close (up to numerical tolerances)."""
810    self.assertEqual(x.shape, y.shape)
811    atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
812    rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
813
814    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol)
815
816    if check_dtypes:
817      self.assertDtypesMatch(x, y)
818
819  def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
820    if not FLAGS.jax_enable_x64 and canonicalize_dtypes:
821      self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
822                       _dtypes.canonicalize_dtype(_dtype(y)))
823    else:
824      self.assertEqual(_dtype(x), _dtype(y))
825
826  def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
827                     canonicalize_dtypes=True):
828    """Assert that x and y, either arrays or nested tuples/lists, are close."""
829    if isinstance(x, dict):
830      self.assertIsInstance(y, dict)
831      self.assertEqual(set(x.keys()), set(y.keys()))
832      for k in x.keys():
833        self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
834                            rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
835    elif is_sequence(x) and not hasattr(x, '__array__'):
836      self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
837      self.assertEqual(len(x), len(y))
838      for x_elt, y_elt in zip(x, y):
839        self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
840                            rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
841    elif hasattr(x, '__array__') or np.isscalar(x):
842      self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
843      if check_dtypes:
844        self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
845      x = np.asarray(x)
846      y = np.asarray(y)
847      self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol)
848    elif x == y:
849      return
850    else:
851      raise TypeError((type(x), type(y)))
852
853  def assertMultiLineStrippedEqual(self, expected, what):
854    """Asserts two strings are equal, after dedenting and stripping each line."""
855    expected = textwrap.dedent(expected)
856    what = textwrap.dedent(what)
857    ignore_space_re = re.compile(r'\s*\n\s*')
858    expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
859    what_clean = re.sub(ignore_space_re, '\n', what.strip())
860    self.assertMultiLineEqual(expected_clean, what_clean,
861                              msg="Found\n{}\nExpecting\n{}".format(what, expected))
862
863  def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
864                       rtol=None, atol=None):
865    """Helper method for running JAX compilation and allclose assertions."""
866    args = args_maker()
867
868    def wrapped_fun(*args):
869      self.assertTrue(python_should_be_executing)
870      return fun(*args)
871
872    python_should_be_executing = True
873    python_ans = fun(*args)
874
875    python_shapes = tree_map(lambda x: np.shape(x), python_ans)
876    np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
877    self.assertEqual(python_shapes, np_shapes)
878
879    cache_misses = xla.xla_primitive_callable.cache_info().misses
880    python_ans = fun(*args)
881    self.assertEqual(
882        cache_misses, xla.xla_primitive_callable.cache_info().misses,
883        "Compilation detected during second call of {} in op-by-op "
884        "mode.".format(fun))
885
886    cfun = api.jit(wrapped_fun)
887    python_should_be_executing = True
888    monitored_ans = cfun(*args)
889
890    python_should_be_executing = False
891    compiled_ans = cfun(*args)
892
893    self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
894                        atol=atol, rtol=rtol)
895    self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
896                        atol=atol, rtol=rtol)
897
898    args = args_maker()
899
900    python_should_be_executing = True
901    python_ans = fun(*args)
902
903    python_should_be_executing = False
904    compiled_ans = cfun(*args)
905
906    self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
907                        atol=atol, rtol=rtol)
908
909  def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
910                         check_dtypes=True, tol=None,
911                         canonicalize_dtypes=True):
912    args = args_maker()
913    lax_ans = lax_op(*args)
914    numpy_ans = numpy_reference_op(*args)
915    self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
916                        atol=tol, rtol=tol,
917                        canonicalize_dtypes=canonicalize_dtypes)
918
919
920@contextmanager
921def ignore_warning(**kw):
922  with warnings.catch_warnings():
923    warnings.filterwarnings("ignore", **kw)
924    yield
925
926
927class _cached_property:
928  null = object()
929
930  def __init__(self, method):
931    self._method = method
932    self._value = self.null
933
934  def __get__(self, obj, cls):
935    if self._value is self.null:
936      self._value = self._method(obj)
937    return self._value
938
939
940class _LazyDtypes:
941  """A class that unifies lists of supported dtypes.
942
943  These could be module-level constants, but device_under_test() is not always
944  known at import time, so we need to define these lists lazily.
945  """
946  def supported(self, dtypes):
947    supported = supported_dtypes()
948    return type(dtypes)(d for d in dtypes if d in supported)
949
950  @_cached_property
951  def floating(self):
952    return self.supported([np.float32, np.float64])
953
954  @_cached_property
955  def all_floating(self):
956    return self.supported([_dtypes.bfloat16, np.float16, np.float32, np.float64])
957
958  @_cached_property
959  def integer(self):
960    return self.supported([np.int32, np.int64])
961
962  @_cached_property
963  def all_integer(self):
964    return self.supported([np.int8, np.int16, np.int32, np.int64])
965
966  @_cached_property
967  def unsigned(self):
968    return self.supported([np.uint32, np.uint64])
969
970  @_cached_property
971  def all_unsigned(self):
972    return self.supported([np.uint8, np.uint16, np.uint32, np.uint64])
973
974  @_cached_property
975  def complex(self):
976    return self.supported([np.complex64, np.complex128])
977
978  @_cached_property
979  def boolean(self):
980    return self.supported([np.bool_])
981
982  @_cached_property
983  def inexact(self):
984    return self.floating + self.complex
985
986  @_cached_property
987  def all_inexact(self):
988    return self.all_floating + self.complex
989
990  @_cached_property
991  def numeric(self):
992    return self.floating + self.integer + self.unsigned + self.complex
993
994  @_cached_property
995  def all(self):
996    return (self.all_floating + self.all_integer + self.all_unsigned +
997            self.complex + self.boolean)
998
999
1000dtypes = _LazyDtypes()
1001