1# coding=utf-8
2# Copyright 2019 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16Control flow primitives.
17"""
18
19
20import collections
21import functools
22import inspect
23import itertools
24import operator
25import os
26from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
27
28import numpy as np
29
30import jax
31from jax import api
32from jax import core
33from jax import dtypes
34from jax._src import source_info_util
35from jax._src import util
36from jax._src.lax import lax
37from jax import linear_util as lu
38from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
39from jax.api_util import flatten_fun_nokwargs
40from jax.interpreters import ad
41from jax.interpreters import partial_eval as pe
42from jax.interpreters import xla
43from jax.interpreters import batching
44from jax.interpreters import masking
45from jax.lib import xla_bridge as xb
46from jax.lib import xla_client
47from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
48                           split_list, cache, extend_name_stack)
49from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
50                           treedef_children, treedef_tuple, tree_multimap,
51                           tree_leaves)
52from jax import ad_util
53from jax.config import config
54
55xops = xla_client.ops
56
57_map = safe_map
58zip = safe_zip
59_reduce = functools.reduce
60
61T = TypeVar('T')
62Array = Any
63
64@cache()
65def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
66  wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
67  jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
68  return jaxpr, consts, out_tree()
69
70@cache()
71def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
72  jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
73  closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
74  return closed_jaxpr, consts, out_tree
75
76@cache()
77def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
78                                             in_tree, in_avals):
79  # When staging the branches of a conditional into jaxprs, constants are
80  # extracted from each branch and converted to jaxpr arguments. To use the
81  # staged jaxprs as the branches to a conditional *primitive*, we need for
82  # their (input) signatures to match. This function "joins" the staged jaxprs:
83  # for each one, it makes another that accepts *all* constants, but only uses
84  # those that it needs (dropping the rest).
85
86  jaxprs, all_consts, all_out_trees = unzip3(
87      _initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs)
88
89  newvar = core.gensym(jaxprs, suffix='_')
90  all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
91                     for consts in all_consts]
92  unused_const_vars = [[newvar(aval) for aval in const_avals]
93                       for const_avals in all_const_avals]
94
95  def pad_jaxpr_constvars(i, jaxpr):
96    prefix = util.concatenate(unused_const_vars[:i])
97    suffix = util.concatenate(unused_const_vars[i + 1:])
98    constvars = [*prefix, *jaxpr.constvars, *suffix]
99    return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
100                      outvars=jaxpr.outvars, eqns=jaxpr.eqns)
101
102  consts = util.concatenate(all_consts)
103  jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
104  closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
105                   for jaxpr in jaxprs]
106  return closed_jaxprs, consts, all_out_trees
107
108def _abstractify(x):
109  return raise_to_shaped(core.get_aval(x))
110
111def _typecheck_param(prim, param, name, msg_required, pred):
112  msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
113         f'{msg_required} required:')
114  param_str = str(param)
115  sep = os.linesep if os.linesep in param_str else ' '
116  msg = sep.join([msg, param_str])
117  core.typecheck_assert(pred, msg)
118
119
120### fori_loop and while_loop
121
122def _fori_cond_fun(loop_carry):
123  i, upper, _ = loop_carry
124  return lax.lt(i, upper)
125
126@cache()
127def _fori_body_fun(body_fun):
128  def while_body_fun(loop_carry):
129    i, upper, x = loop_carry
130    return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
131  return while_body_fun
132
133@cache()
134def _fori_scan_body_fun(body_fun):
135  def scanned_fun(loop_carry, _):
136    i, upper, x = loop_carry
137    return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
138  return scanned_fun
139
140def fori_loop(lower, upper, body_fun, init_val):
141  """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
142
143  The type signature in brief is
144
145  .. code-block:: haskell
146
147    fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a
148
149  The semantics of ``fori_loop`` are given by this Python implementation::
150
151    def fori_loop(lower, upper, body_fun, init_val):
152      val = init_val
153      for i in range(lower, upper):
154        val = body_fun(i, val)
155      return val
156
157  Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
158  :func:`jax.lax.while_loop`. See the :func:`jax.lax.while_loop` documentation
159  for more information.
160
161  Also unlike the Python analogue, the loop-carried value ``val`` must hold a
162  fixed shape and dtype across all iterations (and not just be consistent up to
163  NumPy rank/shape broadcasting and dtype promotion rules, for example). In
164  other words, the type ``a`` in the type signature above represents an array
165  with a fixed shape and dtype (or a nested tuple/list/dict container data
166  structure with a fixed structure and arrays with fixed shape and dtype at the
167  leaves).
168
169  Args:
170    lower: an integer representing the loop index lower bound (inclusive)
171    upper: an integer representing the loop index upper bound (exclusive)
172    body_fun: function of type ``(int, a) -> a``.
173    init_val: initial loop carry value of type ``a``.
174
175  Returns:
176    Loop value from the final iteration, of type ``a``.
177  """
178  # TODO(phawkins): perhaps do more type checking here, better error messages.
179  lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
180  upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
181  if lower_dtype != upper_dtype:
182    msg = ("lower and upper arguments to fori_loop must have equal types, "
183           "got {} and {}")
184    raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
185
186  # If we can specialize on the trip count, call scan instead of a while_loop
187  # to enable efficient reverse-mode differentiation.
188  try:
189    lower_ = int(lower)
190    upper_ = int(upper)
191  except TypeError:
192    use_scan = False
193  else:
194    use_scan = False  # TODO(mattjj): re-enable this
195
196  if use_scan:
197    (_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
198                             (lower, upper, init_val), None,
199                             length=upper_ - lower_)
200  else:
201    _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
202                              (lower, upper, init_val))
203  return result
204
205
206def while_loop(cond_fun: Callable[[T], bool],
207               body_fun: Callable[[T], T],
208               init_val: T) -> T:
209  """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
210
211  The type signature in brief is
212
213  .. code-block:: haskell
214
215    while_loop :: (a -> Bool) -> (a -> a) -> a -> a
216
217  The semantics of ``while_loop`` are given by this Python implementation::
218
219    def while_loop(cond_fun, body_fun, init_val):
220      val = init_val
221      while cond_fun(val):
222        val = body_fun(val)
223      return val
224
225  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
226  to a single XLA While HLO. That makes it useful for reducing compilation times
227  for jit-compiled functions, since native Python loop constructs in an ``@jit``
228  function are unrolled, leading to large XLA computations.
229
230  Also unlike the Python analogue, the loop-carried value ``val`` must hold a
231  fixed shape and dtype across all iterations (and not just be consistent up to
232  NumPy rank/shape broadcasting and dtype promotion rules, for example). In
233  other words, the type ``a`` in the type signature above represents an array
234  with a fixed shape and dtype (or a nested tuple/list/dict container data
235  structure with a fixed structure and arrays with fixed shape and dtype at the
236  leaves).
237
238  Another difference from using Python-native loop constructs is that
239  ``while_loop`` is not reverse-mode differentiable because XLA computations
240  require static bounds on memory requirements.
241
242  Args:
243    cond_fun: function of type ``a -> Bool``.
244    body_fun: function of type ``a -> a``.
245    init_val: value of type ``a``, a type that can be a scalar, array, or any
246      pytree (nested Python tuple/list/dict) thereof, representing the initial
247      loop carry value.
248
249  Returns:
250    The output from the final iteration of body_fun, of type ``a``.
251  """
252  if jax.api._jit_is_disabled():
253    try:
254      val = init_val
255      while cond_fun(val):
256        val = body_fun(val)
257      return val
258    except core.ConcretizationTypeError:
259      # Can't run this while_loop in Python (e.g. because there's a vmap
260      # transformation on it), so we fall back to the primitive version.
261      pass
262
263  def _create_jaxpr(init_val):
264    init_vals, in_tree = tree_flatten((init_val,))
265    init_avals = tuple(_map(_abstractify, init_vals))
266    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
267    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
268    if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
269      msg = "cond_fun must return a boolean scalar, but got pytree {}."
270      raise TypeError(msg.format(cond_tree))
271    if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
272      msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
273      raise TypeError(msg.format(cond_jaxpr.out_avals))
274    return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
275
276  # The body input and output avals must match exactly. However, we want to account for
277  # the case when init contains weakly-typed values (e.g. Python scalars), with avals that
278  # may not match the output despite being compatible by virtue of their weak type.
279  # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
280  # necessary, a second time with modified init values.
281  init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
282  new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
283  if changed:
284    new_init_val, = tree_unflatten(in_tree, new_init_vals)
285    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
286  cond_jaxpr, cond_consts, body_consts, body_tree = rest
287
288  in_tree_children = in_tree.children()
289  assert len(in_tree_children) == 1
290  _check_tree_and_avals("body_fun output and input",
291                        body_tree, body_jaxpr.out_avals,
292                        in_tree_children[0], init_avals)
293  outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
294                      cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
295                      body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
296  return tree_unflatten(body_tree, outs)
297
298def _while_loop_abstract_eval(*args, **kwargs):
299  return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
300
301def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
302                                 cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
303  cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
304  batched = bool(cond_jaxpr.out_avals[0].shape)
305
306  # Since jaxprs don't have tuples and have multiple return values, but we need
307  # the HLO While loop to take a single tuple input and output a single boolean
308  # (for the cond computation) or a single tuple output (for the body
309  # computation), we build XLA computations that handle the tuple munging before
310  # generating a Call into the computations formed from the jaxprs.
311
312  init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
313
314  cond_c = xb.make_computation_builder("cond_computation")
315  cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
316  cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
317  x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
318  pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
319                            _map(partial(xb.constant, cond_c), cond_jaxpr.consts),
320                            extend_name_stack(name_stack, 'cond'), *(x + z))
321  if batched:
322    scalar = ShapedArray((), np.bool_)
323    or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
324    pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
325                         list(range(cond_jaxpr.out_avals[0].ndim)))
326
327  body_c = xb.make_computation_builder("body_computation")
328  body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
329  body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
330  x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
331  new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
332                            _map(partial(xb.constant, body_c), body_jaxpr.consts),
333                            extend_name_stack(name_stack, 'body'), *(y + z))
334  if batched:
335    body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
336                                   _map(partial(xb.constant, body_c), cond_jaxpr.consts),
337                                   extend_name_stack(name_stack, 'body_pred'), *(x + z))
338    new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
339    assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
340  new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
341
342  ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry)
343  ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
344  _,  _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
345  return xops.Tuple(c, z)
346
347def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
348  pred_shape = c.get_shape(pred).dimensions()
349  x_shape = c.get_shape(x).dimensions()
350  y_shape = c.get_shape(y).dimensions()
351  assert x_shape == y_shape
352  if x_y_aval is core.abstract_unit:
353    return x
354  elif x_y_aval is core.abstract_token:
355    return xops.AfterAll(c, [x, y])
356  else:
357    assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)]
358    bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
359    return xops.Select(bcast_pred, x, y)
360
361def _while_loop_batching_rule(args, dims, axis_name,
362                              cond_nconsts, cond_jaxpr,
363                              body_nconsts, body_jaxpr):
364  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
365  orig_batched = [d is not batching.not_mapped for d in dims]
366  cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
367
368  # Fixpoint computation of which carry are batched: either
369  # batched from init, or the carry out is batched. Each iteration promotes
370  # at least one carry to batched. We need at most len(carry) iterations,
371  # but we need one last iteration to prepare the jaxpr based on the final
372  # carry_bat.
373  carry_bat = init_bat
374  for _ in range(1 + len(carry_bat)):
375    batched = bconst_bat + carry_bat
376    body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
377        body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name)
378    cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
379        cond_jaxpr, size, cconst_bat + carry_bat,
380        instantiate=bool(cond_jaxpr.out_avals[0].shape),
381        axis_name=axis_name)
382    carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
383    if carry_bat_out == carry_bat:
384      break
385    else:
386      carry_bat = _map(operator.or_, carry_bat, carry_bat_out)
387  else:
388    assert False, "Fixpoint not reached"
389
390  consts, init = split_list(args, [cond_nconsts + body_nconsts])
391  const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
392  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
393                else x for x, d in zip(consts, const_dims)]
394  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
395              else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
396              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]
397
398  outs = while_p.bind(*(new_consts + new_init),
399                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
400                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
401  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
402  return outs, out_bdims
403
404def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
405                    body_jaxpr):
406  nonzeros = [type(t) is not ad_util.Zero for t in tangents]
407  cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
408
409  carry_nz = init_nz
410  for _ in range(1 + len(carry_nz)):
411    body_nonzeros = bconst_nz + carry_nz
412    body_jvp, nonzeros_out = ad.jvp_jaxpr(
413        body_jaxpr, body_nonzeros, instantiate=carry_nz)
414    if nonzeros_out == carry_nz:
415      break
416    carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
417  else:
418    assert False, "Fixpoint not reached"
419
420  nonzeros = cconst_nz + body_nonzeros
421  tangents = [ad.instantiate_zeros(t) if nz else t
422              for t, nz in zip(tangents, nonzeros)]
423
424  cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
425  _, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
426  bconst_dot = _prune_zeros(bconst_dot)
427  init_dot = _prune_zeros(init_dot)
428
429  num_carry = len(primals) - cond_nconsts - body_nconsts
430
431  body_jvp_rearranged = ad.rearrange_binders(
432      body_jvp,
433      [body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
434      [num_carry], [len(init_dot)])
435
436  newvar = core.gensym([cond_jaxpr.jaxpr])
437  invars_aug = (
438      cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
439  cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
440                                    invars_aug,
441                                    cond_jaxpr.jaxpr.outvars,
442                                    cond_jaxpr.jaxpr.eqns)
443  cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
444
445  out = while_p.bind(
446      *(cconst + bconst + bconst_dot + init + init_dot),
447      cond_nconsts=cond_nconsts,
448      cond_jaxpr=cond_jaxpr_augmented,
449      body_nconsts=len(bconst) + len(bconst_dot),
450      body_jaxpr=body_jvp_rearranged)
451
452  out_carry, out_carry_dot = split_list(out, [num_carry])
453  out_tangents_iter = iter(out_carry_dot)
454  out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
455                  for p, nz in zip(out_carry, nonzeros_out)]
456  return out_carry, out_tangents
457
458def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
459                        cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
460                        body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
461  """An implementation of partial evaluation for while.
462  As long as some carry (and hence output) are known and the output
463  of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
464  outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
465  the whole while, including recomputing the known parts.
466
467  This means that we don't actually save any computation by partial
468  evaluation if there are unknown outputs.
469
470  What this achieves is that we can give a proper error for reverse
471  differentiation of `while`, because in that use of partial evaluation the
472  primal inputs are considered "known", and only the tangent computation is
473  unknown (see issue #2129).
474  """
475  unknowns = [not t.pval.is_known() for t in tracers]
476  params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
477                body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
478
479  if config.omnistaging_enabled:
480    partial_eval_jaxpr = pe.partial_eval_jaxpr
481  else:
482    partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
483
484  cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
485  # Fixpoint computation of unknown carry. Each iteration promotes
486  # at least one carry to unknown. We need one last iteration to prepare the jaxpr.
487  carry_uk = carry_init_uk
488  for _ in range(1 + len(carry_uk)):
489    body_jaxpr_known, _, carry_out_uk = partial_eval_jaxpr(  # type: ignore
490        body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
491    if carry_out_uk == carry_uk:
492      break
493    else:
494      carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
495  else:
496    assert False, "Fixpoint not reached"
497
498  cond_jaxpr_known, _, cond_uk = partial_eval_jaxpr(  # type: ignore
499      cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
500
501  if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
502    # If conditional is unknown, or all inputs are known, or all are unknown,
503    # just do the default processing.
504    return trace.default_process_primitive(while_p, tracers, params)
505
506  # Run the known part of the while. Prepare the inputs, as constants (if known), or
507  # as core.unit.
508  in_consts = [ core.unit if uk else t.pval.get_known()
509                for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
510                                 tracers)]
511  # There should be no residuals for the cond_jaxpr_known
512  assert 1 == len(cond_jaxpr_known.out_avals)
513  # We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
514  # the type of outputs; residuals are at the end
515  if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
516    # TODO(necula): this is not quite enough; we should drop the residual computations also
517    body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
518  out_known = while_p.bind(
519    *in_consts,
520    cond_nconsts=cond_nconsts,
521    cond_jaxpr=cond_jaxpr_known,
522    body_nconsts=body_nconsts,
523    body_jaxpr=body_jaxpr_known)
524
525  # Run the whole while_loop to get all the outputs, then merge with known ones
526  out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
527  out_tracers: Sequence[pe.Tracer] = [
528    out_unknown if uk
529    else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
530    for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
531
532  return out_tracers
533
534def _while_transpose_error(*_, **kwargs):
535  raise ValueError("Reverse-mode differentiation does not work for "
536                   "lax.while_loop or lax.fori_loop. "
537                   "Try using lax.scan instead.")
538
539while_p = lax.Primitive('while')
540while_p.multiple_results = True
541while_p.def_impl(partial(xla.apply_primitive, while_p))
542while_p.def_abstract_eval(_while_loop_abstract_eval)
543ad.primitive_jvps[while_p] = _while_loop_jvp
544pe.custom_partial_eval_rules[while_p] = _while_partial_eval
545xla.initial_style_translations[while_p] = _while_loop_translation_rule
546ad.primitive_transposes[while_p] = _while_transpose_error
547batching.initial_style_batchers[while_p] = _while_loop_batching_rule
548
549
550### cond and switch
551
552def switch(index, branches: Sequence[Callable], operand):
553  """Apply exactly one of ``branches`` given by ``index``.
554
555  If ``index`` is out of bounds, it is clamped to within bounds.
556
557  Has the semantics of the following Python::
558
559    def switch(index, branches, operand):
560      index = clamp(0, index, len(branches) - 1)
561      return branches[index](operand)
562
563  Args:
564    index: Integer scalar type, indicating which branch function to apply.
565    branches: Sequence of functions (A -> B) to be applied based on `index`.
566    operand: Operand (A) input to whichever branch is applied.
567  """
568  if len(np.shape(index)) != 0:
569    raise TypeError(
570        f"Branch index must be scalar, "
571        f"got {index} of shape {np.shape(index)}.")
572
573  try:
574    index_dtype = dtypes.result_type(index)
575  except TypeError as err:
576    msg = f"Index type must be an integer, got {index}."
577    raise TypeError(msg) from err
578
579  if index_dtype.kind not in 'iu':
580    raise TypeError(
581        f"Index type must be an integer, got {index} as {index_dtype}")
582
583  branches = tuple(branches)
584
585  if len(branches) == 0:
586    raise ValueError("Empty branch sequence")
587  elif len(branches) == 1:
588    return branches[0](operand)
589
590  index = lax.convert_element_type(index, np.int32)
591  lo = np.array(0, np.int32)
592  hi = np.array(len(branches) - 1, np.int32)
593  index = lax.clamp(lo, index, hi)
594
595  if (jax.api._jit_is_disabled() and
596      isinstance(core.get_aval(index), ConcreteArray)):
597    return branches[int(index)](operand)
598
599  ops, ops_tree = tree_flatten((operand,))
600  ops_avals = tuple(_map(_abstractify, ops))
601
602  jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
603      branches, ops_tree, ops_avals)
604
605  for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
606    _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
607                          out_trees[0], jaxprs[0].out_avals,
608                          out_tree, jaxpr.out_avals)
609
610  linear = (False,) * (len(consts) + len(ops))
611  out = cond_p.bind(
612      index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
613  return tree_unflatten(out_trees[0], out)
614
615
616def cond(*args, **kwargs):
617  """Conditionally apply ``true_fun`` or ``false_fun``.
618
619  ``cond()`` has equivalent semantics to this Python implementation::
620
621    def cond(pred, true_fun, false_fun, operand):
622      if pred:
623        return true_fun(operand)
624      else:
625        return false_fun(operand)
626
627  ``pred`` must be a scalar type.
628
629  Functions ``true_fun``/``false_fun`` may not need to refer to an ``operand``
630  to compute their result, but one must still be provided to the ``cond`` call
631  and be accepted by both the branch functions, e.g.::
632
633    jax.lax.cond(
634        get_predicate_value(),
635        lambda _: 23,
636        lambda _: 42,
637        operand=None)
638
639
640  Args:
641    pred: Boolean scalar type, indicating which branch function to apply.
642    true_fun: Function (A -> B), to be applied if ``pred`` is True.
643    false_fun: Function (A -> B), to be applied if ``pred`` is False.
644    operand: Operand (A) input to either branch depending on ``pred``. The type
645      can be a scalar, array, or any pytree (nested Python tuple/list/dict)
646      thereof.
647
648  Returns:
649    Value (B) of either ``true_fun(operand)`` or ``false_fun(operand)``,
650    depending on the value of ``pred``. The type can be a scalar, array, or any
651    pytree (nested Python tuple/list/dict) thereof.
652  """
653
654  # detect an attempt to call the former, deprecated cond
655  try:
656    ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
657  except TypeError:
658    pass
659  else:
660    return _cond_with_per_branch_args(*ba.args)
661
662  return _cond(*args, **kwargs)
663
664def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
665  if len(np.shape(pred)) != 0:
666    raise TypeError(
667        f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.")
668
669  try:
670    pred_dtype = dtypes.result_type(pred)
671  except TypeError as err:
672    msg = ("Pred type must be either boolean or number, got {}.")
673    raise TypeError(msg.format(pred)) from err
674
675  if pred_dtype.kind != 'b':
676    if pred_dtype.kind in 'iuf':
677      pred = pred != 0
678    else:
679      msg = ("Pred type must be either boolean or number, got {}.")
680      raise TypeError(msg.format(pred_dtype))
681
682  if jax.api._jit_is_disabled() and isinstance(core.get_aval(pred), ConcreteArray):
683    if pred:
684      return true_fun(operand)
685    else:
686      return false_fun(operand)
687
688  ops, ops_tree = tree_flatten((operand,))
689  ops_avals = tuple(_map(_abstractify, ops))
690
691  jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
692      (true_fun, false_fun), ops_tree, ops_avals)
693  true_jaxpr, false_jaxpr = jaxprs
694  out_tree, false_out_tree = out_trees
695
696  _check_tree_and_avals("true_fun and false_fun output",
697                        out_tree, true_jaxpr.out_avals,
698                        false_out_tree, false_jaxpr.out_avals)
699
700  index = lax.convert_element_type(pred, np.int32)
701
702  linear = (False,) * (len(consts) + len(ops))
703  out = cond_p.bind(
704      index, *consts, *ops,
705      branches=(false_jaxpr, true_jaxpr), linear=linear)
706  return tree_unflatten(out_tree, out)
707
708def _cond_with_per_branch_args(pred,
709                               true_operand, true_fun: Callable,
710                               false_operand, false_fun: Callable):
711  """Conditionally apply ``true_fun`` or ``false_fun``.
712
713  Has equivalent semantics to this Python implementation::
714
715    def cond(pred, true_operand, true_fun, false_operand, false_fun):
716      if pred:
717        return true_fun(true_operand)
718      else:
719        return false_fun(false_operand)
720
721  Pred has to be a scalar type, collection types (list, tuple) are not supported
722  """
723  return _cond(pred,
724               lambda op: true_fun(op[0]),
725               lambda op: false_fun(op[1]),
726               (true_operand, false_operand))
727
728def _cond_abstract_eval(*args, **kwargs):
729  return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
730
731def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
732                           index, *args, branches, linear):
733  del linear  # Unused.
734
735  def make_computation(name, jaxpr, op_shape):
736    c = xb.make_computation_builder(name + '_comp')
737    op = xb.parameter(c, 0, op_shape)
738    ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
739    outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
740                             _map(partial(xb.constant, c), jaxpr.consts),
741                             extend_name_stack(name_stack, name + '_fun'), *ops)
742    return c.build(xops.Tuple(c, outs))
743
744  op = xops.Tuple(c, args)
745  op_shape = c.get_shape(op)
746  branch_computations = [
747      make_computation(f'branch_{i}', jaxpr, op_shape)
748      for i, jaxpr in enumerate(branches)]
749  return xops.Conditional(index, branch_computations, [op] * len(branches))
750
751def _select_tree(indices, branch_vals):
752  assert len(branch_vals) > 0
753  if len(branch_vals) == 1:
754    return branch_vals[0]
755  mid = len(branch_vals) // 2
756  mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
757  return lax.select(lax.lt(indices, mid),
758                    _select_tree(indices, branch_vals[:mid]),
759                    _select_tree(indices - mid, branch_vals[mid:]))
760
761def _cond_index_bcast_and_select_tree(indices, branch_vals):
762  if all(core.get_aval(x) is core.abstract_unit for x in branch_vals):
763    return branch_vals[0]
764  else:
765    bcast_indices = lax.broadcast_in_dim(
766        indices, np.shape(branch_vals[0]), list(range(np.ndim(indices))))
767    return _select_tree(bcast_indices, branch_vals)
768
769def _cond_batching_rule(args, dims, axis_name, branches, linear):
770  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
771  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
772  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
773          else x for x, d in zip(args, dims)]
774  orig_bat = [d is not batching.not_mapped for d in dims]
775  del dims
776  index, *ops = args
777  index_bat, *bat = orig_bat
778
779  branches_out_bat = [batching.batch_jaxpr(jaxpr, size, bat, False, axis_name)[1]
780                      for jaxpr in branches]
781  out_bat = [any(bat) for bat in zip(*branches_out_bat)]
782
783  branches_batched = tuple(batching.batch_jaxpr(jaxpr, size, bat, out_bat, axis_name)[0]
784                           for jaxpr in branches)
785
786  if index_bat:
787    branch_outs = []
788    for jaxpr in branches_batched:
789      out = core.jaxpr_as_fun(jaxpr)(*ops)
790      out = [batching.broadcast(x, size, 0) if not b else x
791             for x, b in zip(out, out_bat)]
792      branch_outs.append(out)
793    return [_cond_index_bcast_and_select_tree(index, outs)
794            for outs in zip(*branch_outs)], [0] * len(branch_outs[0])
795  else:
796    out_dims = [0 if b else batching.not_mapped for b in out_bat]
797    out = cond_p.bind(
798        index, *ops, branches=branches_batched, linear=linear)
799    return out, out_dims
800
801def _cond_jvp(primals, tangents, branches, linear):
802  nonzeros = [type(t) is not ad_util.Zero for t in tangents]
803
804  index_nz, *ops_nz = nonzeros
805  assert index_nz is False
806
807  branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
808                     for jaxpr in branches]
809  out_nz = [any(nz) for nz in zip(*branches_out_nz)]
810
811  branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
812                       for jaxpr in branches)
813
814  index, *ops = primals
815  _, *ops_dot = tangents
816  ops_dot = _prune_zeros(ops_dot)
817
818  ops_lin = tuple(linear)
819  linear_jvp = ops_lin + (True,) * len(ops_dot)
820  out = cond_p.bind(
821      index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
822  out_primals, out_tangents = split_list(out, [len(out_nz)])
823  out_tangents_iter = iter(out_tangents)
824  out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
825                  for p, nz in zip(out_primals, out_nz)]
826  return out_primals, out_tangents
827
828def _cond_partial_eval(trace, *tracers, branches, linear):
829  unknowns = [t.pval[0] is not None for t in tracers]
830  index_uk, *ops_uk = unknowns
831
832  if config.omnistaging_enabled:
833    partial_eval_jaxpr = pe.partial_eval_jaxpr
834  else:
835    partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
836
837  if index_uk:
838    # When the branch index is unknown, we stage out the whole cond.
839    params = dict(branches=branches, linear=linear)
840    return trace.default_process_primitive(cond_p, tracers, params)
841
842  branches_out_uks = []
843  for branch_jaxpr in branches:
844    _, _, out_uks = partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
845    branches_out_uks.append(out_uks)
846  out_uks = [any(uks) for uks in zip(*branches_out_uks)]
847
848  branches_1, branches_2, branch_res_avals = [], [], []
849  for branch_jaxpr in branches:
850    branch_jaxpr_1, branch_jaxpr_2, _ = partial_eval_jaxpr(
851        branch_jaxpr, ops_uk, instantiate=out_uks)
852    branch_num_res = len(branch_jaxpr_1.out_avals) - len(out_uks)
853
854    # move residuals to the front
855    move = [False] * len(ops_uk) + [True] * branch_num_res
856    branch_jaxpr_2 = pe.move_binders_to_front(branch_jaxpr_2, move)
857
858    # TODO(frostig,mattjj): pe.partial_eval_jaxpr should raise to shaped avals
859    res_avals = _map(
860        raise_to_shaped, branch_jaxpr_2.in_avals[:branch_num_res])
861
862    branches_1.append(branch_jaxpr_1)
863    branches_2.append(branch_jaxpr_2)
864    branch_res_avals.append(res_avals)
865
866  branches_1 = tuple(branches_1)
867  branches_2 = tuple(branches_2)
868
869  for jaxpr in branches_2[1:]:
870    assert len(jaxpr.out_avals) == len(branches_2[0].out_avals)
871
872  num_outs = len(branches_2[0].out_avals)
873
874  all_res_avals, res_avals_per_branch = _merge_branch_residuals(
875      branch_res_avals)
876
877  branches_1 = _join_cond_outputs(
878      branches_1, all_res_avals, res_avals_per_branch, num_outs)
879  branches_2 = _join_cond_pe_staged_jaxpr_inputs(
880      branches_2, all_res_avals, res_avals_per_branch)
881
882  # TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr
883  # raises to shaped avals
884  # for j in branches_1[1:]:
885  #   assert j.out_avals == branches_1[0].out_avals
886  num_res = len(all_res_avals)
887
888  _, in_consts = unzip2([t.pval for t in tracers])
889  out_consts_res = cond_p.bind(*in_consts, branches=branches_1, linear=linear)
890  out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
891
892  # TODO(frostig,mattjj): remove raised_to_shaped of avals once
893  # pe.partial_eval_jaxpr handles it
894  out_avals = _map(raise_to_shaped, branches_2[0].out_avals)
895  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uks)]
896
897  index_tracer = trace.instantiate_const(tracers[0])
898
899  ops_tracers = [trace.instantiate_const(t) if uk
900                 else trace.new_instantiated_literal(core.unit)
901                 for uk, t in zip(unknowns[1:], tracers[1:])]
902
903  res_tracers = _map(trace.new_instantiated_const, res)
904
905  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
906                 for pv, const in zip(out_pvs, out_consts)]
907
908  linear_2 = (False,) * num_res + linear
909  params = dict(branches=branches_2, linear=linear_2)
910  eqn = pe.new_eqn_recipe(
911      [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
912      source_info_util.current())
913  for t in out_tracers: t.recipe = eqn
914  return out_tracers
915
916# When partially evaluating conditionals, each branch produces residuals
917# depending on the computation carried out by the branch, and a corresponding
918# staged jaxpr that accepts those residuals as its first few inputs. The
919# residual-producing branches are staged as jaxprs and bound right away in a
920# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
921# conditional. The following helper functions ensure that both collections of
922# jaxprs (those evaluated and those staged) are valid for joint use under their
923# respective conditionals.
924#
925# In particular, the residuals derived from each original branch may have
926# distinct types. Because the branches of conditionals must have identical type
927# signatures, we join residuals together across branches into a common format.
928
929# In order to set up a type signature that all branches can conform to, it would
930# suffice to concatenate all branches' residuals. But concatenation can result
931# in redundant inputs and outputs, and might lead to memory allocation that
932# scales unnecessarily with the branch count. This function finds common
933# residual types across branches for reuse, so as to avoid redundant
934# allocation. It returns a list L of types (avals) representing the collection
935# of residuals merged according to type, and, for each branch, a lookup table to
936# match its residuals to their positions/types in L. Example input/output:
937#
938# [x], [y], [x, x]             -> [x, y, x],    [[0], [1], [0, 2]]
939# [x], [x], [x, x]             -> [x, x],       [[0], [0], [0, 1]]
940# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
941def _merge_branch_residuals(branch_res_avals):
942  def enumerate_equal(xs):
943    counts = {v: itertools.count() for v in set(xs)}
944    return [(x, next(counts[x])) for x in xs]
945  branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
946  all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
947  indices = {v: i for i, v in enumerate(all_tagged_avals)}
948  branch_indices = [
949      [indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
950  all_avals = [x for x, _ in all_tagged_avals]
951  return all_avals, branch_indices
952
953# This function augments branch outputs to agree with the merged residual
954# format: each branch is made to return zero-filled values in the places of
955# residual outputs that it does not populate.
956def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
957                       num_non_res_outputs):
958  def augment_jaxpr(jaxpr, res_indices):
959    @lu.wrap_init
960    def f_aug(*args):
961      outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
962      outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
963      aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
964      aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
965      return outs + list(aug_residuals)
966
967    return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
968
969  return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
970
971# This function augments branch inputs to agree with the merged residual format:
972# each branch is made to accept all residuals, even though it will ignore those
973# that it does not read.
974def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
975                                      res_aval_indices_per_jaxpr):
976  newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
977  all_res_vars = _map(newvar, all_res_avals)
978
979  def augment_jaxpr(jaxpr, res_indices):
980    num_res = len(res_indices)
981    res_vars = jaxpr.jaxpr.invars[:num_res]
982    non_res_vars = jaxpr.jaxpr.invars[num_res:]
983
984    aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
985    aug_invars = aug_res_vars + non_res_vars
986    jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
987                           jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
988    jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
989    return jaxpr_aug
990
991  return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
992
993def _ordered_unique(xs):
994  d = collections.OrderedDict((x, None) for x in xs)
995  return list(d.keys())
996
997def _transpose_cond_jaxpr(jaxpr, num_res):
998  res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
999  primal_avals = _map(raise_to_shaped, primal_avals)
1000
1001  @lu.wrap_init
1002  def transposed(*args):
1003    res, cts_out = split_list(args, [num_res])
1004    primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
1005    cts_in = ad.backward_pass(
1006        jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
1007    _, cts_in = split_list(cts_in, [num_res])
1008    return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
1009
1010  return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
1011
1012def _cond_transpose(cts, *args, branches, linear):
1013  index, *ops = args
1014  in_avals = _map(raise_to_shaped, branches[0].in_avals)
1015  num_res = len(ops) - sum(linear)
1016
1017  branches_trans = tuple(
1018      _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
1019  lin_in_avals = [raise_to_shaped(a, weak_type=False)
1020                  for a, l in zip(in_avals, linear) if l]
1021  assert all(core.typematch(out_aval, lin_in_aval)
1022             for jaxpr in branches_trans
1023             for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
1024
1025  res = ops[:num_res]
1026  cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
1027  linear_trans = (False,) * num_res + (True,) * len(cts)
1028
1029  out = cond_p.bind(
1030      index, *res, *cts, branches=branches_trans, linear=linear_trans)
1031  assert all(_map(core.typecheck, lin_in_avals, out))
1032
1033  out_iter = iter(out)
1034  out = [next(out_iter) if l else None for l in linear]
1035  assert next(out_iter, None) is None
1036  return [None] + out
1037
1038def _avals_short(avals):
1039  to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
1040  return ' '.join(_map(to_str, avals))
1041
1042def _cond_typecheck(*avals, branches, linear):
1043  tc = partial(_typecheck_param, 'cond')
1044  tc(branches, 'branches', 'tuple of ClosedJaxpr',
1045     type(branches) is tuple and
1046     all(type(x) is core.ClosedJaxpr for x in branches))
1047  tc(linear, 'linear', 'tuple of bool',
1048     type(linear) is tuple and all(type(x) is bool for x in linear))
1049
1050  core.typecheck_assert(
1051      len(branches) > 0,
1052      'cond requires at least one branch function')
1053  core.typecheck_assert(
1054      len(linear) + 1 == len(avals),
1055      f'cond given {len(linear)} linear flags for '
1056      f'{len(avals) - 1} non-predicate operands')
1057
1058  jaxpr0 = branches[0]
1059  jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
1060  jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
1061
1062  for i, jaxpr in enumerate(branches[1:]):
1063    core.typecheck_assert(
1064        len(jaxpr0.in_avals) == len(jaxpr.in_avals),
1065        f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
1066        f'branch {i+1} takes {len(jaxpr.in_avals)}')
1067    core.typecheck_assert(
1068        len(jaxpr0.out_avals) == len(jaxpr.out_avals),
1069        f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
1070        f'branch {i+1} outputs {len(jaxpr.out_avals)}')
1071    core.typecheck_assert(
1072        all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)),
1073        f'cond branches 0 and {i+1} have mismatching input types: '
1074        f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
1075    core.typecheck_assert(
1076        all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)),
1077        f'cond branches 0 and {i+1} have mismatching output types: '
1078        f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
1079
1080  core.typecheck_assert(
1081      len(avals) == 1 + len(jaxpr0.in_avals),
1082      f'cond called with {len(avals) - 1} non-predicate operands, '
1083      f'but branches take {len(jaxpr0.in_avals)} inputs')
1084
1085  index_aval, *op_avals = avals
1086  core.typecheck_assert(
1087      index_aval.dtype == np.int32,
1088      f'cond called with index of type {index_aval.dtype} instead of int32')
1089  core.typecheck_assert(
1090      all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
1091      f'cond branches take input types {jaxpr0_in_avals_str}, '
1092      f'called with operands of type {_avals_short(op_avals)}')
1093
1094def cond_bind(*args, branches, linear):
1095  if not core.skip_checks:
1096    avals = _map(core.get_aval, args)
1097    _cond_typecheck(*avals, branches=branches, linear=linear)
1098    for jaxpr in branches:
1099      core.check_jaxpr(jaxpr.jaxpr)
1100  return core.Primitive.bind(cond_p, *args, branches=branches, linear=linear)
1101
1102cond_p = lax.Primitive('cond')
1103cond_p.multiple_results = True
1104cond_p.def_impl(partial(xla.apply_primitive, cond_p))
1105cond_p.def_abstract_eval(_cond_abstract_eval)
1106cond_p.def_custom_bind(cond_bind)
1107ad.primitive_jvps[cond_p] = _cond_jvp
1108ad.primitive_transposes[cond_p] = _cond_transpose
1109pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
1110batching.initial_style_batchers[cond_p] = _cond_batching_rule
1111xla.initial_style_translations[cond_p] = _cond_translation_rule
1112core.custom_typechecks[cond_p] = _cond_typecheck
1113
1114
1115### scan
1116
1117Carry = TypeVar('Carry')
1118X = TypeVar('X')
1119Y = TypeVar('Y')
1120
1121def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
1122         init: Carry,
1123         xs: X,
1124         length: Optional[int] = None,
1125         reverse: bool = False,
1126         unroll: int = 1) -> Tuple[Carry, Y]:
1127  """Scan a function over leading array axes while carrying along state.
1128
1129  The type signature in brief is
1130
1131  .. code-block:: haskell
1132
1133    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
1134
1135  where we use [t] here to denote the type t with an additional leading axis.
1136  That is, if t is an array type then [t] represents the type with an additional
1137  leading axis, and if t is a pytree (container) type with array leaves then [t]
1138  represents the type with the same pytree structure and corresponding leaves
1139  each with an additional leading axis.
1140
1141  When ``a`` is an array type or None, and ``b`` is an array type, the semantics
1142  of ``scan`` are given roughly by this Python implementation::
1143
1144    def scan(f, init, xs, length=None):
1145      if xs is None:
1146        xs = [None] * length
1147      carry = init
1148      ys = []
1149      for x in xs:
1150        carry, y = f(carry, x)
1151        ys.append(y)
1152      return carry, np.stack(ys)
1153
1154  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
1155  types, and so multiple arrays can be scanned over at once and produce multiple
1156  output arrays. (None is actually an empty pytree.)
1157
1158  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
1159  a single XLA While HLO. That makes it useful for reducing compilation times
1160  for jit-compiled functions, since native Python loop constructs in an ``@jit``
1161  function are unrolled, leading to large XLA computations.
1162
1163  Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
1164  across all iterations (and not just be consistent up to NumPy rank/shape
1165  broadcasting and dtype promotion rules, for example). In other words, the type
1166  ``c`` in the type signature above represents an array with a fixed shape and
1167  dtype (or a nested tuple/list/dict container data structure with a fixed
1168  structure and arrays with fixed shape and dtype at the leaves).
1169
1170  Args:
1171    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
1172      that ``f`` accepts two arguments where the first is a value of the loop
1173      carry and the second is a slice of ``xs`` along its leading axis, and that
1174      ``f`` returns a pair where the first element represents a new value for
1175      the loop carry and the second represents a slice of the output.
1176    init: an initial loop carry value of type ``c``, which can be a scalar,
1177      array, or any pytree (nested Python tuple/list/dict) thereof, representing
1178      the initial loop carry value. This value must have the same structure as
1179      the first element of the pair returned by ``f``.
1180    xs: the value of type ``[a]`` over which to scan along the leading axis,
1181      where ``[a]`` can be an array or any pytree (nested Python
1182      tuple/list/dict) thereof with consistent leading axis sizes.
1183    length: optional integer specifying the number of loop iterations, which
1184      must agree with the sizes of leading axes of the arrays in ``xs`` (but can
1185      be used to perform scans where no input ``xs`` are needed).
1186    reverse: optional boolean specifying whether to run the scan iteration
1187      forward (the default) or in reverse, equivalent to reversing the leading
1188      axes of the arrays in both ``xs`` and in ``ys``.
1189    unroll: optional positive int specifying, in the underlying operation of the
1190      scan primitive, how many scan iterations to unroll within a single
1191      iteration of a loop.
1192
1193  Returns:
1194    A pair of type ``(c, [b])`` where the first element represents the final
1195    loop carry value and the second element represents the stacked outputs of
1196    the second output of ``f`` when scanned over the leading axis of the inputs.
1197  """
1198  xs_flat, xs_tree = tree_flatten(xs)
1199
1200  try:
1201    lengths = [x.shape[0] for x in xs_flat]
1202  except AttributeError as err:
1203    msg = "scan got value with no leading axis to scan over: {}."
1204    raise ValueError(
1205      msg.format(', '.join(str(x) for x in xs_flat
1206                           if not hasattr(x, 'shape')))) from err
1207
1208  if length is not None:
1209    length = int(length)
1210    if not all(length == l for l in lengths):
1211      msg = ("scan got `length` argument of {} which disagrees with "
1212             "leading axis sizes {}.")
1213      raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
1214  else:
1215    unique_lengths = set(lengths)
1216    if len(unique_lengths) > 1:
1217      msg = "scan got values with different leading axis sizes: {}."
1218      raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
1219    elif len(unique_lengths) == 0:
1220      msg = "scan got no values to scan over and `length` not provided."
1221      raise ValueError(msg)
1222    else:
1223      length, = unique_lengths
1224
1225  if jax.api._jit_is_disabled():
1226    carry = init
1227    ys = []
1228    maybe_reversed = reversed if reverse else lambda x: x
1229    for i in maybe_reversed(range(length)):
1230      xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
1231      carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
1232      ys.append(y)
1233    stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
1234                            else jax.numpy.stack((y, *ys)))
1235    stacked_y = tree_multimap(stack, *maybe_reversed(ys))
1236    return carry, stacked_y
1237
1238  x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
1239  x_dtypes = [x.dtype for x in xs_flat]
1240  x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
1241
1242  def _create_jaxpr(init):
1243    init_flat, init_tree = tree_flatten(init)
1244    in_flat, in_tree = tree_flatten((init, xs))
1245
1246    carry_avals = tuple(_map(_abstractify, init_flat))
1247    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
1248    out_tree_children = out_tree.children()
1249    if len(out_tree_children) != 2:
1250      msg = "scan body output must be a pair, got {}."
1251      raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
1252    carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
1253    return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
1254
1255  # The carry input and output avals must match exactly. However, we want to account for
1256  # the case when init contains weakly-typed values (e.g. Python scalars), with avals that
1257  # may not match the output despite being compatible by virtue of their weak type.
1258  # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
1259  # necessary, a second time with modified init values.
1260  init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
1261  new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
1262  if changed:
1263    new_init = tree_unflatten(init_tree, new_init_flat)
1264    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
1265  in_flat, jaxpr, consts, out_tree, out_tree_children = rest
1266
1267  _check_tree_and_avals("scan carry output and input",
1268                        # Extract the subtree and avals for the first element of the return tuple
1269                        out_tree_children[0], carry_avals_out,
1270                        init_tree, carry_avals)
1271
1272  out = scan_p.bind(*itertools.chain(consts, in_flat),
1273                    reverse=reverse, length=length, jaxpr=jaxpr,
1274                    num_consts=len(consts), num_carry=len(init_flat),
1275                    linear=(False,) * (len(consts) + len(in_flat)),
1276                    unroll=unroll)
1277  return tree_unflatten(out_tree, out)
1278
1279def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
1280                        f_impl, x_avals, y_avals):
1281  consts, init, xs = split_list(args, [num_consts, num_carry])
1282
1283  carry = init
1284  ys = []
1285
1286  for i in range(length):
1287    i_ = length - i - 1 if reverse else i
1288    x = _map(partial(_index_array, i_), x_avals, xs)
1289    out = f_impl(*consts, *carry, *x)
1290    carry, y = split_list(out, [num_carry])
1291    ys.append(y)
1292
1293  ys = list(reversed(ys)) if reverse else ys
1294  ys = list(zip(*ys))
1295  ys = _map(_stack, y_avals, ys)
1296  return (*carry, *ys)
1297
1298def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
1299                    f_impl, x_avals, y_avals):
1300  consts, init, xs = split_list(args, [num_consts, num_carry])
1301
1302  def cond_fun(vals):
1303    i, *_ = vals
1304    return i < length
1305
1306  def body_fun(vals):
1307    [i], carry, ys = split_list(vals, [1, num_carry])
1308    i_ = length - i - 1 if reverse else i
1309    x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
1310    out_flat = f_impl(*consts, *carry, *x)
1311    carry_out, y_updates = split_list(out_flat, [num_carry])
1312    ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
1313    return [i + 1] + carry_out + ys_out
1314
1315  ys_init = _map(partial(_empty_array, length), y_avals)
1316  if length == 0:
1317    return init + ys_init
1318  else:
1319    init_val = [lax._const(length, 0)] + init + ys_init
1320    _, *outs = while_loop(cond_fun, body_fun, init_val)
1321    return outs
1322
1323def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
1324                              linear, block_length, f_impl, x_avals, y_avals):
1325  consts, init, xs = split_list(args, [num_consts, num_carry])
1326
1327  num_blocks, rem = divmod(length, block_length)
1328  assert rem == 0
1329
1330  partition = partial(_partition_leading, num_blocks, block_length)
1331  xs_block = _map(partition, x_avals, xs)
1332
1333  prepend_aval = partial(_prepend_dim_to_aval, block_length)
1334  x_block_avals = _map(prepend_aval, x_avals)
1335  y_block_avals = _map(prepend_aval, y_avals)
1336
1337  f_impl_block = partial(
1338      _scan_impl_unrolled, reverse=reverse, length=block_length,
1339      num_consts=num_consts, num_carry=num_carry, linear=linear,
1340      f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
1341
1342  outs = _scan_impl_loop(
1343      *consts, *init, *xs_block, reverse=reverse, length=num_blocks,
1344      num_consts=num_consts, num_carry=num_carry, linear=linear,
1345      f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
1346
1347  carry, ys_blocks = split_list(outs, [num_carry])
1348  combine = partial(_combine_leading, num_blocks, block_length)
1349  ys = _map(combine, y_avals, ys_blocks)
1350  return (*carry, *ys)
1351
1352def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
1353               unroll):
1354  _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
1355  _, y_avals = split_list(jaxpr.out_avals, [num_carry])
1356  f_impl = core.jaxpr_as_fun(jaxpr)
1357
1358  if unroll == 1:
1359    return _scan_impl_loop(
1360        *args, reverse=reverse, length=length, num_consts=num_consts,
1361        num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
1362        y_avals=y_avals)
1363
1364  consts, init, xs = split_list(args, [num_consts, num_carry])
1365  num_blocks, rem = divmod(length, unroll)
1366  length_div = num_blocks * unroll
1367
1368  if rem > 0:
1369    if reverse:
1370      split = partial(_split_leading_dim, rem)
1371      xs_rem, xs = unzip2(_map(split, x_avals, xs))
1372    else:
1373      split = partial(_split_leading_dim, length_div)
1374      xs, xs_rem = unzip2(_map(split, x_avals, xs))
1375
1376  outs = _scan_impl_block_unrolled(
1377      *consts, *init, *xs, reverse=reverse, length=length_div,
1378      num_consts=num_consts, num_carry=num_carry, linear=linear,
1379      block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
1380
1381  carry, ys = split_list(outs, [num_carry])
1382
1383  if rem > 0:
1384    outs = _scan_impl_unrolled(
1385        *consts, *carry, *xs_rem, reverse=reverse, length=rem,
1386        num_consts=num_consts, num_carry=num_carry, linear=linear,
1387        f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
1388    carry, ys_rem = split_list(outs, [num_carry])
1389    if reverse:
1390      ys = _map(_concatenate, y_avals, ys_rem, ys)
1391    else:
1392      ys = _map(_concatenate, y_avals, ys, ys_rem)
1393
1394  return (*carry, *ys)
1395
1396def _stack(aval, vals):
1397  if aval is core.abstract_unit:
1398    return core.unit
1399  else:
1400    vals = [lax.expand_dims(x, (0,)) for x in vals]
1401    return lax.concatenate(vals, 0)
1402
1403def _concatenate(aval, x1, x2):
1404  if aval is core.abstract_unit:
1405    return core.unit
1406  else:
1407    return lax.concatenate([x1, x2], 0)
1408
1409def _split_leading_dim(i, aval, x):
1410  if aval is core.abstract_unit:
1411    return (core.unit, core.unit)
1412  else:
1413    assert x.ndim >= 1
1414    return (lax.slice_in_dim(x, 0, i),
1415            lax.slice_in_dim(x, i, x.shape[0]))
1416
1417def _dynamic_index_array(i, aval, x):
1418  if aval is core.abstract_unit:
1419    return core.unit
1420  else:
1421    return lax.dynamic_index_in_dim(x, i, keepdims=False)
1422
1423def _index_array(i, aval, x):
1424  if aval is core.abstract_unit:
1425    return core.unit
1426  else:
1427    return lax.index_in_dim(x, i, keepdims=False)
1428
1429def _empty_array(sz, aval):
1430  if aval is core.abstract_unit:
1431    return core.unit
1432  else:
1433    return lax.full((sz,) + aval.shape, 0, aval.dtype)
1434
1435def _update_array(i, aval, xs, x):
1436  if aval is core.abstract_unit:
1437    return core.unit
1438  else:
1439    return lax.dynamic_update_index_in_dim(xs, x, i, 0)
1440
1441def _partition_leading(sz0, sz1, aval, x):
1442  if aval is core.abstract_unit:
1443    return core.unit
1444  else:
1445    assert x.ndim >= 1
1446    assert x.shape[0] == sz0 * sz1
1447    return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
1448
1449def _combine_leading(sz0, sz1, aval, x):
1450  if aval is core.abstract_unit:
1451    return core.unit
1452  else:
1453    assert x.ndim >= 2
1454    assert x.shape[0] == sz0
1455    assert x.shape[1] == sz1
1456    return lax.collapse(x, 0, 2)
1457
1458def _prepend_dim_to_aval(sz, aval):
1459  if aval is core.abstract_unit:
1460    return aval
1461  elif isinstance(aval, ShapedArray):
1462    return ShapedArray((sz, *aval.shape), aval.dtype)
1463  else:
1464    raise TypeError(f'Prepending dim {sz} to aval {aval}')
1465
1466def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
1467                        linear, unroll):
1468  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
1469  ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
1470              if aval is not core.abstract_unit else aval for aval in y_avals]
1471  return carry_avals + ys_avals
1472
1473def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
1474              linear, unroll):
1475  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
1476  num_ys = len(jaxpr.out_avals) - num_carry
1477  nonzeros = [type(t) is not ad_util.Zero for t in tangents]
1478  const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
1479
1480  # Fixpoint computation of which carry are not ad.zero: either
1481  # non-zero from init, or the carry out is non-zero. Each iteration promotes
1482  # at least one carry to non-zero. We need at most len(carry) iterations,
1483  # but we need one last iteration to prepare the jaxpr based on the final
1484  # carry_nz.
1485  carry_nz = init_nz
1486  for _ in range(1 + len(carry_nz)):
1487    nonzeros = const_nz + carry_nz + xs_nz
1488    jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
1489        jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
1490    carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
1491    if carry_nz_out == carry_nz:
1492      break
1493    else:
1494      carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
1495  else:
1496    assert False, "Fixpoint not reached"
1497
1498  tangents = [ad.instantiate_zeros(t) if nz else t
1499              for t, nz in zip(tangents, nonzeros)]
1500
1501  consts, init, xs = split_list(primals, [num_consts, num_carry])
1502  all_tangents = split_list(tangents, [num_consts, num_carry])
1503  consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
1504
1505  jaxpr_jvp_rearranged = ad.rearrange_binders(
1506      jaxpr_jvp,
1507      [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
1508      [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
1509
1510  consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
1511  jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
1512                           + init_linear + [True] * len(init_dot)
1513                           + xs_linear + [True] * len(xs_dot))
1514
1515  out_flat = scan_p.bind(
1516      *(consts + consts_dot + init + init_dot + xs + xs_dot),
1517      reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
1518      num_consts=num_consts + len(consts_dot),
1519      num_carry=num_carry + len(init_dot),
1520      linear=jaxpr_jvp_linear, unroll=unroll)
1521
1522  carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
1523  primals_out = carry + ys
1524  tangents_out_iter = iter(carry_dot + ys_dot)
1525  tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
1526                  for p, nz in zip(primals_out, nonzeros_out)]
1527  return primals_out, tangents_out
1528
1529def _prune_zeros(ts):
1530  return [t for t in ts if type(t) is not ad_util.Zero]
1531
1532def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
1533                       jaxpr, linear, unroll):
1534  if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace:  # type: ignore
1535    params = dict(reverse=reverse, length=length, num_consts=num_consts,
1536                  num_carry=num_carry, jaxpr=jaxpr, linear=linear,
1537                  unroll=unroll)
1538    return trace.default_process_primitive(scan_p, tracers, params)
1539
1540  num_ys = len(jaxpr.out_avals) - num_carry
1541
1542  unknowns = [t.pval[0] is not None for t in tracers]
1543  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
1544
1545  if config.omnistaging_enabled:
1546    partial_eval_jaxpr = pe.partial_eval_jaxpr
1547  else:
1548    partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
1549
1550  # Fixpoint computation of which carry are unknown (not a constant): either
1551  # unknown from init, or the carry out is unknown. Each iteration promotes
1552  # at least one carry to unknown. We need at most len(carry) iterations,
1553  # but we need one last iteration to prepare the jaxpr based on the final
1554  # carry_uk.
1555  carry_uk = init_uk
1556  for _ in range(1 + len(carry_uk)):
1557    unknowns = const_uk + carry_uk + xs_uk
1558    jaxpr_1, jaxpr_2, out_uk = partial_eval_jaxpr(
1559        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
1560    carry_uk_out = out_uk[:num_carry]
1561    if carry_uk_out == carry_uk:
1562      break
1563    else:
1564      carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
1565  else:
1566    assert False, "Fixpoint not reached"
1567  num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
1568
1569  # The residuals are treated as extensive outputs of jaxpr_1 (and extensive
1570  # inputs to jaxpr_2), but residuals that are loop-invariant can be hoisted.
1571  # TODO(mattjj): hoist other loop-invariant values here too (instantiate=False)
1572  invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1])
1573                     for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
1574  other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
1575  in_pvals_1 = invariant_pvals + other_pvals
1576  jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
1577      lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
1578      instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
1579  jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
1580  num_consts_1 = num_consts + len(consts_1)
1581  # any now-known residuals are intensive, so we want to revise jaxpr_2 to take
1582  # those inputs as constants rather than as extensive inputs
1583  _, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
1584  intensive_residuals = [const for pv, const in res_pvals if pv is None]
1585  move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
1586  jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
1587  num_consts_2 = num_consts + len(intensive_residuals)
1588
1589  # As another optimization, for any extensive inputs that are just forwarded to
1590  # extensive outputs, to avoid a copy (looping over dynamic-update-slice) we'd
1591  # rather just forward the input tracer. That means pruning some extensive
1592  # outputs from the jaxpr here, and updating out_flat below.
1593  extensive_invars = jaxpr_1_opt.jaxpr.invars[num_consts_1 + num_carry:]
1594  extensive_outvars = jaxpr_1_opt.jaxpr.outvars[num_carry:]
1595  extensive_avals = [core.unmapped_aval(length, 0, core.raise_to_shaped(v.aval))
1596                     for v in extensive_outvars]
1597  fwd_extensive = [num_consts + num_carry + extensive_invars.index(v)
1598                   if v in extensive_invars else None for v in extensive_outvars]
1599  jaxpr_1_opt.jaxpr.outvars = (
1600      jaxpr_1_opt.jaxpr.outvars[:num_carry] +
1601      [v for i, v in zip(fwd_extensive, extensive_outvars) if i is None])
1602
1603  in_consts = (list(consts_1) + [core.unit] * num_consts +
1604               [core.unit if uk else t.pval[1]
1605                for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])])
1606  linear_1 = ([False] * len(consts_1) + [True] * num_consts +
1607              [lin or uk for uk, lin
1608               in zip(unknowns[num_consts:], linear[num_consts:])])
1609  out_flat = scan_p.bind(
1610      *in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
1611      num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1),
1612      unroll=unroll)
1613
1614  # Propagate the forwarded extensive outputs using fwd_extensive. Any
1615  # numpy.ndarray inputs should be converted to JAX DeviceArrays.
1616  out_carry, out_extensive = split_list(out_flat, [num_carry])
1617  out_extensive_iter = iter(out_extensive)
1618  out_extensive = [next(out_extensive_iter) if i is None
1619                   else _maybe_device_put(tracers[i].pval[1]) if tracers[i].is_known()
1620                   else tracers[i] for i in fwd_extensive]
1621  assert all(a == core.raise_to_shaped(core.get_aval(out))
1622             for a, out in zip(extensive_avals, out_extensive))
1623  out_flat = out_carry + out_extensive
1624
1625  out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
1626  extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
1627
1628  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
1629                 for uk, t in zip(unknowns, tracers)]
1630  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
1631  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
1632  out_avals = carry_avals + ys_avals
1633  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]
1634
1635  out_consts = out_carry + ys
1636  int_res_tracers = _map(trace.new_instantiated_const, intensive_residuals)
1637  ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
1638  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
1639                 for pv, const in zip(out_pvs, out_consts)]
1640  linear_2 = ([False] * len(int_res_tracers) +
1641              [lin or not uk for uk, lin in zip(unknowns, linear)] +
1642              [False] * len(ext_res_tracers))
1643  eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
1644                          out_tracers, scan_p,
1645                          dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
1646                               num_consts=num_consts_2,
1647                               num_carry=num_carry, linear=tuple(linear_2),
1648                               unroll=unroll),
1649                          source_info_util.current())
1650  for t in out_tracers: t.recipe = eqn
1651  return out_tracers
1652
1653def _maybe_device_put(x):
1654  if isinstance(x, np.ndarray):
1655    return lax._device_put_raw(x)
1656  else:
1657    return x
1658
1659def _promote_aval_rank(sz, aval):
1660  if aval is core.abstract_unit:
1661    return core.abstract_unit
1662  else:
1663    return ShapedArray((sz,) + aval.shape, aval.dtype)
1664
1665def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
1666                    linear, unroll):
1667  # we've only implemented transposing scans with specific lin/nonlin patterns
1668  consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
1669  num_ires = len(consts_lin) - sum(consts_lin)
1670  num_eres = len(xs_lin) - sum(xs_lin)
1671  if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
1672    raise NotImplementedError
1673  if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
1674    raise NotImplementedError
1675  if not all(init_lin):
1676    pass  # TODO(mattjj): error check https://github.com/google/jax/issues/1963
1677
1678  consts, _, xs = split_list(args, [num_consts, num_carry])
1679  ires, _ = split_list(consts, [num_ires])
1680  _, eres = split_list(xs, [sum(xs_lin)])
1681  assert not any(ad.is_undefined_primal(r) for r in ires)
1682  assert not any(ad.is_undefined_primal(r) for r in eres)
1683
1684  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
1685  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
1686  ct_carry, ct_ys = split_list(cts, [num_carry])
1687  ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
1688  ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
1689  ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
1690
1691  #       jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
1692  # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
1693  jaxpr_trans = _transpose_scan_jaxpr(
1694      num_ires, num_consts - num_ires, num_eres, jaxpr)
1695  linear_trans = ([False] * num_ires +
1696                  [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
1697                  [False] * num_eres)
1698
1699  outs = scan_p.bind(
1700      *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
1701      length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
1702      num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
1703      unroll=unroll)
1704  ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
1705  return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
1706
1707# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
1708#                         -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
1709def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
1710  num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
1711  res1_avals, c_avals, a_avals, res2_avals = split_list(
1712      jaxpr.in_avals, [num_res1, num_c, num_a])
1713  num_b = len(jaxpr.out_avals)
1714  b_avals = list(jaxpr.out_avals)
1715
1716  @lu.wrap_init
1717  def transposed(*res1_cbar_bbar_res2):
1718    res1, c_bar, b_bar, res2 = split_list(
1719        res1_cbar_bbar_res2, [num_res1, num_c, num_b])
1720    primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
1721               [ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
1722    cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
1723    _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
1724    a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
1725    c_bar = _map(ad.instantiate_zeros_aval, c_avals,
1726                _map(ad.add_tangents, c_bar, new_c_bar))
1727    return c_bar + a_bar
1728  return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
1729
1730def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
1731  if config.omnistaging_enabled:
1732    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
1733  else:
1734    pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
1735    jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
1736    out_avals, _ = unzip2(pvals_out)
1737  return core.ClosedJaxpr(jaxpr, consts)
1738
1739
1740def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts,
1741                        num_carry, linear, unroll):
1742  num_ys = len(jaxpr.out_avals) - num_carry
1743  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
1744  orig_batched = [d is not batching.not_mapped for d in dims]
1745  const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
1746
1747  # Fixpoint computation of which carry are batched: either
1748  # batched from init, or the carry out is batched. Each iteration promotes
1749  # at least one carry to batched. We need at most len(carry) iterations,
1750  # but we need one last iteration to prepare the jaxpr based on the final
1751  # carry_batched.
1752  carry_batched = init_batched
1753  for _ in range(1 + len(carry_batched)):
1754    batched = const_batched + carry_batched + xs_batched
1755    jaxpr_batched, batched_out = batching.batch_jaxpr(
1756        jaxpr, size, batched,
1757        instantiate=carry_batched + [False] * num_ys,
1758        axis_name=axis_name)
1759    carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
1760    if carry_batched_out == carry_batched:
1761      break
1762    else:
1763      carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
1764  else:
1765    assert False, "Fixpoint not reached"
1766
1767  consts, init, xs = split_list(args, [num_consts, num_carry])
1768  consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
1769  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
1770                else x for x, d in zip(consts, consts_bdims)]
1771  new_init = [batching.broadcast(x, size, 0) if now_batched and not was_batched
1772              else batching.moveaxis(x, d, 0) if now_batched else x
1773              for x, d, was_batched, now_batched in
1774              zip(init, init_bdims, init_batched, carry_batched)]
1775  new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
1776            else x for x, d in zip(xs, xs_bdims)]
1777  new_args = new_consts + new_init + new_xs
1778
1779  outs = scan_p.bind(
1780      *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
1781      num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
1782  carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
1783  ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
1784  return outs, carry_bdims + ys_bdims
1785
1786def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
1787                       jaxpr, num_consts, num_carry, linear, unroll):
1788  dynamic_length, = masking.shape_as_value((length,))
1789  masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
1790  consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
1791  max_length, = {x.shape[0] for x in xs}
1792  const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
1793  out_vals = scan_p.bind(
1794      *itertools.chain([dynamic_length] + consts, [0], init, xs),
1795      reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
1796      num_consts=1 + num_consts, num_carry=1 + num_carry,
1797      linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
1798      unroll=unroll)
1799  return out_vals[1:]
1800
1801def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
1802  fun = core.jaxpr_as_fun(jaxpr)
1803
1804  @lu.wrap_init
1805  def masked(*args):
1806    [dynamic_length], consts, [i], carry, xs = split_list(
1807        args, [1, num_consts, 1, num_carry])
1808    out = fun(*(consts + carry + xs))
1809    new_carry, ys = split_list(out, [num_carry])
1810    new_carry = [lax.select(i < dynamic_length, new_c, c)
1811                 for new_c, c in zip(new_carry, carry)]
1812    return [i + 1] + new_carry + ys
1813
1814  aval = ShapedArray((), dtypes.int_)
1815  const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
1816  return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
1817
1818def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
1819                    jaxpr, linear, unroll):
1820  tc = partial(_typecheck_param, 'scan')
1821  tc(reverse, 'reverse', 'bool', type(reverse) is bool)
1822  tc(num_consts, 'num_consts', 'non-negative int',
1823     type(num_consts) is int and num_consts >= 0)
1824  tc(num_carry, 'num_carry', 'non-negative int',
1825     type(num_carry) is int and num_carry >= 0)
1826  tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
1827  tc(linear, 'linear', 'tuple of bool',
1828     type(linear) is tuple and all(type(x) is bool for x in linear))
1829  tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
1830
1831  length_types = (int, masking.Poly) if bind_time else (int,)
1832  tc(length, 'length', 'non-negative int',
1833     type(length) in length_types and length >= 0)
1834
1835  core.typecheck_assert(
1836      len(linear) == len(avals),
1837      f'scan param linear has length {len(linear)} for {len(avals)} operands')
1838
1839  const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
1840  const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
1841      jaxpr.in_avals, [num_consts, num_carry])
1842  carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry])
1843  x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
1844
1845  core.typecheck_assert(
1846      all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)),
1847      f'scan input carry input and output types mismatch: '
1848      f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
1849  core.typecheck_assert(
1850      all(_map(core.typecompat, const_avals_jaxpr, const_avals)),
1851      f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
1852      f'called with consts of type\n{_avals_short(const_avals)}')
1853  core.typecheck_assert(
1854      all(_map(core.typecompat, init_avals_jaxpr, init_avals)),
1855      f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
1856      f'called with initial carry of type\n{_avals_short(init_avals)}')
1857  core.typecheck_assert(
1858      all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)),
1859      f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
1860      f'called with sequence of type\n{_avals_short(x_avals)}')
1861
1862def scan_bind(*args, **params):
1863  if not core.skip_checks:
1864    avals = _map(core.get_aval, args)
1865    _scan_typecheck(True, *avals, **params)
1866    core.check_jaxpr(params['jaxpr'].jaxpr)
1867  return core.Primitive.bind(scan_p, *args, **params)
1868
1869scan_p = core.Primitive("scan")
1870scan_p.multiple_results = True
1871scan_p.def_custom_bind(scan_bind)
1872scan_p.def_impl(_scan_impl)
1873# scan_p.def_impl(partial(xla.apply_primitive, scan_p))  # TODO(mattjj): re-enable
1874scan_p.def_abstract_eval(_scan_abstract_eval)
1875ad.primitive_jvps[scan_p] = _scan_jvp
1876ad.primitive_transposes[scan_p] = _scan_transpose
1877pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
1878xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
1879batching.initial_style_batchers[scan_p] = _scan_batching_rule
1880masking.masking_rules[scan_p] = _scan_masking_rule
1881core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
1882
1883
1884def map(f, xs):
1885  """Map a function over leading array axes.
1886
1887  Like Python's builtin map, except inputs and outputs are in the form of
1888  stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
1889  need to apply a function element by element for reduced memory usage or
1890  heterogeneous computation with other control flow primitives.
1891
1892  When ``xs`` is an array type, the semantics of ``map`` are given by this
1893  Python implementation::
1894
1895    def map(f, xs):
1896      return np.stack([f(x) for x in xs])
1897
1898  Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
1899  the same advantages over a Python loop apply: ``xs`` may be an arbitrary
1900  nested pytree type, and the mapped computation is compiled only once.
1901
1902  Args:
1903    f: a Python function to apply element-wise over the first axis or axes of
1904      ``xs``.
1905    xs: values over which to map along the leading axis.
1906
1907  Returns:
1908    Mapped values.
1909  """
1910  g = lambda _, x: ((), f(x))
1911  _, ys = scan(g, (), xs)
1912  return ys
1913
1914
1915def _concat_masking_rule(padded_vals, logical_shapes, dimension):
1916  result = lax.concatenate(padded_vals, dimension)  # fragmented
1917  offset = 0
1918  for padded_val, logical_shape in zip(padded_vals, logical_shapes):
1919    result = _memcpy(dimension, logical_shape[dimension], padded_val,
1920                     result, offset)
1921    offset = offset + logical_shape[dimension]
1922  return result
1923
1924def _memcpy(axis, num, src, dst, offset):
1925  def body(i, dst):
1926    update = lax.dynamic_index_in_dim(src, i, axis)
1927    return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
1928  return fori_loop(0, num, body, dst)
1929
1930masking.masking_rules[lax.concatenate_p] = _concat_masking_rule  # type: ignore
1931
1932
1933def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
1934  """Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
1935
1936  Corresponding `tree` and `avals` must match in the sense that the number of
1937  leaves in `tree` must be equal to the length of `avals`. `what` will be
1938  prepended to details of the mismatch in TypeError.
1939  """
1940  if tree1 != tree2:
1941    raise TypeError(
1942        f"{what} must have same type structure, got {tree1} and {tree2}.")
1943  if not all(_map(core.typematch, avals1, avals2)):
1944    raise TypeError(
1945        f"{what} must have identical types, got\n"
1946        f"{tree_unflatten(tree1, avals1)}\nand\n"
1947        f"{tree_unflatten(tree2, avals2)}.")
1948
1949
1950def _check_tree(func_name, expected_name, actual_tree, expected_tree):
1951  if actual_tree != expected_tree:
1952    raise TypeError(
1953        f"{func_name}() output pytree structure must match {expected_name}, "
1954        f"got {actual_tree} and {expected_tree}.")
1955
1956
1957def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
1958  """Promote weakly-typed in_vals to be compatible with out_avals.
1959
1960  Args:
1961    in_vals : flattened list of input values.
1962    in_avals : corresponding list of avals.
1963    out_avals : list of target output avals.
1964  Returns:
1965    in_vals_new : flattened list of modified in_vals with no weak types.
1966    changed : bool; true if in_vals required modification.
1967  """
1968  if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
1969    # Calling function is responsible for catching this.
1970    return in_vals, False
1971  weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
1972                    if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
1973  if not weak_mismatches:
1974    return in_vals, False
1975  for i in weak_mismatches:
1976    new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
1977    in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
1978  return in_vals, True
1979
1980def _stop_gradient_fun(f):
1981  """Create a version of f() that stops all gradients."""
1982  def wrapper(*args, **kwargs):
1983    args_flat, in_args_tree = tree_flatten((args, kwargs))
1984    args_avals = tuple(_map(_abstractify, args_flat))
1985    g = lambda a, b: f(*a, **b)
1986    jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
1987    all_args = _map(lax.stop_gradient, (*consts, *args_flat))
1988    out = core.jaxpr_as_fun(jaxpr)(*all_args)
1989    return tree_unflatten(out_tree, out)
1990  return wrapper
1991
1992
1993_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
1994
1995
1996def _split_root_args(args, const_lengths):
1997  params_list = split_list(args, list(const_lengths))
1998  return _RootTuple(*params_list[:-1]), params_list[-1]
1999
2000
2001def custom_root(f, initial_guess, solve, tangent_solve):
2002  """Differentiably solve for a roots of a function.
2003
2004  This is a low-level routine, mostly intended for internal use in JAX.
2005  Gradients of custom_root() are defined with respect to closed-over variables
2006  from the provided function ``f`` via the implicit function theorem:
2007  https://en.wikipedia.org/wiki/Implicit_function_theorem
2008
2009  Args:
2010    f: function for which to find a root. Should accept a single argument,
2011      return a tree of arrays with the same structure as its input.
2012    initial_guess: initial guess for a zero of f.
2013    solve: function to solve for the roots of f. Should take two positional
2014      arguments, f and initial_guess, and return a solution with the same
2015      structure as initial_guess such that func(solution) = 0. In other words,
2016      the following is assumed to be true (but not checked)::
2017
2018        solution = solve(f, initial_guess)
2019        error = f(solution)
2020        assert all(error == 0)
2021
2022    tangent_solve: function to solve the tangent system. Should take two
2023      positional arguments, a linear function ``g`` (the function ``f``
2024      linearized at its root) and a tree of array(s) ``y`` with the same
2025      structure as initial_guess, and return a solution ``x`` such that
2026      ``g(x)=y``:
2027
2028      - For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
2029      - For vector ``y``, you could use a linear solve with the Jacobian, if
2030        dimensionality of ``y`` is not too large:
2031        ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
2032
2033  Returns:
2034    The result of calling solve(f, initial_guess) with gradients defined via
2035    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
2036  """
2037  guess_flat, in_args_tree = tree_flatten((initial_guess,))
2038  guess_avals = tuple(_map(_abstractify, guess_flat))
2039  f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
2040      f, in_args_tree, guess_avals)
2041
2042  in_tree, = treedef_children(in_args_tree)
2043  _check_tree("f", "initial_guess", out_tree, in_tree)
2044
2045  solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
2046      partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
2047  _check_tree("solve", "initial_guess", solution_tree, in_tree)
2048
2049  def linearize_and_solve(x, b):
2050    unchecked_zeros, f_jvp = jax.linearize(f, x)
2051    return tangent_solve(f_jvp, b)
2052
2053  l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
2054      linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
2055  _check_tree("tangent_solve", "x", out_tree, in_tree)
2056
2057  all_consts = [f_consts, solve_consts, l_and_s_consts]
2058  const_lengths = _RootTuple(*_map(len, all_consts))
2059  jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
2060
2061  out_flat = _custom_root(
2062      const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
2063  return tree_unflatten(out_tree, out_flat)
2064
2065
2066@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
2067def _custom_root(const_lengths, jaxprs, *args):
2068  params, initial_guess = _split_root_args(args, const_lengths)
2069  solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
2070  return solution
2071
2072
2073@_custom_root.defjvp
2074def _root_jvp(const_lengths, jaxprs, primals, tangents):
2075  params, _ = _split_root_args(primals, const_lengths)
2076  solution = _custom_root(const_lengths, jaxprs, *primals)
2077
2078  params_dot, _ = _split_root_args(tangents, const_lengths)
2079
2080  # F(m, u) = 0      # system of equations in u, parameterized by m
2081  #                  # solution is u*(m) defined in a neighborhood
2082  # F(m, u*(m)) = 0  # satisfied in a neighborhood
2083  #
2084  # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0       # implied by line above
2085  # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m))  # rearrange
2086  #
2087  # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]]  # jvp
2088
2089  f = core.jaxpr_as_fun(jaxprs.f)
2090  linearize_and_solve = partial(
2091      core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
2092  f_at_solution = lambda *params: f(*itertools.chain(params, solution))
2093  _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
2094      params.f, params_dot.f)
2095  solution_dot = _map(
2096      operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
2097
2098  return solution, solution_dot
2099
2100
2101class _LinearSolveTuple(collections.namedtuple(
2102    '_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
2103
2104  def transpose(self):
2105    return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
2106
2107
2108def _split_linear_solve_args(args, const_lengths):
2109  params_list = split_list(args, list(const_lengths))
2110  return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
2111
2112
2113def _transpose_one_output(linear_fun, primals):
2114  transpose_fun = jax.linear_transpose(linear_fun, primals)
2115  def transposed_fun(x):
2116    (y,) = transpose_fun(x)
2117    return y
2118  return transposed_fun
2119
2120
2121def _flatten(args):
2122  return [x for arg in args for x in arg]
2123
2124
2125def _check_shapes(func_name, expected_name, actual, expected):
2126  actual_shapes = _map(np.shape, tree_leaves(actual))
2127  expected_shapes = _map(np.shape, tree_leaves(expected))
2128  if actual_shapes != expected_shapes:
2129    raise ValueError(
2130        f"{func_name}() output shapes must match {expected_name}, "
2131        f"got {actual_shapes} and {expected_shapes}")
2132
2133
2134def custom_linear_solve(
2135    matvec, b, solve, transpose_solve=None, symmetric=False):
2136  """Perform a matrix-free linear solve with implicitly defined gradients.
2137
2138  This function allows for overriding or defining gradients for a linear
2139  solve directly via implicit differentiation at the solution, rather than by
2140  differentiating *through* the solve operation. This can sometimes be much faster
2141  or more numerically stable, or differentiating through the solve operation
2142  may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
2143
2144  Required invariant::
2145
2146      x = solve(matvec, b)  # solve the linear equation
2147      assert matvec(x) == b  # not checked
2148
2149  Args:
2150    matvec: linear function to invert. Must be differentiable.
2151    b: constant right handle side of the equation. May be any nested structure
2152      of arrays.
2153    solve: higher level function that solves for solution to the linear
2154      equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
2155      as ``b``. This function need not be differentiable.
2156    transpose_solve: higher level function for solving the transpose linear
2157      equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
2158      the transpose of the linear map ``matvec`` (computed automatically with
2159      autodiff). Required for backwards mode automatic differentiation, unless
2160      ``symmetric=True``, in which case ``solve`` provides the default value.
2161    symmetric: bool indicating if it is safe to assume the linear map
2162      corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
2163
2164  Returns:
2165    Result of ``solve(matvec, b)``, with gradients defined assuming that the
2166    solution ``x`` satisfies the linear equation ``matvec(x) == b``.
2167  """
2168  if transpose_solve is None and symmetric:
2169    transpose_solve = solve
2170
2171  b_flat, in_args_tree = tree_flatten((b,))
2172  b_avals = tuple(_map(_abstractify, b_flat))
2173
2174  tree, = treedef_children(in_args_tree)
2175
2176  def _shape_checked(fun, name):
2177    def f(x):
2178      y = fun(x)
2179      _check_shapes(name, "b", y, b_flat)
2180      return y
2181    return f
2182
2183  matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
2184      _shape_checked(matvec, "matvec"), in_args_tree, b_avals)
2185  _check_tree("matvec", "b", out_tree, tree)
2186
2187  solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
2188      _shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals)
2189  _check_tree("solve", "b", out_tree, tree)
2190
2191  if transpose_solve is None:
2192    vecmat_jaxpr = tr_solve_jaxpr = None
2193    vecmat_consts = tr_solve_consts = []
2194  else:
2195    if symmetric:
2196      vecmat = matvec
2197      vecmat_jaxpr = matvec_jaxpr
2198      vecmat_consts = matvec_consts
2199    else:
2200      vecmat = _transpose_one_output(matvec, b)
2201      vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
2202          vecmat, in_args_tree, b_avals)
2203      assert out_tree == tree
2204
2205    tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
2206        _shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
2207        in_args_tree, b_avals)
2208    _check_tree("transpose_solve", "b", out_tree, tree)
2209
2210  all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
2211  const_lengths = _LinearSolveTuple(*_map(len, all_consts))
2212  jaxprs = _LinearSolveTuple(
2213      matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
2214
2215  out_flat = linear_solve_p.bind(
2216      *(_flatten(all_consts) + b_flat),
2217      const_lengths=const_lengths, jaxprs=jaxprs)
2218  return tree_unflatten(tree, out_flat)
2219
2220
2221def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
2222  return _map(raise_to_shaped, args[sum(const_lengths):])
2223
2224
2225def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
2226  params, b = _split_linear_solve_args(args, const_lengths)
2227  x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
2228  return x
2229
2230
2231def _tangent_linear_map(func, params, params_dot, *x):
2232  """Compute the tangent of a linear map.
2233
2234  Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
2235  this function computes ``∂A @ x``.
2236  """
2237  assert any(type(p) is not ad_util.Zero for p in params_dot)
2238  zeros = _map(ad_util.Zero.from_value, x)
2239  _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
2240      params + list(x), params_dot + zeros)
2241  return out_tangent
2242
2243
2244def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
2245  # A x - b = 0
2246  # ∂A x + A ∂x - ∂b = 0
2247  # ∂x = A^{-1} (∂b - ∂A x)
2248
2249  kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
2250  x = linear_solve_p.bind(*primals, **kwargs)
2251
2252  params, _ = _split_linear_solve_args(primals, const_lengths)
2253  params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
2254
2255  if all(type(p) is ad_util.Zero for p in params_dot.matvec):
2256    # no need to evaluate matvec_tangents
2257    rhs = b_dot
2258  else:
2259    matvec_tangents = _tangent_linear_map(
2260        core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
2261    rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
2262
2263  x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
2264
2265  return x, x_dot
2266
2267
2268def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
2269  if jaxprs.transpose_solve is None:
2270    raise TypeError('transpose_solve required for backwards mode automatic '
2271                    'differentiation of custom_linear_solve')
2272
2273  params, b = _split_linear_solve_args(primals, const_lengths)
2274  assert all(ad.is_undefined_primal(x) for x in b)
2275  cotangent_b = linear_solve_p.bind(
2276      *(_flatten(params.transpose()) + cotangent),
2277      const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
2278  return [None] * sum(const_lengths) + cotangent_b
2279
2280
2281def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
2282  orig_bat = [d is not batching.not_mapped for d in dims]
2283  size, = {
2284      a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
2285  }
2286
2287  params, b = _split_linear_solve_args(args, const_lengths)
2288  params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
2289  params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
2290
2291  (matvec, vecmat, solve, solve_t) = jaxprs
2292  (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
2293
2294  # Fixpoint computation of which parts of x and b are batched; we need to
2295  # ensure this is consistent between all four jaxprs
2296  b_bat = orig_b_bat
2297  x_bat = [False] * len(solve.out_avals)
2298  for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
2299    # Apply vecmat and solve -> new batched parts of x
2300    solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
2301        solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
2302    if vecmat is None:
2303      vecmat_jaxpr_batched = None
2304      x_bat_out = solve_x_bat
2305    else:
2306      vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
2307          vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
2308      x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
2309    # Apply matvec and solve_t -> new batched parts of b
2310    matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
2311        matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
2312    if solve_t is None:
2313      solve_t_jaxpr_batched = None
2314      b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
2315    else:
2316      solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
2317          solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
2318      b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
2319                      orig_b_bat)
2320    if x_bat_out == x_bat and b_bat_out == b_bat:
2321      break
2322    else:
2323      x_bat = x_bat_out
2324      b_bat = b_bat_out
2325  else:
2326    assert False, "Fixedpoint not reached"
2327
2328  batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
2329                                     solve_jaxpr_batched, solve_t_jaxpr_batched)
2330
2331  # Move batched axes to the front
2332  new_params = [
2333      batching.moveaxis(x, d, 0)
2334      if d is not batching.not_mapped and d != 0 else x
2335      for x, d in zip(_flatten(params), _flatten(params_dims))
2336  ]
2337  # Broadcast out b if necessary
2338  new_b = [
2339      batching.broadcast(x, size, 0) if now_bat and not was_bat else
2340      batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
2341      for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
2342  ]
2343
2344  outs = linear_solve_p.bind(
2345      *(new_params + new_b),
2346      const_lengths=const_lengths,
2347      jaxprs=batched_jaxprs)
2348  out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
2349  return outs, out_dims
2350
2351
2352linear_solve_p = core.Primitive('custom_linear_solve')
2353linear_solve_p.multiple_results = True
2354linear_solve_p.def_impl(_custom_linear_solve_impl)
2355linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
2356ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
2357xla.initial_style_translations[linear_solve_p] = \
2358    xla.lower_fun_initial_style(_custom_linear_solve_impl)
2359ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
2360batching.initial_style_batchers[linear_solve_p] = _linear_solve_batching_rule
2361
2362
2363def _interleave(a, b, axis):
2364  """Given two Tensors of static shape, interleave them along the first axis."""
2365  assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
2366  a_pad = [(0, 0, 0)] * a.ndim
2367  b_pad = [(0, 0, 0)] * b.ndim
2368  a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
2369  b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
2370  return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
2371                 lax.pad(b, lax._const(b, 0), b_pad))
2372
2373def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
2374  """Performs a scan with an associative binary operation, in parallel.
2375
2376  For an introduction to associative scans, see [BLE1990]_.
2377
2378  Args:
2379    fn: A Python callable implementing an associative binary operation with
2380      signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
2381      must satisfy the equation
2382      ``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
2383
2384      The inputs and result are (possibly nested Python tree structures of)
2385      array(s) matching ``elems``. Each array has a dimension in place
2386      of the ``axis`` dimension. `fn` should be applied elementwise over
2387      the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
2388      elementwise function.)
2389
2390      The result ``r`` has the same shape (and structure) as the two inputs
2391      ``a`` and ``b``.
2392    elems: A (possibly nested Python tree structure of) array(s), each with
2393      an ``axis`` dimension of size ``num_elems``.
2394    reverse: A boolean stating if the scan should be reversed with respect to
2395      the ``axis`` dimension.
2396    axis: an integer identifying the axis over which the scan should occur.
2397
2398  Returns:
2399    A (possibly nested Python tree structure of) array(s) of the same shape
2400    and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
2401    result of recursively applying ``fn`` to combine the first ``k`` elements
2402    of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
2403    the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
2404
2405  Example 1: partial sums of an array of numbers:
2406
2407  >>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
2408  [ 0, 1, 3, 6]
2409
2410  Example 2: partial products of an array of matrices
2411
2412  >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
2413  >>> partial_prods = lax.associative_scan(jnp.matmul, mats)
2414  >>> partial_prods.shape
2415  (4, 2, 2)
2416
2417  Example 3: reversed partial sums of an array of numbers
2418
2419  >>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
2420  [ 6, 6, 5, 3]
2421
2422  .. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
2423    Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
2424    University.
2425  """
2426  elems_flat, tree = tree_flatten(elems)
2427
2428  if reverse:
2429    elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
2430
2431  def combine(a_flat, b_flat):
2432    # Lower `fn` to operate on flattened sequences of elems.
2433    a = tree_unflatten(tree, a_flat)
2434    b = tree_unflatten(tree, b_flat)
2435    c = fn(a, b)
2436    c_flat, _ = tree_flatten(c)
2437    return c_flat
2438
2439  # Check that all inputs have a consistent leading dimension `num_elems`.
2440  axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
2441  num_elems = int(elems_flat[0].shape[axis])
2442  if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
2443    raise ValueError('Array inputs to associative_scan must have the same '
2444                     'first dimension. (saw: {})'
2445                     .format([elems.shape for elem in elems_flat]))
2446
2447
2448  # Summary of algorithm:
2449  #
2450  # Consider elements of `_scan(elems)` at odd indices. That's the same as first
2451  # summing successive pairs of elements of `elems` and performing a scan on
2452  # that half sized tensor. We perform the latter scan by recursion.
2453  #
2454  # Now consider the even elements of `_scan(elems)`. These can be computed
2455  # from the odd elements of `_scan(elems)` by adding each odd element of
2456  # `_scan(elems)` to the matching even element in the original `elems`.
2457  #
2458  # We return the odd and even elements interleaved.
2459  #
2460  # For the base case of the recursion we return the first element
2461  # of `elems` followed by the sum of the first two elements computed as
2462  # a (small two-down-to-one) reduction step.
2463  def _scan(elems):
2464    """Perform scan on `elems`."""
2465
2466    num_elems = elems[0].shape[axis]
2467
2468    if num_elems < 2:
2469      return elems
2470
2471    # Combine adjacent pairs of elements.
2472    reduced_elems = combine(
2473      [lax.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
2474      [lax.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems])
2475
2476    # Recursively compute scan for partially reduced tensors.
2477    odd_elems = _scan(reduced_elems)
2478
2479    if num_elems % 2 == 0:
2480      even_elems = combine(
2481        [lax.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
2482        [lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
2483    else:
2484      even_elems = combine(
2485        odd_elems,
2486        [lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
2487
2488    # The first element of a scan is the same as the first element
2489    # of the original `elems`.
2490    even_elems = [
2491      lax.concatenate([lax.slice_in_dim(elem, 0, 1, axis=axis), result],
2492                      dimension=axis)
2493      for (elem, result) in zip(elems, even_elems)]
2494    return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
2495
2496  scans = _scan(elems_flat)
2497
2498  if reverse:
2499    scans = [lax.rev(scanned, [axis]) for scanned in scans]
2500
2501  return tree_unflatten(tree, scans)
2502
2503
2504# Cumulative reductions.
2505
2506def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
2507  """Computes a cumulative sum along `axis`."""
2508  return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
2509
2510def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
2511  """Computes a cumulative product along `axis`."""
2512  return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
2513
2514def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
2515  """Computes a cumulative maximum along `axis`."""
2516  return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
2517
2518def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
2519  """Computes a cumulative minimum along `axis`."""
2520  return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
2521
2522def _cumred_shape_rule(x, *, axis: int, reverse: bool):
2523  if axis < 0 or axis >= x.ndim:
2524    raise ValueError(
2525        "axis {} is out of bounds for array of shape {}".format(axis, x.shape))
2526  return x.shape
2527
2528def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
2529  return [cumsum(t, axis=axis, reverse=not reverse)]
2530
2531
2532
2533def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
2534                                 axis: int, reverse: bool):
2535  # On TPU, an implementation using reduce_window is handled specially by the
2536  # compiler and is efficient. On other backends, it is O(n^2).
2537  n = x.shape[axis]
2538  if n == 0:
2539    return x
2540  padding = [(0, 0)] * x.ndim
2541  padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
2542  strides = [1] * x.ndim
2543  window_dims = [1] * x.ndim
2544  window_dims[axis] = n
2545  return window_reduce(x, window_dims, strides, padding)
2546
2547def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
2548                       reverse: bool):
2549  operand, = batched_args
2550  bdim, = batch_dims
2551  axis = axis if axis < bdim else axis + 1
2552  return prim.bind(operand, axis=axis, reverse=reverse), bdim
2553
2554def _cumred_dtype_rule(name, operand, *args, **kw):
2555  if not dtypes.issubdtype(operand.dtype, np.number):
2556    raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
2557                    "of number.".format(name, np.dtype(operand.dtype).name))
2558  return dtypes.canonicalize_dtype(operand.dtype)
2559
2560cumsum_p = lax.standard_primitive(
2561  _cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
2562  'cumsum')
2563ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
2564xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
2565  partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
2566  multiple_results=False)
2567batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
2568
2569
2570def _cumulative_reduction_primitive(name, reduce_window_fn):
2571  reducer_p = lax.standard_primitive(
2572    _cumred_shape_rule, partial(_cumred_dtype_rule, name),
2573    name)
2574  xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
2575    partial(_cumred_tpu_translation_rule, reduce_window_fn),
2576    multiple_results=False)
2577  batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
2578  return reducer_p
2579
2580
2581cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
2582cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
2583cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
2584
2585xla.translations[cumsum_p] = xla.lower_fun(
2586  partial(associative_scan, lax.add), multiple_results=False)
2587xla.translations[cumprod_p] = xla.lower_fun(
2588  partial(associative_scan, lax.mul), multiple_results=False)
2589xla.translations[cummin_p] = xla.lower_fun(
2590  partial(associative_scan, lax.min), multiple_results=False)
2591xla.translations[cummax_p] = xla.lower_fun(
2592  partial(associative_scan, lax.max), multiple_results=False)
2593
2594def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
2595                         combine_fn: Callable):
2596  # Irrespective of backend, we always use the parallel prefix scan
2597  # implementation when differentiating because reduce_window is not
2598  # arbitrarily differentiable.
2599  return api.jvp(partial(associative_scan, combine_fn, axis=axis,
2600                         reverse=reverse),
2601                 primals, tangents)
2602
2603ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
2604ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
2605ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
2606
2607
2608@config.register_omnistaging_disabler
2609def omnistaging_disabler() -> None:
2610  global _initial_style_open_jaxpr, _initial_style_jaxpr, \
2611      _initial_style_jaxprs_with_common_consts
2612
2613  @cache()
2614  def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
2615    in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
2616    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
2617    with core.initial_style_staging():  # type: ignore
2618      jaxpr, out_pvals, consts = pe.trace_to_jaxpr(  # type: ignore
2619        wrapped_fun, in_pvals, instantiate=True, stage_out=False)  # type: ignore
2620    return jaxpr, out_pvals, consts, out_tree
2621
2622  @cache()
2623  def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
2624    jaxpr, out_pvals, consts, out_tree = _initial_style_open_jaxpr(
2625        fun, in_tree, in_avals)
2626    closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
2627    return closed_jaxpr, consts, out_tree()
2628
2629  @cache()
2630  def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
2631                                              in_tree, in_avals):
2632    # When staging the branches of a conditional into jaxprs, constants are
2633    # extracted from each branch and converted to jaxpr arguments. To use the
2634    # staged jaxprs as the branches to a conditional *primitive*, we need for
2635    # their (input) signatures to match. This function "joins" the staged jaxprs:
2636    # for each one, it makes another that accepts *all* constants, but only uses
2637    # those that it needs (dropping the rest).
2638
2639    jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
2640        _initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs])
2641
2642    newvar = core.gensym(jaxprs, suffix='_')
2643    all_const_avals = tuple(
2644        tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
2645        for consts in all_consts)
2646    unused_const_vars = tuple(
2647        tuple(newvar(aval) for aval in const_avals)
2648        for const_avals in all_const_avals)
2649
2650    def pad_jaxpr_constvars(i, jaxpr):
2651      prefix = util.concatenate(unused_const_vars[:i])
2652      suffix = util.concatenate(unused_const_vars[i+1:])
2653      constvars = prefix + jaxpr.constvars + suffix
2654      return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
2655                        outvars=jaxpr.outvars, eqns=jaxpr.eqns)
2656
2657    def type_and_const_convert_jaxpr(jaxpr, out_pvals):
2658      return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
2659
2660    jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
2661    closed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
2662
2663    return (tuple(closed_jaxprs),
2664            tuple(util.concatenate(all_consts)),
2665            tuple(out_tree() for out_tree in all_out_trees))
2666