1# coding=utf-8 2# Copyright 2019 Google LLC 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15""" 16Control flow primitives. 17""" 18 19 20import collections 21import functools 22import inspect 23import itertools 24import operator 25import os 26from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar 27 28import numpy as np 29 30import jax 31from jax import api 32from jax import core 33from jax import dtypes 34from jax._src import source_info_util 35from jax._src import util 36from jax._src.lax import lax 37from jax import linear_util as lu 38from jax.core import ConcreteArray, ShapedArray, raise_to_shaped 39from jax.api_util import flatten_fun_nokwargs 40from jax.interpreters import ad 41from jax.interpreters import partial_eval as pe 42from jax.interpreters import xla 43from jax.interpreters import batching 44from jax.interpreters import masking 45from jax.lib import xla_bridge as xb 46from jax.lib import xla_client 47from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip, 48 split_list, cache, extend_name_stack) 49from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, 50 treedef_children, treedef_tuple, tree_multimap, 51 tree_leaves) 52from jax import ad_util 53from jax.config import config 54 55xops = xla_client.ops 56 57_map = safe_map 58zip = safe_zip 59_reduce = functools.reduce 60 61T = TypeVar('T') 62Array = Any 63 64@cache() 65def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals): 66 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) 67 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) 68 return jaxpr, consts, out_tree() 69 70@cache() 71def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): 72 jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals) 73 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 74 return closed_jaxpr, consts, out_tree 75 76@cache() 77def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], 78 in_tree, in_avals): 79 # When staging the branches of a conditional into jaxprs, constants are 80 # extracted from each branch and converted to jaxpr arguments. To use the 81 # staged jaxprs as the branches to a conditional *primitive*, we need for 82 # their (input) signatures to match. This function "joins" the staged jaxprs: 83 # for each one, it makes another that accepts *all* constants, but only uses 84 # those that it needs (dropping the rest). 85 86 jaxprs, all_consts, all_out_trees = unzip3( 87 _initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs) 88 89 newvar = core.gensym(jaxprs, suffix='_') 90 all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts] 91 for consts in all_consts] 92 unused_const_vars = [[newvar(aval) for aval in const_avals] 93 for const_avals in all_const_avals] 94 95 def pad_jaxpr_constvars(i, jaxpr): 96 prefix = util.concatenate(unused_const_vars[:i]) 97 suffix = util.concatenate(unused_const_vars[i + 1:]) 98 constvars = [*prefix, *jaxpr.constvars, *suffix] 99 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars, 100 outvars=jaxpr.outvars, eqns=jaxpr.eqns) 101 102 consts = util.concatenate(all_consts) 103 jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] 104 closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 105 for jaxpr in jaxprs] 106 return closed_jaxprs, consts, all_out_trees 107 108def _abstractify(x): 109 return raise_to_shaped(core.get_aval(x)) 110 111def _typecheck_param(prim, param, name, msg_required, pred): 112 msg = (f'invalid {prim} param {name} of type {type(param).__name__}, ' 113 f'{msg_required} required:') 114 param_str = str(param) 115 sep = os.linesep if os.linesep in param_str else ' ' 116 msg = sep.join([msg, param_str]) 117 core.typecheck_assert(pred, msg) 118 119 120### fori_loop and while_loop 121 122def _fori_cond_fun(loop_carry): 123 i, upper, _ = loop_carry 124 return lax.lt(i, upper) 125 126@cache() 127def _fori_body_fun(body_fun): 128 def while_body_fun(loop_carry): 129 i, upper, x = loop_carry 130 return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x) 131 return while_body_fun 132 133@cache() 134def _fori_scan_body_fun(body_fun): 135 def scanned_fun(loop_carry, _): 136 i, upper, x = loop_carry 137 return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None 138 return scanned_fun 139 140def fori_loop(lower, upper, body_fun, init_val): 141 """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. 142 143 The type signature in brief is 144 145 .. code-block:: haskell 146 147 fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a 148 149 The semantics of ``fori_loop`` are given by this Python implementation:: 150 151 def fori_loop(lower, upper, body_fun, init_val): 152 val = init_val 153 for i in range(lower, upper): 154 val = body_fun(i, val) 155 return val 156 157 Unlike that Python version, ``fori_loop`` is implemented in terms of a call to 158 :func:`jax.lax.while_loop`. See the :func:`jax.lax.while_loop` documentation 159 for more information. 160 161 Also unlike the Python analogue, the loop-carried value ``val`` must hold a 162 fixed shape and dtype across all iterations (and not just be consistent up to 163 NumPy rank/shape broadcasting and dtype promotion rules, for example). In 164 other words, the type ``a`` in the type signature above represents an array 165 with a fixed shape and dtype (or a nested tuple/list/dict container data 166 structure with a fixed structure and arrays with fixed shape and dtype at the 167 leaves). 168 169 Args: 170 lower: an integer representing the loop index lower bound (inclusive) 171 upper: an integer representing the loop index upper bound (exclusive) 172 body_fun: function of type ``(int, a) -> a``. 173 init_val: initial loop carry value of type ``a``. 174 175 Returns: 176 Loop value from the final iteration, of type ``a``. 177 """ 178 # TODO(phawkins): perhaps do more type checking here, better error messages. 179 lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) 180 upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) 181 if lower_dtype != upper_dtype: 182 msg = ("lower and upper arguments to fori_loop must have equal types, " 183 "got {} and {}") 184 raise TypeError(msg.format(lower_dtype.name, upper_dtype.name)) 185 186 # If we can specialize on the trip count, call scan instead of a while_loop 187 # to enable efficient reverse-mode differentiation. 188 try: 189 lower_ = int(lower) 190 upper_ = int(upper) 191 except TypeError: 192 use_scan = False 193 else: 194 use_scan = False # TODO(mattjj): re-enable this 195 196 if use_scan: 197 (_, _, result), _ = scan(_fori_scan_body_fun(body_fun), 198 (lower, upper, init_val), None, 199 length=upper_ - lower_) 200 else: 201 _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), 202 (lower, upper, init_val)) 203 return result 204 205 206def while_loop(cond_fun: Callable[[T], bool], 207 body_fun: Callable[[T], T], 208 init_val: T) -> T: 209 """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. 210 211 The type signature in brief is 212 213 .. code-block:: haskell 214 215 while_loop :: (a -> Bool) -> (a -> a) -> a -> a 216 217 The semantics of ``while_loop`` are given by this Python implementation:: 218 219 def while_loop(cond_fun, body_fun, init_val): 220 val = init_val 221 while cond_fun(val): 222 val = body_fun(val) 223 return val 224 225 Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered 226 to a single XLA While HLO. That makes it useful for reducing compilation times 227 for jit-compiled functions, since native Python loop constructs in an ``@jit`` 228 function are unrolled, leading to large XLA computations. 229 230 Also unlike the Python analogue, the loop-carried value ``val`` must hold a 231 fixed shape and dtype across all iterations (and not just be consistent up to 232 NumPy rank/shape broadcasting and dtype promotion rules, for example). In 233 other words, the type ``a`` in the type signature above represents an array 234 with a fixed shape and dtype (or a nested tuple/list/dict container data 235 structure with a fixed structure and arrays with fixed shape and dtype at the 236 leaves). 237 238 Another difference from using Python-native loop constructs is that 239 ``while_loop`` is not reverse-mode differentiable because XLA computations 240 require static bounds on memory requirements. 241 242 Args: 243 cond_fun: function of type ``a -> Bool``. 244 body_fun: function of type ``a -> a``. 245 init_val: value of type ``a``, a type that can be a scalar, array, or any 246 pytree (nested Python tuple/list/dict) thereof, representing the initial 247 loop carry value. 248 249 Returns: 250 The output from the final iteration of body_fun, of type ``a``. 251 """ 252 if jax.api._jit_is_disabled(): 253 try: 254 val = init_val 255 while cond_fun(val): 256 val = body_fun(val) 257 return val 258 except core.ConcretizationTypeError: 259 # Can't run this while_loop in Python (e.g. because there's a vmap 260 # transformation on it), so we fall back to the primitive version. 261 pass 262 263 def _create_jaxpr(init_val): 264 init_vals, in_tree = tree_flatten((init_val,)) 265 init_avals = tuple(_map(_abstractify, init_vals)) 266 cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals) 267 body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals) 268 if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: 269 msg = "cond_fun must return a boolean scalar, but got pytree {}." 270 raise TypeError(msg.format(cond_tree)) 271 if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_): 272 msg = "cond_fun must return a boolean scalar, but got output type(s) {}." 273 raise TypeError(msg.format(cond_jaxpr.out_avals)) 274 return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree 275 276 # The body input and output avals must match exactly. However, we want to account for 277 # the case when init contains weakly-typed values (e.g. Python scalars), with avals that 278 # may not match the output despite being compatible by virtue of their weak type. 279 # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if 280 # necessary, a second time with modified init values. 281 init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) 282 new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals) 283 if changed: 284 new_init_val, = tree_unflatten(in_tree, new_init_vals) 285 init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) 286 cond_jaxpr, cond_consts, body_consts, body_tree = rest 287 288 in_tree_children = in_tree.children() 289 assert len(in_tree_children) == 1 290 _check_tree_and_avals("body_fun output and input", 291 body_tree, body_jaxpr.out_avals, 292 in_tree_children[0], init_avals) 293 outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals), 294 cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, 295 body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) 296 return tree_unflatten(body_tree, outs) 297 298def _while_loop_abstract_eval(*args, **kwargs): 299 return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals) 300 301def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args, 302 cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): 303 cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts]) 304 batched = bool(cond_jaxpr.out_avals[0].shape) 305 306 # Since jaxprs don't have tuples and have multiple return values, but we need 307 # the HLO While loop to take a single tuple input and output a single boolean 308 # (for the cond computation) or a single tuple output (for the body 309 # computation), we build XLA computations that handle the tuple munging before 310 # generating a Call into the computations formed from the jaxprs. 311 312 init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals) 313 314 cond_c = xb.make_computation_builder("cond_computation") 315 cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry)) 316 cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))] 317 x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) 318 pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env, 319 _map(partial(xb.constant, cond_c), cond_jaxpr.consts), 320 extend_name_stack(name_stack, 'cond'), *(x + z)) 321 if batched: 322 scalar = ShapedArray((), np.bool_) 323 or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar) 324 pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_, 325 list(range(cond_jaxpr.out_avals[0].ndim))) 326 327 body_c = xb.make_computation_builder("body_computation") 328 body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry)) 329 body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))] 330 x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) 331 new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env, 332 _map(partial(xb.constant, body_c), body_jaxpr.consts), 333 extend_name_stack(name_stack, 'body'), *(y + z)) 334 if batched: 335 body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env, 336 _map(partial(xb.constant, body_c), cond_jaxpr.consts), 337 extend_name_stack(name_stack, 'body_pred'), *(x + z)) 338 new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals) 339 assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast 340 new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z))) 341 342 ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry) 343 ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))] 344 _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) 345 return xops.Tuple(c, z) 346 347def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue): 348 pred_shape = c.get_shape(pred).dimensions() 349 x_shape = c.get_shape(x).dimensions() 350 y_shape = c.get_shape(y).dimensions() 351 assert x_shape == y_shape 352 if x_y_aval is core.abstract_unit: 353 return x 354 elif x_y_aval is core.abstract_token: 355 return xops.AfterAll(c, [x, y]) 356 else: 357 assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)] 358 bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape)))) 359 return xops.Select(bcast_pred, x, y) 360 361def _while_loop_batching_rule(args, dims, axis_name, 362 cond_nconsts, cond_jaxpr, 363 body_nconsts, body_jaxpr): 364 size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} 365 orig_batched = [d is not batching.not_mapped for d in dims] 366 cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) 367 368 # Fixpoint computation of which carry are batched: either 369 # batched from init, or the carry out is batched. Each iteration promotes 370 # at least one carry to batched. We need at most len(carry) iterations, 371 # but we need one last iteration to prepare the jaxpr based on the final 372 # carry_bat. 373 carry_bat = init_bat 374 for _ in range(1 + len(carry_bat)): 375 batched = bconst_bat + carry_bat 376 body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( 377 body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name) 378 cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( 379 cond_jaxpr, size, cconst_bat + carry_bat, 380 instantiate=bool(cond_jaxpr.out_avals[0].shape), 381 axis_name=axis_name) 382 carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) 383 if carry_bat_out == carry_bat: 384 break 385 else: 386 carry_bat = _map(operator.or_, carry_bat, carry_bat_out) 387 else: 388 assert False, "Fixpoint not reached" 389 390 consts, init = split_list(args, [cond_nconsts + body_nconsts]) 391 const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts]) 392 new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 393 else x for x, d in zip(consts, const_dims)] 394 new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat 395 else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x 396 for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)] 397 398 outs = while_p.bind(*(new_consts + new_init), 399 cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, 400 body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) 401 out_bdims = [0 if b else batching.not_mapped for b in carry_bat] 402 return outs, out_bdims 403 404def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, 405 body_jaxpr): 406 nonzeros = [type(t) is not ad_util.Zero for t in tangents] 407 cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts]) 408 409 carry_nz = init_nz 410 for _ in range(1 + len(carry_nz)): 411 body_nonzeros = bconst_nz + carry_nz 412 body_jvp, nonzeros_out = ad.jvp_jaxpr( 413 body_jaxpr, body_nonzeros, instantiate=carry_nz) 414 if nonzeros_out == carry_nz: 415 break 416 carry_nz = _map(operator.or_, carry_nz, nonzeros_out) 417 else: 418 assert False, "Fixpoint not reached" 419 420 nonzeros = cconst_nz + body_nonzeros 421 tangents = [ad.instantiate_zeros(t) if nz else t 422 for t, nz in zip(tangents, nonzeros)] 423 424 cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts]) 425 _, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts]) 426 bconst_dot = _prune_zeros(bconst_dot) 427 init_dot = _prune_zeros(init_dot) 428 429 num_carry = len(primals) - cond_nconsts - body_nconsts 430 431 body_jvp_rearranged = ad.rearrange_binders( 432 body_jvp, 433 [body_nconsts, num_carry], [len(bconst_dot), len(init_dot)], 434 [num_carry], [len(init_dot)]) 435 436 newvar = core.gensym([cond_jaxpr.jaxpr]) 437 invars_aug = ( 438 cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot]) 439 cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, 440 invars_aug, 441 cond_jaxpr.jaxpr.outvars, 442 cond_jaxpr.jaxpr.eqns) 443 cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts) 444 445 out = while_p.bind( 446 *(cconst + bconst + bconst_dot + init + init_dot), 447 cond_nconsts=cond_nconsts, 448 cond_jaxpr=cond_jaxpr_augmented, 449 body_nconsts=len(bconst) + len(bconst_dot), 450 body_jaxpr=body_jvp_rearranged) 451 452 out_carry, out_carry_dot = split_list(out, [num_carry]) 453 out_tangents_iter = iter(out_carry_dot) 454 out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) 455 for p, nz in zip(out_carry, nonzeros_out)] 456 return out_carry, out_tangents 457 458def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int, 459 cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int, 460 body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]: 461 """An implementation of partial evaluation for while. 462 As long as some carry (and hence output) are known and the output 463 of `cond_jaxpr` is known, we use a portion of the loop body to compute the known 464 outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run 465 the whole while, including recomputing the known parts. 466 467 This means that we don't actually save any computation by partial 468 evaluation if there are unknown outputs. 469 470 What this achieves is that we can give a proper error for reverse 471 differentiation of `while`, because in that use of partial evaluation the 472 primal inputs are considered "known", and only the tangent computation is 473 unknown (see issue #2129). 474 """ 475 unknowns = [not t.pval.is_known() for t in tracers] 476 params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr, 477 body_nconsts=body_nconsts, body_jaxpr=body_jaxpr) 478 479 if config.omnistaging_enabled: 480 partial_eval_jaxpr = pe.partial_eval_jaxpr 481 else: 482 partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) 483 484 cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts]) 485 # Fixpoint computation of unknown carry. Each iteration promotes 486 # at least one carry to unknown. We need one last iteration to prepare the jaxpr. 487 carry_uk = carry_init_uk 488 for _ in range(1 + len(carry_uk)): 489 body_jaxpr_known, _, carry_out_uk = partial_eval_jaxpr( # type: ignore 490 body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk) 491 if carry_out_uk == carry_uk: 492 break 493 else: 494 carry_uk = _map(operator.or_, carry_uk, carry_out_uk) 495 else: 496 assert False, "Fixpoint not reached" 497 498 cond_jaxpr_known, _, cond_uk = partial_eval_jaxpr( # type: ignore 499 cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False) 500 501 if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns): 502 # If conditional is unknown, or all inputs are known, or all are unknown, 503 # just do the default processing. 504 return trace.default_process_primitive(while_p, tracers, params) 505 506 # Run the known part of the while. Prepare the inputs, as constants (if known), or 507 # as core.unit. 508 in_consts = [ core.unit if uk else t.pval.get_known() 509 for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk, 510 tracers)] 511 # There should be no residuals for the cond_jaxpr_known 512 assert 1 == len(cond_jaxpr_known.out_avals) 513 # We ignore the residuals from the body_jaxpr_known, so the type of inputs matches 514 # the type of outputs; residuals are at the end 515 if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals): 516 # TODO(necula): this is not quite enough; we should drop the residual computations also 517 body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)] 518 out_known = while_p.bind( 519 *in_consts, 520 cond_nconsts=cond_nconsts, 521 cond_jaxpr=cond_jaxpr_known, 522 body_nconsts=body_nconsts, 523 body_jaxpr=body_jaxpr_known) 524 525 # Run the whole while_loop to get all the outputs, then merge with known ones 526 out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params) 527 out_tracers: Sequence[pe.Tracer] = [ 528 out_unknown if uk 529 else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe) 530 for uk, out_unknown, known in zip(carry_uk, out_all, out_known)] 531 532 return out_tracers 533 534def _while_transpose_error(*_, **kwargs): 535 raise ValueError("Reverse-mode differentiation does not work for " 536 "lax.while_loop or lax.fori_loop. " 537 "Try using lax.scan instead.") 538 539while_p = lax.Primitive('while') 540while_p.multiple_results = True 541while_p.def_impl(partial(xla.apply_primitive, while_p)) 542while_p.def_abstract_eval(_while_loop_abstract_eval) 543ad.primitive_jvps[while_p] = _while_loop_jvp 544pe.custom_partial_eval_rules[while_p] = _while_partial_eval 545xla.initial_style_translations[while_p] = _while_loop_translation_rule 546ad.primitive_transposes[while_p] = _while_transpose_error 547batching.initial_style_batchers[while_p] = _while_loop_batching_rule 548 549 550### cond and switch 551 552def switch(index, branches: Sequence[Callable], operand): 553 """Apply exactly one of ``branches`` given by ``index``. 554 555 If ``index`` is out of bounds, it is clamped to within bounds. 556 557 Has the semantics of the following Python:: 558 559 def switch(index, branches, operand): 560 index = clamp(0, index, len(branches) - 1) 561 return branches[index](operand) 562 563 Args: 564 index: Integer scalar type, indicating which branch function to apply. 565 branches: Sequence of functions (A -> B) to be applied based on `index`. 566 operand: Operand (A) input to whichever branch is applied. 567 """ 568 if len(np.shape(index)) != 0: 569 raise TypeError( 570 f"Branch index must be scalar, " 571 f"got {index} of shape {np.shape(index)}.") 572 573 try: 574 index_dtype = dtypes.result_type(index) 575 except TypeError as err: 576 msg = f"Index type must be an integer, got {index}." 577 raise TypeError(msg) from err 578 579 if index_dtype.kind not in 'iu': 580 raise TypeError( 581 f"Index type must be an integer, got {index} as {index_dtype}") 582 583 branches = tuple(branches) 584 585 if len(branches) == 0: 586 raise ValueError("Empty branch sequence") 587 elif len(branches) == 1: 588 return branches[0](operand) 589 590 index = lax.convert_element_type(index, np.int32) 591 lo = np.array(0, np.int32) 592 hi = np.array(len(branches) - 1, np.int32) 593 index = lax.clamp(lo, index, hi) 594 595 if (jax.api._jit_is_disabled() and 596 isinstance(core.get_aval(index), ConcreteArray)): 597 return branches[int(index)](operand) 598 599 ops, ops_tree = tree_flatten((operand,)) 600 ops_avals = tuple(_map(_abstractify, ops)) 601 602 jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( 603 branches, ops_tree, ops_avals) 604 605 for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): 606 _check_tree_and_avals(f"branch 0 and {i + 1} outputs", 607 out_trees[0], jaxprs[0].out_avals, 608 out_tree, jaxpr.out_avals) 609 610 linear = (False,) * (len(consts) + len(ops)) 611 out = cond_p.bind( 612 index, *consts, *ops, branches=tuple(jaxprs), linear=linear) 613 return tree_unflatten(out_trees[0], out) 614 615 616def cond(*args, **kwargs): 617 """Conditionally apply ``true_fun`` or ``false_fun``. 618 619 ``cond()`` has equivalent semantics to this Python implementation:: 620 621 def cond(pred, true_fun, false_fun, operand): 622 if pred: 623 return true_fun(operand) 624 else: 625 return false_fun(operand) 626 627 ``pred`` must be a scalar type. 628 629 Functions ``true_fun``/``false_fun`` may not need to refer to an ``operand`` 630 to compute their result, but one must still be provided to the ``cond`` call 631 and be accepted by both the branch functions, e.g.:: 632 633 jax.lax.cond( 634 get_predicate_value(), 635 lambda _: 23, 636 lambda _: 42, 637 operand=None) 638 639 640 Args: 641 pred: Boolean scalar type, indicating which branch function to apply. 642 true_fun: Function (A -> B), to be applied if ``pred`` is True. 643 false_fun: Function (A -> B), to be applied if ``pred`` is False. 644 operand: Operand (A) input to either branch depending on ``pred``. The type 645 can be a scalar, array, or any pytree (nested Python tuple/list/dict) 646 thereof. 647 648 Returns: 649 Value (B) of either ``true_fun(operand)`` or ``false_fun(operand)``, 650 depending on the value of ``pred``. The type can be a scalar, array, or any 651 pytree (nested Python tuple/list/dict) thereof. 652 """ 653 654 # detect an attempt to call the former, deprecated cond 655 try: 656 ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) 657 except TypeError: 658 pass 659 else: 660 return _cond_with_per_branch_args(*ba.args) 661 662 return _cond(*args, **kwargs) 663 664def _cond(pred, true_fun: Callable, false_fun: Callable, operand): 665 if len(np.shape(pred)) != 0: 666 raise TypeError( 667 f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.") 668 669 try: 670 pred_dtype = dtypes.result_type(pred) 671 except TypeError as err: 672 msg = ("Pred type must be either boolean or number, got {}.") 673 raise TypeError(msg.format(pred)) from err 674 675 if pred_dtype.kind != 'b': 676 if pred_dtype.kind in 'iuf': 677 pred = pred != 0 678 else: 679 msg = ("Pred type must be either boolean or number, got {}.") 680 raise TypeError(msg.format(pred_dtype)) 681 682 if jax.api._jit_is_disabled() and isinstance(core.get_aval(pred), ConcreteArray): 683 if pred: 684 return true_fun(operand) 685 else: 686 return false_fun(operand) 687 688 ops, ops_tree = tree_flatten((operand,)) 689 ops_avals = tuple(_map(_abstractify, ops)) 690 691 jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( 692 (true_fun, false_fun), ops_tree, ops_avals) 693 true_jaxpr, false_jaxpr = jaxprs 694 out_tree, false_out_tree = out_trees 695 696 _check_tree_and_avals("true_fun and false_fun output", 697 out_tree, true_jaxpr.out_avals, 698 false_out_tree, false_jaxpr.out_avals) 699 700 index = lax.convert_element_type(pred, np.int32) 701 702 linear = (False,) * (len(consts) + len(ops)) 703 out = cond_p.bind( 704 index, *consts, *ops, 705 branches=(false_jaxpr, true_jaxpr), linear=linear) 706 return tree_unflatten(out_tree, out) 707 708def _cond_with_per_branch_args(pred, 709 true_operand, true_fun: Callable, 710 false_operand, false_fun: Callable): 711 """Conditionally apply ``true_fun`` or ``false_fun``. 712 713 Has equivalent semantics to this Python implementation:: 714 715 def cond(pred, true_operand, true_fun, false_operand, false_fun): 716 if pred: 717 return true_fun(true_operand) 718 else: 719 return false_fun(false_operand) 720 721 Pred has to be a scalar type, collection types (list, tuple) are not supported 722 """ 723 return _cond(pred, 724 lambda op: true_fun(op[0]), 725 lambda op: false_fun(op[1]), 726 (true_operand, false_operand)) 727 728def _cond_abstract_eval(*args, **kwargs): 729 return _map(raise_to_shaped, kwargs["branches"][0].out_avals) 730 731def _cond_translation_rule(c, axis_env, name_stack, avals, backend, 732 index, *args, branches, linear): 733 del linear # Unused. 734 735 def make_computation(name, jaxpr, op_shape): 736 c = xb.make_computation_builder(name + '_comp') 737 op = xb.parameter(c, 0, op_shape) 738 ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] 739 outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env, 740 _map(partial(xb.constant, c), jaxpr.consts), 741 extend_name_stack(name_stack, name + '_fun'), *ops) 742 return c.build(xops.Tuple(c, outs)) 743 744 op = xops.Tuple(c, args) 745 op_shape = c.get_shape(op) 746 branch_computations = [ 747 make_computation(f'branch_{i}', jaxpr, op_shape) 748 for i, jaxpr in enumerate(branches)] 749 return xops.Conditional(index, branch_computations, [op] * len(branches)) 750 751def _select_tree(indices, branch_vals): 752 assert len(branch_vals) > 0 753 if len(branch_vals) == 1: 754 return branch_vals[0] 755 mid = len(branch_vals) // 2 756 mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices))) 757 return lax.select(lax.lt(indices, mid), 758 _select_tree(indices, branch_vals[:mid]), 759 _select_tree(indices - mid, branch_vals[mid:])) 760 761def _cond_index_bcast_and_select_tree(indices, branch_vals): 762 if all(core.get_aval(x) is core.abstract_unit for x in branch_vals): 763 return branch_vals[0] 764 else: 765 bcast_indices = lax.broadcast_in_dim( 766 indices, np.shape(branch_vals[0]), list(range(np.ndim(indices)))) 767 return _select_tree(bcast_indices, branch_vals) 768 769def _cond_batching_rule(args, dims, axis_name, branches, linear): 770 # TODO: maybe avoid moving arg axes to front if we're promoting to select? 771 size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} 772 args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 773 else x for x, d in zip(args, dims)] 774 orig_bat = [d is not batching.not_mapped for d in dims] 775 del dims 776 index, *ops = args 777 index_bat, *bat = orig_bat 778 779 branches_out_bat = [batching.batch_jaxpr(jaxpr, size, bat, False, axis_name)[1] 780 for jaxpr in branches] 781 out_bat = [any(bat) for bat in zip(*branches_out_bat)] 782 783 branches_batched = tuple(batching.batch_jaxpr(jaxpr, size, bat, out_bat, axis_name)[0] 784 for jaxpr in branches) 785 786 if index_bat: 787 branch_outs = [] 788 for jaxpr in branches_batched: 789 out = core.jaxpr_as_fun(jaxpr)(*ops) 790 out = [batching.broadcast(x, size, 0) if not b else x 791 for x, b in zip(out, out_bat)] 792 branch_outs.append(out) 793 return [_cond_index_bcast_and_select_tree(index, outs) 794 for outs in zip(*branch_outs)], [0] * len(branch_outs[0]) 795 else: 796 out_dims = [0 if b else batching.not_mapped for b in out_bat] 797 out = cond_p.bind( 798 index, *ops, branches=branches_batched, linear=linear) 799 return out, out_dims 800 801def _cond_jvp(primals, tangents, branches, linear): 802 nonzeros = [type(t) is not ad_util.Zero for t in tangents] 803 804 index_nz, *ops_nz = nonzeros 805 assert index_nz is False 806 807 branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1] 808 for jaxpr in branches] 809 out_nz = [any(nz) for nz in zip(*branches_out_nz)] 810 811 branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0] 812 for jaxpr in branches) 813 814 index, *ops = primals 815 _, *ops_dot = tangents 816 ops_dot = _prune_zeros(ops_dot) 817 818 ops_lin = tuple(linear) 819 linear_jvp = ops_lin + (True,) * len(ops_dot) 820 out = cond_p.bind( 821 index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp) 822 out_primals, out_tangents = split_list(out, [len(out_nz)]) 823 out_tangents_iter = iter(out_tangents) 824 out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) 825 for p, nz in zip(out_primals, out_nz)] 826 return out_primals, out_tangents 827 828def _cond_partial_eval(trace, *tracers, branches, linear): 829 unknowns = [t.pval[0] is not None for t in tracers] 830 index_uk, *ops_uk = unknowns 831 832 if config.omnistaging_enabled: 833 partial_eval_jaxpr = pe.partial_eval_jaxpr 834 else: 835 partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) 836 837 if index_uk: 838 # When the branch index is unknown, we stage out the whole cond. 839 params = dict(branches=branches, linear=linear) 840 return trace.default_process_primitive(cond_p, tracers, params) 841 842 branches_out_uks = [] 843 for branch_jaxpr in branches: 844 _, _, out_uks = partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False) 845 branches_out_uks.append(out_uks) 846 out_uks = [any(uks) for uks in zip(*branches_out_uks)] 847 848 branches_1, branches_2, branch_res_avals = [], [], [] 849 for branch_jaxpr in branches: 850 branch_jaxpr_1, branch_jaxpr_2, _ = partial_eval_jaxpr( 851 branch_jaxpr, ops_uk, instantiate=out_uks) 852 branch_num_res = len(branch_jaxpr_1.out_avals) - len(out_uks) 853 854 # move residuals to the front 855 move = [False] * len(ops_uk) + [True] * branch_num_res 856 branch_jaxpr_2 = pe.move_binders_to_front(branch_jaxpr_2, move) 857 858 # TODO(frostig,mattjj): pe.partial_eval_jaxpr should raise to shaped avals 859 res_avals = _map( 860 raise_to_shaped, branch_jaxpr_2.in_avals[:branch_num_res]) 861 862 branches_1.append(branch_jaxpr_1) 863 branches_2.append(branch_jaxpr_2) 864 branch_res_avals.append(res_avals) 865 866 branches_1 = tuple(branches_1) 867 branches_2 = tuple(branches_2) 868 869 for jaxpr in branches_2[1:]: 870 assert len(jaxpr.out_avals) == len(branches_2[0].out_avals) 871 872 num_outs = len(branches_2[0].out_avals) 873 874 all_res_avals, res_avals_per_branch = _merge_branch_residuals( 875 branch_res_avals) 876 877 branches_1 = _join_cond_outputs( 878 branches_1, all_res_avals, res_avals_per_branch, num_outs) 879 branches_2 = _join_cond_pe_staged_jaxpr_inputs( 880 branches_2, all_res_avals, res_avals_per_branch) 881 882 # TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr 883 # raises to shaped avals 884 # for j in branches_1[1:]: 885 # assert j.out_avals == branches_1[0].out_avals 886 num_res = len(all_res_avals) 887 888 _, in_consts = unzip2([t.pval for t in tracers]) 889 out_consts_res = cond_p.bind(*in_consts, branches=branches_1, linear=linear) 890 out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) 891 892 # TODO(frostig,mattjj): remove raised_to_shaped of avals once 893 # pe.partial_eval_jaxpr handles it 894 out_avals = _map(raise_to_shaped, branches_2[0].out_avals) 895 out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uks)] 896 897 index_tracer = trace.instantiate_const(tracers[0]) 898 899 ops_tracers = [trace.instantiate_const(t) if uk 900 else trace.new_instantiated_literal(core.unit) 901 for uk, t in zip(unknowns[1:], tracers[1:])] 902 903 res_tracers = _map(trace.new_instantiated_const, res) 904 905 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 906 for pv, const in zip(out_pvs, out_consts)] 907 908 linear_2 = (False,) * num_res + linear 909 params = dict(branches=branches_2, linear=linear_2) 910 eqn = pe.new_eqn_recipe( 911 [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, 912 source_info_util.current()) 913 for t in out_tracers: t.recipe = eqn 914 return out_tracers 915 916# When partially evaluating conditionals, each branch produces residuals 917# depending on the computation carried out by the branch, and a corresponding 918# staged jaxpr that accepts those residuals as its first few inputs. The 919# residual-producing branches are staged as jaxprs and bound right away in a 920# conditional. The residual-consuming jaxprs are assembled together in a jaxpr 921# conditional. The following helper functions ensure that both collections of 922# jaxprs (those evaluated and those staged) are valid for joint use under their 923# respective conditionals. 924# 925# In particular, the residuals derived from each original branch may have 926# distinct types. Because the branches of conditionals must have identical type 927# signatures, we join residuals together across branches into a common format. 928 929# In order to set up a type signature that all branches can conform to, it would 930# suffice to concatenate all branches' residuals. But concatenation can result 931# in redundant inputs and outputs, and might lead to memory allocation that 932# scales unnecessarily with the branch count. This function finds common 933# residual types across branches for reuse, so as to avoid redundant 934# allocation. It returns a list L of types (avals) representing the collection 935# of residuals merged according to type, and, for each branch, a lookup table to 936# match its residuals to their positions/types in L. Example input/output: 937# 938# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]] 939# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]] 940# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]] 941def _merge_branch_residuals(branch_res_avals): 942 def enumerate_equal(xs): 943 counts = {v: itertools.count() for v in set(xs)} 944 return [(x, next(counts[x])) for x in xs] 945 branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals) 946 all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals)) 947 indices = {v: i for i, v in enumerate(all_tagged_avals)} 948 branch_indices = [ 949 [indices[aval] for aval in avals] for avals in branch_res_tagged_avals] 950 all_avals = [x for x, _ in all_tagged_avals] 951 return all_avals, branch_indices 952 953# This function augments branch outputs to agree with the merged residual 954# format: each branch is made to return zero-filled values in the places of 955# residual outputs that it does not populate. 956def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr, 957 num_non_res_outputs): 958 def augment_jaxpr(jaxpr, res_indices): 959 @lu.wrap_init 960 def f_aug(*args): 961 outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) 962 outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) 963 aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals) 964 aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) 965 return outs + list(aug_residuals) 966 967 return _make_closed_jaxpr(f_aug, jaxpr.in_avals) 968 969 return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) 970 971# This function augments branch inputs to agree with the merged residual format: 972# each branch is made to accept all residuals, even though it will ignore those 973# that it does not read. 974def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, 975 res_aval_indices_per_jaxpr): 976 newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_') 977 all_res_vars = _map(newvar, all_res_avals) 978 979 def augment_jaxpr(jaxpr, res_indices): 980 num_res = len(res_indices) 981 res_vars = jaxpr.jaxpr.invars[:num_res] 982 non_res_vars = jaxpr.jaxpr.invars[num_res:] 983 984 aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars))) 985 aug_invars = aug_res_vars + non_res_vars 986 jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, 987 jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns) 988 jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) 989 return jaxpr_aug 990 991 return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) 992 993def _ordered_unique(xs): 994 d = collections.OrderedDict((x, None) for x in xs) 995 return list(d.keys()) 996 997def _transpose_cond_jaxpr(jaxpr, num_res): 998 res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) 999 primal_avals = _map(raise_to_shaped, primal_avals) 1000 1001 @lu.wrap_init 1002 def transposed(*args): 1003 res, cts_out = split_list(args, [num_res]) 1004 primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] 1005 cts_in = ad.backward_pass( 1006 jaxpr.jaxpr, jaxpr.consts, primals, cts_out) 1007 _, cts_in = split_list(cts_in, [num_res]) 1008 return _map(ad.instantiate_zeros_aval, primal_avals, cts_in) 1009 1010 return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) 1011 1012def _cond_transpose(cts, *args, branches, linear): 1013 index, *ops = args 1014 in_avals = _map(raise_to_shaped, branches[0].in_avals) 1015 num_res = len(ops) - sum(linear) 1016 1017 branches_trans = tuple( 1018 _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) 1019 lin_in_avals = [raise_to_shaped(a, weak_type=False) 1020 for a, l in zip(in_avals, linear) if l] 1021 assert all(core.typematch(out_aval, lin_in_aval) 1022 for jaxpr in branches_trans 1023 for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) 1024 1025 res = ops[:num_res] 1026 cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts) 1027 linear_trans = (False,) * num_res + (True,) * len(cts) 1028 1029 out = cond_p.bind( 1030 index, *res, *cts, branches=branches_trans, linear=linear_trans) 1031 assert all(_map(core.typecheck, lin_in_avals, out)) 1032 1033 out_iter = iter(out) 1034 out = [next(out_iter) if l else None for l in linear] 1035 assert next(out_iter, None) is None 1036 return [None] + out 1037 1038def _avals_short(avals): 1039 to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))() 1040 return ' '.join(_map(to_str, avals)) 1041 1042def _cond_typecheck(*avals, branches, linear): 1043 tc = partial(_typecheck_param, 'cond') 1044 tc(branches, 'branches', 'tuple of ClosedJaxpr', 1045 type(branches) is tuple and 1046 all(type(x) is core.ClosedJaxpr for x in branches)) 1047 tc(linear, 'linear', 'tuple of bool', 1048 type(linear) is tuple and all(type(x) is bool for x in linear)) 1049 1050 core.typecheck_assert( 1051 len(branches) > 0, 1052 'cond requires at least one branch function') 1053 core.typecheck_assert( 1054 len(linear) + 1 == len(avals), 1055 f'cond given {len(linear)} linear flags for ' 1056 f'{len(avals) - 1} non-predicate operands') 1057 1058 jaxpr0 = branches[0] 1059 jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) 1060 jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) 1061 1062 for i, jaxpr in enumerate(branches[1:]): 1063 core.typecheck_assert( 1064 len(jaxpr0.in_avals) == len(jaxpr.in_avals), 1065 f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, ' 1066 f'branch {i+1} takes {len(jaxpr.in_avals)}') 1067 core.typecheck_assert( 1068 len(jaxpr0.out_avals) == len(jaxpr.out_avals), 1069 f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, ' 1070 f'branch {i+1} outputs {len(jaxpr.out_avals)}') 1071 core.typecheck_assert( 1072 all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)), 1073 f'cond branches 0 and {i+1} have mismatching input types: ' 1074 f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}') 1075 core.typecheck_assert( 1076 all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)), 1077 f'cond branches 0 and {i+1} have mismatching output types: ' 1078 f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}') 1079 1080 core.typecheck_assert( 1081 len(avals) == 1 + len(jaxpr0.in_avals), 1082 f'cond called with {len(avals) - 1} non-predicate operands, ' 1083 f'but branches take {len(jaxpr0.in_avals)} inputs') 1084 1085 index_aval, *op_avals = avals 1086 core.typecheck_assert( 1087 index_aval.dtype == np.int32, 1088 f'cond called with index of type {index_aval.dtype} instead of int32') 1089 core.typecheck_assert( 1090 all(_map(core.typecompat, jaxpr0.in_avals, op_avals)), 1091 f'cond branches take input types {jaxpr0_in_avals_str}, ' 1092 f'called with operands of type {_avals_short(op_avals)}') 1093 1094def cond_bind(*args, branches, linear): 1095 if not core.skip_checks: 1096 avals = _map(core.get_aval, args) 1097 _cond_typecheck(*avals, branches=branches, linear=linear) 1098 for jaxpr in branches: 1099 core.check_jaxpr(jaxpr.jaxpr) 1100 return core.Primitive.bind(cond_p, *args, branches=branches, linear=linear) 1101 1102cond_p = lax.Primitive('cond') 1103cond_p.multiple_results = True 1104cond_p.def_impl(partial(xla.apply_primitive, cond_p)) 1105cond_p.def_abstract_eval(_cond_abstract_eval) 1106cond_p.def_custom_bind(cond_bind) 1107ad.primitive_jvps[cond_p] = _cond_jvp 1108ad.primitive_transposes[cond_p] = _cond_transpose 1109pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval 1110batching.initial_style_batchers[cond_p] = _cond_batching_rule 1111xla.initial_style_translations[cond_p] = _cond_translation_rule 1112core.custom_typechecks[cond_p] = _cond_typecheck 1113 1114 1115### scan 1116 1117Carry = TypeVar('Carry') 1118X = TypeVar('X') 1119Y = TypeVar('Y') 1120 1121def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], 1122 init: Carry, 1123 xs: X, 1124 length: Optional[int] = None, 1125 reverse: bool = False, 1126 unroll: int = 1) -> Tuple[Carry, Y]: 1127 """Scan a function over leading array axes while carrying along state. 1128 1129 The type signature in brief is 1130 1131 .. code-block:: haskell 1132 1133 scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) 1134 1135 where we use [t] here to denote the type t with an additional leading axis. 1136 That is, if t is an array type then [t] represents the type with an additional 1137 leading axis, and if t is a pytree (container) type with array leaves then [t] 1138 represents the type with the same pytree structure and corresponding leaves 1139 each with an additional leading axis. 1140 1141 When ``a`` is an array type or None, and ``b`` is an array type, the semantics 1142 of ``scan`` are given roughly by this Python implementation:: 1143 1144 def scan(f, init, xs, length=None): 1145 if xs is None: 1146 xs = [None] * length 1147 carry = init 1148 ys = [] 1149 for x in xs: 1150 carry, y = f(carry, x) 1151 ys.append(y) 1152 return carry, np.stack(ys) 1153 1154 Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree 1155 types, and so multiple arrays can be scanned over at once and produce multiple 1156 output arrays. (None is actually an empty pytree.) 1157 1158 Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to 1159 a single XLA While HLO. That makes it useful for reducing compilation times 1160 for jit-compiled functions, since native Python loop constructs in an ``@jit`` 1161 function are unrolled, leading to large XLA computations. 1162 1163 Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype 1164 across all iterations (and not just be consistent up to NumPy rank/shape 1165 broadcasting and dtype promotion rules, for example). In other words, the type 1166 ``c`` in the type signature above represents an array with a fixed shape and 1167 dtype (or a nested tuple/list/dict container data structure with a fixed 1168 structure and arrays with fixed shape and dtype at the leaves). 1169 1170 Args: 1171 f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning 1172 that ``f`` accepts two arguments where the first is a value of the loop 1173 carry and the second is a slice of ``xs`` along its leading axis, and that 1174 ``f`` returns a pair where the first element represents a new value for 1175 the loop carry and the second represents a slice of the output. 1176 init: an initial loop carry value of type ``c``, which can be a scalar, 1177 array, or any pytree (nested Python tuple/list/dict) thereof, representing 1178 the initial loop carry value. This value must have the same structure as 1179 the first element of the pair returned by ``f``. 1180 xs: the value of type ``[a]`` over which to scan along the leading axis, 1181 where ``[a]`` can be an array or any pytree (nested Python 1182 tuple/list/dict) thereof with consistent leading axis sizes. 1183 length: optional integer specifying the number of loop iterations, which 1184 must agree with the sizes of leading axes of the arrays in ``xs`` (but can 1185 be used to perform scans where no input ``xs`` are needed). 1186 reverse: optional boolean specifying whether to run the scan iteration 1187 forward (the default) or in reverse, equivalent to reversing the leading 1188 axes of the arrays in both ``xs`` and in ``ys``. 1189 unroll: optional positive int specifying, in the underlying operation of the 1190 scan primitive, how many scan iterations to unroll within a single 1191 iteration of a loop. 1192 1193 Returns: 1194 A pair of type ``(c, [b])`` where the first element represents the final 1195 loop carry value and the second element represents the stacked outputs of 1196 the second output of ``f`` when scanned over the leading axis of the inputs. 1197 """ 1198 xs_flat, xs_tree = tree_flatten(xs) 1199 1200 try: 1201 lengths = [x.shape[0] for x in xs_flat] 1202 except AttributeError as err: 1203 msg = "scan got value with no leading axis to scan over: {}." 1204 raise ValueError( 1205 msg.format(', '.join(str(x) for x in xs_flat 1206 if not hasattr(x, 'shape')))) from err 1207 1208 if length is not None: 1209 length = int(length) 1210 if not all(length == l for l in lengths): 1211 msg = ("scan got `length` argument of {} which disagrees with " 1212 "leading axis sizes {}.") 1213 raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) 1214 else: 1215 unique_lengths = set(lengths) 1216 if len(unique_lengths) > 1: 1217 msg = "scan got values with different leading axis sizes: {}." 1218 raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) 1219 elif len(unique_lengths) == 0: 1220 msg = "scan got no values to scan over and `length` not provided." 1221 raise ValueError(msg) 1222 else: 1223 length, = unique_lengths 1224 1225 if jax.api._jit_is_disabled(): 1226 carry = init 1227 ys = [] 1228 maybe_reversed = reversed if reverse else lambda x: x 1229 for i in maybe_reversed(range(length)): 1230 xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] 1231 carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) 1232 ys.append(y) 1233 stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit 1234 else jax.numpy.stack((y, *ys))) 1235 stacked_y = tree_multimap(stack, *maybe_reversed(ys)) 1236 return carry, stacked_y 1237 1238 x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] 1239 x_dtypes = [x.dtype for x in xs_flat] 1240 x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) 1241 1242 def _create_jaxpr(init): 1243 init_flat, init_tree = tree_flatten(init) 1244 in_flat, in_tree = tree_flatten((init, xs)) 1245 1246 carry_avals = tuple(_map(_abstractify, init_flat)) 1247 jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals) 1248 out_tree_children = out_tree.children() 1249 if len(out_tree_children) != 2: 1250 msg = "scan body output must be a pair, got {}." 1251 raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) 1252 carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] 1253 return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children 1254 1255 # The carry input and output avals must match exactly. However, we want to account for 1256 # the case when init contains weakly-typed values (e.g. Python scalars), with avals that 1257 # may not match the output despite being compatible by virtue of their weak type. 1258 # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if 1259 # necessary, a second time with modified init values. 1260 init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) 1261 new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) 1262 if changed: 1263 new_init = tree_unflatten(init_tree, new_init_flat) 1264 init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init) 1265 in_flat, jaxpr, consts, out_tree, out_tree_children = rest 1266 1267 _check_tree_and_avals("scan carry output and input", 1268 # Extract the subtree and avals for the first element of the return tuple 1269 out_tree_children[0], carry_avals_out, 1270 init_tree, carry_avals) 1271 1272 out = scan_p.bind(*itertools.chain(consts, in_flat), 1273 reverse=reverse, length=length, jaxpr=jaxpr, 1274 num_consts=len(consts), num_carry=len(init_flat), 1275 linear=(False,) * (len(consts) + len(in_flat)), 1276 unroll=unroll) 1277 return tree_unflatten(out_tree, out) 1278 1279def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, 1280 f_impl, x_avals, y_avals): 1281 consts, init, xs = split_list(args, [num_consts, num_carry]) 1282 1283 carry = init 1284 ys = [] 1285 1286 for i in range(length): 1287 i_ = length - i - 1 if reverse else i 1288 x = _map(partial(_index_array, i_), x_avals, xs) 1289 out = f_impl(*consts, *carry, *x) 1290 carry, y = split_list(out, [num_carry]) 1291 ys.append(y) 1292 1293 ys = list(reversed(ys)) if reverse else ys 1294 ys = list(zip(*ys)) 1295 ys = _map(_stack, y_avals, ys) 1296 return (*carry, *ys) 1297 1298def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, 1299 f_impl, x_avals, y_avals): 1300 consts, init, xs = split_list(args, [num_consts, num_carry]) 1301 1302 def cond_fun(vals): 1303 i, *_ = vals 1304 return i < length 1305 1306 def body_fun(vals): 1307 [i], carry, ys = split_list(vals, [1, num_carry]) 1308 i_ = length - i - 1 if reverse else i 1309 x = _map(partial(_dynamic_index_array, i_), x_avals, xs) 1310 out_flat = f_impl(*consts, *carry, *x) 1311 carry_out, y_updates = split_list(out_flat, [num_carry]) 1312 ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) 1313 return [i + 1] + carry_out + ys_out 1314 1315 ys_init = _map(partial(_empty_array, length), y_avals) 1316 if length == 0: 1317 return init + ys_init 1318 else: 1319 init_val = [lax._const(length, 0)] + init + ys_init 1320 _, *outs = while_loop(cond_fun, body_fun, init_val) 1321 return outs 1322 1323def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, 1324 linear, block_length, f_impl, x_avals, y_avals): 1325 consts, init, xs = split_list(args, [num_consts, num_carry]) 1326 1327 num_blocks, rem = divmod(length, block_length) 1328 assert rem == 0 1329 1330 partition = partial(_partition_leading, num_blocks, block_length) 1331 xs_block = _map(partition, x_avals, xs) 1332 1333 prepend_aval = partial(_prepend_dim_to_aval, block_length) 1334 x_block_avals = _map(prepend_aval, x_avals) 1335 y_block_avals = _map(prepend_aval, y_avals) 1336 1337 f_impl_block = partial( 1338 _scan_impl_unrolled, reverse=reverse, length=block_length, 1339 num_consts=num_consts, num_carry=num_carry, linear=linear, 1340 f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) 1341 1342 outs = _scan_impl_loop( 1343 *consts, *init, *xs_block, reverse=reverse, length=num_blocks, 1344 num_consts=num_consts, num_carry=num_carry, linear=linear, 1345 f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) 1346 1347 carry, ys_blocks = split_list(outs, [num_carry]) 1348 combine = partial(_combine_leading, num_blocks, block_length) 1349 ys = _map(combine, y_avals, ys_blocks) 1350 return (*carry, *ys) 1351 1352def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, 1353 unroll): 1354 _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) 1355 _, y_avals = split_list(jaxpr.out_avals, [num_carry]) 1356 f_impl = core.jaxpr_as_fun(jaxpr) 1357 1358 if unroll == 1: 1359 return _scan_impl_loop( 1360 *args, reverse=reverse, length=length, num_consts=num_consts, 1361 num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, 1362 y_avals=y_avals) 1363 1364 consts, init, xs = split_list(args, [num_consts, num_carry]) 1365 num_blocks, rem = divmod(length, unroll) 1366 length_div = num_blocks * unroll 1367 1368 if rem > 0: 1369 if reverse: 1370 split = partial(_split_leading_dim, rem) 1371 xs_rem, xs = unzip2(_map(split, x_avals, xs)) 1372 else: 1373 split = partial(_split_leading_dim, length_div) 1374 xs, xs_rem = unzip2(_map(split, x_avals, xs)) 1375 1376 outs = _scan_impl_block_unrolled( 1377 *consts, *init, *xs, reverse=reverse, length=length_div, 1378 num_consts=num_consts, num_carry=num_carry, linear=linear, 1379 block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) 1380 1381 carry, ys = split_list(outs, [num_carry]) 1382 1383 if rem > 0: 1384 outs = _scan_impl_unrolled( 1385 *consts, *carry, *xs_rem, reverse=reverse, length=rem, 1386 num_consts=num_consts, num_carry=num_carry, linear=linear, 1387 f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) 1388 carry, ys_rem = split_list(outs, [num_carry]) 1389 if reverse: 1390 ys = _map(_concatenate, y_avals, ys_rem, ys) 1391 else: 1392 ys = _map(_concatenate, y_avals, ys, ys_rem) 1393 1394 return (*carry, *ys) 1395 1396def _stack(aval, vals): 1397 if aval is core.abstract_unit: 1398 return core.unit 1399 else: 1400 vals = [lax.expand_dims(x, (0,)) for x in vals] 1401 return lax.concatenate(vals, 0) 1402 1403def _concatenate(aval, x1, x2): 1404 if aval is core.abstract_unit: 1405 return core.unit 1406 else: 1407 return lax.concatenate([x1, x2], 0) 1408 1409def _split_leading_dim(i, aval, x): 1410 if aval is core.abstract_unit: 1411 return (core.unit, core.unit) 1412 else: 1413 assert x.ndim >= 1 1414 return (lax.slice_in_dim(x, 0, i), 1415 lax.slice_in_dim(x, i, x.shape[0])) 1416 1417def _dynamic_index_array(i, aval, x): 1418 if aval is core.abstract_unit: 1419 return core.unit 1420 else: 1421 return lax.dynamic_index_in_dim(x, i, keepdims=False) 1422 1423def _index_array(i, aval, x): 1424 if aval is core.abstract_unit: 1425 return core.unit 1426 else: 1427 return lax.index_in_dim(x, i, keepdims=False) 1428 1429def _empty_array(sz, aval): 1430 if aval is core.abstract_unit: 1431 return core.unit 1432 else: 1433 return lax.full((sz,) + aval.shape, 0, aval.dtype) 1434 1435def _update_array(i, aval, xs, x): 1436 if aval is core.abstract_unit: 1437 return core.unit 1438 else: 1439 return lax.dynamic_update_index_in_dim(xs, x, i, 0) 1440 1441def _partition_leading(sz0, sz1, aval, x): 1442 if aval is core.abstract_unit: 1443 return core.unit 1444 else: 1445 assert x.ndim >= 1 1446 assert x.shape[0] == sz0 * sz1 1447 return lax.reshape(x, (sz0, sz1, *x.shape[1:])) 1448 1449def _combine_leading(sz0, sz1, aval, x): 1450 if aval is core.abstract_unit: 1451 return core.unit 1452 else: 1453 assert x.ndim >= 2 1454 assert x.shape[0] == sz0 1455 assert x.shape[1] == sz1 1456 return lax.collapse(x, 0, 2) 1457 1458def _prepend_dim_to_aval(sz, aval): 1459 if aval is core.abstract_unit: 1460 return aval 1461 elif isinstance(aval, ShapedArray): 1462 return ShapedArray((sz, *aval.shape), aval.dtype) 1463 else: 1464 raise TypeError(f'Prepending dim {sz} to aval {aval}') 1465 1466def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, 1467 linear, unroll): 1468 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) 1469 ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype) 1470 if aval is not core.abstract_unit else aval for aval in y_avals] 1471 return carry_avals + ys_avals 1472 1473def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, 1474 linear, unroll): 1475 num_xs = len(jaxpr.in_avals) - num_carry - num_consts 1476 num_ys = len(jaxpr.out_avals) - num_carry 1477 nonzeros = [type(t) is not ad_util.Zero for t in tangents] 1478 const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) 1479 1480 # Fixpoint computation of which carry are not ad.zero: either 1481 # non-zero from init, or the carry out is non-zero. Each iteration promotes 1482 # at least one carry to non-zero. We need at most len(carry) iterations, 1483 # but we need one last iteration to prepare the jaxpr based on the final 1484 # carry_nz. 1485 carry_nz = init_nz 1486 for _ in range(1 + len(carry_nz)): 1487 nonzeros = const_nz + carry_nz + xs_nz 1488 jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr( 1489 jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) 1490 carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:] 1491 if carry_nz_out == carry_nz: 1492 break 1493 else: 1494 carry_nz = _map(operator.or_, carry_nz, carry_nz_out) 1495 else: 1496 assert False, "Fixpoint not reached" 1497 1498 tangents = [ad.instantiate_zeros(t) if nz else t 1499 for t, nz in zip(tangents, nonzeros)] 1500 1501 consts, init, xs = split_list(primals, [num_consts, num_carry]) 1502 all_tangents = split_list(tangents, [num_consts, num_carry]) 1503 consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) 1504 1505 jaxpr_jvp_rearranged = ad.rearrange_binders( 1506 jaxpr_jvp, 1507 [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], 1508 [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) 1509 1510 consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) 1511 jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot) 1512 + init_linear + [True] * len(init_dot) 1513 + xs_linear + [True] * len(xs_dot)) 1514 1515 out_flat = scan_p.bind( 1516 *(consts + consts_dot + init + init_dot + xs + xs_dot), 1517 reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, 1518 num_consts=num_consts + len(consts_dot), 1519 num_carry=num_carry + len(init_dot), 1520 linear=jaxpr_jvp_linear, unroll=unroll) 1521 1522 carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) 1523 primals_out = carry + ys 1524 tangents_out_iter = iter(carry_dot + ys_dot) 1525 tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) 1526 for p, nz in zip(primals_out, nonzeros_out)] 1527 return primals_out, tangents_out 1528 1529def _prune_zeros(ts): 1530 return [t for t in ts if type(t) is not ad_util.Zero] 1531 1532def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, 1533 jaxpr, linear, unroll): 1534 if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore 1535 params = dict(reverse=reverse, length=length, num_consts=num_consts, 1536 num_carry=num_carry, jaxpr=jaxpr, linear=linear, 1537 unroll=unroll) 1538 return trace.default_process_primitive(scan_p, tracers, params) 1539 1540 num_ys = len(jaxpr.out_avals) - num_carry 1541 1542 unknowns = [t.pval[0] is not None for t in tracers] 1543 const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) 1544 1545 if config.omnistaging_enabled: 1546 partial_eval_jaxpr = pe.partial_eval_jaxpr 1547 else: 1548 partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) 1549 1550 # Fixpoint computation of which carry are unknown (not a constant): either 1551 # unknown from init, or the carry out is unknown. Each iteration promotes 1552 # at least one carry to unknown. We need at most len(carry) iterations, 1553 # but we need one last iteration to prepare the jaxpr based on the final 1554 # carry_uk. 1555 carry_uk = init_uk 1556 for _ in range(1 + len(carry_uk)): 1557 unknowns = const_uk + carry_uk + xs_uk 1558 jaxpr_1, jaxpr_2, out_uk = partial_eval_jaxpr( 1559 jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) 1560 carry_uk_out = out_uk[:num_carry] 1561 if carry_uk_out == carry_uk: 1562 break 1563 else: 1564 carry_uk = _map(operator.or_, carry_uk, carry_uk_out) 1565 else: 1566 assert False, "Fixpoint not reached" 1567 num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals) 1568 1569 # The residuals are treated as extensive outputs of jaxpr_1 (and extensive 1570 # inputs to jaxpr_2), but residuals that are loop-invariant can be hoisted. 1571 # TODO(mattjj): hoist other loop-invariant values here too (instantiate=False) 1572 invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1]) 1573 for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])] 1574 other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]] 1575 in_pvals_1 = invariant_pvals + other_pvals 1576 jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr( 1577 lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1, 1578 instantiate=[True] * (num_carry + num_ys) + [False] * num_res) 1579 jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ()) 1580 num_consts_1 = num_consts + len(consts_1) 1581 # any now-known residuals are intensive, so we want to revise jaxpr_2 to take 1582 # those inputs as constants rather than as extensive inputs 1583 _, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys]) 1584 intensive_residuals = [const for pv, const in res_pvals if pv is None] 1585 move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals] 1586 jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move) 1587 num_consts_2 = num_consts + len(intensive_residuals) 1588 1589 # As another optimization, for any extensive inputs that are just forwarded to 1590 # extensive outputs, to avoid a copy (looping over dynamic-update-slice) we'd 1591 # rather just forward the input tracer. That means pruning some extensive 1592 # outputs from the jaxpr here, and updating out_flat below. 1593 extensive_invars = jaxpr_1_opt.jaxpr.invars[num_consts_1 + num_carry:] 1594 extensive_outvars = jaxpr_1_opt.jaxpr.outvars[num_carry:] 1595 extensive_avals = [core.unmapped_aval(length, 0, core.raise_to_shaped(v.aval)) 1596 for v in extensive_outvars] 1597 fwd_extensive = [num_consts + num_carry + extensive_invars.index(v) 1598 if v in extensive_invars else None for v in extensive_outvars] 1599 jaxpr_1_opt.jaxpr.outvars = ( 1600 jaxpr_1_opt.jaxpr.outvars[:num_carry] + 1601 [v for i, v in zip(fwd_extensive, extensive_outvars) if i is None]) 1602 1603 in_consts = (list(consts_1) + [core.unit] * num_consts + 1604 [core.unit if uk else t.pval[1] 1605 for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])]) 1606 linear_1 = ([False] * len(consts_1) + [True] * num_consts + 1607 [lin or uk for uk, lin 1608 in zip(unknowns[num_consts:], linear[num_consts:])]) 1609 out_flat = scan_p.bind( 1610 *in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt, 1611 num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1), 1612 unroll=unroll) 1613 1614 # Propagate the forwarded extensive outputs using fwd_extensive. Any 1615 # numpy.ndarray inputs should be converted to JAX DeviceArrays. 1616 out_carry, out_extensive = split_list(out_flat, [num_carry]) 1617 out_extensive_iter = iter(out_extensive) 1618 out_extensive = [next(out_extensive_iter) if i is None 1619 else _maybe_device_put(tracers[i].pval[1]) if tracers[i].is_known() 1620 else tracers[i] for i in fwd_extensive] 1621 assert all(a == core.raise_to_shaped(core.get_aval(out)) 1622 for a, out in zip(extensive_avals, out_extensive)) 1623 out_flat = out_carry + out_extensive 1624 1625 out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys]) 1626 extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None] 1627 1628 new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit) 1629 for uk, t in zip(unknowns, tracers)] 1630 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) 1631 ys_avals = _map(partial(_promote_aval_rank, length), y_avals) 1632 out_avals = carry_avals + ys_avals 1633 out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)] 1634 1635 out_consts = out_carry + ys 1636 int_res_tracers = _map(trace.new_instantiated_const, intensive_residuals) 1637 ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals) 1638 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 1639 for pv, const in zip(out_pvs, out_consts)] 1640 linear_2 = ([False] * len(int_res_tracers) + 1641 [lin or not uk for uk, lin in zip(unknowns, linear)] + 1642 [False] * len(ext_res_tracers)) 1643 eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers, 1644 out_tracers, scan_p, 1645 dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt, 1646 num_consts=num_consts_2, 1647 num_carry=num_carry, linear=tuple(linear_2), 1648 unroll=unroll), 1649 source_info_util.current()) 1650 for t in out_tracers: t.recipe = eqn 1651 return out_tracers 1652 1653def _maybe_device_put(x): 1654 if isinstance(x, np.ndarray): 1655 return lax._device_put_raw(x) 1656 else: 1657 return x 1658 1659def _promote_aval_rank(sz, aval): 1660 if aval is core.abstract_unit: 1661 return core.abstract_unit 1662 else: 1663 return ShapedArray((sz,) + aval.shape, aval.dtype) 1664 1665def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, 1666 linear, unroll): 1667 # we've only implemented transposing scans with specific lin/nonlin patterns 1668 consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) 1669 num_ires = len(consts_lin) - sum(consts_lin) 1670 num_eres = len(xs_lin) - sum(xs_lin) 1671 if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): 1672 raise NotImplementedError 1673 if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: 1674 raise NotImplementedError 1675 if not all(init_lin): 1676 pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 1677 1678 consts, _, xs = split_list(args, [num_consts, num_carry]) 1679 ires, _ = split_list(consts, [num_ires]) 1680 _, eres = split_list(xs, [sum(xs_lin)]) 1681 assert not any(ad.is_undefined_primal(r) for r in ires) 1682 assert not any(ad.is_undefined_primal(r) for r in eres) 1683 1684 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) 1685 ys_avals = _map(partial(_promote_aval_rank, length), y_avals) 1686 ct_carry, ct_ys = split_list(cts, [num_carry]) 1687 ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) 1688 ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) 1689 ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) 1690 1691 # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) 1692 # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) 1693 jaxpr_trans = _transpose_scan_jaxpr( 1694 num_ires, num_consts - num_ires, num_eres, jaxpr) 1695 linear_trans = ([False] * num_ires + 1696 [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + 1697 [False] * num_eres) 1698 1699 outs = scan_p.bind( 1700 *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse, 1701 length=length, jaxpr=jaxpr_trans, num_consts=num_ires, 1702 num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans), 1703 unroll=unroll) 1704 ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) 1705 return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres 1706 1707# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) 1708# -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) 1709def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr): 1710 num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 1711 res1_avals, c_avals, a_avals, res2_avals = split_list( 1712 jaxpr.in_avals, [num_res1, num_c, num_a]) 1713 num_b = len(jaxpr.out_avals) 1714 b_avals = list(jaxpr.out_avals) 1715 1716 @lu.wrap_init 1717 def transposed(*res1_cbar_bbar_res2): 1718 res1, c_bar, b_bar, res2 = split_list( 1719 res1_cbar_bbar_res2, [num_res1, num_c, num_b]) 1720 primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + 1721 [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) 1722 cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar) 1723 _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) 1724 a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) 1725 c_bar = _map(ad.instantiate_zeros_aval, c_avals, 1726 _map(ad.add_tangents, c_bar, new_c_bar)) 1727 return c_bar + a_bar 1728 return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals) 1729 1730def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): 1731 if config.omnistaging_enabled: 1732 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals) 1733 else: 1734 pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] 1735 jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) 1736 out_avals, _ = unzip2(pvals_out) 1737 return core.ClosedJaxpr(jaxpr, consts) 1738 1739 1740def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts, 1741 num_carry, linear, unroll): 1742 num_ys = len(jaxpr.out_avals) - num_carry 1743 size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} 1744 orig_batched = [d is not batching.not_mapped for d in dims] 1745 const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) 1746 1747 # Fixpoint computation of which carry are batched: either 1748 # batched from init, or the carry out is batched. Each iteration promotes 1749 # at least one carry to batched. We need at most len(carry) iterations, 1750 # but we need one last iteration to prepare the jaxpr based on the final 1751 # carry_batched. 1752 carry_batched = init_batched 1753 for _ in range(1 + len(carry_batched)): 1754 batched = const_batched + carry_batched + xs_batched 1755 jaxpr_batched, batched_out = batching.batch_jaxpr( 1756 jaxpr, size, batched, 1757 instantiate=carry_batched + [False] * num_ys, 1758 axis_name=axis_name) 1759 carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] 1760 if carry_batched_out == carry_batched: 1761 break 1762 else: 1763 carry_batched = _map(operator.or_, carry_batched, carry_batched_out) 1764 else: 1765 assert False, "Fixpoint not reached" 1766 1767 consts, init, xs = split_list(args, [num_consts, num_carry]) 1768 consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) 1769 new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 1770 else x for x, d in zip(consts, consts_bdims)] 1771 new_init = [batching.broadcast(x, size, 0) if now_batched and not was_batched 1772 else batching.moveaxis(x, d, 0) if now_batched else x 1773 for x, d, was_batched, now_batched in 1774 zip(init, init_bdims, init_batched, carry_batched)] 1775 new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 1776 else x for x, d in zip(xs, xs_bdims)] 1777 new_args = new_consts + new_init + new_xs 1778 1779 outs = scan_p.bind( 1780 *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, 1781 num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) 1782 carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] 1783 ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] 1784 return outs, carry_bdims + ys_bdims 1785 1786def _scan_masking_rule(padded_vals, logical_shapes, reverse, length, 1787 jaxpr, num_consts, num_carry, linear, unroll): 1788 dynamic_length, = masking.shape_as_value((length,)) 1789 masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry) 1790 consts, init, xs = split_list(padded_vals, [num_consts, num_carry]) 1791 max_length, = {x.shape[0] for x in xs} 1792 const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) 1793 out_vals = scan_p.bind( 1794 *itertools.chain([dynamic_length] + consts, [0], init, xs), 1795 reverse=reverse, length=max_length, jaxpr=masked_jaxpr, 1796 num_consts=1 + num_consts, num_carry=1 + num_carry, 1797 linear=tuple([False] + const_linear + [False] + init_linear + xs_linear), 1798 unroll=unroll) 1799 return out_vals[1:] 1800 1801def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): 1802 fun = core.jaxpr_as_fun(jaxpr) 1803 1804 @lu.wrap_init 1805 def masked(*args): 1806 [dynamic_length], consts, [i], carry, xs = split_list( 1807 args, [1, num_consts, 1, num_carry]) 1808 out = fun(*(consts + carry + xs)) 1809 new_carry, ys = split_list(out, [num_carry]) 1810 new_carry = [lax.select(i < dynamic_length, new_c, c) 1811 for new_c, c in zip(new_carry, carry)] 1812 return [i + 1] + new_carry + ys 1813 1814 aval = ShapedArray((), dtypes.int_) 1815 const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) 1816 return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals) 1817 1818def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, 1819 jaxpr, linear, unroll): 1820 tc = partial(_typecheck_param, 'scan') 1821 tc(reverse, 'reverse', 'bool', type(reverse) is bool) 1822 tc(num_consts, 'num_consts', 'non-negative int', 1823 type(num_consts) is int and num_consts >= 0) 1824 tc(num_carry, 'num_carry', 'non-negative int', 1825 type(num_carry) is int and num_carry >= 0) 1826 tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) 1827 tc(linear, 'linear', 'tuple of bool', 1828 type(linear) is tuple and all(type(x) is bool for x in linear)) 1829 tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) 1830 1831 length_types = (int, masking.Poly) if bind_time else (int,) 1832 tc(length, 'length', 'non-negative int', 1833 type(length) in length_types and length >= 0) 1834 1835 core.typecheck_assert( 1836 len(linear) == len(avals), 1837 f'scan param linear has length {len(linear)} for {len(avals)} operands') 1838 1839 const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) 1840 const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list( 1841 jaxpr.in_avals, [num_consts, num_carry]) 1842 carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry]) 1843 x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) 1844 1845 core.typecheck_assert( 1846 all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)), 1847 f'scan input carry input and output types mismatch: ' 1848 f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') 1849 core.typecheck_assert( 1850 all(_map(core.typecompat, const_avals_jaxpr, const_avals)), 1851 f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' 1852 f'called with consts of type\n{_avals_short(const_avals)}') 1853 core.typecheck_assert( 1854 all(_map(core.typecompat, init_avals_jaxpr, init_avals)), 1855 f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' 1856 f'called with initial carry of type\n{_avals_short(init_avals)}') 1857 core.typecheck_assert( 1858 all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)), 1859 f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' 1860 f'called with sequence of type\n{_avals_short(x_avals)}') 1861 1862def scan_bind(*args, **params): 1863 if not core.skip_checks: 1864 avals = _map(core.get_aval, args) 1865 _scan_typecheck(True, *avals, **params) 1866 core.check_jaxpr(params['jaxpr'].jaxpr) 1867 return core.Primitive.bind(scan_p, *args, **params) 1868 1869scan_p = core.Primitive("scan") 1870scan_p.multiple_results = True 1871scan_p.def_custom_bind(scan_bind) 1872scan_p.def_impl(_scan_impl) 1873# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable 1874scan_p.def_abstract_eval(_scan_abstract_eval) 1875ad.primitive_jvps[scan_p] = _scan_jvp 1876ad.primitive_transposes[scan_p] = _scan_transpose 1877pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval 1878xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl) 1879batching.initial_style_batchers[scan_p] = _scan_batching_rule 1880masking.masking_rules[scan_p] = _scan_masking_rule 1881core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) 1882 1883 1884def map(f, xs): 1885 """Map a function over leading array axes. 1886 1887 Like Python's builtin map, except inputs and outputs are in the form of 1888 stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you 1889 need to apply a function element by element for reduced memory usage or 1890 heterogeneous computation with other control flow primitives. 1891 1892 When ``xs`` is an array type, the semantics of ``map`` are given by this 1893 Python implementation:: 1894 1895 def map(f, xs): 1896 return np.stack([f(x) for x in xs]) 1897 1898 Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of 1899 the same advantages over a Python loop apply: ``xs`` may be an arbitrary 1900 nested pytree type, and the mapped computation is compiled only once. 1901 1902 Args: 1903 f: a Python function to apply element-wise over the first axis or axes of 1904 ``xs``. 1905 xs: values over which to map along the leading axis. 1906 1907 Returns: 1908 Mapped values. 1909 """ 1910 g = lambda _, x: ((), f(x)) 1911 _, ys = scan(g, (), xs) 1912 return ys 1913 1914 1915def _concat_masking_rule(padded_vals, logical_shapes, dimension): 1916 result = lax.concatenate(padded_vals, dimension) # fragmented 1917 offset = 0 1918 for padded_val, logical_shape in zip(padded_vals, logical_shapes): 1919 result = _memcpy(dimension, logical_shape[dimension], padded_val, 1920 result, offset) 1921 offset = offset + logical_shape[dimension] 1922 return result 1923 1924def _memcpy(axis, num, src, dst, offset): 1925 def body(i, dst): 1926 update = lax.dynamic_index_in_dim(src, i, axis) 1927 return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis) 1928 return fori_loop(0, num, body, dst) 1929 1930masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore 1931 1932 1933def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): 1934 """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). 1935 1936 Corresponding `tree` and `avals` must match in the sense that the number of 1937 leaves in `tree` must be equal to the length of `avals`. `what` will be 1938 prepended to details of the mismatch in TypeError. 1939 """ 1940 if tree1 != tree2: 1941 raise TypeError( 1942 f"{what} must have same type structure, got {tree1} and {tree2}.") 1943 if not all(_map(core.typematch, avals1, avals2)): 1944 raise TypeError( 1945 f"{what} must have identical types, got\n" 1946 f"{tree_unflatten(tree1, avals1)}\nand\n" 1947 f"{tree_unflatten(tree2, avals2)}.") 1948 1949 1950def _check_tree(func_name, expected_name, actual_tree, expected_tree): 1951 if actual_tree != expected_tree: 1952 raise TypeError( 1953 f"{func_name}() output pytree structure must match {expected_name}, " 1954 f"got {actual_tree} and {expected_tree}.") 1955 1956 1957def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): 1958 """Promote weakly-typed in_vals to be compatible with out_avals. 1959 1960 Args: 1961 in_vals : flattened list of input values. 1962 in_avals : corresponding list of avals. 1963 out_avals : list of target output avals. 1964 Returns: 1965 in_vals_new : flattened list of modified in_vals with no weak types. 1966 changed : bool; true if in_vals required modification. 1967 """ 1968 if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals): 1969 # Calling function is responsible for catching this. 1970 return in_vals, False 1971 weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals)) 1972 if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)] 1973 if not weak_mismatches: 1974 return in_vals, False 1975 for i in weak_mismatches: 1976 new_dtype = dtypes.result_type(in_vals[i], out_avals[i]) 1977 in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype) 1978 return in_vals, True 1979 1980def _stop_gradient_fun(f): 1981 """Create a version of f() that stops all gradients.""" 1982 def wrapper(*args, **kwargs): 1983 args_flat, in_args_tree = tree_flatten((args, kwargs)) 1984 args_avals = tuple(_map(_abstractify, args_flat)) 1985 g = lambda a, b: f(*a, **b) 1986 jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals) 1987 all_args = _map(lax.stop_gradient, (*consts, *args_flat)) 1988 out = core.jaxpr_as_fun(jaxpr)(*all_args) 1989 return tree_unflatten(out_tree, out) 1990 return wrapper 1991 1992 1993_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s') 1994 1995 1996def _split_root_args(args, const_lengths): 1997 params_list = split_list(args, list(const_lengths)) 1998 return _RootTuple(*params_list[:-1]), params_list[-1] 1999 2000 2001def custom_root(f, initial_guess, solve, tangent_solve): 2002 """Differentiably solve for a roots of a function. 2003 2004 This is a low-level routine, mostly intended for internal use in JAX. 2005 Gradients of custom_root() are defined with respect to closed-over variables 2006 from the provided function ``f`` via the implicit function theorem: 2007 https://en.wikipedia.org/wiki/Implicit_function_theorem 2008 2009 Args: 2010 f: function for which to find a root. Should accept a single argument, 2011 return a tree of arrays with the same structure as its input. 2012 initial_guess: initial guess for a zero of f. 2013 solve: function to solve for the roots of f. Should take two positional 2014 arguments, f and initial_guess, and return a solution with the same 2015 structure as initial_guess such that func(solution) = 0. In other words, 2016 the following is assumed to be true (but not checked):: 2017 2018 solution = solve(f, initial_guess) 2019 error = f(solution) 2020 assert all(error == 0) 2021 2022 tangent_solve: function to solve the tangent system. Should take two 2023 positional arguments, a linear function ``g`` (the function ``f`` 2024 linearized at its root) and a tree of array(s) ``y`` with the same 2025 structure as initial_guess, and return a solution ``x`` such that 2026 ``g(x)=y``: 2027 2028 - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. 2029 - For vector ``y``, you could use a linear solve with the Jacobian, if 2030 dimensionality of ``y`` is not too large: 2031 ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. 2032 2033 Returns: 2034 The result of calling solve(f, initial_guess) with gradients defined via 2035 implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. 2036 """ 2037 guess_flat, in_args_tree = tree_flatten((initial_guess,)) 2038 guess_avals = tuple(_map(_abstractify, guess_flat)) 2039 f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( 2040 f, in_args_tree, guess_avals) 2041 2042 in_tree, = treedef_children(in_args_tree) 2043 _check_tree("f", "initial_guess", out_tree, in_tree) 2044 2045 solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( 2046 partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals) 2047 _check_tree("solve", "initial_guess", solution_tree, in_tree) 2048 2049 def linearize_and_solve(x, b): 2050 unchecked_zeros, f_jvp = jax.linearize(f, x) 2051 return tangent_solve(f_jvp, b) 2052 2053 l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( 2054 linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2) 2055 _check_tree("tangent_solve", "x", out_tree, in_tree) 2056 2057 all_consts = [f_consts, solve_consts, l_and_s_consts] 2058 const_lengths = _RootTuple(*_map(len, all_consts)) 2059 jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) 2060 2061 out_flat = _custom_root( 2062 const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) 2063 return tree_unflatten(out_tree, out_flat) 2064 2065 2066@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) 2067def _custom_root(const_lengths, jaxprs, *args): 2068 params, initial_guess = _split_root_args(args, const_lengths) 2069 solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) 2070 return solution 2071 2072 2073@_custom_root.defjvp 2074def _root_jvp(const_lengths, jaxprs, primals, tangents): 2075 params, _ = _split_root_args(primals, const_lengths) 2076 solution = _custom_root(const_lengths, jaxprs, *primals) 2077 2078 params_dot, _ = _split_root_args(tangents, const_lengths) 2079 2080 # F(m, u) = 0 # system of equations in u, parameterized by m 2081 # # solution is u*(m) defined in a neighborhood 2082 # F(m, u*(m)) = 0 # satisfied in a neighborhood 2083 # 2084 # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above 2085 # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange 2086 # 2087 # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp 2088 2089 f = core.jaxpr_as_fun(jaxprs.f) 2090 linearize_and_solve = partial( 2091 core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) 2092 f_at_solution = lambda *params: f(*itertools.chain(params, solution)) 2093 _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( 2094 params.f, params_dot.f) 2095 solution_dot = _map( 2096 operator.neg, linearize_and_solve(*itertools.chain(solution, rhs))) 2097 2098 return solution, solution_dot 2099 2100 2101class _LinearSolveTuple(collections.namedtuple( 2102 '_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')): 2103 2104 def transpose(self): 2105 return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve) 2106 2107 2108def _split_linear_solve_args(args, const_lengths): 2109 params_list = split_list(args, list(const_lengths)) 2110 return _LinearSolveTuple(*params_list[:-1]), params_list[-1] 2111 2112 2113def _transpose_one_output(linear_fun, primals): 2114 transpose_fun = jax.linear_transpose(linear_fun, primals) 2115 def transposed_fun(x): 2116 (y,) = transpose_fun(x) 2117 return y 2118 return transposed_fun 2119 2120 2121def _flatten(args): 2122 return [x for arg in args for x in arg] 2123 2124 2125def _check_shapes(func_name, expected_name, actual, expected): 2126 actual_shapes = _map(np.shape, tree_leaves(actual)) 2127 expected_shapes = _map(np.shape, tree_leaves(expected)) 2128 if actual_shapes != expected_shapes: 2129 raise ValueError( 2130 f"{func_name}() output shapes must match {expected_name}, " 2131 f"got {actual_shapes} and {expected_shapes}") 2132 2133 2134def custom_linear_solve( 2135 matvec, b, solve, transpose_solve=None, symmetric=False): 2136 """Perform a matrix-free linear solve with implicitly defined gradients. 2137 2138 This function allows for overriding or defining gradients for a linear 2139 solve directly via implicit differentiation at the solution, rather than by 2140 differentiating *through* the solve operation. This can sometimes be much faster 2141 or more numerically stable, or differentiating through the solve operation 2142 may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``). 2143 2144 Required invariant:: 2145 2146 x = solve(matvec, b) # solve the linear equation 2147 assert matvec(x) == b # not checked 2148 2149 Args: 2150 matvec: linear function to invert. Must be differentiable. 2151 b: constant right handle side of the equation. May be any nested structure 2152 of arrays. 2153 solve: higher level function that solves for solution to the linear 2154 equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form 2155 as ``b``. This function need not be differentiable. 2156 transpose_solve: higher level function for solving the transpose linear 2157 equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is 2158 the transpose of the linear map ``matvec`` (computed automatically with 2159 autodiff). Required for backwards mode automatic differentiation, unless 2160 ``symmetric=True``, in which case ``solve`` provides the default value. 2161 symmetric: bool indicating if it is safe to assume the linear map 2162 corresponds to a symmetric matrix, i.e., ``matvec == vecmat``. 2163 2164 Returns: 2165 Result of ``solve(matvec, b)``, with gradients defined assuming that the 2166 solution ``x`` satisfies the linear equation ``matvec(x) == b``. 2167 """ 2168 if transpose_solve is None and symmetric: 2169 transpose_solve = solve 2170 2171 b_flat, in_args_tree = tree_flatten((b,)) 2172 b_avals = tuple(_map(_abstractify, b_flat)) 2173 2174 tree, = treedef_children(in_args_tree) 2175 2176 def _shape_checked(fun, name): 2177 def f(x): 2178 y = fun(x) 2179 _check_shapes(name, "b", y, b_flat) 2180 return y 2181 return f 2182 2183 matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( 2184 _shape_checked(matvec, "matvec"), in_args_tree, b_avals) 2185 _check_tree("matvec", "b", out_tree, tree) 2186 2187 solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( 2188 _shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals) 2189 _check_tree("solve", "b", out_tree, tree) 2190 2191 if transpose_solve is None: 2192 vecmat_jaxpr = tr_solve_jaxpr = None 2193 vecmat_consts = tr_solve_consts = [] 2194 else: 2195 if symmetric: 2196 vecmat = matvec 2197 vecmat_jaxpr = matvec_jaxpr 2198 vecmat_consts = matvec_consts 2199 else: 2200 vecmat = _transpose_one_output(matvec, b) 2201 vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( 2202 vecmat, in_args_tree, b_avals) 2203 assert out_tree == tree 2204 2205 tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( 2206 _shape_checked(partial(transpose_solve, vecmat), "transpose_solve"), 2207 in_args_tree, b_avals) 2208 _check_tree("transpose_solve", "b", out_tree, tree) 2209 2210 all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] 2211 const_lengths = _LinearSolveTuple(*_map(len, all_consts)) 2212 jaxprs = _LinearSolveTuple( 2213 matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) 2214 2215 out_flat = linear_solve_p.bind( 2216 *(_flatten(all_consts) + b_flat), 2217 const_lengths=const_lengths, jaxprs=jaxprs) 2218 return tree_unflatten(tree, out_flat) 2219 2220 2221def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): 2222 return _map(raise_to_shaped, args[sum(const_lengths):]) 2223 2224 2225def _custom_linear_solve_impl(*args, const_lengths, jaxprs): 2226 params, b = _split_linear_solve_args(args, const_lengths) 2227 x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) 2228 return x 2229 2230 2231def _tangent_linear_map(func, params, params_dot, *x): 2232 """Compute the tangent of a linear map. 2233 2234 Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, 2235 this function computes ``∂A @ x``. 2236 """ 2237 assert any(type(p) is not ad_util.Zero for p in params_dot) 2238 zeros = _map(ad_util.Zero.from_value, x) 2239 _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( 2240 params + list(x), params_dot + zeros) 2241 return out_tangent 2242 2243 2244def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): 2245 # A x - b = 0 2246 # ∂A x + A ∂x - ∂b = 0 2247 # ∂x = A^{-1} (∂b - ∂A x) 2248 2249 kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs) 2250 x = linear_solve_p.bind(*primals, **kwargs) 2251 2252 params, _ = _split_linear_solve_args(primals, const_lengths) 2253 params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths) 2254 2255 if all(type(p) is ad_util.Zero for p in params_dot.matvec): 2256 # no need to evaluate matvec_tangents 2257 rhs = b_dot 2258 else: 2259 matvec_tangents = _tangent_linear_map( 2260 core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x) 2261 rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents)) 2262 2263 x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs) 2264 2265 return x, x_dot 2266 2267 2268def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): 2269 if jaxprs.transpose_solve is None: 2270 raise TypeError('transpose_solve required for backwards mode automatic ' 2271 'differentiation of custom_linear_solve') 2272 2273 params, b = _split_linear_solve_args(primals, const_lengths) 2274 assert all(ad.is_undefined_primal(x) for x in b) 2275 cotangent_b = linear_solve_p.bind( 2276 *(_flatten(params.transpose()) + cotangent), 2277 const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose()) 2278 return [None] * sum(const_lengths) + cotangent_b 2279 2280 2281def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs): 2282 orig_bat = [d is not batching.not_mapped for d in dims] 2283 size, = { 2284 a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped 2285 } 2286 2287 params, b = _split_linear_solve_args(args, const_lengths) 2288 params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) 2289 params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) 2290 2291 (matvec, vecmat, solve, solve_t) = jaxprs 2292 (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat 2293 2294 # Fixpoint computation of which parts of x and b are batched; we need to 2295 # ensure this is consistent between all four jaxprs 2296 b_bat = orig_b_bat 2297 x_bat = [False] * len(solve.out_avals) 2298 for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): 2299 # Apply vecmat and solve -> new batched parts of x 2300 solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( 2301 solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name) 2302 if vecmat is None: 2303 vecmat_jaxpr_batched = None 2304 x_bat_out = solve_x_bat 2305 else: 2306 vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( 2307 vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name) 2308 x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat) 2309 # Apply matvec and solve_t -> new batched parts of b 2310 matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( 2311 matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) 2312 if solve_t is None: 2313 solve_t_jaxpr_batched = None 2314 b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) 2315 else: 2316 solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr( 2317 solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) 2318 b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, 2319 orig_b_bat) 2320 if x_bat_out == x_bat and b_bat_out == b_bat: 2321 break 2322 else: 2323 x_bat = x_bat_out 2324 b_bat = b_bat_out 2325 else: 2326 assert False, "Fixedpoint not reached" 2327 2328 batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, 2329 solve_jaxpr_batched, solve_t_jaxpr_batched) 2330 2331 # Move batched axes to the front 2332 new_params = [ 2333 batching.moveaxis(x, d, 0) 2334 if d is not batching.not_mapped and d != 0 else x 2335 for x, d in zip(_flatten(params), _flatten(params_dims)) 2336 ] 2337 # Broadcast out b if necessary 2338 new_b = [ 2339 batching.broadcast(x, size, 0) if now_bat and not was_bat else 2340 batching.moveaxis(x, d, 0) if now_bat and d != 0 else x 2341 for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) 2342 ] 2343 2344 outs = linear_solve_p.bind( 2345 *(new_params + new_b), 2346 const_lengths=const_lengths, 2347 jaxprs=batched_jaxprs) 2348 out_dims = [0 if batched else batching.not_mapped for batched in b_bat] 2349 return outs, out_dims 2350 2351 2352linear_solve_p = core.Primitive('custom_linear_solve') 2353linear_solve_p.multiple_results = True 2354linear_solve_p.def_impl(_custom_linear_solve_impl) 2355linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) 2356ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp 2357xla.initial_style_translations[linear_solve_p] = \ 2358 xla.lower_fun_initial_style(_custom_linear_solve_impl) 2359ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule 2360batching.initial_style_batchers[linear_solve_p] = _linear_solve_batching_rule 2361 2362 2363def _interleave(a, b, axis): 2364 """Given two Tensors of static shape, interleave them along the first axis.""" 2365 assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 2366 a_pad = [(0, 0, 0)] * a.ndim 2367 b_pad = [(0, 0, 0)] * b.ndim 2368 a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1) 2369 b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1) 2370 return lax.add(lax.pad(a, lax._const(a, 0), a_pad), 2371 lax.pad(b, lax._const(b, 0), b_pad)) 2372 2373def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): 2374 """Performs a scan with an associative binary operation, in parallel. 2375 2376 For an introduction to associative scans, see [BLE1990]_. 2377 2378 Args: 2379 fn: A Python callable implementing an associative binary operation with 2380 signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it 2381 must satisfy the equation 2382 ``fn(a, fn(b, c)) == fn(fn(a, b), c)``. 2383 2384 The inputs and result are (possibly nested Python tree structures of) 2385 array(s) matching ``elems``. Each array has a dimension in place 2386 of the ``axis`` dimension. `fn` should be applied elementwise over 2387 the ``axis`` dimension (for example, by using :func:`jax.vmap` over the 2388 elementwise function.) 2389 2390 The result ``r`` has the same shape (and structure) as the two inputs 2391 ``a`` and ``b``. 2392 elems: A (possibly nested Python tree structure of) array(s), each with 2393 an ``axis`` dimension of size ``num_elems``. 2394 reverse: A boolean stating if the scan should be reversed with respect to 2395 the ``axis`` dimension. 2396 axis: an integer identifying the axis over which the scan should occur. 2397 2398 Returns: 2399 A (possibly nested Python tree structure of) array(s) of the same shape 2400 and structure as ``elems``, in which the ``k``'th element of ``axis`` is the 2401 result of recursively applying ``fn`` to combine the first ``k`` elements 2402 of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, 2403 the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. 2404 2405 Example 1: partial sums of an array of numbers: 2406 2407 >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) 2408 [ 0, 1, 3, 6] 2409 2410 Example 2: partial products of an array of matrices 2411 2412 >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) 2413 >>> partial_prods = lax.associative_scan(jnp.matmul, mats) 2414 >>> partial_prods.shape 2415 (4, 2, 2) 2416 2417 Example 3: reversed partial sums of an array of numbers 2418 2419 >>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) 2420 [ 6, 6, 5, 3] 2421 2422 .. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", 2423 Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon 2424 University. 2425 """ 2426 elems_flat, tree = tree_flatten(elems) 2427 2428 if reverse: 2429 elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat] 2430 2431 def combine(a_flat, b_flat): 2432 # Lower `fn` to operate on flattened sequences of elems. 2433 a = tree_unflatten(tree, a_flat) 2434 b = tree_unflatten(tree, b_flat) 2435 c = fn(a, b) 2436 c_flat, _ = tree_flatten(c) 2437 return c_flat 2438 2439 # Check that all inputs have a consistent leading dimension `num_elems`. 2440 axis = lax._canonicalize_axis(axis, elems_flat[0].ndim) 2441 num_elems = int(elems_flat[0].shape[axis]) 2442 if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): 2443 raise ValueError('Array inputs to associative_scan must have the same ' 2444 'first dimension. (saw: {})' 2445 .format([elems.shape for elem in elems_flat])) 2446 2447 2448 # Summary of algorithm: 2449 # 2450 # Consider elements of `_scan(elems)` at odd indices. That's the same as first 2451 # summing successive pairs of elements of `elems` and performing a scan on 2452 # that half sized tensor. We perform the latter scan by recursion. 2453 # 2454 # Now consider the even elements of `_scan(elems)`. These can be computed 2455 # from the odd elements of `_scan(elems)` by adding each odd element of 2456 # `_scan(elems)` to the matching even element in the original `elems`. 2457 # 2458 # We return the odd and even elements interleaved. 2459 # 2460 # For the base case of the recursion we return the first element 2461 # of `elems` followed by the sum of the first two elements computed as 2462 # a (small two-down-to-one) reduction step. 2463 def _scan(elems): 2464 """Perform scan on `elems`.""" 2465 2466 num_elems = elems[0].shape[axis] 2467 2468 if num_elems < 2: 2469 return elems 2470 2471 # Combine adjacent pairs of elements. 2472 reduced_elems = combine( 2473 [lax.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems], 2474 [lax.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems]) 2475 2476 # Recursively compute scan for partially reduced tensors. 2477 odd_elems = _scan(reduced_elems) 2478 2479 if num_elems % 2 == 0: 2480 even_elems = combine( 2481 [lax.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems], 2482 [lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) 2483 else: 2484 even_elems = combine( 2485 odd_elems, 2486 [lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) 2487 2488 # The first element of a scan is the same as the first element 2489 # of the original `elems`. 2490 even_elems = [ 2491 lax.concatenate([lax.slice_in_dim(elem, 0, 1, axis=axis), result], 2492 dimension=axis) 2493 for (elem, result) in zip(elems, even_elems)] 2494 return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems)) 2495 2496 scans = _scan(elems_flat) 2497 2498 if reverse: 2499 scans = [lax.rev(scanned, [axis]) for scanned in scans] 2500 2501 return tree_unflatten(tree, scans) 2502 2503 2504# Cumulative reductions. 2505 2506def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array: 2507 """Computes a cumulative sum along `axis`.""" 2508 return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse)) 2509 2510def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array: 2511 """Computes a cumulative product along `axis`.""" 2512 return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse)) 2513 2514def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array: 2515 """Computes a cumulative maximum along `axis`.""" 2516 return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse)) 2517 2518def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array: 2519 """Computes a cumulative minimum along `axis`.""" 2520 return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse)) 2521 2522def _cumred_shape_rule(x, *, axis: int, reverse: bool): 2523 if axis < 0 or axis >= x.ndim: 2524 raise ValueError( 2525 "axis {} is out of bounds for array of shape {}".format(axis, x.shape)) 2526 return x.shape 2527 2528def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool): 2529 return [cumsum(t, axis=axis, reverse=not reverse)] 2530 2531 2532 2533def _cumred_tpu_translation_rule(window_reduce: Callable, x, *, 2534 axis: int, reverse: bool): 2535 # On TPU, an implementation using reduce_window is handled specially by the 2536 # compiler and is efficient. On other backends, it is O(n^2). 2537 n = x.shape[axis] 2538 if n == 0: 2539 return x 2540 padding = [(0, 0)] * x.ndim 2541 padding[axis] = (0, n - 1) if reverse else (n - 1, 0) 2542 strides = [1] * x.ndim 2543 window_dims = [1] * x.ndim 2544 window_dims[axis] = n 2545 return window_reduce(x, window_dims, strides, padding) 2546 2547def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int, 2548 reverse: bool): 2549 operand, = batched_args 2550 bdim, = batch_dims 2551 axis = axis if axis < bdim else axis + 1 2552 return prim.bind(operand, axis=axis, reverse=reverse), bdim 2553 2554def _cumred_dtype_rule(name, operand, *args, **kw): 2555 if not dtypes.issubdtype(operand.dtype, np.number): 2556 raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " 2557 "of number.".format(name, np.dtype(operand.dtype).name)) 2558 return dtypes.canonicalize_dtype(operand.dtype) 2559 2560cumsum_p = lax.standard_primitive( 2561 _cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"), 2562 'cumsum') 2563ad.deflinear2(cumsum_p, _cumsum_transpose_rule) 2564xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun( 2565 partial(_cumred_tpu_translation_rule, lax._reduce_window_sum), 2566 multiple_results=False) 2567batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p) 2568 2569 2570def _cumulative_reduction_primitive(name, reduce_window_fn): 2571 reducer_p = lax.standard_primitive( 2572 _cumred_shape_rule, partial(_cumred_dtype_rule, name), 2573 name) 2574 xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun( 2575 partial(_cumred_tpu_translation_rule, reduce_window_fn), 2576 multiple_results=False) 2577 batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) 2578 return reducer_p 2579 2580 2581cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod) 2582cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max) 2583cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min) 2584 2585xla.translations[cumsum_p] = xla.lower_fun( 2586 partial(associative_scan, lax.add), multiple_results=False) 2587xla.translations[cumprod_p] = xla.lower_fun( 2588 partial(associative_scan, lax.mul), multiple_results=False) 2589xla.translations[cummin_p] = xla.lower_fun( 2590 partial(associative_scan, lax.min), multiple_results=False) 2591xla.translations[cummax_p] = xla.lower_fun( 2592 partial(associative_scan, lax.max), multiple_results=False) 2593 2594def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool, 2595 combine_fn: Callable): 2596 # Irrespective of backend, we always use the parallel prefix scan 2597 # implementation when differentiating because reduce_window is not 2598 # arbitrarily differentiable. 2599 return api.jvp(partial(associative_scan, combine_fn, axis=axis, 2600 reverse=reverse), 2601 primals, tangents) 2602 2603ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul) 2604ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min) 2605ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max) 2606 2607 2608@config.register_omnistaging_disabler 2609def omnistaging_disabler() -> None: 2610 global _initial_style_open_jaxpr, _initial_style_jaxpr, \ 2611 _initial_style_jaxprs_with_common_consts 2612 2613 @cache() 2614 def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals): 2615 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] 2616 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) 2617 with core.initial_style_staging(): # type: ignore 2618 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore 2619 wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore 2620 return jaxpr, out_pvals, consts, out_tree 2621 2622 @cache() 2623 def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): 2624 jaxpr, out_pvals, consts, out_tree = _initial_style_open_jaxpr( 2625 fun, in_tree, in_avals) 2626 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 2627 return closed_jaxpr, consts, out_tree() 2628 2629 @cache() 2630 def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], 2631 in_tree, in_avals): 2632 # When staging the branches of a conditional into jaxprs, constants are 2633 # extracted from each branch and converted to jaxpr arguments. To use the 2634 # staged jaxprs as the branches to a conditional *primitive*, we need for 2635 # their (input) signatures to match. This function "joins" the staged jaxprs: 2636 # for each one, it makes another that accepts *all* constants, but only uses 2637 # those that it needs (dropping the rest). 2638 2639 jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([ 2640 _initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs]) 2641 2642 newvar = core.gensym(jaxprs, suffix='_') 2643 all_const_avals = tuple( 2644 tuple(raise_to_shaped(core.get_aval(c)) for c in consts) 2645 for consts in all_consts) 2646 unused_const_vars = tuple( 2647 tuple(newvar(aval) for aval in const_avals) 2648 for const_avals in all_const_avals) 2649 2650 def pad_jaxpr_constvars(i, jaxpr): 2651 prefix = util.concatenate(unused_const_vars[:i]) 2652 suffix = util.concatenate(unused_const_vars[i+1:]) 2653 constvars = prefix + jaxpr.constvars + suffix 2654 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars, 2655 outvars=jaxpr.outvars, eqns=jaxpr.eqns) 2656 2657 def type_and_const_convert_jaxpr(jaxpr, out_pvals): 2658 return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 2659 2660 jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] 2661 closed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals) 2662 2663 return (tuple(closed_jaxprs), 2664 tuple(util.concatenate(all_consts)), 2665 tuple(out_tree() for out_tree in all_out_trees)) 2666