15import numpy as np
16from typing import Any, Callable, Dict, Optional, Tuple, Union
18import jax
19from ..config import config
20from .. import core
21from ..core import ShapedArray, raise_to_shaped, Trace, Tracer
22from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
23from .. import linear_util as lu
24from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list,
25                         canonicalize_axis, moveaxis, as_hashable_function)
26from . import xla
27from . import partial_eval as pe
29map = safe_map
31def batch(fun: lu.WrappedFun, axis_name, axis_size, in_dims, out_dim_dests,
32          ) -> lu.WrappedFun:
33  # anlogue of `jvp` in ad.py
34  fun, out_dims_thunk = batch_subtrace(fun)
35  return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
36                     axis_size, out_dims_thunk, out_dim_dests)
39def batchfun(axis_name, axis_size, in_dims, *in_vals):
40  # analogue of `jvpfun` in ad.py
41  in_dims = in_dims() if callable(in_dims) else in_dims
42  in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
43             and not isinstance(core.get_aval(x), core.AbstractUnit)  # non-omnistaging
44             else ax for x, ax in zip(in_vals, in_dims)]
45  with core.new_main(BatchTrace, axis_name=axis_name) as main:
46    with core.extend_axis_env(axis_name, axis_size, main):
47      out_vals = yield (main, in_dims, *in_vals), {}
48      del main
49  yield out_vals
52def batch_subtrace(main, in_dims, *in_vals):
53  # analogue of `jvp_subtrace` in ad.py
54  trace = main.with_cur_sublevel()
55  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
56                for val, dim in zip(in_vals, in_dims)]
57  outs = yield in_tracers, {}
58  out_tracers = map(trace.full_raise, outs)
59  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
60  yield out_vals, out_dims
63def _match_axes(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
64  out_vals = yield in_vals, {}
65  out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
66  out_dims = out_dims_thunk()
67  for od, od_dest in zip(out_dims, out_dim_dests):
68    if od is not None and not isinstance(od_dest, int):
69      msg = f"vmap has mapped output but out_axes is {od_dest}"
70      raise ValueError(msg)
71  yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
74# These next two functions, `batch_fun2` and `_batch_fun2`, are deprecated; the
75# former is only called from `custom_transforms`, which itself is deprecated.
76# TODO(mattjj): delete these along with custom_transforms
78def batch_fun2(fun: lu.WrappedFun, in_dims):
79  # like `batch_fun` but returns output batch dims (so no out_dim_dests)
80  fun, out_dims = batch_subtrace(fun)
81  return _batch_fun2(fun, in_dims), out_dims
84def _batch_fun2(in_dims, *in_vals, **params):
85  with core.new_main(BatchTrace, axis_name=None) as main:
86    out_vals = yield (main, in_dims,) + in_vals, params
87    del main
88  yield out_vals
91### tracer
93# TODO(mattjj): use a special sentinel type rather than None
94NotMapped = type(None)
95not_mapped = None
97class BatchTracer(Tracer):
98  __slots__ = ['val', 'batch_dim']
100  def __init__(self, trace, val, batch_dim: Optional[int]):
101    assert core.skip_checks or type(batch_dim) in (int, NotMapped)  # type: ignore
102    self._trace = trace
103    self.val = val
104    self.batch_dim = batch_dim
106  @property
107  def aval(self):
108    aval = raise_to_shaped(core.get_aval(self.val))
109    if self.batch_dim is not_mapped:
110      return aval
111    else:
112      if aval is core.abstract_unit:
113        return aval
114      elif type(aval) is ShapedArray:
115        assert 0 <= self.batch_dim < aval.ndim
116        new_shape = tuple(np.delete(aval.shape, self.batch_dim))
117        return ShapedArray(new_shape, aval.dtype)
118      else:
119        raise TypeError(aval)
121  def full_lower(self):
122    if self.batch_dim is not_mapped:
123      return core.full_lower(self.val)
124    else:
125      return self
127class BatchTrace(Trace):
128  def __init__(self, *args, axis_name):
129    super().__init__(*args)
130    self.axis_name = axis_name
132  def pure(self, val):
133    return BatchTracer(self, val, not_mapped)
135  def lift(self, val):
136    return BatchTracer(self, val, not_mapped)
138  def sublift(self, val):
139    return BatchTracer(self, val.val, val.batch_dim)
141  def process_primitive(self, primitive, tracers, params):
142    vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
143    if all(bdim is not_mapped for bdim in dims_in):
144      return primitive.bind(*vals_in, **params)
145    if (primitive in collective_rules and
146          _main_trace_for_axis_names(self.main, params['axis_name'])):
147      frame = core.axis_frame(self.axis_name)
148      val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
149    else:
150      batched_primitive = get_primitive_batcher(primitive, self.axis_name)
151      val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
152    if primitive.multiple_results:
153      return map(partial(BatchTracer, self), val_out, dim_out)
154    else:
155      return BatchTracer(self, val_out, dim_out)
157  def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
158    assert call_primitive.multiple_results
159    params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
160    vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
161    if all(bdim is not_mapped for bdim in dims):
162      return call_primitive.bind(f, *vals, **params)
163    else:
164      f, dims_out = batch_subtrace(f, self.main, dims)
165      vals_out = call_primitive.bind(f, *vals, **params)
166      return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
168  def post_process_call(self, call_primitive, out_tracers, params):
169    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
170    main = self.main
171    def todo(vals):
172      trace = main.with_cur_sublevel()
173      return map(partial(BatchTracer, trace), vals, dims)
174    return vals, todo
176  def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
177    vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
178    if all(dim is not_mapped for dim in dims):
179      return map_primitive.bind(f, *vals, **params)
180    else:
181      assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
182      # The logic for the dimension math below is as follows:
183      # ╔═════════════╦════════════════════════════════════════╦═══════════╗
184      # ║ d / in_axis ║ None                                   ║ int       ║
185      # ╠═════════════╬════════════════════════════════════════╩═══════════╣
186      # ║ None        ║ No extra axis, so in_axis unaffected               ║
187      # ╠═════════════╬════════════════════════════════════════╦═══════════╣
188      # ║ int         ║ Not mapped, so batching dim unaffected ║ See below ║
189      # ╚═════════════╩════════════════════════════════════════╩═══════════╝
190      # When both d and in_axis are defined then:
191      # - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
192      # - If `d >  in_axis`, we have to decrement `d` (as `in_axis` will get removed).
193      def both_mapped(in_out_axis, d):
194        return in_out_axis is not None and d is not not_mapped
195      new_in_axes = tuple(
196        in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
197        for d, in_axis in zip(dims, params['in_axes']))
198      new_dims = tuple(
199        d - 1 if both_mapped(in_axis, d) and in_axis < d else d
200        for d, in_axis in zip(dims, params['in_axes']))
201      f, dims_out = batch_subtrace(f, self.main, new_dims)
202      out_axes_thunk = params['out_axes_thunk']
203      # NOTE: This assumes that the choice of the dimensions over which outputs
204      #       are batched is entirely dependent on the function and not e.g. on the
205      #       data or its shapes.
206      @as_hashable_function(closure=out_axes_thunk)
207      def new_out_axes_thunk():
208        return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
209                     for out_axis, d in zip(out_axes_thunk(), dims_out()))
210      new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
211      vals_out = map_primitive.bind(f, *vals, **new_params)
212      dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
213                  for d, out_axis in zip(dims_out(), out_axes_thunk()))
214      return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
216  def post_process_map(self, call_primitive, out_tracers, params):
217    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
218    main = self.main
219    def both_mapped(in_out_axis, d):
220      return in_out_axis is not None and d is not not_mapped
221    def todo(vals):
222      trace = main.with_cur_sublevel()
223      return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
224              for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())]
225    if call_primitive.map_primitive:
226      def out_axes_transform(out_axes):
227        return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
228                     for out_axis, d in zip(out_axes, dims))
229      todo = (todo, out_axes_transform)
230    return vals, todo
232  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
233    in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
234    fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
235    jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
236    out_vals = prim.bind(fun, jvp, *in_vals)
237    fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
238    if not fst:
239      assert out_dims == out_dims[:len(out_dims) // 2] * 2
240      out_dims = out_dims[:len(out_dims) // 2]
241    return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
243  def post_process_custom_jvp_call(self, out_tracers, params):
244    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
245    main = self.main
246    def todo(vals):
247      trace = main.with_cur_sublevel()
248      return map(partial(BatchTracer, trace), vals, dims)
249    return vals, todo
251  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
252    in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
253    axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
254                  if d is not not_mapped}
255    fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
256    fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
257    bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
258                               out_dims2, in_dims)
259    out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
260    fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
261    if not fst:
262      out_dims = out_dims[-len(out_vals) % len(out_dims):]
263    return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
265  post_process_custom_vjp_call = post_process_custom_jvp_call
267def _main_trace_for_axis_names(main_trace: core.MainTrace,
268                               axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
269                               ) -> bool:
270  # This function exists to identify whether a main trace corresponds to any of
271  # the axis names used by a primitive. Axis names alone aren't enough because
272  # axis names can shadow, so we use the main trace as a tag.
273  if not isinstance(axis_name, (list, tuple)):
274    axis_name = (axis_name,)
275  return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
277def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests):
278  bwd, out_dims_thunk = batch_subtrace(bwd)
279  return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims),
280                             axis_size, out_dims_thunk, out_dim_dests)
283def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
284  # this is like _match_axes, but we do reduce-sums as needed
285  out_vals = yield in_vals, {}
286  yield map(partial(matchaxis, axis_size, sum_match=True),
287            out_dims_thunk(), out_dim_dests, out_vals)
290### primitives
292BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
293primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
294initial_style_batchers : Dict[core.Primitive, Any] = {}
296def get_primitive_batcher(p, axis_name):
297  if p in initial_style_batchers:
298    return partial(initial_style_batchers[p], axis_name=axis_name)
299  try:
300    return primitive_batchers[p]
301  except KeyError as err:
302    msg = "Batching rule for '{}' not implemented"
303    raise NotImplementedError(msg.format(p)) from err
305def defvectorized(prim):
306  primitive_batchers[prim] = partial(vectorized_batcher, prim)
308def vectorized_batcher(prim, batched_args, batch_dims, **params):
309  assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
310  return prim.bind(*batched_args, **params), batch_dims[0]
312def defbroadcasting(prim):
313  primitive_batchers[prim] = partial(broadcast_batcher, prim)
315def broadcast_batcher(prim, args, dims, **params):
316  """Process a primitive with built-in broadcasting.
318  Args:
319    args: the possibly-batched arguments
320    dims: list or tuple of the same length as `args`, where each
321      entry indicates the batching state of the corresponding entry to `args`:
322      either an int indicating the batch dimension, or else `not_mapped`
323      indicating no batching.
324  """
325  shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
326  if len(shapes) == 1:
327    # if there's only agreeing batch dims and scalars, just call the primitive
328    d = next(d for d in dims if d is not not_mapped)
329    out = prim.bind(*args, **params)
330    return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
331  else:
332    size, = {shape[d] for shape, d in shapes if d is not not_mapped}
333    args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
334    ndim = max(np.ndim(x) for x in args)  # special-case scalar broadcasting
335    args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
336    out = prim.bind(*args, **params)
337    return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
339def _handle_scalar_broadcasting(nd, x, d):
340  if d is not_mapped or nd == np.ndim(x):
341    return x
342  else:
343    return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
345def defreducer(prim):
346  primitive_batchers[prim] = partial(reducer_batcher, prim)
348def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
349  operand, = batched_args
350  bdim, = batch_dims
351  axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
352  bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
353  if 'input_shape' in params:
354    params = dict(params, input_shape=operand.shape)
355  return prim.bind(operand, axes=axes, **params), bdim_out
357# sets up primitive batchers for ad_util and xla primitives
359def add_batched(batched_args, batch_dims):
360  bdx, bdy = batch_dims
361  x, y = batched_args
362  if bdx == bdy or core.get_aval(x) == core.abstract_unit:
363    return add_jaxvals(x, y), bdx
364  elif bdx is not_mapped:
365    x = broadcast(x, y.shape[bdy], bdy)
366    return add_jaxvals(x, y), bdy
367  elif bdy is not_mapped:
368    y = broadcast(y, x.shape[bdx], bdx)
369    return add_jaxvals(x, y), bdx
370  else:
371    x = moveaxis(x, bdx, bdy)
372    return add_jaxvals(x, y), bdy
373primitive_batchers[add_jaxvals_p] = add_batched
375def zeros_like_batched(batched_args, batch_dims):
376  val, = batched_args
377  bdim, = batch_dims
378  return zeros_like_jaxval(val), bdim
379primitive_batchers[zeros_like_p] = zeros_like_batched
383### util
385def broadcast(x, sz, axis):
386  if core.get_aval(x) is core.abstract_unit:
387    return core.unit
388  shape = list(np.shape(x))
389  shape.insert(axis, sz)
390  broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
391  return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
393def matchaxis(sz, src, dst, x, sum_match=False):
394  if core.get_aval(x) is core.abstract_unit:
395    return core.unit
396  if src == dst:
397    return x
398  elif type(src) == type(dst) == int:
399    return moveaxis(x, src, dst)
400  elif src is not_mapped and dst is not not_mapped:
401    return broadcast(
402      x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
403  elif dst is None and sum_match:
404    return x.sum(src)
405  else:
406    raise ValueError((src, dst))
408def bdim_at_front(x, bdim, size):
409  if core.get_aval(x) is core.abstract_unit:
410    return core.unit
411  if bdim is not_mapped:
412    return broadcast(x, size, 0)
413  else:
414    return moveaxis(x, bdim, 0)
417def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name):
418  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
419  f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
420  f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
421  avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
422              for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
423  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
424  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
427def batch_subtrace_instantiate(instantiate, axis_size, main, in_dims, *in_vals):
428  # this is like `batch_subtrace` but we take an extra `instantiate` arg
429  # analogue of `jvp_subtrace` in ad.py
430  trace = main.with_cur_sublevel()
431  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
432                for val, dim in zip(in_vals, in_dims)]
433  outs = yield in_tracers, {}
434  out_tracers = map(trace.full_raise, outs)
435  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
437  if type(instantiate) is bool:
438    instantiate = [instantiate] * len(out_vals)
439  out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
440              else broadcast(x, axis_size, 0) if d is not_mapped and inst else x
441              for x, d, inst in zip(out_vals, out_dims, instantiate)]
442  out_batched = [d is not not_mapped or inst
443                 for d, inst in zip(out_dims, instantiate)]
444  yield out_vals, out_batched
447def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
448  size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
449  trace = main.with_cur_sublevel()
450  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
451                for val, dim in zip(in_vals, in_dims * 2)]
452  outs = yield in_tracers, {}
453  out_tracers = map(trace.full_raise, outs)
454  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
455  out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
456  out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
457  out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
458  out_primals  = map(partial(matchaxis, size),  out_primal_bds, out_dims,  out_primals)
459  out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
460  yield out_primals + out_tangents, out_dims * 2
462def _merge_bdims(x, y):
463  if x == y:
464    return x
465  elif x is not_mapped:
466    return y
467  elif y is not_mapped:
468    return x
469  else:
470    return x  # arbitrary
474def omnistaging_disabler() -> None:
475  global batch_jaxpr
477  def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name):
478    f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
479    f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
480    f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
481    avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
482                for aval, b in zip(jaxpr.in_avals, in_batched)]
483    in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
484    jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
485    avals_out, _ = unzip2(pvals_out)
486    return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
489collective_rules: Dict[core.Primitive, Callable] = {}