1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16from typing import Any, Callable, Dict, Optional, Tuple, Union
17
18import jax
19from ..config import config
20from .. import core
21from ..core import ShapedArray, raise_to_shaped, Trace, Tracer
22from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
23from .. import linear_util as lu
24from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list,
25                         canonicalize_axis, moveaxis, as_hashable_function)
26from . import xla
27from . import partial_eval as pe
28
29map = safe_map
30
31def batch(fun: lu.WrappedFun, axis_name, axis_size, in_dims, out_dim_dests,
32          ) -> lu.WrappedFun:
33  # anlogue of `jvp` in ad.py
34  fun, out_dims_thunk = batch_subtrace(fun)
35  return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
36                     axis_size, out_dims_thunk, out_dim_dests)
37
38@lu.transformation
39def batchfun(axis_name, axis_size, in_dims, *in_vals):
40  # analogue of `jvpfun` in ad.py
41  in_dims = in_dims() if callable(in_dims) else in_dims
42  in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
43             and not isinstance(core.get_aval(x), core.AbstractUnit)  # non-omnistaging
44             else ax for x, ax in zip(in_vals, in_dims)]
45  with core.new_main(BatchTrace, axis_name=axis_name) as main:
46    with core.extend_axis_env(axis_name, axis_size, main):
47      out_vals = yield (main, in_dims, *in_vals), {}
48      del main
49  yield out_vals
50
51@lu.transformation_with_aux
52def batch_subtrace(main, in_dims, *in_vals):
53  # analogue of `jvp_subtrace` in ad.py
54  trace = main.with_cur_sublevel()
55  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
56                for val, dim in zip(in_vals, in_dims)]
57  outs = yield in_tracers, {}
58  out_tracers = map(trace.full_raise, outs)
59  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
60  yield out_vals, out_dims
61
62@lu.transformation
63def _match_axes(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
64  out_vals = yield in_vals, {}
65  out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
66  out_dims = out_dims_thunk()
67  for od, od_dest in zip(out_dims, out_dim_dests):
68    if od is not None and not isinstance(od_dest, int):
69      msg = f"vmap has mapped output but out_axes is {od_dest}"
70      raise ValueError(msg)
71  yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
72
73
74# These next two functions, `batch_fun2` and `_batch_fun2`, are deprecated; the
75# former is only called from `custom_transforms`, which itself is deprecated.
76# TODO(mattjj): delete these along with custom_transforms
77
78def batch_fun2(fun: lu.WrappedFun, in_dims):
79  # like `batch_fun` but returns output batch dims (so no out_dim_dests)
80  fun, out_dims = batch_subtrace(fun)
81  return _batch_fun2(fun, in_dims), out_dims
82
83@lu.transformation
84def _batch_fun2(in_dims, *in_vals, **params):
85  with core.new_main(BatchTrace, axis_name=None) as main:
86    out_vals = yield (main, in_dims,) + in_vals, params
87    del main
88  yield out_vals
89
90
91### tracer
92
93# TODO(mattjj): use a special sentinel type rather than None
94NotMapped = type(None)
95not_mapped = None
96
97class BatchTracer(Tracer):
98  __slots__ = ['val', 'batch_dim']
99
100  def __init__(self, trace, val, batch_dim: Optional[int]):
101    assert core.skip_checks or type(batch_dim) in (int, NotMapped)  # type: ignore
102    self._trace = trace
103    self.val = val
104    self.batch_dim = batch_dim
105
106  @property
107  def aval(self):
108    aval = raise_to_shaped(core.get_aval(self.val))
109    if self.batch_dim is not_mapped:
110      return aval
111    else:
112      if aval is core.abstract_unit:
113        return aval
114      elif type(aval) is ShapedArray:
115        assert 0 <= self.batch_dim < aval.ndim
116        new_shape = tuple(np.delete(aval.shape, self.batch_dim))
117        return ShapedArray(new_shape, aval.dtype)
118      else:
119        raise TypeError(aval)
120
121  def full_lower(self):
122    if self.batch_dim is not_mapped:
123      return core.full_lower(self.val)
124    else:
125      return self
126
127class BatchTrace(Trace):
128  def __init__(self, *args, axis_name):
129    super().__init__(*args)
130    self.axis_name = axis_name
131
132  def pure(self, val):
133    return BatchTracer(self, val, not_mapped)
134
135  def lift(self, val):
136    return BatchTracer(self, val, not_mapped)
137
138  def sublift(self, val):
139    return BatchTracer(self, val.val, val.batch_dim)
140
141  def process_primitive(self, primitive, tracers, params):
142    vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
143    if all(bdim is not_mapped for bdim in dims_in):
144      return primitive.bind(*vals_in, **params)
145    if (primitive in collective_rules and
146          _main_trace_for_axis_names(self.main, params['axis_name'])):
147      frame = core.axis_frame(self.axis_name)
148      val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
149    else:
150      batched_primitive = get_primitive_batcher(primitive, self.axis_name)
151      val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
152    if primitive.multiple_results:
153      return map(partial(BatchTracer, self), val_out, dim_out)
154    else:
155      return BatchTracer(self, val_out, dim_out)
156
157  def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
158    assert call_primitive.multiple_results
159    params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
160    vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
161    if all(bdim is not_mapped for bdim in dims):
162      return call_primitive.bind(f, *vals, **params)
163    else:
164      f, dims_out = batch_subtrace(f, self.main, dims)
165      vals_out = call_primitive.bind(f, *vals, **params)
166      return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
167
168  def post_process_call(self, call_primitive, out_tracers, params):
169    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
170    main = self.main
171    def todo(vals):
172      trace = main.with_cur_sublevel()
173      return map(partial(BatchTracer, trace), vals, dims)
174    return vals, todo
175
176  def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
177    vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
178    if all(dim is not_mapped for dim in dims):
179      return map_primitive.bind(f, *vals, **params)
180    else:
181      assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
182      # The logic for the dimension math below is as follows:
183      # ╔═════════════╦════════════════════════════════════════╦═══════════╗
184      # ║ d / in_axis ║ None                                   ║ int       ║
185      # ╠═════════════╬════════════════════════════════════════╩═══════════╣
186      # ║ None        ║ No extra axis, so in_axis unaffected               ║
187      # ╠═════════════╬════════════════════════════════════════╦═══════════╣
188      # ║ int         ║ Not mapped, so batching dim unaffected ║ See below ║
189      # ╚═════════════╩════════════════════════════════════════╩═══════════╝
190      # When both d and in_axis are defined then:
191      # - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
192      # - If `d >  in_axis`, we have to decrement `d` (as `in_axis` will get removed).
193      def both_mapped(in_out_axis, d):
194        return in_out_axis is not None and d is not not_mapped
195      new_in_axes = tuple(
196        in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
197        for d, in_axis in zip(dims, params['in_axes']))
198      new_dims = tuple(
199        d - 1 if both_mapped(in_axis, d) and in_axis < d else d
200        for d, in_axis in zip(dims, params['in_axes']))
201      f, dims_out = batch_subtrace(f, self.main, new_dims)
202      out_axes_thunk = params['out_axes_thunk']
203      # NOTE: This assumes that the choice of the dimensions over which outputs
204      #       are batched is entirely dependent on the function and not e.g. on the
205      #       data or its shapes.
206      @as_hashable_function(closure=out_axes_thunk)
207      def new_out_axes_thunk():
208        return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
209                     for out_axis, d in zip(out_axes_thunk(), dims_out()))
210      new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
211      vals_out = map_primitive.bind(f, *vals, **new_params)
212      dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
213                  for d, out_axis in zip(dims_out(), out_axes_thunk()))
214      return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
215
216  def post_process_map(self, call_primitive, out_tracers, params):
217    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
218    main = self.main
219    def both_mapped(in_out_axis, d):
220      return in_out_axis is not None and d is not not_mapped
221    def todo(vals):
222      trace = main.with_cur_sublevel()
223      return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
224              for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())]
225    if call_primitive.map_primitive:
226      def out_axes_transform(out_axes):
227        return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
228                     for out_axis, d in zip(out_axes, dims))
229      todo = (todo, out_axes_transform)
230    return vals, todo
231
232  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
233    in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
234    fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
235    jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
236    out_vals = prim.bind(fun, jvp, *in_vals)
237    fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
238    if not fst:
239      assert out_dims == out_dims[:len(out_dims) // 2] * 2
240      out_dims = out_dims[:len(out_dims) // 2]
241    return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
242
243  def post_process_custom_jvp_call(self, out_tracers, params):
244    vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
245    main = self.main
246    def todo(vals):
247      trace = main.with_cur_sublevel()
248      return map(partial(BatchTracer, trace), vals, dims)
249    return vals, todo
250
251  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
252    in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
253    axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
254                  if d is not not_mapped}
255    fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
256    fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
257    bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
258                               out_dims2, in_dims)
259    out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
260    fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
261    if not fst:
262      out_dims = out_dims[-len(out_vals) % len(out_dims):]
263    return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
264
265  post_process_custom_vjp_call = post_process_custom_jvp_call
266
267def _main_trace_for_axis_names(main_trace: core.MainTrace,
268                               axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
269                               ) -> bool:
270  # This function exists to identify whether a main trace corresponds to any of
271  # the axis names used by a primitive. Axis names alone aren't enough because
272  # axis names can shadow, so we use the main trace as a tag.
273  if not isinstance(axis_name, (list, tuple)):
274    axis_name = (axis_name,)
275  return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
276
277def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests):
278  bwd, out_dims_thunk = batch_subtrace(bwd)
279  return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims),
280                             axis_size, out_dims_thunk, out_dim_dests)
281
282@lu.transformation
283def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
284  # this is like _match_axes, but we do reduce-sums as needed
285  out_vals = yield in_vals, {}
286  yield map(partial(matchaxis, axis_size, sum_match=True),
287            out_dims_thunk(), out_dim_dests, out_vals)
288
289
290### primitives
291
292BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
293primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
294initial_style_batchers : Dict[core.Primitive, Any] = {}
295
296def get_primitive_batcher(p, axis_name):
297  if p in initial_style_batchers:
298    return partial(initial_style_batchers[p], axis_name=axis_name)
299  try:
300    return primitive_batchers[p]
301  except KeyError as err:
302    msg = "Batching rule for '{}' not implemented"
303    raise NotImplementedError(msg.format(p)) from err
304
305def defvectorized(prim):
306  primitive_batchers[prim] = partial(vectorized_batcher, prim)
307
308def vectorized_batcher(prim, batched_args, batch_dims, **params):
309  assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
310  return prim.bind(*batched_args, **params), batch_dims[0]
311
312def defbroadcasting(prim):
313  primitive_batchers[prim] = partial(broadcast_batcher, prim)
314
315def broadcast_batcher(prim, args, dims, **params):
316  """Process a primitive with built-in broadcasting.
317
318  Args:
319    args: the possibly-batched arguments
320    dims: list or tuple of the same length as `args`, where each
321      entry indicates the batching state of the corresponding entry to `args`:
322      either an int indicating the batch dimension, or else `not_mapped`
323      indicating no batching.
324  """
325  shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
326  if len(shapes) == 1:
327    # if there's only agreeing batch dims and scalars, just call the primitive
328    d = next(d for d in dims if d is not not_mapped)
329    out = prim.bind(*args, **params)
330    return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
331  else:
332    size, = {shape[d] for shape, d in shapes if d is not not_mapped}
333    args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
334    ndim = max(np.ndim(x) for x in args)  # special-case scalar broadcasting
335    args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
336    out = prim.bind(*args, **params)
337    return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
338
339def _handle_scalar_broadcasting(nd, x, d):
340  if d is not_mapped or nd == np.ndim(x):
341    return x
342  else:
343    return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
344
345def defreducer(prim):
346  primitive_batchers[prim] = partial(reducer_batcher, prim)
347
348def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
349  operand, = batched_args
350  bdim, = batch_dims
351  axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
352  bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
353  if 'input_shape' in params:
354    params = dict(params, input_shape=operand.shape)
355  return prim.bind(operand, axes=axes, **params), bdim_out
356
357# sets up primitive batchers for ad_util and xla primitives
358
359def add_batched(batched_args, batch_dims):
360  bdx, bdy = batch_dims
361  x, y = batched_args
362  if bdx == bdy or core.get_aval(x) == core.abstract_unit:
363    return add_jaxvals(x, y), bdx
364  elif bdx is not_mapped:
365    x = broadcast(x, y.shape[bdy], bdy)
366    return add_jaxvals(x, y), bdy
367  elif bdy is not_mapped:
368    y = broadcast(y, x.shape[bdx], bdx)
369    return add_jaxvals(x, y), bdx
370  else:
371    x = moveaxis(x, bdx, bdy)
372    return add_jaxvals(x, y), bdy
373primitive_batchers[add_jaxvals_p] = add_batched
374
375def zeros_like_batched(batched_args, batch_dims):
376  val, = batched_args
377  bdim, = batch_dims
378  return zeros_like_jaxval(val), bdim
379primitive_batchers[zeros_like_p] = zeros_like_batched
380
381defvectorized(xla.device_put_p)
382
383### util
384
385def broadcast(x, sz, axis):
386  if core.get_aval(x) is core.abstract_unit:
387    return core.unit
388  shape = list(np.shape(x))
389  shape.insert(axis, sz)
390  broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
391  return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
392
393def matchaxis(sz, src, dst, x, sum_match=False):
394  if core.get_aval(x) is core.abstract_unit:
395    return core.unit
396  if src == dst:
397    return x
398  elif type(src) == type(dst) == int:
399    return moveaxis(x, src, dst)
400  elif src is not_mapped and dst is not not_mapped:
401    return broadcast(
402      x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
403  elif dst is None and sum_match:
404    return x.sum(src)
405  else:
406    raise ValueError((src, dst))
407
408def bdim_at_front(x, bdim, size):
409  if core.get_aval(x) is core.abstract_unit:
410    return core.unit
411  if bdim is not_mapped:
412    return broadcast(x, size, 0)
413  else:
414    return moveaxis(x, bdim, 0)
415
416
417def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name):
418  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
419  f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
420  f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
421  avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
422              for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
423  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
424  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
425
426@lu.transformation_with_aux
427def batch_subtrace_instantiate(instantiate, axis_size, main, in_dims, *in_vals):
428  # this is like `batch_subtrace` but we take an extra `instantiate` arg
429  # analogue of `jvp_subtrace` in ad.py
430  trace = main.with_cur_sublevel()
431  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
432                for val, dim in zip(in_vals, in_dims)]
433  outs = yield in_tracers, {}
434  out_tracers = map(trace.full_raise, outs)
435  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
436
437  if type(instantiate) is bool:
438    instantiate = [instantiate] * len(out_vals)
439  out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
440              else broadcast(x, axis_size, 0) if d is not_mapped and inst else x
441              for x, d, inst in zip(out_vals, out_dims, instantiate)]
442  out_batched = [d is not not_mapped or inst
443                 for d, inst in zip(out_dims, instantiate)]
444  yield out_vals, out_batched
445
446@lu.transformation_with_aux
447def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
448  size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
449  trace = main.with_cur_sublevel()
450  in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
451                for val, dim in zip(in_vals, in_dims * 2)]
452  outs = yield in_tracers, {}
453  out_tracers = map(trace.full_raise, outs)
454  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
455  out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
456  out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
457  out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
458  out_primals  = map(partial(matchaxis, size),  out_primal_bds, out_dims,  out_primals)
459  out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
460  yield out_primals + out_tangents, out_dims * 2
461
462def _merge_bdims(x, y):
463  if x == y:
464    return x
465  elif x is not_mapped:
466    return y
467  elif y is not_mapped:
468    return x
469  else:
470    return x  # arbitrary
471
472
473@config.register_omnistaging_disabler
474def omnistaging_disabler() -> None:
475  global batch_jaxpr
476
477  def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name):
478    f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
479    f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
480    f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
481    avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
482                for aval, b in zip(jaxpr.in_avals, in_batched)]
483    in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
484    jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
485    avals_out, _ = unzip2(pvals_out)
486    return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
487
488
489collective_rules: Dict[core.Primitive, Callable] = {}
490