1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15Parallelization primitives.
16"""
17
18import collections
19import string
20import warnings
21
22import numpy as np
23
24from jax import core
25from jax import dtypes
26from jax import tree_util
27from jax._src import source_info_util
28from . import lax
29from jax.core import ShapedArray, raise_to_shaped
30from jax.interpreters import ad
31from jax.interpreters import xla
32from jax.interpreters import pxla
33from jax.interpreters import batching
34from jax.interpreters import partial_eval as pe
35from jax._src.util import partial, unzip2, prod
36from jax.lib import xla_client as xc
37from jax.lib import xla_bridge as xb
38from jax.config import config
39from jax._src.numpy import lax_numpy
40
41xops = xc.ops
42
43
44### parallel traceables
45
46def psum(x, axis_name, *, axis_index_groups=None):
47  """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.
48
49  If ``x`` is a pytree then the result is equivalent to mapping this function to
50  each leaf in the tree.
51
52  Inputs of boolean dtype are converted to integers before the reduction.
53
54  Args:
55    x: array(s) with a mapped axis named ``axis_name``.
56    axis_name: hashable Python object used to name a pmapped axis (see the
57      :func:`jax.pmap` documentation for more details).
58    axis_index_groups: optional list of lists containing axis indices (e.g. for
59      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
60      two and last two replicas). Groups must cover all axis indices exactly
61      once, and all groups must be the same size.
62
63
64  Returns:
65    Array(s) with the same shape as ``x`` representing the result of an
66    all-reduce sum along the axis ``axis_name``.
67
68  For example, with 4 XLA devices available:
69
70  >>> x = np.arange(4)
71  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
72  >>> print(y)
73  [6 6 6 6]
74  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
75  >>> print(y)
76  [ 0.          0.16666667  0.33333334  0.5       ]
77  """
78  if not isinstance(axis_name, (tuple, list)):
79    axis_name = (axis_name,)
80  _validate_axis_index_groups(axis_index_groups)
81  leaves, treedef = tree_util.tree_flatten(x)
82  leaves = [lax.convert_element_type(l, np.int32)
83            if dtypes.dtype(l) == np.bool_ else l for l in leaves]
84  out_flat = psum_p.bind(*leaves, axis_name=axis_name,
85                         axis_index_groups=axis_index_groups)
86  return tree_util.tree_unflatten(treedef, out_flat)
87
88def pmean(x, axis_name, *, axis_index_groups=None):
89  """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``.
90
91  If ``x`` is a pytree then the result is equivalent to mapping this function to
92  each leaf in the tree.
93
94  Args:
95    x: array(s) with a mapped axis named ``axis_name``.
96    axis_name: hashable Python object used to name a pmapped axis (see the
97      :func:`jax.pmap` documentation for more details).
98    axis_index_groups: optional list of lists containing axis indices (e.g. for
99      an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
100      two and last two replicas). Groups must cover all axis indices exactly
101      once, and all groups must be the same size.
102
103  Returns:
104    Array(s) with the same shape as ``x`` representing the result of an
105    all-reduce mean along the axis ``axis_name``.
106
107  For example, with 4 XLA devices available:
108
109  >>> x = np.arange(4)
110  >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x)
111  >>> print(y)
112  [ 1.5         1.5         1.5         1.5       ]
113  >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x)
114  >>> print(y)
115  [ 0.          0.66666667  1.33333334  2.0       ]
116  """
117  x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
118  n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups)
119  return tree_util.tree_map(lambda v: v / n, x)
120
121def pmax(x, axis_name, *, axis_index_groups=None):
122  """Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``.
123
124  If ``x`` is a pytree then the result is equivalent to mapping this function to
125  each leaf in the tree.
126
127  Args:
128    x: array(s) with a mapped axis named ``axis_name``.
129    axis_name: hashable Python object used to name a pmapped axis (see the
130      :func:`jax.pmap` documentation for more details).
131    axis_index_groups: optional list of lists containing axis indices (e.g. for
132      an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
133      two and last two replicas). Groups must cover all axis indices exactly
134      once, and all groups must be the same size.
135
136  Returns:
137    Array(s) with the same shape as ``x`` representing the result of an
138    all-reduce max along the axis ``axis_name``.
139  """
140  if not isinstance(axis_name, (tuple, list)):
141    axis_name = (axis_name,)
142  _validate_axis_index_groups(axis_index_groups)
143  leaves, treedef = tree_util.tree_flatten(x)
144  out_flat = pmax_p.bind(*leaves, axis_name=axis_name,
145                         axis_index_groups=axis_index_groups)
146  return tree_util.tree_unflatten(treedef, out_flat)
147
148def pmin(x, axis_name, *, axis_index_groups=None):
149  """Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``.
150
151  If ``x`` is a pytree then the result is equivalent to mapping this function to
152  each leaf in the tree.
153
154  Args:
155    x: array(s) with a mapped axis named ``axis_name``.
156    axis_name: hashable Python object used to name a pmapped axis (see the
157      :func:`jax.pmap` documentation for more details).
158    axis_index_groups: optional list of lists containing axis indices (e.g. for
159      an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
160      two and last two replicas). Groups must cover all axis indices exactly
161      once, and all groups must be the same size.
162
163  Returns:
164    Array(s) with the same shape as ``x`` representing the result of an
165    all-reduce min along the axis ``axis_name``.
166  """
167  if not isinstance(axis_name, (tuple, list)):
168    axis_name = (axis_name,)
169  _validate_axis_index_groups(axis_index_groups)
170  leaves, treedef = tree_util.tree_flatten(x)
171  out_flat = pmin_p.bind(*leaves, axis_name=axis_name,
172                         axis_index_groups=axis_index_groups)
173  return tree_util.tree_unflatten(treedef, out_flat)
174
175def _validate_axis_index_groups(axis_index_groups):
176  if axis_index_groups is None:
177    return
178  len_0 = len(axis_index_groups[0])
179  if any(len(g) != len_0 for g in axis_index_groups):
180    raise ValueError("axis_index_groups must all be the same size")
181  axis_space = range(len_0 * len(axis_index_groups))
182  if {i for g in axis_index_groups for i in g} != set(axis_space):
183    raise ValueError("axis_index_groups must cover all indices exactly once")
184
185def ppermute(x, axis_name, perm):
186  """Perform a collective permutation according to the permutation ``perm``.
187
188  If ``x`` is a pytree then the result is equivalent to mapping this function to
189  each leaf in the tree.
190
191  This function is an analog of the CollectivePermute XLA HLO.
192
193  Args:
194    x: array(s) with a mapped axis named ``axis_name``.
195    axis_name: hashable Python object used to name a pmapped axis (see the
196      :func:`jax.pmap` documentation for more details).
197    perm: list of pairs of ints, representing
198      ``(source_index, destination_index)``
199      pairs that encode how the mapped axis named ``axis_name`` should be
200      shuffled. The integer values are treated as indices into the mapped axis
201      ``axis_name``. Any two pairs should not have the same source index or the
202      same destination index. For each index of the axis ``axis_name`` that does
203      not correspond to a destination index in ``perm``, the corresponding
204      values in the result are filled with zeros of the appropriate type.
205
206  Returns:
207    Array(s) with the same shape as ``x`` with slices along the axis
208    ``axis_name`` gathered from ``x`` according to the permutation ``perm``.
209  """
210  return tree_util.tree_map(
211      partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x)
212
213def pshuffle(x, axis_name, perm):
214  """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
215
216  If ``x`` is a pytree then the result is equivalent to mapping this function to
217  each leaf in the tree.
218
219  Args:
220    x: array(s) with a mapped axis named ``axis_name``.
221    axis_name: hashable Python object used to name a pmapped axis (see the
222      :func:`jax.pmap` documentation for more details).
223    perm: list of of ints encoding sources for the permutation to be applied to
224      the axis named ``axis_name``, so that the output at axis index i
225      comes from the input at axis index perm[i]. Every integer in [0, N) should
226      be included exactly once for axis size N.
227
228  Returns:
229    Array(s) with the same shape as ``x`` with slices along the axis
230    ``axis_name`` gathered from ``x`` according to the permutation ``perm``.
231  """
232  if set(perm) != set(range(len(perm))):
233    raise ValueError(f"`perm` does not represent a permutation: {perm}")
234  return ppermute(x, axis_name, list(zip(perm, range(len(perm)))))
235
236
237def pswapaxes(x, axis_name, axis, *, axis_index_groups=None):
238  """Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``.
239
240  If ``x`` is a pytree then the result is equivalent to mapping this function to
241  each leaf in the tree.
242
243  The group size of the mapped axis size must be equal to the size of the
244  unmapped axis; that is, we must have
245  ``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``.
246  By default, when ``axis_index_groups=None``, this encompasses all the devices.
247
248  This function is a special case of ``all_to_all`` where the pmapped axis of
249  the input is placed at the position ``axis`` in the output. That is, it is
250  equivalent to ``all_to_all(x, axis_name, axis, axis)``.
251
252  Args:
253    x: array(s) with a mapped axis named ``axis_name``.
254    axis_name: hashable Python object used to name a pmapped axis (see the
255      :func:`jax.pmap` documentation for more details).
256    axis: int indicating the unmapped axis of ``x`` to map with the name
257      ``axis_name``.
258    axis_index_groups: optional list of lists containing axis indices (e.g. for
259      an axis of size 4, [[0, 1], [2, 3]] would run pswapaxes over the first
260      two and last two replicas). Groups must cover all axis indices exactly
261      once, and all groups must be the same size.
262
263  Returns:
264    Array(s) with the same shape as ``x``.
265  """
266  return all_to_all(x, axis_name, axis, axis, axis_index_groups=axis_index_groups)
267
268def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None):
269  """Materialize the mapped axis and map a different axis.
270
271  If ``x`` is a pytree then the result is equivalent to mapping this function to
272  each leaf in the tree.
273
274  In the output, the input mapped axis ``axis_name`` is materialized at the
275  logical axis position ``concat_axis``, and the input unmapped axis at position
276  ``split_axis`` is mapped with the name ``axis_name``.
277
278  The group size of the mapped axis size must be equal to the size of the
279  unmapped axis; that is, we must have
280  ``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``.
281  By default, when ``axis_index_groups=None``, this encompasses all the devices.
282
283  Args:
284    x: array(s) with a mapped axis named ``axis_name``.
285    axis_name: hashable Python object used to name a pmapped axis (see the
286      :func:`jax.pmap` documentation for more details).
287    split_axis: int indicating the unmapped axis of ``x`` to map with the name
288      ``axis_name``.
289    concat_axis: int indicating the position in the output to materialize the
290      mapped axis of the input with the name ``axis_name``.
291    axis_index_groups: optional list of lists containing axis indices (e.g. for
292      an axis of size 4, [[0, 1], [2, 3]] would run all_to_all over the first
293      two and last two replicas). Groups must cover all axis indices exactly
294      once, and all groups must be the same size.
295
296  Returns:
297    Array(s) with shape given by the expression::
298
299      np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
300
301    where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
302    the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``.
303  """
304  def bind(x):
305    group_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
306    if group_size != x.shape[split_axis]:
307      msg = ("all_to_all requires the size of the mapped axis axis_name to "
308             "equal x.shape[split_axis], but they are {} and {} respectively.")
309      raise ValueError(msg.format(group_size, x.shape[split_axis]))
310    return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
311                             axis_name=axis_name,
312                             axis_index_groups=axis_index_groups)
313
314  return tree_util.tree_map(bind, x)
315
316def axis_index(axis_name):
317  """Return the index along the mapped axis ``axis_name``.
318
319  Args:
320    axis_name: hashable Python object used to name the mapped axis.
321
322  Returns:
323    An integer representing the index.
324
325  For example, with 8 XLA devices available:
326
327  >>> from functools import partial
328  >>> @partial(jax.pmap, axis_name='i')
329  ... def f(_):
330  ...   return lax.axis_index('i')
331  ...
332  >>> f(np.zeros(4))
333  ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
334  >>> f(np.zeros(8))
335  ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
336  >>> @partial(jax.pmap, axis_name='i')
337  ... @partial(jax.pmap, axis_name='j')
338  ... def f(_):
339  ...   return lax.axis_index('i'), lax.axis_index('j')
340  ...
341  >>> x, y = f(np.zeros((4, 2)))
342  >>> print(x)
343  [[0 0]
344  [1 1]
345  [2 2]
346  [3 3]]
347  >>> print(y)
348  [[0 1]
349  [0 1]
350  [0 1]
351  [0 1]]
352  """
353  return axis_index_p.bind(axis_name=axis_name)
354
355
356def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())):
357  if not isinstance(axis_name, (list, tuple)):
358    axis_name = (axis_name,)
359  return pdot_p.bind(x, y, axis_name=axis_name,
360                     pos_contract=pos_contract, pos_batch=pos_batch)
361
362
363def xeinsum(spec: str, x, y):
364  in_spec, out_spec = spec.split('->')
365  (lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args()
366  (out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args()
367  all_named = {*lhs_named, *rhs_named, *out_named}
368  all_subs = {*lhs_subs, *rhs_subs, *out_subs}
369  lhs_uniques = set(lhs_subs) - set(rhs_subs)
370  rhs_uniques = set(rhs_subs) - set(lhs_subs)
371  if all_subs & all_named:
372    raise NotImplementedError
373  if not set(out_named).issubset({*lhs_named, *rhs_named}):
374    raise ValueError
375
376  # if a named axis appears in both inputs and not the output, contract!
377  named_contract = list(all_named - set(out_named))
378
379  # if a subscript appears in both inputs and not the outputs, contract!
380  subs_contract = all_subs - set(out_subs)
381
382  lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract]
383  if lhs_reduce_axes:
384    x = lax._reduce_sum(x, lhs_reduce_axes)
385    for i in sorted(lhs_reduce_axes, reverse=True):
386      del lhs_subs[i]
387
388  rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract]
389  if rhs_reduce_axes:
390    y = lax._reduce_sum(y, rhs_reduce_axes)
391    for i in sorted(rhs_reduce_axes, reverse=True):
392      del rhs_subs[i]
393
394  pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
395                        for n in subs_contract - (lhs_uniques | rhs_uniques))
396
397  # if a subscript apperas in both inputs _and_ the outputs, batch!
398  subs_batch = all_subs - subs_contract
399  if subs_batch & (lhs_uniques | rhs_uniques):
400    raise NotImplementedError
401
402  pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n))
403                        for n in subs_batch)
404
405  return pdot(x, y, axis_name=named_contract,
406              pos_contract=pos_contract, pos_batch=pos_batch)
407
408class XeinsumSpecParser:
409  spec: str
410  pos: int
411
412  def __init__(self, spec: str):
413    self.spec = spec
414    self.pos = 0
415
416  @property
417  def eof(self):
418    return self.pos == len(self.spec)
419
420  @property
421  def cur(self):
422    return self.spec[self.pos]
423
424  def parse_subscript(self):
425    if self.cur in string.ascii_lowercase:
426      out = self.cur
427      self.pos += 1
428      return out, True
429    else:
430      return None, False
431
432  def parse_axis_name(self):
433    try:
434      end = self.spec.index('}', self.pos)
435    except ValueError:
436      assert False
437
438    try:
439      end = self.spec.index(',', self.pos, end)
440    except ValueError:
441      pass
442
443    axis_name = self.spec[self.pos:end]
444    assert axis_name
445    self.pos = end + 1
446    return axis_name, self.spec[end] == ','
447
448  def maybe_take(self, char: str, on_eof: bool = False):
449    if self.eof:
450      return on_eof
451    if self.cur == char:
452      self.pos += 1
453      return True
454
455  def parse_arg(self):
456    subscripts = []
457    names = []
458    while not self.eof:
459      subscript, cont = self.parse_subscript()
460      if not cont: break
461      subscripts.append(subscript)
462    if self.eof:
463      return False, (subscripts, names)
464    if self.maybe_take(','):
465      return True, (subscripts, names)
466    else:
467      assert self.maybe_take('{')
468      while True:
469        axis_name, cont = self.parse_axis_name()
470        names.append(axis_name)
471        if not cont: break
472      return self.maybe_take(',', False), (subscripts, names)
473
474  def parse_args(self):
475    arg_specs = []
476    cont = True
477    while not self.eof:
478      cont, result = self.parse_arg()
479      arg_specs.append(result)
480    if cont:
481      arg_specs.append(([], []))
482    return arg_specs
483
484
485### parallel primitives
486
487def _allreduce_soft_pmap_rule(prim, reducer, vals, mapped, chunk_size,
488                              *, axis_name, axis_index_groups):
489  if axis_index_groups is not None:
490    raise NotImplementedError("soft_pmap does not yet support axis_index_groups")
491  reduced_vals = [reducer(x, [0]) if m else x for x, m in zip(vals, mapped)]
492  outs = prim.bind(*reduced_vals, axis_name=axis_name,
493                   axis_index_groups=axis_index_groups)
494  return outs, (False,) * len(vals)
495
496# This is only used for collectives that do not include the vmapped axis name,
497# which is why the rule is so simple.
498def _collective_batcher(prim, args, dims, **params):
499  return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
500
501def _batched_reduction_collective(
502    prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name,
503    axis_index_groups):
504  assert prim.multiple_results
505  assert frame.name in axis_name
506  if axis_index_groups is not None:
507    raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
508                              "Please open a feature request!")
509  vals_out = [if_mapped(v, d) if d is not batching.not_mapped
510              else if_unmapped(v, frame.size) for v, d in zip(vals_in, dims_in)]
511  if len(axis_name) > 1:
512    remaining_axis_names = tuple(n for n in axis_name if n != frame.name)
513    vals_out = prim.bind(*vals_out, axis_name=remaining_axis_names,
514                         axis_index_groups=None)
515  return vals_out, [batching.not_mapped] * len(vals_out)
516
517def _replica_groups(axis_env, axis_name, axis_index_groups):
518  replica_groups = xla.axis_groups(axis_env, axis_name)
519  if axis_index_groups is not None:
520    replica_groups = [[axis_group[i] for i in axis_index_group]
521                      for axis_group in replica_groups
522                      for axis_index_group in axis_index_groups]
523  return replica_groups
524
525def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
526                                axis_env, platform):
527  if platform in ("cpu", "tpu"):
528    return _notuple_allreduce_translation_rule(
529        prim, c, *args, axis_name=axis_name,
530        axis_index_groups=axis_index_groups, axis_env=axis_env,
531        platform=platform)
532
533  # XLA's tuple all-reduce doesn't support different dtypes in the same
534  # allreduce. Instead, we perform once all-reduce for each argument input type.
535  args_by_type = collections.defaultdict(lambda: ([], []))
536  for i, arg in enumerate(args):
537    indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
538    indices.append(i)
539    dtype_args.append(arg)
540
541  # The outputs, in the original argument order.
542  out = [None] * len(args)
543  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
544  replica_groups_protos = xc.make_replica_groups(replica_groups)
545  for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
546    is_complex = dtypes.issubdtype(dtype, np.complexfloating)
547    n = len(dtype_args)
548    if is_complex and prim is lax.add_p:
549      # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
550      # special case because it's not currently handled by XLA:GPU
551      dtype_args = ([xops.Real(x) for x in dtype_args] +
552                    [xops.Imag(x) for x in dtype_args])
553    scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
554    computation = xla.primitive_subcomputation(prim, scalar, scalar)
555    all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
556                                replica_groups_protos, None, None)
557    if is_complex and prim is lax.add_p:
558      xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
559                         xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
560    else:
561      xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
562    for i, x in zip(indices, xs):
563      out[i] = x
564  return xops.Tuple(c, out)
565
566# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support
567# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
568def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
569                                        axis_index_groups, platform):
570  def all_reduce(x):
571    replica_groups_protos = xc.make_replica_groups(
572        _replica_groups(axis_env, axis_name, axis_index_groups))
573    scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
574    computation = xla.primitive_subcomputation(prim, scalar, scalar)
575    return xops.AllReduce(x, computation, replica_groups_protos, None, None)
576
577  if prim is not lax.add_p:
578    outs = [all_reduce(x) for x in args]
579  else:
580    # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
581    # special case because it's not currently handled by XLA:GPU
582    outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x)))
583            if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating)
584            else all_reduce(x) for x in args]
585  return xops.Tuple(c, outs)
586
587def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
588  nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
589  nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
590                               axis_index_groups=axis_index_groups)
591  return tree_util.tree_unflatten(treedef, nonzero_in_cts)
592
593psum_p = core.Primitive('psum')
594psum_p.multiple_results = True
595psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
596pxla.soft_pmap_rules[psum_p] = \
597    partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
598xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p)  # type: ignore
599ad.deflinear2(psum_p, _psum_transpose_rule)
600pxla.multi_host_supported_collectives.add(psum_p)
601batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
602batching.collective_rules[psum_p] = \
603  partial(_batched_reduction_collective,
604          psum_p,
605          lambda v, d: v.sum(d),
606          lambda v, axis_size: axis_size * v)
607
608# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
609# tracing time.
610@psum_p.def_custom_bind
611def psum_bind(*args, axis_name, axis_index_groups):
612  if all(not isinstance(x, core.Tracer) for x in args):
613    if axis_index_groups is not None:
614      size = len(axis_index_groups[0])
615    elif isinstance(axis_name, (list, tuple)):
616      size = prod([core.axis_frame(name).size for name in axis_name])  # type: ignore
617    else:
618      size = core.axis_frame(axis_name).size  # type: ignore
619    return tuple(size * x for x in args)
620  return core.Primitive.bind(
621      psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
622
623
624pmax_p = core.Primitive('pmax')
625pmax_p.multiple_results = True
626pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
627xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p)
628pxla.multi_host_supported_collectives.add(pmax_p)
629batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
630batching.collective_rules[pmax_p] = \
631  partial(_batched_reduction_collective, pmax_p,
632          lambda v, d: v.max(d), lambda v, axis_size: v)
633
634
635pmin_p = core.Primitive('pmin')
636pmin_p.multiple_results = True
637pmin_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
638xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p)
639pxla.multi_host_supported_collectives.add(pmin_p)
640batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p)
641batching.collective_rules[pmin_p] = \
642  partial(_batched_reduction_collective, pmin_p,
643          lambda v, d: v.min(d), lambda v, axis_size: v)
644
645
646def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
647  replica_groups = _replica_groups(axis_env, axis_name, None)
648  group_size = len(replica_groups[0])
649  srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
650  if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
651    msg = "ppermute sources and destinations must be unique, got {}."
652    raise ValueError(msg.format(perm))
653
654  full_perm = []
655  for grp in replica_groups:
656    grp = list(sorted(grp))
657    full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
658  return xops.CollectivePermute(x, full_perm)
659
660def _ppermute_transpose_rule(t, x, perm, axis_name):
661  srcs, dsts = unzip2(perm)
662  inverse_perm = list(zip(dsts, srcs))
663  return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
664
665def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
666  assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
667  assert axis_name == frame.name, "ppermute batcher called with wrong axis name"
668  (v,), (d,) = vals_in, dims_in
669  assert d is not batching.not_mapped
670  perm_indices = [None] * frame.size
671  for src, dst in perm:
672    perm_indices[src] = dst
673  return lax_numpy.take(v, perm_indices, d), d
674
675ppermute_p = core.Primitive('ppermute')
676ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
677ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
678xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
679pxla.multi_host_supported_collectives.add(ppermute_p)
680batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
681batching.collective_rules[ppermute_p] = _ppermute_batcher
682
683
684def _moveaxis(src, dst, x):
685  perm = [i for i in range(x.ndim) if i != src]
686  perm.insert(dst, src)
687  return lax.transpose(x, perm)
688
689def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_index_groups):
690  global_full = all_gather(x, axis_name, axis_index_groups=axis_index_groups)
691  idx = axis_index(axis_name)
692  if axis_index_groups:
693    idx = idx % len(axis_index_groups[0])
694  local_slice = lax.dynamic_index_in_dim(global_full, idx, split_axis + 1, keepdims=False)
695  return _moveaxis(0, concat_axis, local_slice)
696
697def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
698                                 axis_index_groups, axis_env, platform):
699  # Workaround for AllToAll not being implemented on CPU.
700  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
701  if len(replica_groups[0]) == 1:
702    return x
703  elif platform != 'tpu':
704    warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other "
705                  "backends emulate it using a very slow and memory intensive algorithm, so expect "
706                  "significant slowdowns.")
707    lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True)
708    return lowering(c, x,
709                    split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name,
710                    axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform)
711  else:
712    split_count = len(replica_groups[0])
713    if not all(split_count == len(g) for g in replica_groups):
714      raise ValueError('Replica groups must be equally sized')
715    replica_groups_protos = xc.make_replica_groups(replica_groups)
716    if concat_axis == split_axis:
717      return xops.AllToAll(x, split_axis, concat_axis, split_count,
718                           replica_groups_protos)
719    else:
720      if concat_axis < split_axis:
721        split_axis += 1
722      elif split_axis < concat_axis:
723        concat_axis += 1
724      x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x)
725      x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
726      x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x)
727      return x
728
729def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
730  return (all_to_all(
731      cts,
732      axis_name=axis_name,
733      split_axis=concat_axis,
734      concat_axis=split_axis,
735      axis_index_groups=axis_index_groups),)
736
737
738def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups):
739  x, = vals_in
740  d, = dims_in
741  if d <= split_axis:
742    split_axis += 1
743  if d <= concat_axis:
744    concat_axis += 1
745  # Note: At this point split_axis and concat_axis are adjusted to the extra
746  #       dimension and we have d != split_axis and d != concat_axis.
747  if split_axis < d < concat_axis:
748    d -= 1
749  elif concat_axis < d < split_axis:
750    d += 1
751  result = all_to_all_p.bind(
752      x,
753      axis_name=axis_name,
754      split_axis=split_axis,
755      concat_axis=concat_axis,
756      axis_index_groups=axis_index_groups)
757  return result, d
758
759def _all_to_all_batched_collective(frame, vals_in, dims_in,
760                                   axis_name, split_axis, concat_axis,
761                                   axis_index_groups):
762  if isinstance(axis_name, (list, tuple)) and len(axis_name) > 1:
763    raise NotImplementedError("update after #4835")  # TODO(mattjj,apaszke)
764  x, = vals_in
765  d, = dims_in
766  split_axis_adj = split_axis + (1 if d <= split_axis else 0)
767  concat_axis_adj = concat_axis + (1 if split_axis_adj <= concat_axis else 0)
768  if d < split_axis_adj < concat_axis_adj:
769    split_axis_adj -= 1
770  elif concat_axis_adj < split_axis_adj < d:
771    split_axis_adj += 1
772  return _moveaxis(d, concat_axis_adj, x), split_axis_adj
773
774def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups):
775  input_aval = raise_to_shaped(x)
776  shape = list(input_aval.shape)
777  size = shape.pop(split_axis)
778  shape.insert(concat_axis, size)
779  return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
780
781all_to_all_p = core.Primitive('all_to_all')
782all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
783xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
784ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
785pxla.multi_host_supported_collectives.add(all_to_all_p)
786batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
787batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective
788
789
790def _expand(dim, size, index, x):
791  shape = list(x.shape)
792  shape.insert(dim, size)
793  out = lax.full(shape, lax._const(x, 0))
794  return lax.dynamic_update_index_in_dim(out, x, index, dim)
795
796def all_gather(x, axis_name, *, axis_index_groups=None):
797  """Gather values of x across all replicas.
798
799  If ``x`` is a pytree then the result is equivalent to mapping this function to
800  each leaf in the tree.
801
802  This is equivalent to, but faster than, all_to_all(broadcast(x)).
803
804  Args:
805    x: array(s) with a mapped axis named ``axis_name``.
806    axis_name: hashable Python object used to name a pmapped axis (see the
807      :func:`jax.pmap` documentation for more details).
808    axis_index_groups: optional list of lists containing axis indices (e.g. for
809      an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first
810      two and last two replicas). Groups must cover all axis indices exactly
811      once, and all groups must be the same size.
812
813  Returns:
814    Array(s) representing the result of an all-gather along the axis
815    ``axis_name``. Shapes are the same as ``x.shape``, but with a leading
816    dimension of the axis_size.
817
818  For example, with 4 XLA devices available:
819
820  >>> x = np.arange(4)
821  >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
822  >>> print(y)
823  [[0 1 2 3]
824   [0 1 2 3]
825   [0 1 2 3]
826   [0 1 2 3]]
827
828  An example of using axis_index_groups, groups split by even & odd device ids:
829
830  >>> x = np.arange(16).reshape(4, 4)
831  >>> print(x)
832  [[ 0.  1.  2.  3.]
833   [ 4.  5.  6.  7.]
834   [ 8.  9. 10. 11.]
835   [12. 13. 14. 15.]]
836  >>> y = jax.pmap(lambda x: jax.lax.all_gather(
837  ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x)
838  >>> print(y)
839  [[[ 0.  1.  2.  3.]
840    [ 8.  9. 10. 11.]]
841   [[12. 13. 14. 15.]
842    [ 4.  5.  6.  7.]]
843   [[ 0.  1.  2.  3.]
844    [ 8.  9. 10. 11.]]
845   [[12. 13. 14. 15.]
846    [ 4.  5.  6.  7.]]
847  """
848  axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
849  # The all_gather primitive doesn't work when omni-staging is disabled.
850  if not config.omnistaging_enabled:
851    return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name,
852                                axis_index_groups=axis_index_groups, axis_size=axis_size)
853
854  def bind(x):
855    return all_gather_p.bind(x, all_gather_dimension=0, axis_name=axis_name,
856                             axis_index_groups=axis_index_groups, axis_size=axis_size)
857
858  return tree_util.tree_map(bind, x)
859
860def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
861  index = axis_index(axis_name)
862  if axis_index_groups is not None:
863    indices = np.array(axis_index_groups).flatten()
864    axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0])
865    index = lax_numpy.array(axis_index_to_group_index)[index]
866  outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x)
867  return psum(outs, axis_name, axis_index_groups=axis_index_groups)
868
869def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
870  # Only called when the argument is not mapped.
871  out_shape = list(np.shape(x))
872  out_shape.insert(all_gather_dimension, axis_size)
873  broadcast_dims = [i for i in range(len(out_shape)) if i != all_gather_dimension]
874  return lax.broadcast_in_dim(x, out_shape, broadcast_dims)
875
876def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, axis_env, platform):
877  # TODO(cjfj): Enable this for TPU also?
878  if (platform == 'gpu') and (all_gather_dimension == 0):
879    new_shape = list(c.get_shape(x).dimensions())
880    new_shape.insert(all_gather_dimension, 1)
881    broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
882    x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions)
883    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
884    return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size,
885                          replica_groups=xc.make_replica_groups(replica_groups))
886  else:
887    lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
888    return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
889                    axis_index_groups=axis_index_groups, axis_size=axis_size, axis_env=axis_env, platform=platform)
890
891def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
892  x_aval = raise_to_shaped(x)
893  new_shape = list(x_aval.shape)
894  new_shape.insert(all_gather_dimension, axis_size)
895  return ShapedArray(new_shape, x_aval.dtype)
896
897def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
898  # TODO(cjfj): Add reduce-scatter op to XLA?
899  concat_axis = 0
900  return (lax_numpy.sum(
901      all_to_all(
902          cts,
903          axis_name=axis_name,
904          split_axis=all_gather_dimension,
905          concat_axis=concat_axis,
906          axis_index_groups=axis_index_groups),
907      axis=concat_axis),)
908
909def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
910  (x,), (d,) = vals_in, dims_in
911  if d <= all_gather_dimension:
912    all_gather_dimension += 1
913  else:
914    d += 1
915  result = all_gather_p.bind(
916      x,
917      all_gather_dimension=all_gather_dimension,
918      axis_name=axis_name,
919      axis_index_groups=axis_index_groups,
920      axis_size=axis_size)
921  return result, d
922
923def _all_gather_batched_collective(frame, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size):
924  assert axis_index_groups is None, "axis_index_groups not supported in vmap"
925  assert axis_size == frame.size, "axis size doesn't match"
926  assert axis_name == frame.name, "batcher called with wrong axis name"
927  (x,), (d,) = vals_in, dims_in
928  assert d is not batching.not_mapped
929  return _moveaxis(d, all_gather_dimension, x), batching.not_mapped
930
931all_gather_p = core.Primitive('all_gather')
932all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
933all_gather_p.def_impl(_all_gather_impl)
934xla.parallel_translations[all_gather_p] = _all_gather_translation_rule
935ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
936pxla.multi_host_supported_collectives.add(all_gather_p)
937batching.primitive_batchers[all_gather_p] = _all_gather_batcher
938batching.collective_rules[all_gather_p] = _all_gather_batched_collective
939
940def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
941  axis_pos = list(axis_env.names).index(axis_name)
942  nreplicas = axis_env.nreps // prod(axis_env.sizes)
943  div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
944                                dtype=np.uint32))
945  mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
946  unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
947  return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
948
949def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
950  assert not vals and not mapped
951  idx = axis_index(axis_name)  # type: ignore
952  return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True
953
954axis_index_p = core.Primitive('axis_index')
955xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
956pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule  # type: ignore
957axis_index_p.def_abstract_eval(
958    lambda *args, **params: ShapedArray((), np.int32))
959pxla.multi_host_supported_collectives.add(axis_index_p)
960
961# Axis index doesn't get any arguments, so that the default bind would have no
962# way to call into a data-dependency based trace such as vmap. Each trace that
963# wants to bind an axis name has to additionally implement `process_axis_index`
964# and put its main trace on the axis env stack.
965def _axis_index_bind(*, axis_name):
966  if not isinstance(axis_name, (tuple, list)):
967    axis_name = (axis_name,)
968  inner_size = 1
969  index = 0
970  for name in reversed(axis_name):
971    frame = core.axis_frame(name)
972    if frame.main_trace is not None:
973      trace = frame.main_trace.with_cur_sublevel()
974      name_idx = trace.process_axis_index(frame)
975    else:
976      name_idx = core.Primitive.bind(axis_index_p, axis_name=name)
977    index += name_idx * inner_size
978    inner_size *= psum(1, name)
979  return index
980axis_index_p.def_custom_bind(_axis_index_bind)
981
982def _process_axis_index(self, frame):
983  return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
984batching.BatchTrace.process_axis_index = _process_axis_index  # type: ignore
985
986
987pdot_p = core.Primitive('pdot')
988
989@pdot_p.def_impl
990def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
991  if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}")
992  return lax.dot_general(x, y, [pos_contract, pos_batch])
993
994@pdot_p.def_abstract_eval
995def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
996  # TODO: avals with names, check inputs are mapped along axis_name, eliminate
997  if not len(set(axis_name)) == len(axis_name): raise ValueError
998  return lax.dot_general_p.abstract_eval(
999      x, y, dimension_numbers=[pos_contract, pos_batch],
1000      precision=None, preferred_element_type=None)
1001
1002def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name,
1003                               pos_contract, pos_batch):
1004  x, y = vals_in
1005  x_dim, y_dim = dims_in
1006  x_pos_contract, y_pos_contract = pos_contract
1007  x_pos_contract = [x_dim] + [d + (d >= x_dim) for d in x_pos_contract]
1008  y_pos_contract = [y_dim] + [d + (d >= y_dim) for d in y_pos_contract]
1009  x_pos_batch, y_pos_batch = pos_batch
1010  x_pos_batch = [d + (d >= x_dim) for d in x_pos_batch]
1011  y_pos_batch = [d + (d >= y_dim) for d in y_pos_batch]
1012  remaining_axis_names = tuple(n for n in axis_name if n != frame.name)
1013  out = pdot_p.bind(x, y, axis_name=remaining_axis_names,
1014                    pos_contract=[x_pos_contract, y_pos_contract],
1015                    pos_batch=[x_pos_batch, y_pos_batch])
1016  return out, None
1017batching.collective_rules[pdot_p] = _pdot_vmap_collective_rule
1018
1019def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract,
1020                             pos_batch):
1021  x, y = vals_in
1022  (pos_contract, pos_batch), result_batch_dim = lax._dot_general_batch_dim_nums(
1023      (x.ndim, y.ndim), dims_in, [pos_contract, pos_batch])
1024  out = pdot_p.bind(x, y, axis_name=axis_name, pos_contract=pos_contract,
1025                    pos_batch=pos_batch)
1026  return out, result_batch_dim
1027batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule
1028
1029def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
1030                           axis_env, platform):
1031  local_out = lax._dot_general_translation_rule(
1032      c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None,
1033      preferred_element_type=None)
1034  if axis_name:
1035    out_tup = xla.parallel_translations[psum_p](
1036        c, local_out, axis_name=axis_name, axis_index_groups=None,
1037        axis_env=axis_env, platform=platform)
1038    out, = xla.xla_destructure(c, out_tup)
1039  else:
1040    out = local_out
1041  return out
1042xla.parallel_translations[pdot_p] = _pdot_translation_rule
1043
1044def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch):
1045  # TODO: avals with names, call pbroadcast with axis_name
1046  return lax._dot_general_transpose_lhs(
1047      g, y, dimension_numbers=[pos_contract, pos_batch], precision=None,
1048      preferred_element_type=None)
1049def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch):
1050  # TODO: avals with names, call pbroadcast with axis_name
1051  return lax._dot_general_transpose_rhs(
1052      g, x, dimension_numbers=[pos_contract, pos_batch], precision=None,
1053      preferred_element_type=None)
1054ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs)
1055
1056pxla.multi_host_supported_collectives.add(pdot_p)
1057
1058
1059@config.register_omnistaging_disabler
1060def omnistaging_disabler() -> None:
1061  global axis_index
1062
1063  psum_p.bind = partial(core.Primitive.bind, psum_p)  # type: ignore
1064  psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p))  # type: ignore
1065  pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)  # type: ignore
1066
1067  def _axis_index_bind(*, axis_name):
1068    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
1069    frame = dynamic_axis_env[axis_name]
1070    sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
1071    nreps = dynamic_axis_env.nreps
1072    trace = frame.pmap_trace
1073
1074    out_aval = ShapedArray((), np.int32)
1075    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
1076    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
1077                            dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
1078                            source_info_util.current())
1079    out_tracer.recipe = eqn
1080
1081    return out_tracer
1082
1083  def _axis_index_translation_rule(c, nreps, sizes, axis_name):
1084    div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
1085    mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
1086    unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
1087    return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
1088
1089  axis_index_p.def_custom_bind(_axis_index_bind)
1090  axis_index_p.def_abstract_eval(
1091      lambda *args, **params: ShapedArray((), np.int32))
1092  xla.translations[axis_index_p] = _axis_index_translation_rule
1093