1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from functools import partial 16from typing import Callable, Iterable, Optional, Tuple, Union 17 18from absl import logging 19import numpy as np 20 21from .. import core 22from . import ad 23from . import partial_eval as pe 24# TODO(skye): separate pmap into it's own module? 25from . import pxla 26from . import xla 27from .. import linear_util as lu 28from ..lib import xla_bridge as xb 29from ..lib import xla_client as xc 30from ..api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple 31from ..tree_util import tree_flatten, tree_unflatten 32from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip, 33 HashableFunction) 34from ..config import config, flags 35 36xops = xc._xla.ops 37 38FLAGS = flags.FLAGS 39 40 41def _map(f, *xs): 42 return tuple(map(f, *xs)) 43 44 45class ResultToPopulate: pass 46result_to_populate = ResultToPopulate() 47 48 49def _avals_to_results_handler(nrep, npart, partitions, out_avals): 50 nouts = len(out_avals) 51 handlers = [_aval_to_result_handler(npart, parts, out_aval) 52 for parts, out_aval in safe_zip(partitions, out_avals)] 53 54 def handler(out_bufs): 55 assert nrep * npart == len(out_bufs) 56 buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)] 57 for r, tuple_buf in enumerate(out_bufs): 58 for i, buf in enumerate(tuple_buf): 59 buffers[i][r] = buf 60 assert not any(buf is result_to_populate for bufs in buffers 61 for buf in bufs) 62 return [h(bufs) for h, bufs in zip(handlers, buffers)] 63 64 return handler 65 66def _aval_to_result_handler(npart, parts, aval): 67 if aval is not core.abstract_unit: 68 spec = pxla.partitioned_sharding_spec(npart, parts, aval) 69 indices = pxla.spec_to_indices(aval.shape, spec) 70 else: 71 spec = indices = None 72 return pxla.aval_to_result_handler(spec, indices, aval) 73 74 75@lu.cache 76def _sharded_callable( 77 fun: lu.WrappedFun, nparts: Optional[int], 78 in_parts: Tuple[pxla.PartitionsOrReplicated, ...], 79 out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], 80 local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], 81 local_out_parts_thunk: Callable[[], Optional[Tuple[pxla.PartitionsOrReplicated, ...]]], 82 local_nparts: Optional[int], name: str, *abstract_args): 83 nrep = 1 84 85 if local_in_parts is None: 86 local_in_parts = in_parts 87 88 global_abstract_args = [pxla.get_global_aval(arg, parts, lparts) 89 for arg, parts, lparts 90 in safe_zip(abstract_args, in_parts, local_in_parts)] 91 92 logging.vlog(2, "abstract_args: %s", abstract_args) 93 logging.vlog(2, "global_abstract_args: %s", global_abstract_args) 94 logging.vlog(2, "in_parts: %s", in_parts) 95 logging.vlog(2, "local_in_parts: %s", local_in_parts) 96 97 if config.omnistaging_enabled: 98 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( 99 fun, global_abstract_args) 100 else: 101 in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args] 102 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore 103 instantiate=False, bottom=True) # type: ignore 104 105 # TODO(skye): add tests for equationless jaxpr cases 106 if not jaxpr.eqns and all(outvar.aval is core.abstract_unit 107 for outvar in jaxpr.outvars): 108 return lambda *_: [ 109 const if pv is None else core.unit for pv, const in out_pvals 110 ] 111 112 if xb.get_backend().platform != "tpu": 113 # TODO(skye): fall back to regular jit? 114 raise ValueError("sharded_jit only works on TPU!") 115 116 nparts = pxla.reconcile_num_partitions(jaxpr, nparts) 117 assert nparts is not None 118 if nparts > xb.device_count(): 119 raise ValueError( 120 f"sharded_jit computation requires {nparts} devices, " 121 f"but only {xb.device_count()} devices are available.") 122 if xb.local_device_count() < nparts < xb.device_count(): 123 raise NotImplementedError( 124 f"sharded_jit across multiple hosts must use all available devices. " 125 f"Got {nparts} out of {xb.device_count()} requested devices " 126 f"(local device count: {xb.local_device_count()})") 127 128 if local_nparts is None: 129 if nparts > xb.local_device_count(): 130 raise ValueError( 131 "Specify 'local_nparts' when using cross-process sharded_jit " 132 "and all inputs and outputs are replicated.") 133 else: 134 local_nparts = nparts 135 if local_nparts > xb.local_device_count(): 136 raise ValueError( 137 f"sharded_jit computation requires {local_nparts} local devices, " 138 f"but only {xb.local_device_count()} local devices are available.") 139 140 logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) 141 142 out_parts = out_parts_thunk() 143 144 local_out_parts = local_out_parts_thunk() 145 if local_out_parts is None: 146 local_out_parts = out_parts 147 148 logging.vlog(2, "out_parts: %s", out_parts) 149 logging.vlog(2, "local_out_parts: %s", local_out_parts) 150 151 local_out_avals = [pxla.get_local_aval(out, parts, lparts) 152 for out, parts, lparts 153 in safe_zip(global_out_avals, out_parts, local_out_parts)] 154 155 log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG 156 logging.log(log_priority, 157 f"Compiling {fun.__name__} for {nparts} devices with " 158 f"args {global_abstract_args}.") 159 160 c = xb.make_computation_builder("spjit_{}".format(fun.__name__)) 161 xla_consts = _map(partial(xb.constant, c), consts) 162 xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) 163 axis_env = xla.AxisEnv(nrep, (), ()) 164 out_nodes = xla.jaxpr_subcomp( 165 c, jaxpr, None, axis_env, xla_consts, 166 extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args) 167 out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) 168 built = c.Build(out_tuple) 169 170 if nparts <= xb.local_device_count(): 171 devices = xb.local_devices()[:nparts] 172 else: 173 assert nparts == xb.device_count() 174 devices = xb.devices() 175 device_assignment = np.array([[d.id for d in devices]]) 176 device_assignment = np.reshape(device_assignment, (-1, nparts)) 177 # device_assignment = None # TODO(skye): replace with default device assignment? 178 179 compiled = xla.backend_compile( 180 xb.get_backend(), built, 181 xb.get_compile_options(nrep, nparts, device_assignment)) 182 183 input_specs = [ 184 pxla.partitioned_sharding_spec(local_nparts, parts, aval) 185 for parts, aval in zip(local_in_parts, abstract_args)] 186 input_indices = [pxla.spec_to_indices(aval.shape, spec) 187 if spec is not None else None 188 for aval, spec in zip(abstract_args, input_specs)] 189 190 handle_args = partial(pxla.shard_args, compiled.local_devices(), 191 input_indices) 192 assert config.omnistaging_enabled 193 handle_outs = _avals_to_results_handler(nrep, local_nparts, # type: ignore 194 local_out_parts, local_out_avals) 195 return partial(_execute_spatially_partitioned, compiled, handle_args, 196 handle_outs) 197 198 199def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, 200 in_parts, out_parts_thunk, nparts, backend, 201 name, call_jaxpr, local_in_parts, 202 local_out_parts_thunk, local_nparts): 203 subc = xc.XlaBuilder(f"sharded_jit_{name}") 204 205 # We assume any extra leading in_nodes are constants and replicate them. 206 num_extra_nodes = len(in_nodes) - len(in_parts) 207 assert num_extra_nodes >= 0 208 in_parts = (None,) * num_extra_nodes + in_parts 209 210 args = [] 211 for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): 212 # We use xb.set_sharding instead of xb.with_sharding because inlined calls 213 # shouldn't have shardings set directly on the inputs or outputs. 214 arg = xb.parameter(subc, i, c.GetShape(n)) 215 args.append(xb.set_sharding(subc, arg, sharding)) 216 217 out_nodes = xla.jaxpr_subcomp( 218 subc, call_jaxpr, backend, axis_env, (), 219 extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args) 220 out_parts = out_parts_thunk() 221 assert len(out_parts) == len(out_nodes) 222 out_nodes = [xb.set_sharding(subc, out, sharding) 223 for out, sharding in safe_zip(out_nodes, out_parts)] 224 225 subc = subc.build(xops.Tuple(subc, out_nodes)) 226 return xops.Call(c, subc, list(in_nodes)) 227 228 229def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args): 230 input_bufs = in_handler(args) 231 out_bufs = compiled.execute_on_local_devices(list(input_bufs)) 232 return out_handler(out_bufs) 233 234 235def _xla_sharded_args(c, avals, in_parts): 236 xla_args = [] 237 for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)): 238 param = xb.with_sharding(c, sharding, xb.parameter, c, i, 239 *xla.aval_to_xla_shapes(aval)) 240 xla_args.append(param) 241 return xla_args 242 243 244def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk, 245 local_in_parts, local_out_parts_thunk, local_nparts, 246 name): 247 compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk, 248 local_in_parts, local_out_parts_thunk, 249 local_nparts, name, 250 *map(xla.abstractify, args)) 251 return compiled_fun(*args) 252 253 254sharded_call_p = core.CallPrimitive("sharded_call") 255sharded_call = sharded_call_p.bind 256sharded_call_p.def_impl(_sharded_call_impl) 257xla.call_translations[sharded_call_p] = _sharded_jit_translation_rule 258 259 260class PartitionSpec(tuple): 261 """Tuple of integer specifying how a value should be partitioned. 262 263 Each integer corresponds to how many ways a dimension is partitioned. We 264 create a separate class for this so JAX's pytree utilities can distinguish it 265 from a tuple that should be treated as a pytree. 266 """ 267 def __new__(cls, *partitions): 268 return tuple.__new__(PartitionSpec, partitions) 269 270 def __repr__(self): 271 return "PartitionSpec%s" % tuple.__repr__(self) 272 273 274def sharded_jit(fun: Callable, in_parts, out_parts, num_partitions: int = None, 275 local_in_parts=None, local_out_parts=None, 276 local_num_partitions=None, 277 static_argnums: Union[int, Iterable[int]] = (), 278): 279 """Like ``jit``, but partitions ``fun`` across multiple devices. 280 281 WARNING: this feature is still under active development! It may not work well, 282 and may change without warning! 283 284 `sharded_jit` sets up ``fun`` for just-in-time compilation with XLA, but 285 unlike ``jit``, the compiled function will run across multiple devices 286 (e.g. multiple GPUs or multiple TPU cores). This is achieved by spatially 287 partitioning the data that flows through the computation, so each operation is 288 run across all devices and each device runs only a shard of the full 289 data. (Some data can optionally be replicated, which is sometimes more 290 efficient for small arrays when combined with larger spatially-partitioned 291 arrays.) Communication between devices is automatically inserted as necessary. 292 293 ``sharded_jit`` can be useful if the jitted version of ``fun`` would not fit 294 in a single device's memory, or to speed up ``fun`` by running each operation 295 in parallel across multiple devices. 296 297 Note: ``sharded_jit`` is currently available on TPU only! 298 299 Args: 300 fun: Function to be jitted. 301 in_parts: The input partitions, i.e. how each argument to ``fun`` should be 302 partitioned or replicated. This should be a PartitionSpec indicating into 303 how many partitions each dimension should be sharded, None indicating 304 replication, or (nested) standard Python containers thereof. For example, 305 ``in_parts=PartitionSpec(2,1)`` means all arguments should be partitioned 306 over two devices across the first dimension; 307 ``in_parts=(PartitionSpec(2,2), PartitionSpec(4,1), None)`` means the 308 first argument should be partitioned over four devices by splitting the 309 first two dimensions in half, the second argument should be partitioned 310 over the four devices across the first dimension, and the third argument 311 is replicated across the four devices. All PartitionSpecs in a given 312 ``sharded_jit`` call must correspond to the same total number of 313 partitions, i.e. the product of all PartitionSpecs must be equal. 314 out_parts: The output partitions, i.e. how each output of ``fun`` should be 315 partitioned or replicated. This follows the same convention as 316 ``in_parts``. 317 num_partitions: Optional. If set, explicitly specifies the number of devices 318 ``fun`` should partitioned across (rather than inferring it from 319 ``in_parts``, ``out_parts``, and/or any ``with_sharding_constraint`` 320 calls). Setting this should usually be unnecessary, but can be used to 321 maintain device persistence across multiple sharded_jit calls when some of 322 those calls only involve replicated values. 323 local_in_parts: Optional. This should be set when partitioning across 324 multiple processes, and says how each process's worth of data should be 325 partitioned (vs. in_parts which is the "global" partitioning across all 326 processes). This API is likely to change in the future. 327 local_out_parts: Optional. This should be set when partitioning across 328 multiple processes, and says how each process's worth of data should be 329 partitioned (vs. out_parts which is the "global" partitioning across all 330 processes). This API is likely to change in the future. 331 local_num_partitions: Optional. Explicitly specifies the numbers of local 332 devices to partitions across in a multi-process setting. This API is 333 likely to change in the future. 334 static_argnums: An int or collection of ints specifying which positional 335 arguments to treat as static (compile-time constant). Operations that only 336 depend on static arguments will be constant-folded. Calling the jitted 337 function with different values for these constants will trigger 338 recompilation. If the jitted function is called with fewer positional 339 arguments than indicated by ``static_argnums`` then an error is raised. 340 Each of the static arguments will be broadcasted to all devices. 341 Arguments that are not arrays or containers thereof must be marked as 342 static. Defaults to (). 343 344 Returns: 345 A version of ``fun`` that will be distributed across multiple devices. 346 """ 347 if num_partitions is not None: 348 nparts = num_partitions 349 else: 350 nparts = pxla.get_num_partitions(in_parts, out_parts) 351 352 if local_num_partitions is not None: 353 local_nparts = local_num_partitions 354 else: 355 local_nparts = pxla.get_num_partitions(local_in_parts, local_out_parts) 356 357 static_argnums = _ensure_index_tuple(static_argnums) 358 359 @wraps(fun) 360 def wrapped(*args, **kwargs): 361 if kwargs: 362 raise NotImplementedError("sharded_jit over kwargs not yet supported") 363 364 f = lu.wrap_init(fun) 365 if static_argnums: 366 if max(static_argnums) >= len(args): 367 raise ValueError( 368 f"jitted function has static_argnums={static_argnums}" 369 f" but was called with only {len(args)} positional " 370 f"argument{'s' if len(args) > 1 else ''}. " 371 "All static broadcasted arguments must be passed positionally.") 372 dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] 373 f, args = argnums_partial(f, dyn_argnums, args) 374 375 args_flat, in_tree = tree_flatten((args, kwargs)) 376 in_parts_flat = tuple(flatten_axes("sharded_jit in_parts", 377 in_tree.children()[0], in_parts)) 378 if local_in_parts is not None: 379 local_in_parts_flat = tuple(flatten_axes("sharded_jit local_in_parts", 380 in_tree.children()[0], local_in_parts)) 381 else: 382 local_in_parts_flat = None 383 384 flat_fun, out_tree = flatten_fun(f, in_tree) 385 # TODO(skye): having a function-typed param in a primitive seems dicey, is 386 # there a better way? 387 out_parts_thunk = HashableFunction( 388 lambda: tuple(flatten_axes("sharded_jit out_parts", out_tree(), out_parts)), 389 closure=out_parts) 390 if local_out_parts: 391 local_out_parts_thunk = HashableFunction( 392 lambda: tuple(flatten_axes("sharded_jit local_out_parts", 393 out_tree(), local_out_parts)), 394 closure=local_out_parts) 395 else: 396 local_out_parts_thunk = HashableFunction(lambda: None, closure=None) 397 398 out = sharded_call( 399 flat_fun, 400 *args_flat, 401 nparts=nparts, 402 in_parts=in_parts_flat, 403 out_parts_thunk=out_parts_thunk, 404 local_in_parts=local_in_parts_flat, 405 local_out_parts_thunk=local_out_parts_thunk, 406 local_nparts=local_nparts, 407 name=flat_fun.__name__) 408 return tree_unflatten(out_tree(), out) 409 410 return wrapped 411 412 413def _sharding_constraint_impl(x, partitions): 414 # TODO(skye): can we also prevent this from being called in other 415 # non-sharded_jit contexts? (e.g. pmap, control flow) 416 raise NotImplementedError( 417 "with_sharding_constraint() should only be called inside sharded_jit()") 418 419def _sharding_constraint_translation_rule(c, x_node, partitions): 420 return xb.set_sharding(c, x_node, partitions) 421 422sharding_constraint_p = core.Primitive("sharding_constraint") 423sharding_constraint_p.def_impl(_sharding_constraint_impl) 424sharding_constraint_p.def_abstract_eval(lambda x, partitions: x) 425ad.deflinear2(sharding_constraint_p, 426 lambda ct, _, partitions: (with_sharding_constraint(ct, partitions),)) 427xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule 428 429def with_sharding_constraint(x, partitions: Optional[PartitionSpec]): 430 """Identity-like function that specifies how ``x`` should be sharded. 431 432 WARNING: this feature is still under active development! It may not work well, 433 and may change without warning! 434 435 This should only be called inside a function transformed by ``sharded_jit``. 436 It constrains how the function is sharded: regardless of any other specified 437 partitions, the compiler will make sure that ``x`` is sharded according to 438 ``partitions``. Note that a ``with_sharding_constraint`` call doesn't 439 necessarily correspond to a reshard, since the compiler is free to achieve 440 this sharding as long as the constraint is met, e.g. it might insert a reshard 441 earlier in the computation. Another way to think of this is that the 442 ``with_sharding_constraint`` call may flow "up" the function to preceding 443 operations as well as "down" to subsequent ones. 444 445 ``partitions`` must correspond to the same number of total partitions dictated 446 by the outer ``sharded_jit`` and any other ``with_sharding_constraint`` calls. 447 In the case where only replication has been specified, any ``partitions`` are 448 valid. 449 450 Example usage: 451 @partial(sharded_jit, in_parts=None, out_parts=None, num_shards=2 452 def f(x): 453 y = x + 1 454 y = with_sharding_constraint(y, PartitionSpec(2,1)) 455 return y * 2 456 457 In this example, the inputs and outputs of ``f`` will be replicated, but the 458 inner value of ``y`` will be partitioned in half. ``f`` will run on two 459 devices due to the with_sharding_constraint call. 460 461 Args: 462 x: Array value 463 partitions: PartitionSpec indicating how ``x`` should be partitioned, or 464 None for replication. 465 466 Returns: 467 A new version of ``x`` with the specified sharding applied. 468 """ 469 return sharding_constraint_p.bind(x, partitions=partitions) 470 471 472@config.register_omnistaging_disabler 473def omnistaging_disabler() -> None: 474 global _pvals_to_results_handler, _pval_to_result_handler 475 476 def _pvals_to_results_handler(nrep, npart, partitions, out_pvals): 477 nouts = len(out_pvals) 478 handlers = [_pval_to_result_handler(npart, parts, out_pval) 479 for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore 480 481 def handler(out_bufs): 482 assert nrep * npart == len(out_bufs) 483 buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)] 484 for r, tuple_buf in enumerate(out_bufs): 485 for i, buf in enumerate(tuple_buf): 486 buffers[i][r] = buf 487 assert not any(buf is result_to_populate for bufs in buffers 488 for buf in bufs) 489 return [h(bufs) for h, bufs in zip(handlers, buffers)] 490 491 return handler 492 493 def _pval_to_result_handler(npart, parts, pval): 494 pv, const = pval 495 if pv is None: 496 raise NotImplementedError # TODO(skye): handle constant outputs 497 else: 498 if pv is not core.abstract_unit: 499 spec = pxla.partitioned_sharding_spec(npart, parts, pv) 500 indices = pxla.spec_to_indices(pv.shape, spec) 501 else: 502 spec = indices = None 503 return pxla.aval_to_result_handler(spec, indices, pv) 504