1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import contextlib
16import logging
17import numpy as np
18from typing import Any, Callable, List, Optional, Sequence
19import tensorflow as tf  # type: ignore[import]
20
21import jax
22from jax.config import config
23from jax import dtypes
24from jax.experimental import jax2tf
25from jax.interpreters import masking
26from jax import test_util as jtu
27from jax import tree_util
28from jax import numpy as jnp
29
30
31DType = Any
32
33def _make_tf_args(args):
34  def _convert_to_tensor(v):
35    if hasattr(v, "dtype"):
36      tf.convert_to_tensor(v)
37    return v
38
39  return tf.nest.map_structure(_convert_to_tensor, args)
40
41
42def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
43  # tf_args can be PyTrees
44  def _make_one_arg_signature(tf_arg):
45    return tf.TensorSpec(np.shape(tf_arg), tf_arg.dtype)
46
47  return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))
48
49
50def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
51  if mode == "eager":
52    return func_tf(*tf_args)  # EAGER
53  elif mode == "graph":
54    return tf.function(
55        func_tf,
56        autograph=False,
57        input_signature=_make_tf_input_signature(*tf_args))(*tf_args)  # GRAPH
58  elif mode == "compiled":
59    # Adding an explicit input_signature prevents TF from constant-folding
60    # the computation eagerly before compilation
61    return tf.function(
62        func_tf,
63        autograph=False,
64        jit_compile=True,
65        input_signature=_make_tf_input_signature(*tf_args))(
66            *tf_args)  # COMPILED
67  else:
68    assert False, (
69        f"Expected 'eager', 'graph', or 'compiled' for mode: got '{mode}'")
70
71
72class JaxToTfTestCase(jtu.JaxTestCase):
73
74  def setUp(self):
75    super().setUp()
76    # Ensure that all TF ops are created on the proper device (TPU or GPU or CPU)
77    # TODO(necula): why doesn't TF do this automatically?
78    tf_preferred_devices = (
79        tf.config.list_logical_devices("TPU") +
80        tf.config.list_logical_devices("GPU") +
81        tf.config.list_logical_devices())
82    self.tf_default_device = tf_preferred_devices[0]
83    logging.info(f"Running jax2tf converted code on {self.tf_default_device}.")
84    if jtu.device_under_test() != "gpu":
85      # TODO(necula): Change the build flags to ensure the GPU is seen by TF
86      # It seems that we need --config=cuda build flag for this to work?
87      self.assertEqual(jtu.device_under_test().upper(),
88                       self.tf_default_device.device_type)
89
90    with contextlib.ExitStack() as stack:
91      stack.enter_context(tf.device(self.tf_default_device))
92      self.addCleanup(stack.pop_all().close)
93
94  def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
95    """Compares dtypes across JAX and TF dtypes. Overrides super method."""
96
97    def to_numpy_dtype(dt):
98      return dt if isinstance(dt, np.dtype) else dt.as_numpy_dtype
99
100    if not config.FLAGS.jax_enable_x64 and canonicalize_dtypes:
101      self.assertEqual(
102          dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
103          dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
104    else:
105      self.assertEqual(
106          to_numpy_dtype(jtu._dtype(x)), to_numpy_dtype(jtu._dtype(y)))
107
108  def ConvertAndCompare(self,
109                        func_jax: Callable,
110                        *args,
111                        enable_xla: bool = True,
112                        limitations: Sequence = ()):
113    """Compares jax_func(*args) with convert(jax_func)(*args).
114
115    It compares the result of JAX, TF ("eager" mode),
116    TF with tf.function ("graph" mode), and TF with
117    tf.function(jit_compile=True) ("compiled" mode). In each mode,
118    either we expect to encounter a known limitation, or the value should
119    match the value from the JAX execution.
120
121    Args:
122      func_jax: the function to invoke (``func_jax(*args)``)
123      args: the arguments.
124      enable_xla: if True, allows the use of XLA ops in jax2tf.convert
125        (default: True).
126      limitations: the set of limitations for this harness (not yet filtered
127        by mode).
128    """
129    # Run JAX. Should not fail, we assume that the harness has been filtered
130    # already by JAX unimplemented primitives.
131    result_jax = func_jax(*args)  # JAX
132    result_tf = None
133
134    func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla)
135    tf_args = _make_tf_args(args)
136
137    unexpected_successes: List[str] = []
138    # Run the "compiled" mode first, it is most important
139    for mode in ("compiled", "eager", "graph"):
140      def log_message(extra):
141        return f"[{self._testMethodName}] mode={mode}: {extra}"
142
143      jax2tf_limits = tuple(filter(lambda l: l.filter(mode=mode), limitations))
144
145      skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run]
146      if skip_tf_run:
147        logging.info(log_message(f"Skip TF run due to limitations {skip_tf_run}"))
148        continue
149
150      try:
151        result_tf = _run_tf_function(func_tf, *tf_args, mode=mode)
152        tf_exception = None
153      except Exception as e:
154        tf_exception = e
155
156      expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error]
157      if tf_exception:
158        if expect_tf_error:
159          logging.info(log_message(
160            "Found expected TF error with enabled limitations "
161            f"{expect_tf_error}; TF error is {tf_exception}"))
162          continue
163        else:
164          raise tf_exception
165      else:
166        if expect_tf_error:
167          # It is more ergonomic to print all successful modes once
168          logging.warning(log_message(
169            f"Unexpected success with known limitations {expect_tf_error}"))
170          unexpected_successes.append(f"{mode}: {expect_tf_error}")
171
172      skip_comparison = [l for l in jax2tf_limits if l.skip_comparison]
173      if skip_comparison:
174        logging.warning(log_message(f"Skip result comparison due to {skip_comparison}"))
175        continue
176
177      max_tol = None
178      max_tol_lim = None if not jax2tf_limits else jax2tf_limits[0].get_max_tolerance_limitation(jax2tf_limits)
179      if max_tol_lim is not None:
180        max_tol = max_tol_lim.tol
181        logging.info(log_message(f"Using tol={max_tol} due to {max_tol_lim}"))
182
183      # Convert results to np.arrays
184      result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf)  # type: ignore
185
186      custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
187      assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"
188
189      if custom_assert_lim:
190        logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
191        custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
192      else:
193        logging.info(log_message(f"Running default assert with tol={max_tol}"))
194        # In compiled mode we expect the same result as JAX by default
195        self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
196
197    # end "for mode"
198
199    if unexpected_successes:
200      msg = (f"[{self._testMethodName}] The following are unexpected "
201             "successful modes:\n" + "\n".join(unexpected_successes))
202      logging.warning(msg)
203      # Uncomment the below if you want to see warnings as failures
204      #self.assertEmpty(msg)
205    return result_jax, result_tf
206
207  def TransformConvertAndCompare(self, func: Callable, arg,
208                                 transform: Optional[str]):
209    """Like ConvertAndCompare but first applies a transformation.
210
211    `func` must be a function from one argument to one result. `arg` is
212    the argument before the transformation.
213
214    `transform` can be None, "jit", "jvp", "grad", "vmap", "jvp_vmap",
215    "grad_vmap"
216    """
217    if transform is None:
218      return self.ConvertAndCompare(func, arg)
219    if transform == "jit":
220      return self.ConvertAndCompare(jax.jit(func), arg)
221    if transform == "jvp":
222      t_func = lambda x, xt: jax.jvp(func, (x,), (xt,))
223      return self.ConvertAndCompare(t_func, arg, np.full_like(arg, 0.1))
224    if transform == "grad":
225      return self.ConvertAndCompare(jax.grad(func), arg)
226    if transform == "vmap":
227      t_arg = np.stack([arg] * 4)
228      return self.ConvertAndCompare(jax.vmap(func), t_arg)
229    if transform == "jvp_vmap":
230      jvp_func = lambda x, xt: jax.jvp(jax.vmap(func), (x,), (xt,))
231      t_arg = np.stack([arg] * 4)
232      return self.ConvertAndCompare(jvp_func, t_arg, np.full_like(t_arg, 0.1))
233    if transform == "grad_vmap":
234      grad_func = jax.grad(lambda x: jnp.sum(jax.vmap(func)(x)))
235      t_arg = np.stack([arg] * 4)
236      return self.ConvertAndCompare(grad_func, t_arg)
237    assert False, transform
238
239  def CheckShapePolymorphism(self, f_jax: Callable, *,
240                             input_signature: Sequence[tf.TensorSpec],
241                             in_shapes: Optional[Sequence[Any]],
242                             expected_output_signature: tf.TensorSpec):
243    """Convert a function using polymorphic shapes.
244
245    Args:
246      f_jax: a JAX function of `n` arguments
247      input_signature: used as the input signature for the tf.function.
248      in_shapes: if given, it must be a sequence of `n` shape specifications and
249        must match the `input_signature`. (see jax2tf.convert).
250    """
251    f_tf = tf.function(
252        jax2tf.convert(f_jax, in_shapes=in_shapes),
253        autograph=False,
254        input_signature=input_signature)
255    concrete_f_tf = f_tf.get_concrete_function(*input_signature)
256    if expected_output_signature:
257      concrete_output_tf_shape = concrete_f_tf.output_shapes
258      assert not isinstance(concrete_output_tf_shape, tuple)  # A single result
259      self.assertEqual(
260          tuple(expected_output_signature.shape),
261          tuple(concrete_output_tf_shape))
262    return f_tf
263
264  def MakeInputSignature(self, *in_shapes):
265    """From a pytree of in_shape string specification, make a pytree of tf.TensorSpec.
266
267    Dimension variables are replaced with None.
268    """
269
270    def in_shape_to_tensorspec(in_shape: str) -> tf.TensorSpec:
271      in_spec = masking.parse_spec(in_shape)
272      return tf.TensorSpec(
273          tuple(
274              int(dim_spec) if dim_spec.is_constant else None
275              for dim_spec in in_spec),
276          dtype=tf.float32)
277
278    return tree_util.tree_multimap(in_shape_to_tensorspec, in_shapes)
279