1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""See primitives_test docstring for how the Jax2TfLimitations are used""" 15 16import itertools 17import numpy as np 18from typing import Any, Callable, Optional, Sequence 19 20from jax import dtypes 21from jax import lax 22from jax import numpy as jnp 23 24from jax.experimental.jax2tf.tests import primitive_harness 25 26DType = Any 27 28 29class Jax2TfLimitation(primitive_harness.Limitation): 30 """Specific primitive limitations for jax2tf. 31 32 See the primitive_test module docstring for details. 33 """ 34 def __init__( 35 self, 36 description: str, 37 *, 38 devices: Sequence[str] = ("cpu", "gpu", "tpu"), 39 dtypes: Sequence[DType] = (), 40 enabled: bool = True, 41 # jax2tf specific 42 modes=("eager", "graph", "compiled"), 43 skip_tf_run=False, 44 expect_tf_error: bool = True, 45 skip_comparison=False, 46 custom_assert: Optional[Callable] = None, 47 tol=None): 48 """See the primitive_harness.Limitation common arguments. 49 50 Args : 51 modes: one of "eager", "graph", "compiled" 52 skip_tf_run: if set will skip the TF execution. Use this sparingly, 53 prefer `expect_tf_error`. Use only when the test cannot recover from 54 the TF error. 55 expect_tf_error: if set, then expect a TF error in the given mode when 56 executing the result of jax2tf conversion. If not set, then the 57 limitation must have a custom_assert or non-default tol. 58 skip_comparison: skips the numeric comparison. 59 tol: a tolerance to use for both atol and rtol. We will use the maximum 60 tolerance over all the applicable limitations, irrespective of their 61 order. 62 custom_assert: if given, then execute as 63 `custom_assert(tst, result_jax, result_tf, args=args, tol=tol)`, where 64 `tst` is the current TestCase instance, and args are the input 65 arguments that the harness created. The `tol` is the maximum tolerance 66 based on the applicable limitations. 67 `result_tf` is already converted to NumPy arrays. 68 """ 69 super().__init__( 70 description, 71 devices=devices, 72 dtypes=dtypes, 73 enabled=enabled) 74 if isinstance(modes, str): 75 modes = (modes,) 76 assert all(m in ["eager", "graph", "compiled"] for m in modes) 77 self.modes = modes 78 self.expect_tf_error = expect_tf_error 79 self.skip_tf_run = skip_tf_run 80 self.custom_assert = custom_assert 81 self.tol = tol 82 self.skip_comparison = skip_comparison 83 84 85 def get_max_tolerance_limitation( 86 self, limitations: Sequence["Jax2TfLimitation"]) -> Optional["Jax2TfLimitation"]: 87 """Pick the tolerance limitation that establishes the maximum tolerance""" 88 # TODO: it would be best if the limitations with tolerance are mutually exclusive 89 # and we don't have to compute the maximum 90 # TODO: we made this an instance method only so that we don't have to import 91 # this module from tf_test.util. 92 max_tol_lim = None 93 for l in limitations: 94 if l.tol is not None: 95 if max_tol_lim is None or l.tol > max_tol_lim.tol: 96 max_tol_lim = l 97 return max_tol_lim 98 99 def filter(self, # type: ignore[override] 100 dtype: Optional[DType] = None, 101 device: Optional[str] = None, 102 mode: Optional[str] = None) -> bool: 103 return ((mode is None or mode in self.modes) and 104 super().filter(device=device, dtype=dtype)) 105 106 107 @classmethod 108 def limitations_for_harness( 109 cls, harness: primitive_harness.Harness) -> Sequence["Jax2TfLimitation"]: 110 group_method = getattr(cls, harness.group_name, None) 111 if harness.group_name in cls.harness_groups_no_limitations: 112 assert group_method is None, ( 113 f"Harness group {harness.group_name} is both in " 114 f"'harness_groups_no_limitations' and has a custom " 115 f"Jax2TfLimitation.classmethod defined (see module docstring)" 116 ) 117 return [] 118 else: 119 assert group_method is not None, ( 120 f"Harness group {harness.group_name} must be either part of " 121 f"'harness_groups_no_limitations' or must have a custom " 122 f"Jax2TfLimitation.classmethod defined (see module docstring)" 123 ) 124 limitations = group_method(harness) 125 assert isinstance(limitations, (list, tuple)) 126 return limitations 127 128 129 # We keep here the explicit set of groups for which we don't have limitations 130 harness_groups_no_limitations = { 131 "abs", "and", "argmin", "argmax", "broadcast", "broadcast_in_dim", "ceil", 132 "concatenate", "cos", "complex", "conj", "device_put", "dynamic_slice", 133 "dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag", 134 "iota", "is_finite", "ne", "not", "or", "pad", "random_split", 135 "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape", 136 "select", "shift_left", "shift_right_logical", "shift_right_arithmetic", 137 "sin", "slice", "sqrt", "squeeze", "stop_gradient", "tie_in", "transpose", 138 "xor", "zeros_like" 139 } 140 141 142 143 @classmethod 144 def helper_get_trig_custom_limitation(cls, np_inverse): 145 146 def custom_assert(tst, result_jax, result_tf, *, args, tol): 147 operand, = args 148 tst.assertAllClose(operand, np_inverse(result_tf), atol=tol, rtol=tol) 149 150 return custom_numeric( 151 description="May return different but still correct results", 152 dtypes=[np.complex64, np.complex128], 153 custom_assert=custom_assert, 154 modes=("eager", "graph")) 155 156 @classmethod 157 def acos(cls, harness: primitive_harness.Harness): 158 return [ 159 missing_tf_kernel( 160 dtypes=[np.float16, dtypes.bfloat16, np.complex64], 161 devices=("cpu", "gpu"), 162 modes=("eager", "graph")), 163 missing_tf_kernel( 164 dtypes=[np.complex128], 165 devices=("cpu", "gpu"), 166 modes=("eager", "graph")), 167 custom_numeric(dtypes=np.complex128, tol=1e-13), 168 custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-3), 169 custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4), 170 cls.helper_get_trig_custom_limitation(np.cos), 171 ] 172 173 @classmethod 174 def acosh(cls, harness: primitive_harness.Harness): 175 return [ 176 missing_tf_kernel( 177 dtypes=[dtypes.bfloat16, np.float16], 178 devices=("cpu", "gpu"), 179 modes=("eager", "graph")), 180 custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), 181 custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), 182 cls.helper_get_trig_custom_limitation(np.cosh) 183 ] 184 185 @classmethod 186 def add(cls, harness: primitive_harness.Harness): 187 return [ 188 missing_tf_kernel(dtypes=[np.uint16]), 189 missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu")) 190 ] 191 192 @classmethod 193 # Also called add_jaxvals 194 def add_any(cls, harness: primitive_harness.Harness): 195 return [missing_tf_kernel(dtypes=[np.uint16, np.uint64])] 196 197 @classmethod 198 def asin(cls, harness: primitive_harness.Harness): 199 return [ 200 missing_tf_kernel( 201 dtypes=[np.float16, dtypes.bfloat16], 202 devices=("cpu", "gpu"), 203 modes=("eager", "graph")), 204 missing_tf_kernel(dtypes=[np.complex64, np.complex128]), 205 cls.helper_get_trig_custom_limitation(np.sin) 206 ] 207 208 @classmethod 209 def asinh(cls, harness: primitive_harness.Harness): 210 return [ 211 missing_tf_kernel( 212 dtypes=[np.float16, dtypes.bfloat16], 213 devices=("cpu", "gpu"), 214 modes=("eager", "graph")), 215 custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), 216 custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), 217 cls.helper_get_trig_custom_limitation(np.sinh) 218 ] 219 220 @classmethod 221 def atan(cls, harness: primitive_harness.Harness): 222 return [ 223 missing_tf_kernel( 224 dtypes=[np.float16, dtypes.bfloat16], 225 devices=("cpu", "gpu"), 226 modes=("eager", "graph")), 227 missing_tf_kernel(dtypes=[np.complex64, np.complex128]), 228 cls.helper_get_trig_custom_limitation(np.tan) 229 ] 230 231 @classmethod 232 def atanh(cls, harness: primitive_harness.Harness): 233 return [ 234 missing_tf_kernel( 235 dtypes=[np.float16, dtypes.bfloat16], 236 devices=("cpu", "gpu"), 237 modes=("eager", "graph")), 238 custom_numeric(dtypes=np.float64, tol=1e-14), 239 custom_numeric(dtypes=np.complex64, tol=1e-3), 240 custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12), 241 cls.helper_get_trig_custom_limitation(np.tanh) 242 ] 243 244 @classmethod 245 def atan2(cls, harness: primitive_harness.Harness): 246 return [ 247 missing_tf_kernel( 248 dtypes=[np.float16, dtypes.bfloat16], 249 devices=("cpu", "gpu"), 250 modes=("eager", "graph")) 251 ] 252 253 @classmethod 254 def bessel_i0e(cls, harness: primitive_harness.Harness): 255 return [ 256 missing_tf_kernel( 257 dtypes=[dtypes.bfloat16], 258 devices=("cpu", "gpu"), 259 modes=("eager", "graph")) 260 ] 261 262 @classmethod 263 def bessel_i1e(cls, harness: primitive_harness.Harness): 264 return cls.bessel_i0e(harness) 265 266 @classmethod 267 def bitcast_convert_type(cls, harness: primitive_harness.Harness): 268 return [missing_tf_kernel(dtypes=[np.bool_])] 269 270 @classmethod 271 def cholesky(cls, harness: primitive_harness.Harness): 272 273 def custom_assert(tst, result_jax, result_tf, *, tol, **_): 274 # cholesky_p returns garbage in the strictly upper triangular part of the 275 # result, so we can safely ignore that part. 276 tst.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol) 277 278 return [ 279 # See https://github.com/google/jax/pull/3775#issuecomment-659407824; 280 Jax2TfLimitation( 281 "function not compilable", 282 dtypes=[np.complex64, np.complex128], 283 devices=("cpu", "gpu"), 284 modes="compiled"), 285 missing_tf_kernel( 286 # Interesting: on TPU, complex64 works in eager 287 # mode, but fails otherwise. 288 dtypes=[np.complex64, np.complex128], 289 devices="tpu", 290 modes=("graph", "compiled")), 291 # TODO(bchetioui): very high discrepancy in the float32/complex64 case 292 custom_numeric(dtypes=[np.float32, np.complex64], tol=1e-2), 293 custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-6), 294 custom_numeric(dtypes=[dtypes.bfloat16, np.float16], tol=5e-2), 295 custom_numeric( 296 custom_assert=custom_assert, 297 description=( 298 "May return different values in the strictly upper triangular " 299 "part of the result. This does not matter for correctness, " 300 "because this part of the matrix is not considered in the result." 301 )) 302 ] 303 304 @classmethod 305 def clamp(cls, harness: primitive_harness.Harness): 306 return [ 307 missing_tf_kernel(dtypes=[np.int8, np.uint16, np.uint32, np.uint64]) 308 ] 309 310 @classmethod 311 def convert_element_type(cls, harness: primitive_harness.Harness): 312 return [] 313 314 @classmethod 315 def conv_general_dilated(cls, harness: primitive_harness.Harness): 316 return [ 317 Jax2TfLimitation( 318 "jax2tf BUG: batch_group_count > 1 not yet converted", 319 enabled=(harness.params["batch_group_count"] > 1)), 320 custom_numeric(devices="gpu", tol=1e-4), 321 custom_numeric(devices="tpu", tol=1e-3), 322 # TODO(bchetioui): significant discrepancies in some float16 cases. 323 custom_numeric(dtypes=np.float16, tol=1), 324 # TODO(bchetioui): slight occasional discrepancy in float32 cases. 325 custom_numeric(dtypes=np.float32, devices="tpu", tol=0.5), 326 custom_numeric(dtypes=np.float32, devices="gpu", tol=1e-3), 327 custom_numeric(dtypes=np.float32, devices="cpu", tol=1e-4), 328 custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.1), 329 custom_numeric(dtypes=[np.complex64, np.complex128], devices=("cpu", "gpu"), tol=5e-4), 330 # TODO(bchetioui): slight discrepancy when going through the path using 331 # tf.nn.convolution. 332 custom_numeric(dtypes=np.float64, devices="cpu", tol=1e-13), 333 ] 334 335 # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is 336 # 337 # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False 338 # 339 # with the following assertion error in TensorFlowTrace.process_primitive: 340 # 341 # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16]) 342 # 343 # Deactivating this assertion is enough to pass the test, which suggests 344 # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)). 345 # Further investigation is required to really understand this behavior, 346 # which we have not managed to reproduce as a pure TF test. 347 # 348 # This bug is low priority since it only occurs when using a non-TFXLA 349 # conversion path in compiled mode, i.e. in a context where using the 350 # TFXLA path is possible. 351 # if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False": 352 # raise unittest.SkipTest("TODO: known but unidentified bug in compiled " 353 # "mode") 354 355 @classmethod 356 def cosh(cls, harness: primitive_harness.Harness): 357 return [ 358 missing_tf_kernel( 359 dtypes=[np.float16], 360 devices=("cpu", "gpu"), 361 modes=("eager", "graph")) 362 ] 363 364 @classmethod 365 def cummax(cls, harness): 366 return [ 367 missing_tf_kernel( 368 dtypes=[np.uint64, np.complex128], 369 devices=("cpu", "gpu"), 370 ), 371 missing_tf_kernel( 372 dtypes=[np.uint16, np.uint32, np.int8, np.complex64],), 373 custom_numeric(dtypes=np.float16, tol=0.1), 374 custom_numeric(dtypes=dtypes.bfloat16, tol=0.5) 375 ] 376 377 @classmethod 378 def cummin(cls, harness): 379 return [ 380 missing_tf_kernel( 381 dtypes=[np.uint64, np.complex128], 382 devices=("cpu", "gpu"), 383 ), 384 missing_tf_kernel( 385 dtypes=[np.uint16, np.uint32, np.int8, np.complex64],), 386 custom_numeric(dtypes=np.float16, tol=0.1), 387 custom_numeric(dtypes=dtypes.bfloat16, tol=0.5), 388 ] 389 390 @classmethod 391 def cumprod(cls, harness): 392 return [ 393 missing_tf_kernel( 394 dtypes=[np.uint64], 395 devices=("cpu", "gpu"), 396 ), 397 missing_tf_kernel(dtypes=[np.uint32]), 398 custom_numeric(dtypes=np.float16, tol=0.1), 399 custom_numeric(dtypes=dtypes.bfloat16, tol=0.5), 400 ] 401 402 @classmethod 403 def cumsum(cls, harness): 404 return [ 405 missing_tf_kernel( 406 dtypes=[np.uint64], 407 devices=("cpu", "gpu"), 408 ), 409 missing_tf_kernel(dtypes=[np.complex64], devices="tpu"), 410 missing_tf_kernel(dtypes=[np.uint16]), 411 custom_numeric(dtypes=np.float16, tol=0.1), 412 custom_numeric(dtypes=dtypes.bfloat16, tol=0.5), 413 ] 414 415 @classmethod 416 def custom_linear_solve(cls, harness: primitive_harness.Harness): 417 return [ 418 Jax2TfLimitation( 419 "TODO: large numerical discrepancy", 420 dtypes=np.float32, 421 devices="tpu", 422 expect_tf_error=False, 423 skip_comparison=True), 424 custom_numeric(dtypes=np.float32, devices="tpu", tol=0.01), 425 custom_numeric(tol=1e-3), 426 ] 427 428 @classmethod 429 def digamma(cls, harness: primitive_harness.Harness): 430 dtype = harness.dtype 431 432 # In the bfloat16 case, TF and lax both return NaN in undefined cases. 433 # digamma is not defined at 0 and -1 434 def custom_assert(tst, result_jax, result_tf, *, args, tol): 435 # lax.digamma returns NaN and tf.math.digamma returns inf 436 arg, = args 437 special_cases = (arg == 0.) | (arg == -1.) 438 nr_special_cases = np.count_nonzero(special_cases) 439 tst.assertAllClose( 440 np.full((nr_special_cases,), dtype(np.nan)), 441 result_jax[special_cases]) 442 tst.assertAllClose( 443 np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases]) 444 # non-special cases are equal 445 tst.assertAllClose( 446 result_jax[~special_cases], 447 result_tf[~special_cases], 448 atol=tol, 449 rtol=tol) 450 451 return [ 452 missing_tf_kernel( 453 dtypes=[dtypes.bfloat16], 454 devices=("cpu", "gpu"), 455 modes=("eager", "graph")), 456 custom_numeric(dtypes=np.float64, tol=1e-13), 457 custom_numeric(dtypes=np.float32, devices=["cpu", "gpu"], tol=1e-3), 458 custom_numeric( 459 dtypes=dtypes.bfloat16, 460 custom_assert=custom_assert, 461 description=( 462 "May return different results at singularity points 0 and -1." 463 "JAX returns nan and TF returns inf"), 464 modes=("eager", "graph")) 465 ] 466 467 @classmethod 468 def div(cls, harness: primitive_harness.Harness): 469 return [ 470 missing_tf_kernel( 471 dtypes=[ 472 np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16 473 ],), 474 Jax2TfLimitation( 475 "TF integer division fails if divisor contains 0; JAX returns NaN", 476 dtypes=[ 477 np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, 478 np.int16, np.int32, np.int64 479 ], 480 # Only the harnesses with "singularity" will have divide by 0 481 enabled=("singularity" in harness.name)) 482 ] 483 484 @classmethod 485 def dot_general(cls, harness: primitive_harness.Harness): 486 return [ 487 missing_tf_kernel( 488 dtypes=[ 489 np.bool_, np.uint8, np.uint16, np.uint32, np.uint64, np.int8, 490 np.int16 491 ],), 492 missing_tf_kernel( 493 dtypes=[np.int64], devices=("cpu", "gpu"), modes="compiled"), 494 custom_numeric(dtypes=dtypes.bfloat16, tol=0.3), 495 custom_numeric( 496 dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"), 497 tol=1e-5), 498 custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1), 499 custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3), 500 custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1), 501 custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01) 502 ] 503 504 @classmethod 505 def eig(cls, harness: primitive_harness.Harness): 506 compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] 507 compute_right_eigenvectors = harness.params["compute_right_eigenvectors"] 508 dtype = harness.dtype 509 510 def custom_assert(tst, result_jax, result_tf, *, args, tol): 511 operand, = args 512 inner_dimension = operand.shape[-1] 513 514 # Test ported from tests.linlag_test.testEig 515 # Norm, adjusted for dimension and type. 516 def norm(x): 517 norm = np.linalg.norm(x, axis=(-2, -1)) 518 return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) 519 520 def check_right_eigenvectors(a, w, vr): 521 tst.assertTrue( 522 np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) 523 524 def check_left_eigenvectors(a, w, vl): 525 rank = len(a.shape) 526 aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) 527 wC = jnp.conj(w) 528 check_right_eigenvectors(aH, wC, vl) 529 530 def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): 531 tol = None 532 # TODO(bchetioui): numerical discrepancies 533 if dtype in [np.float32, np.complex64]: 534 tol = 1e-4 535 elif dtype in [np.float64, np.complex128]: 536 tol = 1e-13 537 closest_diff = min(abs(eigenvalues_array - eigenvalue)) 538 tst.assertAllClose( 539 closest_diff, np.array(0., closest_diff.dtype), atol=tol) 540 541 all_w_jax, all_w_tf = result_jax[0], result_tf[0] 542 for idx in itertools.product(*map(range, operand.shape[:-2])): 543 w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] 544 for i in range(inner_dimension): 545 check_eigenvalue_is_in_array(w_jax[i], w_tf) 546 check_eigenvalue_is_in_array(w_tf[i], w_jax) 547 548 if compute_left_eigenvectors: 549 check_left_eigenvectors(operand, all_w_tf, result_tf[1]) 550 if compute_right_eigenvectors: 551 check_right_eigenvectors(operand, all_w_tf, 552 result_tf[1 + compute_left_eigenvectors]) 553 554 return [ 555 # Eig does not work in JAX on gpu or tpu 556 Jax2TfLimitation("function not compilable", modes="compiled", 557 devices="cpu"), 558 Jax2TfLimitation( 559 "TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True", 560 enabled=(compute_left_eigenvectors and compute_right_eigenvectors)), 561 custom_numeric( 562 custom_assert=custom_assert, 563 description=("May return the eigenvalues and eigenvectors in a " 564 "potentially different order. The eigenvectors may " 565 "also be different, but equally valid."), 566 modes=("eager", "graph")) 567 ] 568 569 @classmethod 570 def eigh(cls, harness: primitive_harness.Harness): 571 dtype = harness.dtype 572 shape = harness.params["shape"] 573 574 def custom_assert(tst, result_jax, result_tf, *, args, tol): 575 operand, = args 576 inner_dimension = operand.shape[-1] 577 578 def check_right_eigenvectors(a, w, vr): 579 tol = 1e-16 580 # TODO(bchetioui): tolerance needs to be very high in compiled mode, 581 # specifically for eigenvectors. 582 if dtype == np.float64: 583 tol = 1e-6 584 elif dtype == np.float32: 585 tol = 1e-2 586 elif dtype in [dtypes.bfloat16, np.complex64]: 587 tol = 1e-3 588 elif dtype == np.complex128: 589 tol = 1e-13 590 tst.assertAllClose( 591 np.matmul(a, vr) - w[..., None, :] * vr, 592 np.zeros(a.shape, dtype=vr.dtype), 593 atol=tol) 594 595 def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): 596 tol = None 597 if dtype in [dtypes.bfloat16, np.float32, np.complex64]: 598 tol = 1e-3 599 elif dtype in [np.float64, np.complex128]: 600 tol = 1e-11 601 closest_diff = min(abs(eigenvalues_array - eigenvalue)) 602 tst.assertAllClose( 603 closest_diff, np.array(0., closest_diff.dtype), atol=tol) 604 605 _, all_w_jax = result_jax 606 all_vr_tf, all_w_tf = result_tf 607 608 for idx in itertools.product(*map(range, operand.shape[:-2])): 609 w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] 610 for i in range(inner_dimension): 611 check_eigenvalue_is_in_array(w_jax[i], w_tf) 612 check_eigenvalue_is_in_array(w_tf[i], w_jax) 613 614 check_right_eigenvectors(operand, all_w_tf, all_vr_tf) 615 616 return [ 617 # See https://github.com/google/jax/pull/3775#issuecomment-659407824; 618 Jax2TfLimitation( 619 "function not compilable", 620 dtypes=[np.complex64, np.complex128], 621 modes="compiled", 622 enabled=(shape[0] > 0)), 623 Jax2TfLimitation( 624 "TODO: numeric discrepancies", 625 dtypes=[np.float64], 626 modes="compiled", 627 devices=("cpu", "gpu"), 628 expect_tf_error=False, 629 skip_comparison=True), 630 Jax2TfLimitation( 631 "TODO: numeric discrepancies", 632 dtypes=[np.float16], 633 devices=("tpu",), 634 expect_tf_error=False, 635 skip_comparison=True), 636 custom_numeric( 637 custom_assert=custom_assert, 638 description=("May return the eigenvalues and eigenvectors in a " 639 "potentially different order. The eigenvectors may " 640 "also be different, but equally valid.")) 641 ] 642 643 @classmethod 644 def ge(cls, harness: primitive_harness.Harness): 645 return [ 646 missing_tf_kernel(dtypes=[np.bool_]), 647 missing_tf_kernel( 648 dtypes=[np.uint16, np.uint32], 649 devices=("cpu", "gpu"), 650 modes=("eager", "graph")), 651 missing_tf_kernel( 652 dtypes=[np.uint64], 653 devices=("cpu", "gpu"), 654 modes=("eager", "graph")) 655 ] 656 657 @classmethod 658 def gt(cls, harness: primitive_harness.Harness): 659 return cls.ge(harness) 660 661 @classmethod 662 def erf(cls, harness: primitive_harness.Harness): 663 return [ 664 missing_tf_kernel( 665 dtypes=[dtypes.bfloat16], 666 devices=("cpu", "gpu"), 667 modes=("eager", "graph")) 668 ] 669 670 @classmethod 671 def erfc(cls, harness: primitive_harness.Harness): 672 return [ 673 missing_tf_kernel( 674 dtypes=[dtypes.bfloat16], 675 devices=("cpu", "gpu"), 676 modes=("eager", "graph")) 677 ] 678 679 @classmethod 680 def erf_inv(cls, harness: primitive_harness.Harness): 681 # erf_inv is not defined for arg <= -1 or arg >= 1 682 def custom_assert(tst, result_jax, result_tf, *, args, tol): # noqa: F811 683 arg, = args 684 # for arg < -1 or arg > 1 685 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf 686 special_cases = (arg < -1.) | (arg > 1.) 687 # non-special cases are equal 688 tst.assertAllClose( 689 result_jax[~special_cases], 690 result_tf[~special_cases], 691 atol=tol, 692 rtol=tol) 693 694 return [ 695 missing_tf_kernel( 696 dtypes=[dtypes.bfloat16, np.float16], 697 devices=("cpu", "gpu"), 698 modes=("eager", "graph")), 699 custom_numeric(dtypes=[np.float32, np.float64], tol=1e-4), 700 custom_numeric( 701 dtypes=[np.float32, np.float64], 702 custom_assert=custom_assert, 703 description=( 704 "May return different results at undefined points (< -1 or > 1):" 705 " JAX returns `NaN` and TF returns `+inf` or `-inf`.")) 706 ] 707 708 @classmethod 709 def expm1(cls, harness: primitive_harness.Harness): 710 return [custom_numeric(dtypes=np.float64, tol=1e-5)] 711 712 @classmethod 713 def fft(cls, harness): 714 return [ 715 Jax2TfLimitation( 716 "TF function not compileable", 717 devices=("cpu", "gpu"), 718 dtypes=[np.float64, np.complex128], 719 modes="compiled"), 720 custom_numeric(tol=1e-3) 721 ] 722 723 @classmethod 724 def _pow_test_util(cls, harness: primitive_harness.Harness): 725 726 def custom_assert(tst, result_jax, result_tf, *, args, tol): 727 # NaNs are mismatched, but assertAllClose will also behave weirdly for 728 # complex numbers containing np.inf as one of their components. See 729 # https://github.com/numpy/numpy/issues/15959 for more details. 730 mask = ( 731 np.isnan(result_jax) + np.isnan(result_tf) + np.isinf(result_jax) + 732 np.isinf(result_tf)) 733 tst.assertAllClose(result_jax[~mask], result_tf[~mask], rtol=tol) 734 735 return [ 736 custom_numeric( 737 dtypes=[np.float32, np.complex64], devices="tpu", tol=1e-2), 738 custom_numeric( 739 dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), 740 tol=1e-3), 741 custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-12), 742 custom_numeric(dtypes=np.float16, tol=1), 743 # Values get really small for large negative powers. 744 custom_numeric(dtypes=dtypes.bfloat16, tol=3), 745 custom_numeric( 746 dtypes=[np.complex64, np.complex128], 747 custom_assert=custom_assert, 748 ) 749 ] 750 751 @classmethod 752 def igamma(cls, harness: primitive_harness.Harness): 753 dtype = harness.dtype 754 755 # igamma is not defined when the first argument is <=0 756 def custom_assert(tst, result_jax, result_tf, *, args, tol): 757 arg1, arg2 = args 758 # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0 759 special_cases = (arg1 == 0.) & (arg2 == 0.) 760 nr_special_cases = np.count_nonzero(special_cases) 761 tst.assertAllClose( 762 np.full((nr_special_cases,), np.nan, dtype=dtype), 763 result_jax[special_cases]) 764 tst.assertAllClose( 765 np.full((nr_special_cases,), 0., dtype=dtype), 766 result_tf[special_cases]) 767 # non-special cases are equal 768 tst.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) 769 770 return [ 771 custom_numeric( 772 custom_assert=custom_assert, 773 description=( 774 "May return different results at undefined points " 775 "(both arguments 0). JAX returns `NaN` and TF returns 0 or " 776 "JAX returns 1 and TF returns `NaN`"), 777 modes=("eager", "graph")) 778 ] 779 780 @classmethod 781 def igammac(cls, harness: primitive_harness.Harness): 782 dtype = harness.dtype 783 784 # igammac is not defined when the first argument is <=0 785 def custom_assert(tst, result_jax, result_tf, *, args, tol): # noqa: F811 786 arg1, arg2 = args 787 # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN 788 special_cases = (arg1 <= 0.) | (arg2 <= 0) 789 nr_special_cases = np.count_nonzero(special_cases) 790 tst.assertAllClose( 791 np.full((nr_special_cases,), 1., dtype=dtype), 792 result_jax[special_cases]) 793 tst.assertAllClose( 794 np.full((nr_special_cases,), np.nan, dtype=dtype), 795 result_tf[special_cases]) 796 # non-special cases are equal 797 tst.assertAllClose( 798 result_jax[~special_cases], 799 result_tf[~special_cases], 800 atol=tol, 801 rtol=tol) 802 803 return [ 804 custom_numeric(dtypes=np.float64, tol=1e-9), 805 custom_numeric(devices="gpu", tol=1e-3), 806 custom_numeric( 807 custom_assert=custom_assert, 808 devices=("cpu", "gpu"), 809 modes=("eager", "graph"), 810 description=( 811 "May return different results at undefined points " 812 "(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or " 813 "JAX returns 1 and TF returns `NaN`")), 814 ] 815 816 @classmethod 817 def integer_pow(cls, harness: primitive_harness.Harness): 818 y = harness.params["y"] 819 return [ 820 missing_tf_kernel( 821 dtypes=[ 822 np.uint8, np.uint16, np.int8, np.int16, np.uint32, np.uint64 823 ],), 824 # hitting rtol = nan 825 Jax2TfLimitation(("Different overflow behavior for large exponents. It " 826 "and `+inf`/`-inf` differently in JAX and TF."), 827 devices="tpu", 828 dtypes=np.complex64, 829 enabled=(y in [1000, -1000]), 830 expect_tf_error=False, 831 skip_comparison=True), 832 Jax2TfLimitation( 833 "Different overflow behavior for large exponents. ", 834 dtypes=[np.int32, np.int64, np.float32], 835 enabled=(y > 10), 836 expect_tf_error=False, 837 skip_comparison=True) 838 ] + list(cls._pow_test_util(harness)) 839 840 @classmethod 841 def pow(cls, harness: primitive_harness.Harness): 842 return cls._pow_test_util(harness) 843 844 @classmethod 845 def le(cls, harness: primitive_harness.Harness): 846 return [ 847 missing_tf_kernel(dtypes=[np.bool_]), 848 missing_tf_kernel( 849 dtypes=[np.uint16, np.uint32], 850 devices=("cpu", "gpu"), 851 modes=("eager", "graph")), 852 missing_tf_kernel( 853 dtypes=[np.uint64], 854 devices=("cpu", "gpu"), 855 modes=("eager", "graph")) 856 ] 857 858 @classmethod 859 def lt(cls, harness: primitive_harness.Harness): 860 return cls.ge(harness) 861 862 @classmethod 863 def lgamma(cls, harness: primitive_harness.Harness): 864 return [ 865 missing_tf_kernel( 866 dtypes=[dtypes.bfloat16], 867 devices=("cpu", "gpu"), 868 modes=("eager", "graph")), 869 custom_numeric(dtypes=np.float64, tol=1e-11), 870 custom_numeric(dtypes=np.float32, tol=1e-3) 871 ] 872 873 @classmethod 874 def log1p(cls, harness: primitive_harness.Harness): 875 return [ 876 custom_numeric(dtypes=np.float64, tol=1e-10), 877 custom_numeric(dtypes=np.float32, tol=1e-3) 878 ] 879 880 @classmethod 881 def lu(cls, harness: primitive_harness.Harness): 882 dtype = harness.dtype 883 884 def custom_assert(tst, result_jax, result_tf, *, args, tol): 885 operand, = args 886 lu, pivots, perm = result_tf 887 batch_dims = operand.shape[:-2] 888 m, n = operand.shape[-2], operand.shape[-1] 889 890 def _make_permutation_matrix(perm): 891 result = [] 892 for idx in itertools.product(*map(range, operand.shape[:-1])): 893 result += [0 if c != perm[idx] else 1 for c in range(m)] 894 result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) 895 return result 896 897 k = min(m, n) 898 l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) 899 u = jnp.triu(lu)[..., :k, :] 900 p_mat = _make_permutation_matrix(perm) 901 902 tst.assertArraysEqual( 903 lax.linalg.lu_pivots_to_permutation(pivots, m), perm) 904 tst.assertAllClose( 905 jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol) 906 907 return [ 908 missing_tf_kernel(dtypes=[np.complex64], devices="tpu"), 909 custom_numeric( 910 dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1), 911 custom_numeric( 912 dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), 913 tol=1e-5), 914 custom_numeric(dtypes=[np.float64, np.complex128], tol=1e-13), 915 custom_numeric( 916 custom_assert=custom_assert, 917 description=("May return different, but also correct, results when " 918 "the decomposition is not unique")), 919 ] 920 921 @classmethod 922 def _min_max_test_util(cls, harness: primitive_harness.Harness): 923 # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; 924 # JAX always returns NaN, while TF returns the value NaN is compared with. 925 def custom_assert(tst, result_jax, result_tf, **_): 926 mask = np.isnan(result_jax) 927 tst.assertAllClose(result_jax[~mask], result_tf[~mask]) 928 929 return [ 930 missing_tf_kernel( 931 dtypes=[ 932 np.bool_, np.int8, np.complex64, np.uint16, np.uint32, np.uint64 933 ],), 934 missing_tf_kernel( 935 dtypes=[np.complex128], 936 devices=("cpu", "gpu"), 937 ), 938 custom_numeric( 939 custom_assert=custom_assert, 940 description=( 941 "May return different values when one of the values is NaN. " 942 "JAX always returns NaN, while TF returns the value NaN is compared with." 943 )) 944 ] 945 946 @classmethod 947 def max(cls, harness: primitive_harness.Harness): 948 return cls._min_max_test_util(harness) 949 950 @classmethod 951 def min(cls, harness: primitive_harness.Harness): 952 return cls._min_max_test_util(harness) 953 954 @classmethod 955 def mul(cls, harness: primitive_harness.Harness): 956 return [missing_tf_kernel(dtypes=[np.uint32, np.uint64])] 957 958 @classmethod 959 def neg(cls, harness: primitive_harness.Harness): 960 return [ 961 missing_tf_kernel(dtypes=[np.uint8, np.uint16, np.uint32, np.uint64],) 962 ] 963 964 @classmethod 965 def nextafter(cls, harness: primitive_harness.Harness): 966 return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])] 967 968 @classmethod 969 def population_count(cls, harness: primitive_harness.Harness): 970 return [ 971 missing_tf_kernel( 972 dtypes=[np.uint32, np.uint64], 973 devices=("cpu", "gpu"), 974 modes=("eager", "graph")) 975 ] 976 977 @classmethod 978 def qr(cls, harness: primitive_harness.Harness): 979 # See https://github.com/google/jax/pull/3775#issuecomment-659407824; 980 # # jit_compile=True breaks for complex types. 981 # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. 982 # - for now, the performance of the HLO QR implementation called when 983 # compiling with TF is expected to have worse performance than the 984 # custom calls made in JAX. 985 return [ 986 custom_numeric(tol=1e-5), 987 missing_tf_kernel( 988 dtypes=[dtypes.bfloat16], 989 devices="tpu", 990 ) 991 ] 992 993 @classmethod 994 def random_gamma(cls, harness: primitive_harness.Harness): 995 return [custom_numeric(devices="tpu", tol=1e-3)] 996 997 @classmethod 998 def reduce_max(cls, harness: primitive_harness.Harness): 999 return [ 1000 missing_tf_kernel(dtypes=[np.complex64]), 1001 missing_tf_kernel(dtypes=[np.complex128]) 1002 ] 1003 1004 @classmethod 1005 def reduce_min(cls, harness: primitive_harness.Harness): 1006 return [ 1007 missing_tf_kernel(dtypes=[np.complex64]), 1008 missing_tf_kernel(dtypes=[np.complex128]) 1009 ] 1010 1011 @classmethod 1012 def reduce_window_add(cls, harness): 1013 assert "add" == harness.params["computation"].__name__ 1014 return [ 1015 missing_tf_kernel(dtypes=[np.uint16]), 1016 missing_tf_kernel(dtypes=[np.complex64], devices="tpu"), 1017 missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu")) 1018 ] 1019 1020 @classmethod 1021 def reduce_window_mul(cls, harness): 1022 assert "mul" == harness.params["computation"].__name__ 1023 return [ 1024 missing_tf_kernel(dtypes=[np.uint32]), 1025 missing_tf_kernel(dtypes=[np.uint64], devices=("cpu", "gpu")) 1026 ] 1027 1028 @classmethod 1029 def reduce_window_min(cls, harness): 1030 assert "min" == harness.params["computation"].__name__ 1031 return [ 1032 missing_tf_kernel( 1033 dtypes=[np.uint32, np.uint16, np.bool_, np.complex64, np.int8],), 1034 missing_tf_kernel( 1035 dtypes=[np.uint64, np.complex128], 1036 devices=("cpu", "gpu"), 1037 ) 1038 ] 1039 1040 @classmethod 1041 def reduce_window_max(cls, harness): 1042 assert "max" == harness.params["computation"].__name__ 1043 dtype = harness.dtype 1044 init_value = harness.params["init_value"] 1045 return [ 1046 missing_tf_kernel(dtypes=[np.uint32, np.bool_, np.complex64]), 1047 missing_tf_kernel( 1048 dtypes=[np.uint64, np.complex128], 1049 devices=("cpu", "gpu"), 1050 ), 1051 Jax2TfLimitation( 1052 "TF kernel missing, except when the initial_value is the minimum for the dtype", 1053 dtypes=[np.uint16, np.int8], 1054 enabled=((dtype == np.uint16 and init_value != 0) or 1055 (dtype == np.int8 and init_value != -128))) 1056 ] 1057 1058 @classmethod 1059 def regularized_incomplete_beta(cls, harness: primitive_harness.Harness): 1060 return [ 1061 custom_numeric(dtypes=np.float64, tol=1e-14), 1062 missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16]) 1063 ] 1064 1065 @classmethod 1066 def rem(cls, harness: primitive_harness.Harness): 1067 return [ 1068 missing_tf_kernel( 1069 dtypes=[ 1070 np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16 1071 ],), 1072 Jax2TfLimitation( 1073 "TF integer division fails if divisor contains 0; JAX returns NaN", 1074 dtypes=[ 1075 np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, 1076 np.int16, np.int32, np.int64 1077 ], 1078 # Only the harnesses with "singularity" will have divide by 0 1079 enabled=("singularity" in harness.name)), 1080 missing_tf_kernel( 1081 dtypes=[np.float16], 1082 devices=("cpu", "gpu"), 1083 modes=("eager", "graph")), 1084 ] 1085 1086 @classmethod 1087 def rev(cls, harness: primitive_harness.Harness): 1088 return [missing_tf_kernel(dtypes=[np.uint32, np.uint64])] 1089 1090 @classmethod 1091 def round(cls, harness: primitive_harness.Harness): 1092 return [ 1093 missing_tf_kernel( 1094 dtypes=[dtypes.bfloat16], 1095 devices=("cpu", "gpu"), 1096 modes=("eager", "graph")) 1097 ] 1098 1099 @classmethod 1100 def rsqrt(cls, harness: primitive_harness.Harness): 1101 return [ 1102 missing_tf_kernel( 1103 dtypes=[dtypes.bfloat16], 1104 devices=("cpu", "gpu"), 1105 modes=("eager", "graph")) 1106 ] 1107 1108 @classmethod 1109 def scatter_add(cls, harness): 1110 return [ 1111 missing_tf_kernel(dtypes=[np.uint16, np.uint64, np.bool_],), 1112 missing_tf_kernel( 1113 dtypes=[np.complex64], 1114 devices="tpu", 1115 ), 1116 ] 1117 1118 @classmethod 1119 def scatter_max(cls, harness): 1120 return [ 1121 missing_tf_kernel( 1122 dtypes=[ 1123 np.int8, np.uint16, np.uint32, np.uint64, np.complex64, 1124 np.complex128, np.bool_ 1125 ],) 1126 ] 1127 1128 @classmethod 1129 def scatter_min(cls, harness): 1130 return [ 1131 missing_tf_kernel( 1132 dtypes=[ 1133 np.int8, np.uint16, np.uint32, np.complex64, np.bool_, 1134 np.uint64, np.complex128 1135 ],) 1136 ] 1137 1138 @classmethod 1139 def scatter_mul(cls, harness): 1140 return [ 1141 missing_tf_kernel(dtypes=[np.uint32, np.uint64, np.bool_],), 1142 missing_tf_kernel( 1143 dtypes=[np.complex64], 1144 devices="tpu", 1145 ), 1146 ] 1147 1148 @classmethod 1149 def select_and_gather_add(cls, harness): 1150 return [ 1151 missing_tf_kernel( 1152 dtypes=[np.float32], 1153 devices="tpu", 1154 description=( 1155 "This JAX primitives is not not exposed directly in the JAX API " 1156 "but arises from JVP of `lax.reduce_window` for reducers " 1157 "`lax.max` or `lax.min`. It also arises from second-order " 1158 "VJP of the same. Implemented using XlaReduceWindow")), 1159 Jax2TfLimitation(( 1160 "jax2tf unimplemented for 64-bit inputs because the current implementation " 1161 "relies on packing two values into a single value. This can be " 1162 "fixed by using a variadic XlaReduceWindow, when available"), 1163 dtypes=[np.float64], 1164 devices=("cpu", "gpu")) 1165 ] 1166 1167 @classmethod 1168 def select_and_scatter_add(cls, harness): 1169 return [ 1170 missing_tf_kernel(dtypes=[np.uint16]), 1171 missing_tf_kernel( 1172 dtypes=[np.uint64], 1173 devices=("cpu", "gpu"), 1174 ) 1175 ] 1176 1177 @classmethod 1178 def sign(cls, harness: primitive_harness.Harness): 1179 return [ 1180 missing_tf_kernel( 1181 dtypes=[ 1182 np.uint32, np.uint16, np.int16, np.int8, np.uint8, np.uint64 1183 ],) 1184 ] 1185 1186 @classmethod 1187 def sinh(cls, harness: primitive_harness.Harness): 1188 return [ 1189 missing_tf_kernel( 1190 dtypes=[np.float16], 1191 devices=("cpu", "gpu"), 1192 modes=("eager", "graph")) 1193 ] 1194 1195 @classmethod 1196 def sort(cls, harness: primitive_harness.Harness): 1197 return [ 1198 Jax2TfLimitation( 1199 # I think that this is because TF is running on CPU even for GPU tests? 1200 "TODO: TF non-stable multiple-array sort", 1201 devices="gpu", 1202 enabled=(harness.params["num_arrays"] > 1 and 1203 not harness.params["is_stable"]), 1204 expect_tf_error=False, 1205 skip_comparison=True), 1206 missing_tf_kernel( 1207 dtypes=[np.complex128, np.float64], devices=("cpu", "gpu")), 1208 missing_tf_kernel(dtypes=[np.bool_],), 1209 ] 1210 1211 @classmethod 1212 def sub(cls, harness): 1213 return [missing_tf_kernel(dtypes=[np.uint64])] 1214 1215 @classmethod 1216 def svd(cls, harness: primitive_harness.Harness): 1217 # TODO: slow test 1218 1219 def custom_assert(tst, r_jax, r_tf, *, args, tol): 1220 1221 def _reconstruct_operand(result, is_tf: bool): 1222 # Reconstructing operand as documented in numpy.linalg.svd (see 1223 # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) 1224 s, u, v = result 1225 U = u[..., :s.shape[-1]] 1226 V = v[..., :s.shape[-1], :] 1227 S = s[..., None, :] 1228 return jnp.matmul(U * S, V), s.shape, u.shape, v.shape 1229 1230 if harness.params["compute_uv"]: 1231 r_jax_reconstructed = _reconstruct_operand(r_jax, False) 1232 r_tf_reconstructed = _reconstruct_operand(r_tf, True) 1233 tst.assertAllClose( 1234 r_jax_reconstructed, r_tf_reconstructed, atol=tol, rtol=tol) 1235 else: 1236 tst.assertAllClose(r_jax, r_tf, atol=tol, rtol=tol) 1237 1238 return [ 1239 # Works in JAX for complex due to custom calls on cpu and gpu 1240 Jax2TfLimitation( 1241 "function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint`", 1242 dtypes=[np.complex64, np.complex128], 1243 devices=("cpu", "gpu"), 1244 modes=("compiled",)), 1245 missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"), 1246 custom_numeric(tol=1e-4), 1247 custom_numeric(custom_assert=custom_assert) 1248 ] 1249 1250 @classmethod 1251 def tan(cls, harness): 1252 return [ 1253 custom_numeric(dtypes=np.complex64, devices="tpu", tol=1e-4), 1254 custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-3), 1255 custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12)] 1256 1257 @classmethod 1258 def tanh(cls, harness): 1259 return [ 1260 custom_numeric(dtypes=np.complex128, tol=1e-7), 1261 custom_numeric(dtypes=np.complex64, tol=1e-4)] 1262 1263 @classmethod 1264 def top_k(cls, harness): 1265 1266 def custom_assert(tst, result_jax, result_tf, **_): 1267 assert len(result_jax) == len(result_tf) 1268 # TODO: TF and JAX sort [inf, nan] differently. 1269 first_arr_jax, first_arr_tf = result_jax[0], result_tf[0] 1270 if np.all(first_arr_jax == first_arr_tf): 1271 for arr_jax, arr_tf in zip(result_jax, result_tf): 1272 tst.assertArraysEqual(arr_jax, arr_tf) 1273 else: 1274 mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(first_arr_tf) 1275 tst.assertArraysEqual(first_arr_jax[~mask_jax], first_arr_tf[~mask_tf]) 1276 1277 return [ 1278 missing_tf_kernel( 1279 dtypes=[np.uint64, np.int64], 1280 devices=("cpu", "gpu"), 1281 modes="compiled"), 1282 custom_numeric( 1283 dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64], 1284 custom_assert=custom_assert, 1285 description=( 1286 "Produces different results when the array contains `inf` and `NaN`" 1287 " (they are sorted differently in TF vs. XLA).") 1288 )] 1289 1290 @classmethod 1291 def triangular_solve(cls, harness: primitive_harness.Harness): 1292 return [ 1293 missing_tf_kernel(dtypes=[dtypes.bfloat16]), 1294 missing_tf_kernel( 1295 dtypes=[np.float16], 1296 devices=("gpu", "cpu"), 1297 modes=("eager", "graph")), 1298 custom_numeric(dtypes=np.float32, tol=5e-3) 1299 ] 1300 1301 1302def custom_numeric( 1303 *, 1304 description="custom numeric comparison", 1305 dtypes=(), # All 1306 modes=("eager", "graph", "compiled"), 1307 devices=("cpu", "gpu", "tpu"), 1308 custom_assert=None, 1309 tol=None) -> Jax2TfLimitation: 1310 1311 return Jax2TfLimitation( 1312 description, 1313 expect_tf_error=False, 1314 dtypes=dtypes, 1315 devices=devices, 1316 modes=modes, 1317 custom_assert=custom_assert, 1318 tol=tol) 1319 1320 1321def missing_tf_kernel( 1322 *, 1323 description="op not defined for dtype", 1324 dtypes, 1325 modes=("eager", "graph", "compiled"), 1326 devices=("cpu", "gpu", "tpu") 1327) -> Jax2TfLimitation: 1328 1329 return Jax2TfLimitation( 1330 description, 1331 dtypes=dtypes, 1332 devices=devices, 1333 modes=modes) 1334