1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" 15Parallelization primitives. 16""" 17 18import collections 19import string 20import warnings 21 22import numpy as np 23 24from jax import core 25from jax import dtypes 26from jax import tree_util 27from jax._src import source_info_util 28from . import lax 29from jax.core import ShapedArray, raise_to_shaped 30from jax.interpreters import ad 31from jax.interpreters import xla 32from jax.interpreters import pxla 33from jax.interpreters import batching 34from jax.interpreters import partial_eval as pe 35from jax._src.util import partial, unzip2, prod 36from jax.lib import xla_client as xc 37from jax.lib import xla_bridge as xb 38from jax.config import config 39from jax._src.numpy import lax_numpy 40 41xops = xc.ops 42 43 44### parallel traceables 45 46def psum(x, axis_name, *, axis_index_groups=None): 47 """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. 48 49 If ``x`` is a pytree then the result is equivalent to mapping this function to 50 each leaf in the tree. 51 52 Inputs of boolean dtype are converted to integers before the reduction. 53 54 Args: 55 x: array(s) with a mapped axis named ``axis_name``. 56 axis_name: hashable Python object used to name a pmapped axis (see the 57 :func:`jax.pmap` documentation for more details). 58 axis_index_groups: optional list of lists containing axis indices (e.g. for 59 an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first 60 two and last two replicas). Groups must cover all axis indices exactly 61 once, and all groups must be the same size. 62 63 64 Returns: 65 Array(s) with the same shape as ``x`` representing the result of an 66 all-reduce sum along the axis ``axis_name``. 67 68 For example, with 4 XLA devices available: 69 70 >>> x = np.arange(4) 71 >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) 72 >>> print(y) 73 [6 6 6 6] 74 >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) 75 >>> print(y) 76 [ 0. 0.16666667 0.33333334 0.5 ] 77 """ 78 if not isinstance(axis_name, (tuple, list)): 79 axis_name = (axis_name,) 80 _validate_axis_index_groups(axis_index_groups) 81 leaves, treedef = tree_util.tree_flatten(x) 82 leaves = [lax.convert_element_type(l, np.int32) 83 if dtypes.dtype(l) == np.bool_ else l for l in leaves] 84 out_flat = psum_p.bind(*leaves, axis_name=axis_name, 85 axis_index_groups=axis_index_groups) 86 return tree_util.tree_unflatten(treedef, out_flat) 87 88def pmean(x, axis_name, *, axis_index_groups=None): 89 """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. 90 91 If ``x`` is a pytree then the result is equivalent to mapping this function to 92 each leaf in the tree. 93 94 Args: 95 x: array(s) with a mapped axis named ``axis_name``. 96 axis_name: hashable Python object used to name a pmapped axis (see the 97 :func:`jax.pmap` documentation for more details). 98 axis_index_groups: optional list of lists containing axis indices (e.g. for 99 an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first 100 two and last two replicas). Groups must cover all axis indices exactly 101 once, and all groups must be the same size. 102 103 Returns: 104 Array(s) with the same shape as ``x`` representing the result of an 105 all-reduce mean along the axis ``axis_name``. 106 107 For example, with 4 XLA devices available: 108 109 >>> x = np.arange(4) 110 >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x) 111 >>> print(y) 112 [ 1.5 1.5 1.5 1.5 ] 113 >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x) 114 >>> print(y) 115 [ 0. 0.66666667 1.33333334 2.0 ] 116 """ 117 x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) 118 n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) 119 return tree_util.tree_map(lambda v: v / n, x) 120 121def pmax(x, axis_name, *, axis_index_groups=None): 122 """Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``. 123 124 If ``x`` is a pytree then the result is equivalent to mapping this function to 125 each leaf in the tree. 126 127 Args: 128 x: array(s) with a mapped axis named ``axis_name``. 129 axis_name: hashable Python object used to name a pmapped axis (see the 130 :func:`jax.pmap` documentation for more details). 131 axis_index_groups: optional list of lists containing axis indices (e.g. for 132 an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first 133 two and last two replicas). Groups must cover all axis indices exactly 134 once, and all groups must be the same size. 135 136 Returns: 137 Array(s) with the same shape as ``x`` representing the result of an 138 all-reduce max along the axis ``axis_name``. 139 """ 140 if not isinstance(axis_name, (tuple, list)): 141 axis_name = (axis_name,) 142 _validate_axis_index_groups(axis_index_groups) 143 leaves, treedef = tree_util.tree_flatten(x) 144 out_flat = pmax_p.bind(*leaves, axis_name=axis_name, 145 axis_index_groups=axis_index_groups) 146 return tree_util.tree_unflatten(treedef, out_flat) 147 148def pmin(x, axis_name, *, axis_index_groups=None): 149 """Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``. 150 151 If ``x`` is a pytree then the result is equivalent to mapping this function to 152 each leaf in the tree. 153 154 Args: 155 x: array(s) with a mapped axis named ``axis_name``. 156 axis_name: hashable Python object used to name a pmapped axis (see the 157 :func:`jax.pmap` documentation for more details). 158 axis_index_groups: optional list of lists containing axis indices (e.g. for 159 an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first 160 two and last two replicas). Groups must cover all axis indices exactly 161 once, and all groups must be the same size. 162 163 Returns: 164 Array(s) with the same shape as ``x`` representing the result of an 165 all-reduce min along the axis ``axis_name``. 166 """ 167 if not isinstance(axis_name, (tuple, list)): 168 axis_name = (axis_name,) 169 _validate_axis_index_groups(axis_index_groups) 170 leaves, treedef = tree_util.tree_flatten(x) 171 out_flat = pmin_p.bind(*leaves, axis_name=axis_name, 172 axis_index_groups=axis_index_groups) 173 return tree_util.tree_unflatten(treedef, out_flat) 174 175def _validate_axis_index_groups(axis_index_groups): 176 if axis_index_groups is None: 177 return 178 len_0 = len(axis_index_groups[0]) 179 if any(len(g) != len_0 for g in axis_index_groups): 180 raise ValueError("axis_index_groups must all be the same size") 181 axis_space = range(len_0 * len(axis_index_groups)) 182 if {i for g in axis_index_groups for i in g} != set(axis_space): 183 raise ValueError("axis_index_groups must cover all indices exactly once") 184 185def ppermute(x, axis_name, perm): 186 """Perform a collective permutation according to the permutation ``perm``. 187 188 If ``x`` is a pytree then the result is equivalent to mapping this function to 189 each leaf in the tree. 190 191 This function is an analog of the CollectivePermute XLA HLO. 192 193 Args: 194 x: array(s) with a mapped axis named ``axis_name``. 195 axis_name: hashable Python object used to name a pmapped axis (see the 196 :func:`jax.pmap` documentation for more details). 197 perm: list of pairs of ints, representing 198 ``(source_index, destination_index)`` 199 pairs that encode how the mapped axis named ``axis_name`` should be 200 shuffled. The integer values are treated as indices into the mapped axis 201 ``axis_name``. Any two pairs should not have the same source index or the 202 same destination index. For each index of the axis ``axis_name`` that does 203 not correspond to a destination index in ``perm``, the corresponding 204 values in the result are filled with zeros of the appropriate type. 205 206 Returns: 207 Array(s) with the same shape as ``x`` with slices along the axis 208 ``axis_name`` gathered from ``x`` according to the permutation ``perm``. 209 """ 210 return tree_util.tree_map( 211 partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x) 212 213def pshuffle(x, axis_name, perm): 214 """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding 215 216 If ``x`` is a pytree then the result is equivalent to mapping this function to 217 each leaf in the tree. 218 219 Args: 220 x: array(s) with a mapped axis named ``axis_name``. 221 axis_name: hashable Python object used to name a pmapped axis (see the 222 :func:`jax.pmap` documentation for more details). 223 perm: list of of ints encoding sources for the permutation to be applied to 224 the axis named ``axis_name``, so that the output at axis index i 225 comes from the input at axis index perm[i]. Every integer in [0, N) should 226 be included exactly once for axis size N. 227 228 Returns: 229 Array(s) with the same shape as ``x`` with slices along the axis 230 ``axis_name`` gathered from ``x`` according to the permutation ``perm``. 231 """ 232 if set(perm) != set(range(len(perm))): 233 raise ValueError(f"`perm` does not represent a permutation: {perm}") 234 return ppermute(x, axis_name, list(zip(perm, range(len(perm))))) 235 236 237def pswapaxes(x, axis_name, axis, *, axis_index_groups=None): 238 """Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``. 239 240 If ``x`` is a pytree then the result is equivalent to mapping this function to 241 each leaf in the tree. 242 243 The group size of the mapped axis size must be equal to the size of the 244 unmapped axis; that is, we must have 245 ``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``. 246 By default, when ``axis_index_groups=None``, this encompasses all the devices. 247 248 This function is a special case of ``all_to_all`` where the pmapped axis of 249 the input is placed at the position ``axis`` in the output. That is, it is 250 equivalent to ``all_to_all(x, axis_name, axis, axis)``. 251 252 Args: 253 x: array(s) with a mapped axis named ``axis_name``. 254 axis_name: hashable Python object used to name a pmapped axis (see the 255 :func:`jax.pmap` documentation for more details). 256 axis: int indicating the unmapped axis of ``x`` to map with the name 257 ``axis_name``. 258 axis_index_groups: optional list of lists containing axis indices (e.g. for 259 an axis of size 4, [[0, 1], [2, 3]] would run pswapaxes over the first 260 two and last two replicas). Groups must cover all axis indices exactly 261 once, and all groups must be the same size. 262 263 Returns: 264 Array(s) with the same shape as ``x``. 265 """ 266 return all_to_all(x, axis_name, axis, axis, axis_index_groups=axis_index_groups) 267 268def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None): 269 """Materialize the mapped axis and map a different axis. 270 271 If ``x`` is a pytree then the result is equivalent to mapping this function to 272 each leaf in the tree. 273 274 In the output, the input mapped axis ``axis_name`` is materialized at the 275 logical axis position ``concat_axis``, and the input unmapped axis at position 276 ``split_axis`` is mapped with the name ``axis_name``. 277 278 The group size of the mapped axis size must be equal to the size of the 279 unmapped axis; that is, we must have 280 ``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``. 281 By default, when ``axis_index_groups=None``, this encompasses all the devices. 282 283 Args: 284 x: array(s) with a mapped axis named ``axis_name``. 285 axis_name: hashable Python object used to name a pmapped axis (see the 286 :func:`jax.pmap` documentation for more details). 287 split_axis: int indicating the unmapped axis of ``x`` to map with the name 288 ``axis_name``. 289 concat_axis: int indicating the position in the output to materialize the 290 mapped axis of the input with the name ``axis_name``. 291 axis_index_groups: optional list of lists containing axis indices (e.g. for 292 an axis of size 4, [[0, 1], [2, 3]] would run all_to_all over the first 293 two and last two replicas). Groups must cover all axis indices exactly 294 once, and all groups must be the same size. 295 296 Returns: 297 Array(s) with shape given by the expression:: 298 299 np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) 300 301 where ``axis_size`` is the size of the mapped axis named ``axis_name`` in 302 the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. 303 """ 304 def bind(x): 305 group_size = psum(1, axis_name, axis_index_groups=axis_index_groups) 306 if group_size != x.shape[split_axis]: 307 msg = ("all_to_all requires the size of the mapped axis axis_name to " 308 "equal x.shape[split_axis], but they are {} and {} respectively.") 309 raise ValueError(msg.format(group_size, x.shape[split_axis])) 310 return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, 311 axis_name=axis_name, 312 axis_index_groups=axis_index_groups) 313 314 return tree_util.tree_map(bind, x) 315 316def axis_index(axis_name): 317 """Return the index along the mapped axis ``axis_name``. 318 319 Args: 320 axis_name: hashable Python object used to name the mapped axis. 321 322 Returns: 323 An integer representing the index. 324 325 For example, with 8 XLA devices available: 326 327 >>> from functools import partial 328 >>> @partial(jax.pmap, axis_name='i') 329 ... def f(_): 330 ... return lax.axis_index('i') 331 ... 332 >>> f(np.zeros(4)) 333 ShardedDeviceArray([0, 1, 2, 3], dtype=int32) 334 >>> f(np.zeros(8)) 335 ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) 336 >>> @partial(jax.pmap, axis_name='i') 337 ... @partial(jax.pmap, axis_name='j') 338 ... def f(_): 339 ... return lax.axis_index('i'), lax.axis_index('j') 340 ... 341 >>> x, y = f(np.zeros((4, 2))) 342 >>> print(x) 343 [[0 0] 344 [1 1] 345 [2 2] 346 [3 3]] 347 >>> print(y) 348 [[0 1] 349 [0 1] 350 [0 1] 351 [0 1]] 352 """ 353 return axis_index_p.bind(axis_name=axis_name) 354 355 356def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())): 357 if not isinstance(axis_name, (list, tuple)): 358 axis_name = (axis_name,) 359 return pdot_p.bind(x, y, axis_name=axis_name, 360 pos_contract=pos_contract, pos_batch=pos_batch) 361 362 363def xeinsum(spec: str, x, y): 364 in_spec, out_spec = spec.split('->') 365 (lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args() 366 (out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args() 367 all_named = {*lhs_named, *rhs_named, *out_named} 368 all_subs = {*lhs_subs, *rhs_subs, *out_subs} 369 lhs_uniques = set(lhs_subs) - set(rhs_subs) 370 rhs_uniques = set(rhs_subs) - set(lhs_subs) 371 if all_subs & all_named: 372 raise NotImplementedError 373 if not set(out_named).issubset({*lhs_named, *rhs_named}): 374 raise ValueError 375 376 # if a named axis appears in both inputs and not the output, contract! 377 named_contract = list(all_named - set(out_named)) 378 379 # if a subscript appears in both inputs and not the outputs, contract! 380 subs_contract = all_subs - set(out_subs) 381 382 lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract] 383 if lhs_reduce_axes: 384 x = lax._reduce_sum(x, lhs_reduce_axes) 385 for i in sorted(lhs_reduce_axes, reverse=True): 386 del lhs_subs[i] 387 388 rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract] 389 if rhs_reduce_axes: 390 y = lax._reduce_sum(y, rhs_reduce_axes) 391 for i in sorted(rhs_reduce_axes, reverse=True): 392 del rhs_subs[i] 393 394 pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n)) 395 for n in subs_contract - (lhs_uniques | rhs_uniques)) 396 397 # if a subscript apperas in both inputs _and_ the outputs, batch! 398 subs_batch = all_subs - subs_contract 399 if subs_batch & (lhs_uniques | rhs_uniques): 400 raise NotImplementedError 401 402 pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) 403 for n in subs_batch) 404 405 return pdot(x, y, axis_name=named_contract, 406 pos_contract=pos_contract, pos_batch=pos_batch) 407 408class XeinsumSpecParser: 409 spec: str 410 pos: int 411 412 def __init__(self, spec: str): 413 self.spec = spec 414 self.pos = 0 415 416 @property 417 def eof(self): 418 return self.pos == len(self.spec) 419 420 @property 421 def cur(self): 422 return self.spec[self.pos] 423 424 def parse_subscript(self): 425 if self.cur in string.ascii_lowercase: 426 out = self.cur 427 self.pos += 1 428 return out, True 429 else: 430 return None, False 431 432 def parse_axis_name(self): 433 try: 434 end = self.spec.index('}', self.pos) 435 except ValueError: 436 assert False 437 438 try: 439 end = self.spec.index(',', self.pos, end) 440 except ValueError: 441 pass 442 443 axis_name = self.spec[self.pos:end] 444 assert axis_name 445 self.pos = end + 1 446 return axis_name, self.spec[end] == ',' 447 448 def maybe_take(self, char: str, on_eof: bool = False): 449 if self.eof: 450 return on_eof 451 if self.cur == char: 452 self.pos += 1 453 return True 454 455 def parse_arg(self): 456 subscripts = [] 457 names = [] 458 while not self.eof: 459 subscript, cont = self.parse_subscript() 460 if not cont: break 461 subscripts.append(subscript) 462 if self.eof: 463 return False, (subscripts, names) 464 if self.maybe_take(','): 465 return True, (subscripts, names) 466 else: 467 assert self.maybe_take('{') 468 while True: 469 axis_name, cont = self.parse_axis_name() 470 names.append(axis_name) 471 if not cont: break 472 return self.maybe_take(',', False), (subscripts, names) 473 474 def parse_args(self): 475 arg_specs = [] 476 cont = True 477 while not self.eof: 478 cont, result = self.parse_arg() 479 arg_specs.append(result) 480 if cont: 481 arg_specs.append(([], [])) 482 return arg_specs 483 484 485### parallel primitives 486 487def _allreduce_soft_pmap_rule(prim, reducer, vals, mapped, chunk_size, 488 *, axis_name, axis_index_groups): 489 if axis_index_groups is not None: 490 raise NotImplementedError("soft_pmap does not yet support axis_index_groups") 491 reduced_vals = [reducer(x, [0]) if m else x for x, m in zip(vals, mapped)] 492 outs = prim.bind(*reduced_vals, axis_name=axis_name, 493 axis_index_groups=axis_index_groups) 494 return outs, (False,) * len(vals) 495 496# This is only used for collectives that do not include the vmapped axis name, 497# which is why the rule is so simple. 498def _collective_batcher(prim, args, dims, **params): 499 return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] 500 501def _batched_reduction_collective( 502 prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name, 503 axis_index_groups): 504 assert prim.multiple_results 505 assert frame.name in axis_name 506 if axis_index_groups is not None: 507 raise NotImplementedError("axis_index_groups not supported in vmap collectives. " 508 "Please open a feature request!") 509 vals_out = [if_mapped(v, d) if d is not batching.not_mapped 510 else if_unmapped(v, frame.size) for v, d in zip(vals_in, dims_in)] 511 if len(axis_name) > 1: 512 remaining_axis_names = tuple(n for n in axis_name if n != frame.name) 513 vals_out = prim.bind(*vals_out, axis_name=remaining_axis_names, 514 axis_index_groups=None) 515 return vals_out, [batching.not_mapped] * len(vals_out) 516 517def _replica_groups(axis_env, axis_name, axis_index_groups): 518 replica_groups = xla.axis_groups(axis_env, axis_name) 519 if axis_index_groups is not None: 520 replica_groups = [[axis_group[i] for i in axis_index_group] 521 for axis_group in replica_groups 522 for axis_index_group in axis_index_groups] 523 return replica_groups 524 525def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, 526 axis_env, platform): 527 if platform in ("cpu", "tpu"): 528 return _notuple_allreduce_translation_rule( 529 prim, c, *args, axis_name=axis_name, 530 axis_index_groups=axis_index_groups, axis_env=axis_env, 531 platform=platform) 532 533 # XLA's tuple all-reduce doesn't support different dtypes in the same 534 # allreduce. Instead, we perform once all-reduce for each argument input type. 535 args_by_type = collections.defaultdict(lambda: ([], [])) 536 for i, arg in enumerate(args): 537 indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] 538 indices.append(i) 539 dtype_args.append(arg) 540 541 # The outputs, in the original argument order. 542 out = [None] * len(args) 543 replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) 544 replica_groups_protos = xc.make_replica_groups(replica_groups) 545 for dtype, (indices, dtype_args) in sorted(args_by_type.items()): 546 is_complex = dtypes.issubdtype(dtype, np.complexfloating) 547 n = len(dtype_args) 548 if is_complex and prim is lax.add_p: 549 # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a 550 # special case because it's not currently handled by XLA:GPU 551 dtype_args = ([xops.Real(x) for x in dtype_args] + 552 [xops.Imag(x) for x in dtype_args]) 553 scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) 554 computation = xla.primitive_subcomputation(prim, scalar, scalar) 555 all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, 556 replica_groups_protos, None, None) 557 if is_complex and prim is lax.add_p: 558 xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), 559 xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] 560 else: 561 xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] 562 for i, x in zip(indices, xs): 563 out[i] = x 564 return xops.Tuple(c, out) 565 566# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support 567# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior. 568def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env, 569 axis_index_groups, platform): 570 def all_reduce(x): 571 replica_groups_protos = xc.make_replica_groups( 572 _replica_groups(axis_env, axis_name, axis_index_groups)) 573 scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) 574 computation = xla.primitive_subcomputation(prim, scalar, scalar) 575 return xops.AllReduce(x, computation, replica_groups_protos, None, None) 576 577 if prim is not lax.add_p: 578 outs = [all_reduce(x) for x in args] 579 else: 580 # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a 581 # special case because it's not currently handled by XLA:GPU 582 outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x))) 583 if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating) 584 else all_reduce(x) for x in args] 585 return xops.Tuple(c, outs) 586 587def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups): 588 nonzero_out_cts, treedef = tree_util.tree_flatten(cts) 589 nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, 590 axis_index_groups=axis_index_groups) 591 return tree_util.tree_unflatten(treedef, nonzero_in_cts) 592 593psum_p = core.Primitive('psum') 594psum_p.multiple_results = True 595psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) 596pxla.soft_pmap_rules[psum_p] = \ 597 partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) 598xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore 599ad.deflinear2(psum_p, _psum_transpose_rule) 600pxla.multi_host_supported_collectives.add(psum_p) 601batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p) 602batching.collective_rules[psum_p] = \ 603 partial(_batched_reduction_collective, 604 psum_p, 605 lambda v, d: v.sum(d), 606 lambda v, axis_size: axis_size * v) 607 608# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at 609# tracing time. 610@psum_p.def_custom_bind 611def psum_bind(*args, axis_name, axis_index_groups): 612 if all(not isinstance(x, core.Tracer) for x in args): 613 if axis_index_groups is not None: 614 size = len(axis_index_groups[0]) 615 elif isinstance(axis_name, (list, tuple)): 616 size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore 617 else: 618 size = core.axis_frame(axis_name).size # type: ignore 619 return tuple(size * x for x in args) 620 return core.Primitive.bind( 621 psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups) 622 623 624pmax_p = core.Primitive('pmax') 625pmax_p.multiple_results = True 626pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) 627xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p) 628pxla.multi_host_supported_collectives.add(pmax_p) 629batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p) 630batching.collective_rules[pmax_p] = \ 631 partial(_batched_reduction_collective, pmax_p, 632 lambda v, d: v.max(d), lambda v, axis_size: v) 633 634 635pmin_p = core.Primitive('pmin') 636pmin_p.multiple_results = True 637pmin_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) 638xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p) 639pxla.multi_host_supported_collectives.add(pmin_p) 640batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p) 641batching.collective_rules[pmin_p] = \ 642 partial(_batched_reduction_collective, pmin_p, 643 lambda v, d: v.min(d), lambda v, axis_size: v) 644 645 646def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform): 647 replica_groups = _replica_groups(axis_env, axis_name, None) 648 group_size = len(replica_groups[0]) 649 srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) 650 if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): 651 msg = "ppermute sources and destinations must be unique, got {}." 652 raise ValueError(msg.format(perm)) 653 654 full_perm = [] 655 for grp in replica_groups: 656 grp = list(sorted(grp)) 657 full_perm.extend((grp[src], grp[dst]) for src, dst in perm) 658 return xops.CollectivePermute(x, full_perm) 659 660def _ppermute_transpose_rule(t, x, perm, axis_name): 661 srcs, dsts = unzip2(perm) 662 inverse_perm = list(zip(dsts, srcs)) 663 return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] 664 665def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm): 666 assert len(perm) == frame.size, "Permutation doesn't match the axis size!" 667 assert axis_name == frame.name, "ppermute batcher called with wrong axis name" 668 (v,), (d,) = vals_in, dims_in 669 assert d is not batching.not_mapped 670 perm_indices = [None] * frame.size 671 for src, dst in perm: 672 perm_indices[src] = dst 673 return lax_numpy.take(v, perm_indices, d), d 674 675ppermute_p = core.Primitive('ppermute') 676ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) 677ad.deflinear2(ppermute_p, _ppermute_transpose_rule) 678xla.parallel_translations[ppermute_p] = _ppermute_translation_rule 679pxla.multi_host_supported_collectives.add(ppermute_p) 680batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) 681batching.collective_rules[ppermute_p] = _ppermute_batcher 682 683 684def _moveaxis(src, dst, x): 685 perm = [i for i in range(x.ndim) if i != src] 686 perm.insert(dst, src) 687 return lax.transpose(x, perm) 688 689def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_index_groups): 690 global_full = all_gather(x, axis_name, axis_index_groups=axis_index_groups) 691 idx = axis_index(axis_name) 692 if axis_index_groups: 693 idx = idx % len(axis_index_groups[0]) 694 local_slice = lax.dynamic_index_in_dim(global_full, idx, split_axis + 1, keepdims=False) 695 return _moveaxis(0, concat_axis, local_slice) 696 697def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name, 698 axis_index_groups, axis_env, platform): 699 # Workaround for AllToAll not being implemented on CPU. 700 replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) 701 if len(replica_groups[0]) == 1: 702 return x 703 elif platform != 'tpu': 704 warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other " 705 "backends emulate it using a very slow and memory intensive algorithm, so expect " 706 "significant slowdowns.") 707 lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True) 708 return lowering(c, x, 709 split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, 710 axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) 711 else: 712 split_count = len(replica_groups[0]) 713 if not all(split_count == len(g) for g in replica_groups): 714 raise ValueError('Replica groups must be equally sized') 715 replica_groups_protos = xc.make_replica_groups(replica_groups) 716 if concat_axis == split_axis: 717 return xops.AllToAll(x, split_axis, concat_axis, split_count, 718 replica_groups_protos) 719 else: 720 if concat_axis < split_axis: 721 split_axis += 1 722 elif split_axis < concat_axis: 723 concat_axis += 1 724 x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x) 725 x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) 726 x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x) 727 return x 728 729def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups): 730 return (all_to_all( 731 cts, 732 axis_name=axis_name, 733 split_axis=concat_axis, 734 concat_axis=split_axis, 735 axis_index_groups=axis_index_groups),) 736 737 738def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups): 739 x, = vals_in 740 d, = dims_in 741 if d <= split_axis: 742 split_axis += 1 743 if d <= concat_axis: 744 concat_axis += 1 745 # Note: At this point split_axis and concat_axis are adjusted to the extra 746 # dimension and we have d != split_axis and d != concat_axis. 747 if split_axis < d < concat_axis: 748 d -= 1 749 elif concat_axis < d < split_axis: 750 d += 1 751 result = all_to_all_p.bind( 752 x, 753 axis_name=axis_name, 754 split_axis=split_axis, 755 concat_axis=concat_axis, 756 axis_index_groups=axis_index_groups) 757 return result, d 758 759def _all_to_all_batched_collective(frame, vals_in, dims_in, 760 axis_name, split_axis, concat_axis, 761 axis_index_groups): 762 if isinstance(axis_name, (list, tuple)) and len(axis_name) > 1: 763 raise NotImplementedError("update after #4835") # TODO(mattjj,apaszke) 764 x, = vals_in 765 d, = dims_in 766 split_axis_adj = split_axis + (1 if d <= split_axis else 0) 767 concat_axis_adj = concat_axis + (1 if split_axis_adj <= concat_axis else 0) 768 if d < split_axis_adj < concat_axis_adj: 769 split_axis_adj -= 1 770 elif concat_axis_adj < split_axis_adj < d: 771 split_axis_adj += 1 772 return _moveaxis(d, concat_axis_adj, x), split_axis_adj 773 774def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups): 775 input_aval = raise_to_shaped(x) 776 shape = list(input_aval.shape) 777 size = shape.pop(split_axis) 778 shape.insert(concat_axis, size) 779 return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False) 780 781all_to_all_p = core.Primitive('all_to_all') 782all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) 783xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule 784ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) 785pxla.multi_host_supported_collectives.add(all_to_all_p) 786batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher 787batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective 788 789 790def _expand(dim, size, index, x): 791 shape = list(x.shape) 792 shape.insert(dim, size) 793 out = lax.full(shape, lax._const(x, 0)) 794 return lax.dynamic_update_index_in_dim(out, x, index, dim) 795 796def all_gather(x, axis_name, *, axis_index_groups=None): 797 """Gather values of x across all replicas. 798 799 If ``x`` is a pytree then the result is equivalent to mapping this function to 800 each leaf in the tree. 801 802 This is equivalent to, but faster than, all_to_all(broadcast(x)). 803 804 Args: 805 x: array(s) with a mapped axis named ``axis_name``. 806 axis_name: hashable Python object used to name a pmapped axis (see the 807 :func:`jax.pmap` documentation for more details). 808 axis_index_groups: optional list of lists containing axis indices (e.g. for 809 an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first 810 two and last two replicas). Groups must cover all axis indices exactly 811 once, and all groups must be the same size. 812 813 Returns: 814 Array(s) representing the result of an all-gather along the axis 815 ``axis_name``. Shapes are the same as ``x.shape``, but with a leading 816 dimension of the axis_size. 817 818 For example, with 4 XLA devices available: 819 820 >>> x = np.arange(4) 821 >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) 822 >>> print(y) 823 [[0 1 2 3] 824 [0 1 2 3] 825 [0 1 2 3] 826 [0 1 2 3]] 827 828 An example of using axis_index_groups, groups split by even & odd device ids: 829 830 >>> x = np.arange(16).reshape(4, 4) 831 >>> print(x) 832 [[ 0. 1. 2. 3.] 833 [ 4. 5. 6. 7.] 834 [ 8. 9. 10. 11.] 835 [12. 13. 14. 15.]] 836 >>> y = jax.pmap(lambda x: jax.lax.all_gather( 837 ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x) 838 >>> print(y) 839 [[[ 0. 1. 2. 3.] 840 [ 8. 9. 10. 11.]] 841 [[12. 13. 14. 15.] 842 [ 4. 5. 6. 7.]] 843 [[ 0. 1. 2. 3.] 844 [ 8. 9. 10. 11.]] 845 [[12. 13. 14. 15.] 846 [ 4. 5. 6. 7.]] 847 """ 848 axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) 849 # The all_gather primitive doesn't work when omni-staging is disabled. 850 if not config.omnistaging_enabled: 851 return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name, 852 axis_index_groups=axis_index_groups, axis_size=axis_size) 853 854 def bind(x): 855 return all_gather_p.bind(x, all_gather_dimension=0, axis_name=axis_name, 856 axis_index_groups=axis_index_groups, axis_size=axis_size) 857 858 return tree_util.tree_map(bind, x) 859 860def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): 861 index = axis_index(axis_name) 862 if axis_index_groups is not None: 863 indices = np.array(axis_index_groups).flatten() 864 axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0]) 865 index = lax_numpy.array(axis_index_to_group_index)[index] 866 outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x) 867 return psum(outs, axis_name, axis_index_groups=axis_index_groups) 868 869def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): 870 # Only called when the argument is not mapped. 871 out_shape = list(np.shape(x)) 872 out_shape.insert(all_gather_dimension, axis_size) 873 broadcast_dims = [i for i in range(len(out_shape)) if i != all_gather_dimension] 874 return lax.broadcast_in_dim(x, out_shape, broadcast_dims) 875 876def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, axis_env, platform): 877 # TODO(cjfj): Enable this for TPU also? 878 if (platform == 'gpu') and (all_gather_dimension == 0): 879 new_shape = list(c.get_shape(x).dimensions()) 880 new_shape.insert(all_gather_dimension, 1) 881 broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension] 882 x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions) 883 replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) 884 return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size, 885 replica_groups=xc.make_replica_groups(replica_groups)) 886 else: 887 lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True) 888 return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, 889 axis_index_groups=axis_index_groups, axis_size=axis_size, axis_env=axis_env, platform=platform) 890 891def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): 892 x_aval = raise_to_shaped(x) 893 new_shape = list(x_aval.shape) 894 new_shape.insert(all_gather_dimension, axis_size) 895 return ShapedArray(new_shape, x_aval.dtype) 896 897def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): 898 # TODO(cjfj): Add reduce-scatter op to XLA? 899 concat_axis = 0 900 return (lax_numpy.sum( 901 all_to_all( 902 cts, 903 axis_name=axis_name, 904 split_axis=all_gather_dimension, 905 concat_axis=concat_axis, 906 axis_index_groups=axis_index_groups), 907 axis=concat_axis),) 908 909def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): 910 (x,), (d,) = vals_in, dims_in 911 if d <= all_gather_dimension: 912 all_gather_dimension += 1 913 else: 914 d += 1 915 result = all_gather_p.bind( 916 x, 917 all_gather_dimension=all_gather_dimension, 918 axis_name=axis_name, 919 axis_index_groups=axis_index_groups, 920 axis_size=axis_size) 921 return result, d 922 923def _all_gather_batched_collective(frame, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size): 924 assert axis_index_groups is None, "axis_index_groups not supported in vmap" 925 assert axis_size == frame.size, "axis size doesn't match" 926 assert axis_name == frame.name, "batcher called with wrong axis name" 927 (x,), (d,) = vals_in, dims_in 928 assert d is not batching.not_mapped 929 return _moveaxis(d, all_gather_dimension, x), batching.not_mapped 930 931all_gather_p = core.Primitive('all_gather') 932all_gather_p.def_abstract_eval(_all_gather_abstract_eval) 933all_gather_p.def_impl(_all_gather_impl) 934xla.parallel_translations[all_gather_p] = _all_gather_translation_rule 935ad.deflinear2(all_gather_p, _all_gather_transpose_rule) 936pxla.multi_host_supported_collectives.add(all_gather_p) 937batching.primitive_batchers[all_gather_p] = _all_gather_batcher 938batching.collective_rules[all_gather_p] = _all_gather_batched_collective 939 940def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): 941 axis_pos = list(axis_env.names).index(axis_name) 942 nreplicas = axis_env.nreps // prod(axis_env.sizes) 943 div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), 944 dtype=np.uint32)) 945 mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) 946 unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) 947 return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) 948 949def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): 950 assert not vals and not mapped 951 idx = axis_index(axis_name) # type: ignore 952 return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True 953 954axis_index_p = core.Primitive('axis_index') 955xla.parallel_translations[axis_index_p] = _axis_index_translation_rule 956pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore 957axis_index_p.def_abstract_eval( 958 lambda *args, **params: ShapedArray((), np.int32)) 959pxla.multi_host_supported_collectives.add(axis_index_p) 960 961# Axis index doesn't get any arguments, so that the default bind would have no 962# way to call into a data-dependency based trace such as vmap. Each trace that 963# wants to bind an axis name has to additionally implement `process_axis_index` 964# and put its main trace on the axis env stack. 965def _axis_index_bind(*, axis_name): 966 if not isinstance(axis_name, (tuple, list)): 967 axis_name = (axis_name,) 968 inner_size = 1 969 index = 0 970 for name in reversed(axis_name): 971 frame = core.axis_frame(name) 972 if frame.main_trace is not None: 973 trace = frame.main_trace.with_cur_sublevel() 974 name_idx = trace.process_axis_index(frame) 975 else: 976 name_idx = core.Primitive.bind(axis_index_p, axis_name=name) 977 index += name_idx * inner_size 978 inner_size *= psum(1, name) 979 return index 980axis_index_p.def_custom_bind(_axis_index_bind) 981 982def _process_axis_index(self, frame): 983 return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0) 984batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore 985 986 987pdot_p = core.Primitive('pdot') 988 989@pdot_p.def_impl 990def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch): 991 if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}") 992 return lax.dot_general(x, y, [pos_contract, pos_batch]) 993 994@pdot_p.def_abstract_eval 995def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch): 996 # TODO: avals with names, check inputs are mapped along axis_name, eliminate 997 if not len(set(axis_name)) == len(axis_name): raise ValueError 998 return lax.dot_general_p.abstract_eval( 999 x, y, dimension_numbers=[pos_contract, pos_batch], 1000 precision=None, preferred_element_type=None) 1001 1002def _pdot_vmap_collective_rule(frame, vals_in, dims_in, *, axis_name, 1003 pos_contract, pos_batch): 1004 x, y = vals_in 1005 x_dim, y_dim = dims_in 1006 x_pos_contract, y_pos_contract = pos_contract 1007 x_pos_contract = [x_dim] + [d + (d >= x_dim) for d in x_pos_contract] 1008 y_pos_contract = [y_dim] + [d + (d >= y_dim) for d in y_pos_contract] 1009 x_pos_batch, y_pos_batch = pos_batch 1010 x_pos_batch = [d + (d >= x_dim) for d in x_pos_batch] 1011 y_pos_batch = [d + (d >= y_dim) for d in y_pos_batch] 1012 remaining_axis_names = tuple(n for n in axis_name if n != frame.name) 1013 out = pdot_p.bind(x, y, axis_name=remaining_axis_names, 1014 pos_contract=[x_pos_contract, y_pos_contract], 1015 pos_batch=[x_pos_batch, y_pos_batch]) 1016 return out, None 1017batching.collective_rules[pdot_p] = _pdot_vmap_collective_rule 1018 1019def _pdot_vmap_batching_rule(vals_in, dims_in, *, axis_name, pos_contract, 1020 pos_batch): 1021 x, y = vals_in 1022 (pos_contract, pos_batch), result_batch_dim = lax._dot_general_batch_dim_nums( 1023 (x.ndim, y.ndim), dims_in, [pos_contract, pos_batch]) 1024 out = pdot_p.bind(x, y, axis_name=axis_name, pos_contract=pos_contract, 1025 pos_batch=pos_batch) 1026 return out, result_batch_dim 1027batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule 1028 1029def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, 1030 axis_env, platform): 1031 local_out = lax._dot_general_translation_rule( 1032 c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None, 1033 preferred_element_type=None) 1034 if axis_name: 1035 out_tup = xla.parallel_translations[psum_p]( 1036 c, local_out, axis_name=axis_name, axis_index_groups=None, 1037 axis_env=axis_env, platform=platform) 1038 out, = xla.xla_destructure(c, out_tup) 1039 else: 1040 out = local_out 1041 return out 1042xla.parallel_translations[pdot_p] = _pdot_translation_rule 1043 1044def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch): 1045 # TODO: avals with names, call pbroadcast with axis_name 1046 return lax._dot_general_transpose_lhs( 1047 g, y, dimension_numbers=[pos_contract, pos_batch], precision=None, 1048 preferred_element_type=None) 1049def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch): 1050 # TODO: avals with names, call pbroadcast with axis_name 1051 return lax._dot_general_transpose_rhs( 1052 g, x, dimension_numbers=[pos_contract, pos_batch], precision=None, 1053 preferred_element_type=None) 1054ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs) 1055 1056pxla.multi_host_supported_collectives.add(pdot_p) 1057 1058 1059@config.register_omnistaging_disabler 1060def omnistaging_disabler() -> None: 1061 global axis_index 1062 1063 psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore 1064 psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore 1065 pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore 1066 1067 def _axis_index_bind(*, axis_name): 1068 dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env 1069 frame = dynamic_axis_env[axis_name] 1070 sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] 1071 nreps = dynamic_axis_env.nreps 1072 trace = frame.pmap_trace 1073 1074 out_aval = ShapedArray((), np.int32) 1075 out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) 1076 eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, 1077 dict(nreps=nreps, sizes=sizes, axis_name=axis_name), 1078 source_info_util.current()) 1079 out_tracer.recipe = eqn 1080 1081 return out_tracer 1082 1083 def _axis_index_translation_rule(c, nreps, sizes, axis_name): 1084 div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) 1085 mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) 1086 unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) 1087 return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) 1088 1089 axis_index_p.def_custom_bind(_axis_index_bind) 1090 axis_index_p.def_abstract_eval( 1091 lambda *args, **params: ShapedArray((), np.int32)) 1092 xla.translations[axis_index_p] = _axis_index_translation_rule 1093