1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16from typing import Any, Callable, Dict, Optional, Tuple, Union 17 18import jax 19from ..config import config 20from .. import core 21from ..core import ShapedArray, raise_to_shaped, Trace, Tracer 22from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p 23from .. import linear_util as lu 24from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list, 25 canonicalize_axis, moveaxis, as_hashable_function) 26from . import xla 27from . import partial_eval as pe 28 29map = safe_map 30 31def batch(fun: lu.WrappedFun, axis_name, axis_size, in_dims, out_dim_dests, 32 ) -> lu.WrappedFun: 33 # anlogue of `jvp` in ad.py 34 fun, out_dims_thunk = batch_subtrace(fun) 35 return _match_axes(batchfun(fun, axis_name, axis_size, in_dims), 36 axis_size, out_dims_thunk, out_dim_dests) 37 38@lu.transformation 39def batchfun(axis_name, axis_size, in_dims, *in_vals): 40 # analogue of `jvpfun` in ad.py 41 in_dims = in_dims() if callable(in_dims) else in_dims 42 in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) 43 and not isinstance(core.get_aval(x), core.AbstractUnit) # non-omnistaging 44 else ax for x, ax in zip(in_vals, in_dims)] 45 with core.new_main(BatchTrace, axis_name=axis_name) as main: 46 with core.extend_axis_env(axis_name, axis_size, main): 47 out_vals = yield (main, in_dims, *in_vals), {} 48 del main 49 yield out_vals 50 51@lu.transformation_with_aux 52def batch_subtrace(main, in_dims, *in_vals): 53 # analogue of `jvp_subtrace` in ad.py 54 trace = main.with_cur_sublevel() 55 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 56 for val, dim in zip(in_vals, in_dims)] 57 outs = yield in_tracers, {} 58 out_tracers = map(trace.full_raise, outs) 59 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 60 yield out_vals, out_dims 61 62@lu.transformation 63def _match_axes(axis_size, out_dims_thunk, out_dim_dests, *in_vals): 64 out_vals = yield in_vals, {} 65 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests 66 out_dims = out_dims_thunk() 67 for od, od_dest in zip(out_dims, out_dim_dests): 68 if od is not None and not isinstance(od_dest, int): 69 msg = f"vmap has mapped output but out_axes is {od_dest}" 70 raise ValueError(msg) 71 yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals) 72 73 74# These next two functions, `batch_fun2` and `_batch_fun2`, are deprecated; the 75# former is only called from `custom_transforms`, which itself is deprecated. 76# TODO(mattjj): delete these along with custom_transforms 77 78def batch_fun2(fun: lu.WrappedFun, in_dims): 79 # like `batch_fun` but returns output batch dims (so no out_dim_dests) 80 fun, out_dims = batch_subtrace(fun) 81 return _batch_fun2(fun, in_dims), out_dims 82 83@lu.transformation 84def _batch_fun2(in_dims, *in_vals, **params): 85 with core.new_main(BatchTrace, axis_name=None) as main: 86 out_vals = yield (main, in_dims,) + in_vals, params 87 del main 88 yield out_vals 89 90 91### tracer 92 93# TODO(mattjj): use a special sentinel type rather than None 94NotMapped = type(None) 95not_mapped = None 96 97class BatchTracer(Tracer): 98 __slots__ = ['val', 'batch_dim'] 99 100 def __init__(self, trace, val, batch_dim: Optional[int]): 101 assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore 102 self._trace = trace 103 self.val = val 104 self.batch_dim = batch_dim 105 106 @property 107 def aval(self): 108 aval = raise_to_shaped(core.get_aval(self.val)) 109 if self.batch_dim is not_mapped: 110 return aval 111 else: 112 if aval is core.abstract_unit: 113 return aval 114 elif type(aval) is ShapedArray: 115 assert 0 <= self.batch_dim < aval.ndim 116 new_shape = tuple(np.delete(aval.shape, self.batch_dim)) 117 return ShapedArray(new_shape, aval.dtype) 118 else: 119 raise TypeError(aval) 120 121 def full_lower(self): 122 if self.batch_dim is not_mapped: 123 return core.full_lower(self.val) 124 else: 125 return self 126 127class BatchTrace(Trace): 128 def __init__(self, *args, axis_name): 129 super().__init__(*args) 130 self.axis_name = axis_name 131 132 def pure(self, val): 133 return BatchTracer(self, val, not_mapped) 134 135 def lift(self, val): 136 return BatchTracer(self, val, not_mapped) 137 138 def sublift(self, val): 139 return BatchTracer(self, val.val, val.batch_dim) 140 141 def process_primitive(self, primitive, tracers, params): 142 vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) 143 if all(bdim is not_mapped for bdim in dims_in): 144 return primitive.bind(*vals_in, **params) 145 if (primitive in collective_rules and 146 _main_trace_for_axis_names(self.main, params['axis_name'])): 147 frame = core.axis_frame(self.axis_name) 148 val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params) 149 else: 150 batched_primitive = get_primitive_batcher(primitive, self.axis_name) 151 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 152 if primitive.multiple_results: 153 return map(partial(BatchTracer, self), val_out, dim_out) 154 else: 155 return BatchTracer(self, val_out, dim_out) 156 157 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): 158 assert call_primitive.multiple_results 159 params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) 160 vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) 161 if all(bdim is not_mapped for bdim in dims): 162 return call_primitive.bind(f, *vals, **params) 163 else: 164 f, dims_out = batch_subtrace(f, self.main, dims) 165 vals_out = call_primitive.bind(f, *vals, **params) 166 return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())] 167 168 def post_process_call(self, call_primitive, out_tracers, params): 169 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 170 main = self.main 171 def todo(vals): 172 trace = main.with_cur_sublevel() 173 return map(partial(BatchTracer, trace), vals, dims) 174 return vals, todo 175 176 def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): 177 vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) 178 if all(dim is not_mapped for dim in dims): 179 return map_primitive.bind(f, *vals, **params) 180 else: 181 assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1 182 # The logic for the dimension math below is as follows: 183 # ╔═════════════╦════════════════════════════════════════╦═══════════╗ 184 # ║ d / in_axis ║ None ║ int ║ 185 # ╠═════════════╬════════════════════════════════════════╩═══════════╣ 186 # ║ None ║ No extra axis, so in_axis unaffected ║ 187 # ╠═════════════╬════════════════════════════════════════╦═══════════╣ 188 # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ 189 # ╚═════════════╩════════════════════════════════════════╩═══════════╝ 190 # When both d and in_axis are defined then: 191 # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; 192 # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). 193 def both_mapped(in_out_axis, d): 194 return in_out_axis is not None and d is not not_mapped 195 new_in_axes = tuple( 196 in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis 197 for d, in_axis in zip(dims, params['in_axes'])) 198 new_dims = tuple( 199 d - 1 if both_mapped(in_axis, d) and in_axis < d else d 200 for d, in_axis in zip(dims, params['in_axes'])) 201 f, dims_out = batch_subtrace(f, self.main, new_dims) 202 out_axes_thunk = params['out_axes_thunk'] 203 # NOTE: This assumes that the choice of the dimensions over which outputs 204 # are batched is entirely dependent on the function and not e.g. on the 205 # data or its shapes. 206 @as_hashable_function(closure=out_axes_thunk) 207 def new_out_axes_thunk(): 208 return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis 209 for out_axis, d in zip(out_axes_thunk(), dims_out())) 210 new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) 211 vals_out = map_primitive.bind(f, *vals, **new_params) 212 dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d 213 for d, out_axis in zip(dims_out(), out_axes_thunk())) 214 return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)] 215 216 def post_process_map(self, call_primitive, out_tracers, params): 217 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 218 main = self.main 219 def both_mapped(in_out_axis, d): 220 return in_out_axis is not None and d is not not_mapped 221 def todo(vals): 222 trace = main.with_cur_sublevel() 223 return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d) 224 for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())] 225 if call_primitive.map_primitive: 226 def out_axes_transform(out_axes): 227 return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis 228 for out_axis, d in zip(out_axes, dims)) 229 todo = (todo, out_axes_transform) 230 return vals, todo 231 232 def process_custom_jvp_call(self, prim, fun, jvp, tracers): 233 in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) 234 fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) 235 jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) 236 out_vals = prim.bind(fun, jvp, *in_vals) 237 fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) 238 if not fst: 239 assert out_dims == out_dims[:len(out_dims) // 2] * 2 240 out_dims = out_dims[:len(out_dims) // 2] 241 return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)] 242 243 def post_process_custom_jvp_call(self, out_tracers, params): 244 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 245 main = self.main 246 def todo(vals): 247 trace = main.with_cur_sublevel() 248 return map(partial(BatchTracer, trace), vals, dims) 249 return vals, todo 250 251 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): 252 in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) 253 axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) 254 if d is not not_mapped} 255 fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) 256 fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) 257 bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, 258 out_dims2, in_dims) 259 out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) 260 fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) 261 if not fst: 262 out_dims = out_dims[-len(out_vals) % len(out_dims):] 263 return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)] 264 265 post_process_custom_vjp_call = post_process_custom_jvp_call 266 267def _main_trace_for_axis_names(main_trace: core.MainTrace, 268 axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]] 269 ) -> bool: 270 # This function exists to identify whether a main trace corresponds to any of 271 # the axis names used by a primitive. Axis names alone aren't enough because 272 # axis names can shadow, so we use the main trace as a tag. 273 if not isinstance(axis_name, (list, tuple)): 274 axis_name = (axis_name,) 275 return any(main_trace is core.axis_frame(n).main_trace for n in axis_name) 276 277def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests): 278 bwd, out_dims_thunk = batch_subtrace(bwd) 279 return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims), 280 axis_size, out_dims_thunk, out_dim_dests) 281 282@lu.transformation 283def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals): 284 # this is like _match_axes, but we do reduce-sums as needed 285 out_vals = yield in_vals, {} 286 yield map(partial(matchaxis, axis_size, sum_match=True), 287 out_dims_thunk(), out_dim_dests, out_vals) 288 289 290### primitives 291 292BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]] 293primitive_batchers : Dict[core.Primitive, BatchingRule] = {} 294initial_style_batchers : Dict[core.Primitive, Any] = {} 295 296def get_primitive_batcher(p, axis_name): 297 if p in initial_style_batchers: 298 return partial(initial_style_batchers[p], axis_name=axis_name) 299 try: 300 return primitive_batchers[p] 301 except KeyError as err: 302 msg = "Batching rule for '{}' not implemented" 303 raise NotImplementedError(msg.format(p)) from err 304 305def defvectorized(prim): 306 primitive_batchers[prim] = partial(vectorized_batcher, prim) 307 308def vectorized_batcher(prim, batched_args, batch_dims, **params): 309 assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims 310 return prim.bind(*batched_args, **params), batch_dims[0] 311 312def defbroadcasting(prim): 313 primitive_batchers[prim] = partial(broadcast_batcher, prim) 314 315def broadcast_batcher(prim, args, dims, **params): 316 """Process a primitive with built-in broadcasting. 317 318 Args: 319 args: the possibly-batched arguments 320 dims: list or tuple of the same length as `args`, where each 321 entry indicates the batching state of the corresponding entry to `args`: 322 either an int indicating the batch dimension, or else `not_mapped` 323 indicating no batching. 324 """ 325 shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)} 326 if len(shapes) == 1: 327 # if there's only agreeing batch dims and scalars, just call the primitive 328 d = next(d for d in dims if d is not not_mapped) 329 out = prim.bind(*args, **params) 330 return (out, (d,) * len(out)) if prim.multiple_results else (out, d) 331 else: 332 size, = {shape[d] for shape, d in shapes if d is not not_mapped} 333 args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)] 334 ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting 335 args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)] 336 out = prim.bind(*args, **params) 337 return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) 338 339def _handle_scalar_broadcasting(nd, x, d): 340 if d is not_mapped or nd == np.ndim(x): 341 return x 342 else: 343 return x.reshape(x.shape + (1,) * (nd - np.ndim(x))) 344 345def defreducer(prim): 346 primitive_batchers[prim] = partial(reducer_batcher, prim) 347 348def reducer_batcher(prim, batched_args, batch_dims, axes, **params): 349 operand, = batched_args 350 bdim, = batch_dims 351 axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1))) 352 bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim)) 353 if 'input_shape' in params: 354 params = dict(params, input_shape=operand.shape) 355 return prim.bind(operand, axes=axes, **params), bdim_out 356 357# sets up primitive batchers for ad_util and xla primitives 358 359def add_batched(batched_args, batch_dims): 360 bdx, bdy = batch_dims 361 x, y = batched_args 362 if bdx == bdy or core.get_aval(x) == core.abstract_unit: 363 return add_jaxvals(x, y), bdx 364 elif bdx is not_mapped: 365 x = broadcast(x, y.shape[bdy], bdy) 366 return add_jaxvals(x, y), bdy 367 elif bdy is not_mapped: 368 y = broadcast(y, x.shape[bdx], bdx) 369 return add_jaxvals(x, y), bdx 370 else: 371 x = moveaxis(x, bdx, bdy) 372 return add_jaxvals(x, y), bdy 373primitive_batchers[add_jaxvals_p] = add_batched 374 375def zeros_like_batched(batched_args, batch_dims): 376 val, = batched_args 377 bdim, = batch_dims 378 return zeros_like_jaxval(val), bdim 379primitive_batchers[zeros_like_p] = zeros_like_batched 380 381defvectorized(xla.device_put_p) 382 383### util 384 385def broadcast(x, sz, axis): 386 if core.get_aval(x) is core.abstract_unit: 387 return core.unit 388 shape = list(np.shape(x)) 389 shape.insert(axis, sz) 390 broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) 391 return jax.lax.broadcast_in_dim(x, shape, broadcast_dims) 392 393def matchaxis(sz, src, dst, x, sum_match=False): 394 if core.get_aval(x) is core.abstract_unit: 395 return core.unit 396 if src == dst: 397 return x 398 elif type(src) == type(dst) == int: 399 return moveaxis(x, src, dst) 400 elif src is not_mapped and dst is not not_mapped: 401 return broadcast( 402 x, sz, canonicalize_axis(dst, np.ndim(x) + 1)) 403 elif dst is None and sum_match: 404 return x.sum(src) 405 else: 406 raise ValueError((src, dst)) 407 408def bdim_at_front(x, bdim, size): 409 if core.get_aval(x) is core.abstract_unit: 410 return core.unit 411 if bdim is not_mapped: 412 return broadcast(x, size, 0) 413 else: 414 return moveaxis(x, bdim, 0) 415 416 417def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name): 418 f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) 419 f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size) 420 f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched]) 421 avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval 422 for aval, b in zip(closed_jaxpr.in_avals, in_batched)] 423 jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) 424 return core.ClosedJaxpr(jaxpr_out, consts), out_batched() 425 426@lu.transformation_with_aux 427def batch_subtrace_instantiate(instantiate, axis_size, main, in_dims, *in_vals): 428 # this is like `batch_subtrace` but we take an extra `instantiate` arg 429 # analogue of `jvp_subtrace` in ad.py 430 trace = main.with_cur_sublevel() 431 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 432 for val, dim in zip(in_vals, in_dims)] 433 outs = yield in_tracers, {} 434 out_tracers = map(trace.full_raise, outs) 435 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 436 437 if type(instantiate) is bool: 438 instantiate = [instantiate] * len(out_vals) 439 out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0 440 else broadcast(x, axis_size, 0) if d is not_mapped and inst else x 441 for x, d, inst in zip(out_vals, out_dims, instantiate)] 442 out_batched = [d is not not_mapped or inst 443 for d, inst in zip(out_dims, instantiate)] 444 yield out_vals, out_batched 445 446@lu.transformation_with_aux 447def batch_custom_jvp_subtrace(main, in_dims, *in_vals): 448 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} 449 trace = main.with_cur_sublevel() 450 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 451 for val, dim in zip(in_vals, in_dims * 2)] 452 outs = yield in_tracers, {} 453 out_tracers = map(trace.full_raise, outs) 454 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 455 out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) 456 out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) 457 out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) 458 out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals) 459 out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents) 460 yield out_primals + out_tangents, out_dims * 2 461 462def _merge_bdims(x, y): 463 if x == y: 464 return x 465 elif x is not_mapped: 466 return y 467 elif y is not_mapped: 468 return x 469 else: 470 return x # arbitrary 471 472 473@config.register_omnistaging_disabler 474def omnistaging_disabler() -> None: 475 global batch_jaxpr 476 477 def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name): 478 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 479 f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size) 480 f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched]) 481 avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval 482 for aval, b in zip(jaxpr.in_avals, in_batched)] 483 in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] 484 jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True) 485 avals_out, _ = unzip2(pvals_out) 486 return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched() 487 488 489collective_rules: Dict[core.Primitive, Callable] = {} 490