1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from functools import partial
16from typing import Callable, Iterable, Optional, Tuple, Union
17
18from absl import logging
19import numpy as np
20
21from .. import core
22from . import ad
23from . import partial_eval as pe
24# TODO(skye): separate pmap into it's own module?
25from . import pxla
26from . import xla
27from .. import linear_util as lu
28from ..lib import xla_bridge as xb
29from ..lib import xla_client as xc
30from ..api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
31from ..tree_util import tree_flatten, tree_unflatten
32from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
33                         HashableFunction)
34from ..config import config, flags
35
36xops = xc._xla.ops
37
38FLAGS = flags.FLAGS
39
40
41def _map(f, *xs):
42  return tuple(map(f, *xs))
43
44
45class ResultToPopulate: pass
46result_to_populate = ResultToPopulate()
47
48
49def _avals_to_results_handler(nrep, npart, partitions, out_avals):
50  nouts = len(out_avals)
51  handlers = [_aval_to_result_handler(npart, parts, out_aval)
52              for parts, out_aval in safe_zip(partitions, out_avals)]
53
54  def handler(out_bufs):
55    assert nrep * npart == len(out_bufs)
56    buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)]
57    for r, tuple_buf in enumerate(out_bufs):
58      for i, buf in enumerate(tuple_buf):
59        buffers[i][r] = buf
60    assert not any(buf is result_to_populate for bufs in buffers
61                  for buf in bufs)
62    return [h(bufs) for h, bufs in zip(handlers, buffers)]
63
64  return handler
65
66def _aval_to_result_handler(npart, parts, aval):
67  if aval is not core.abstract_unit:
68    spec = pxla.partitioned_sharding_spec(npart, parts, aval)
69    indices = pxla.spec_to_indices(aval.shape, spec)
70  else:
71    spec = indices = None
72  return pxla.aval_to_result_handler(spec, indices, aval)
73
74
75@lu.cache
76def _sharded_callable(
77    fun: lu.WrappedFun, nparts: Optional[int],
78    in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
79    out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
80    local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
81    local_out_parts_thunk: Callable[[], Optional[Tuple[pxla.PartitionsOrReplicated, ...]]],
82    local_nparts: Optional[int], name: str, *abstract_args):
83  nrep = 1
84
85  if local_in_parts is None:
86    local_in_parts = in_parts
87
88  global_abstract_args = [pxla.get_global_aval(arg, parts, lparts)
89                          for arg, parts, lparts
90                          in safe_zip(abstract_args, in_parts, local_in_parts)]
91
92  logging.vlog(2, "abstract_args: %s", abstract_args)
93  logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
94  logging.vlog(2, "in_parts: %s", in_parts)
95  logging.vlog(2, "local_in_parts: %s", local_in_parts)
96
97  if config.omnistaging_enabled:
98    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
99        fun, global_abstract_args)
100  else:
101    in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args]
102    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals,  # type: ignore
103                                                 instantiate=False, bottom=True)  # type: ignore
104
105    # TODO(skye): add tests for equationless jaxpr cases
106    if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
107                              for outvar in jaxpr.outvars):
108      return lambda *_: [
109          const if pv is None else core.unit for pv, const in out_pvals
110      ]
111
112  if xb.get_backend().platform != "tpu":
113    # TODO(skye): fall back to regular jit?
114    raise ValueError("sharded_jit only works on TPU!")
115
116  nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
117  assert nparts is not None
118  if nparts > xb.device_count():
119    raise ValueError(
120        f"sharded_jit computation requires {nparts} devices, "
121        f"but only {xb.device_count()} devices are available.")
122  if xb.local_device_count() < nparts < xb.device_count():
123    raise NotImplementedError(
124        f"sharded_jit across multiple hosts must use all available devices. "
125        f"Got {nparts} out of {xb.device_count()} requested devices "
126        f"(local device count: {xb.local_device_count()})")
127
128  if local_nparts is None:
129    if nparts > xb.local_device_count():
130      raise ValueError(
131        "Specify 'local_nparts' when using cross-process sharded_jit "
132        "and all inputs and outputs are replicated.")
133    else:
134      local_nparts = nparts
135  if local_nparts > xb.local_device_count():
136    raise ValueError(
137        f"sharded_jit computation requires {local_nparts} local devices, "
138        f"but only {xb.local_device_count()} local devices are available.")
139
140  logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)
141
142  out_parts = out_parts_thunk()
143
144  local_out_parts = local_out_parts_thunk()
145  if local_out_parts is None:
146    local_out_parts = out_parts
147
148  logging.vlog(2, "out_parts: %s", out_parts)
149  logging.vlog(2, "local_out_parts: %s", local_out_parts)
150
151  local_out_avals = [pxla.get_local_aval(out, parts, lparts)
152                     for out, parts, lparts
153                     in safe_zip(global_out_avals, out_parts, local_out_parts)]
154
155  log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
156  logging.log(log_priority,
157              f"Compiling {fun.__name__} for {nparts} devices with "
158              f"args {global_abstract_args}.")
159
160  c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
161  xla_consts = _map(partial(xb.constant, c), consts)
162  xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
163  axis_env = xla.AxisEnv(nrep, (), ())
164  out_nodes = xla.jaxpr_subcomp(
165      c, jaxpr, None, axis_env, xla_consts,
166      extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
167  out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
168  built = c.Build(out_tuple)
169
170  if nparts <= xb.local_device_count():
171    devices = xb.local_devices()[:nparts]
172  else:
173    assert nparts == xb.device_count()
174    devices = xb.devices()
175  device_assignment = np.array([[d.id for d in devices]])
176  device_assignment = np.reshape(device_assignment, (-1, nparts))
177  # device_assignment = None  # TODO(skye): replace with default device assignment?
178
179  compiled = xla.backend_compile(
180      xb.get_backend(), built,
181      xb.get_compile_options(nrep, nparts, device_assignment))
182
183  input_specs = [
184      pxla.partitioned_sharding_spec(local_nparts, parts, aval)
185      for parts, aval in zip(local_in_parts, abstract_args)]
186  input_indices = [pxla.spec_to_indices(aval.shape, spec)
187                   if spec is not None else None
188                   for aval, spec in zip(abstract_args, input_specs)]
189
190  handle_args = partial(pxla.shard_args, compiled.local_devices(),
191                        input_indices)
192  assert config.omnistaging_enabled
193  handle_outs = _avals_to_results_handler(nrep, local_nparts,  # type: ignore
194                                          local_out_parts, local_out_avals)
195  return partial(_execute_spatially_partitioned, compiled, handle_args,
196                 handle_outs)
197
198
199def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
200                                  in_parts, out_parts_thunk, nparts, backend,
201                                  name, call_jaxpr, local_in_parts,
202                                  local_out_parts_thunk, local_nparts):
203  subc = xc.XlaBuilder(f"sharded_jit_{name}")
204
205  # We assume any extra leading in_nodes are constants and replicate them.
206  num_extra_nodes = len(in_nodes) - len(in_parts)
207  assert num_extra_nodes >= 0
208  in_parts = (None,) * num_extra_nodes + in_parts
209
210  args = []
211  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
212    # We use xb.set_sharding instead of xb.with_sharding because inlined calls
213    # shouldn't have shardings set directly on the inputs or outputs.
214    arg = xb.parameter(subc, i, c.GetShape(n))
215    args.append(xb.set_sharding(subc, arg, sharding))
216
217  out_nodes = xla.jaxpr_subcomp(
218      subc, call_jaxpr, backend, axis_env, (),
219      extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args)
220  out_parts = out_parts_thunk()
221  assert len(out_parts) == len(out_nodes)
222  out_nodes = [xb.set_sharding(subc, out, sharding)
223               for out, sharding in safe_zip(out_nodes, out_parts)]
224
225  subc = subc.build(xops.Tuple(subc, out_nodes))
226  return xops.Call(c, subc, list(in_nodes))
227
228
229def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args):
230  input_bufs = in_handler(args)
231  out_bufs = compiled.execute_on_local_devices(list(input_bufs))
232  return out_handler(out_bufs)
233
234
235def _xla_sharded_args(c, avals, in_parts):
236  xla_args = []
237  for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
238    param = xb.with_sharding(c, sharding, xb.parameter, c, i,
239                             *xla.aval_to_xla_shapes(aval))
240    xla_args.append(param)
241  return xla_args
242
243
244def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk,
245                       local_in_parts, local_out_parts_thunk, local_nparts,
246                       name):
247  compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk,
248                                   local_in_parts, local_out_parts_thunk,
249                                   local_nparts, name,
250                                   *map(xla.abstractify, args))
251  return compiled_fun(*args)
252
253
254sharded_call_p = core.CallPrimitive("sharded_call")
255sharded_call = sharded_call_p.bind
256sharded_call_p.def_impl(_sharded_call_impl)
257xla.call_translations[sharded_call_p] = _sharded_jit_translation_rule
258
259
260class PartitionSpec(tuple):
261  """Tuple of integer specifying how a value should be partitioned.
262
263  Each integer corresponds to how many ways a dimension is partitioned. We
264  create a separate class for this so JAX's pytree utilities can distinguish it
265  from a tuple that should be treated as a pytree.
266  """
267  def __new__(cls, *partitions):
268    return tuple.__new__(PartitionSpec, partitions)
269
270  def __repr__(self):
271    return "PartitionSpec%s" % tuple.__repr__(self)
272
273
274def sharded_jit(fun: Callable, in_parts, out_parts, num_partitions: int = None,
275                local_in_parts=None, local_out_parts=None,
276                local_num_partitions=None,
277                static_argnums: Union[int, Iterable[int]] = (),
278):
279  """Like ``jit``, but partitions ``fun`` across multiple devices.
280
281  WARNING: this feature is still under active development! It may not work well,
282  and may change without warning!
283
284  `sharded_jit` sets up ``fun`` for just-in-time compilation with XLA, but
285  unlike ``jit``, the compiled function will run across multiple devices
286  (e.g. multiple GPUs or multiple TPU cores). This is achieved by spatially
287  partitioning the data that flows through the computation, so each operation is
288  run across all devices and each device runs only a shard of the full
289  data. (Some data can optionally be replicated, which is sometimes more
290  efficient for small arrays when combined with larger spatially-partitioned
291  arrays.) Communication between devices is automatically inserted as necessary.
292
293  ``sharded_jit`` can be useful if the jitted version of ``fun`` would not fit
294  in a single device's memory, or to speed up ``fun`` by running each operation
295  in parallel across multiple devices.
296
297  Note: ``sharded_jit`` is currently available on TPU only!
298
299  Args:
300    fun: Function to be jitted.
301    in_parts: The input partitions, i.e. how each argument to ``fun`` should be
302      partitioned or replicated. This should be a PartitionSpec indicating into
303      how many partitions each dimension should be sharded, None indicating
304      replication, or (nested) standard Python containers thereof. For example,
305      ``in_parts=PartitionSpec(2,1)`` means all arguments should be partitioned
306      over two devices across the first dimension;
307      ``in_parts=(PartitionSpec(2,2), PartitionSpec(4,1), None)`` means the
308      first argument should be partitioned over four devices by splitting the
309      first two dimensions in half, the second argument should be partitioned
310      over the four devices across the first dimension, and the third argument
311      is replicated across the four devices. All PartitionSpecs in a given
312      ``sharded_jit`` call must correspond to the same total number of
313      partitions, i.e. the product of all PartitionSpecs must be equal.
314    out_parts: The output partitions, i.e. how each output of ``fun`` should be
315      partitioned or replicated. This follows the same convention as
316     ``in_parts``.
317    num_partitions: Optional. If set, explicitly specifies the number of devices
318      ``fun`` should partitioned across (rather than inferring it from
319      ``in_parts``, ``out_parts``, and/or any ``with_sharding_constraint``
320      calls).  Setting this should usually be unnecessary, but can be used to
321      maintain device persistence across multiple sharded_jit calls when some of
322      those calls only involve replicated values.
323    local_in_parts: Optional. This should be set when partitioning across
324      multiple processes, and says how each process's worth of data should be
325      partitioned (vs. in_parts which is the "global" partitioning across all
326      processes). This API is likely to change in the future.
327    local_out_parts: Optional. This should be set when partitioning across
328      multiple processes, and says how each process's worth of data should be
329      partitioned (vs. out_parts which is the "global" partitioning across all
330      processes). This API is likely to change in the future.
331    local_num_partitions: Optional. Explicitly specifies the numbers of local
332      devices to partitions across in a multi-process setting. This API is
333      likely to change in the future.
334    static_argnums: An int or collection of ints specifying which positional
335      arguments to treat as static (compile-time constant). Operations that only
336      depend on static arguments will be constant-folded. Calling the jitted
337      function with different values for these constants will trigger
338      recompilation. If the jitted function is called with fewer positional
339      arguments than indicated by ``static_argnums`` then an error is raised.
340      Each of the static arguments will be broadcasted to all devices.
341      Arguments that are not arrays or containers thereof must be marked as
342      static. Defaults to ().
343
344  Returns:
345    A version of ``fun`` that will be distributed across multiple devices.
346  """
347  if num_partitions is not None:
348    nparts = num_partitions
349  else:
350    nparts = pxla.get_num_partitions(in_parts, out_parts)
351
352  if local_num_partitions is not None:
353    local_nparts = local_num_partitions
354  else:
355    local_nparts = pxla.get_num_partitions(local_in_parts, local_out_parts)
356
357  static_argnums = _ensure_index_tuple(static_argnums)
358
359  @wraps(fun)
360  def wrapped(*args, **kwargs):
361    if kwargs:
362      raise NotImplementedError("sharded_jit over kwargs not yet supported")
363
364    f = lu.wrap_init(fun)
365    if static_argnums:
366      if max(static_argnums) >= len(args):
367        raise ValueError(
368            f"jitted function has static_argnums={static_argnums}"
369            f" but was called with only {len(args)} positional "
370            f"argument{'s' if len(args) > 1 else ''}. "
371            "All static broadcasted arguments must be passed positionally.")
372      dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
373      f, args = argnums_partial(f, dyn_argnums, args)
374
375    args_flat, in_tree = tree_flatten((args, kwargs))
376    in_parts_flat = tuple(flatten_axes("sharded_jit in_parts",
377                                       in_tree.children()[0], in_parts))
378    if local_in_parts is not None:
379      local_in_parts_flat = tuple(flatten_axes("sharded_jit local_in_parts",
380                                               in_tree.children()[0], local_in_parts))
381    else:
382      local_in_parts_flat = None
383
384    flat_fun, out_tree = flatten_fun(f, in_tree)
385    # TODO(skye): having a function-typed param in a primitive seems dicey, is
386    # there a better way?
387    out_parts_thunk = HashableFunction(
388        lambda: tuple(flatten_axes("sharded_jit out_parts", out_tree(), out_parts)),
389        closure=out_parts)
390    if local_out_parts:
391      local_out_parts_thunk = HashableFunction(
392          lambda: tuple(flatten_axes("sharded_jit local_out_parts",
393                                     out_tree(), local_out_parts)),
394          closure=local_out_parts)
395    else:
396      local_out_parts_thunk = HashableFunction(lambda: None, closure=None)
397
398    out = sharded_call(
399        flat_fun,
400        *args_flat,
401        nparts=nparts,
402        in_parts=in_parts_flat,
403        out_parts_thunk=out_parts_thunk,
404        local_in_parts=local_in_parts_flat,
405        local_out_parts_thunk=local_out_parts_thunk,
406        local_nparts=local_nparts,
407        name=flat_fun.__name__)
408    return tree_unflatten(out_tree(), out)
409
410  return wrapped
411
412
413def _sharding_constraint_impl(x, partitions):
414  # TODO(skye): can we also prevent this from being called in other
415  # non-sharded_jit contexts? (e.g. pmap, control flow)
416  raise NotImplementedError(
417      "with_sharding_constraint() should only be called inside sharded_jit()")
418
419def _sharding_constraint_translation_rule(c, x_node, partitions):
420  return xb.set_sharding(c, x_node, partitions)
421
422sharding_constraint_p = core.Primitive("sharding_constraint")
423sharding_constraint_p.def_impl(_sharding_constraint_impl)
424sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
425ad.deflinear2(sharding_constraint_p,
426              lambda ct, _, partitions: (with_sharding_constraint(ct, partitions),))
427xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
428
429def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
430  """Identity-like function that specifies how ``x`` should be sharded.
431
432  WARNING: this feature is still under active development! It may not work well,
433  and may change without warning!
434
435  This should only be called inside a function transformed by ``sharded_jit``.
436  It constrains how the function is sharded: regardless of any other specified
437  partitions, the compiler will make sure that ``x`` is sharded according to
438  ``partitions``.  Note that a ``with_sharding_constraint`` call doesn't
439  necessarily correspond to a reshard, since the compiler is free to achieve
440  this sharding as long as the constraint is met, e.g. it might insert a reshard
441  earlier in the computation. Another way to think of this is that the
442  ``with_sharding_constraint`` call may flow "up" the function to preceding
443  operations as well as "down" to subsequent ones.
444
445  ``partitions`` must correspond to the same number of total partitions dictated
446  by the outer ``sharded_jit`` and any other ``with_sharding_constraint`` calls.
447  In the case where only replication has been specified, any ``partitions`` are
448  valid.
449
450  Example usage:
451    @partial(sharded_jit, in_parts=None, out_parts=None, num_shards=2
452    def f(x):
453      y = x + 1
454      y = with_sharding_constraint(y, PartitionSpec(2,1))
455      return y * 2
456
457  In this example, the inputs and outputs of ``f`` will be replicated, but the
458  inner value of ``y`` will be partitioned in half. ``f`` will run on two
459  devices due to the with_sharding_constraint call.
460
461  Args:
462    x: Array value
463    partitions: PartitionSpec indicating how ``x`` should be partitioned, or
464      None for replication.
465
466  Returns:
467    A new version of ``x`` with the specified sharding applied.
468  """
469  return sharding_constraint_p.bind(x, partitions=partitions)
470
471
472@config.register_omnistaging_disabler
473def omnistaging_disabler() -> None:
474  global _pvals_to_results_handler, _pval_to_result_handler
475
476  def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
477    nouts = len(out_pvals)
478    handlers = [_pval_to_result_handler(npart, parts, out_pval)
479                for parts, out_pval in safe_zip(partitions, out_pvals)]  # type: ignore
480
481    def handler(out_bufs):
482      assert nrep * npart == len(out_bufs)
483      buffers = [[result_to_populate] * nrep * npart for _ in range(nouts)]
484      for r, tuple_buf in enumerate(out_bufs):
485        for i, buf in enumerate(tuple_buf):
486          buffers[i][r] = buf
487      assert not any(buf is result_to_populate for bufs in buffers
488                    for buf in bufs)
489      return [h(bufs) for h, bufs in zip(handlers, buffers)]
490
491    return handler
492
493  def _pval_to_result_handler(npart, parts, pval):
494    pv, const = pval
495    if pv is None:
496      raise NotImplementedError  # TODO(skye): handle constant outputs
497    else:
498      if pv is not core.abstract_unit:
499        spec = pxla.partitioned_sharding_spec(npart, parts, pv)
500        indices = pxla.spec_to_indices(pv.shape, spec)
501      else:
502        spec = indices = None
503      return pxla.aval_to_result_handler(spec, indices, pv)
504