# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict, deque import itertools as it import operator as op from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tuple, Union, NamedTuple) from warnings import warn from absl import logging import numpy as np from ..config import flags, bool_env, config from .. import core from .. import ad_util from .. import dtypes from .. import lazy from .. import linear_util as lu from jax._src import source_info_util from ..abstract_arrays import (make_shaped_array, array_types) from ..core import (ConcreteArray, ShapedArray, AbstractToken, Literal, pp_eqn_compact, raise_to_shaped, abstract_token) from jax._src.pprint_util import pp from .._src.util import (partial, partialmethod, cache, prod, unzip2, extend_name_stack, wrap_name, safe_zip, safe_map) from ..lib import xla_bridge as xb from ..lib import xla_client as xc from . import partial_eval as pe from . import ad from . import masking map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip xe = xc._xla xops = xc._xla.ops # Types Backend = Any # xc.LocalBackend (why does mypy not like this?) Device = Any # xc.Device PyLocalBuffer = Any XlaOp = Any # xla_extension.XlaOp XlaShape = Any # xla_client.Shape XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder XlaExecutable = Any # xla_extension.LocalExecutable FLAGS = flags.FLAGS flags.DEFINE_bool('jax_debug_nans', bool_env('JAX_DEBUG_NANS', False), 'Add nan checks to every operation.') flags.DEFINE_bool('jax_debug_infs', bool_env('JAX_DEBUG_INFS', False), 'Add inf checks to every operation.') flags.DEFINE_bool('jax_log_compiles', bool_env('JAX_LOG_COMPILES', False), 'Print a message each time a `jit` computation is compiled.') # This flag is set on exit; no logging should be attempted _on_exit = False def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() # unit representation def _make_unit_constant(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool'))) def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),) def _device_put_unit(_, device): backend = xb.get_device_backend(device) return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')), device),) def _make_array_shape(a): if a.dtype is dtypes.float0: return (xc.Shape.array_shape(np.dtype('bool'), a.shape),) else: return (xc.Shape.array_shape(a.dtype, a.shape),) ### handlers xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c)) def aval_to_xla_shapes(aval): try: return xla_shape_handlers[type(aval)](aval) except KeyError as err: raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = { core.AbstractUnit: _make_unit_shape, ShapedArray: _make_array_shape, ConcreteArray: _make_array_shape, } def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable: try: return xla_result_handlers[type(aval)](device, aval) except KeyError as err: raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err def array_result_handler(device: Optional[Device], aval: core.ShapedArray): if aval.dtype is dtypes.float0: return lambda _: np.zeros(aval.shape, dtypes.float0) return partial(make_device_array, raise_to_shaped(aval), device, lazy.array(aval.shape)) xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = { core.AbstractUnit: lambda _, __: lambda _: core.unit, ShapedArray: array_result_handler, ConcreteArray: array_result_handler, } def device_put(x, device: Optional[Device] = None) -> Tuple[Any]: x = canonicalize_dtype(x) try: return device_put_handlers[type(x)](x, device) except KeyError as err: raise TypeError(f"No device_put handler for type: {type(x)}") from err def _device_put_array(x, device: Optional[Device]): backend = xb.get_device_backend(device) if x.dtype is dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) return (backend.buffer_from_pyval(x, device),) def _device_put_scalar(x, device): return _device_put_array(dtypes.coerce_to_array(x), device) device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = { core.Unit: _device_put_unit } device_put_handlers.update((t, _device_put_array) for t in array_types) device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types) # TODO(mattjj): try to remove this canonicalize_dtype stuff def canonicalize_dtype(x): typ = type(x) handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) for typ in typ.mro(): handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") def _canonicalize_ndarray_dtype(x): return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) def _canonicalize_python_scalar_dtype(typ, x): return np.asarray( x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ])) canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity} canonicalize_dtype_handlers.update( (t, _canonicalize_ndarray_dtype) for t in array_types) canonicalize_dtype_handlers.update( (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) def abstractify(x) -> core.AbstractValue: typ = type(x) aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) for typ in typ.mro(): aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") def _make_abstract_python_scalar(typ, _): return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True) pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = { core.Unit: lambda _: core.abstract_unit, } pytype_aval_mappings.update((t, make_shaped_array) for t in array_types) pytype_aval_mappings.update( (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types) # We can optionally set a Jaxpr rewriter that can be applied just before # compilation. This mechanism is used for compiling id_tap, we can # remove it once we bring the id_tap implementation into the core. outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: if outfeed_rewriter is not None: return outfeed_rewriter(jaxpr) else: return jaxpr outfeed_primitives: Set[core.Primitive] = set() def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" return any(primitive_uses_outfeed(eqn.primitive, eqn.params) for eqn in jaxpr.eqns) def _param_uses_outfeed(param): if type(param) is core.Jaxpr: if jaxpr_uses_outfeed(param): return True elif type(param) is core.ClosedJaxpr: if jaxpr_uses_outfeed(param.jaxpr): return True return False def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool: if prim in outfeed_primitives: return True for param in params.values(): if isinstance(param, tuple): if any(unsafe_map(_param_uses_outfeed, param)): return True elif _param_uses_outfeed(param): return True return False ### op-by-op execution def arg_spec(x): aval = abstractify(x) try: return aval, x._device except: return aval, None def apply_primitive(prim, *args, **params): """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params) return compiled_fun(*args) def _partition_outputs(avals, outs): nouts = [aval._num_buffers for aval in avals] if not core.skip_checks: assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}." outs = iter(outs) return [[next(outs) for _ in range(nout)] for nout in nouts] @cache() def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue, Optional[Device]], **params): avals, arg_devices = unzip2(arg_specs) donated_invars = (False,) * len(arg_specs) device = _device_from_arg_devices(arg_devices) backend = xb.get_device_backend(device) if primitive_uses_outfeed(prim, params): # We use the _xla_callable path, where we pre-process the primitives def prim_fun(*args): return prim.bind(*args, **params) return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars, *arg_specs) aval_out = prim.abstract_eval(*avals, **params) if not prim.multiple_results: handle_result = aval_to_result_handler(device, aval_out) else: handlers = map(partial(aval_to_result_handler, device), aval_out) handle_result = lambda *bufs:\ tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs))) tuple_args = len(avals) > 100 if prim in initial_style_translations: nreps = initial_style_primitive_replicas(params) else: nreps = 1 if nreps > xb.device_count(backend): raise ValueError( f"compiling a primitive computation `{prim}` that requires {nreps} " f"replicas, but only {xb.device_count(backend)} XLA devices are " f"available on backend {backend.platform}.") built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend, tuple_args, *avals, **params) options = xb.get_compile_options( num_replicas=nreps, num_partitions=1, device_assignment=device and (device.id,)) options.parameter_is_tupled_arguments = tuple_args compiled = backend_compile(backend, built_c, options) if nreps == 1: return partial(_execute_compiled_primitive, prim, compiled, handle_result) else: return partial(_execute_replicated_primitive, prim, compiled, handle_result) def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]: """Given devices of inputs, determine where to perform a computation. Args: devices: list where each element is a either a `Device` instance or `None`. Returns: A `Device` instance or None. Raises: ValueError if input devices are inconsistent. """ try: device, = {d for d in devices if d is not None} or (None,) return device except ValueError as err: msg = "primitive arguments must be colocated on the same device, got {}" raise ValueError(msg.format(", ".join(map(str, devices)))) from err @cache() def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params): c = xb.make_computation_builder(f"primitive_computation_{prim.name}") c.set_op_metadata(xc.OpMetadata( op_type=prim.name, op_name=str(pp_eqn_compact(prim.name, params)))) platform = xb.get_backend(backend).platform xla_args, _ = _xla_callable_args(c, avals, tuple_args) # return val always set as a side-effect on c if prim in backend_specific_translations[platform]: rule = backend_specific_translations[platform][prim] ans = rule(c, *xla_args, **params) elif prim in translations: rule = translations[prim] ans = rule(c, *xla_args, **params) elif prim in translations_with_avals: rule = translations_with_avals[prim] ans = rule(c, avals, xla_args, params) elif prim in initial_style_translations: rule = initial_style_translations[prim] ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend, *xla_args, **params) else: raise NotImplementedError(f"XLA translation rule for {prim} not found") assert isinstance(ans, xe.XlaOp) c.clear_op_metadata() try: return c.build(ans) except RuntimeError as e: msg = (" ".join(map(str, e.args)) + "\n" "This is a bug in JAX's shape-checking rules; please report it!\n" "https://github.com/google/jax/issues\n") raise RuntimeError(msg) from e def primitive_subcomputation(prim, *avals, **params): axis_env = AxisEnv(1, (), ()) return primitive_computation(prim, axis_env, None, False, *avals, **params) def backend_compile(backend, built_c, options): # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results return backend.compile(built_c, compile_options=options) def _execute_compiled_primitive(prim, compiled, result_handler, *args): device, = compiled.local_devices() input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) out_bufs = compiled.execute(input_bufs) check_special(prim, out_bufs) return result_handler(*out_bufs) def _execute_replicated_primitive(prim, compiled, result_handler, *args): input_bufs = [ list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) for device in compiled.local_devices()] out_bufs = compiled.execute_on_local_devices(input_bufs)[0] return result_handler(*out_bufs) def check_special(prim, bufs): if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans: for buf in bufs: _check_special(prim.name, buf.xla_shape(), buf) def _check_special(name, xla_shape, buf): assert not xla_shape.is_tuple() if dtypes.issubdtype(xla_shape.element_type(), np.inexact): if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())): raise FloatingPointError(f"invalid value (nan) encountered in {name}") if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())): raise FloatingPointError(f"invalid value (inf) encountered in {name}") ### compiling jaxprs def prefetch(x): if isinstance(x, DeviceArray): x.copy_to_host_async() return x def jaxpr_literals(jaxpr): """Generates all the literals inside a jaxpr, including nested subjaxprs.""" for eqn in jaxpr.eqns: for v in eqn.invars: if type(v) is core.Literal: yield v.val for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_literals(subjaxpr) def _flatmap(func: Callable, vars: Sequence): return list(it.chain.from_iterable(map(func, vars))) def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence): return map(func, vars, _partition_outputs([v.aval for v in vars], nodes)) def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): if backend not in ('cpu', 'gpu', 'tpu'): platform = xb.get_backend(backend).platform # canonicalize else: platform = backend def read(v): if type(v) is Literal: return [xb.constant(c, canonicalize_dtype(v.val))] else: return env[v] def aval(v): if type(v) is Literal: return abstractify(v.val) else: return v.aval def write(v, node): assert node is not None env[v] = node env = {} _partitionmap(write, [core.unitvar], [_make_unit_constant(c)]) _partitionmap(write, jaxpr.constvars, consts) _partitionmap(write, jaxpr.invars, args) for eqn in jaxpr.eqns: frame = source_info_util.user_frame(eqn.source_info) c.set_op_metadata(xc.OpMetadata( op_type=eqn.primitive.name, op_name=str(pp(name_stack) >> pp_eqn_compact( eqn.primitive.name, eqn.params)), source_file=frame.file_name if frame else None, source_line=frame.line_num if frame else None)) in_nodes = _flatmap(read, eqn.invars) # TODO(jakevdp): migrate `translations` table to `translations_with_avals` if eqn.primitive in backend_specific_translations[platform]: rule = backend_specific_translations[platform][eqn.primitive] ans = rule(c, *in_nodes, **eqn.params) elif eqn.primitive in translations: ans = translations[eqn.primitive](c, *in_nodes, **eqn.params) elif eqn.primitive in translations_with_avals: rule = translations_with_avals[eqn.primitive] ans = rule(c, map(aval, eqn.invars), in_nodes, eqn.params) elif eqn.primitive in initial_style_translations: new_params = check_backend_params(eqn.params, backend) rule = initial_style_translations[eqn.primitive] ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name), map(aval, eqn.invars), backend, *in_nodes, **new_params) elif eqn.primitive in parallel_translations: rule = parallel_translations[eqn.primitive] ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params) elif eqn.primitive in call_translations: new_params = check_backend_params(eqn.params, backend) rule = call_translations[eqn.primitive] ans = rule(c, axis_env, in_nodes, name_stack, backend=backend, **new_params) else: raise NotImplementedError( f"XLA translation rule for primitive '{eqn.primitive.name}' not found") assert isinstance(ans, xe.XlaOp) c.get_shape(ans) # force xla to do shape error checking if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars): out_nodes = xla_destructure(c, ans) else: out_nodes = [ans] c.clear_op_metadata() _partitionmap(write, eqn.outvars, out_nodes) return _flatmap(read, jaxpr.outvars) def xla_destructure(c, ans): num_elements = len(c.get_shape(ans).tuple_shapes()) return [xops.GetTupleElement(ans, i) for i in range(num_elements)] def check_backend_params(params, outer_backend): # For nested calls, the outermost call sets the backend for all inner calls; # it's an error if the inner call has a conflicting explicit backend spec. inner_backend = params.get('backend', None) if inner_backend and inner_backend != outer_backend: raise ValueError( f"Outer-jit backend specification {outer_backend} must match explicit " f"inner-jit backend specification {inner_backend}.") return {k: params[k] for k in params if k != 'backend'} class AxisEnv(NamedTuple): """Represents a pmap mesh (only along the replica axes).""" nreps: int names: Tuple[Any, ...] sizes: Tuple[int, ...] def extend_axis_env(env: AxisEnv, name, size: int): return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,)) def axis_read(axis_env, axis_name): try: return max(i for i, name in enumerate(axis_env.names) if name == axis_name) except ValueError: raise NameError("unbound axis name: {}".format(axis_name)) from None def axis_groups(axis_env: AxisEnv, name): if not isinstance(name, (list, tuple)): name = (name,) mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name)) trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes)) assert not ragged mesh_spec = axis_env.sizes + (trailing_size,) return _axis_groups(mesh_spec, mesh_axes) def _axis_groups(mesh_spec, mesh_axes): """Computes replica group ids for a collective performed over a subset of the mesh. Args: mesh_spec: A sequence of integers representing the mesh shape. mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive) indicating over which axes the collective is performed. Returns: A tuple of replica groups (i.e. tuples containing replica ids). """ iota = np.arange(prod(mesh_spec)).reshape(mesh_spec) groups = np.reshape( np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), (prod(np.take(mesh_spec, mesh_axes)), -1)) return tuple(unsafe_map(tuple, groups.T)) def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: """The number of replicas needed for a jaxpr. For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the subjaxprs. For a list of eqns, take the maximum number of replicas. """ return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1) # TODO(mattjj): this function assumes that only pmap has a parameter named # axis_size, and that it corresponds to cross-replica mapping def eqn_replicas(eqn): call_jaxpr = eqn.params.get("call_jaxpr") if call_jaxpr: return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr) elif eqn.primitive in initial_style_translations: return initial_style_primitive_replicas(eqn.params) else: return 1 def initial_style_primitive_replicas(params): return max(core.traverse_jaxpr_params(jaxpr_replicas, params), default=1) # TODO(mattjj,skyewm): the functions here are utilities for checking if # not-yet-supported features are used with multi-host programming def jaxpr_has_pmap(jaxpr): """Whether there is an xla_pmap primitive anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if 'xla_pmap' in eqn.primitive.name: return True for subjaxpr in core.subjaxprs(jaxpr): if jaxpr_has_pmap(subjaxpr): return True return False def jaxpr_collectives(jaxpr): """Generates all the collective primitives anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if eqn.primitive in parallel_translations: yield eqn.primitive for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr) ### xla_call underlying jit def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, *unsafe_map(arg_spec, args)) try: return compiled_fun(*args) except FloatingPointError: assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case print("Invalid value encountered in the output of a jit function. " "Calling the de-optimized version.") # We want to run the wrapped function again (after _xla_callable already ran # it), but linear_util.WrappedFun instances are meant to be run only once. # In addition to re-executing the Python code, which is usually undesirable # but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing # any linear_util.py-style side effects, i.e. re-populating Stores created # by any transformation_with_aux's applied to fun. Since this is # intentional here, to avoid "Store occupied" errors we reset the stores to # be empty. for store in fun.stores: store and store.reset() return fun.call_wrapped(*args) # probably won't return def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]: """Expands a given shape tree into a flat list of indices to arrays. Given the following computation: >>> c = xc.XlaBuilder("example") >>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1]))) >>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2]))) >>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3]))) >>> o = xops.Tuple(c, [p0, p1, p2]) We can query the arrays in the output tuple: >>> flatten_shape(c.GetShape(o)) (((0,), f32[1]{0}), ((1,), f32[2]{0}), ((2,), f32[3]{0})) Or the arrays in one of the parameters (which is itself an array): >>> flatten_shape(c.GetShape(p0)) (((), f32[1]{0}),) Args s: The input shape. Returns: An iterable of pairs of indices and shapes for each array within the shape tree. """ def _flatten_shape(s, index): if s.is_array(): yield index, s else: assert s.is_tuple() for i, sub in enumerate(s.tuple_shapes()): subindex = index + (i,) if sub.is_tuple(): yield from _flatten_shape(sub, subindex) else: yield subindex, sub return tuple(_flatten_shape(s, index=())) def _xla_consts(c, consts): unique_consts = {id(const): const for const in consts} xla_consts = { id_: xb.constant(c, const) for id_, const in unique_consts.items()} return [xla_consts[id(const)] for const in consts] @lu.cache def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs): if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format(device, backend)) abstract_args, arg_devices = unzip2(arg_specs) if config.omnistaging_enabled: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) if any(isinstance(c, core.Tracer) for c in consts): raise core.UnexpectedTracerError("Encountered an unexpected tracer.") else: pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = device.platform if device else backend if config.omnistaging_enabled: result_handlers = map(partial(aval_to_result_handler, device), out_avals) else: out_avals = [pval.get_aval() for pval in pvals] result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to force their (potentially lazy) arguments. if not jaxpr.eqns: return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers) if not _on_exit: log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args) if nreps > 1: warn(f"The jitted function {fun.__name__} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation that requires {nreps} replicas, but only " f"{xb.device_count(backend)} XLA devices are available") if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all).") tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU c = xb.make_computation_builder("jit_{}".format(fun.__name__)) xla_consts = _xla_consts(c, consts) xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars) out_nodes = jaxpr_subcomp( c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts, extend_name_stack(wrap_name(name, 'jit')), *xla_args) out_tuple = xops.Tuple(c, out_nodes) backend = xb.get_backend(backend) if backend.platform in ("gpu", "tpu"): donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) if any(donated_invars): # TODO(tomhennigan): At call time we should mark these buffers as deleted. unused_donations = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d] warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations))) built = c.build(out_tuple) options = xb.get_compile_options( num_replicas=nreps, num_partitions=1, device_assignment=(device.id,) if device else None) options.parameter_is_tupled_arguments = tuple_args compiled = backend_compile(backend, built, options) if nreps == 1: return partial(_execute_compiled, compiled, out_avals, result_handlers) else: return partial(_execute_replicated, compiled, out_avals, result_handlers) def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args): """Configures input/output "must" aliasing based on `donated_args`.""" # First for every input array add it to `donations` iff it is a member of # `donated_args`. donations = defaultdict(deque) for arg_index, arg in enumerate(xla_args): if donated_args[arg_index]: for param_index, element in flatten_shape(c.GetShape(arg)): key = (element.dimensions(), element.numpy_dtype()) if tuple_args: param_number = 0 param_index = (arg_index,) + tuple(param_index) donations[key].append((param_number, param_index, arg_index)) else: param_number = arg_index donations[key].append((param_number, param_index, arg_index)) # Consume donations for outputs. out_donated_args = list(donated_args) for output_index, element in flatten_shape(c.GetShape(out_tuple)): key = (element.dimensions(), element.numpy_dtype()) if donations.get(key, ()): param_number, param_index, arg_index = donations[key].popleft() out_donated_args[arg_index] = False c.setup_alias(output_index, param_number, param_index) return tuple(out_donated_args) def _xla_callable_device(nreps, backend, device, arg_devices): if nreps > 1: if device is not None or backend is not None: raise ValueError(f"can't specify device or backend for jit-of-pmap, " f"got device={device} and backend={backend}") return None else: if device is None and backend is None: return _device_from_arg_devices(arg_devices) elif device is not None and backend is None: return device elif device is None and backend is not None: return xb.get_backend(backend).get_default_device_assignment(1)[0] else: assert False # Unreachable given the error check in _xla_callable # Used within _xla_callable_args and _xla_param to distinguish between None (no # sharding annotation set) and replicated. _replicated_param = object() def _xla_callable_args( c, avals, tuple_args, *, replicated=None, partitions=None, partitions_proto: bool = False, donated_invars=None): assert partitions is None or len(partitions) == len(avals) if not tuple_args: if replicated is None: replicated = [None] * len(avals) if partitions is None: parts: List[object] = [None] * len(avals) elif partitions_proto: parts = partitions else: parts = [_replicated_param if part is None else part for part in partitions] counts = it.count() xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto) if a is not abstract_token else xops.CreateToken(c) for (a, r, p) in safe_zip(avals, replicated, parts) for xla_shape in aval_to_xla_shapes(a)] if donated_invars is not None: donated_invars = [d for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars) for xla_shape in aval_to_xla_shapes(a)] return xla_args, donated_invars else: if replicated is not None: replicated = [r for a, r in zip(avals, replicated) if a is not abstract_token] if partitions is None: tuple_parts = None elif partitions_proto: tuple_parts = xb.tuple_sharding_proto(partitions) else: tuple_parts = tuple(partitions) tuple_shape = xc.Shape.tuple_shape( [shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token]) tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto) xla_inputs = iter(xla_destructure(c, tuple_param)) xla_args = [next(xla_inputs) if a is not abstract_token else xops.CreateToken(c) for a in avals] assert next(xla_inputs, None) is None return xla_args, donated_invars def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto): make_param = partial(xb.parameter, builder, param_num, xla_shape, replicated=replicated) with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding if partitions is None: return make_param() elif partitions is _replicated_param: return with_sharding(builder, None, make_param) else: return with_sharding(builder, partitions, make_param) def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args): device, = compiled.local_devices() input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) out_bufs = compiled.execute(input_bufs) check_special(xla_call_p, out_bufs) return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))] def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args): input_bufs = [ list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) for device in compiled.local_devices()] out_bufs = compiled.execute_on_local_devices(input_bufs)[0] check_special(xla_call_p, out_bufs) return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))] def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args): env = {core.unitvar: core.unit} map(env.setdefault, jaxpr.invars, args) map(env.setdefault, jaxpr.constvars, consts) outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v] for v in jaxpr.outvars] return [_copy_device_array_to_device(x, device) if type_is_device_array(x) else h(*device_put(x, device)) for h, x in zip(handlers, outs)] xla_call_p = core.CallPrimitive('xla_call') xla_call = xla_call_p.bind xla_call_p.def_impl(_xla_call_impl) def _xla_call_partial_eval_update_params(params, in_unknowns): call_jaxpr = params['call_jaxpr'] donated_invars = params['donated_invars'] if not in_unknowns and donated_invars: # JaxprTrace.post_process_call creates a call with no input tracers new_donated_invars = (False,) * len(call_jaxpr.invars) else: # JaxprTrace.process_call drops known input tracers donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk] new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars)) + tuple(donated_invars)) return dict(params, donated_invars=new_donated_invars) pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params def _xla_call_jvp_update_params(params, nz_tangents): donated_invars = params['donated_invars'] donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): donated_invars = params['donated_invars'] donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] donated_cotangents = [False for nz in nonzero_cts if nz] return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params def _xla_call_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr, donated_invars, device=None): del device, donated_invars # Ignored. subc = xb.make_computation_builder(f"jit_{name}") args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), extend_name_stack(name_stack, wrap_name(name, 'jit')), *args) subc = subc.build(xops.Tuple(subc, out_nodes)) return xops.Call(c, subc, list(in_nodes)) ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p) ### translation tables translations: Dict[core.Primitive, Callable] = {} translations_with_avals: Dict[core.Primitive, Callable] = {} parallel_translations: Dict[core.Primitive, Callable] = {} initial_style_translations: Dict[core.Primitive, Callable] = {} call_translations: Dict[core.Primitive, Callable] = {} backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict) call_translations[xla_call_p] = _xla_call_translation_rule def zeros_like_translation_rule(c, x): shape = c.get_shape(x) assert not shape.is_tuple() zero = xb.constant(c, np.array(0, shape.element_type())) return xops.Broadcast(zero, shape.dimensions()) translations[ad_util.zeros_like_p] = zeros_like_translation_rule def add_jaxvals_translation_rule(c, x, y): shape = c.get_shape(x) assert not shape.is_tuple() return xops.Add(x, y) translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule translations[ad_util.stop_gradient_p] = lambda c, x: x @lu.transformation def _tuple_output(*args, **kwargs): ans = yield args, kwargs yield (ans,) def lower_fun(fun, multiple_results, parallel=False, with_avals=False): # TODO(jakevdp): migrate dependent code & always use the with_avals=True. def f(c, *xla_args, **params): avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args] return f_with_avals(c, avals, xla_args, params) def f_with_avals(c, avals, xla_args, params): if parallel: axis_env = params.pop('axis_env') del params['platform'] else: axis_env = AxisEnv(1, (), ()) wrapped_fun = lu.wrap_init(fun, params) if not multiple_results: wrapped_fun = _tuple_output(wrapped_fun) if config.omnistaging_enabled: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '', *xla_args) else: pvals = [pe.PartialVal.unknown(a) for a in avals] jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True, stage_out=True) # type: ignore xla_consts = _xla_consts(c, consts) outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args) if multiple_results or any(v.aval._num_buffers > 1 for v in jaxpr.outvars): return xops.Tuple(c, outs) else: assert len(outs) == 1, outs return outs[0] return f_with_avals if with_avals else f def _array_aval_from_xla_shape(xla_shape): # This function instantiates the assumption that we can map fro XLA array # types to JAX array types. # TODO(mattjj): remove assumption can map XLA array types to JAX array types assert not xla_shape.is_tuple() return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) def lower_fun_initial_style(fun): def f(c, axis_env, name_stack, avals, backend, *xla_args, **params): if config.omnistaging_enabled: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals) outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts), name_stack, *xla_args) else: pvals = [pe.PartialVal.unknown(a) for a in avals] jaxpr, _, consts = pe.trace_to_jaxpr( lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore xla_consts = _xla_consts(c, consts) outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack, *xla_args) return xops.Tuple(c, outs) return f ### device-persistent data class Token(object): pass token = Token() pytype_aval_mappings[Token] = lambda _: abstract_token core.pytype_aval_mappings[Token] = lambda _: abstract_token xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),) xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token canonicalize_dtype_handlers[Token] = identity def _forward_method(attrname, self, fun, *args): return fun(getattr(self, attrname), *args) _forward_to_value = partial(_forward_method, "_value") # The following is used for the type _CppDeviceArray or _DeviceArray. DeviceArrayProtocol = Any if hasattr(xc, "DeviceArrayBase"): DeviceArray = xc.DeviceArrayBase else: # prior to jaxlib version 0.1.58. class DeviceArray: # type: ignore pass _CppDeviceArray: DeviceArrayProtocol = xc.Buffer _EXPERIMENTAL_CPP_DEVICE_ARRAY = False def make_device_array( aval: core.ShapedArray, device: Optional[Device], lazy_expr: Optional[lazy.LazyExpr], device_buffer: Union[PyLocalBuffer, "DeviceConstant"], ) -> Union[PyLocalBuffer, "_DeviceArray"]: """Returns a DeviceArray implementation based on arguments. This is to be used only within JAX. It will return either a PythonDeviceArray or a C++ equivalent implementation. """ if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr) and not isinstance(device_buffer, DeviceConstant)): assert isinstance(device_buffer, _CppDeviceArray) device_buffer._device = device # pylint: disable=protected-access device_buffer.aval = aval return device_buffer return _DeviceArray(aval, device, lazy_expr, device_buffer) def type_is_device_array(x): """Returns `True` if `x` is a non-sharded DeviceArray. Use this function instead of `type(x) is Devicearray`. """ type_x = type(x) return type_x is _DeviceArray or type_x is _CppDeviceArray class _DeviceArray(DeviceArray): # type: ignore """A DeviceArray is an ndarray backed by a single device memory buffer.""" # We don't subclass ndarray because that would open up a host of issues, # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods. __slots__ = [ "aval", "device_buffer", "_npy_value", "_device", "_lazy_expr" ] __array_priority__ = 100 # DeviceArray has methods that are dynamically populated in lax_numpy.py, # and this annotation is needed to make pytype happy. _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self, aval: core.ShapedArray, device: Optional[Device], lazy_expr: Optional[lazy.LazyExpr], device_buffer: PyLocalBuffer): """Initializer. Args: aval: The abstract value associated to this array (shape+dtype+weak_type). device: The optional sticky device. See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices lazy_expr: An optional `LayExpr`. `None` is equivalent to a trivial `LazyExpr`. device_buffer: The underlying buffer owning the on-device data. """ DeviceArray.__init__(self) self.aval = aval self.device_buffer = device_buffer self._device = device self._lazy_expr = lazy_expr self._npy_value = None if not core.skip_checks: assert type(aval) is ShapedArray npy_value = self._value assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape assert (device is None) or device is device_buffer.device() def _check_if_deleted(self): if self.device_buffer is deleted_buffer: raise RuntimeError("DeviceArray has been deleted.") def block_until_ready(self): """Blocks the caller until the buffer's value has been computed on device. This method is mostly useful for timing microbenchmarks that wish to time how long a computation takes, without transferring the result back to the host. Returns the buffer object (`self`). """ self._check_if_deleted() self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error return self @property def _value(self): self._check_if_deleted() if self._npy_value is None: if is_device_constant(self): self._npy_value = lazy.eval_lexpr(self._lazy_expr, None) else: self._npy_value = _force(self).device_buffer.to_py() self._npy_value.flags.writeable = False return self._npy_value @property def shape(self): return self.aval.shape @property def dtype(self): return self.aval.dtype @property def size(self): return prod(self.aval.shape) @property def ndim(self): return len(self.aval.shape) def copy_to_host_async(self): """Requests a copy of the buffer to the host.""" self._check_if_deleted() if self._npy_value is None and not is_device_constant(self): self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error def delete(self): """Deletes the device array and any cached copy on the host. It is an error to access the contents of a `DeviceArray` after it has been deleted. Use of this method is optional; device buffers will be reclaimed automatically by Python when a DeviceArray object is garbage collected. However, it is sometimes useful to have more explicit control over the time of deletion. """ self.device_buffer.delete() # pytype: disable=attribute-error self.device_buffer = deleted_buffer self._npy_value = None @property def __cuda_array_interface__(self): return _force(self).device_buffer.__cuda_array_interface__ # Adding methods dynamically to both _DeviceArray and _CppDeviceArray # pylint: disable=protected-access for device_array in [_DeviceArray, _CppDeviceArray]: def copy(self): """Returns an ndarray (backed by host memory, not device memory).""" return np.asarray(self) setattr(device_array, "copy", copy) def __repr__(self): line_width = np.get_printoptions()["linewidth"] prefix = '{}('.format(self.__class__.__name__.lstrip('_')) s = np.array2string(self._value, prefix=prefix, suffix=',', separator=', ', max_line_width=line_width) dtype_str = 'dtype={})'.format(self.dtype.name) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' if last_line_len + len(dtype_str) + 1 > line_width: sep = ' ' * len(prefix) return "{}{},{}{}".format(prefix, s, sep, dtype_str) setattr(device_array, "__repr__", __repr__) def item(self): if dtypes.issubdtype(self.dtype, np.complexfloating): return complex(self) elif dtypes.issubdtype(self.dtype, np.floating): return float(self) elif dtypes.issubdtype(self.dtype, np.integer): return int(self) elif dtypes.issubdtype(self.dtype, np.bool_): return bool(self) else: raise TypeError(self.dtype) setattr(device_array, "item", item) def __len__(self): try: return self.aval.shape[0] except IndexError as err: raise TypeError("len() of unsized object") from err # same as numpy error setattr(device_array, "__len__", __len__) def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: return self._value.__iter__() setattr(device_array, "__iter__", __iter__) def __reversed__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") else: return reversed(self._value) setattr(device_array, "__reversed__", __reversed__) def __format__(self, format_spec): # Simulates behavior of https://github.com/numpy/numpy/pull/9883 if self.ndim == 0: return format(self._value[()], format_spec) else: return format(self._value, format_spec) setattr(device_array, "__format__", __format__) def __array__(self, dtype=None, context=None): return np.asarray(self._value, dtype=dtype) setattr(device_array, "__array__", __array__) setattr(device_array, "__str__", partialmethod(_forward_to_value, str)) setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool)) setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool)) setattr(device_array, "__float__", lambda self: self._value.__float__()) setattr(device_array, "__int__", lambda self: self._value.__int__()) setattr(device_array, "__complex__", lambda self: self._value.__complex__()) setattr(device_array, "__hex__", partialmethod(_forward_to_value, hex)) setattr(device_array, "__oct__", partialmethod(_forward_to_value, oct)) setattr(device_array, "__index__", partialmethod(_forward_to_value, op.index)) to_bytes = lambda self, order="C": self._value.tobytes(order) setattr(device_array, "tobytes", to_bytes) del to_bytes setattr(device_array, "tolist", lambda self: self._value.tolist()) # pickle saves and loads just like an ndarray setattr(device_array, "__reduce__", partialmethod(_forward_to_value, op.methodcaller("__reduce__"))) # clobbered when jax.numpy is imported, but useful in tests setattr(device_array, "__eq__", lambda self, other: self._value == other) def __hash__(self): raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.") setattr(device_array, "__hash__", __hash__) # The following methods are dynamically overridden in lax_numpy.py. def raise_not_implemented(): raise NotImplementedError setattr(device_array, "__getitem__", lambda self, i: raise_not_implemented()) # pylint: enable=protected-access class DeletedBuffer(object): pass deleted_buffer = DeletedBuffer() class DeviceConstant(object): __slots__ = ["_device"] def __init__(self, device=None): self._device = device def device(self): return self._device def to_py(self): return None def is_device_constant(x): return type_is_device_array(x) and type(x.device_buffer) is DeviceConstant for device_array in [_CppDeviceArray, _DeviceArray]: core.literalable_types.add(device_array) core.pytype_aval_mappings[device_array] = ConcreteArray pytype_aval_mappings[device_array] = op.attrgetter('aval') canonicalize_dtype_handlers[device_array] = identity def _device_array_constant_handler(c, val, canonicalize_types=True): if is_device_constant(val): return lazy.stage_lexpr(c, val._lazy_expr, None) else: base_val = xb.constant(c, val.device_buffer.to_py()) return lazy.stage_lexpr(c, val._lazy_expr, base_val) xb.register_constant_handler(_DeviceArray, _device_array_constant_handler) xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler) def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]): x = _copy_device_array_to_device(x, device) return (_force(x).device_buffer,) device_put_handlers[_CppDeviceArray] = _device_put_device_array device_put_handlers[_DeviceArray] = _device_put_device_array def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[xc.Device]) -> Union[DeviceArrayProtocol, _DeviceArray]: if device is None: # no copying to be done because there's no target specified return x elif is_device_constant(x): # create a new DeviceArray with the same lazy expr, no copying return make_device_array(x.aval, device, x._lazy_expr, DeviceConstant(device)) elif xb.get_device_backend(device).platform == x.device_buffer.platform(): # source and target platforms are the same if x.device_buffer.device() == device: # no copying to be done because source equals target if x._device == device: return x else: moved_buf = x.device_buffer # We need to change stickyness else: # move the buffer with a device-to-device copy moved_buf = x.device_buffer.copy_to_device(device) else: # buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) return _DeviceArray(x.aval, device, x._lazy_expr, moved_buf) def _force(x: DeviceArrayProtocol) -> DeviceArrayProtocol: if lazy.is_trivial(x._lazy_expr): return x else: # force x on the device where it lives, but preserve stickiness on result if x._device: device = x._device else: device = x.device_buffer.device() force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr) result = force_fun(x) return make_device_array(x.aval, x._device, lazy.array(x.aval.shape), result) @cache() def _lazy_force_computation(aval: core.ShapedArray, device: Device, lexpr: lazy.LazyExpr ) -> Callable[[_DeviceArray], PyLocalBuffer]: c = xb.make_computation_builder("lazy_force") if lazy.is_constant(lexpr): param = None else: idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None] param_shape = [None] * len(idxs) for src, dst in idxs: param_shape[src] = aval.shape[dst] param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape)) xla_out = lazy.stage_lexpr(c, lexpr, param) built_c = c.build(xla_out) device = _device_from_arg_devices([device]) options = xb.get_compile_options( num_replicas=1, num_partitions=1, device_assignment=device and (device.id,)) compiled = backend_compile(xb.get_device_backend(device), built_c, options) force_fun: Callable[[_DeviceArray], PyLocalBuffer] if lazy.is_constant(lexpr): def force_fun(_): return compiled.execute([])[0] else: def force_fun(x): return compiled.execute([x.device_buffer])[0] return force_fun def _device_put_impl(x, device: Optional[Device] = None): if type_is_device_array(x): return _copy_device_array_to_device(x, device) try: a = abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(*device_put(x, device)) device_put_p = core.Primitive('device_put') device_put_p.def_impl(_device_put_impl) device_put_p.def_abstract_eval(lambda x, device=None: x) translations[device_put_p] = lambda c, x, device=None: x ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) masking.defvectorized(device_put_p) def _remat_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, call_jaxpr, device=None, concrete=None): """Lower remat to a Conditional which always returns true. This: 1. Circumvents common subexpression elimination. 2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks occur after the primal blocks, because cotangent is an input to the Conditional.""" del device, concrete # Unused. # Fake condition which always selects True branch. rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)), xb.constant(c, np.array(1, dtype=np.float32)), xc.Shape.array_shape(xc.PrimitiveType.F32, [])) pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32))) true_op = xops.Tuple(c, in_nodes) remat_subc = xb.make_computation_builder("remat_call_subcomputation") input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) args = [xops.GetTupleElement(input_op, i) for i in range(len(in_nodes))] out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (), extend_name_stack(name_stack, wrap_name(name, 'remat')), *args) out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes] remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes)) false_op = true_op dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation") xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) def zeros(xla_shape): if xla_shape.is_array(): shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() zero = xb.constant(dummy_subc, np.array(0, dtype=dtype)) return xops.Broadcast(zero, shape) else: # It is a token return xops.CreateToken(dummy_subc) out_nodes = [zeros(s) for s in out_node_shapes] dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) return xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc) call_translations[pe.remat_call_p] = _remat_translation_rule # type: ignore ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose, core.named_call_p) def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *, name="core_call", backend, call_jaxpr): subc = xb.make_computation_builder(name) args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)] out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), extend_name_stack(name_stack, name), *args) subc = subc.Build(xops.Tuple(subc, out_nodes)) return xops.Call(c, subc, list(in_nodes)) call_translations[core.named_call_p] = _named_call_translation_rule def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend, call_jaxpr): return _named_call_translation_rule( c, axis_env, in_nodes, name_stack, name="core_call", backend=backend, call_jaxpr=call_jaxpr) call_translations[core.call_p] = _call_translation_rule @config.register_omnistaging_disabler def omnistaging_disabler() -> None: global _pval_to_result_handler def _pval_to_result_handler(device, pval): pv, const = pval if pv is None: const = _device_put_impl(const, device) if device else const return lambda _: const else: return aval_to_result_handler(device, pv) pe.staged_out_calls.add(xla_call_p) # type: ignore