1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import contextlib 16import logging 17import numpy as np 18from typing import Any, Callable, List, Optional, Sequence 19import tensorflow as tf # type: ignore[import] 20 21import jax 22from jax.config import config 23from jax import dtypes 24from jax.experimental import jax2tf 25from jax.interpreters import masking 26from jax import test_util as jtu 27from jax import tree_util 28from jax import numpy as jnp 29 30 31DType = Any 32 33def _make_tf_args(args): 34 def _convert_to_tensor(v): 35 if hasattr(v, "dtype"): 36 tf.convert_to_tensor(v) 37 return v 38 39 return tf.nest.map_structure(_convert_to_tensor, args) 40 41 42def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]: 43 # tf_args can be PyTrees 44 def _make_one_arg_signature(tf_arg): 45 return tf.TensorSpec(np.shape(tf_arg), tf_arg.dtype) 46 47 return tf.nest.map_structure(_make_one_arg_signature, list(tf_args)) 48 49 50def _run_tf_function(func_tf: Callable, *tf_args, mode: str): 51 if mode == "eager": 52 return func_tf(*tf_args) # EAGER 53 elif mode == "graph": 54 return tf.function( 55 func_tf, 56 autograph=False, 57 input_signature=_make_tf_input_signature(*tf_args))(*tf_args) # GRAPH 58 elif mode == "compiled": 59 # Adding an explicit input_signature prevents TF from constant-folding 60 # the computation eagerly before compilation 61 return tf.function( 62 func_tf, 63 autograph=False, 64 jit_compile=True, 65 input_signature=_make_tf_input_signature(*tf_args))( 66 *tf_args) # COMPILED 67 else: 68 assert False, ( 69 f"Expected 'eager', 'graph', or 'compiled' for mode: got '{mode}'") 70 71 72class JaxToTfTestCase(jtu.JaxTestCase): 73 74 def setUp(self): 75 super().setUp() 76 # Ensure that all TF ops are created on the proper device (TPU or GPU or CPU) 77 # TODO(necula): why doesn't TF do this automatically? 78 tf_preferred_devices = ( 79 tf.config.list_logical_devices("TPU") + 80 tf.config.list_logical_devices("GPU") + 81 tf.config.list_logical_devices()) 82 self.tf_default_device = tf_preferred_devices[0] 83 logging.info(f"Running jax2tf converted code on {self.tf_default_device}.") 84 if jtu.device_under_test() != "gpu": 85 # TODO(necula): Change the build flags to ensure the GPU is seen by TF 86 # It seems that we need --config=cuda build flag for this to work? 87 self.assertEqual(jtu.device_under_test().upper(), 88 self.tf_default_device.device_type) 89 90 with contextlib.ExitStack() as stack: 91 stack.enter_context(tf.device(self.tf_default_device)) 92 self.addCleanup(stack.pop_all().close) 93 94 def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): 95 """Compares dtypes across JAX and TF dtypes. Overrides super method.""" 96 97 def to_numpy_dtype(dt): 98 return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype 99 100 if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes: 101 self.assertEqual( 102 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))), 103 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y)))) 104 else: 105 self.assertEqual( 106 to_numpy_dtype(jtu._dtype(x)), to_numpy_dtype(jtu._dtype(y))) 107 108 def ConvertAndCompare(self, 109 func_jax: Callable, 110 *args, 111 enable_xla: bool = True, 112 limitations: Sequence = ()): 113 """Compares jax_func(*args) with convert(jax_func)(*args). 114 115 It compares the result of JAX, TF ("eager" mode), 116 TF with tf.function ("graph" mode), and TF with 117 tf.function(jit_compile=True) ("compiled" mode). In each mode, 118 either we expect to encounter a known limitation, or the value should 119 match the value from the JAX execution. 120 121 Args: 122 func_jax: the function to invoke (``func_jax(*args)``) 123 args: the arguments. 124 enable_xla: if True, allows the use of XLA ops in jax2tf.convert 125 (default: True). 126 limitations: the set of limitations for this harness (not yet filtered 127 by mode). 128 """ 129 # Run JAX. Should not fail, we assume that the harness has been filtered 130 # already by JAX unimplemented primitives. 131 result_jax = func_jax(*args) # JAX 132 result_tf = None 133 134 func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla) 135 tf_args = _make_tf_args(args) 136 137 unexpected_successes: List[str] = [] 138 # Run the "compiled" mode first, it is most important 139 for mode in ("compiled", "eager", "graph"): 140 def log_message(extra): 141 return f"[{self._testMethodName}] mode={mode}: {extra}" 142 143 jax2tf_limits = tuple(filter(lambda l: l.filter(mode=mode), limitations)) 144 145 skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run] 146 if skip_tf_run: 147 logging.info(log_message(f"Skip TF run due to limitations {skip_tf_run}")) 148 continue 149 150 try: 151 result_tf = _run_tf_function(func_tf, *tf_args, mode=mode) 152 tf_exception = None 153 except Exception as e: 154 tf_exception = e 155 156 expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error] 157 if tf_exception: 158 if expect_tf_error: 159 logging.info(log_message( 160 "Found expected TF error with enabled limitations " 161 f"{expect_tf_error}; TF error is {tf_exception}")) 162 continue 163 else: 164 raise tf_exception 165 else: 166 if expect_tf_error: 167 # It is more ergonomic to print all successful modes once 168 logging.warning(log_message( 169 f"Unexpected success with known limitations {expect_tf_error}")) 170 unexpected_successes.append(f"{mode}: {expect_tf_error}") 171 172 skip_comparison = [l for l in jax2tf_limits if l.skip_comparison] 173 if skip_comparison: 174 logging.warning(log_message(f"Skip result comparison due to {skip_comparison}")) 175 continue 176 177 max_tol = None 178 max_tol_lim = None if not jax2tf_limits else jax2tf_limits[0].get_max_tolerance_limitation(jax2tf_limits) 179 if max_tol_lim is not None: 180 max_tol = max_tol_lim.tol 181 logging.info(log_message(f"Using tol={max_tol} due to {max_tol_lim}")) 182 183 # Convert results to np.arrays 184 result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) # type: ignore 185 186 custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] 187 assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" 188 189 if custom_assert_lim: 190 logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}")) 191 custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol) 192 else: 193 logging.info(log_message(f"Running default assert with tol={max_tol}")) 194 # In compiled mode we expect the same result as JAX by default 195 self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol) 196 197 # end "for mode" 198 199 if unexpected_successes: 200 msg = (f"[{self._testMethodName}] The following are unexpected " 201 "successful modes:\n" + "\n".join(unexpected_successes)) 202 logging.warning(msg) 203 # Uncomment the below if you want to see warnings as failures 204 #self.assertEmpty(msg) 205 return result_jax, result_tf 206 207 def TransformConvertAndCompare(self, func: Callable, arg, 208 transform: Optional[str]): 209 """Like ConvertAndCompare but first applies a transformation. 210 211 `func` must be a function from one argument to one result. `arg` is 212 the argument before the transformation. 213 214 `transform` can be None, "jit", "jvp", "grad", "vmap", "jvp_vmap", 215 "grad_vmap" 216 """ 217 if transform is None: 218 return self.ConvertAndCompare(func, arg) 219 if transform == "jit": 220 return self.ConvertAndCompare(jax.jit(func), arg) 221 if transform == "jvp": 222 t_func = lambda x, xt: jax.jvp(func, (x,), (xt,)) 223 return self.ConvertAndCompare(t_func, arg, np.full_like(arg, 0.1)) 224 if transform == "grad": 225 return self.ConvertAndCompare(jax.grad(func), arg) 226 if transform == "vmap": 227 t_arg = np.stack([arg] * 4) 228 return self.ConvertAndCompare(jax.vmap(func), t_arg) 229 if transform == "jvp_vmap": 230 jvp_func = lambda x, xt: jax.jvp(jax.vmap(func), (x,), (xt,)) 231 t_arg = np.stack([arg] * 4) 232 return self.ConvertAndCompare(jvp_func, t_arg, np.full_like(t_arg, 0.1)) 233 if transform == "grad_vmap": 234 grad_func = jax.grad(lambda x: jnp.sum(jax.vmap(func)(x))) 235 t_arg = np.stack([arg] * 4) 236 return self.ConvertAndCompare(grad_func, t_arg) 237 assert False, transform 238 239 def CheckShapePolymorphism(self, f_jax: Callable, *, 240 input_signature: Sequence[tf.TensorSpec], 241 in_shapes: Optional[Sequence[Any]], 242 expected_output_signature: tf.TensorSpec): 243 """Convert a function using polymorphic shapes. 244 245 Args: 246 f_jax: a JAX function of `n` arguments 247 input_signature: used as the input signature for the tf.function. 248 in_shapes: if given, it must be a sequence of `n` shape specifications and 249 must match the `input_signature`. (see jax2tf.convert). 250 """ 251 f_tf = tf.function( 252 jax2tf.convert(f_jax, in_shapes=in_shapes), 253 autograph=False, 254 input_signature=input_signature) 255 concrete_f_tf = f_tf.get_concrete_function(*input_signature) 256 if expected_output_signature: 257 concrete_output_tf_shape = concrete_f_tf.output_shapes 258 assert not isinstance(concrete_output_tf_shape, tuple) # A single result 259 self.assertEqual( 260 tuple(expected_output_signature.shape), 261 tuple(concrete_output_tf_shape)) 262 return f_tf 263 264 def MakeInputSignature(self, *in_shapes): 265 """From a pytree of in_shape string specification, make a pytree of tf.TensorSpec. 266 267 Dimension variables are replaced with None. 268 """ 269 270 def in_shape_to_tensorspec(in_shape: str) -> tf.TensorSpec: 271 in_spec = masking.parse_spec(in_shape) 272 return tf.TensorSpec( 273 tuple( 274 int(dim_spec) if dim_spec.is_constant else None 275 for dim_spec in in_spec), 276 dtype=tf.float32) 277 278 return tree_util.tree_multimap(in_shape_to_tensorspec, in_shapes) 279