1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""See primitives_test docstring for how the Jax2TfLimitations are used"""
15
16import itertools
17import numpy as np
18from typing import Any, Callable, Optional, Sequence
19
20from jax import dtypes
21from jax import lax
22from jax import numpy as jnp
23
24from jax.experimental.jax2tf.tests import primitive_harness
25
26DType = Any
27
28
29class Jax2TfLimitation(primitive_harness.Limitation):
30  """Specific primitive limitations for jax2tf.
31
32  See the primitive_test module docstring for details.
33  """
34  def __init__(
35      self,
36      description: str,
37      *,
38      devices: Sequence[str] = ("cpu", "gpu", "tpu"),
39      dtypes: Sequence[DType] = (),
40      enabled: bool = True,
41      # jax2tf specific
42      modes=("eager", "graph", "compiled"),
43      skip_tf_run=False,
44      expect_tf_error: bool = True,
45      skip_comparison=False,
46      custom_assert: Optional[Callable] = None,
47      tol=None):
48    """See the primitive_harness.Limitation common arguments.
49
50    Args :
51      modes: one of "eager", "graph", "compiled"
52      skip_tf_run: if set will skip the TF execution. Use this sparingly,
53        prefer `expect_tf_error`. Use only when the test cannot recover from
54        the TF error.
55      expect_tf_error: if set, then expect a TF error in the given mode when
56        executing the result of jax2tf conversion. If not set, then the
57        limitation must have a custom_assert or non-default tol.
58      skip_comparison: skips the numeric comparison.
59      tol: a tolerance to use for both atol and rtol. We will use the maximum
60        tolerance over all the applicable limitations, irrespective of their
61        order.
62      custom_assert: if given, then execute as
63        `custom_assert(tst, result_jax, result_tf, args=args, tol=tol)`, where
64        `tst` is the current TestCase instance, and args are the input
65        arguments that the harness created. The `tol` is the maximum tolerance
66        based on the applicable limitations.
67        `result_tf` is already converted to NumPy arrays.
68    """
69    super().__init__(
70        description,
71        devices=devices,
72        dtypes=dtypes,
73        enabled=enabled)
74    if isinstance(modes, str):
75      modes = (modes,)
76    assert all(m in ["eager", "graph", "compiled"] for m in modes)
77    self.modes = modes
78    self.expect_tf_error = expect_tf_error
79    self.skip_tf_run = skip_tf_run
80    self.custom_assert = custom_assert
81    self.tol = tol
82    self.skip_comparison = skip_comparison
83
84
85  def get_max_tolerance_limitation(
86      self, limitations: Sequence["Jax2TfLimitation"]) -> Optional["Jax2TfLimitation"]:
87    """Pick the tolerance limitation that establishes the maximum tolerance"""
88    # TODO: it would be best if the limitations with tolerance are mutually exclusive
89    # and we don't have to compute the maximum
90    # TODO: we made this an instance method only so that we don't have to import
91    # this module from tf_test.util.
92    max_tol_lim = None
93    for l in limitations:
94      if l.tol is not None:
95        if max_tol_lim is None or l.tol > max_tol_lim.tol:
96          max_tol_lim = l
97    return max_tol_lim
98
99  def filter(self,  # type: ignore[override]
100             dtype: Optional[DType] = None,
101             device: Optional[str] = None,
102             mode: Optional[str] = None) -> bool:
103    return ((mode is None or mode in self.modes) and
104            super().filter(device=device, dtype=dtype))
105
106
107  @classmethod
108  def limitations_for_harness(
109      cls, harness: primitive_harness.Harness) -> Sequence["Jax2TfLimitation"]:
110    group_method = getattr(cls, harness.group_name, None)
111    if harness.group_name in cls.harness_groups_no_limitations:
112      assert group_method is None, (
113          f"Harness group {harness.group_name} is both in "
114          f"'harness_groups_no_limitations' and has a custom "
115          f"Jax2TfLimitation.classmethod defined (see module docstring)"
116      )
117      return []
118    else:
119      assert group_method is not None, (
120          f"Harness group {harness.group_name} must be either part of "
121          f"'harness_groups_no_limitations' or must have a custom "
122          f"Jax2TfLimitation.classmethod defined (see module docstring)"
123      )
124      limitations = group_method(harness)
125      assert isinstance(limitations, (list, tuple))
126      return limitations
127
128
129  # We keep here the explicit set of groups for which we don't have limitations
130  harness_groups_no_limitations = {
131      "abs", "and", "argmin", "argmax", "broadcast", "broadcast_in_dim", "ceil",
132      "concatenate", "cos", "complex", "conj", "device_put", "dynamic_slice",
133      "dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag",
134      "iota", "is_finite", "ne", "not", "or", "pad", "random_split",
135      "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape",
136      "select", "shift_left", "shift_right_logical", "shift_right_arithmetic",
137      "sin", "slice", "sqrt", "squeeze", "stop_gradient", "tie_in", "transpose",
138      "xor", "zeros_like"
139  }
140
141
142
143  @classmethod
144  def helper_get_trig_custom_limitation(cls, np_inverse):
145
146    def custom_assert(tst, result_jax, result_tf, *, args, tol):
147      operand, = args
148      tst.assertAllClose(operand, np_inverse(result_tf), atol=tol, rtol=tol)
149
150    return custom_numeric(
151        description="May return different but still correct results",
152        dtypes=[np.complex64, np.complex128],
153        custom_assert=custom_assert,
154        modes=("eager", "graph"))
155
156  @classmethod
157  def acos(cls, harness: primitive_harness.Harness):
158    return [
159        missing_tf_kernel(
160            dtypes=[np.float16, dtypes.bfloat16, np.complex64],
161            devices=("cpu", "gpu"),
162            modes=("eager", "graph")),
163        missing_tf_kernel(
164            dtypes=[np.complex128],
165            devices=("cpu", "gpu"),
166            modes=("eager", "graph")),
167        custom_numeric(dtypes=np.complex128, tol=1e-13),
168        custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-3),
169        custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4),
170        cls.helper_get_trig_custom_limitation(np.cos),
171    ]
172
173  @classmethod
174  def acosh(cls, harness: primitive_harness.Harness):
175    return [
176        missing_tf_kernel(
177            dtypes=[dtypes.bfloat16, np.float16],
178            devices=("cpu", "gpu"),
179            modes=("eager", "graph")),
180        custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
181        custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
182        cls.helper_get_trig_custom_limitation(np.cosh)
183    ]
184
185  @classmethod
186  def add(cls, harness: primitive_harness.Harness):
187    return [
188        missing_tf_kernel(dtypes=[np.uint16]),
189        missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu"))
190    ]
191
192  @classmethod
193  # Also called add_jaxvals
194  def add_any(cls, harness: primitive_harness.Harness):
195    return [missing_tf_kernel(dtypes=[np.uint16, np.uint64])]
196
197  @classmethod
198  def asin(cls, harness: primitive_harness.Harness):
199    return [
200        missing_tf_kernel(
201            dtypes=[np.float16, dtypes.bfloat16],
202            devices=("cpu", "gpu"),
203            modes=("eager", "graph")),
204        missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
205        cls.helper_get_trig_custom_limitation(np.sin)
206    ]
207
208  @classmethod
209  def asinh(cls, harness: primitive_harness.Harness):
210    return [
211        missing_tf_kernel(
212            dtypes=[np.float16, dtypes.bfloat16],
213            devices=("cpu", "gpu"),
214            modes=("eager", "graph")),
215        custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
216        custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
217        cls.helper_get_trig_custom_limitation(np.sinh)
218    ]
219
220  @classmethod
221  def atan(cls, harness: primitive_harness.Harness):
222    return [
223        missing_tf_kernel(
224            dtypes=[np.float16, dtypes.bfloat16],
225            devices=("cpu", "gpu"),
226            modes=("eager", "graph")),
227        missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
228        cls.helper_get_trig_custom_limitation(np.tan)
229    ]
230
231  @classmethod
232  def atanh(cls, harness: primitive_harness.Harness):
233    return [
234        missing_tf_kernel(
235            dtypes=[np.float16, dtypes.bfloat16],
236            devices=("cpu", "gpu"),
237            modes=("eager", "graph")),
238        custom_numeric(dtypes=np.float64, tol=1e-14),
239        custom_numeric(dtypes=np.complex64, tol=1e-3),
240        custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
241        cls.helper_get_trig_custom_limitation(np.tanh)
242    ]
243
244  @classmethod
245  def atan2(cls, harness: primitive_harness.Harness):
246    return [
247        missing_tf_kernel(
248            dtypes=[np.float16, dtypes.bfloat16],
249            devices=("cpu", "gpu"),
250            modes=("eager", "graph"))
251    ]
252
253  @classmethod
254  def bessel_i0e(cls, harness: primitive_harness.Harness):
255    return [
256        missing_tf_kernel(
257            dtypes=[dtypes.bfloat16],
258            devices=("cpu", "gpu"),
259            modes=("eager", "graph"))
260    ]
261
262  @classmethod
263  def bessel_i1e(cls, harness: primitive_harness.Harness):
264    return cls.bessel_i0e(harness)
265
266  @classmethod
267  def bitcast_convert_type(cls, harness: primitive_harness.Harness):
268    return [missing_tf_kernel(dtypes=[np.bool_])]
269
270  @classmethod
271  def cholesky(cls, harness: primitive_harness.Harness):
272
273    def custom_assert(tst, result_jax, result_tf, *, tol, **_):
274      # cholesky_p returns garbage in the strictly upper triangular part of the
275      # result, so we can safely ignore that part.
276      tst.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol)
277
278    return [
279        # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
280        Jax2TfLimitation(
281            "function not compilable",
282            dtypes=[np.complex64, np.complex128],
283            devices=("cpu", "gpu"),
284            modes="compiled"),
285        missing_tf_kernel(
286            # Interesting: on TPU, complex64 works in eager
287            # mode, but fails otherwise.
288            dtypes=[np.complex64, np.complex128],
289            devices="tpu",
290            modes=("graph", "compiled")),
291        # TODO(bchetioui): very high discrepancy in the float32/complex64 case
292        custom_numeric(dtypes=[np.float32, np.complex64], tol=1e-2),
293        custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-6),
294        custom_numeric(dtypes=[dtypes.bfloat16, np.float16], tol=5e-2),
295        custom_numeric(
296            custom_assert=custom_assert,
297            description=(
298                "May return different values in the strictly upper triangular "
299                "part of the result. This does not matter for correctness, "
300                "because this part of the matrix is not considered in the result."
301            ))
302    ]
303
304  @classmethod
305  def clamp(cls, harness: primitive_harness.Harness):
306    return [
307        missing_tf_kernel(dtypes=[np.int8, np.uint16, np.uint32, np.uint64])
308    ]
309
310  @classmethod
311  def convert_element_type(cls, harness: primitive_harness.Harness):
312    return []
313
314  @classmethod
315  def conv_general_dilated(cls, harness: primitive_harness.Harness):
316    return [
317        Jax2TfLimitation(
318            "jax2tf BUG: batch_group_count > 1 not yet converted",
319            enabled=(harness.params["batch_group_count"] > 1)),
320        custom_numeric(devices="gpu", tol=1e-4),
321        custom_numeric(devices="tpu", tol=1e-3),
322        # TODO(bchetioui): significant discrepancies in some float16 cases.
323        custom_numeric(dtypes=np.float16, tol=1),
324        # TODO(bchetioui): slight occasional discrepancy in float32 cases.
325        custom_numeric(dtypes=np.float32, devices="tpu", tol=0.5),
326        custom_numeric(dtypes=np.float32, devices="gpu", tol=1e-3),
327        custom_numeric(dtypes=np.float32, devices="cpu", tol=1e-4),
328        custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.1),
329        custom_numeric(dtypes=[np.complex64, np.complex128], devices=("cpu", "gpu"), tol=5e-4),
330        # TODO(bchetioui): slight discrepancy when going through the path using
331        # tf.nn.convolution.
332        custom_numeric(dtypes=np.float64, devices="cpu", tol=1e-13),
333    ]
334
335    # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is
336    #
337    # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False
338    #
339    # with the following assertion error in TensorFlowTrace.process_primitive:
340    #
341    # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16])
342    #
343    # Deactivating this assertion is enough to pass the test, which suggests
344    # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)).
345    # Further investigation is required to really understand this behavior,
346    # which we have not managed to reproduce as a pure TF test.
347    #
348    # This bug is low priority since it only occurs when using a non-TFXLA
349    # conversion path in compiled mode, i.e. in a context where using the
350    # TFXLA path is possible.
351    # if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False":
352    #  raise unittest.SkipTest("TODO: known but unidentified bug in compiled "
353    #                          "mode")
354
355  @classmethod
356  def cosh(cls, harness: primitive_harness.Harness):
357    return [
358        missing_tf_kernel(
359            dtypes=[np.float16],
360            devices=("cpu", "gpu"),
361            modes=("eager", "graph"))
362    ]
363
364  @classmethod
365  def cummax(cls, harness):
366    return [
367        missing_tf_kernel(
368            dtypes=[np.uint64, np.complex128],
369            devices=("cpu", "gpu"),
370        ),
371        missing_tf_kernel(
372            dtypes=[np.uint16, np.uint32, np.int8, np.complex64],),
373        custom_numeric(dtypes=np.float16, tol=0.1),
374        custom_numeric(dtypes=dtypes.bfloat16, tol=0.5)
375    ]
376
377  @classmethod
378  def cummin(cls, harness):
379    return [
380        missing_tf_kernel(
381            dtypes=[np.uint64, np.complex128],
382            devices=("cpu", "gpu"),
383        ),
384        missing_tf_kernel(
385            dtypes=[np.uint16, np.uint32, np.int8, np.complex64],),
386        custom_numeric(dtypes=np.float16, tol=0.1),
387        custom_numeric(dtypes=dtypes.bfloat16, tol=0.5),
388    ]
389
390  @classmethod
391  def cumprod(cls, harness):
392    return [
393        missing_tf_kernel(
394            dtypes=[np.uint64],
395            devices=("cpu", "gpu"),
396        ),
397        missing_tf_kernel(dtypes=[np.uint32]),
398        custom_numeric(dtypes=np.float16, tol=0.1),
399        custom_numeric(dtypes=dtypes.bfloat16, tol=0.5),
400    ]
401
402  @classmethod
403  def cumsum(cls, harness):
404    return [
405        missing_tf_kernel(
406            dtypes=[np.uint64],
407            devices=("cpu", "gpu"),
408        ),
409        missing_tf_kernel(dtypes=[np.complex64], devices="tpu"),
410        missing_tf_kernel(dtypes=[np.uint16]),
411        custom_numeric(dtypes=np.float16, tol=0.1),
412        custom_numeric(dtypes=dtypes.bfloat16, tol=0.5),
413    ]
414
415  @classmethod
416  def custom_linear_solve(cls, harness: primitive_harness.Harness):
417    return [
418        Jax2TfLimitation(
419            "TODO: large numerical discrepancy",
420            dtypes=np.float32,
421            devices="tpu",
422            expect_tf_error=False,
423            skip_comparison=True),
424        custom_numeric(dtypes=np.float32, devices="tpu", tol=0.01),
425        custom_numeric(tol=1e-3),
426    ]
427
428  @classmethod
429  def digamma(cls, harness: primitive_harness.Harness):
430    dtype = harness.dtype
431
432    # In the bfloat16 case, TF and lax both return NaN in undefined cases.
433    # digamma is not defined at 0 and -1
434    def custom_assert(tst, result_jax, result_tf, *, args, tol):
435      # lax.digamma returns NaN and tf.math.digamma returns inf
436      arg, = args
437      special_cases = (arg == 0.) | (arg == -1.)
438      nr_special_cases = np.count_nonzero(special_cases)
439      tst.assertAllClose(
440          np.full((nr_special_cases,), dtype(np.nan)),
441          result_jax[special_cases])
442      tst.assertAllClose(
443          np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases])
444      # non-special cases are equal
445      tst.assertAllClose(
446          result_jax[~special_cases],
447          result_tf[~special_cases],
448          atol=tol,
449          rtol=tol)
450
451    return [
452        missing_tf_kernel(
453            dtypes=[dtypes.bfloat16],
454            devices=("cpu", "gpu"),
455            modes=("eager", "graph")),
456        custom_numeric(dtypes=np.float64, tol=1e-13),
457        custom_numeric(dtypes=np.float32, devices=["cpu", "gpu"], tol=1e-3),
458        custom_numeric(
459            dtypes=dtypes.bfloat16,
460            custom_assert=custom_assert,
461            description=(
462                "May return different results at singularity points 0 and -1."
463                "JAX returns nan and TF returns inf"),
464            modes=("eager", "graph"))
465    ]
466
467  @classmethod
468  def div(cls, harness: primitive_harness.Harness):
469    return [
470        missing_tf_kernel(
471            dtypes=[
472                np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16
473            ],),
474        Jax2TfLimitation(
475            "TF integer division fails if divisor contains 0; JAX returns NaN",
476            dtypes=[
477                np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8,
478                np.int16, np.int32, np.int64
479            ],
480            # Only the harnesses with "singularity" will have divide by 0
481            enabled=("singularity" in harness.name))
482    ]
483
484  @classmethod
485  def dot_general(cls, harness: primitive_harness.Harness):
486    return [
487        missing_tf_kernel(
488            dtypes=[
489                np.bool_, np.uint8, np.uint16, np.uint32, np.uint64, np.int8,
490                np.int16
491            ],),
492        missing_tf_kernel(
493            dtypes=[np.int64], devices=("cpu", "gpu"), modes="compiled"),
494        custom_numeric(dtypes=dtypes.bfloat16, tol=0.3),
495        custom_numeric(
496            dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"),
497            tol=1e-5),
498        custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1),
499        custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3),
500        custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1),
501        custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01)
502    ]
503
504  @classmethod
505  def eig(cls, harness: primitive_harness.Harness):
506    compute_left_eigenvectors = harness.params["compute_left_eigenvectors"]
507    compute_right_eigenvectors = harness.params["compute_right_eigenvectors"]
508    dtype = harness.dtype
509
510    def custom_assert(tst, result_jax, result_tf, *, args, tol):
511      operand, = args
512      inner_dimension = operand.shape[-1]
513
514      # Test ported from tests.linlag_test.testEig
515      # Norm, adjusted for dimension and type.
516      def norm(x):
517        norm = np.linalg.norm(x, axis=(-2, -1))
518        return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)
519
520      def check_right_eigenvectors(a, w, vr):
521        tst.assertTrue(
522            np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
523
524      def check_left_eigenvectors(a, w, vl):
525        rank = len(a.shape)
526        aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
527        wC = jnp.conj(w)
528        check_right_eigenvectors(aH, wC, vl)
529
530      def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
531        tol = None
532        # TODO(bchetioui): numerical discrepancies
533        if dtype in [np.float32, np.complex64]:
534          tol = 1e-4
535        elif dtype in [np.float64, np.complex128]:
536          tol = 1e-13
537        closest_diff = min(abs(eigenvalues_array - eigenvalue))
538        tst.assertAllClose(
539            closest_diff, np.array(0., closest_diff.dtype), atol=tol)
540
541      all_w_jax, all_w_tf = result_jax[0], result_tf[0]
542      for idx in itertools.product(*map(range, operand.shape[:-2])):
543        w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
544        for i in range(inner_dimension):
545          check_eigenvalue_is_in_array(w_jax[i], w_tf)
546          check_eigenvalue_is_in_array(w_tf[i], w_jax)
547
548      if compute_left_eigenvectors:
549        check_left_eigenvectors(operand, all_w_tf, result_tf[1])
550      if compute_right_eigenvectors:
551        check_right_eigenvectors(operand, all_w_tf,
552                                 result_tf[1 + compute_left_eigenvectors])
553
554    return [
555        # Eig does not work in JAX on gpu or tpu
556        Jax2TfLimitation("function not compilable", modes="compiled",
557                         devices="cpu"),
558        Jax2TfLimitation(
559            "TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True",
560            enabled=(compute_left_eigenvectors and compute_right_eigenvectors)),
561        custom_numeric(
562            custom_assert=custom_assert,
563            description=("May return the eigenvalues and eigenvectors in a "
564                         "potentially different order. The eigenvectors may "
565                         "also be different, but equally valid."),
566            modes=("eager", "graph"))
567    ]
568
569  @classmethod
570  def eigh(cls, harness: primitive_harness.Harness):
571    dtype = harness.dtype
572    shape = harness.params["shape"]
573
574    def custom_assert(tst, result_jax, result_tf, *, args, tol):
575      operand, = args
576      inner_dimension = operand.shape[-1]
577
578      def check_right_eigenvectors(a, w, vr):
579        tol = 1e-16
580        # TODO(bchetioui): tolerance needs to be very high in compiled mode,
581        # specifically for eigenvectors.
582        if dtype == np.float64:
583          tol = 1e-6
584        elif dtype == np.float32:
585          tol = 1e-2
586        elif dtype in [dtypes.bfloat16, np.complex64]:
587          tol = 1e-3
588        elif dtype == np.complex128:
589          tol = 1e-13
590        tst.assertAllClose(
591            np.matmul(a, vr) - w[..., None, :] * vr,
592            np.zeros(a.shape, dtype=vr.dtype),
593            atol=tol)
594
595      def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
596        tol = None
597        if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
598          tol = 1e-3
599        elif dtype in [np.float64, np.complex128]:
600          tol = 1e-11
601        closest_diff = min(abs(eigenvalues_array - eigenvalue))
602        tst.assertAllClose(
603            closest_diff, np.array(0., closest_diff.dtype), atol=tol)
604
605      _, all_w_jax = result_jax
606      all_vr_tf, all_w_tf = result_tf
607
608      for idx in itertools.product(*map(range, operand.shape[:-2])):
609        w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
610        for i in range(inner_dimension):
611          check_eigenvalue_is_in_array(w_jax[i], w_tf)
612          check_eigenvalue_is_in_array(w_tf[i], w_jax)
613
614      check_right_eigenvectors(operand, all_w_tf, all_vr_tf)
615
616    return [
617        # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
618        Jax2TfLimitation(
619            "function not compilable",
620            dtypes=[np.complex64, np.complex128],
621            modes="compiled",
622            enabled=(shape[0] > 0)),
623        Jax2TfLimitation(
624            "TODO: numeric discrepancies",
625            dtypes=[np.float64],
626            modes="compiled",
627            devices=("cpu", "gpu"),
628            expect_tf_error=False,
629            skip_comparison=True),
630        Jax2TfLimitation(
631            "TODO: numeric discrepancies",
632            dtypes=[np.float16],
633            devices=("tpu",),
634            expect_tf_error=False,
635            skip_comparison=True),
636        custom_numeric(
637            custom_assert=custom_assert,
638            description=("May return the eigenvalues and eigenvectors in a "
639                         "potentially different order. The eigenvectors may "
640                         "also be different, but equally valid."))
641    ]
642
643  @classmethod
644  def ge(cls, harness: primitive_harness.Harness):
645    return [
646        missing_tf_kernel(dtypes=[np.bool_]),
647        missing_tf_kernel(
648            dtypes=[np.uint16, np.uint32],
649            devices=("cpu", "gpu"),
650            modes=("eager", "graph")),
651        missing_tf_kernel(
652            dtypes=[np.uint64],
653            devices=("cpu", "gpu"),
654            modes=("eager", "graph"))
655    ]
656
657  @classmethod
658  def gt(cls, harness: primitive_harness.Harness):
659    return cls.ge(harness)
660
661  @classmethod
662  def erf(cls, harness: primitive_harness.Harness):
663    return [
664        missing_tf_kernel(
665            dtypes=[dtypes.bfloat16],
666            devices=("cpu", "gpu"),
667            modes=("eager", "graph"))
668    ]
669
670  @classmethod
671  def erfc(cls, harness: primitive_harness.Harness):
672    return [
673        missing_tf_kernel(
674            dtypes=[dtypes.bfloat16],
675            devices=("cpu", "gpu"),
676            modes=("eager", "graph"))
677    ]
678
679  @classmethod
680  def erf_inv(cls, harness: primitive_harness.Harness):
681    # erf_inv is not defined for arg <= -1 or arg >= 1
682    def custom_assert(tst, result_jax, result_tf, *, args, tol):  # noqa: F811
683      arg, = args
684      # for arg < -1 or arg > 1
685      # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
686      special_cases = (arg < -1.) | (arg > 1.)
687      # non-special cases are equal
688      tst.assertAllClose(
689          result_jax[~special_cases],
690          result_tf[~special_cases],
691          atol=tol,
692          rtol=tol)
693
694    return [
695        missing_tf_kernel(
696            dtypes=[dtypes.bfloat16, np.float16],
697            devices=("cpu", "gpu"),
698            modes=("eager", "graph")),
699        custom_numeric(dtypes=[np.float32, np.float64], tol=1e-4),
700        custom_numeric(
701            dtypes=[np.float32, np.float64],
702            custom_assert=custom_assert,
703            description=(
704                "May return different results at undefined points (< -1 or > 1):"
705                " JAX returns `NaN` and TF returns `+inf` or `-inf`."))
706    ]
707
708  @classmethod
709  def expm1(cls, harness: primitive_harness.Harness):
710    return [custom_numeric(dtypes=np.float64, tol=1e-5)]
711
712  @classmethod
713  def fft(cls, harness):
714    return [
715        Jax2TfLimitation(
716            "TF function not compileable",
717            devices=("cpu", "gpu"),
718            dtypes=[np.float64, np.complex128],
719            modes="compiled"),
720        custom_numeric(tol=1e-3)
721    ]
722
723  @classmethod
724  def _pow_test_util(cls, harness: primitive_harness.Harness):
725
726    def custom_assert(tst, result_jax, result_tf, *, args, tol):
727      # NaNs are mismatched, but assertAllClose will also behave weirdly for
728      # complex numbers containing np.inf as one of their components. See
729      # https://github.com/numpy/numpy/issues/15959 for more details.
730      mask = (
731          np.isnan(result_jax) + np.isnan(result_tf) + np.isinf(result_jax) +
732          np.isinf(result_tf))
733      tst.assertAllClose(result_jax[~mask], result_tf[~mask], rtol=tol)
734
735    return [
736        custom_numeric(
737            dtypes=[np.float32, np.complex64], devices="tpu", tol=1e-2),
738        custom_numeric(
739            dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"),
740            tol=1e-3),
741        custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-12),
742        custom_numeric(dtypes=np.float16, tol=1),
743        # Values get really small for large negative powers.
744        custom_numeric(dtypes=dtypes.bfloat16, tol=3),
745        custom_numeric(
746            dtypes=[np.complex64, np.complex128],
747            custom_assert=custom_assert,
748        )
749    ]
750
751  @classmethod
752  def igamma(cls, harness: primitive_harness.Harness):
753    dtype = harness.dtype
754
755    # igamma is not defined when the first argument is <=0
756    def custom_assert(tst, result_jax, result_tf, *, args, tol):
757      arg1, arg2 = args
758      # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
759      special_cases = (arg1 == 0.) & (arg2 == 0.)
760      nr_special_cases = np.count_nonzero(special_cases)
761      tst.assertAllClose(
762          np.full((nr_special_cases,), np.nan, dtype=dtype),
763          result_jax[special_cases])
764      tst.assertAllClose(
765          np.full((nr_special_cases,), 0., dtype=dtype),
766          result_tf[special_cases])
767      # non-special cases are equal
768      tst.assertAllClose(result_jax[~special_cases], result_tf[~special_cases])
769
770    return [
771        custom_numeric(
772            custom_assert=custom_assert,
773            description=(
774                "May return different results at undefined points "
775                "(both arguments 0). JAX returns `NaN` and TF returns 0 or "
776                "JAX returns 1 and TF returns `NaN`"),
777            modes=("eager", "graph"))
778    ]
779
780  @classmethod
781  def igammac(cls, harness: primitive_harness.Harness):
782    dtype = harness.dtype
783
784    # igammac is not defined when the first argument is <=0
785    def custom_assert(tst, result_jax, result_tf, *, args, tol):  # noqa: F811
786      arg1, arg2 = args
787      # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
788      special_cases = (arg1 <= 0.) | (arg2 <= 0)
789      nr_special_cases = np.count_nonzero(special_cases)
790      tst.assertAllClose(
791          np.full((nr_special_cases,), 1., dtype=dtype),
792          result_jax[special_cases])
793      tst.assertAllClose(
794          np.full((nr_special_cases,), np.nan, dtype=dtype),
795          result_tf[special_cases])
796      # non-special cases are equal
797      tst.assertAllClose(
798          result_jax[~special_cases],
799          result_tf[~special_cases],
800          atol=tol,
801          rtol=tol)
802
803    return [
804        custom_numeric(dtypes=np.float64, tol=1e-9),
805        custom_numeric(devices="gpu", tol=1e-3),
806        custom_numeric(
807            custom_assert=custom_assert,
808            devices=("cpu", "gpu"),
809            modes=("eager", "graph"),
810            description=(
811                "May return different results at undefined points "
812                "(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or "
813                "JAX returns 1 and TF returns `NaN`")),
814    ]
815
816  @classmethod
817  def integer_pow(cls, harness: primitive_harness.Harness):
818    y = harness.params["y"]
819    return [
820        missing_tf_kernel(
821            dtypes=[
822                np.uint8, np.uint16, np.int8, np.int16, np.uint32, np.uint64
823            ],),
824        # hitting rtol = nan
825        Jax2TfLimitation(("Different overflow behavior for large exponents. It "
826                          "and `+inf`/`-inf` differently in JAX and TF."),
827                         devices="tpu",
828                         dtypes=np.complex64,
829                         enabled=(y in [1000, -1000]),
830                         expect_tf_error=False,
831                         skip_comparison=True),
832        Jax2TfLimitation(
833            "Different overflow behavior for large exponents. ",
834            dtypes=[np.int32, np.int64, np.float32],
835            enabled=(y > 10),
836            expect_tf_error=False,
837            skip_comparison=True)
838    ] + list(cls._pow_test_util(harness))
839
840  @classmethod
841  def pow(cls, harness: primitive_harness.Harness):
842    return cls._pow_test_util(harness)
843
844  @classmethod
845  def le(cls, harness: primitive_harness.Harness):
846    return [
847        missing_tf_kernel(dtypes=[np.bool_]),
848        missing_tf_kernel(
849            dtypes=[np.uint16, np.uint32],
850            devices=("cpu", "gpu"),
851            modes=("eager", "graph")),
852        missing_tf_kernel(
853            dtypes=[np.uint64],
854            devices=("cpu", "gpu"),
855            modes=("eager", "graph"))
856    ]
857
858  @classmethod
859  def lt(cls, harness: primitive_harness.Harness):
860    return cls.ge(harness)
861
862  @classmethod
863  def lgamma(cls, harness: primitive_harness.Harness):
864    return [
865        missing_tf_kernel(
866            dtypes=[dtypes.bfloat16],
867            devices=("cpu", "gpu"),
868            modes=("eager", "graph")),
869        custom_numeric(dtypes=np.float64, tol=1e-11),
870        custom_numeric(dtypes=np.float32, tol=1e-3)
871    ]
872
873  @classmethod
874  def log1p(cls, harness: primitive_harness.Harness):
875    return [
876        custom_numeric(dtypes=np.float64, tol=1e-10),
877        custom_numeric(dtypes=np.float32, tol=1e-3)
878    ]
879
880  @classmethod
881  def lu(cls, harness: primitive_harness.Harness):
882    dtype = harness.dtype
883
884    def custom_assert(tst, result_jax, result_tf, *, args, tol):
885      operand, = args
886      lu, pivots, perm = result_tf
887      batch_dims = operand.shape[:-2]
888      m, n = operand.shape[-2], operand.shape[-1]
889
890      def _make_permutation_matrix(perm):
891        result = []
892        for idx in itertools.product(*map(range, operand.shape[:-1])):
893          result += [0 if c != perm[idx] else 1 for c in range(m)]
894        result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m])
895        return result
896
897      k = min(m, n)
898      l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
899      u = jnp.triu(lu)[..., :k, :]
900      p_mat = _make_permutation_matrix(perm)
901
902      tst.assertArraysEqual(
903          lax.linalg.lu_pivots_to_permutation(pivots, m), perm)
904      tst.assertAllClose(
905          jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol)
906
907    return [
908        missing_tf_kernel(dtypes=[np.complex64], devices="tpu"),
909        custom_numeric(
910            dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1),
911        custom_numeric(
912            dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"),
913            tol=1e-5),
914        custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-13),
915        custom_numeric(
916            custom_assert=custom_assert,
917            description=("May return different, but also correct, results when "
918                         "the decomposition is not unique")),
919    ]
920
921  @classmethod
922  def _min_max_test_util(cls, harness: primitive_harness.Harness):
923    # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
924    # JAX always returns NaN, while TF returns the value NaN is compared with.
925    def custom_assert(tst, result_jax, result_tf, **_):
926      mask = np.isnan(result_jax)
927      tst.assertAllClose(result_jax[~mask], result_tf[~mask])
928
929    return [
930        missing_tf_kernel(
931            dtypes=[
932                np.bool_, np.int8, np.complex64, np.uint16, np.uint32, np.uint64
933            ],),
934        missing_tf_kernel(
935            dtypes=[np.complex128],
936            devices=("cpu", "gpu"),
937        ),
938        custom_numeric(
939            custom_assert=custom_assert,
940            description=(
941                "May return different values when one of the values is NaN. "
942                "JAX always returns NaN, while TF returns the value NaN is compared with."
943            ))
944    ]
945
946  @classmethod
947  def max(cls, harness: primitive_harness.Harness):
948    return cls._min_max_test_util(harness)
949
950  @classmethod
951  def min(cls, harness: primitive_harness.Harness):
952    return cls._min_max_test_util(harness)
953
954  @classmethod
955  def mul(cls, harness: primitive_harness.Harness):
956    return [missing_tf_kernel(dtypes=[np.uint32, np.uint64])]
957
958  @classmethod
959  def neg(cls, harness: primitive_harness.Harness):
960    return [
961        missing_tf_kernel(dtypes=[np.uint8, np.uint16, np.uint32, np.uint64],)
962    ]
963
964  @classmethod
965  def nextafter(cls, harness: primitive_harness.Harness):
966    return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
967
968  @classmethod
969  def population_count(cls, harness: primitive_harness.Harness):
970    return [
971        missing_tf_kernel(
972            dtypes=[np.uint32, np.uint64],
973            devices=("cpu", "gpu"),
974            modes=("eager", "graph"))
975    ]
976
977  @classmethod
978  def qr(cls, harness: primitive_harness.Harness):
979    # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
980    #     # jit_compile=True breaks for complex types.
981    # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
982    # - for now, the performance of the HLO QR implementation called when
983    #   compiling with TF is expected to have worse performance than the
984    #   custom calls made in JAX.
985    return [
986        custom_numeric(tol=1e-5),
987        missing_tf_kernel(
988            dtypes=[dtypes.bfloat16],
989            devices="tpu",
990        )
991    ]
992
993  @classmethod
994  def random_gamma(cls, harness: primitive_harness.Harness):
995    return [custom_numeric(devices="tpu", tol=1e-3)]
996
997  @classmethod
998  def reduce_max(cls, harness: primitive_harness.Harness):
999    return [
1000        missing_tf_kernel(dtypes=[np.complex64]),
1001        missing_tf_kernel(dtypes=[np.complex128])
1002    ]
1003
1004  @classmethod
1005  def reduce_min(cls, harness: primitive_harness.Harness):
1006    return [
1007        missing_tf_kernel(dtypes=[np.complex64]),
1008        missing_tf_kernel(dtypes=[np.complex128])
1009    ]
1010
1011  @classmethod
1012  def reduce_window_add(cls, harness):
1013    assert "add" == harness.params["computation"].__name__
1014    return [
1015        missing_tf_kernel(dtypes=[np.uint16]),
1016        missing_tf_kernel(dtypes=[np.complex64], devices="tpu"),
1017        missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu"))
1018    ]
1019
1020  @classmethod
1021  def reduce_window_mul(cls, harness):
1022    assert "mul" == harness.params["computation"].__name__
1023    return [
1024        missing_tf_kernel(dtypes=[np.uint32]),
1025        missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu"))
1026    ]
1027
1028  @classmethod
1029  def reduce_window_min(cls, harness):
1030    assert "min" == harness.params["computation"].__name__
1031    return [
1032        missing_tf_kernel(
1033            dtypes=[np.uint32, np.uint16, np.bool_, np.complex64, np.int8],),
1034        missing_tf_kernel(
1035            dtypes=[np.uint64, np.complex128],
1036            devices=("cpu", "gpu"),
1037        )
1038    ]
1039
1040  @classmethod
1041  def reduce_window_max(cls, harness):
1042    assert "max" == harness.params["computation"].__name__
1043    dtype = harness.dtype
1044    init_value = harness.params["init_value"]
1045    return [
1046        missing_tf_kernel(dtypes=[np.uint32, np.bool_, np.complex64]),
1047        missing_tf_kernel(
1048            dtypes=[np.uint64, np.complex128],
1049            devices=("cpu", "gpu"),
1050        ),
1051        Jax2TfLimitation(
1052            "TF kernel missing, except when the initial_value is the minimum for the dtype",
1053            dtypes=[np.uint16, np.int8],
1054            enabled=((dtype == np.uint16 and init_value != 0) or
1055                     (dtype == np.int8 and init_value != -128)))
1056    ]
1057
1058  @classmethod
1059  def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
1060    return [
1061        custom_numeric(dtypes=np.float64, tol=1e-14),
1062        missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])
1063    ]
1064
1065  @classmethod
1066  def rem(cls, harness: primitive_harness.Harness):
1067    return [
1068        missing_tf_kernel(
1069            dtypes=[
1070                np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16
1071            ],),
1072        Jax2TfLimitation(
1073            "TF integer division fails if divisor contains 0; JAX returns NaN",
1074            dtypes=[
1075                np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8,
1076                np.int16, np.int32, np.int64
1077            ],
1078            # Only the harnesses with "singularity" will have divide by 0
1079            enabled=("singularity" in harness.name)),
1080        missing_tf_kernel(
1081            dtypes=[np.float16],
1082            devices=("cpu", "gpu"),
1083            modes=("eager", "graph")),
1084    ]
1085
1086  @classmethod
1087  def rev(cls, harness: primitive_harness.Harness):
1088    return [missing_tf_kernel(dtypes=[np.uint32, np.uint64])]
1089
1090  @classmethod
1091  def round(cls, harness: primitive_harness.Harness):
1092    return [
1093        missing_tf_kernel(
1094            dtypes=[dtypes.bfloat16],
1095            devices=("cpu", "gpu"),
1096            modes=("eager", "graph"))
1097    ]
1098
1099  @classmethod
1100  def rsqrt(cls, harness: primitive_harness.Harness):
1101    return [
1102        missing_tf_kernel(
1103            dtypes=[dtypes.bfloat16],
1104            devices=("cpu", "gpu"),
1105            modes=("eager", "graph"))
1106    ]
1107
1108  @classmethod
1109  def scatter_add(cls, harness):
1110    return [
1111        missing_tf_kernel(dtypes=[np.uint16, np.uint64, np.bool_],),
1112        missing_tf_kernel(
1113            dtypes=[np.complex64],
1114            devices="tpu",
1115        ),
1116    ]
1117
1118  @classmethod
1119  def scatter_max(cls, harness):
1120    return [
1121        missing_tf_kernel(
1122            dtypes=[
1123                np.int8, np.uint16, np.uint32, np.uint64, np.complex64,
1124                np.complex128, np.bool_
1125            ],)
1126    ]
1127
1128  @classmethod
1129  def scatter_min(cls, harness):
1130    return [
1131        missing_tf_kernel(
1132            dtypes=[
1133                np.int8, np.uint16, np.uint32, np.complex64, np.bool_,
1134                np.uint64, np.complex128
1135            ],)
1136    ]
1137
1138  @classmethod
1139  def scatter_mul(cls, harness):
1140    return [
1141        missing_tf_kernel(dtypes=[np.uint32, np.uint64, np.bool_],),
1142        missing_tf_kernel(
1143            dtypes=[np.complex64],
1144            devices="tpu",
1145        ),
1146    ]
1147
1148  @classmethod
1149  def select_and_gather_add(cls, harness):
1150    return [
1151        missing_tf_kernel(
1152            dtypes=[np.float32],
1153            devices="tpu",
1154            description=(
1155                "This JAX primitives is not not exposed directly in the JAX API "
1156                "but arises from JVP of `lax.reduce_window` for reducers "
1157                "`lax.max` or `lax.min`. It also arises from second-order "
1158                "VJP of the same. Implemented using XlaReduceWindow")),
1159        Jax2TfLimitation((
1160            "jax2tf unimplemented for 64-bit inputs because the current implementation "
1161            "relies on packing two values into a single value. This can be "
1162            "fixed by using a variadic XlaReduceWindow, when available"),
1163                         dtypes=[np.float64],
1164                         devices=("cpu", "gpu"))
1165    ]
1166
1167  @classmethod
1168  def select_and_scatter_add(cls, harness):
1169    return [
1170        missing_tf_kernel(dtypes=[np.uint16]),
1171        missing_tf_kernel(
1172            dtypes=[np.uint64],
1173            devices=("cpu", "gpu"),
1174        )
1175    ]
1176
1177  @classmethod
1178  def sign(cls, harness: primitive_harness.Harness):
1179    return [
1180        missing_tf_kernel(
1181            dtypes=[
1182                np.uint32, np.uint16, np.int16, np.int8, np.uint8, np.uint64
1183            ],)
1184    ]
1185
1186  @classmethod
1187  def sinh(cls, harness: primitive_harness.Harness):
1188    return [
1189        missing_tf_kernel(
1190            dtypes=[np.float16],
1191            devices=("cpu", "gpu"),
1192            modes=("eager", "graph"))
1193    ]
1194
1195  @classmethod
1196  def sort(cls, harness: primitive_harness.Harness):
1197    return [
1198        Jax2TfLimitation(
1199            # I think that this is because TF is running on CPU even for GPU tests?
1200            "TODO: TF non-stable multiple-array sort",
1201            devices="gpu",
1202            enabled=(harness.params["num_arrays"] > 1 and
1203                     not harness.params["is_stable"]),
1204            expect_tf_error=False,
1205            skip_comparison=True),
1206        missing_tf_kernel(
1207            dtypes=[np.complex128, np.float64], devices=("cpu", "gpu")),
1208        missing_tf_kernel(dtypes=[np.bool_],),
1209    ]
1210
1211  @classmethod
1212  def sub(cls, harness):
1213    return [missing_tf_kernel(dtypes=[np.uint64])]
1214
1215  @classmethod
1216  def svd(cls, harness: primitive_harness.Harness):
1217    # TODO: slow test
1218
1219    def custom_assert(tst, r_jax, r_tf, *, args, tol):
1220
1221      def _reconstruct_operand(result, is_tf: bool):
1222        # Reconstructing operand as documented in numpy.linalg.svd (see
1223        # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
1224        s, u, v = result
1225        U = u[..., :s.shape[-1]]
1226        V = v[..., :s.shape[-1], :]
1227        S = s[..., None, :]
1228        return jnp.matmul(U * S, V), s.shape, u.shape, v.shape
1229
1230      if harness.params["compute_uv"]:
1231        r_jax_reconstructed = _reconstruct_operand(r_jax, False)
1232        r_tf_reconstructed = _reconstruct_operand(r_tf, True)
1233        tst.assertAllClose(
1234            r_jax_reconstructed, r_tf_reconstructed, atol=tol, rtol=tol)
1235      else:
1236        tst.assertAllClose(r_jax, r_tf, atol=tol, rtol=tol)
1237
1238    return [
1239        # Works in JAX for complex due to custom calls on cpu and gpu
1240        Jax2TfLimitation(
1241            "function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint`",
1242            dtypes=[np.complex64, np.complex128],
1243            devices=("cpu", "gpu"),
1244            modes=("compiled",)),
1245        missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
1246        custom_numeric(tol=1e-4),
1247        custom_numeric(custom_assert=custom_assert)
1248    ]
1249
1250  @classmethod
1251  def tan(cls, harness):
1252    return [
1253        custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-4),
1254        custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3),
1255        custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12)]
1256
1257  @classmethod
1258  def tanh(cls, harness):
1259    return [
1260        custom_numeric(dtypes=np.complex128, tol=1e-7),
1261        custom_numeric(dtypes=np.complex64, tol=1e-4)]
1262
1263  @classmethod
1264  def top_k(cls, harness):
1265
1266    def custom_assert(tst, result_jax, result_tf, **_):
1267      assert len(result_jax) == len(result_tf)
1268      # TODO: TF and JAX sort [inf, nan] differently.
1269      first_arr_jax, first_arr_tf = result_jax[0], result_tf[0]
1270      if np.all(first_arr_jax == first_arr_tf):
1271        for arr_jax, arr_tf in zip(result_jax, result_tf):
1272          tst.assertArraysEqual(arr_jax, arr_tf)
1273      else:
1274        mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(first_arr_tf)
1275        tst.assertArraysEqual(first_arr_jax[~mask_jax], first_arr_tf[~mask_tf])
1276
1277    return [
1278        missing_tf_kernel(
1279            dtypes=[np.uint64, np.int64],
1280            devices=("cpu", "gpu"),
1281            modes="compiled"),
1282        custom_numeric(
1283            dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64],
1284            custom_assert=custom_assert,
1285            description=(
1286               "Produces different results when the array contains `inf` and `NaN`"
1287               " (they are sorted differently in TF vs. XLA).")
1288        )]
1289
1290  @classmethod
1291  def triangular_solve(cls, harness: primitive_harness.Harness):
1292    return [
1293        missing_tf_kernel(dtypes=[dtypes.bfloat16]),
1294        missing_tf_kernel(
1295            dtypes=[np.float16],
1296            devices=("gpu", "cpu"),
1297            modes=("eager", "graph")),
1298        custom_numeric(dtypes=np.float32, tol=5e-3)
1299    ]
1300
1301
1302def custom_numeric(
1303    *,
1304    description="custom numeric comparison",
1305    dtypes=(),  # All
1306    modes=("eager", "graph", "compiled"),
1307    devices=("cpu", "gpu", "tpu"),
1308    custom_assert=None,
1309    tol=None) -> Jax2TfLimitation:
1310
1311  return Jax2TfLimitation(
1312      description,
1313      expect_tf_error=False,
1314      dtypes=dtypes,
1315      devices=devices,
1316      modes=modes,
1317      custom_assert=custom_assert,
1318      tol=tol)
1319
1320
1321def missing_tf_kernel(
1322    *,
1323    description="op not defined for dtype",
1324    dtypes,
1325    modes=("eager", "graph", "compiled"),
1326    devices=("cpu", "gpu", "tpu")
1327) -> Jax2TfLimitation:
1328
1329  return Jax2TfLimitation(
1330      description,
1331      dtypes=dtypes,
1332      devices=devices,
1333      modes=modes)
1334