1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementation of pmap and related functionality."""
15
16# A ShardingSpec describes at a high level how a logical array is sharded across
17# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
18# describe how to shard inputs to a parallel computation). spec_to_indices()
19# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
20# how the sharded array is "laid out" across devices. Given a sequence of
21# devices, we shard the data across the devices in row-major order, with
22# replication treated as an extra inner dimension.
23#
24# For example, given the logical data array [1, 2, 3, 4], if we were to
25# partition this array 4 ways with a replication factor of 2, for a total of 8
26# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
27#
28# This encoding is assumed by various parts of the system, e.g. generating
29# replica groups for collective operations.
30
31import sys
32from contextlib import contextmanager
33from collections import defaultdict, OrderedDict
34import itertools as it
35import operator as op
36import threading
37from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
38                    Type, Union, Iterable, no_type_check, NamedTuple, TYPE_CHECKING)
39
40from absl import logging
41import numpy as np
42
43from ..config import flags, config
44from .. import core
45from .. import linear_util as lu
46from .. import lazy
47from ..abstract_arrays import array_types
48from ..core import ConcreteArray, ShapedArray, Var, Literal
49from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
50                         extend_name_stack, wrap_name, assert_unreachable,
51                         tuple_insert, tuple_delete, taggedtuple, curry)
52from ..lib import xla_bridge as xb
53from ..lib import xla_client as xc
54from ..tree_util import tree_flatten, tree_map
55from .batching import broadcast, not_mapped, moveaxis
56from . import batching
57from . import partial_eval as pe
58from . import xla
59from . import ad
60
61if sys.version_info >= (3, 9):
62  OrderedDictType = OrderedDict
63else:
64  OrderedDictType = Dict
65
66xops = xc.ops
67
68FLAGS = flags.FLAGS
69
70unsafe_map, map = map, safe_map  # type: ignore
71
72Index = Union[int, slice, Tuple[Union[int, slice], ...]]
73
74
75class NoSharding:
76
77  def __eq__(self, other):
78    return isinstance(other, NoSharding)
79
80  def __repr__(self):
81    return "NoSharding()"
82
83
84_UNSHARDED_INSTANCE = NoSharding()
85
86# mypy is very unhappy about taggedtuple
87if TYPE_CHECKING:
88  class Unstacked(NamedTuple):
89    size: int
90else:
91  Unstacked = taggedtuple('Unstacked', ('size',))
92
93class Chunked:
94  chunks: Tuple[int, ...]
95
96  def __init__(self, chunks: Union[int, Tuple[int, ...]]):
97    if not isinstance(chunks, tuple):
98      chunks = (chunks,)
99    object.__setattr__(self, 'chunks', chunks)
100
101  def __setattr__(self, name, value):
102    raise RuntimeError("Chunked is immutable")
103
104  def __delattr__(self, name):
105    raise RuntimeError("Chunked is immutable")
106
107  def __hash__(self):
108    return hash(self.chunks)
109
110  def __eq__(self, other):
111    return type(other) is Chunked and self.chunks == other.chunks
112
113  def __repr__(self):
114    return f'Chunked({self.chunks})'
115
116"""
117Represents all the ways we can shard a dimension.
118- `None` means no sharding;
119- `Chunked` means that the dimension is split into the specified number of chunks,
120  but the split dimension itself is preserved inside the map;
121- `Unstacked` means that the dimension is split into chunks of size 1, and doesn't
122  appear inside the map.
123"""
124AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
125
126# mypy is very unhappy about taggedtuple
127if TYPE_CHECKING:
128  class ShardedAxis(NamedTuple):
129    axis: int
130  class Replicated(NamedTuple):
131    replicas: int
132else:
133  ShardedAxis = taggedtuple('ShardedAxis', ('axis',))
134  Replicated = taggedtuple('Replicated', ('replicas',))
135
136"""
137Assigns sharded axes to mesh dimensions.
138
139When no axis is assigned, the data is replicated.
140Note that `ShardedAxis(2)` refers to the second actually sharded axis (i.e.
141counting as if the None dimensions of sharding were filtered out). For example,
142given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry of `ShardedAxis(1)`
143refers to the `Chunked(m)` axis, not the `None`.
144"""
145MeshDimAssignment = Union[ShardedAxis, Replicated]
146
147class ShardingSpec:
148  """Describes the sharding of an ndarray.
149
150  Attributes:
151    sharding: specifies how the array is supposed to get partitioned into chunks.
152      Its length should match the rank of the array. See the docstring of
153      `AvalDimSharding` for the supported partitioning schemes.
154    mesh_mapping` describes an assignments of the array chunks created by `sharding`
155      to a logical device mesh. The length of the tuple is equal to the rank of the
156      mesh. Each mesh dimension can either get partitions of data varying along one
157      of the sharded dimensions, or the data can be replicated. See the docstring of
158      `MeshDimAssignment` for more information.
159  """
160  sharding: Tuple[AvalDimSharding, ...]
161  mesh_mapping: Tuple[MeshDimAssignment, ...]
162
163  def __init__(self,
164               sharding: Iterable[AvalDimSharding],
165               mesh_mapping: Iterable[MeshDimAssignment]):
166    self.sharding = tuple(sharding)
167    assert all(x is not None for x in self.sharding)
168    self.mesh_mapping = tuple(mesh_mapping)
169
170  @property
171  def mesh_shape(self):
172    sharded_axis_sizes = []
173    for sharding in self.sharding:
174      if isinstance(sharding, NoSharding):
175        continue
176      elif isinstance(sharding, Unstacked):
177        sharded_axis_sizes.append(sharding.size)
178      elif isinstance(sharding, Chunked):
179        sharded_axis_sizes.extend(sharding.chunks)
180      else:
181        assert_unreachable(sharding)
182    return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
183                 for a in self.mesh_mapping)
184
185  def sharding_proto(self):
186    """Converts a ShardingSpec to an OpSharding proto.
187
188    See
189    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
190    for details on the OpSharding proto.
191    Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
192    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
193    """
194    mesh_shape = self.mesh_shape
195    mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
196
197    sharded_axes = {}  # maps sharded axis identifiers to mesh axis indices to which they're mapped
198    replicated_maxes = []  # lists mesh axis identifiers to replicate over
199    for maxis, assignment in enumerate(self.mesh_mapping):
200      if isinstance(assignment, Replicated):
201        replicated_maxes.append(maxis)
202      elif isinstance(assignment, ShardedAxis):
203        sharded_axes[assignment.axis] = maxis
204      else:
205        assert_unreachable(assignment)
206
207    proto = xc.OpSharding()
208    if len(replicated_maxes) == len(self.mesh_mapping):
209      proto.type = xc.OpSharding.Type.REPLICATED
210      return proto
211    else:
212      proto.type = xc.OpSharding.Type.OTHER
213
214    mesh_permutation = []
215    new_mesh_shape = []
216    next_sharded_axis = 0
217    for axis, sharding in enumerate(self.sharding):
218      if isinstance(sharding, NoSharding):
219        new_mesh_shape.append(1)  # Add a dummy mesh axis we won't be sharding over
220      elif isinstance(sharding, Chunked):
221        for nchunks in sharding.chunks:
222          maxis = sharded_axes[next_sharded_axis]
223          assert mesh_shape[maxis] == nchunks
224          mesh_permutation.append(maxis)
225          next_sharded_axis += 1
226        new_mesh_shape.append(int(np.prod(sharding.chunks)))
227      elif isinstance(sharding, Unstacked):
228        raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
229      else:
230        assert_unreachable(sharding)
231
232    # Create the partial sharding proto if tensor is replicated over some mesh axes
233    if replicated_maxes:
234      new_mesh_shape.append(-1)
235      mesh_permutation.extend(replicated_maxes)
236      proto.replicate_on_last_tile_dim = True
237
238    proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
239    proto.tile_assignment_dimensions = list(proto_mesh.shape)
240    proto.tile_assignment_devices = list(proto_mesh.flat)
241    return proto
242
243  def indices(self, shape: Tuple[int, ...]) -> np.ndarray:
244    """Returns NumPy-style indices corresponding to a sharding spec.
245
246    Args:
247      shape: The shape of the logical array being sharded.
248
249    Returns:
250      An ndarray with the same shape as the logical mesh (as derived form
251      `mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
252      the data array to be placed on a corresponding device. The indices can be
253      ints, slice objects with step=1, or tuples of those.
254    """
255    assert len(shape) == len(self.sharding), (shape, self.sharding)
256
257    axis_indices: List[Sequence[Index]] = []
258    shard_indices_shape = []
259    for dim, sharding in enumerate(self.sharding):
260      axis_size = shape[dim]
261      if isinstance(sharding, NoSharding):
262        axis_indices.append([slice(None)])
263        # NOTE: We don't append unsharded dimensions to shard_indices_shape here,
264        #       because they do not appear in the mesh mapping.
265      elif isinstance(sharding, Unstacked):
266        assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
267        axis_indices.append(range(axis_size))
268        shard_indices_shape.append(axis_size)
269      elif isinstance(sharding, Chunked):
270        total_chunks = int(np.prod(sharding.chunks))
271        shard_size, ragged = divmod(axis_size, total_chunks)
272        assert not ragged, (axis_size, total_chunks, dim)
273        axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
274                             for i in range(total_chunks)])
275        shard_indices_shape.extend(sharding.chunks)
276      else:
277        assert_unreachable(sharding)
278
279    # shard_indices is an ndarray representing the sharded axes of the logical array,
280    # with each dimension having size equal to the number of shards across the corresponding
281    # logical array dimension, and each element containing the multi-dimensional index that
282    # is used to extract the corresponding shard of the logical array.
283    shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object)
284    for i, idxs in enumerate(it.product(*axis_indices)):
285      shard_indices[i] = idxs
286    shard_indices = shard_indices.reshape(shard_indices_shape)
287
288    # Ensure that each sharded axis is used exactly once in the mesh mapping
289    num_sharded_dim = len(shard_indices_shape)
290    sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
291    assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
292            len(sharded_dim_perm) == num_sharded_dim)
293    # Replicate/reorder the indices according to the mesh mapping
294    replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
295    replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
296    perm = [next(replica_dim) if isinstance(a, Replicated) else
297            len(replica_sizes) + next(sharded_dim)
298            for a in self.mesh_mapping]
299    return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
300              .transpose(perm))
301
302  def __eq__(self, other):
303    return (self.sharding, self.mesh_mapping) == (other.sharding, other.mesh_mapping)
304
305  def __hash__(self):
306    return hash((self.sharding, self.mesh_mapping))
307
308  def __repr__(self):
309    return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
310
311def spec_to_indices(shape: Tuple[int, ...],
312                    spec: ShardingSpec) -> Tuple[Index, ...]:
313  """Returns numpy-style indices corresponding to a sharding spec.
314
315  Each index describes a shard of the array. The order of the indices is the
316  same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
317  row-major).
318
319  Args:
320    shape: The shape of the logical array being sharded.
321    spec: Describes how the array is sharded and how the shards are assigned to
322          the logical mesh.
323
324  Returns:
325    A tuple of length equal to the size of the mesh (inferred as the product of
326    sharded dimension sizes and all replication factors).  Each element is an
327    int, a slice object with step=1, or a tuple thereof, to be treated as an
328    index into the full logical array.
329  """
330  return tuple(spec.indices(shape).flat)
331
332
333### util
334
335def identity(x): return x
336
337# TODO(skye): expose PyLocalBuffers in xla_client
338def shard_args(devices: Sequence[xb.xla_client.Device],
339               indices: Sequence[Sequence[Index]],
340               args) -> Sequence[Sequence[xb.xla_client._xla.PyLocalBuffer]]:
341  """Shard each argument data array along its leading axis.
342
343  Args:
344    devices: sequence of Devices mapping replica index to a physical device.
345    indices: sequence of the same length as `args` describing how each arg
346      should be sharded/replicated across `devices`. Each element in `indices`
347      is the same length as `devices`.
348    args: a sequence of JaxTypes representing arguments to be sharded according
349      to `indices` and placed on `devices`.
350
351  Returns:
352    A list of device buffers with the same length as `devices` indexed by
353    replica number, so that the nth element is the argument to be passed to the
354    nth replica.
355  """
356  nargs, nrep = len(args), len(devices)
357  buffers = [[None] * nargs for _ in range(nrep)]
358  for a, arg in enumerate(args):
359    # The shard_arg_handlers allow an extensible set of types to be sharded, but
360    # inline handling for ShardedDeviceArray as a special case for performance
361    # NOTE: we compare indices instead of sharding_spec because
362    # pmap_benchmark.pmap_shard_args_benchmark indicates this is faster.
363    if type(arg) is ShardedDeviceArray and indices[a] == arg.indices:
364      for r, buf in enumerate(arg.device_buffers):
365        buffers[r][a] = (buf if buf.device() == devices[r]
366                         else buf.copy_to_device(devices[r]))
367    else:
368      arg = xla.canonicalize_dtype(arg)
369      bufs = shard_arg_handlers[type(arg)](arg, devices, indices[a])
370      for r, buf in enumerate(bufs):
371        buffers[r][a] = buf
372
373  return buffers
374
375
376shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
377shard_arg_handlers[core.Unit] = \
378    lambda x, devices, _: device_put(core.unit, devices, replicate=True)
379def _shard_array(x, devices, indices):
380  return device_put([x[i] for i in indices], devices)
381for _t in array_types:
382  shard_arg_handlers[_t] = _shard_array
383
384def _shard_device_array(x, devices, indices):
385  start_indices, limit_indices, removed_dims = map(tuple, unzip3(
386      _as_slice_indices(x, idx) for idx in indices))
387  shards = x._multi_slice(start_indices, limit_indices, removed_dims)
388  return device_put(shards, devices)
389shard_arg_handlers[xla._DeviceArray] = _shard_device_array
390shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
391
392
393# NOTE(skye): we could refactor to generate _multi_slice parameters directly
394# from the input ShardingSpec, rather than the indices. However, this would
395# require duplicating the ordering logic of spec_to_indices, which is more
396# subtle and more likely to change than the index logic we have to support here.
397def _as_slice_indices(arr: xla.DeviceArrayProtocol, idx: Index) -> Tuple[
398    Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
399  """Returns start_indices, limit_indices, removed_dims"""
400  start_indices = [0] * arr.ndim
401  limit_indices = list(arr.shape)
402  removed_dims = []
403
404  tuple_idx = idx if isinstance(idx, tuple) else (idx,)
405  for dim, sub_idx in enumerate(tuple_idx):
406    if isinstance(sub_idx, int):
407      start_indices[dim] = sub_idx
408      limit_indices[dim] = sub_idx + 1
409      removed_dims.append(dim)
410    elif sub_idx == slice(None):
411      continue
412    else:
413      assert isinstance(sub_idx, slice), sub_idx
414      assert isinstance(sub_idx.start, int), sub_idx
415      assert isinstance(sub_idx.stop, int), sub_idx
416      start_indices[dim] = sub_idx.start
417      limit_indices[dim] = sub_idx.stop
418
419  return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
420
421
422def shard_aval(size, axis: int, aval):
423  try:
424    return shard_aval_handlers[type(aval)](size, axis, aval)
425  except KeyError as err:
426    raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
427shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
428shard_aval_handlers[core.AbstractUnit] = lambda size, axis, x: x
429def _shard_abstract_array(size, axis: int, x):
430  try:
431    if x.shape[axis] != size:
432      raise ValueError(f"Axis size {size} does not match dimension {axis} of "
433                       f"shape {x.shape}")
434  except IndexError:
435    raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
436  return ShapedArray(tuple_delete(x.shape, axis), x.dtype)
437shard_aval_handlers[ShapedArray] = _shard_abstract_array
438
439# TODO(skye): expose PyLocalBuffers in xla_client
440def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
441                           indices: Optional[Tuple[Index]],
442                           aval: core.AbstractValue) -> Callable[
443                               [List[xb.xla_client._xla.PyLocalBuffer]], Any]:
444  """Returns a function for handling the raw buffers of a single output aval.
445
446  Args:
447    sharding_spec: indicates how the output is sharded across devices, or None
448      for non-array avals.
449    indices: the pre-computed result of spec_to_indices, or None for non-array
450      avals.
451    aval: the output AbstractValue.
452
453  Returns:
454    A function for handling the PyLocalBuffers that will eventually be produced
455    for this output. The function will return an object suitable for returning
456    to the user, e.g. a ShardedDeviceArray.
457  """
458  try:
459    return pxla_result_handlers[type(aval)](sharding_spec, indices, aval)
460  except KeyError as err:
461    raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
462                    ) from err
463
464PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client._xla.PyLocalBuffer]], Any]]
465pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
466pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
467def array_result_handler(sharding_spec, indices, aval: ShapedArray):
468  return lambda bufs: ShardedDeviceArray(aval, sharding_spec, bufs, indices)
469pxla_result_handlers[ShapedArray] = array_result_handler
470pxla_result_handlers[ConcreteArray] = array_result_handler
471
472
473### lazy device-memory persistence and result handling
474
475class ShardedDeviceArray(xla._DeviceArray):
476  """A ShardedDeviceArray is an ndarray sharded across devices.
477
478  The purpose of a ShardedDeviceArray is to reduce the number of transfers when
479  executing replicated computations, by allowing results to persist on the
480  devices that produced them. That way dispatching a similarly replicated
481  computation that consumes the same sharded memory layout does not incur any
482  transfers.
483
484  A ShardedDeviceArray represents one logical ndarray value, and simulates the
485  behavior of an ndarray so that it can be treated by user code as an ndarray;
486  that is, it is only an optimization to reduce transfers.
487
488  Attributes:
489    aval: A ShapedArray indicating the shape and dtype of this array.
490    sharding_spec: describes how this array is sharded across `device_buffers`.
491    device_buffers: the buffers containing the data for this array. Each buffer
492      is the same shape and on a different device. Buffers are in row-major
493      order, with replication treated as an extra innermost dimension.
494    indices: the result of spec_to_indices(sharding_spec). Can optionally be
495      precomputed for efficiency. A list the same length as
496      `device_buffers`. Each index indicates what portion of the full array is
497      stored in the corresponding device buffer, i.e. `array[indices[i]] ==
498      device_buffers[i].to_py()`.
499  """
500  __slots__ = ["device_buffers", "sharding_spec", "indices",
501               "_one_replica_buffer_indices"]
502
503  # TODO(skye): expose PyLocalBuffers in xla_client
504  def __init__(self,
505               aval: ShapedArray,
506               sharding_spec, # TODO(skye): add type annotation back, see below
507               device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None,
508               indices: Optional[Tuple[Index, ...]] = None):
509    xla.DeviceArray.__init__(self)
510
511    # TODO(skye): this is temporary staging while we switch users over to
512    # providing sharding_spec. It assumes that any pre-existing callers are
513    # creating pmap-style ShardedDeviceArrays over the first dimension.
514    if device_buffers is None:
515      device_buffers = sharding_spec
516      sharded_aval = ShapedArray(aval.shape[1:], aval.dtype)
517      sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
518                                          1, None, sharded_aval, 0)
519
520    # TODO(skye): assert invariants. Keep performance in mind though.
521    if indices is None:
522      indices = spec_to_indices(aval.shape, sharding_spec)
523    self.aval = aval
524    self.device_buffers = device_buffers
525    self.sharding_spec = sharding_spec
526    self.indices = indices
527    self._npy_value = None
528    self._one_replica_buffer_indices = None
529    if not core.skip_checks:
530      assert type(aval) is ShapedArray
531
532  @property
533  def one_replica_buffer_indices(self):
534    """Indices of buffers containing one complete copy of the array data."""
535    if self._one_replica_buffer_indices is None:
536      one_replica_indices = []
537      seen_index_hashes = set()
538      for i, index in enumerate(self.indices):
539        hashed_index = _hashable_index(index)
540        if hashed_index not in seen_index_hashes:
541          one_replica_indices.append(i)
542          seen_index_hashes.add(hashed_index)
543      self._one_replica_buffer_indices = one_replica_indices
544    return self._one_replica_buffer_indices
545
546  def copy_to_host_async(self):
547    for buffer_index in self.one_replica_buffer_indices:
548      self.device_buffers[buffer_index].copy_to_host_async()
549
550  def delete(self):
551    for buf in self.device_buffers:
552      buf.delete()
553    self.device_buffers = None
554    self._npy_value = None
555
556  def _check_if_deleted(self):
557    if self.device_buffers is None:
558      raise ValueError("ShardedDeviceArray has been deleted.")
559
560  def block_until_ready(self):
561    self._check_if_deleted()
562    for buf in self.device_buffers:
563      buf.block_host_until_ready()
564    return self
565
566  @property
567  def _value(self):
568    if self._npy_value is None:
569      self.copy_to_host_async()
570      npy_value = np.empty(self.aval.shape, self.aval.dtype)
571      for i in self.one_replica_buffer_indices:
572        npy_value[self.indices[i]] = self.device_buffers[i].to_py()
573      self._npy_value = npy_value
574    return self._npy_value
575
576  def __getitem__(self, idx):
577    if not isinstance(idx, tuple):
578      cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
579    else:
580      cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
581    if self._npy_value is None:
582      try:
583        buf_idx = self.indices.index(cidx)
584      except ValueError:
585        buf_idx = None
586      if buf_idx is not None:
587        buf = self.device_buffers[buf_idx]
588        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58
589        # is the default.
590        aval = ShapedArray(
591            getattr(buf, "xla_shape", buf.shape)().dimensions(),
592            self.aval.dtype)
593        return xla.make_device_array(aval, None, lazy.array(aval.shape), buf)
594    return super(ShardedDeviceArray, self).__getitem__(idx)
595
596
597def _hashable_index(idx):
598  return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
599                  idx)
600
601# The fast path is handled directly in shard_args().
602# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
603def _shard_sharded_device_array_slow_path(x, devices, indices):
604  candidates = defaultdict(list)
605  for buf, idx in safe_zip(x.device_buffers, x.indices):
606    candidates[_hashable_index(idx)].append(buf)
607
608  bufs = []
609  for idx, device in safe_zip(indices, devices):
610    # Look up all buffers that contain the correct slice of the logical array.
611    candidates_list = candidates[_hashable_index(idx)]
612    if not candidates_list:
613      # This array isn't sharded correctly. Reshard it via host roundtrip.
614      # TODO(skye): more efficient reshard?
615      return shard_arg_handlers[type(x._value)](x._value, devices, indices)
616    # Try to find a candidate buffer already on the correct device,
617    # otherwise copy one of them.
618    for buf in candidates_list:
619      if buf.device() == device:
620        bufs.append(buf)
621        break
622    else:
623      bufs.append(buf.copy_to_device(device))
624  return bufs
625shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path
626
627def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
628  return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types)
629xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)
630
631core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
632xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
633xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval')
634xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
635
636
637### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
638
639def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
640                  global_axis_size, devices, name, in_axes, out_axes_thunk,
641                  donated_invars, global_arg_shapes):
642  abstract_args = unsafe_map(xla.abstractify, args)
643  compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
644                                   global_axis_size, devices, name,
645                                   in_axes, out_axes_thunk,
646                                   donated_invars, global_arg_shapes,
647                                   *abstract_args)
648  return compiled_fun(*args)
649
650@lu.cache
651def parallel_callable(fun: lu.WrappedFun,
652                      backend_name: Optional[str],
653                      axis_name,
654                      axis_size: int,
655                      global_axis_size: Optional[int],
656                      devices: Optional[Sequence[Any]],
657                      name: str,
658                      in_axes: Iterable[Optional[int]],
659                      out_axes_thunk: Callable[[], Sequence[Optional[int]]],
660                      donated_invars: Iterable[bool],
661                      global_arg_shapes,
662                      *avals):
663  if devices is not None and len(devices) == 0:
664    raise ValueError("'devices' argument to pmap must be non-empty, or None.")
665
666  # Determine global_axis_size for use in AxisEnv.
667  # TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
668  # if xb.host_count() > 1 and global_axis_size is None and inner_pmap:
669  #   raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
670  if (xb.host_count() == 1 and global_axis_size is not None and
671      global_axis_size != axis_size):
672    raise ValueError(
673        f"Specified axis_size {global_axis_size} doesn't match received "
674        f"axis_size {axis_size}.")
675
676  must_run_on_all_devices = False
677  no_nested_sharding = False
678  if global_axis_size is None:
679    if xb.host_count() == 1:
680      global_axis_size = axis_size
681    elif devices:
682      # This allows each host in a multi-host pmap to run on a different number
683      # of devices, but precludes nested sharding (i.e. inner pmaps or
684      # sharded_jits).
685      global_axis_size = len(devices)
686      no_nested_sharding = True
687    else:
688      # This assumes all hosts run on the same number of devices. We make sure
689      # this assumption is true by requiring that the pmap is run on all devices
690      # (and making the further assumption that each host has the same number of
691      # devices). Nested sharding is ok in this case.
692      global_axis_size = axis_size * xb.host_count()
693      assert all(len(xb.local_devices(host_id)) == xb.local_device_count()
694                 for host_id in xb.host_ids())
695      must_run_on_all_devices = True
696
697  if devices:
698    local_devices = [d for d in devices if d.host_id == xb.host_id()]
699    assert len(local_devices) > 0
700  else:
701    local_devices = None  # type: ignore
702
703  if config.omnistaging_enabled:
704    sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
705                          for axis, aval in safe_zip(in_axes, avals))
706    if any(s is not None for s in global_arg_shapes):
707      # TODO(skye): we could take this branch unconditionally if we handled
708      # grad of global_arg_shapes correctly.
709      global_sharded_avals = [
710          ShapedArray(shape, aval.dtype) if shape is not None else aval
711          for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
712    else:
713      global_sharded_avals = sharded_avals  # type: ignore
714    logging.vlog(2, "sharded_avals: %s", sharded_avals)
715    logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
716
717    with core.extend_axis_env(axis_name, global_axis_size, None):  # type: ignore
718      jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
719    jaxpr = xla.apply_outfeed_rewriter(jaxpr)
720  else:
721    @lu.wrap_init
722    def dynamic_fun(dummy, *args):
723      with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size):  # type: ignore
724        return fun.call_wrapped(*args)
725
726    sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
727                          for axis, aval in safe_zip(in_axes, avals))
728    pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals]
729    # We add a dummy first invar, to carry the trace  details to `dynamic_fun`
730    pval = pe.PartialVal.unknown(core.abstract_unit)  # dummy value for axis env
731    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(  # type: ignore
732      dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)  # type: ignore
733    jaxpr.invars = jaxpr.invars[1:]  # ignore dummy
734    jaxpr = xla.apply_outfeed_rewriter(jaxpr)
735
736    out_pvs, out_consts = unzip2(out_pvals)
737    global_sharded_avals = sharded_avals  # type: ignore
738
739  out_axes = out_axes_thunk()
740  if config.omnistaging_enabled:
741    assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
742  else:
743    assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes))
744    assert all(out_axis == 0 for out_axis in out_axes)
745
746  # TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
747  # for now raise an error
748  if devices is not None:
749    is_multi_host_pmap = len(local_devices) != len(devices)
750  else:
751    is_multi_host_pmap = xb.host_count() > 1
752  if is_multi_host_pmap:
753    check_multihost_collective_allowlist(jaxpr)
754
755  if not config.omnistaging_enabled:
756    if all(pv is None for pv in out_pvs):
757      # When the output doesn't depend on the input we don't need to compile an
758      # XLA computation at all; we handle this as a special case so we can stage
759      # out multi-replica XLA computations regardless of the hardware available.
760      # The 'None' values here are just dummies we know will be ignored.
761      handlers = [
762        _pval_to_result_handler(  # type: ignore
763          axis_size, None, None, None, pval, local_devices, backend_name)  # type: ignore
764        for pval in out_pvals  # type: ignore
765      ]
766      results = [handler(None) for handler in handlers]
767      return lambda *_: results
768
769
770  # TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
771  jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
772  num_local_replicas = axis_size * jaxpr_replicas
773  num_global_replicas = global_axis_size * jaxpr_replicas
774
775  (arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
776   local_num_partitions) = _find_partitions(jaxpr)
777
778  if local_num_partitions is None:
779    local_num_partitions = num_partitions
780
781  if local_arg_parts is None:
782    local_arg_parts = arg_parts
783  if local_out_parts is None:
784    local_out_parts = out_parts
785
786  logging.vlog(2, "num_replicas: %d  num_local_replicas: %d",
787               num_global_replicas, num_local_replicas)
788  logging.vlog(2, "num_partitions: %d  local_num_partitions: %d",
789               num_partitions, local_num_partitions)
790  logging.vlog(2, "arg_parts: %s", arg_parts)
791  logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
792  logging.vlog(2, "out_parts: %s", out_parts)
793  logging.vlog(2, "local_out_parts: %s", local_out_parts)
794  logging.vlog(2, "devices: %s", devices)
795  logging.vlog(2, "local_devices: %s", local_devices)
796
797  num_local_shards = num_local_replicas * local_num_partitions
798  num_global_shards = num_global_replicas * num_partitions
799
800  if (xb.host_count() > 1 and must_run_on_all_devices and
801      num_local_shards != xb.local_device_count()):
802    if num_local_shards == axis_size:
803      raise ValueError(
804         f"On multi-host platforms, the input to pmapped functions must have "
805         f"leading axis size equal to the number of local devices if no "
806         f"`devices` argument is specified. Got axis_size={axis_size}, "
807         f"num_local_devices={xb.local_device_count()}")
808    else:
809      raise ValueError(
810        f"On multi-host platforms, pmapped functions must run across all "
811        f"devices, i.e. num_replicas * num_partitions should equal the "
812        f"number of local devices. Got num_replicas={num_local_replicas}, "
813        f"num_partitions={num_partitions}, and "
814        f"num_local_devices={xb.local_device_count()}")
815
816  if no_nested_sharding and (jaxpr_replicas > 1 or num_partitions > 1):
817    raise ValueError(
818      f"On multi-host platforms, pmapped functions that both have `devices` "
819      f"specified and contain an inner_pmap or sharded_jit must specify an "
820      f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
821      f"{jaxpr_replicas} and nested_partitions={num_partitions}")
822
823  log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
824  logging.log(log_priority,
825              f"Compiling {fun.__name__} for {num_global_shards} devices with "
826              f"args {avals}. (num_replicas={num_global_replicas} "
827              f"num_partitions={num_partitions})")
828
829  axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,))
830
831  tuple_args = len(global_sharded_avals) > 100  # pass long arg lists as tuple for TPU
832
833  c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
834  xla_consts = map(partial(xb.constant, c), consts)
835  replicated_args = [axis is None for axis in in_axes]
836  xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
837                                                    replicated=replicated_args,
838                                                    partitions=arg_parts,
839                                                    donated_invars=donated_invars)
840  with maybe_extend_axis_env(axis_name, global_axis_size, None):  # type: ignore
841    out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts,
842                                  extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
843  build_out_tuple = partial(xops.Tuple, c, out_nodes)
844  if out_parts is not None:
845    out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)
846  else:
847    out_tuple = build_out_tuple()
848  backend = xb.get_backend(backend_name)
849  if backend.platform in ("gpu", "tpu"):
850    donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
851  built = c.Build(out_tuple)
852
853  if devices is None:
854    if num_global_shards > xb.device_count(backend):
855      msg = ("compiling computation that requires {} logical devices, but only {} XLA "
856             "devices are available (num_replicas={}, num_partitions={})")
857      raise ValueError(msg.format(num_global_shards, xb.device_count(backend),
858                                  num_global_replicas, num_partitions))
859
860    # On a single host, we use the platform's default device assignment to
861    # potentially take advantage of device locality. On multiple hosts, the
862    # default device assignment may interleave different hosts' replicas,
863    # violating pmap's semantics where data is sharded across replicas in
864    # row-major order. Instead, manually create a device assignment that ensures
865    # each host is responsible for a continguous set of replicas.
866    if num_global_shards > num_local_shards:
867      # TODO(skye): use a locality-aware assignment that satisfies the above
868      # constraint.
869      devices = [d for host_id in xb.host_ids()
870                 for d in xb.local_devices(host_id)]
871    else:
872      devices = xb.get_backend(backend).get_default_device_assignment(
873          num_global_replicas, num_partitions)
874  else:
875    if num_local_shards != len(local_devices):
876      local_devices_str = ", ".join(map(str, local_devices))
877      raise ValueError(
878          "Leading axis size of input to pmapped function must equal the "
879          "number of local devices passed to pmap. Got axis_size=%d, "
880          "num_local_devices=%d.\n(Local devices passed to pmap: %s)"
881          % (axis_size, len(local_devices), local_devices_str))
882    if num_global_shards != len(devices):
883      raise ValueError("compiling computation that creates %s shards, "
884                       "but %s devices were specified" %
885                       (num_global_shards, len(devices)))
886
887  # 'devices' may be 1D or 2D at this point (e.g.
888  # get_default_device_assignment() returns 2D assignment, caller may have
889  # provided 1D list of devices).
890  device_assignment = tree_map(lambda d: d.id, devices)
891  # Convert to 2D in case it's 1D and we have > 1 partitions.
892  device_assignment = np.array(device_assignment).reshape(
893      (num_global_replicas, num_partitions))
894  # TODO(b/162356737): Enabling SPMD partitioning causes issues with some
895  # non-partitioned workloads, so disable unless needed.
896  use_spmd_partitioning = num_partitions > 1
897  compile_options = xb.get_compile_options(
898      num_replicas=num_global_replicas,
899      num_partitions=num_partitions,
900      device_assignment=device_assignment,
901      use_spmd_partitioning=use_spmd_partitioning,
902  )
903  compile_options.parameter_is_tupled_arguments = tuple_args
904  compiled = xla.backend_compile(backend, built, compile_options)
905
906  local_arg_parts_ = local_arg_parts or [None] * len(avals)
907  input_sharding_specs = [
908      _pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
909                          parts, aval, in_axis)
910      if aval is not core.abstract_unit else None
911      for aval, parts, in_axis in safe_zip(sharded_avals, local_arg_parts_, in_axes)]
912  input_indices = [spec_to_indices(aval.shape, spec)
913                   if spec is not None else None
914                   for aval, spec in safe_zip(avals, input_sharding_specs)]
915  handle_args = partial(shard_args, compiled.local_devices(), input_indices)
916  if config.omnistaging_enabled:
917    nouts = len(out_sharded_avals)
918    if out_parts is None:
919      out_parts = (None,) * nouts
920    if local_out_parts is None:
921      local_out_parts = (None,) * nouts
922
923    local_out_avals = [get_local_aval(aval, parts, lparts)
924                       for aval, parts, lparts
925                       in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
926    local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
927                            if out_axis is not None else aval
928                            for aval, out_axis in safe_zip(local_out_avals, out_axes)]
929
930    out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
931                                     parts, aval, out_axis)
932                if aval is not core.abstract_unit else None
933                for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
934    handle_outs = avals_to_results_handler(
935        num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
936  else:
937    handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,  # type: ignore
938                                            local_num_partitions,
939                                            local_out_parts, out_pvals,
940                                            compiled.local_devices(), backend)
941
942  return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
943
944multi_host_supported_collectives: Set[core.Primitive] = set()
945
946
947def check_multihost_collective_allowlist(jaxpr):
948  used_collectives = set(xla.jaxpr_collectives(jaxpr))
949  if not used_collectives.issubset(multi_host_supported_collectives):
950    bad_collectives = used_collectives - multi_host_supported_collectives
951    msg = "using collectives that aren't supported for multi-host: {}"
952    raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
953
954
955PartitionsOrReplicated = Optional[Tuple[int, ...]]
956
957def _find_partitions(jaxpr) -> Tuple[
958    Optional[Tuple[PartitionsOrReplicated, ...]],
959    Optional[Tuple[PartitionsOrReplicated, ...]],
960    int,
961    Optional[Tuple[PartitionsOrReplicated, ...]],
962    Optional[Tuple[PartitionsOrReplicated, ...]],
963    Optional[int]]:
964  """Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
965              local_out_parts, local_num_partitions).
966  """
967  for eqn in jaxpr.eqns:
968    if eqn.primitive.name == "sharded_call":
969      if len(jaxpr.eqns) > 1:
970        raise NotImplementedError(
971            "pmap of sharded_jit + non-sharded operations not yet implemented.")
972      num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
973                                                eqn.params["nparts"])
974      return (eqn.params["in_parts"],
975              eqn.params["out_parts_thunk"](),
976              num_partitions,
977              eqn.params["local_in_parts"],
978              eqn.params["local_out_parts_thunk"](),
979              eqn.params["local_nparts"])
980  return None, None, 1, None, None, None
981
982def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
983  """Returns the total number of partitions to use.
984
985  Validates that any inner partitioning matches outer_num_parts if provided, and
986  returns the number of partitions to use based on outer_num_parts and any inner
987  partitioning.
988  """
989  inner_num_parts = _inner_partitions(jaxpr, outer_num_parts)
990  if outer_num_parts is None and inner_num_parts is None:
991    # No partitions specified anywhere, everything is replicated.
992    return 1
993  if outer_num_parts is None:
994    return inner_num_parts
995  return outer_num_parts
996
997
998def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):
999  """Returns the total number of partitions from PartitionSpecs inside `jaxpr`.
1000
1001  Also validates that this number matches `expected_num_parts` if provided.
1002  """
1003  for eqn in jaxpr.eqns:
1004    if eqn.primitive.name in ["sharding_constraint", "infeed"]:
1005      parts = eqn.params["partitions"]
1006      nparts = get_num_partitions(parts)
1007      if expected_num_parts is None:
1008        expected_num_parts = nparts
1009      elif nparts is not None and nparts != expected_num_parts:
1010        # TODO(skye): raise this error as we trace the jaxpr
1011        raise ValueError(
1012            f"with_sharding_constraint with partitions={parts} "
1013            f"(total partitions: {nparts}) doesn't match expected number of "
1014            f"partitions: {expected_num_parts}. If these partitions look "
1015            f"right, check outer sharded_jit and/or other "
1016            f"with_sharding_constraint calls.")
1017    else:
1018      for subjaxpr in core.jaxprs_in_params(eqn.params):
1019        expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts)
1020  return expected_num_parts
1021
1022
1023def get_num_partitions(*partitions):
1024  partition_specs = tree_flatten(partitions)[0]
1025  if len(partition_specs) == 0:
1026    # Everything is specified as replicated (all Nones).
1027    return None
1028  num_partitions_set = {np.prod(spec) for spec in partition_specs}
1029  if len(num_partitions_set) > 1:
1030    raise ValueError(
1031        f"All partition specs must use the same number of total partitions, "
1032        f"got {partitions}, with distinct number of partitions "
1033        f"{num_partitions_set} (the total number of partitions is the product "
1034        f"of a partition spec)")
1035  assert len(num_partitions_set) == 1
1036  return num_partitions_set.pop()
1037
1038
1039def get_global_aval(local_aval, global_parts: PartitionsOrReplicated,
1040                    local_parts: PartitionsOrReplicated):
1041  if local_aval is core.abstract_unit:
1042    return local_aval
1043  if global_parts is None:
1044    return local_aval
1045  assert local_parts is not None
1046  global_shape = [dim * _safe_div(ngparts, nlparts)
1047                  for dim, ngparts, nlparts
1048                  in safe_zip(local_aval.shape, global_parts, local_parts)]
1049  return ShapedArray(global_shape, local_aval.dtype)
1050
1051
1052def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
1053                   local_parts: PartitionsOrReplicated):
1054  if global_aval is core.abstract_unit:
1055    return global_aval
1056  if global_parts is None:
1057    return global_aval
1058  assert local_parts is not None
1059  local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts))
1060                 for dim, ngparts, nlparts
1061                 in safe_zip(global_aval.shape, global_parts, local_parts)]
1062  return ShapedArray(local_shape, global_aval.dtype)
1063
1064
1065def _safe_div(x, y):
1066  result, ragged = divmod(x, y)
1067  assert not ragged, f"{x} % {y} != 0"
1068  return result
1069
1070
1071class ResultToPopulate: pass
1072result_to_populate = ResultToPopulate()
1073
1074class ResultsHandler:
1075  __slots__ = ("nrep", "npart", "nouts", "out_specs", "out_indices", "handlers",
1076               "unmapped_local_out_avals")
1077
1078  def __init__(self, nrep, npart, nouts, out_specs, out_indices, handlers,
1079               unmapped_local_out_avals):
1080    self.nrep = nrep
1081    self.npart = npart
1082    self.nouts = nouts
1083    self.out_specs = out_specs
1084    self.out_indices = out_indices
1085    self.handlers = handlers
1086    self.unmapped_local_out_avals = unmapped_local_out_avals
1087
1088  def __call__(self, out_bufs):
1089    assert self.nrep * self.npart == len(out_bufs)
1090    buffers = [[result_to_populate] * (self.nrep * self.npart)
1091               for _ in range(self.nouts)]
1092    for r, tuple_buf in enumerate(out_bufs):
1093      for i, buf in enumerate(tuple_buf):
1094        buffers[i][r] = buf
1095    assert not any(
1096        buf is result_to_populate for bufs in buffers for buf in bufs)
1097    return [h(bufs) for h, bufs in safe_zip(self.handlers, buffers)]
1098
1099def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals):
1100  nouts = len(unmapped_local_out_avals)
1101  out_indices = [spec_to_indices(aval.shape, spec)
1102                 if aval is not core.abstract_unit else None
1103                 for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)]  # pytype: disable=attribute-error
1104  handlers = [aval_to_result_handler(spec, idcs, aval)
1105              for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)]
1106
1107  return ResultsHandler(nrep, npart, nouts, out_specs, out_indices, handlers,
1108                        unmapped_local_out_avals)
1109
1110def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
1111  """Replicates ``val`` across multiple devices.
1112
1113  Args:
1114    val: the value to be replicated.
1115    axis_size: the length of the output, i.e. the logical number of replicas to
1116    create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may
1117    be a multiple of `axis_size`.
1118    nrep: the number of replicas to create. If ``devices`` is set, must be equal
1119      to ``len(devices)``.
1120    devices: the devices to replicate across. If None, ``nrep`` will be used to
1121      generate a default device assignment.
1122    backend: string specifying which backend to use.
1123    in_axis: axis along which the value is to be replciated.
1124
1125  Returns:
1126    A ShardedDeviceArray of length `axis_size` where each shard is equal to
1127    ``val``.
1128  """
1129  device_count = (len(devices) if devices else xb.local_device_count(backend))
1130  if nrep > device_count:
1131    msg = ("Cannot replicate across %d replicas because only %d local devices "
1132           "are available." % (nrep, device_count))
1133    if devices:
1134      msg += (" (local devices = %s)"
1135              % ", ".join(map(str, devices)) if devices else str(None))
1136    raise ValueError(msg)
1137
1138  if devices is None:
1139    assert nrep is not None
1140    # TODO(skye): use different device assignment on multihost
1141    devices = xb.get_backend(backend).get_default_device_assignment(nrep)
1142  assert nrep == len(devices)
1143
1144  aval = xla.abstractify(val)  # type: ShapedArray
1145  replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
1146  # TODO(skye): figure out how partitioning should work here
1147  sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
1148  device_buffers = device_put(val, devices, replicate=True)
1149  return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
1150
1151def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, map_axis: Optional[int]):
1152  """Sharding spec for arguments or results of a pmap.
1153  Args:
1154    nrep: number of local XLA replicas (product of local axis sizes)
1155    axis_size: local axis size for outer pmap
1156    npart: total number of XLA partitions (required by sharded_jit calls)
1157    parts: the partitioning of the value or None
1158    sharded_aval: the aval of the value inside the outer pmap, an instance of
1159      a ShapedArray.
1160    map_axis: the axis along which the value is mapped in the outer pmap
1161  Returns:
1162    A ShardingSpec.
1163  """
1164  assert isinstance(sharded_aval, ShapedArray), sharded_aval
1165  replication_factor, ragged = divmod(nrep, axis_size)
1166  assert not ragged
1167  # get the sharding spec from inner sharded_jits as if we weren't in a pmap
1168  pspec = partitioned_sharding_spec(npart, parts, sharded_aval)
1169  maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
1170  if map_axis is not None:
1171    sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
1172    def shift_sharded_axis(a: MeshDimAssignment):
1173      if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
1174        return ShardedAxis(a.axis + 1)
1175      return a
1176    # replication_factor represents the product of inner pmaps, so it goes
1177    # after the outer pmapped axis at index 0
1178    return ShardingSpec(
1179      sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)),
1180      mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)],
1181                            maybe_replicate,
1182                            map(shift_sharded_axis, pspec.mesh_mapping)))
1183  else:
1184    return ShardingSpec(
1185      sharding=pspec.sharding,
1186      mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
1187
1188def partitioned_sharding_spec(num_partitions: int,
1189                              partitions: Optional[Sequence[int]],
1190                              aval) -> ShardingSpec:
1191  if partitions is None:
1192    maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
1193    return ShardingSpec(
1194        sharding=[_UNSHARDED_INSTANCE] * len(aval.shape),
1195        mesh_mapping=maybe_replicate)
1196  else:
1197    assert len(partitions) == len(aval.shape)
1198    return ShardingSpec(sharding=map(Chunked, partitions),
1199                        mesh_mapping=map(ShardedAxis, range(len(partitions))))
1200
1201
1202def execute_replicated(compiled, backend, in_handler, out_handler, *args):
1203  input_bufs = in_handler(args)
1204  out_bufs = compiled.execute_on_local_devices(list(input_bufs))
1205  return out_handler(out_bufs)
1206
1207
1208xla_pmap_p = core.MapPrimitive('xla_pmap')
1209xla_pmap = xla_pmap_p.bind
1210xla_pmap_p.def_impl(xla_pmap_impl)
1211
1212# Set param update handlers to update `donated_invars` just like xla_call_p
1213pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
1214ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
1215ad.call_transpose_param_updaters[xla_pmap_p] = \
1216    ad.call_transpose_param_updaters[xla.xla_call_p]
1217
1218def _pmap_translation_rule(c, axis_env,
1219                           in_nodes, name_stack, axis_name, axis_size,
1220                           global_axis_size, devices, name,
1221                           call_jaxpr, *, backend=None, in_axes, out_axes,
1222                           donated_invars, global_arg_shapes):
1223  del donated_invars  # Unused.
1224  # We in-line here rather than generating a Call HLO as in the xla_call
1225  # translation rule just because the extra tuple stuff is a pain.
1226  if axis_env.names and devices is not None:
1227    raise ValueError("Nested pmap with explicit devices argument.")
1228  if global_axis_size is None:
1229    global_axis_size = axis_size
1230  new_env = xla.extend_axis_env(axis_env, axis_name, global_axis_size)
1231  # Shard the in_nodes that are mapped
1232  in_avals = [v.aval for v in call_jaxpr.invars]
1233  in_nodes_sharded = (
1234    _xla_shard(c, aval, new_env, in_node, in_axis) if in_axis is not None else in_node
1235    for aval, in_node, in_axis in safe_zip(in_avals, in_nodes, in_axes))
1236
1237  with maybe_extend_axis_env(axis_name, global_axis_size, None):  # type: ignore
1238    sharded_outs = xla.jaxpr_subcomp(
1239        c, call_jaxpr, backend, new_env, (),
1240        extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
1241  out_avals = [v.aval for v in call_jaxpr.outvars]
1242  outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
1243          for aval, out_axis, shard in safe_zip(out_avals, out_axes, sharded_outs)]
1244  return xops.Tuple(c, outs)
1245
1246xla.call_translations[xla_pmap_p] = _pmap_translation_rule
1247ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
1248
1249def _xla_shard(c, aval, axis_env, x, in_axis):
1250  if aval is core.abstract_unit:
1251    return x
1252  elif aval is core.abstract_token:
1253    return x
1254  elif isinstance(aval, ShapedArray):
1255    dims = list(c.get_shape(x).dimensions())
1256    zero = xb.constant(c, np.zeros((), dtype=np.uint32))
1257    idxs = [zero] * (len(dims) - 1)
1258    idxs.insert(in_axis, _unravel_index(c, axis_env))
1259    dims_unsqueezed = dims.copy()
1260    dims_unsqueezed[in_axis] = 1
1261    dims_squeezed = dims.copy()
1262    dims_squeezed.pop(in_axis)
1263    return xops.Reshape(xops.DynamicSlice(x, idxs, dims_unsqueezed), dims_squeezed)
1264  else:
1265    raise TypeError((aval, c.get_shape(x)))
1266
1267# TODO(b/110096942): more efficient gather
1268def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
1269  if aval is core.abstract_unit:
1270    return x
1271  elif aval is core.abstract_token:
1272    return x
1273  elif isinstance(aval, ShapedArray):
1274    # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
1275    convert_bool = (np.issubdtype(aval.dtype, np.bool_)
1276                    and xb.get_backend(backend).platform in ('cpu', 'gpu'))
1277    if convert_bool:
1278      x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
1279
1280    xla_shape = c.get_shape(x)
1281    dims = list(xla_shape.dimensions())
1282    padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
1283                         [axis_env.sizes[-1]] + dims)
1284    zero = xb.constant(c, np.zeros((), dtype=np.uint32))
1285    idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
1286    padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
1287    replica_groups_protos = xc.make_replica_groups(
1288      xla.axis_groups(axis_env, axis_env.names[-1]))
1289    out = xops.CrossReplicaSum(padded, replica_groups_protos)
1290    if out_axis != 0:
1291      # TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
1292      perm = list(range(1, len(dims)))
1293      perm.insert(out_axis, 0)
1294      out = xops.Transpose(out, perm)
1295
1296    # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
1297    if convert_bool:
1298      nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
1299      out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
1300    return out
1301  else:
1302    raise TypeError((aval, c.get_shape(x)))
1303
1304def _unravel_index(c, axis_env):
1305  div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
1306  mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
1307  return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
1308
1309# ------------------- xmap -------------------
1310
1311MeshAxisName = Any
1312"""
1313ArrayMapping specifies how an ndarray should map to mesh axes.
1314
1315Note that the ordering is crucial for the cases when this mapping is non-injective
1316(i.e. when multiple mesh axes map to the same positional axis). Then, the
1317order of entries of the mapping determines a major-to-minor order on mesh axes,
1318according to which chunks of the value along the repeated dimension will be assigned.
1319
1320For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
1321The second dimension of the value would get chunked into 6 pieces, and assigned to the
1322mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
1323that would mean that a flat list of chunks would get assigned to a flattened list of
1324mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
1325mesh devices ndarray would have to be transposed before flattening and assignment.
1326"""
1327ArrayMapping = OrderedDictType[MeshAxisName, int]
1328
1329class Mesh:
1330  __slots__ = ('devices', 'axis_names')
1331
1332  def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
1333    assert devices.ndim == len(axis_names)
1334    # TODO: Make sure that devices are unique? At least with the quick and
1335    #       dirty check that the array size is not larger than the number of
1336    #       available devices?
1337    self.devices = devices
1338    self.axis_names = tuple(axis_names)
1339
1340  @property
1341  def shape(self):
1342    return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
1343
1344  @property
1345  def size(self):
1346    return np.prod(list(self.shape.values()))
1347
1348  # TODO: This is pretty expensive to compute. Cache this on the mesh object?
1349  @property
1350  def local_mesh(self):
1351    host_id = xb.host_id()
1352    is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices)
1353    subcube_indices = []
1354    # We take the smallest slice of each dimension that doesn't skip any local device.
1355    for axis in range(self.devices.ndim):
1356      other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis)
1357      # NOTE: This re-reduces over many axes multiple times, so we could definitely
1358      #       optimize it, but I hope it won't be a bottleneck anytime soon.
1359      local_slices = is_local_device.any(other_axes, keepdims=False)
1360      nonzero_indices = np.flatnonzero(local_slices)
1361      start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
1362      subcube_indices.append(slice(start, end + 1))
1363    subcube_indices = tuple(subcube_indices)
1364    # We only end up with all conditions being true if the local devices formed a
1365    # subcube of the full array. This is because we were biased towards taking a
1366    # "hull" spanned by the devices, and in case the local devices don't form a
1367    # subcube that hull will contain non-local devices.
1368    assert is_local_device[subcube_indices].all()
1369    return Mesh(self.devices[subcube_indices], self.axis_names)
1370
1371  def __getitem__(self, new_axes):
1372    indices = [0] * len(self.axis_names)
1373    axis_pos = {name: i for i, name in enumerate(self.axis_names)}
1374    for axis in new_axes:
1375      indices[axis_pos[axis]] = slice(None)
1376    new_devices = self.devices[tuple(indices)]
1377    new_devices = new_devices.transpose(tuple(axis_pos[axis] for axis in new_axes))
1378    return Mesh(new_devices, new_axes)
1379
1380  @property
1381  def device_ids(self):
1382    return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
1383
1384def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
1385  if aval is core.abstract_unit:
1386    return aval
1387  assert isinstance(aval, ShapedArray)
1388  shape = list(aval.shape)
1389  for name, axis in in_axes.items():
1390    assert shape[axis] % axis_sizes[name] == 0
1391    shape[axis] //= axis_sizes[name]
1392  return ShapedArray(tuple(shape), aval.dtype)
1393
1394def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
1395  if aval is core.abstract_unit:
1396    return aval
1397  assert isinstance(aval, ShapedArray)
1398  shape = list(aval.shape)
1399  for name, axis in out_axes.items():
1400    shape[axis] *= axis_sizes[name]
1401  return ShapedArray(tuple(shape), aval.dtype)
1402
1403def mesh_tiled_callable(fun: lu.WrappedFun,
1404                        transformed_name: str,
1405                        backend_name: Optional[str],
1406                        mesh: Mesh,
1407                        in_axes: Sequence[ArrayMapping],
1408                        out_axes: Sequence[ArrayMapping],
1409                        spmd_lowering,
1410                        *local_in_untiled_avals):
1411  assert config.omnistaging_enabled
1412  local_mesh = mesh.local_mesh
1413  global_axis_sizes = mesh.shape
1414  local_axis_sizes = local_mesh.shape
1415
1416  log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
1417  logging.log(log_priority,
1418              f"Compiling {fun.__name__} for {tuple(global_axis_sizes.items())} "
1419              f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.")
1420
1421  # 1. Trace to jaxpr and preprocess/verify it
1422  in_tiled_avals = [tile_aval_nd(local_axis_sizes, aval_in_axes, aval)
1423                    for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
1424  if spmd_lowering:
1425    # TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
1426    for name, size in reversed(mesh.shape.items()):
1427      fun = vtile(fun,
1428                  tuple(a.get(name, None) for a in in_axes),
1429                  tuple(a.get(name, None) for a in out_axes),
1430                  tile_size=size, axis_name=name)
1431    global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
1432                               for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
1433    in_jaxpr_avals = global_in_untiled_avals
1434  else:
1435    in_jaxpr_avals = in_tiled_avals
1436  with core.extend_axis_env_nd(mesh.shape.items()):
1437    jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
1438  assert len(out_axes) == len(out_jaxpr_avals)
1439  if spmd_lowering:
1440    global_out_untiled_avals = out_jaxpr_avals
1441    out_tiled_avals = [tile_aval_nd(global_axis_sizes, aval_out_axes, aval)
1442                       for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
1443  else:
1444    out_tiled_avals = out_jaxpr_avals
1445  local_out_untiled_avals = [untile_aval_nd(local_axis_sizes, aval_out_axes, aval)
1446                             for aval, aval_out_axes in safe_zip(out_tiled_avals, out_axes)]
1447  _sanitize_mesh_jaxpr(jaxpr)
1448  if local_mesh.shape != mesh.shape:
1449    check_multihost_collective_allowlist(jaxpr)
1450  jaxpr = xla.apply_outfeed_rewriter(jaxpr)
1451
1452  # 3. Build up the HLO
1453  c = xb.make_computation_builder(f"xmap_{fun.__name__}")
1454  xla_consts = map(partial(xb.constant, c), consts)
1455  donated_invars = (False,) * len(in_jaxpr_avals)  # TODO(apaszke): support donation
1456  tuple_args = len(in_jaxpr_avals) > 100  # pass long arg lists as tuple for TPU
1457  in_partitions: Optional[List]
1458  if spmd_lowering:
1459    replicated_args = [False] * len(in_jaxpr_avals)
1460    global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names)
1461    in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto()
1462                     if aval is not core.abstract_unit else None
1463                     for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)]
1464    out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
1465                      for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
1466    partitions_proto = True
1467    axis_env = xla.AxisEnv(nreps=1, names=(), sizes=())  # All named axes have been vmapped
1468  else:
1469    replicated_args = [not axis for axis in in_axes]
1470    in_partitions = None
1471    partitions_proto = False
1472    axis_env = xla.AxisEnv(nreps=mesh.size,
1473                           names=tuple(global_axis_sizes.keys()),
1474                           sizes=tuple(global_axis_sizes.values()))
1475  xla_args, donated_invars = xla._xla_callable_args(
1476      c, in_jaxpr_avals, tuple_args,
1477      replicated=replicated_args,
1478      partitions=in_partitions,
1479      partitions_proto=partitions_proto,
1480      donated_invars=donated_invars)
1481  with core.extend_axis_env_nd(mesh.shape.items()):
1482    out_nodes = xla.jaxpr_subcomp(
1483        c, jaxpr, backend_name, axis_env, xla_consts,
1484        extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
1485  backend = xb.get_backend(backend_name)
1486  if spmd_lowering:
1487    out_partitions_t = xb.tuple_sharding_proto(out_partitions)
1488    out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
1489  else:
1490    out_tuple = xops.Tuple(c, out_nodes)
1491  # TODO(apaszke): Does that work with SPMD sharding?
1492  if backend.platform in ("gpu", "tpu"):
1493    donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
1494  built = c.Build(out_tuple)
1495
1496  # 4. Compile the HLO
1497  if spmd_lowering:
1498    num_replicas, num_partitions = 1, mesh.size
1499    num_local_replicas, num_local_partitions = 1, local_mesh.size
1500  else:
1501    num_replicas, num_partitions = mesh.size, 1
1502    num_local_replicas, num_local_partitions = local_mesh.size, 1
1503  device_assignment = mesh.device_ids.reshape((num_replicas, num_partitions))
1504  compile_options = xb.get_compile_options(
1505      num_replicas=num_replicas,
1506      num_partitions=num_partitions,
1507      device_assignment=device_assignment,
1508      use_spmd_partitioning=spmd_lowering,
1509  )
1510  compile_options.parameter_is_tupled_arguments = tuple_args
1511  compiled = xla.backend_compile(backend, built, compile_options)
1512
1513  # 5. Argument sharding / output wrapping
1514  local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
1515  local_input_specs = [local_sharding_spec(aval, aval_in_axes)
1516                       if aval is not core.abstract_unit else None
1517                       for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
1518  input_indices = [spec_to_indices(aval.shape, spec)
1519                   if spec is not None else None
1520                   for aval, spec in safe_zip(local_in_untiled_avals, local_input_specs)]
1521  handle_args = partial(shard_args, compiled.local_devices(), input_indices)
1522
1523  local_output_specs = [local_sharding_spec(aval, aval_out_axes)
1524                        for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
1525  handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
1526                                         local_output_specs, local_out_untiled_avals)
1527
1528  return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
1529
1530# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
1531def vtile(f_flat,
1532          in_axes_flat: Tuple[Optional[int], ...],
1533          out_axes_flat: Tuple[Optional[int], ...],
1534          tile_size: Optional[int], axis_name):
1535  if tile_size == 1:
1536    return f_flat
1537
1538  @curry
1539  def tile_axis(arg, axis: Optional[int], tile_size):
1540    if axis is None:
1541      return arg
1542    shape = list(arg.shape)
1543    shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
1544    return arg.reshape(shape)
1545
1546  def untile_axis(out, axis: Optional[int]):
1547    if axis is None:
1548      return out
1549    shape = list(out.shape)
1550    shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
1551    return out.reshape(shape)
1552
1553  @lu.transformation
1554  def _map_to_tile(*args_flat):
1555    sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
1556    tile_size_ = tile_size or next(sizes, None)
1557    assert tile_size_ is not None, "No mapped arguments?"
1558    outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
1559    yield map(untile_axis, outputs_flat, out_axes_flat)
1560
1561  return _map_to_tile(
1562      batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat))
1563
1564_forbidden_primitives = {
1565  'xla_pmap': 'pmap',
1566  'soft_pmap': 'soft_pmap',
1567  'sharded_call': 'sharded_jit',
1568}
1569def _sanitize_mesh_jaxpr(jaxpr):
1570  for eqn in jaxpr.eqns:
1571    if eqn.primitive.name in _forbidden_primitives:
1572      raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} "
1573                         f"inside xmaps not supported!")
1574    core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
1575
1576
1577def mesh_sharding_specs(axis_sizes, axis_names):
1578  mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
1579  # NOTE: This takes in the non-sharded avals!
1580  def mk_sharding_spec(aval, aval_axes):
1581    sharding = [_UNSHARDED_INSTANCE] * len(aval.shape)
1582    mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
1583    next_sharded_axis = 0
1584    aval_shape = list(aval.shape)
1585    # NOTE: sorted is stable, which is important when multiple resources
1586    #       map to the same axis.
1587    for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
1588      assert aval_shape[axis] % axis_sizes[name] == 0, (axis_sizes[name], aval.shape[axis])
1589      aval_shape[axis] //= axis_sizes[name]
1590      if isinstance(sharding[axis], NoSharding):
1591        sharding[axis] = Chunked(())
1592      sharding[axis] = Chunked(sharding[axis].chunks + (axis_sizes[name],))
1593      assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
1594          "Value mapped to the same mesh axis twice"
1595      mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
1596      next_sharded_axis += 1
1597    return ShardingSpec(sharding, mesh_mapping)
1598  return mk_sharding_spec
1599
1600
1601# ------------------- soft_pmap -------------------
1602
1603def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk):
1604  abstract_args = unsafe_map(xla.abstractify, args)
1605  compiled_fun = _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk,
1606                                     *abstract_args)
1607  return compiled_fun(*args)
1608
1609@lu.cache
1610def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals):
1611  mapped_avals = [core.mapped_aval(axis_size, in_axis, aval) if in_axis is not None else aval
1612                  for in_axis, aval in safe_zip(in_axes, avals)]
1613  with core.extend_axis_env(axis_name, axis_size, None):  # type: ignore
1614    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
1615  out_axes = out_axes_thunk()
1616  assert all(out_axis == 0 for out_axis in out_axes)
1617  jaxpr = xla.apply_outfeed_rewriter(jaxpr)
1618
1619  num_devices = xb.local_device_count()
1620  chunk_size, ragged = divmod(axis_size, num_devices)
1621  if ragged:
1622    msg = f"number of devices {num_devices} must divide axis size {axis_size}"
1623    raise NotImplementedError(msg)
1624
1625  jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes,
1626                                      axis_name, axis_size, chunk_size)
1627  jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
1628  if jaxpr_replicas != 1: raise NotImplementedError
1629
1630  tuple_args = len(avals) > 100  # pass long arg lists as tuple for TPU
1631
1632  c = xb.make_computation_builder("soft_pmap_{}".format(fun.__name__))
1633  xla_consts = map(partial(xb.constant, c), consts)
1634  chunked_avals = [core.unmapped_aval(chunk_size, in_axis, aval) if in_axis is not None else aval
1635                   for in_axis, aval in safe_zip(in_axes, mapped_avals)]
1636  xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args)
1637  axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,))
1638  out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts,
1639                                'soft_pmap', *xla_args)
1640  built = c.Build(xops.Tuple(c, out_nodes))
1641
1642  compile_options = xb.get_compile_options(
1643          num_replicas=num_devices, num_partitions=1, device_assignment=None)
1644  compile_options.tuple_arguments = tuple_args
1645  backend = xb.get_backend(None)
1646  compiled = xla.backend_compile(backend, built, compile_options)
1647
1648  input_specs = [
1649      ShardingSpec(
1650          sharding=tuple_insert((_UNSHARDED_INSTANCE,) *
1651                                (aval.ndim - 1), in_axis, Chunked(num_devices)),
1652          mesh_mapping=[ShardedAxis(0)])
1653      if in_axis is not None else ShardingSpec(
1654          sharding=[_UNSHARDED_INSTANCE] * aval.ndim,
1655          mesh_mapping=[Replicated(num_devices)])
1656      for aval, in_axis in safe_zip(avals, in_axes)
1657  ]
1658  input_indices = [spec and spec_to_indices(aval.shape, spec)
1659                   for aval, spec in safe_zip(avals, input_specs)]
1660  handle_args = partial(shard_args, compiled.local_devices(), input_indices)
1661  handle_outs = soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals)
1662
1663  return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
1664
1665def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size):
1666  assert all(in_axis is None or in_axis == 0 for in_axis in in_axes), in_axes
1667  mapped_invars = [in_axis is not None for in_axis in in_axes]
1668  fun = partial(_soft_pmap_interp, chunk_size, jaxpr, consts, mapped_invars)
1669  in_avals = [core.unmapped_aval(chunk_size, in_axis, v.aval) if in_axis is not None else v.aval
1670              for v, in_axis in safe_zip(jaxpr.invars, in_axes)]
1671  with core.extend_axis_env(axis_name, axis_size, None):
1672    return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
1673
1674def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args):
1675
1676  env: Dict[Var, Tuple[Any, bool]] = {}
1677
1678  def read(atom: Union[Var, Literal]) -> Tuple[Any, bool]:
1679    if isinstance(atom, Literal):
1680      return (atom.val, False)
1681    else:
1682      return env[atom]
1683
1684  def write(v: Var, val: Any, mapped: bool) -> None:
1685    env[v] = (val, mapped)
1686
1687  write(core.unitvar, core.unit, False)
1688  map(write, jaxpr.constvars, consts, (False,) * len(consts))
1689  map(write, jaxpr.invars, args, mapped_invars)
1690  for eqn in jaxpr.eqns:
1691    in_vals, in_mapped = unzip2(map(read, eqn.invars))
1692    if eqn.primitive in xla.parallel_translations:
1693      rule = soft_pmap_rules[eqn.primitive]
1694      out_vals, out_mapped = rule(in_vals, in_mapped, chunk_size, **eqn.params)
1695      if not eqn.primitive.multiple_results:
1696        out_vals, out_mapped = [out_vals], [out_mapped]
1697    elif isinstance(eqn.primitive, core.CallPrimitive):
1698      # we just inline here for convenience
1699      call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
1700      out_vals = _soft_pmap_interp(chunk_size, call_jaxpr, (), in_mapped, *in_vals)
1701      out_mapped = [True] * len(out_vals)
1702    elif isinstance(eqn.primitive, core.MapPrimitive):
1703      raise NotImplementedError  # TODO
1704    else:
1705      if any(in_mapped):
1706        rule = batching.get_primitive_batcher(eqn.primitive, None)
1707        in_axes = [0 if m else batching.not_mapped for m in in_mapped]
1708        out_vals, out_axes = rule(in_vals, in_axes, **eqn.params)
1709        if not eqn.primitive.multiple_results:
1710          out_vals, out_axes = [out_vals], [out_axes]
1711        out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x
1712                    for x, d in safe_zip(out_vals, out_axes)]
1713        out_mapped = [d is not not_mapped for d in out_axes]
1714      else:
1715        out_vals = eqn.primitive.bind(*in_vals, **eqn.params)
1716        if not eqn.primitive.multiple_results:
1717          out_vals = [out_vals]
1718        out_mapped = [False for _ in out_vals]
1719    map(write, eqn.outvars, out_vals, out_mapped)
1720
1721  out_vals, out_mapped = unzip2(map(read, jaxpr.outvars))
1722  out_vals = [out if mapped else broadcast(out, chunk_size, 0)
1723              for out, mapped in safe_zip(out_vals, out_mapped)]
1724  return out_vals
1725
1726# TODO(mattjj): dedup w/ with other aval_to_result_handler via ShardingSpec
1727def soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals):
1728  nouts = len(out_avals)
1729  handlers = [soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval)
1730              for aval in out_avals]
1731  def handler(out_bufs):
1732    buffers = [[result_to_populate] * num_devices for _ in range(nouts)]
1733    for r, tuple_buf in enumerate(out_bufs):
1734      for i, buf in enumerate(tuple_buf):
1735        buffers[i][r] = buf
1736    assert not any(buf is result_to_populate for bufs in buffers
1737                   for buf in bufs)
1738    return [h(bufs) for h, bufs in safe_zip(handlers, buffers)]
1739  return handler
1740
1741def soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval):
1742  axis_size = chunk_size * num_devices
1743  if aval is core.abstract_unit:
1744    return lambda _: core.unit
1745  elif isinstance(aval, core.ShapedArray):
1746    new_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
1747    spec = ShardingSpec(
1748        sharding=(Chunked(num_devices),) + (_UNSHARDED_INSTANCE,) * aval.ndim,
1749        mesh_mapping=(ShardedAxis(0),))
1750    return lambda bufs: ShardedDeviceArray(new_aval, spec, bufs)
1751  else:
1752    raise TypeError(aval)
1753
1754soft_pmap_p = core.MapPrimitive('soft_pmap')
1755soft_pmap = soft_pmap_p.bind
1756soft_pmap_p.def_impl(soft_pmap_impl)
1757
1758soft_pmap_rules: Dict[core.Primitive, Callable] = {}
1759
1760@contextmanager
1761def maybe_extend_axis_env(*args, **kwargs):
1762  with core.extend_axis_env(*args, **kwargs):
1763    yield
1764
1765@config.register_omnistaging_disabler
1766@no_type_check
1767def omnistaging_disabler() -> None:
1768  global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
1769      _thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
1770      apply_parallel_primitive, parallel_pure_rules, \
1771      _pvals_to_results_handler, _pval_to_result_handler, replicate, \
1772      axis_index, maybe_extend_axis_env
1773
1774  @contextmanager
1775  def maybe_extend_axis_env(*args, **kwargs):
1776    yield
1777
1778  def _pvals_to_results_handler(
1779      size, nrep, npart,
1780      out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
1781      out_pvals, devices, backend):
1782    nouts = len(out_pvals)
1783    if out_parts is None:
1784      out_parts = (None,) * len(out_pvals)
1785    handlers = [
1786        _pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
1787        for pval, parts in safe_zip(out_pvals, out_parts)  # type: ignore
1788    ]
1789
1790    def handler(out_bufs):
1791      assert nrep * npart == len(out_bufs)
1792      buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)]
1793      for r, tuple_buf in enumerate(out_bufs):
1794        for i, buf in enumerate(tuple_buf):
1795          buffers[i][r] = buf
1796      assert not any(buf is result_to_populate for bufs in buffers
1797                    for buf in bufs)
1798      return [h(bufs) for h, bufs in safe_zip(handlers, buffers)]
1799    return handler
1800
1801  def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
1802    if devices:
1803      assert all(d.host_id == xb.host_id(backend) for d in devices)
1804    pv, const = pval
1805    if pv is None:
1806      if nrep is None:
1807        nrep = axis_size
1808        # If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
1809        # inside the one we're currently evaluating, and we should replicate
1810        # 'const' across the total number of devices needed. We don't necessarily
1811        # know the nested pmap's axis_size (e.g. the jaxpr for
1812        # pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
1813        # axis size of the output 'const'.
1814        # TODO: we might be doing unnecessary device transfers in the inner pmap.
1815        if isinstance(const, ShardedDeviceArray):
1816          nrep *= len(const)
1817
1818      bcast_const = (core.unit if const is core.unit
1819                    else replicate(const, axis_size, nrep, devices, backend))  # type: ignore
1820      return lambda _: bcast_const  # type: ignore
1821    else:
1822      if pv is not core.abstract_unit:
1823        unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)
1824        sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv, 0)
1825        indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
1826      else:
1827        sharding_spec = indices = None
1828        unsharded_aval = pv
1829      return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
1830
1831  @contextmanager
1832  def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
1833    dynamic_axis_env = _thread_local_state.dynamic_axis_env
1834    dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
1835    try:
1836      yield
1837    finally:
1838      dynamic_axis_env.pop()
1839
1840  def unmapped_device_count(backend=None):
1841    dynamic_axis_env = _thread_local_state.dynamic_axis_env
1842    mapped = prod(frame.hard_size for frame in dynamic_axis_env)
1843    unmapped, ragged = divmod(xb.device_count(backend), mapped)
1844    assert not ragged and unmapped > 0
1845    return unmapped
1846
1847  def apply_parallel_primitive(prim, *args, **params):
1848    # This is the op-by-op version of applying a collective primitive, like a psum
1849    # that doesn't have a data dependence on the argument of a pmap function. In
1850    # particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
1851    # look up information in the dynamic axis env.
1852    dynamic_axis_env = _thread_local_state.dynamic_axis_env
1853    axis_name = params.pop('axis_name')
1854    axis_index_groups = params.pop('axis_index_groups')
1855    if axis_index_groups is not None:
1856      shape = (len(axis_index_groups[0]),)
1857    else:
1858      logical_size = lambda frame: frame.hard_size
1859      if isinstance(axis_name, (list, tuple)):
1860        shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
1861      else:
1862        shape = (logical_size(dynamic_axis_env[axis_name]),)
1863    return parallel_pure_rules[prim](*args, shape=shape, **params)
1864
1865  pe.staged_out_calls.add(xla_pmap_p)  # type: ignore
1866
1867parallel_pure_rules = {}  # type: ignore
1868
1869class DynamicAxisEnvFrame(object):
1870  __slots__ = ["name", "pmap_trace", "hard_size"]
1871  def __init__(self, name, pmap_trace, hard_size):
1872    self.name = name
1873    self.pmap_trace = pmap_trace
1874    self.hard_size = hard_size
1875
1876class DynamicAxisEnv(list):
1877  def __contains__(self, axis_name):
1878    return axis_name in (frame.name for frame in self)
1879
1880  def __getitem__(self, axis_name):
1881    if axis_name not in self:
1882      raise NameError("unbound axis name: {}".format(axis_name))
1883    for frame in reversed(self):
1884      if frame.name == axis_name:
1885        return frame
1886
1887    raise AssertionError
1888
1889  @property
1890  def sizes(self):
1891    return tuple(frame.hard_size for frame in self)
1892
1893  @property
1894  def nreps(self):
1895    return prod(frame.hard_size for frame in self)
1896
1897class _ThreadLocalState(threading.local):
1898  def __init__(self):
1899    self.dynamic_axis_env = DynamicAxisEnv()
1900
1901_thread_local_state = _ThreadLocalState()
1902
1903def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client._xla.PyLocalBuffer]:
1904  """Call device_put on a sequence of devices and return a flat sequence of buffers."""
1905  if replicate:
1906    return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
1907  else:
1908    return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
1909