1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from collections import defaultdict, deque
17import itertools as it
18import operator as op
19from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
20                    Tuple, Union, NamedTuple)
21from warnings import warn
22
23from absl import logging
24import numpy as np
25
26from ..config import flags, bool_env, config
27from .. import core
28from .. import ad_util
29from .. import dtypes
30from .. import lazy
31from .. import linear_util as lu
32from jax._src import source_info_util
33from ..abstract_arrays import (make_shaped_array, array_types)
34from ..core import (ConcreteArray, ShapedArray, AbstractToken,
35                    Literal, pp_eqn_compact, raise_to_shaped, abstract_token)
36from jax._src.pprint_util import pp
37from .._src.util import (partial, partialmethod, cache, prod, unzip2,
38                    extend_name_stack, wrap_name, safe_zip, safe_map)
39from ..lib import xla_bridge as xb
40from ..lib import xla_client as xc
41from . import partial_eval as pe
42from . import ad
43from . import masking
44
45map, unsafe_map = safe_map, map
46zip, unsafe_zip = safe_zip, zip
47
48xe = xc._xla
49xops = xc._xla.ops
50
51# Types
52Backend = Any  # xc.LocalBackend (why does mypy not like this?)
53Device = Any  # xc.Device
54PyLocalBuffer = Any
55
56XlaOp = Any  # xla_extension.XlaOp
57XlaShape = Any # xla_client.Shape
58XlaComputationBuilder = Any  # xla_bridge._JaxComputationBuilder
59XlaExecutable = Any # xla_extension.LocalExecutable
60
61FLAGS = flags.FLAGS
62flags.DEFINE_bool('jax_debug_nans',
63                  bool_env('JAX_DEBUG_NANS', False),
64                  'Add nan checks to every operation.')
65flags.DEFINE_bool('jax_debug_infs',
66                  bool_env('JAX_DEBUG_INFS', False),
67                  'Add inf checks to every operation.')
68flags.DEFINE_bool('jax_log_compiles',
69                  bool_env('JAX_LOG_COMPILES', False),
70                  'Print a message each time a `jit` computation is compiled.')
71
72# This flag is set on exit; no logging should be attempted
73_on_exit = False
74
75def identity(x): return x
76
77_scalar_types = dtypes.python_scalar_dtypes.keys()
78
79# unit representation
80def _make_unit_constant(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
81def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
82def _device_put_unit(_, device):
83  backend = xb.get_device_backend(device)
84  return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
85                                    device),)
86def _make_array_shape(a):
87  if a.dtype is dtypes.float0:
88    return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
89  else:
90    return (xc.Shape.array_shape(a.dtype, a.shape),)
91
92### handlers
93
94xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
95
96def aval_to_xla_shapes(aval):
97  try:
98    return xla_shape_handlers[type(aval)](aval)
99  except KeyError as err:
100    raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
101
102xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {
103    core.AbstractUnit: _make_unit_shape,
104    ShapedArray: _make_array_shape,
105    ConcreteArray: _make_array_shape,
106}
107
108def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable:
109  try:
110    return xla_result_handlers[type(aval)](device, aval)
111  except KeyError as err:
112    raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
113
114def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
115  if aval.dtype is dtypes.float0:
116    return lambda _: np.zeros(aval.shape, dtypes.float0)
117  return partial(make_device_array, raise_to_shaped(aval), device,
118                 lazy.array(aval.shape))
119
120
121xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
122    core.AbstractUnit: lambda _, __: lambda _: core.unit,
123    ShapedArray: array_result_handler,
124    ConcreteArray: array_result_handler,
125}
126
127def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
128  x = canonicalize_dtype(x)
129  try:
130    return device_put_handlers[type(x)](x, device)
131  except KeyError as err:
132    raise TypeError(f"No device_put handler for type: {type(x)}") from err
133
134def _device_put_array(x, device: Optional[Device]):
135  backend = xb.get_device_backend(device)
136  if x.dtype is dtypes.float0:
137    x = np.zeros(x.shape, dtype=np.dtype(bool))
138  return (backend.buffer_from_pyval(x, device),)
139
140def _device_put_scalar(x, device):
141  return _device_put_array(dtypes.coerce_to_array(x), device)
142
143device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
144  core.Unit: _device_put_unit
145}
146device_put_handlers.update((t, _device_put_array) for t in array_types)
147device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
148
149# TODO(mattjj): try to remove this canonicalize_dtype stuff
150def canonicalize_dtype(x):
151  typ = type(x)
152  handler = canonicalize_dtype_handlers.get(typ)
153  if handler: return handler(x)
154  for typ in typ.mro():
155    handler = canonicalize_dtype_handlers.get(typ)
156    if handler: return handler(x)
157  raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
158
159def _canonicalize_ndarray_dtype(x):
160  return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
161
162def _canonicalize_python_scalar_dtype(typ, x):
163  return np.asarray(
164      x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
165
166canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
167canonicalize_dtype_handlers.update(
168    (t, _canonicalize_ndarray_dtype) for t in array_types)
169canonicalize_dtype_handlers.update(
170    (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
171
172def abstractify(x) -> core.AbstractValue:
173  typ = type(x)
174  aval_fn = pytype_aval_mappings.get(typ)
175  if aval_fn: return aval_fn(x)
176  for typ in typ.mro():
177    aval_fn = pytype_aval_mappings.get(typ)
178    if aval_fn: return aval_fn(x)
179  raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
180
181def _make_abstract_python_scalar(typ, _):
182  return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
183
184pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {
185    core.Unit: lambda _: core.abstract_unit,
186}
187pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
188pytype_aval_mappings.update(
189    (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
190
191# We can optionally set a Jaxpr rewriter that can be applied just before
192# compilation. This mechanism is used for compiling id_tap, we can
193# remove it once we bring the id_tap implementation into the core.
194outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
195def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
196  if outfeed_rewriter is not None:
197    return outfeed_rewriter(jaxpr)
198  else:
199    return jaxpr
200
201outfeed_primitives: Set[core.Primitive] = set()
202def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
203  """Finds if there are outfeed primitives anywhere inside a Jaxpr."""
204  return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
205             for eqn in jaxpr.eqns)
206
207def _param_uses_outfeed(param):
208  if type(param) is core.Jaxpr:
209    if jaxpr_uses_outfeed(param):
210      return True
211  elif type(param) is core.ClosedJaxpr:
212    if jaxpr_uses_outfeed(param.jaxpr):
213      return True
214  return False
215
216def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
217  if prim in outfeed_primitives:
218    return True
219  for param in params.values():
220    if isinstance(param, tuple):
221      if any(unsafe_map(_param_uses_outfeed, param)):
222        return True
223    elif _param_uses_outfeed(param):
224      return True
225  return False
226
227### op-by-op execution
228
229def arg_spec(x):
230  aval = abstractify(x)
231  try:
232    return aval, x._device
233  except:
234    return aval, None
235
236def apply_primitive(prim, *args, **params):
237  """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
238  compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
239  return compiled_fun(*args)
240
241
242def _partition_outputs(avals, outs):
243  nouts = [aval._num_buffers for aval in avals]
244  if not core.skip_checks:
245    assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
246  outs = iter(outs)
247  return [[next(outs) for _ in range(nout)] for nout in nouts]
248
249
250@cache()
251def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
252                                                   Optional[Device]], **params):
253  avals, arg_devices = unzip2(arg_specs)
254  donated_invars = (False,) * len(arg_specs)
255  device = _device_from_arg_devices(arg_devices)
256  backend = xb.get_device_backend(device)
257  if primitive_uses_outfeed(prim, params):
258    # We use the _xla_callable path, where we pre-process the primitives
259    def prim_fun(*args):
260      return prim.bind(*args, **params)
261    return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
262                         *arg_specs)
263  aval_out = prim.abstract_eval(*avals, **params)
264  if not prim.multiple_results:
265    handle_result = aval_to_result_handler(device, aval_out)
266  else:
267    handlers = map(partial(aval_to_result_handler, device), aval_out)
268    handle_result = lambda *bufs:\
269      tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs)))
270  tuple_args = len(avals) > 100
271  if prim in initial_style_translations:
272    nreps = initial_style_primitive_replicas(params)
273  else:
274    nreps = 1
275
276  if nreps > xb.device_count(backend):
277    raise ValueError(
278        f"compiling a primitive computation `{prim}` that requires {nreps} "
279        f"replicas, but only {xb.device_count(backend)} XLA devices are "
280        f"available on backend {backend.platform}.")
281  built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend,
282                                  tuple_args, *avals, **params)
283  options = xb.get_compile_options(
284      num_replicas=nreps,
285      num_partitions=1,
286      device_assignment=device and (device.id,))
287  options.parameter_is_tupled_arguments = tuple_args
288  compiled = backend_compile(backend, built_c, options)
289  if nreps == 1:
290    return partial(_execute_compiled_primitive, prim, compiled, handle_result)
291  else:
292    return partial(_execute_replicated_primitive, prim, compiled, handle_result)
293
294def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
295  """Given devices of inputs, determine where to perform a computation.
296
297  Args:
298    devices: list where each element is a either a `Device` instance or `None`.
299  Returns:
300    A `Device` instance or None.
301  Raises:
302    ValueError if input devices are inconsistent.
303  """
304  try:
305    device, = {d for d in devices if d is not None} or (None,)
306    return device
307  except ValueError as err:
308    msg = "primitive arguments must be colocated on the same device, got {}"
309    raise ValueError(msg.format(", ".join(map(str, devices)))) from err
310
311@cache()
312def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params):
313  c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
314  c.set_op_metadata(xc.OpMetadata(
315      op_type=prim.name,
316      op_name=str(pp_eqn_compact(prim.name, params))))
317  platform = xb.get_backend(backend).platform
318  xla_args, _ = _xla_callable_args(c, avals, tuple_args)
319  # return val always set as a side-effect on c
320  if prim in backend_specific_translations[platform]:
321    rule = backend_specific_translations[platform][prim]
322    ans = rule(c, *xla_args, **params)
323  elif prim in translations:
324    rule = translations[prim]
325    ans = rule(c, *xla_args, **params)
326  elif prim in translations_with_avals:
327    rule = translations_with_avals[prim]
328    ans = rule(c, avals, xla_args, params)
329  elif prim in initial_style_translations:
330    rule = initial_style_translations[prim]
331    ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
332         *xla_args, **params)
333  else:
334    raise NotImplementedError(f"XLA translation rule for {prim} not found")
335  assert isinstance(ans, xe.XlaOp)
336  c.clear_op_metadata()
337  try:
338    return c.build(ans)
339  except RuntimeError as e:
340    msg = (" ".join(map(str, e.args)) + "\n"
341           "This is a bug in JAX's shape-checking rules; please report it!\n"
342           "https://github.com/google/jax/issues\n")
343    raise RuntimeError(msg) from e
344
345def primitive_subcomputation(prim, *avals, **params):
346  axis_env = AxisEnv(1, (), ())
347  return primitive_computation(prim, axis_env, None, False, *avals, **params)
348
349def backend_compile(backend, built_c, options):
350  # we use a separate function call to ensure that XLA compilation appears
351  # separately in Python profiling results
352  return backend.compile(built_c, compile_options=options)
353
354def _execute_compiled_primitive(prim, compiled, result_handler, *args):
355  device, = compiled.local_devices()
356  input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
357  out_bufs = compiled.execute(input_bufs)
358  check_special(prim, out_bufs)
359  return result_handler(*out_bufs)
360
361def _execute_replicated_primitive(prim, compiled, result_handler, *args):
362  input_bufs = [
363      list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
364      for device in compiled.local_devices()]
365  out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
366  return result_handler(*out_bufs)
367
368
369def check_special(prim, bufs):
370  if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans:
371    for buf in bufs:
372      _check_special(prim.name, buf.xla_shape(), buf)
373
374def _check_special(name, xla_shape, buf):
375  assert not xla_shape.is_tuple()
376  if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
377    if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())):
378      raise FloatingPointError(f"invalid value (nan) encountered in {name}")
379    if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())):
380      raise FloatingPointError(f"invalid value (inf) encountered in {name}")
381
382### compiling jaxprs
383
384def prefetch(x):
385  if isinstance(x, DeviceArray):
386    x.copy_to_host_async()
387  return x
388
389def jaxpr_literals(jaxpr):
390  """Generates all the literals inside a jaxpr, including nested subjaxprs."""
391  for eqn in jaxpr.eqns:
392    for v in eqn.invars:
393      if type(v) is core.Literal:
394        yield v.val
395  for subjaxpr in core.subjaxprs(jaxpr):
396    yield from jaxpr_literals(subjaxpr)
397
398
399def _flatmap(func: Callable, vars: Sequence):
400  return list(it.chain.from_iterable(map(func, vars)))
401
402def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
403  return map(func, vars, _partition_outputs([v.aval for v in vars], nodes))
404
405def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
406  if backend not in ('cpu', 'gpu', 'tpu'):
407    platform = xb.get_backend(backend).platform  # canonicalize
408  else:
409    platform = backend
410
411  def read(v):
412    if type(v) is Literal:
413      return [xb.constant(c, canonicalize_dtype(v.val))]
414    else:
415      return env[v]
416
417  def aval(v):
418    if type(v) is Literal:
419      return abstractify(v.val)
420    else:
421      return v.aval
422
423  def write(v, node):
424    assert node is not None
425    env[v] = node
426
427  env = {}
428  _partitionmap(write, [core.unitvar], [_make_unit_constant(c)])
429  _partitionmap(write, jaxpr.constvars, consts)
430  _partitionmap(write, jaxpr.invars, args)
431  for eqn in jaxpr.eqns:
432    frame = source_info_util.user_frame(eqn.source_info)
433    c.set_op_metadata(xc.OpMetadata(
434        op_type=eqn.primitive.name,
435        op_name=str(pp(name_stack) >> pp_eqn_compact(
436            eqn.primitive.name, eqn.params)),
437        source_file=frame.file_name if frame else None,
438        source_line=frame.line_num if frame else None))
439    in_nodes = _flatmap(read, eqn.invars)
440    # TODO(jakevdp): migrate `translations` table to `translations_with_avals`
441    if eqn.primitive in backend_specific_translations[platform]:
442      rule = backend_specific_translations[platform][eqn.primitive]
443      ans = rule(c, *in_nodes, **eqn.params)
444    elif eqn.primitive in translations:
445      ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
446    elif eqn.primitive in translations_with_avals:
447      rule = translations_with_avals[eqn.primitive]
448      ans = rule(c, map(aval, eqn.invars), in_nodes, eqn.params)
449    elif eqn.primitive in initial_style_translations:
450      new_params = check_backend_params(eqn.params, backend)
451      rule = initial_style_translations[eqn.primitive]
452      ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
453                 map(aval, eqn.invars), backend, *in_nodes, **new_params)
454    elif eqn.primitive in parallel_translations:
455      rule = parallel_translations[eqn.primitive]
456      ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params)
457    elif eqn.primitive in call_translations:
458      new_params = check_backend_params(eqn.params, backend)
459      rule = call_translations[eqn.primitive]
460      ans = rule(c, axis_env, in_nodes,
461                 name_stack, backend=backend, **new_params)
462    else:
463      raise NotImplementedError(
464          f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
465
466    assert isinstance(ans, xe.XlaOp)
467    c.get_shape(ans)  # force xla to do shape error checking
468    if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars):
469      out_nodes = xla_destructure(c, ans)
470    else:
471      out_nodes = [ans]
472    c.clear_op_metadata()
473    _partitionmap(write, eqn.outvars, out_nodes)
474  return _flatmap(read, jaxpr.outvars)
475
476
477def xla_destructure(c, ans):
478  num_elements = len(c.get_shape(ans).tuple_shapes())
479  return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
480
481def check_backend_params(params, outer_backend):
482  # For nested calls, the outermost call sets the backend for all inner calls;
483  # it's an error if the inner call has a conflicting explicit backend spec.
484  inner_backend = params.get('backend', None)
485  if inner_backend and inner_backend != outer_backend:
486    raise ValueError(
487        f"Outer-jit backend specification {outer_backend} must match explicit "
488        f"inner-jit backend specification {inner_backend}.")
489  return {k: params[k] for k in params if k != 'backend'}
490
491
492class AxisEnv(NamedTuple):
493  """Represents a pmap mesh (only along the replica axes)."""
494  nreps: int
495  names: Tuple[Any, ...]
496  sizes: Tuple[int, ...]
497
498def extend_axis_env(env: AxisEnv, name, size: int):
499  return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
500
501def axis_read(axis_env, axis_name):
502  try:
503    return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
504  except ValueError:
505    raise NameError("unbound axis name: {}".format(axis_name)) from None
506
507def axis_groups(axis_env: AxisEnv, name):
508  if not isinstance(name, (list, tuple)):
509    name = (name,)
510  mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
511  trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
512  assert not ragged
513  mesh_spec = axis_env.sizes + (trailing_size,)
514  return _axis_groups(mesh_spec, mesh_axes)
515
516def _axis_groups(mesh_spec, mesh_axes):
517  """Computes replica group ids for a collective performed over a subset of the mesh.
518
519  Args:
520    mesh_spec: A sequence of integers representing the mesh shape.
521    mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
522      indicating over which axes the collective is performed.
523  Returns:
524    A tuple of replica groups (i.e. tuples containing replica ids).
525  """
526  iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
527  groups = np.reshape(
528      np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
529      (prod(np.take(mesh_spec, mesh_axes)), -1))
530  return tuple(unsafe_map(tuple, groups.T))
531
532def jaxpr_replicas(jaxpr: core.Jaxpr) -> int:
533  """The number of replicas needed for a jaxpr.
534
535  For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
536  subjaxprs. For a list of eqns, take the maximum number of replicas.
537  """
538  return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
539
540# TODO(mattjj): this function assumes that only pmap has a parameter named
541# axis_size, and that it corresponds to cross-replica mapping
542def eqn_replicas(eqn):
543  call_jaxpr = eqn.params.get("call_jaxpr")
544  if call_jaxpr:
545    return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
546  elif eqn.primitive in initial_style_translations:
547    return initial_style_primitive_replicas(eqn.params)
548  else:
549    return 1
550
551def initial_style_primitive_replicas(params):
552  return max(core.traverse_jaxpr_params(jaxpr_replicas, params), default=1)
553
554# TODO(mattjj,skyewm): the functions here are utilities for checking if
555# not-yet-supported features are used with multi-host programming
556
557def jaxpr_has_pmap(jaxpr):
558  """Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
559  for eqn in jaxpr.eqns:
560    if 'xla_pmap' in eqn.primitive.name:
561      return True
562  for subjaxpr in core.subjaxprs(jaxpr):
563    if jaxpr_has_pmap(subjaxpr):
564      return True
565  return False
566
567
568def jaxpr_collectives(jaxpr):
569  """Generates all the collective primitives anywhere inside a Jaxpr."""
570  for eqn in jaxpr.eqns:
571    if eqn.primitive in parallel_translations:
572      yield eqn.primitive
573  for subjaxpr in core.subjaxprs(jaxpr):
574    yield from jaxpr_collectives(subjaxpr)
575
576
577### xla_call underlying jit
578
579def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
580  compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
581                               *unsafe_map(arg_spec, args))
582  try:
583    return compiled_fun(*args)
584  except FloatingPointError:
585    assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs  # compiled_fun can only raise in this case
586    print("Invalid value encountered in the output of a jit function. "
587          "Calling the de-optimized version.")
588    # We want to run the wrapped function again (after _xla_callable already ran
589    # it), but linear_util.WrappedFun instances are meant to be run only once.
590    # In addition to re-executing the Python code, which is usually undesirable
591    # but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing
592    # any linear_util.py-style side effects, i.e. re-populating Stores created
593    # by any transformation_with_aux's applied to fun. Since this is
594    # intentional here, to avoid "Store occupied" errors we reset the stores to
595    # be empty.
596    for store in fun.stores: store and store.reset()
597    return fun.call_wrapped(*args)  # probably won't return
598
599def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
600  """Expands a given shape tree into a flat list of indices to arrays.
601
602  Given the following computation:
603
604  >>> c = xc.XlaBuilder("example")
605  >>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
606  >>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
607  >>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
608  >>> o = xops.Tuple(c, [p0, p1, p2])
609
610  We can query the arrays in the output tuple:
611
612  >>> flatten_shape(c.GetShape(o))
613  (((0,), f32[1]{0}),
614   ((1,), f32[2]{0}),
615   ((2,), f32[3]{0}))
616
617  Or the arrays in one of the parameters (which is itself an array):
618
619  >>> flatten_shape(c.GetShape(p0))
620  (((), f32[1]{0}),)
621
622  Args
623    s: The input shape.
624
625  Returns:
626    An iterable of pairs of indices and shapes for each array within the shape
627    tree.
628  """
629  def _flatten_shape(s, index):
630    if s.is_array():
631      yield index, s
632    else:
633      assert s.is_tuple()
634      for i, sub in enumerate(s.tuple_shapes()):
635        subindex = index + (i,)
636        if sub.is_tuple():
637          yield from _flatten_shape(sub, subindex)
638        else:
639          yield subindex, sub
640  return tuple(_flatten_shape(s, index=()))
641
642def _xla_consts(c, consts):
643  unique_consts = {id(const): const for const in consts}
644  xla_consts = {
645      id_: xb.constant(c, const) for id_, const in unique_consts.items()}
646  return [xla_consts[id(const)] for const in consts]
647
648@lu.cache
649def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
650  if device is not None and backend is not None:
651    raise ValueError("can't specify both a device and a backend for jit, "
652                     "got device={} and backend={}".format(device, backend))
653
654  abstract_args, arg_devices = unzip2(arg_specs)
655  if config.omnistaging_enabled:
656    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
657    if any(isinstance(c, core.Tracer) for c in consts):
658      raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
659  else:
660    pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
661    jaxpr, pvals, consts = pe.trace_to_jaxpr(  # type: ignore
662        fun, pvals, instantiate=False, stage_out=True, bottom=True)  # type: ignore
663  map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
664  jaxpr = apply_outfeed_rewriter(jaxpr)
665
666  nreps = jaxpr_replicas(jaxpr)
667  device = _xla_callable_device(nreps, backend, device, arg_devices)
668  backend = device.platform if device else backend
669  if config.omnistaging_enabled:
670    result_handlers = map(partial(aval_to_result_handler, device), out_avals)
671  else:
672    out_avals = [pval.get_aval() for pval in pvals]
673    result_handlers = map(partial(_pval_to_result_handler, device), pvals)  # type: ignore
674
675  # Computations that only produce constants and/or only rearrange their inputs,
676  # which are often produced from partial evaluation, don't need compilation,
677  # and don't need to force their (potentially lazy) arguments.
678  if not jaxpr.eqns:
679    return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
680
681  if not _on_exit:
682    log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
683    logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
684
685  if nreps > 1:
686    warn(f"The jitted function {fun.__name__} includes a pmap. Using "
687         "jit-of-pmap can lead to inefficient data movement, as the outer jit "
688         "does not preserve sharded data representations and instead collects "
689         "input and output arrays onto a single device. "
690         "Consider removing the outer jit unless you know what you're doing. "
691         "See https://github.com/google/jax/issues/2926.")
692
693  if nreps > xb.device_count(backend):
694    raise ValueError(
695        f"compiling computation that requires {nreps} replicas, but only "
696        f"{xb.device_count(backend)} XLA devices are available")
697
698  if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
699    raise NotImplementedError(
700        "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
701        "extra data movement anyway, so maybe you don't want it after all).")
702
703  tuple_args = len(abstract_args) > 100  # pass long arg lists as tuple for TPU
704
705  c = xb.make_computation_builder("jit_{}".format(fun.__name__))
706  xla_consts = _xla_consts(c, consts)
707  xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars)
708  out_nodes = jaxpr_subcomp(
709      c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
710      extend_name_stack(wrap_name(name, 'jit')), *xla_args)
711  out_tuple = xops.Tuple(c, out_nodes)
712  backend = xb.get_backend(backend)
713  if backend.platform in ("gpu", "tpu"):
714    donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
715  if any(donated_invars):
716    # TODO(tomhennigan): At call time we should mark these buffers as deleted.
717    unused_donations = [str(c.GetShape(a))
718                        for a, d in zip(xla_args, donated_invars) if d]
719    warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
720  built = c.build(out_tuple)
721
722  options = xb.get_compile_options(
723      num_replicas=nreps,
724      num_partitions=1,
725      device_assignment=(device.id,) if device else None)
726  options.parameter_is_tupled_arguments = tuple_args
727  compiled = backend_compile(backend, built, options)
728  if nreps == 1:
729    return partial(_execute_compiled, compiled, out_avals, result_handlers)
730  else:
731    return partial(_execute_replicated, compiled, out_avals, result_handlers)
732
733def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
734  """Configures input/output "must" aliasing based on `donated_args`."""
735  # First for every input array add it to `donations` iff it is a member of
736  # `donated_args`.
737  donations = defaultdict(deque)
738  for arg_index, arg in enumerate(xla_args):
739    if donated_args[arg_index]:
740      for param_index, element in flatten_shape(c.GetShape(arg)):
741        key = (element.dimensions(), element.numpy_dtype())
742        if tuple_args:
743          param_number = 0
744          param_index = (arg_index,) + tuple(param_index)
745          donations[key].append((param_number, param_index, arg_index))
746        else:
747          param_number = arg_index
748          donations[key].append((param_number, param_index, arg_index))
749
750  # Consume donations for outputs.
751  out_donated_args = list(donated_args)
752  for output_index, element in flatten_shape(c.GetShape(out_tuple)):
753    key = (element.dimensions(), element.numpy_dtype())
754    if donations.get(key, ()):
755      param_number, param_index, arg_index = donations[key].popleft()
756      out_donated_args[arg_index] = False
757      c.setup_alias(output_index, param_number, param_index)
758
759  return tuple(out_donated_args)
760
761def _xla_callable_device(nreps, backend, device, arg_devices):
762  if nreps > 1:
763    if device is not None or backend is not None:
764      raise ValueError(f"can't specify device or backend for jit-of-pmap, "
765                       f"got device={device} and backend={backend}")
766    return None
767  else:
768    if device is None and backend is None:
769      return _device_from_arg_devices(arg_devices)
770    elif device is not None and backend is None:
771      return device
772    elif device is None and backend is not None:
773      return xb.get_backend(backend).get_default_device_assignment(1)[0]
774    else:
775      assert False  # Unreachable given the error check in _xla_callable
776
777# Used within _xla_callable_args and _xla_param to distinguish between None (no
778# sharding annotation set) and replicated.
779_replicated_param = object()
780
781def _xla_callable_args(
782    c, avals, tuple_args, *,
783    replicated=None,
784    partitions=None,
785    partitions_proto: bool = False,
786    donated_invars=None):
787  assert partitions is None or len(partitions) == len(avals)
788  if not tuple_args:
789    if replicated is None:
790      replicated = [None] * len(avals)
791    if partitions is None:
792      parts: List[object] = [None] * len(avals)
793    elif partitions_proto:
794      parts = partitions
795    else:
796      parts = [_replicated_param if part is None else part
797               for part in partitions]
798    counts = it.count()
799    xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto)
800                if a is not abstract_token else xops.CreateToken(c)
801                for (a, r, p) in safe_zip(avals, replicated, parts)
802                for xla_shape in aval_to_xla_shapes(a)]
803    if donated_invars is not None:
804      donated_invars = [d
805                        for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars)
806                        for xla_shape in aval_to_xla_shapes(a)]
807    return xla_args, donated_invars
808  else:
809    if replicated is not None:
810      replicated = [r for a, r in zip(avals, replicated)
811                    if a is not abstract_token]
812    if partitions is None:
813      tuple_parts = None
814    elif partitions_proto:
815      tuple_parts = xb.tuple_sharding_proto(partitions)
816    else:
817      tuple_parts = tuple(partitions)
818    tuple_shape = xc.Shape.tuple_shape(
819        [shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token])
820    tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto)
821    xla_inputs = iter(xla_destructure(c, tuple_param))
822    xla_args = [next(xla_inputs) if a is not abstract_token else
823                xops.CreateToken(c) for a in avals]
824    assert next(xla_inputs, None) is None
825    return xla_args, donated_invars
826
827def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto):
828  make_param = partial(xb.parameter, builder, param_num, xla_shape,
829                       replicated=replicated)
830  with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
831  if partitions is None:
832    return make_param()
833  elif partitions is _replicated_param:
834    return with_sharding(builder, None, make_param)
835  else:
836    return with_sharding(builder, partitions, make_param)
837
838def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
839  device, = compiled.local_devices()
840  input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
841  out_bufs = compiled.execute(input_bufs)
842  check_special(xla_call_p, out_bufs)
843  return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
844
845def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
846  input_bufs = [
847      list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
848      for device in compiled.local_devices()]
849  out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
850  check_special(xla_call_p, out_bufs)
851  return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
852
853def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
854  env = {core.unitvar: core.unit}
855  map(env.setdefault, jaxpr.invars, args)
856  map(env.setdefault, jaxpr.constvars, consts)
857  outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
858          for v in jaxpr.outvars]
859  return [_copy_device_array_to_device(x, device) if type_is_device_array(x)
860          else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
861
862xla_call_p = core.CallPrimitive('xla_call')
863xla_call = xla_call_p.bind
864xla_call_p.def_impl(_xla_call_impl)
865
866def _xla_call_partial_eval_update_params(params, in_unknowns):
867  call_jaxpr = params['call_jaxpr']
868  donated_invars = params['donated_invars']
869  if not in_unknowns and donated_invars:
870    # JaxprTrace.post_process_call creates a call with no input tracers
871    new_donated_invars = (False,) * len(call_jaxpr.invars)
872  else:
873    # JaxprTrace.process_call drops known input tracers
874    donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk]
875    new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars))
876                          + tuple(donated_invars))
877  return dict(params, donated_invars=new_donated_invars)
878pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
879
880def _xla_call_jvp_update_params(params, nz_tangents):
881  donated_invars = params['donated_invars']
882  donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
883  new_donated_invars = (*donated_invars, *donated_tangents)
884  return dict(params, donated_invars=new_donated_invars)
885ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
886
887def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
888  donated_invars = params['donated_invars']
889  donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
890  donated_cotangents = [False for nz in nonzero_cts if nz]
891  return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
892ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
893
894
895def _xla_call_translation_rule(c, axis_env,
896                               in_nodes, name_stack, backend, name,
897                               call_jaxpr, donated_invars, device=None):
898  del device, donated_invars  # Ignored.
899  subc = xb.make_computation_builder(f"jit_{name}")
900  args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
901  out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
902                            extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
903  subc = subc.build(xops.Tuple(subc, out_nodes))
904  return xops.Call(c, subc, list(in_nodes))
905ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
906
907
908### translation tables
909
910translations: Dict[core.Primitive, Callable] = {}
911translations_with_avals: Dict[core.Primitive, Callable] = {}
912parallel_translations: Dict[core.Primitive, Callable] = {}
913initial_style_translations: Dict[core.Primitive, Callable] = {}
914call_translations: Dict[core.Primitive, Callable] = {}
915backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict)
916
917call_translations[xla_call_p] = _xla_call_translation_rule
918
919def zeros_like_translation_rule(c, x):
920  shape = c.get_shape(x)
921  assert not shape.is_tuple()
922  zero = xb.constant(c, np.array(0, shape.element_type()))
923  return xops.Broadcast(zero, shape.dimensions())
924translations[ad_util.zeros_like_p] = zeros_like_translation_rule
925
926def add_jaxvals_translation_rule(c, x, y):
927  shape = c.get_shape(x)
928  assert not shape.is_tuple()
929  return xops.Add(x, y)
930translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule
931
932translations[ad_util.stop_gradient_p] = lambda c, x: x
933
934
935@lu.transformation
936def _tuple_output(*args, **kwargs):
937  ans = yield args, kwargs
938  yield (ans,)
939
940def lower_fun(fun, multiple_results, parallel=False, with_avals=False):
941  # TODO(jakevdp): migrate dependent code & always use the with_avals=True.
942  def f(c, *xla_args, **params):
943    avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
944    return f_with_avals(c, avals, xla_args, params)
945
946  def f_with_avals(c, avals, xla_args, params):
947    if parallel:
948      axis_env = params.pop('axis_env')
949      del params['platform']
950    else:
951      axis_env = AxisEnv(1, (), ())
952    wrapped_fun = lu.wrap_init(fun, params)
953    if not multiple_results:
954      wrapped_fun = _tuple_output(wrapped_fun)
955    if config.omnistaging_enabled:
956      jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
957      outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
958                          *xla_args)
959    else:
960      pvals = [pe.PartialVal.unknown(a) for a in avals]
961      jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
962                                          stage_out=True)  # type: ignore
963      xla_consts = _xla_consts(c, consts)
964      outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
965    if multiple_results or any(v.aval._num_buffers > 1 for v in jaxpr.outvars):
966      return xops.Tuple(c, outs)
967    else:
968      assert len(outs) == 1, outs
969      return outs[0]
970
971  return f_with_avals if with_avals else f
972
973def _array_aval_from_xla_shape(xla_shape):
974  # This function instantiates the assumption that we can map fro XLA array
975  # types to JAX array types.
976  # TODO(mattjj): remove assumption can map XLA array types to JAX array types
977  assert not xla_shape.is_tuple()
978  return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
979
980def lower_fun_initial_style(fun):
981  def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
982    if config.omnistaging_enabled:
983      jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
984      outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
985                          name_stack, *xla_args)
986    else:
987      pvals = [pe.PartialVal.unknown(a) for a in avals]
988      jaxpr, _, consts = pe.trace_to_jaxpr(
989          lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)  # type: ignore
990      xla_consts = _xla_consts(c, consts)
991      outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
992                          *xla_args)
993    return xops.Tuple(c, outs)
994  return f
995
996
997### device-persistent data
998
999class Token(object): pass
1000token = Token()
1001
1002pytype_aval_mappings[Token] = lambda _: abstract_token
1003core.pytype_aval_mappings[Token] = lambda _: abstract_token
1004xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
1005xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
1006canonicalize_dtype_handlers[Token] = identity
1007
1008
1009def _forward_method(attrname, self, fun, *args):
1010  return fun(getattr(self, attrname), *args)
1011_forward_to_value = partial(_forward_method, "_value")
1012
1013
1014# The following is used for the type _CppDeviceArray or _DeviceArray.
1015DeviceArrayProtocol = Any
1016if hasattr(xc, "DeviceArrayBase"):
1017  DeviceArray = xc.DeviceArrayBase
1018else:
1019  # prior to jaxlib version 0.1.58.
1020  class DeviceArray:  # type: ignore
1021    pass
1022
1023_CppDeviceArray: DeviceArrayProtocol = xc.Buffer
1024
1025_EXPERIMENTAL_CPP_DEVICE_ARRAY = False
1026
1027
1028def make_device_array(
1029    aval: core.ShapedArray,
1030    device: Optional[Device],
1031    lazy_expr: Optional[lazy.LazyExpr],
1032    device_buffer: Union[PyLocalBuffer, "DeviceConstant"],
1033) -> Union[PyLocalBuffer, "_DeviceArray"]:
1034  """Returns a DeviceArray implementation based on arguments.
1035
1036  This is to be used only within JAX. It will return either a PythonDeviceArray
1037  or a C++ equivalent implementation.
1038  """
1039  if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr) and
1040      not isinstance(device_buffer, DeviceConstant)):
1041    assert isinstance(device_buffer, _CppDeviceArray)
1042    device_buffer._device = device    # pylint: disable=protected-access
1043    device_buffer.aval = aval
1044    return device_buffer
1045
1046  return _DeviceArray(aval, device, lazy_expr, device_buffer)
1047
1048
1049def type_is_device_array(x):
1050  """Returns `True` if `x` is a non-sharded DeviceArray.
1051
1052  Use this function instead of `type(x) is Devicearray`.
1053  """
1054  type_x = type(x)
1055  return type_x is _DeviceArray or type_x is _CppDeviceArray
1056
1057
1058class _DeviceArray(DeviceArray):  # type: ignore
1059  """A DeviceArray is an ndarray backed by a single device memory buffer."""
1060  # We don't subclass ndarray because that would open up a host of issues,
1061  # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
1062  __slots__ = [
1063      "aval", "device_buffer", "_npy_value", "_device", "_lazy_expr"
1064  ]
1065  __array_priority__ = 100
1066
1067  # DeviceArray has methods that are dynamically populated in lax_numpy.py,
1068  # and this annotation is needed to make pytype happy.
1069  _HAS_DYNAMIC_ATTRIBUTES = True
1070
1071  def __init__(self, aval: core.ShapedArray, device: Optional[Device],
1072               lazy_expr: Optional[lazy.LazyExpr],
1073               device_buffer: PyLocalBuffer):
1074    """Initializer.
1075
1076    Args:
1077      aval: The abstract value associated to this array (shape+dtype+weak_type).
1078      device:  The optional sticky device. See
1079        https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
1080      lazy_expr: An optional `LayExpr`. `None` is equivalent to a trivial
1081        `LazyExpr`.
1082      device_buffer: The underlying buffer owning the on-device data.
1083    """
1084    DeviceArray.__init__(self)
1085    self.aval = aval
1086    self.device_buffer = device_buffer
1087    self._device = device
1088    self._lazy_expr = lazy_expr
1089
1090    self._npy_value = None
1091    if not core.skip_checks:
1092      assert type(aval) is ShapedArray
1093      npy_value = self._value
1094      assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape
1095      assert (device is None) or device is device_buffer.device()
1096
1097  def _check_if_deleted(self):
1098    if self.device_buffer is deleted_buffer:
1099      raise RuntimeError("DeviceArray has been deleted.")
1100
1101  def block_until_ready(self):
1102    """Blocks the caller until the buffer's value has been computed on device.
1103
1104    This method is mostly useful for timing microbenchmarks that wish to
1105    time how long a computation takes, without transferring the result back
1106    to the host.
1107
1108    Returns the buffer object (`self`).
1109    """
1110    self._check_if_deleted()
1111    self.device_buffer.block_host_until_ready()  # pytype: disable=attribute-error
1112    return self
1113
1114  @property
1115  def _value(self):
1116    self._check_if_deleted()
1117    if self._npy_value is None:
1118      if is_device_constant(self):
1119        self._npy_value = lazy.eval_lexpr(self._lazy_expr, None)
1120      else:
1121        self._npy_value = _force(self).device_buffer.to_py()
1122      self._npy_value.flags.writeable = False
1123    return self._npy_value
1124
1125  @property
1126  def shape(self):
1127    return self.aval.shape
1128
1129  @property
1130  def dtype(self):
1131    return self.aval.dtype
1132
1133  @property
1134  def size(self):
1135    return prod(self.aval.shape)
1136
1137  @property
1138  def ndim(self):
1139    return len(self.aval.shape)
1140
1141  def copy_to_host_async(self):
1142    """Requests a copy of the buffer to the host."""
1143    self._check_if_deleted()
1144    if self._npy_value is None and not is_device_constant(self):
1145      self.device_buffer.copy_to_host_async()  # pytype: disable=attribute-error
1146
1147  def delete(self):
1148    """Deletes the device array and any cached copy on the host.
1149
1150    It is an error to access the contents of a `DeviceArray` after it has
1151    been deleted.
1152
1153    Use of this method is optional; device buffers will be reclaimed
1154    automatically by Python when a DeviceArray object is garbage collected.
1155    However, it is sometimes useful to have more explicit control over the
1156    time of deletion.
1157    """
1158    self.device_buffer.delete()  # pytype: disable=attribute-error
1159    self.device_buffer = deleted_buffer
1160    self._npy_value = None
1161
1162  @property
1163  def __cuda_array_interface__(self):
1164    return _force(self).device_buffer.__cuda_array_interface__
1165
1166
1167# Adding methods dynamically to both _DeviceArray and _CppDeviceArray
1168# pylint: disable=protected-access
1169for device_array in [_DeviceArray, _CppDeviceArray]:
1170
1171
1172  def copy(self):
1173    """Returns an ndarray (backed by host memory, not device memory)."""
1174    return np.asarray(self)
1175  setattr(device_array, "copy", copy)
1176
1177  def __repr__(self):
1178    line_width = np.get_printoptions()["linewidth"]
1179    prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
1180    s = np.array2string(self._value, prefix=prefix, suffix=',',
1181                        separator=', ', max_line_width=line_width)
1182    dtype_str = 'dtype={})'.format(self.dtype.name)
1183    last_line_len = len(s) - s.rfind('\n') + 1
1184    sep = ' '
1185    if last_line_len + len(dtype_str) + 1 > line_width:
1186      sep = ' ' * len(prefix)
1187    return "{}{},{}{}".format(prefix, s, sep, dtype_str)
1188
1189  setattr(device_array, "__repr__", __repr__)
1190
1191  def item(self):
1192    if dtypes.issubdtype(self.dtype, np.complexfloating):
1193      return complex(self)
1194    elif dtypes.issubdtype(self.dtype, np.floating):
1195      return float(self)
1196    elif dtypes.issubdtype(self.dtype, np.integer):
1197      return int(self)
1198    elif dtypes.issubdtype(self.dtype, np.bool_):
1199      return bool(self)
1200    else:
1201      raise TypeError(self.dtype)
1202
1203  setattr(device_array, "item", item)
1204
1205  def __len__(self):
1206    try:
1207      return self.aval.shape[0]
1208    except IndexError as err:
1209      raise TypeError("len() of unsized object") from err  # same as numpy error
1210
1211  setattr(device_array, "__len__", __len__)
1212
1213  def __iter__(self):
1214    if self.ndim == 0:
1215      raise TypeError("iteration over a 0-d array")  # same as numpy error
1216    else:
1217      return self._value.__iter__()
1218
1219  setattr(device_array, "__iter__", __iter__)
1220
1221  def __reversed__(self):
1222    if self.ndim == 0:
1223      raise TypeError("iteration over a 0-d array")
1224    else:
1225      return reversed(self._value)
1226
1227  setattr(device_array, "__reversed__", __reversed__)
1228
1229  def __format__(self, format_spec):
1230    # Simulates behavior of https://github.com/numpy/numpy/pull/9883
1231    if self.ndim == 0:
1232      return format(self._value[()], format_spec)
1233    else:
1234      return format(self._value, format_spec)
1235
1236  setattr(device_array, "__format__", __format__)
1237
1238  def __array__(self, dtype=None, context=None):
1239    return np.asarray(self._value, dtype=dtype)
1240
1241  setattr(device_array, "__array__", __array__)
1242
1243  setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
1244  setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
1245  setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
1246  setattr(device_array, "__float__", lambda self: self._value.__float__())
1247  setattr(device_array, "__int__", lambda self: self._value.__int__())
1248  setattr(device_array, "__complex__", lambda self: self._value.__complex__())
1249  setattr(device_array, "__hex__", partialmethod(_forward_to_value, hex))
1250  setattr(device_array, "__oct__", partialmethod(_forward_to_value, oct))
1251  setattr(device_array, "__index__", partialmethod(_forward_to_value, op.index))
1252  to_bytes = lambda self, order="C": self._value.tobytes(order)
1253  setattr(device_array, "tobytes", to_bytes)
1254  del to_bytes
1255  setattr(device_array, "tolist", lambda self: self._value.tolist())
1256
1257  # pickle saves and loads just like an ndarray
1258  setattr(device_array, "__reduce__",
1259          partialmethod(_forward_to_value, op.methodcaller("__reduce__")))
1260
1261  # clobbered when jax.numpy is imported, but useful in tests
1262  setattr(device_array, "__eq__", lambda self, other: self._value == other)
1263
1264  def __hash__(self):
1265    raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
1266
1267  setattr(device_array, "__hash__", __hash__)
1268
1269  # The following methods are dynamically overridden in lax_numpy.py.
1270  def raise_not_implemented():
1271    raise NotImplementedError
1272
1273  setattr(device_array, "__getitem__", lambda self, i: raise_not_implemented())
1274# pylint: enable=protected-access
1275
1276
1277class DeletedBuffer(object): pass
1278deleted_buffer = DeletedBuffer()
1279
1280class DeviceConstant(object):
1281  __slots__ = ["_device"]
1282  def __init__(self, device=None): self._device = device
1283  def device(self): return self._device
1284  def to_py(self): return None
1285
1286def is_device_constant(x):
1287  return type_is_device_array(x) and type(x.device_buffer) is DeviceConstant
1288
1289for device_array in [_CppDeviceArray, _DeviceArray]:
1290  core.literalable_types.add(device_array)
1291  core.pytype_aval_mappings[device_array] = ConcreteArray
1292  pytype_aval_mappings[device_array] = op.attrgetter('aval')
1293  canonicalize_dtype_handlers[device_array] = identity
1294
1295def _device_array_constant_handler(c, val, canonicalize_types=True):
1296  if is_device_constant(val):
1297    return lazy.stage_lexpr(c, val._lazy_expr, None)
1298  else:
1299    base_val = xb.constant(c, val.device_buffer.to_py())
1300    return lazy.stage_lexpr(c, val._lazy_expr, base_val)
1301xb.register_constant_handler(_DeviceArray, _device_array_constant_handler)
1302xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
1303
1304def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]):
1305  x = _copy_device_array_to_device(x, device)
1306  return (_force(x).device_buffer,)
1307device_put_handlers[_CppDeviceArray] = _device_put_device_array
1308device_put_handlers[_DeviceArray] = _device_put_device_array
1309
1310def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[xc.Device]) -> Union[DeviceArrayProtocol, _DeviceArray]:
1311  if device is None:
1312    # no copying to be done because there's no target specified
1313    return x
1314  elif is_device_constant(x):
1315    # create a new DeviceArray with the same lazy expr, no copying
1316    return make_device_array(x.aval, device, x._lazy_expr,
1317                               DeviceConstant(device))
1318  elif xb.get_device_backend(device).platform == x.device_buffer.platform():
1319    # source and target platforms are the same
1320    if x.device_buffer.device() == device:
1321      # no copying to be done because source equals target
1322      if x._device == device:
1323        return x
1324      else:
1325        moved_buf = x.device_buffer  # We need to change stickyness
1326    else:
1327      # move the buffer with a device-to-device copy
1328      moved_buf = x.device_buffer.copy_to_device(device)
1329  else:
1330    # buffers from different XLA backends are passed through the host.
1331    backend = xb.get_device_backend(device)
1332    moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
1333  return _DeviceArray(x.aval, device, x._lazy_expr, moved_buf)
1334
1335def _force(x: DeviceArrayProtocol) -> DeviceArrayProtocol:
1336  if lazy.is_trivial(x._lazy_expr):
1337    return x
1338  else:
1339    # force x on the device where it lives, but preserve stickiness on result
1340    if x._device:
1341      device = x._device
1342    else:
1343      device = x.device_buffer.device()
1344    force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr)
1345    result = force_fun(x)
1346    return make_device_array(x.aval, x._device, lazy.array(x.aval.shape), result)
1347
1348@cache()
1349def _lazy_force_computation(aval: core.ShapedArray,
1350                            device: Device, lexpr: lazy.LazyExpr
1351                            ) -> Callable[[_DeviceArray], PyLocalBuffer]:
1352  c = xb.make_computation_builder("lazy_force")
1353  if lazy.is_constant(lexpr):
1354    param = None
1355  else:
1356    idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None]
1357    param_shape = [None] * len(idxs)
1358    for src, dst in idxs:
1359      param_shape[src] = aval.shape[dst]
1360    param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape))
1361  xla_out = lazy.stage_lexpr(c, lexpr, param)
1362  built_c = c.build(xla_out)
1363
1364  device = _device_from_arg_devices([device])
1365  options = xb.get_compile_options(
1366      num_replicas=1,
1367      num_partitions=1,
1368      device_assignment=device and (device.id,))
1369  compiled = backend_compile(xb.get_device_backend(device), built_c, options)
1370
1371  force_fun: Callable[[_DeviceArray], PyLocalBuffer]
1372  if lazy.is_constant(lexpr):
1373    def force_fun(_):
1374      return compiled.execute([])[0]
1375  else:
1376    def force_fun(x):
1377      return compiled.execute([x.device_buffer])[0]
1378  return force_fun
1379
1380
1381def _device_put_impl(x, device: Optional[Device] = None):
1382  if type_is_device_array(x):
1383    return _copy_device_array_to_device(x, device)
1384
1385  try:
1386    a = abstractify(x)
1387  except TypeError as err:
1388    raise TypeError(
1389        f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
1390  return aval_to_result_handler(device, a)(*device_put(x, device))
1391
1392device_put_p = core.Primitive('device_put')
1393device_put_p.def_impl(_device_put_impl)
1394device_put_p.def_abstract_eval(lambda x, device=None: x)
1395translations[device_put_p] = lambda c, x, device=None: x
1396ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
1397masking.defvectorized(device_put_p)
1398
1399
1400def _remat_translation_rule(c, axis_env, in_nodes,
1401                            name_stack, backend, name, call_jaxpr,
1402                            device=None, concrete=None):
1403  """Lower remat to a Conditional which always returns true. This:
1404    1. Circumvents common subexpression elimination.
1405    2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
1406       occur after the primal blocks, because cotangent is an input to the
1407       Conditional."""
1408  del device, concrete  # Unused.
1409  # Fake condition which always selects True branch.
1410  rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
1411                        xb.constant(c, np.array(1, dtype=np.float32)),
1412                        xc.Shape.array_shape(xc.PrimitiveType.F32, []))
1413  pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
1414
1415  true_op = xops.Tuple(c, in_nodes)
1416  remat_subc = xb.make_computation_builder("remat_call_subcomputation")
1417  input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
1418  args = [xops.GetTupleElement(input_op, i) for i in range(len(in_nodes))]
1419  out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
1420                            extend_name_stack(name_stack, wrap_name(name, 'remat')),
1421                            *args)
1422  out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes]
1423  remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))
1424
1425  false_op = true_op
1426  dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")
1427  xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
1428
1429  def zeros(xla_shape):
1430    if xla_shape.is_array():
1431      shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
1432      zero = xb.constant(dummy_subc, np.array(0, dtype=dtype))
1433      return xops.Broadcast(zero, shape)
1434    else:
1435      # It is a token
1436      return xops.CreateToken(dummy_subc)
1437  out_nodes = [zeros(s) for s in out_node_shapes]
1438  dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
1439
1440  return xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc)
1441call_translations[pe.remat_call_p] = _remat_translation_rule  # type: ignore
1442
1443
1444ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
1445                                                     core.named_call_p)
1446
1447
1448def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *,
1449                                 name="core_call", backend, call_jaxpr):
1450  subc = xb.make_computation_builder(name)
1451  args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
1452  out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
1453                            extend_name_stack(name_stack, name), *args)
1454  subc = subc.Build(xops.Tuple(subc, out_nodes))
1455  return xops.Call(c, subc, list(in_nodes))
1456call_translations[core.named_call_p] = _named_call_translation_rule
1457
1458
1459def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend,
1460                           call_jaxpr):
1461  return _named_call_translation_rule(
1462      c, axis_env, in_nodes, name_stack, name="core_call",
1463      backend=backend, call_jaxpr=call_jaxpr)
1464call_translations[core.call_p] = _call_translation_rule
1465
1466
1467@config.register_omnistaging_disabler
1468def omnistaging_disabler() -> None:
1469  global _pval_to_result_handler
1470
1471  def _pval_to_result_handler(device, pval):
1472    pv, const = pval
1473    if pv is None:
1474      const = _device_put_impl(const, device) if device else const
1475      return lambda _: const
1476    else:
1477      return aval_to_result_handler(device, pv)
1478
1479  pe.staged_out_calls.add(xla_call_p)  # type: ignore
1480