1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementation of pmap and related functionality.""" 15 16# A ShardingSpec describes at a high level how a logical array is sharded across 17# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also 18# describe how to shard inputs to a parallel computation). spec_to_indices() 19# encodes exactly how a given ShardingSpec is translated to device buffers, i.e. 20# how the sharded array is "laid out" across devices. Given a sequence of 21# devices, we shard the data across the devices in row-major order, with 22# replication treated as an extra inner dimension. 23# 24# For example, given the logical data array [1, 2, 3, 4], if we were to 25# partition this array 4 ways with a replication factor of 2, for a total of 8 26# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4]. 27# 28# This encoding is assumed by various parts of the system, e.g. generating 29# replica groups for collective operations. 30 31import sys 32from contextlib import contextmanager 33from collections import defaultdict, OrderedDict 34import itertools as it 35import operator as op 36import threading 37from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, 38 Type, Union, Iterable, no_type_check, NamedTuple, TYPE_CHECKING) 39 40from absl import logging 41import numpy as np 42 43from ..config import flags, config 44from .. import core 45from .. import linear_util as lu 46from .. import lazy 47from ..abstract_arrays import array_types 48from ..core import ConcreteArray, ShapedArray, Var, Literal 49from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip, 50 extend_name_stack, wrap_name, assert_unreachable, 51 tuple_insert, tuple_delete, taggedtuple, curry) 52from ..lib import xla_bridge as xb 53from ..lib import xla_client as xc 54from ..tree_util import tree_flatten, tree_map 55from .batching import broadcast, not_mapped, moveaxis 56from . import batching 57from . import partial_eval as pe 58from . import xla 59from . import ad 60 61if sys.version_info >= (3, 9): 62 OrderedDictType = OrderedDict 63else: 64 OrderedDictType = Dict 65 66xops = xc.ops 67 68FLAGS = flags.FLAGS 69 70unsafe_map, map = map, safe_map # type: ignore 71 72Index = Union[int, slice, Tuple[Union[int, slice], ...]] 73 74 75class NoSharding: 76 77 def __eq__(self, other): 78 return isinstance(other, NoSharding) 79 80 def __repr__(self): 81 return "NoSharding()" 82 83 84_UNSHARDED_INSTANCE = NoSharding() 85 86# mypy is very unhappy about taggedtuple 87if TYPE_CHECKING: 88 class Unstacked(NamedTuple): 89 size: int 90else: 91 Unstacked = taggedtuple('Unstacked', ('size',)) 92 93class Chunked: 94 chunks: Tuple[int, ...] 95 96 def __init__(self, chunks: Union[int, Tuple[int, ...]]): 97 if not isinstance(chunks, tuple): 98 chunks = (chunks,) 99 object.__setattr__(self, 'chunks', chunks) 100 101 def __setattr__(self, name, value): 102 raise RuntimeError("Chunked is immutable") 103 104 def __delattr__(self, name): 105 raise RuntimeError("Chunked is immutable") 106 107 def __hash__(self): 108 return hash(self.chunks) 109 110 def __eq__(self, other): 111 return type(other) is Chunked and self.chunks == other.chunks 112 113 def __repr__(self): 114 return f'Chunked({self.chunks})' 115 116""" 117Represents all the ways we can shard a dimension. 118- `None` means no sharding; 119- `Chunked` means that the dimension is split into the specified number of chunks, 120 but the split dimension itself is preserved inside the map; 121- `Unstacked` means that the dimension is split into chunks of size 1, and doesn't 122 appear inside the map. 123""" 124AvalDimSharding = Union[Unstacked, Chunked, NoSharding] 125 126# mypy is very unhappy about taggedtuple 127if TYPE_CHECKING: 128 class ShardedAxis(NamedTuple): 129 axis: int 130 class Replicated(NamedTuple): 131 replicas: int 132else: 133 ShardedAxis = taggedtuple('ShardedAxis', ('axis',)) 134 Replicated = taggedtuple('Replicated', ('replicas',)) 135 136""" 137Assigns sharded axes to mesh dimensions. 138 139When no axis is assigned, the data is replicated. 140Note that `ShardedAxis(2)` refers to the second actually sharded axis (i.e. 141counting as if the None dimensions of sharding were filtered out). For example, 142given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry of `ShardedAxis(1)` 143refers to the `Chunked(m)` axis, not the `None`. 144""" 145MeshDimAssignment = Union[ShardedAxis, Replicated] 146 147class ShardingSpec: 148 """Describes the sharding of an ndarray. 149 150 Attributes: 151 sharding: specifies how the array is supposed to get partitioned into chunks. 152 Its length should match the rank of the array. See the docstring of 153 `AvalDimSharding` for the supported partitioning schemes. 154 mesh_mapping` describes an assignments of the array chunks created by `sharding` 155 to a logical device mesh. The length of the tuple is equal to the rank of the 156 mesh. Each mesh dimension can either get partitions of data varying along one 157 of the sharded dimensions, or the data can be replicated. See the docstring of 158 `MeshDimAssignment` for more information. 159 """ 160 sharding: Tuple[AvalDimSharding, ...] 161 mesh_mapping: Tuple[MeshDimAssignment, ...] 162 163 def __init__(self, 164 sharding: Iterable[AvalDimSharding], 165 mesh_mapping: Iterable[MeshDimAssignment]): 166 self.sharding = tuple(sharding) 167 assert all(x is not None for x in self.sharding) 168 self.mesh_mapping = tuple(mesh_mapping) 169 170 @property 171 def mesh_shape(self): 172 sharded_axis_sizes = [] 173 for sharding in self.sharding: 174 if isinstance(sharding, NoSharding): 175 continue 176 elif isinstance(sharding, Unstacked): 177 sharded_axis_sizes.append(sharding.size) 178 elif isinstance(sharding, Chunked): 179 sharded_axis_sizes.extend(sharding.chunks) 180 else: 181 assert_unreachable(sharding) 182 return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas 183 for a in self.mesh_mapping) 184 185 def sharding_proto(self): 186 """Converts a ShardingSpec to an OpSharding proto. 187 188 See 189 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 190 for details on the OpSharding proto. 191 Unfortunately the semantics are not very well described in the proto spec, but the code here might help: 192 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py 193 """ 194 mesh_shape = self.mesh_shape 195 mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) 196 197 sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped 198 replicated_maxes = [] # lists mesh axis identifiers to replicate over 199 for maxis, assignment in enumerate(self.mesh_mapping): 200 if isinstance(assignment, Replicated): 201 replicated_maxes.append(maxis) 202 elif isinstance(assignment, ShardedAxis): 203 sharded_axes[assignment.axis] = maxis 204 else: 205 assert_unreachable(assignment) 206 207 proto = xc.OpSharding() 208 if len(replicated_maxes) == len(self.mesh_mapping): 209 proto.type = xc.OpSharding.Type.REPLICATED 210 return proto 211 else: 212 proto.type = xc.OpSharding.Type.OTHER 213 214 mesh_permutation = [] 215 new_mesh_shape = [] 216 next_sharded_axis = 0 217 for axis, sharding in enumerate(self.sharding): 218 if isinstance(sharding, NoSharding): 219 new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over 220 elif isinstance(sharding, Chunked): 221 for nchunks in sharding.chunks: 222 maxis = sharded_axes[next_sharded_axis] 223 assert mesh_shape[maxis] == nchunks 224 mesh_permutation.append(maxis) 225 next_sharded_axis += 1 226 new_mesh_shape.append(int(np.prod(sharding.chunks))) 227 elif isinstance(sharding, Unstacked): 228 raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") 229 else: 230 assert_unreachable(sharding) 231 232 # Create the partial sharding proto if tensor is replicated over some mesh axes 233 if replicated_maxes: 234 new_mesh_shape.append(-1) 235 mesh_permutation.extend(replicated_maxes) 236 proto.replicate_on_last_tile_dim = True 237 238 proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape) 239 proto.tile_assignment_dimensions = list(proto_mesh.shape) 240 proto.tile_assignment_devices = list(proto_mesh.flat) 241 return proto 242 243 def indices(self, shape: Tuple[int, ...]) -> np.ndarray: 244 """Returns NumPy-style indices corresponding to a sharding spec. 245 246 Args: 247 shape: The shape of the logical array being sharded. 248 249 Returns: 250 An ndarray with the same shape as the logical mesh (as derived form 251 `mesh_mapping`). Each entry is a NumPy-style index selecting the subset of 252 the data array to be placed on a corresponding device. The indices can be 253 ints, slice objects with step=1, or tuples of those. 254 """ 255 assert len(shape) == len(self.sharding), (shape, self.sharding) 256 257 axis_indices: List[Sequence[Index]] = [] 258 shard_indices_shape = [] 259 for dim, sharding in enumerate(self.sharding): 260 axis_size = shape[dim] 261 if isinstance(sharding, NoSharding): 262 axis_indices.append([slice(None)]) 263 # NOTE: We don't append unsharded dimensions to shard_indices_shape here, 264 # because they do not appear in the mesh mapping. 265 elif isinstance(sharding, Unstacked): 266 assert axis_size == sharding.size, f'{axis_size} != {sharding.size}' 267 axis_indices.append(range(axis_size)) 268 shard_indices_shape.append(axis_size) 269 elif isinstance(sharding, Chunked): 270 total_chunks = int(np.prod(sharding.chunks)) 271 shard_size, ragged = divmod(axis_size, total_chunks) 272 assert not ragged, (axis_size, total_chunks, dim) 273 axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) 274 for i in range(total_chunks)]) 275 shard_indices_shape.extend(sharding.chunks) 276 else: 277 assert_unreachable(sharding) 278 279 # shard_indices is an ndarray representing the sharded axes of the logical array, 280 # with each dimension having size equal to the number of shards across the corresponding 281 # logical array dimension, and each element containing the multi-dimensional index that 282 # is used to extract the corresponding shard of the logical array. 283 shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object) 284 for i, idxs in enumerate(it.product(*axis_indices)): 285 shard_indices[i] = idxs 286 shard_indices = shard_indices.reshape(shard_indices_shape) 287 288 # Ensure that each sharded axis is used exactly once in the mesh mapping 289 num_sharded_dim = len(shard_indices_shape) 290 sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)] 291 assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and 292 len(sharded_dim_perm) == num_sharded_dim) 293 # Replicate/reorder the indices according to the mesh mapping 294 replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated)) 295 replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm) 296 perm = [next(replica_dim) if isinstance(a, Replicated) else 297 len(replica_sizes) + next(sharded_dim) 298 for a in self.mesh_mapping] 299 return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape) 300 .transpose(perm)) 301 302 def __eq__(self, other): 303 return (self.sharding, self.mesh_mapping) == (other.sharding, other.mesh_mapping) 304 305 def __hash__(self): 306 return hash((self.sharding, self.mesh_mapping)) 307 308 def __repr__(self): 309 return f'ShardingSpec({self.sharding}, {self.mesh_mapping})' 310 311def spec_to_indices(shape: Tuple[int, ...], 312 spec: ShardingSpec) -> Tuple[Index, ...]: 313 """Returns numpy-style indices corresponding to a sharding spec. 314 315 Each index describes a shard of the array. The order of the indices is the 316 same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out 317 row-major). 318 319 Args: 320 shape: The shape of the logical array being sharded. 321 spec: Describes how the array is sharded and how the shards are assigned to 322 the logical mesh. 323 324 Returns: 325 A tuple of length equal to the size of the mesh (inferred as the product of 326 sharded dimension sizes and all replication factors). Each element is an 327 int, a slice object with step=1, or a tuple thereof, to be treated as an 328 index into the full logical array. 329 """ 330 return tuple(spec.indices(shape).flat) 331 332 333### util 334 335def identity(x): return x 336 337# TODO(skye): expose PyLocalBuffers in xla_client 338def shard_args(devices: Sequence[xb.xla_client.Device], 339 indices: Sequence[Sequence[Index]], 340 args) -> Sequence[Sequence[xb.xla_client._xla.PyLocalBuffer]]: 341 """Shard each argument data array along its leading axis. 342 343 Args: 344 devices: sequence of Devices mapping replica index to a physical device. 345 indices: sequence of the same length as `args` describing how each arg 346 should be sharded/replicated across `devices`. Each element in `indices` 347 is the same length as `devices`. 348 args: a sequence of JaxTypes representing arguments to be sharded according 349 to `indices` and placed on `devices`. 350 351 Returns: 352 A list of device buffers with the same length as `devices` indexed by 353 replica number, so that the nth element is the argument to be passed to the 354 nth replica. 355 """ 356 nargs, nrep = len(args), len(devices) 357 buffers = [[None] * nargs for _ in range(nrep)] 358 for a, arg in enumerate(args): 359 # The shard_arg_handlers allow an extensible set of types to be sharded, but 360 # inline handling for ShardedDeviceArray as a special case for performance 361 # NOTE: we compare indices instead of sharding_spec because 362 # pmap_benchmark.pmap_shard_args_benchmark indicates this is faster. 363 if type(arg) is ShardedDeviceArray and indices[a] == arg.indices: 364 for r, buf in enumerate(arg.device_buffers): 365 buffers[r][a] = (buf if buf.device() == devices[r] 366 else buf.copy_to_device(devices[r])) 367 else: 368 arg = xla.canonicalize_dtype(arg) 369 bufs = shard_arg_handlers[type(arg)](arg, devices, indices[a]) 370 for r, buf in enumerate(bufs): 371 buffers[r][a] = buf 372 373 return buffers 374 375 376shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {} 377shard_arg_handlers[core.Unit] = \ 378 lambda x, devices, _: device_put(core.unit, devices, replicate=True) 379def _shard_array(x, devices, indices): 380 return device_put([x[i] for i in indices], devices) 381for _t in array_types: 382 shard_arg_handlers[_t] = _shard_array 383 384def _shard_device_array(x, devices, indices): 385 start_indices, limit_indices, removed_dims = map(tuple, unzip3( 386 _as_slice_indices(x, idx) for idx in indices)) 387 shards = x._multi_slice(start_indices, limit_indices, removed_dims) 388 return device_put(shards, devices) 389shard_arg_handlers[xla._DeviceArray] = _shard_device_array 390shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array 391 392 393# NOTE(skye): we could refactor to generate _multi_slice parameters directly 394# from the input ShardingSpec, rather than the indices. However, this would 395# require duplicating the ordering logic of spec_to_indices, which is more 396# subtle and more likely to change than the index logic we have to support here. 397def _as_slice_indices(arr: xla.DeviceArrayProtocol, idx: Index) -> Tuple[ 398 Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]: 399 """Returns start_indices, limit_indices, removed_dims""" 400 start_indices = [0] * arr.ndim 401 limit_indices = list(arr.shape) 402 removed_dims = [] 403 404 tuple_idx = idx if isinstance(idx, tuple) else (idx,) 405 for dim, sub_idx in enumerate(tuple_idx): 406 if isinstance(sub_idx, int): 407 start_indices[dim] = sub_idx 408 limit_indices[dim] = sub_idx + 1 409 removed_dims.append(dim) 410 elif sub_idx == slice(None): 411 continue 412 else: 413 assert isinstance(sub_idx, slice), sub_idx 414 assert isinstance(sub_idx.start, int), sub_idx 415 assert isinstance(sub_idx.stop, int), sub_idx 416 start_indices[dim] = sub_idx.start 417 limit_indices[dim] = sub_idx.stop 418 419 return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore 420 421 422def shard_aval(size, axis: int, aval): 423 try: 424 return shard_aval_handlers[type(aval)](size, axis, aval) 425 except KeyError as err: 426 raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err 427shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {} 428shard_aval_handlers[core.AbstractUnit] = lambda size, axis, x: x 429def _shard_abstract_array(size, axis: int, x): 430 try: 431 if x.shape[axis] != size: 432 raise ValueError(f"Axis size {size} does not match dimension {axis} of " 433 f"shape {x.shape}") 434 except IndexError: 435 raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None 436 return ShapedArray(tuple_delete(x.shape, axis), x.dtype) 437shard_aval_handlers[ShapedArray] = _shard_abstract_array 438 439# TODO(skye): expose PyLocalBuffers in xla_client 440def aval_to_result_handler(sharding_spec: Optional[ShardingSpec], 441 indices: Optional[Tuple[Index]], 442 aval: core.AbstractValue) -> Callable[ 443 [List[xb.xla_client._xla.PyLocalBuffer]], Any]: 444 """Returns a function for handling the raw buffers of a single output aval. 445 446 Args: 447 sharding_spec: indicates how the output is sharded across devices, or None 448 for non-array avals. 449 indices: the pre-computed result of spec_to_indices, or None for non-array 450 avals. 451 aval: the output AbstractValue. 452 453 Returns: 454 A function for handling the PyLocalBuffers that will eventually be produced 455 for this output. The function will return an object suitable for returning 456 to the user, e.g. a ShardedDeviceArray. 457 """ 458 try: 459 return pxla_result_handlers[type(aval)](sharding_spec, indices, aval) 460 except KeyError as err: 461 raise TypeError("No pxla_result_handler for type: {}".format(type(aval)) 462 ) from err 463 464PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client._xla.PyLocalBuffer]], Any]] 465pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} 466pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit 467def array_result_handler(sharding_spec, indices, aval: ShapedArray): 468 return lambda bufs: ShardedDeviceArray(aval, sharding_spec, bufs, indices) 469pxla_result_handlers[ShapedArray] = array_result_handler 470pxla_result_handlers[ConcreteArray] = array_result_handler 471 472 473### lazy device-memory persistence and result handling 474 475class ShardedDeviceArray(xla._DeviceArray): 476 """A ShardedDeviceArray is an ndarray sharded across devices. 477 478 The purpose of a ShardedDeviceArray is to reduce the number of transfers when 479 executing replicated computations, by allowing results to persist on the 480 devices that produced them. That way dispatching a similarly replicated 481 computation that consumes the same sharded memory layout does not incur any 482 transfers. 483 484 A ShardedDeviceArray represents one logical ndarray value, and simulates the 485 behavior of an ndarray so that it can be treated by user code as an ndarray; 486 that is, it is only an optimization to reduce transfers. 487 488 Attributes: 489 aval: A ShapedArray indicating the shape and dtype of this array. 490 sharding_spec: describes how this array is sharded across `device_buffers`. 491 device_buffers: the buffers containing the data for this array. Each buffer 492 is the same shape and on a different device. Buffers are in row-major 493 order, with replication treated as an extra innermost dimension. 494 indices: the result of spec_to_indices(sharding_spec). Can optionally be 495 precomputed for efficiency. A list the same length as 496 `device_buffers`. Each index indicates what portion of the full array is 497 stored in the corresponding device buffer, i.e. `array[indices[i]] == 498 device_buffers[i].to_py()`. 499 """ 500 __slots__ = ["device_buffers", "sharding_spec", "indices", 501 "_one_replica_buffer_indices"] 502 503 # TODO(skye): expose PyLocalBuffers in xla_client 504 def __init__(self, 505 aval: ShapedArray, 506 sharding_spec, # TODO(skye): add type annotation back, see below 507 device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None, 508 indices: Optional[Tuple[Index, ...]] = None): 509 xla.DeviceArray.__init__(self) 510 511 # TODO(skye): this is temporary staging while we switch users over to 512 # providing sharding_spec. It assumes that any pre-existing callers are 513 # creating pmap-style ShardedDeviceArrays over the first dimension. 514 if device_buffers is None: 515 device_buffers = sharding_spec 516 sharded_aval = ShapedArray(aval.shape[1:], aval.dtype) 517 sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], 518 1, None, sharded_aval, 0) 519 520 # TODO(skye): assert invariants. Keep performance in mind though. 521 if indices is None: 522 indices = spec_to_indices(aval.shape, sharding_spec) 523 self.aval = aval 524 self.device_buffers = device_buffers 525 self.sharding_spec = sharding_spec 526 self.indices = indices 527 self._npy_value = None 528 self._one_replica_buffer_indices = None 529 if not core.skip_checks: 530 assert type(aval) is ShapedArray 531 532 @property 533 def one_replica_buffer_indices(self): 534 """Indices of buffers containing one complete copy of the array data.""" 535 if self._one_replica_buffer_indices is None: 536 one_replica_indices = [] 537 seen_index_hashes = set() 538 for i, index in enumerate(self.indices): 539 hashed_index = _hashable_index(index) 540 if hashed_index not in seen_index_hashes: 541 one_replica_indices.append(i) 542 seen_index_hashes.add(hashed_index) 543 self._one_replica_buffer_indices = one_replica_indices 544 return self._one_replica_buffer_indices 545 546 def copy_to_host_async(self): 547 for buffer_index in self.one_replica_buffer_indices: 548 self.device_buffers[buffer_index].copy_to_host_async() 549 550 def delete(self): 551 for buf in self.device_buffers: 552 buf.delete() 553 self.device_buffers = None 554 self._npy_value = None 555 556 def _check_if_deleted(self): 557 if self.device_buffers is None: 558 raise ValueError("ShardedDeviceArray has been deleted.") 559 560 def block_until_ready(self): 561 self._check_if_deleted() 562 for buf in self.device_buffers: 563 buf.block_host_until_ready() 564 return self 565 566 @property 567 def _value(self): 568 if self._npy_value is None: 569 self.copy_to_host_async() 570 npy_value = np.empty(self.aval.shape, self.aval.dtype) 571 for i in self.one_replica_buffer_indices: 572 npy_value[self.indices[i]] = self.device_buffers[i].to_py() 573 self._npy_value = npy_value 574 return self._npy_value 575 576 def __getitem__(self, idx): 577 if not isinstance(idx, tuple): 578 cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1) 579 else: 580 cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx)) 581 if self._npy_value is None: 582 try: 583 buf_idx = self.indices.index(cidx) 584 except ValueError: 585 buf_idx = None 586 if buf_idx is not None: 587 buf = self.device_buffers[buf_idx] 588 # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 589 # is the default. 590 aval = ShapedArray( 591 getattr(buf, "xla_shape", buf.shape)().dimensions(), 592 self.aval.dtype) 593 return xla.make_device_array(aval, None, lazy.array(aval.shape), buf) 594 return super(ShardedDeviceArray, self).__getitem__(idx) 595 596 597def _hashable_index(idx): 598 return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, 599 idx) 600 601# The fast path is handled directly in shard_args(). 602# TODO(skye): is there a simpler way to rewrite this using sharding_spec? 603def _shard_sharded_device_array_slow_path(x, devices, indices): 604 candidates = defaultdict(list) 605 for buf, idx in safe_zip(x.device_buffers, x.indices): 606 candidates[_hashable_index(idx)].append(buf) 607 608 bufs = [] 609 for idx, device in safe_zip(indices, devices): 610 # Look up all buffers that contain the correct slice of the logical array. 611 candidates_list = candidates[_hashable_index(idx)] 612 if not candidates_list: 613 # This array isn't sharded correctly. Reshard it via host roundtrip. 614 # TODO(skye): more efficient reshard? 615 return shard_arg_handlers[type(x._value)](x._value, devices, indices) 616 # Try to find a candidate buffer already on the correct device, 617 # otherwise copy one of them. 618 for buf in candidates_list: 619 if buf.device() == device: 620 bufs.append(buf) 621 break 622 else: 623 bufs.append(buf.copy_to_device(device)) 624 return bufs 625shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path 626 627def _sharded_device_array_constant_handler(c, val, canonicalize_types=True): 628 return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types) 629xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler) 630 631core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray 632xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array 633xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval') 634xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity 635 636 637### the xla_pmap primitive and its rules are comparable to xla_call in xla.py 638 639def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, 640 global_axis_size, devices, name, in_axes, out_axes_thunk, 641 donated_invars, global_arg_shapes): 642 abstract_args = unsafe_map(xla.abstractify, args) 643 compiled_fun = parallel_callable(fun, backend, axis_name, axis_size, 644 global_axis_size, devices, name, 645 in_axes, out_axes_thunk, 646 donated_invars, global_arg_shapes, 647 *abstract_args) 648 return compiled_fun(*args) 649 650@lu.cache 651def parallel_callable(fun: lu.WrappedFun, 652 backend_name: Optional[str], 653 axis_name, 654 axis_size: int, 655 global_axis_size: Optional[int], 656 devices: Optional[Sequence[Any]], 657 name: str, 658 in_axes: Iterable[Optional[int]], 659 out_axes_thunk: Callable[[], Sequence[Optional[int]]], 660 donated_invars: Iterable[bool], 661 global_arg_shapes, 662 *avals): 663 if devices is not None and len(devices) == 0: 664 raise ValueError("'devices' argument to pmap must be non-empty, or None.") 665 666 # Determine global_axis_size for use in AxisEnv. 667 # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) 668 # if xb.host_count() > 1 and global_axis_size is None and inner_pmap: 669 # raise ValueError("'axis_size' must be specified for nested multi-host pmaps") 670 if (xb.host_count() == 1 and global_axis_size is not None and 671 global_axis_size != axis_size): 672 raise ValueError( 673 f"Specified axis_size {global_axis_size} doesn't match received " 674 f"axis_size {axis_size}.") 675 676 must_run_on_all_devices = False 677 no_nested_sharding = False 678 if global_axis_size is None: 679 if xb.host_count() == 1: 680 global_axis_size = axis_size 681 elif devices: 682 # This allows each host in a multi-host pmap to run on a different number 683 # of devices, but precludes nested sharding (i.e. inner pmaps or 684 # sharded_jits). 685 global_axis_size = len(devices) 686 no_nested_sharding = True 687 else: 688 # This assumes all hosts run on the same number of devices. We make sure 689 # this assumption is true by requiring that the pmap is run on all devices 690 # (and making the further assumption that each host has the same number of 691 # devices). Nested sharding is ok in this case. 692 global_axis_size = axis_size * xb.host_count() 693 assert all(len(xb.local_devices(host_id)) == xb.local_device_count() 694 for host_id in xb.host_ids()) 695 must_run_on_all_devices = True 696 697 if devices: 698 local_devices = [d for d in devices if d.host_id == xb.host_id()] 699 assert len(local_devices) > 0 700 else: 701 local_devices = None # type: ignore 702 703 if config.omnistaging_enabled: 704 sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval 705 for axis, aval in safe_zip(in_axes, avals)) 706 if any(s is not None for s in global_arg_shapes): 707 # TODO(skye): we could take this branch unconditionally if we handled 708 # grad of global_arg_shapes correctly. 709 global_sharded_avals = [ 710 ShapedArray(shape, aval.dtype) if shape is not None else aval 711 for shape, aval in safe_zip(global_arg_shapes, sharded_avals)] 712 else: 713 global_sharded_avals = sharded_avals # type: ignore 714 logging.vlog(2, "sharded_avals: %s", sharded_avals) 715 logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals) 716 717 with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore 718 jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals) 719 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 720 else: 721 @lu.wrap_init 722 def dynamic_fun(dummy, *args): 723 with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size): # type: ignore 724 return fun.call_wrapped(*args) 725 726 sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval 727 for axis, aval in safe_zip(in_axes, avals)) 728 pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals] 729 # We add a dummy first invar, to carry the trace details to `dynamic_fun` 730 pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env 731 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore 732 dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore 733 jaxpr.invars = jaxpr.invars[1:] # ignore dummy 734 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 735 736 out_pvs, out_consts = unzip2(out_pvals) 737 global_sharded_avals = sharded_avals # type: ignore 738 739 out_axes = out_axes_thunk() 740 if config.omnistaging_enabled: 741 assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes)) 742 else: 743 assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes)) 744 assert all(out_axis == 0 for out_axis in out_axes) 745 746 # TODO(skye,mattjj): allow more collectives on multi-host as we test them, but 747 # for now raise an error 748 if devices is not None: 749 is_multi_host_pmap = len(local_devices) != len(devices) 750 else: 751 is_multi_host_pmap = xb.host_count() > 1 752 if is_multi_host_pmap: 753 check_multihost_collective_allowlist(jaxpr) 754 755 if not config.omnistaging_enabled: 756 if all(pv is None for pv in out_pvs): 757 # When the output doesn't depend on the input we don't need to compile an 758 # XLA computation at all; we handle this as a special case so we can stage 759 # out multi-replica XLA computations regardless of the hardware available. 760 # The 'None' values here are just dummies we know will be ignored. 761 handlers = [ 762 _pval_to_result_handler( # type: ignore 763 axis_size, None, None, None, pval, local_devices, backend_name) # type: ignore 764 for pval in out_pvals # type: ignore 765 ] 766 results = [handler(None) for handler in handlers] 767 return lambda *_: results 768 769 770 # TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits 771 jaxpr_replicas = xla.jaxpr_replicas(jaxpr) 772 num_local_replicas = axis_size * jaxpr_replicas 773 num_global_replicas = global_axis_size * jaxpr_replicas 774 775 (arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts, 776 local_num_partitions) = _find_partitions(jaxpr) 777 778 if local_num_partitions is None: 779 local_num_partitions = num_partitions 780 781 if local_arg_parts is None: 782 local_arg_parts = arg_parts 783 if local_out_parts is None: 784 local_out_parts = out_parts 785 786 logging.vlog(2, "num_replicas: %d num_local_replicas: %d", 787 num_global_replicas, num_local_replicas) 788 logging.vlog(2, "num_partitions: %d local_num_partitions: %d", 789 num_partitions, local_num_partitions) 790 logging.vlog(2, "arg_parts: %s", arg_parts) 791 logging.vlog(2, "local_arg_parts: %s", local_arg_parts) 792 logging.vlog(2, "out_parts: %s", out_parts) 793 logging.vlog(2, "local_out_parts: %s", local_out_parts) 794 logging.vlog(2, "devices: %s", devices) 795 logging.vlog(2, "local_devices: %s", local_devices) 796 797 num_local_shards = num_local_replicas * local_num_partitions 798 num_global_shards = num_global_replicas * num_partitions 799 800 if (xb.host_count() > 1 and must_run_on_all_devices and 801 num_local_shards != xb.local_device_count()): 802 if num_local_shards == axis_size: 803 raise ValueError( 804 f"On multi-host platforms, the input to pmapped functions must have " 805 f"leading axis size equal to the number of local devices if no " 806 f"`devices` argument is specified. Got axis_size={axis_size}, " 807 f"num_local_devices={xb.local_device_count()}") 808 else: 809 raise ValueError( 810 f"On multi-host platforms, pmapped functions must run across all " 811 f"devices, i.e. num_replicas * num_partitions should equal the " 812 f"number of local devices. Got num_replicas={num_local_replicas}, " 813 f"num_partitions={num_partitions}, and " 814 f"num_local_devices={xb.local_device_count()}") 815 816 if no_nested_sharding and (jaxpr_replicas > 1 or num_partitions > 1): 817 raise ValueError( 818 f"On multi-host platforms, pmapped functions that both have `devices` " 819 f"specified and contain an inner_pmap or sharded_jit must specify an " 820 f"`axis_size` (or remove the `devices` argument). Got nested_replicas=" 821 f"{jaxpr_replicas} and nested_partitions={num_partitions}") 822 823 log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG 824 logging.log(log_priority, 825 f"Compiling {fun.__name__} for {num_global_shards} devices with " 826 f"args {avals}. (num_replicas={num_global_replicas} " 827 f"num_partitions={num_partitions})") 828 829 axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,)) 830 831 tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU 832 833 c = xb.make_computation_builder("pmap_{}".format(fun.__name__)) 834 xla_consts = map(partial(xb.constant, c), consts) 835 replicated_args = [axis is None for axis in in_axes] 836 xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args, 837 replicated=replicated_args, 838 partitions=arg_parts, 839 donated_invars=donated_invars) 840 with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore 841 out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts, 842 extend_name_stack(wrap_name(name, 'pmap')), *xla_args) 843 build_out_tuple = partial(xops.Tuple, c, out_nodes) 844 if out_parts is not None: 845 out_tuple = xb.with_sharding(c, out_parts, build_out_tuple) 846 else: 847 out_tuple = build_out_tuple() 848 backend = xb.get_backend(backend_name) 849 if backend.platform in ("gpu", "tpu"): 850 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 851 built = c.Build(out_tuple) 852 853 if devices is None: 854 if num_global_shards > xb.device_count(backend): 855 msg = ("compiling computation that requires {} logical devices, but only {} XLA " 856 "devices are available (num_replicas={}, num_partitions={})") 857 raise ValueError(msg.format(num_global_shards, xb.device_count(backend), 858 num_global_replicas, num_partitions)) 859 860 # On a single host, we use the platform's default device assignment to 861 # potentially take advantage of device locality. On multiple hosts, the 862 # default device assignment may interleave different hosts' replicas, 863 # violating pmap's semantics where data is sharded across replicas in 864 # row-major order. Instead, manually create a device assignment that ensures 865 # each host is responsible for a continguous set of replicas. 866 if num_global_shards > num_local_shards: 867 # TODO(skye): use a locality-aware assignment that satisfies the above 868 # constraint. 869 devices = [d for host_id in xb.host_ids() 870 for d in xb.local_devices(host_id)] 871 else: 872 devices = xb.get_backend(backend).get_default_device_assignment( 873 num_global_replicas, num_partitions) 874 else: 875 if num_local_shards != len(local_devices): 876 local_devices_str = ", ".join(map(str, local_devices)) 877 raise ValueError( 878 "Leading axis size of input to pmapped function must equal the " 879 "number of local devices passed to pmap. Got axis_size=%d, " 880 "num_local_devices=%d.\n(Local devices passed to pmap: %s)" 881 % (axis_size, len(local_devices), local_devices_str)) 882 if num_global_shards != len(devices): 883 raise ValueError("compiling computation that creates %s shards, " 884 "but %s devices were specified" % 885 (num_global_shards, len(devices))) 886 887 # 'devices' may be 1D or 2D at this point (e.g. 888 # get_default_device_assignment() returns 2D assignment, caller may have 889 # provided 1D list of devices). 890 device_assignment = tree_map(lambda d: d.id, devices) 891 # Convert to 2D in case it's 1D and we have > 1 partitions. 892 device_assignment = np.array(device_assignment).reshape( 893 (num_global_replicas, num_partitions)) 894 # TODO(b/162356737): Enabling SPMD partitioning causes issues with some 895 # non-partitioned workloads, so disable unless needed. 896 use_spmd_partitioning = num_partitions > 1 897 compile_options = xb.get_compile_options( 898 num_replicas=num_global_replicas, 899 num_partitions=num_partitions, 900 device_assignment=device_assignment, 901 use_spmd_partitioning=use_spmd_partitioning, 902 ) 903 compile_options.parameter_is_tupled_arguments = tuple_args 904 compiled = xla.backend_compile(backend, built, compile_options) 905 906 local_arg_parts_ = local_arg_parts or [None] * len(avals) 907 input_sharding_specs = [ 908 _pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions, 909 parts, aval, in_axis) 910 if aval is not core.abstract_unit else None 911 for aval, parts, in_axis in safe_zip(sharded_avals, local_arg_parts_, in_axes)] 912 input_indices = [spec_to_indices(aval.shape, spec) 913 if spec is not None else None 914 for aval, spec in safe_zip(avals, input_sharding_specs)] 915 handle_args = partial(shard_args, compiled.local_devices(), input_indices) 916 if config.omnistaging_enabled: 917 nouts = len(out_sharded_avals) 918 if out_parts is None: 919 out_parts = (None,) * nouts 920 if local_out_parts is None: 921 local_out_parts = (None,) * nouts 922 923 local_out_avals = [get_local_aval(aval, parts, lparts) 924 for aval, parts, lparts 925 in safe_zip(out_sharded_avals, out_parts, local_out_parts)] 926 local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval) 927 if out_axis is not None else aval 928 for aval, out_axis in safe_zip(local_out_avals, out_axes)] 929 930 out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions, 931 parts, aval, out_axis) 932 if aval is not core.abstract_unit else None 933 for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)] 934 handle_outs = avals_to_results_handler( 935 num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals) 936 else: 937 handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore 938 local_num_partitions, 939 local_out_parts, out_pvals, 940 compiled.local_devices(), backend) 941 942 return partial(execute_replicated, compiled, backend, handle_args, handle_outs) 943 944multi_host_supported_collectives: Set[core.Primitive] = set() 945 946 947def check_multihost_collective_allowlist(jaxpr): 948 used_collectives = set(xla.jaxpr_collectives(jaxpr)) 949 if not used_collectives.issubset(multi_host_supported_collectives): 950 bad_collectives = used_collectives - multi_host_supported_collectives 951 msg = "using collectives that aren't supported for multi-host: {}" 952 raise TypeError(msg.format(", ".join(map(str, bad_collectives)))) 953 954 955PartitionsOrReplicated = Optional[Tuple[int, ...]] 956 957def _find_partitions(jaxpr) -> Tuple[ 958 Optional[Tuple[PartitionsOrReplicated, ...]], 959 Optional[Tuple[PartitionsOrReplicated, ...]], 960 int, 961 Optional[Tuple[PartitionsOrReplicated, ...]], 962 Optional[Tuple[PartitionsOrReplicated, ...]], 963 Optional[int]]: 964 """Returns (in_partitions, out_partitions, num_partitions, local_in_parts, 965 local_out_parts, local_num_partitions). 966 """ 967 for eqn in jaxpr.eqns: 968 if eqn.primitive.name == "sharded_call": 969 if len(jaxpr.eqns) > 1: 970 raise NotImplementedError( 971 "pmap of sharded_jit + non-sharded operations not yet implemented.") 972 num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"], 973 eqn.params["nparts"]) 974 return (eqn.params["in_parts"], 975 eqn.params["out_parts_thunk"](), 976 num_partitions, 977 eqn.params["local_in_parts"], 978 eqn.params["local_out_parts_thunk"](), 979 eqn.params["local_nparts"]) 980 return None, None, 1, None, None, None 981 982def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]): 983 """Returns the total number of partitions to use. 984 985 Validates that any inner partitioning matches outer_num_parts if provided, and 986 returns the number of partitions to use based on outer_num_parts and any inner 987 partitioning. 988 """ 989 inner_num_parts = _inner_partitions(jaxpr, outer_num_parts) 990 if outer_num_parts is None and inner_num_parts is None: 991 # No partitions specified anywhere, everything is replicated. 992 return 1 993 if outer_num_parts is None: 994 return inner_num_parts 995 return outer_num_parts 996 997 998def _inner_partitions(jaxpr, expected_num_parts: Optional[int]): 999 """Returns the total number of partitions from PartitionSpecs inside `jaxpr`. 1000 1001 Also validates that this number matches `expected_num_parts` if provided. 1002 """ 1003 for eqn in jaxpr.eqns: 1004 if eqn.primitive.name in ["sharding_constraint", "infeed"]: 1005 parts = eqn.params["partitions"] 1006 nparts = get_num_partitions(parts) 1007 if expected_num_parts is None: 1008 expected_num_parts = nparts 1009 elif nparts is not None and nparts != expected_num_parts: 1010 # TODO(skye): raise this error as we trace the jaxpr 1011 raise ValueError( 1012 f"with_sharding_constraint with partitions={parts} " 1013 f"(total partitions: {nparts}) doesn't match expected number of " 1014 f"partitions: {expected_num_parts}. If these partitions look " 1015 f"right, check outer sharded_jit and/or other " 1016 f"with_sharding_constraint calls.") 1017 else: 1018 for subjaxpr in core.jaxprs_in_params(eqn.params): 1019 expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts) 1020 return expected_num_parts 1021 1022 1023def get_num_partitions(*partitions): 1024 partition_specs = tree_flatten(partitions)[0] 1025 if len(partition_specs) == 0: 1026 # Everything is specified as replicated (all Nones). 1027 return None 1028 num_partitions_set = {np.prod(spec) for spec in partition_specs} 1029 if len(num_partitions_set) > 1: 1030 raise ValueError( 1031 f"All partition specs must use the same number of total partitions, " 1032 f"got {partitions}, with distinct number of partitions " 1033 f"{num_partitions_set} (the total number of partitions is the product " 1034 f"of a partition spec)") 1035 assert len(num_partitions_set) == 1 1036 return num_partitions_set.pop() 1037 1038 1039def get_global_aval(local_aval, global_parts: PartitionsOrReplicated, 1040 local_parts: PartitionsOrReplicated): 1041 if local_aval is core.abstract_unit: 1042 return local_aval 1043 if global_parts is None: 1044 return local_aval 1045 assert local_parts is not None 1046 global_shape = [dim * _safe_div(ngparts, nlparts) 1047 for dim, ngparts, nlparts 1048 in safe_zip(local_aval.shape, global_parts, local_parts)] 1049 return ShapedArray(global_shape, local_aval.dtype) 1050 1051 1052def get_local_aval(global_aval, global_parts: PartitionsOrReplicated, 1053 local_parts: PartitionsOrReplicated): 1054 if global_aval is core.abstract_unit: 1055 return global_aval 1056 if global_parts is None: 1057 return global_aval 1058 assert local_parts is not None 1059 local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts)) 1060 for dim, ngparts, nlparts 1061 in safe_zip(global_aval.shape, global_parts, local_parts)] 1062 return ShapedArray(local_shape, global_aval.dtype) 1063 1064 1065def _safe_div(x, y): 1066 result, ragged = divmod(x, y) 1067 assert not ragged, f"{x} % {y} != 0" 1068 return result 1069 1070 1071class ResultToPopulate: pass 1072result_to_populate = ResultToPopulate() 1073 1074class ResultsHandler: 1075 __slots__ = ("nrep", "npart", "nouts", "out_specs", "out_indices", "handlers", 1076 "unmapped_local_out_avals") 1077 1078 def __init__(self, nrep, npart, nouts, out_specs, out_indices, handlers, 1079 unmapped_local_out_avals): 1080 self.nrep = nrep 1081 self.npart = npart 1082 self.nouts = nouts 1083 self.out_specs = out_specs 1084 self.out_indices = out_indices 1085 self.handlers = handlers 1086 self.unmapped_local_out_avals = unmapped_local_out_avals 1087 1088 def __call__(self, out_bufs): 1089 assert self.nrep * self.npart == len(out_bufs) 1090 buffers = [[result_to_populate] * (self.nrep * self.npart) 1091 for _ in range(self.nouts)] 1092 for r, tuple_buf in enumerate(out_bufs): 1093 for i, buf in enumerate(tuple_buf): 1094 buffers[i][r] = buf 1095 assert not any( 1096 buf is result_to_populate for bufs in buffers for buf in bufs) 1097 return [h(bufs) for h, bufs in safe_zip(self.handlers, buffers)] 1098 1099def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals): 1100 nouts = len(unmapped_local_out_avals) 1101 out_indices = [spec_to_indices(aval.shape, spec) 1102 if aval is not core.abstract_unit else None 1103 for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)] # pytype: disable=attribute-error 1104 handlers = [aval_to_result_handler(spec, idcs, aval) 1105 for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)] 1106 1107 return ResultsHandler(nrep, npart, nouts, out_specs, out_indices, handlers, 1108 unmapped_local_out_avals) 1109 1110def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0): 1111 """Replicates ``val`` across multiple devices. 1112 1113 Args: 1114 val: the value to be replicated. 1115 axis_size: the length of the output, i.e. the logical number of replicas to 1116 create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may 1117 be a multiple of `axis_size`. 1118 nrep: the number of replicas to create. If ``devices`` is set, must be equal 1119 to ``len(devices)``. 1120 devices: the devices to replicate across. If None, ``nrep`` will be used to 1121 generate a default device assignment. 1122 backend: string specifying which backend to use. 1123 in_axis: axis along which the value is to be replciated. 1124 1125 Returns: 1126 A ShardedDeviceArray of length `axis_size` where each shard is equal to 1127 ``val``. 1128 """ 1129 device_count = (len(devices) if devices else xb.local_device_count(backend)) 1130 if nrep > device_count: 1131 msg = ("Cannot replicate across %d replicas because only %d local devices " 1132 "are available." % (nrep, device_count)) 1133 if devices: 1134 msg += (" (local devices = %s)" 1135 % ", ".join(map(str, devices)) if devices else str(None)) 1136 raise ValueError(msg) 1137 1138 if devices is None: 1139 assert nrep is not None 1140 # TODO(skye): use different device assignment on multihost 1141 devices = xb.get_backend(backend).get_default_device_assignment(nrep) 1142 assert nrep == len(devices) 1143 1144 aval = xla.abstractify(val) # type: ShapedArray 1145 replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype) 1146 # TODO(skye): figure out how partitioning should work here 1147 sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis) 1148 device_buffers = device_put(val, devices, replicate=True) 1149 return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers) 1150 1151def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, map_axis: Optional[int]): 1152 """Sharding spec for arguments or results of a pmap. 1153 Args: 1154 nrep: number of local XLA replicas (product of local axis sizes) 1155 axis_size: local axis size for outer pmap 1156 npart: total number of XLA partitions (required by sharded_jit calls) 1157 parts: the partitioning of the value or None 1158 sharded_aval: the aval of the value inside the outer pmap, an instance of 1159 a ShapedArray. 1160 map_axis: the axis along which the value is mapped in the outer pmap 1161 Returns: 1162 A ShardingSpec. 1163 """ 1164 assert isinstance(sharded_aval, ShapedArray), sharded_aval 1165 replication_factor, ragged = divmod(nrep, axis_size) 1166 assert not ragged 1167 # get the sharding spec from inner sharded_jits as if we weren't in a pmap 1168 pspec = partitioned_sharding_spec(npart, parts, sharded_aval) 1169 maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),) 1170 if map_axis is not None: 1171 sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis]) 1172 def shift_sharded_axis(a: MeshDimAssignment): 1173 if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis: 1174 return ShardedAxis(a.axis + 1) 1175 return a 1176 # replication_factor represents the product of inner pmaps, so it goes 1177 # after the outer pmapped axis at index 0 1178 return ShardingSpec( 1179 sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)), 1180 mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)], 1181 maybe_replicate, 1182 map(shift_sharded_axis, pspec.mesh_mapping))) 1183 else: 1184 return ShardingSpec( 1185 sharding=pspec.sharding, 1186 mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping) 1187 1188def partitioned_sharding_spec(num_partitions: int, 1189 partitions: Optional[Sequence[int]], 1190 aval) -> ShardingSpec: 1191 if partitions is None: 1192 maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),) 1193 return ShardingSpec( 1194 sharding=[_UNSHARDED_INSTANCE] * len(aval.shape), 1195 mesh_mapping=maybe_replicate) 1196 else: 1197 assert len(partitions) == len(aval.shape) 1198 return ShardingSpec(sharding=map(Chunked, partitions), 1199 mesh_mapping=map(ShardedAxis, range(len(partitions)))) 1200 1201 1202def execute_replicated(compiled, backend, in_handler, out_handler, *args): 1203 input_bufs = in_handler(args) 1204 out_bufs = compiled.execute_on_local_devices(list(input_bufs)) 1205 return out_handler(out_bufs) 1206 1207 1208xla_pmap_p = core.MapPrimitive('xla_pmap') 1209xla_pmap = xla_pmap_p.bind 1210xla_pmap_p.def_impl(xla_pmap_impl) 1211 1212# Set param update handlers to update `donated_invars` just like xla_call_p 1213pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p] 1214ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p] 1215ad.call_transpose_param_updaters[xla_pmap_p] = \ 1216 ad.call_transpose_param_updaters[xla.xla_call_p] 1217 1218def _pmap_translation_rule(c, axis_env, 1219 in_nodes, name_stack, axis_name, axis_size, 1220 global_axis_size, devices, name, 1221 call_jaxpr, *, backend=None, in_axes, out_axes, 1222 donated_invars, global_arg_shapes): 1223 del donated_invars # Unused. 1224 # We in-line here rather than generating a Call HLO as in the xla_call 1225 # translation rule just because the extra tuple stuff is a pain. 1226 if axis_env.names and devices is not None: 1227 raise ValueError("Nested pmap with explicit devices argument.") 1228 if global_axis_size is None: 1229 global_axis_size = axis_size 1230 new_env = xla.extend_axis_env(axis_env, axis_name, global_axis_size) 1231 # Shard the in_nodes that are mapped 1232 in_avals = [v.aval for v in call_jaxpr.invars] 1233 in_nodes_sharded = ( 1234 _xla_shard(c, aval, new_env, in_node, in_axis) if in_axis is not None else in_node 1235 for aval, in_node, in_axis in safe_zip(in_avals, in_nodes, in_axes)) 1236 1237 with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore 1238 sharded_outs = xla.jaxpr_subcomp( 1239 c, call_jaxpr, backend, new_env, (), 1240 extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded) 1241 out_avals = [v.aval for v in call_jaxpr.outvars] 1242 outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend) 1243 for aval, out_axis, shard in safe_zip(out_avals, out_axes, sharded_outs)] 1244 return xops.Tuple(c, outs) 1245 1246xla.call_translations[xla_pmap_p] = _pmap_translation_rule 1247ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) 1248 1249def _xla_shard(c, aval, axis_env, x, in_axis): 1250 if aval is core.abstract_unit: 1251 return x 1252 elif aval is core.abstract_token: 1253 return x 1254 elif isinstance(aval, ShapedArray): 1255 dims = list(c.get_shape(x).dimensions()) 1256 zero = xb.constant(c, np.zeros((), dtype=np.uint32)) 1257 idxs = [zero] * (len(dims) - 1) 1258 idxs.insert(in_axis, _unravel_index(c, axis_env)) 1259 dims_unsqueezed = dims.copy() 1260 dims_unsqueezed[in_axis] = 1 1261 dims_squeezed = dims.copy() 1262 dims_squeezed.pop(in_axis) 1263 return xops.Reshape(xops.DynamicSlice(x, idxs, dims_unsqueezed), dims_squeezed) 1264 else: 1265 raise TypeError((aval, c.get_shape(x))) 1266 1267# TODO(b/110096942): more efficient gather 1268def _xla_unshard(c, aval, axis_env, out_axis, x, backend): 1269 if aval is core.abstract_unit: 1270 return x 1271 elif aval is core.abstract_token: 1272 return x 1273 elif isinstance(aval, ShapedArray): 1274 # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU 1275 convert_bool = (np.issubdtype(aval.dtype, np.bool_) 1276 and xb.get_backend(backend).platform in ('cpu', 'gpu')) 1277 if convert_bool: 1278 x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32)) 1279 1280 xla_shape = c.get_shape(x) 1281 dims = list(xla_shape.dimensions()) 1282 padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())), 1283 [axis_env.sizes[-1]] + dims) 1284 zero = xb.constant(c, np.zeros((), dtype=np.uint32)) 1285 idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims) 1286 padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs) 1287 replica_groups_protos = xc.make_replica_groups( 1288 xla.axis_groups(axis_env, axis_env.names[-1])) 1289 out = xops.CrossReplicaSum(padded, replica_groups_protos) 1290 if out_axis != 0: 1291 # TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead 1292 perm = list(range(1, len(dims))) 1293 perm.insert(out_axis, 0) 1294 out = xops.Transpose(out, perm) 1295 1296 # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU 1297 if convert_bool: 1298 nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) 1299 out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_)) 1300 return out 1301 else: 1302 raise TypeError((aval, c.get_shape(x))) 1303 1304def _unravel_index(c, axis_env): 1305 div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32)) 1306 mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32)) 1307 return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) 1308 1309# ------------------- xmap ------------------- 1310 1311MeshAxisName = Any 1312""" 1313ArrayMapping specifies how an ndarray should map to mesh axes. 1314 1315Note that the ordering is crucial for the cases when this mapping is non-injective 1316(i.e. when multiple mesh axes map to the same positional axis). Then, the 1317order of entries of the mapping determines a major-to-minor order on mesh axes, 1318according to which chunks of the value along the repeated dimension will be assigned. 1319 1320For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}. 1321The second dimension of the value would get chunked into 6 pieces, and assigned to the 1322mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case, 1323that would mean that a flat list of chunks would get assigned to a flattened list of 1324mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the 1325mesh devices ndarray would have to be transposed before flattening and assignment. 1326""" 1327ArrayMapping = OrderedDictType[MeshAxisName, int] 1328 1329class Mesh: 1330 __slots__ = ('devices', 'axis_names') 1331 1332 def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]): 1333 assert devices.ndim == len(axis_names) 1334 # TODO: Make sure that devices are unique? At least with the quick and 1335 # dirty check that the array size is not larger than the number of 1336 # available devices? 1337 self.devices = devices 1338 self.axis_names = tuple(axis_names) 1339 1340 @property 1341 def shape(self): 1342 return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape)) 1343 1344 @property 1345 def size(self): 1346 return np.prod(list(self.shape.values())) 1347 1348 # TODO: This is pretty expensive to compute. Cache this on the mesh object? 1349 @property 1350 def local_mesh(self): 1351 host_id = xb.host_id() 1352 is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices) 1353 subcube_indices = [] 1354 # We take the smallest slice of each dimension that doesn't skip any local device. 1355 for axis in range(self.devices.ndim): 1356 other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis) 1357 # NOTE: This re-reduces over many axes multiple times, so we could definitely 1358 # optimize it, but I hope it won't be a bottleneck anytime soon. 1359 local_slices = is_local_device.any(other_axes, keepdims=False) 1360 nonzero_indices = np.flatnonzero(local_slices) 1361 start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) 1362 subcube_indices.append(slice(start, end + 1)) 1363 subcube_indices = tuple(subcube_indices) 1364 # We only end up with all conditions being true if the local devices formed a 1365 # subcube of the full array. This is because we were biased towards taking a 1366 # "hull" spanned by the devices, and in case the local devices don't form a 1367 # subcube that hull will contain non-local devices. 1368 assert is_local_device[subcube_indices].all() 1369 return Mesh(self.devices[subcube_indices], self.axis_names) 1370 1371 def __getitem__(self, new_axes): 1372 indices = [0] * len(self.axis_names) 1373 axis_pos = {name: i for i, name in enumerate(self.axis_names)} 1374 for axis in new_axes: 1375 indices[axis_pos[axis]] = slice(None) 1376 new_devices = self.devices[tuple(indices)] 1377 new_devices = new_devices.transpose(tuple(axis_pos[axis] for axis in new_axes)) 1378 return Mesh(new_devices, new_axes) 1379 1380 @property 1381 def device_ids(self): 1382 return np.vectorize(lambda d: d.id, otypes=[int])(self.devices) 1383 1384def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval): 1385 if aval is core.abstract_unit: 1386 return aval 1387 assert isinstance(aval, ShapedArray) 1388 shape = list(aval.shape) 1389 for name, axis in in_axes.items(): 1390 assert shape[axis] % axis_sizes[name] == 0 1391 shape[axis] //= axis_sizes[name] 1392 return ShapedArray(tuple(shape), aval.dtype) 1393 1394def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval): 1395 if aval is core.abstract_unit: 1396 return aval 1397 assert isinstance(aval, ShapedArray) 1398 shape = list(aval.shape) 1399 for name, axis in out_axes.items(): 1400 shape[axis] *= axis_sizes[name] 1401 return ShapedArray(tuple(shape), aval.dtype) 1402 1403def mesh_tiled_callable(fun: lu.WrappedFun, 1404 transformed_name: str, 1405 backend_name: Optional[str], 1406 mesh: Mesh, 1407 in_axes: Sequence[ArrayMapping], 1408 out_axes: Sequence[ArrayMapping], 1409 spmd_lowering, 1410 *local_in_untiled_avals): 1411 assert config.omnistaging_enabled 1412 local_mesh = mesh.local_mesh 1413 global_axis_sizes = mesh.shape 1414 local_axis_sizes = local_mesh.shape 1415 1416 log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG 1417 logging.log(log_priority, 1418 f"Compiling {fun.__name__} for {tuple(global_axis_sizes.items())} " 1419 f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.") 1420 1421 # 1. Trace to jaxpr and preprocess/verify it 1422 in_tiled_avals = [tile_aval_nd(local_axis_sizes, aval_in_axes, aval) 1423 for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)] 1424 if spmd_lowering: 1425 # TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice! 1426 for name, size in reversed(mesh.shape.items()): 1427 fun = vtile(fun, 1428 tuple(a.get(name, None) for a in in_axes), 1429 tuple(a.get(name, None) for a in out_axes), 1430 tile_size=size, axis_name=name) 1431 global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval) 1432 for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)] 1433 in_jaxpr_avals = global_in_untiled_avals 1434 else: 1435 in_jaxpr_avals = in_tiled_avals 1436 with core.extend_axis_env_nd(mesh.shape.items()): 1437 jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals) 1438 assert len(out_axes) == len(out_jaxpr_avals) 1439 if spmd_lowering: 1440 global_out_untiled_avals = out_jaxpr_avals 1441 out_tiled_avals = [tile_aval_nd(global_axis_sizes, aval_out_axes, aval) 1442 for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)] 1443 else: 1444 out_tiled_avals = out_jaxpr_avals 1445 local_out_untiled_avals = [untile_aval_nd(local_axis_sizes, aval_out_axes, aval) 1446 for aval, aval_out_axes in safe_zip(out_tiled_avals, out_axes)] 1447 _sanitize_mesh_jaxpr(jaxpr) 1448 if local_mesh.shape != mesh.shape: 1449 check_multihost_collective_allowlist(jaxpr) 1450 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 1451 1452 # 3. Build up the HLO 1453 c = xb.make_computation_builder(f"xmap_{fun.__name__}") 1454 xla_consts = map(partial(xb.constant, c), consts) 1455 donated_invars = (False,) * len(in_jaxpr_avals) # TODO(apaszke): support donation 1456 tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU 1457 in_partitions: Optional[List] 1458 if spmd_lowering: 1459 replicated_args = [False] * len(in_jaxpr_avals) 1460 global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names) 1461 in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto() 1462 if aval is not core.abstract_unit else None 1463 for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)] 1464 out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto() 1465 for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)] 1466 partitions_proto = True 1467 axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped 1468 else: 1469 replicated_args = [not axis for axis in in_axes] 1470 in_partitions = None 1471 partitions_proto = False 1472 axis_env = xla.AxisEnv(nreps=mesh.size, 1473 names=tuple(global_axis_sizes.keys()), 1474 sizes=tuple(global_axis_sizes.values())) 1475 xla_args, donated_invars = xla._xla_callable_args( 1476 c, in_jaxpr_avals, tuple_args, 1477 replicated=replicated_args, 1478 partitions=in_partitions, 1479 partitions_proto=partitions_proto, 1480 donated_invars=donated_invars) 1481 with core.extend_axis_env_nd(mesh.shape.items()): 1482 out_nodes = xla.jaxpr_subcomp( 1483 c, jaxpr, backend_name, axis_env, xla_consts, 1484 extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args) 1485 backend = xb.get_backend(backend_name) 1486 if spmd_lowering: 1487 out_partitions_t = xb.tuple_sharding_proto(out_partitions) 1488 out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes) 1489 else: 1490 out_tuple = xops.Tuple(c, out_nodes) 1491 # TODO(apaszke): Does that work with SPMD sharding? 1492 if backend.platform in ("gpu", "tpu"): 1493 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 1494 built = c.Build(out_tuple) 1495 1496 # 4. Compile the HLO 1497 if spmd_lowering: 1498 num_replicas, num_partitions = 1, mesh.size 1499 num_local_replicas, num_local_partitions = 1, local_mesh.size 1500 else: 1501 num_replicas, num_partitions = mesh.size, 1 1502 num_local_replicas, num_local_partitions = local_mesh.size, 1 1503 device_assignment = mesh.device_ids.reshape((num_replicas, num_partitions)) 1504 compile_options = xb.get_compile_options( 1505 num_replicas=num_replicas, 1506 num_partitions=num_partitions, 1507 device_assignment=device_assignment, 1508 use_spmd_partitioning=spmd_lowering, 1509 ) 1510 compile_options.parameter_is_tupled_arguments = tuple_args 1511 compiled = xla.backend_compile(backend, built, compile_options) 1512 1513 # 5. Argument sharding / output wrapping 1514 local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names) 1515 local_input_specs = [local_sharding_spec(aval, aval_in_axes) 1516 if aval is not core.abstract_unit else None 1517 for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)] 1518 input_indices = [spec_to_indices(aval.shape, spec) 1519 if spec is not None else None 1520 for aval, spec in safe_zip(local_in_untiled_avals, local_input_specs)] 1521 handle_args = partial(shard_args, compiled.local_devices(), input_indices) 1522 1523 local_output_specs = [local_sharding_spec(aval, aval_out_axes) 1524 for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)] 1525 handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions, 1526 local_output_specs, local_out_untiled_avals) 1527 1528 return partial(execute_replicated, compiled, backend, handle_args, handle_outs) 1529 1530# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. 1531def vtile(f_flat, 1532 in_axes_flat: Tuple[Optional[int], ...], 1533 out_axes_flat: Tuple[Optional[int], ...], 1534 tile_size: Optional[int], axis_name): 1535 if tile_size == 1: 1536 return f_flat 1537 1538 @curry 1539 def tile_axis(arg, axis: Optional[int], tile_size): 1540 if axis is None: 1541 return arg 1542 shape = list(arg.shape) 1543 shape[axis:axis+1] = [tile_size, shape[axis] // tile_size] 1544 return arg.reshape(shape) 1545 1546 def untile_axis(out, axis: Optional[int]): 1547 if axis is None: 1548 return out 1549 shape = list(out.shape) 1550 shape[axis:axis+2] = [shape[axis] * shape[axis+1]] 1551 return out.reshape(shape) 1552 1553 @lu.transformation 1554 def _map_to_tile(*args_flat): 1555 sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) 1556 tile_size_ = tile_size or next(sizes, None) 1557 assert tile_size_ is not None, "No mapped arguments?" 1558 outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} 1559 yield map(untile_axis, outputs_flat, out_axes_flat) 1560 1561 return _map_to_tile( 1562 batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat)) 1563 1564_forbidden_primitives = { 1565 'xla_pmap': 'pmap', 1566 'soft_pmap': 'soft_pmap', 1567 'sharded_call': 'sharded_jit', 1568} 1569def _sanitize_mesh_jaxpr(jaxpr): 1570 for eqn in jaxpr.eqns: 1571 if eqn.primitive.name in _forbidden_primitives: 1572 raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} " 1573 f"inside xmaps not supported!") 1574 core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params) 1575 1576 1577def mesh_sharding_specs(axis_sizes, axis_names): 1578 mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} 1579 # NOTE: This takes in the non-sharded avals! 1580 def mk_sharding_spec(aval, aval_axes): 1581 sharding = [_UNSHARDED_INSTANCE] * len(aval.shape) 1582 mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] 1583 next_sharded_axis = 0 1584 aval_shape = list(aval.shape) 1585 # NOTE: sorted is stable, which is important when multiple resources 1586 # map to the same axis. 1587 for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): 1588 assert aval_shape[axis] % axis_sizes[name] == 0, (axis_sizes[name], aval.shape[axis]) 1589 aval_shape[axis] //= axis_sizes[name] 1590 if isinstance(sharding[axis], NoSharding): 1591 sharding[axis] = Chunked(()) 1592 sharding[axis] = Chunked(sharding[axis].chunks + (axis_sizes[name],)) 1593 assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \ 1594 "Value mapped to the same mesh axis twice" 1595 mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis) 1596 next_sharded_axis += 1 1597 return ShardingSpec(sharding, mesh_mapping) 1598 return mk_sharding_spec 1599 1600 1601# ------------------- soft_pmap ------------------- 1602 1603def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk): 1604 abstract_args = unsafe_map(xla.abstractify, args) 1605 compiled_fun = _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, 1606 *abstract_args) 1607 return compiled_fun(*args) 1608 1609@lu.cache 1610def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals): 1611 mapped_avals = [core.mapped_aval(axis_size, in_axis, aval) if in_axis is not None else aval 1612 for in_axis, aval in safe_zip(in_axes, avals)] 1613 with core.extend_axis_env(axis_name, axis_size, None): # type: ignore 1614 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals) 1615 out_axes = out_axes_thunk() 1616 assert all(out_axis == 0 for out_axis in out_axes) 1617 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 1618 1619 num_devices = xb.local_device_count() 1620 chunk_size, ragged = divmod(axis_size, num_devices) 1621 if ragged: 1622 msg = f"number of devices {num_devices} must divide axis size {axis_size}" 1623 raise NotImplementedError(msg) 1624 1625 jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes, 1626 axis_name, axis_size, chunk_size) 1627 jaxpr_replicas = xla.jaxpr_replicas(jaxpr) 1628 if jaxpr_replicas != 1: raise NotImplementedError 1629 1630 tuple_args = len(avals) > 100 # pass long arg lists as tuple for TPU 1631 1632 c = xb.make_computation_builder("soft_pmap_{}".format(fun.__name__)) 1633 xla_consts = map(partial(xb.constant, c), consts) 1634 chunked_avals = [core.unmapped_aval(chunk_size, in_axis, aval) if in_axis is not None else aval 1635 for in_axis, aval in safe_zip(in_axes, mapped_avals)] 1636 xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args) 1637 axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,)) 1638 out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, 1639 'soft_pmap', *xla_args) 1640 built = c.Build(xops.Tuple(c, out_nodes)) 1641 1642 compile_options = xb.get_compile_options( 1643 num_replicas=num_devices, num_partitions=1, device_assignment=None) 1644 compile_options.tuple_arguments = tuple_args 1645 backend = xb.get_backend(None) 1646 compiled = xla.backend_compile(backend, built, compile_options) 1647 1648 input_specs = [ 1649 ShardingSpec( 1650 sharding=tuple_insert((_UNSHARDED_INSTANCE,) * 1651 (aval.ndim - 1), in_axis, Chunked(num_devices)), 1652 mesh_mapping=[ShardedAxis(0)]) 1653 if in_axis is not None else ShardingSpec( 1654 sharding=[_UNSHARDED_INSTANCE] * aval.ndim, 1655 mesh_mapping=[Replicated(num_devices)]) 1656 for aval, in_axis in safe_zip(avals, in_axes) 1657 ] 1658 input_indices = [spec and spec_to_indices(aval.shape, spec) 1659 for aval, spec in safe_zip(avals, input_specs)] 1660 handle_args = partial(shard_args, compiled.local_devices(), input_indices) 1661 handle_outs = soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals) 1662 1663 return partial(execute_replicated, compiled, backend, handle_args, handle_outs) 1664 1665def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size): 1666 assert all(in_axis is None or in_axis == 0 for in_axis in in_axes), in_axes 1667 mapped_invars = [in_axis is not None for in_axis in in_axes] 1668 fun = partial(_soft_pmap_interp, chunk_size, jaxpr, consts, mapped_invars) 1669 in_avals = [core.unmapped_aval(chunk_size, in_axis, v.aval) if in_axis is not None else v.aval 1670 for v, in_axis in safe_zip(jaxpr.invars, in_axes)] 1671 with core.extend_axis_env(axis_name, axis_size, None): 1672 return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals) 1673 1674def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args): 1675 1676 env: Dict[Var, Tuple[Any, bool]] = {} 1677 1678 def read(atom: Union[Var, Literal]) -> Tuple[Any, bool]: 1679 if isinstance(atom, Literal): 1680 return (atom.val, False) 1681 else: 1682 return env[atom] 1683 1684 def write(v: Var, val: Any, mapped: bool) -> None: 1685 env[v] = (val, mapped) 1686 1687 write(core.unitvar, core.unit, False) 1688 map(write, jaxpr.constvars, consts, (False,) * len(consts)) 1689 map(write, jaxpr.invars, args, mapped_invars) 1690 for eqn in jaxpr.eqns: 1691 in_vals, in_mapped = unzip2(map(read, eqn.invars)) 1692 if eqn.primitive in xla.parallel_translations: 1693 rule = soft_pmap_rules[eqn.primitive] 1694 out_vals, out_mapped = rule(in_vals, in_mapped, chunk_size, **eqn.params) 1695 if not eqn.primitive.multiple_results: 1696 out_vals, out_mapped = [out_vals], [out_mapped] 1697 elif isinstance(eqn.primitive, core.CallPrimitive): 1698 # we just inline here for convenience 1699 call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params) 1700 out_vals = _soft_pmap_interp(chunk_size, call_jaxpr, (), in_mapped, *in_vals) 1701 out_mapped = [True] * len(out_vals) 1702 elif isinstance(eqn.primitive, core.MapPrimitive): 1703 raise NotImplementedError # TODO 1704 else: 1705 if any(in_mapped): 1706 rule = batching.get_primitive_batcher(eqn.primitive, None) 1707 in_axes = [0 if m else batching.not_mapped for m in in_mapped] 1708 out_vals, out_axes = rule(in_vals, in_axes, **eqn.params) 1709 if not eqn.primitive.multiple_results: 1710 out_vals, out_axes = [out_vals], [out_axes] 1711 out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x 1712 for x, d in safe_zip(out_vals, out_axes)] 1713 out_mapped = [d is not not_mapped for d in out_axes] 1714 else: 1715 out_vals = eqn.primitive.bind(*in_vals, **eqn.params) 1716 if not eqn.primitive.multiple_results: 1717 out_vals = [out_vals] 1718 out_mapped = [False for _ in out_vals] 1719 map(write, eqn.outvars, out_vals, out_mapped) 1720 1721 out_vals, out_mapped = unzip2(map(read, jaxpr.outvars)) 1722 out_vals = [out if mapped else broadcast(out, chunk_size, 0) 1723 for out, mapped in safe_zip(out_vals, out_mapped)] 1724 return out_vals 1725 1726# TODO(mattjj): dedup w/ with other aval_to_result_handler via ShardingSpec 1727def soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals): 1728 nouts = len(out_avals) 1729 handlers = [soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval) 1730 for aval in out_avals] 1731 def handler(out_bufs): 1732 buffers = [[result_to_populate] * num_devices for _ in range(nouts)] 1733 for r, tuple_buf in enumerate(out_bufs): 1734 for i, buf in enumerate(tuple_buf): 1735 buffers[i][r] = buf 1736 assert not any(buf is result_to_populate for bufs in buffers 1737 for buf in bufs) 1738 return [h(bufs) for h, bufs in safe_zip(handlers, buffers)] 1739 return handler 1740 1741def soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval): 1742 axis_size = chunk_size * num_devices 1743 if aval is core.abstract_unit: 1744 return lambda _: core.unit 1745 elif isinstance(aval, core.ShapedArray): 1746 new_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype) 1747 spec = ShardingSpec( 1748 sharding=(Chunked(num_devices),) + (_UNSHARDED_INSTANCE,) * aval.ndim, 1749 mesh_mapping=(ShardedAxis(0),)) 1750 return lambda bufs: ShardedDeviceArray(new_aval, spec, bufs) 1751 else: 1752 raise TypeError(aval) 1753 1754soft_pmap_p = core.MapPrimitive('soft_pmap') 1755soft_pmap = soft_pmap_p.bind 1756soft_pmap_p.def_impl(soft_pmap_impl) 1757 1758soft_pmap_rules: Dict[core.Primitive, Callable] = {} 1759 1760@contextmanager 1761def maybe_extend_axis_env(*args, **kwargs): 1762 with core.extend_axis_env(*args, **kwargs): 1763 yield 1764 1765@config.register_omnistaging_disabler 1766@no_type_check 1767def omnistaging_disabler() -> None: 1768 global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \ 1769 _thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \ 1770 apply_parallel_primitive, parallel_pure_rules, \ 1771 _pvals_to_results_handler, _pval_to_result_handler, replicate, \ 1772 axis_index, maybe_extend_axis_env 1773 1774 @contextmanager 1775 def maybe_extend_axis_env(*args, **kwargs): 1776 yield 1777 1778 def _pvals_to_results_handler( 1779 size, nrep, npart, 1780 out_parts: Optional[Tuple[PartitionsOrReplicated, ...]], 1781 out_pvals, devices, backend): 1782 nouts = len(out_pvals) 1783 if out_parts is None: 1784 out_parts = (None,) * len(out_pvals) 1785 handlers = [ 1786 _pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend) 1787 for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore 1788 ] 1789 1790 def handler(out_bufs): 1791 assert nrep * npart == len(out_bufs) 1792 buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)] 1793 for r, tuple_buf in enumerate(out_bufs): 1794 for i, buf in enumerate(tuple_buf): 1795 buffers[i][r] = buf 1796 assert not any(buf is result_to_populate for bufs in buffers 1797 for buf in bufs) 1798 return [h(bufs) for h, bufs in safe_zip(handlers, buffers)] 1799 return handler 1800 1801 def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend): 1802 if devices: 1803 assert all(d.host_id == xb.host_id(backend) for d in devices) 1804 pv, const = pval 1805 if pv is None: 1806 if nrep is None: 1807 nrep = axis_size 1808 # If 'const' is a ShardedDeviceArray, it must have come from a pmap nested 1809 # inside the one we're currently evaluating, and we should replicate 1810 # 'const' across the total number of devices needed. We don't necessarily 1811 # know the nested pmap's axis_size (e.g. the jaxpr for 1812 # pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the 1813 # axis size of the output 'const'. 1814 # TODO: we might be doing unnecessary device transfers in the inner pmap. 1815 if isinstance(const, ShardedDeviceArray): 1816 nrep *= len(const) 1817 1818 bcast_const = (core.unit if const is core.unit 1819 else replicate(const, axis_size, nrep, devices, backend)) # type: ignore 1820 return lambda _: bcast_const # type: ignore 1821 else: 1822 if pv is not core.abstract_unit: 1823 unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype) 1824 sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv, 0) 1825 indices = spec_to_indices(unsharded_aval.shape, sharding_spec) 1826 else: 1827 sharding_spec = indices = None 1828 unsharded_aval = pv 1829 return aval_to_result_handler(sharding_spec, indices, unsharded_aval) 1830 1831 @contextmanager 1832 def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size): 1833 dynamic_axis_env = _thread_local_state.dynamic_axis_env 1834 dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size)) 1835 try: 1836 yield 1837 finally: 1838 dynamic_axis_env.pop() 1839 1840 def unmapped_device_count(backend=None): 1841 dynamic_axis_env = _thread_local_state.dynamic_axis_env 1842 mapped = prod(frame.hard_size for frame in dynamic_axis_env) 1843 unmapped, ragged = divmod(xb.device_count(backend), mapped) 1844 assert not ragged and unmapped > 0 1845 return unmapped 1846 1847 def apply_parallel_primitive(prim, *args, **params): 1848 # This is the op-by-op version of applying a collective primitive, like a psum 1849 # that doesn't have a data dependence on the argument of a pmap function. In 1850 # particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We 1851 # look up information in the dynamic axis env. 1852 dynamic_axis_env = _thread_local_state.dynamic_axis_env 1853 axis_name = params.pop('axis_name') 1854 axis_index_groups = params.pop('axis_index_groups') 1855 if axis_index_groups is not None: 1856 shape = (len(axis_index_groups[0]),) 1857 else: 1858 logical_size = lambda frame: frame.hard_size 1859 if isinstance(axis_name, (list, tuple)): 1860 shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name) 1861 else: 1862 shape = (logical_size(dynamic_axis_env[axis_name]),) 1863 return parallel_pure_rules[prim](*args, shape=shape, **params) 1864 1865 pe.staged_out_calls.add(xla_pmap_p) # type: ignore 1866 1867parallel_pure_rules = {} # type: ignore 1868 1869class DynamicAxisEnvFrame(object): 1870 __slots__ = ["name", "pmap_trace", "hard_size"] 1871 def __init__(self, name, pmap_trace, hard_size): 1872 self.name = name 1873 self.pmap_trace = pmap_trace 1874 self.hard_size = hard_size 1875 1876class DynamicAxisEnv(list): 1877 def __contains__(self, axis_name): 1878 return axis_name in (frame.name for frame in self) 1879 1880 def __getitem__(self, axis_name): 1881 if axis_name not in self: 1882 raise NameError("unbound axis name: {}".format(axis_name)) 1883 for frame in reversed(self): 1884 if frame.name == axis_name: 1885 return frame 1886 1887 raise AssertionError 1888 1889 @property 1890 def sizes(self): 1891 return tuple(frame.hard_size for frame in self) 1892 1893 @property 1894 def nreps(self): 1895 return prod(frame.hard_size for frame in self) 1896 1897class _ThreadLocalState(threading.local): 1898 def __init__(self): 1899 self.dynamic_axis_env = DynamicAxisEnv() 1900 1901_thread_local_state = _ThreadLocalState() 1902 1903def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client._xla.PyLocalBuffer]: 1904 """Call device_put on a sequence of devices and return a flat sequence of buffers.""" 1905 if replicate: 1906 return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices)) 1907 else: 1908 return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices))) 1909