15import itertools as it
16from collections import namedtuple
17import contextlib
18import functools
19from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
20                    List, Union, cast, Type, no_type_check)
21from weakref import ref
23import numpy as np
25from .. import core
26from .. import dtypes
27from .. import linear_util as lu
28from ..ad_util import Zero
29from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
30                         split_list, cache, as_hashable_function)
31from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
32                    unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
33                    dropvar, ConcreteArray, raise_to_shaped)
34from jax._src import source_info_util
35from ..config import config
37map = safe_map
38zip = safe_zip
39def identity(x): return x
41class PartialVal(tuple):
42  """Partial value: either a known value or an unknown (abstract) value.
44  Represented as a pair `(aval_opt, const)` of one of two kinds:
45  * `(None, <Constant>)` indicates a known value, either a Python regular
46    value, or a Tracer.
47  * `(<AbstractValue>, *)` indicates an unknown value characterized by an
48    abstract value.
49  """
50  def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
51    pv, const = xs
52    if not core.skip_checks:
53      # type checks
54      assert isinstance(pv, (AbstractValue, type(None))), xs
55      assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
56      # invariant checks
57      if isinstance(pv, AbstractValue):
58        assert get_aval(const) == core.abstract_unit, xs
59    return tuple.__new__(cls, xs)
61  @classmethod
62  def known(cls, const: core.Value) -> 'PartialVal':
63    return PartialVal((None, const))
65  @classmethod
66  def unknown(cls, aval: AbstractValue) -> 'PartialVal':
67    return PartialVal((aval, core.unit))
69  def is_known(self) -> bool:
70    return self[0] is None
72  def get_known(self) -> Optional[core.Value]:
73    """Get the known value, if known, else None."""
74    return self[1] if self[0] is None else None
76  def get_aval(self) -> AbstractValue:
77    """Get AbstractValue directly (if unknown) or from the constant (known)."""
78    known = self.get_known()
79    if known is not None:
80      return get_aval(known)
81    else:
82      return self[0]
84  def merge_with_known(self, val: core.Value) -> core.Value:
85    """Either the stored known value, or the given 'val'."""
86    known = self.get_known()
87    return known if known is not None else val
90class JaxprTrace(Trace):
91  def pure(self, val) -> 'JaxprTracer':
92    return self.new_const(val)
94  def lift(self, val) -> 'JaxprTracer':
95    return self.new_const(val)
97  def sublift(self, val) -> 'JaxprTracer':
98    return JaxprTracer(self, val.pval, FreeVar(val))
100  def new_const(self, val) -> 'JaxprTracer':
101    if isinstance(val, Tracer) and val._trace.level == self.level:
102      raise Exception
103    return JaxprTracer(self, PartialVal.known(val), unit)
105  def new_instantiated_literal(self, val) -> 'JaxprTracer':
106    return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
108  def new_instantiated_const(self, val) -> 'JaxprTracer':
109    return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
111  def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
112    const = pval.get_known()
113    if const is None:
114      return JaxprTracer(self, pval, LambdaBinding())
115    else:
116      return self.new_const(const)
118  def instantiate_const(self, tracer) -> Tracer:
119    const = tracer.pval.get_known()
120    if const is None:
121      return tracer
122    else:
123      if type(const) in core.literalable_types and np.shape(const) == ():
124        return self.new_instantiated_literal(const)
125      else:
126        return self.new_instantiated_const(const)
128  def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
129    const = tracer.pval.get_known()
130    if const is None:
131      return tracer
132    else:
133      aval = raise_to_shaped(get_aval(const), np.isscalar(const))
134      return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
136  def process_primitive(self, primitive, tracers, params):
137    if primitive in custom_partial_eval_rules:
138      return custom_partial_eval_rules[primitive](self, *tracers, **params)
139    else:
140      return self.default_process_primitive(primitive, tracers, params)
142  def default_process_primitive(self, primitive, tracers, params):
143    """By default, if all the input tracers are known, then execute the primitive
144    and all the ouputs are known. Otherwise, all the outputs are unknown."""
145    consts = [t.pval.get_known() for t in tracers]
146    if all(c is not None for c in consts):
147      return primitive.bind(*consts, **params)
148    tracers = map(self.instantiate_const, tracers)
149    avals = [t.aval for t in tracers]
150    out_aval = primitive.abstract_eval(*avals, **params)
151    source = source_info_util.current()
152    if primitive.multiple_results:
153      out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
154                     for aval in out_aval]
155      eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
156      for t in out_tracers: t.recipe = eqn
157      return out_tracers
158    else:
159      out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
160      out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
161                                         params, source)
162      return out_tracer
164  # We use process_call to handle both call and map primitives.
165  def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
166    if not config.omnistaging_enabled:
167      if (self.main.trace_type is StagingJaxprTrace  # type: ignore
168          and primitive in staged_out_calls):        # type: ignore
169        tracers = map(self.instantiate_const_abstracted, tracers)
171    if primitive in call_partial_eval_rules:
172      return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
174    in_pvals = [t.pval for t in tracers]
175    if primitive.map_primitive:
176      mapped_aval = partial(core.mapped_aval, params['axis_size'])
177      in_pvals = [pval if pval.is_known() or in_axis is None
178                  else PartialVal.unknown(mapped_aval(in_axis, pval[0]))
179                  for pval, in_axis in zip(in_pvals, params['in_axes'])]
181      def app(f, *args):
182        f, num_outputs = count_outputs(f)
183        out_axes_thunk = params['out_axes_thunk']
184        @as_hashable_function(closure=out_axes_thunk)
185        def new_out_axes_thunk():
186          out_axes = out_axes_thunk()
187          return out_axes + (0,) * (num_outputs() - len(out_axes))
188        pe_params = dict(params, out_axes_thunk=new_out_axes_thunk)
189        return primitive.bind(f, *args, **pe_params)
190    else:
191      app = partial(primitive.bind, **params)
192    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
193        f, in_pvals, app, instantiate=False)
194    if primitive.map_primitive:
195      unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
196      out_axes = params['out_axes_thunk']()
197      out_pvals = [pval if pval.is_known() else
198                   PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else
199                   PartialVal.unknown(pval[0])
200                   for pval, out_axis in zip(out_pvals, out_axes)]
202    # Skip known invars and outvars, and lift constants as regular invars
203    in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
204    out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
205    jaxpr = _drop_invars(jaxpr, in_knowns)
206    jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
208    # Known tracers get propagated as if they were constants
209    known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
210                         if pval.is_known()]
212    # Unknown tracers need to have the jaxpr set up as their recipe
213    unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
214                           if not pval.is_known()]
215    unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
216    const_tracers = map(self.new_instantiated_const, consts)
217    in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
219    # Set up new params
220    new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
221    if primitive.map_primitive:
222      in_axes = params['in_axes']
223      # NOTE: const_tracers are added as map outputs, and we always map them
224      #       along axis 0 (see `new_out_axes_thunk` above).
225      new_in_axes = ((0,) * len(const_tracers) +
226                     (None,) * len(env_tracers) +
227                     tuple(axis for axis, t in zip(in_axes, tracers)
228                           if not t.pval.is_known()))
229      new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
230                           if not pval.is_known())
231      new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes)
232      del new_params['out_axes_thunk']
233    update_params = call_param_updaters.get(primitive)
234    if update_params:
235      new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
237    eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
238                         source_info_util.current())
239    for t in unknown_tracers_out: t.recipe = eqn
240    return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
242  process_map = process_call
244  # We use post_process_call to handle both call and map primitives.
245  def post_process_call(self, primitive, out_tracers, params):
246    jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
247    out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
248    out = out_pv_consts + consts
249    nconsts = len(consts)
250    del consts, out_pv_consts
251    main = self.main
253    if primitive.map_primitive:
254      out_axes = params['out_axes_thunk']()
255      sz = params['axis_size']
256      out_pvs = [None if pv is None else core.unmapped_aval(sz, ax, pv)
257                 for pv, ax in zip(out_pvs, out_axes)]
259    def todo(x):
260      n = len(jaxpr.outvars)
261      out_pv_consts, consts = x[:n], x[n:]
262      trace = JaxprTrace(main, core.cur_sublevel())
263      const_tracers = map(trace.new_instantiated_const, consts)
264      out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
265                     for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
266      in_tracers = (*const_tracers, *map(trace.full_raise, env))
268      new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
269      if primitive.map_primitive:
270        # NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform.
271        new_in_axes = (0,) * len(const_tracers) + (None,) * len(env)
272        new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
273        del new_params['out_axes_thunk']
274      update_params = call_param_updaters.get(primitive)
275      if update_params:
276        new_params = update_params(new_params, [])
278      eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
279                           source_info_util.current())
280      for t in out_tracers:
281        t.recipe = eqn
282      return out_tracers
284    if primitive.map_primitive:
285      def out_axes_transform(out_axes):
286        return out_axes + (0,) * nconsts
287      todo = (todo, out_axes_transform)
289    return out, todo
291  post_process_map = post_process_call
293  def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
294                   app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
295                   instantiate: bool):
296    """Partially evaluate f on a sequence of PartialVals."""
297    in_avals, in_consts = unzip2(pvals)
298    f = trace_to_subjaxpr(f, self.main, instantiate)
299    f, aux = partial_eval_wrapper(f, tuple(in_avals))
300    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
301    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
302    out_pvs = map(PartialVal, zip(out_avals, out_consts))
303    env_tracers = map(self.full_raise, env)
304    return jaxpr, out_pvs, consts, env_tracers
306  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
307    tracers = map(self.instantiate_const_abstracted, tracers)
308    in_avals, in_consts = unzip2(t.pval for t in tracers)  # in_consts are units
309    fun = trace_to_subjaxpr(fun, self.main, True)
310    fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
311    out_flat = prim.bind(fun, jvp, *in_consts)
312    out_avals, jaxpr, env = aux()
313    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
314    out_pvals = map(PartialVal, zip(out_avals, out_consts))  # out_consts are units
315    env_tracers = map(self.full_raise, env)
316    out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
317    const_tracers = map(self.new_instantiated_const, consts)
318    in_tracers = (*const_tracers, *env_tracers, *tracers)
319    closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
321    @_memoize
322    def jvp_jaxpr_thunk():
323      jvp_ = trace_to_subjaxpr(jvp, self.main, True)
324      jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
325      out_flat = jvp_.call_wrapped(*(in_consts * 2))  # in_consts are units
326      out_avals, jaxpr, env = aux()
327      _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
328      converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
329      return converted_jaxpr, (*consts, *env)
331    eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
332                         dict(fun_jaxpr=closed_jaxpr,
333                              jvp_jaxpr_thunk=jvp_jaxpr_thunk,
334                              num_consts=len(consts) + len(env)),
335                         source_info_util.current())
336    for t in out_tracers: t.recipe = eqn
337    return out_tracers
339  def post_process_custom_jvp_call(self, out_tracers, params):
340    # This path should only be reachable if we expose a partial eval API
341    # unrelated to autodiff, since we raise an error when differentiation with
342    # respect to values over which a custom_jvp function closes is detected.
343    raise NotImplementedError  # TODO(mattjj)
345  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
346    tracers = map(self.instantiate_const_abstracted, tracers)
347    in_avals, in_consts = unzip2(t.pval for t in tracers)  # in_consts are units
348    fun = trace_to_subjaxpr(fun, self.main, True)
349    fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
350    out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
351    out_avals, jaxpr, env = aux()
352    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
353    out_pvals = map(PartialVal, zip(out_avals, out_consts))  # out_consts are units
354    env_tracers = map(self.full_raise, env)
355    out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
356    const_tracers = map(self.new_instantiated_const, consts)
357    in_tracers = (*const_tracers, *env_tracers, *tracers)
358    closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
360    @_memoize
361    def fwd_jaxpr_thunk():
362      fwd_ = trace_to_subjaxpr(fwd, self.main, True)
363      fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
364      out_flat = fwd_.call_wrapped(*in_consts)  # in_consts are units
365      out_avals, jaxpr, env = aux()
366      _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
367      converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
368      return converted_jaxpr, (*consts, *env)
370    eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
371                         dict(fun_jaxpr=closed_jaxpr,
372                              fwd_jaxpr_thunk=fwd_jaxpr_thunk,
373                              num_consts=len(consts) + len(env),
374                              bwd=bwd, out_trees=out_trees),
375                         source_info_util.current())
376    for t in out_tracers: t.recipe = eqn
377    return out_tracers
379  def post_process_custom_vjp_call(self, out_tracers, params):
380    # This path should only be reachable if we expose a partial eval API
381    # unrelated to autodiff, since we raise an error when differentiation with
382    # respect to values over which a custom_vjp function closes is detected.
383    raise NotImplementedError  # TODO(mattjj)
387def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
388  py_args = map(PartialVal, zip(pvs, consts))
389  jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
390  out_pvs, out_consts = unzip2(out_pvals)
391  out = tuple(out_consts) + tuple(consts)
392  yield out, (out_pvs, jaxpr, env)
395def count_outputs(*args, **kwargs):
396  ans = yield args, kwargs
397  yield ans, len(ans)
399custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
400call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
401call_param_updaters: Dict[core.Primitive, Callable] = {}
404def abstract_eval_fun(fun, *avals, **params):
405  if config.omnistaging_enabled:
406    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
407  else:
408    pvals_in = [PartialVal.unknown(a) for a in avals]
409    _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
410                                    instantiate=True, stage_out=True)  # type: ignore
411    avals_out, _ = unzip2(pvals_out)
412  for aval_out in avals_out:
413    assert isinstance(aval_out, AbstractValue)  # instantiate=True
414  return avals_out
417JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
418                          'ConstVar', Literal, core.Unit]
420class JaxprTracer(Tracer):
421  __slots__ = ['pval', 'recipe']
423  def __init__(self, trace: JaxprTrace, pval: PartialVal,
424               recipe: Optional[JaxprTracerRecipe]):
425    assert isinstance(pval, PartialVal)
426    pv, const = pval
427    if isinstance(const, Tracer) and const._trace.level >= trace.level:
428      raise core.escaped_tracer_error(
429          const, "Tracer from a higher level: {} in trace {}".format(const, trace))
430    self._trace = trace
431    self.pval = pval
432    self.recipe = recipe
434  def __repr__(self):
435    return 'Traced<{}:{}>'.format(self.aval, self._trace)
437  @property
438  def aval(self) -> AbstractValue:
439    return self.pval.get_aval()
441  @property
442  def parents(self) -> Sequence['JaxprTracer']:
443    if isinstance(self.recipe, JaxprEqnRecipe):
444      return self.recipe.invars
445    else:
446      return []
448  def full_lower(self):
449    known = self.pval.get_known()
450    if known is not None:
451      return core.full_lower(known)
452    else:
453      return self
455  def is_known(self):
456      return self.pval.is_known()
458# TODO(necula): this could return a ClosedJaxpr with out_pvals
459def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
460                   instantiate: Union[bool, Sequence[bool]] = False,
461                   ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
462  """Traces a function into a Jaxpr, given PartialVals for inputs.
464  Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
465  computation that depends on unknown inputs. The `out_pvals` are the PartialVal
466  for the outputs. The intermediate values that depend only on known inputs and
467  are needed to compute the output of `jaxpr` are in `consts` and are passed in
468  as the constvars of the `jaxpr`. The handling of the known outputs depends on
469  `instantiate`.
471  For example, given `fun` defined as follows::
473    def fun(ki, ui):  # ki will be a known input in this example
474      ka = ki + 2
475      kb = ka + 3
476      return (kb, ui + ka)
478  with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
479  computation that depends on unknown inputs is `ui + ka` and will be the only
480  computation in the body of the `jaxpr`. This computation depends on the known
481  intermediate value `ka`, which will be computed statically. Currently, such
482  constants are either embedded in the Jaxpr if they are scalars, or passed as a
483  constvar to `jaxpr`, and then the value of the actual constant will be in
484  `consts`:
486  When `instantiate=False` we get::
488    jaxpr =
489      { lambda ka ; ki ui.
490        let c = add ui ka
491        in (*, c) }   # known outputs are `*`
492    out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
493    consts = [3]  # the constant for `ka`
495  When `instantiate=True` we get::
497    jaxpr =
498      { lambda ka kb ; ki ui.
499        let c = add ui ka
500        in (kb, c) }   # known output are explicit
501    out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
502    consts = [3, 6]  # values for `ka` and `kb` constvars
503  """
504  with core.new_main(JaxprTrace) as main:
505    fun = trace_to_subjaxpr(fun, main, instantiate)
506    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
507    assert not env
508    del main, fun, env
510  return jaxpr, out_pvals, consts
514def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
515                      pvals: Sequence[PartialVal]):
516  assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
517  trace = JaxprTrace(main, core.cur_sublevel())
518  in_tracers = map(trace.new_arg, pvals)
519  ans = yield in_tracers, {}
520  instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
521  out_tracers = map(trace.full_raise, map(core.full_lower, ans))
522  out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
523  jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
524  out_pvals = [t.pval for t in out_tracers]
525  del trace, in_tracers, out_tracers
526  yield jaxpr, (out_pvals, consts, env)
528def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
529  if instantiate:
530    return trace.instantiate_const(trace.full_raise(tracer))
531  else:
532    return tracer
535FreeVar = namedtuple('FreeVar', ['val'])
536ConstVar = namedtuple('ConstVar', ['val'])
537LambdaBinding = namedtuple('LambdaBinding', [])
538class JaxprEqnRecipe(NamedTuple):
539  eqn_id: object
540  invars: Sequence[JaxprTracer]
541  outvars: 'Sequence[ref[JaxprTracer]]'
542  primitive: core.Primitive
543  params: Dict[str, Any]
544  source_info: Optional[source_info_util.Traceback]
546def new_eqn_recipe(invars: Sequence[JaxprTracer],
547                   outvars: Sequence[JaxprTracer],
548                   primitive: core.Primitive,
549                   params: Dict[str, Any],
550                   source_info: Optional[source_info_util.Traceback]
551                  ) -> JaxprEqnRecipe:
552  """Constructs a new JaxEqnRecipe.
554  Params:
555    invars: the tracers for the primitive inputs.
556    outvars: the tracers for the primitive outputs.
557    primitive: the primitive.
558    params: the primitive params
559  """
560  # TODO(necula): move these checks to core.check_jaxpr, and call in more places
561  if primitive.call_primitive or primitive.map_primitive:
562    assert "call_jaxpr" in params
563  if primitive.map_primitive:
564    assert ("in_axes" in params and
565            len(params["in_axes"]) == len(params["call_jaxpr"].invars))
566    assert ("donated_invars" in params and
567            len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
568  return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
569                        params, source_info)
572def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
573                  recipe: JaxprEqnRecipe) -> core.JaxprEqn:
574  _, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
575  out_tracers = [t_ref() for t_ref in out_tracer_refs]
576  invars  = [getvar(t) for t in in_tracers]
577  outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
578             for t in out_tracers]
579  return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
581def tracers_to_jaxpr(
582  in_tracers: Sequence[JaxprTracer],
583  out_tracers: Sequence[JaxprTracer]
584  ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
585  """Constructs Jaxpr given tracers for inputs and outputs.
587  Params:
588    in_tracers: the tracers that were created for the function inputs
589    out_tracers: the tracers that were output by the function.
591  Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
592    the `constvars` in the returned Jaxps, and a list of environment values.
593    The vars for the environment values have been prepended to the Jaxpr's
594    `invars`.
595  """
596  newvar = core.gensym()
597  t_to_var: Dict[int, core.Atom] = {}
598  def getvar(t: JaxprTracer) -> core.Atom:
599    var = t_to_var.get(id(t))
600    if var is None:
601      aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
602      var = t_to_var[id(t)] = newvar(aval)
603    return var
604  sorted_tracers = toposort(out_tracers)
605  invars = map(getvar, in_tracers)
606  eqns: List[core.JaxprEqn] = []
607  env: Dict[core.Var, Any] = {}
608  consts: Dict[core.Var, Any] = {}
609  const_to_var: Dict[int, core.Var] = {}
610  def getconstvar(c):
611    var = const_to_var.get(id(c))
612    if var is None:
613      var = const_to_var[id(c)] = newvar(get_aval(c))
614    return var
615  processed_eqn_ids = set()
616  for t in sorted_tracers:
617    recipe = t.recipe
618    if isinstance(recipe, JaxprEqnRecipe):
619      if recipe.eqn_id not in processed_eqn_ids:
620        eqns.append(recipe_to_eqn(getvar, recipe))
621        processed_eqn_ids.add(recipe.eqn_id)
622    elif isinstance(recipe, LambdaBinding):
623      if not any(t is in_tracer for in_tracer in in_tracers):
624        raise core.escaped_tracer_error(
625            t, "Tracer not among input tracers {}".format(t))
626      assert in_tracers, "Lambda binding with no args"
627    elif isinstance(recipe, FreeVar):
628      env[cast(core.Var, getvar(t))] = recipe.val
629    elif isinstance(recipe, ConstVar):
630      v = t_to_var[id(t)] = getconstvar(recipe.val)
631      consts[v] = recipe.val
632    elif isinstance(recipe, Literal):
633      t_to_var[id(t)] = recipe
634    elif recipe is unit:
635      t_to_var[id(t)] = unitvar
636    else:
637      raise TypeError(recipe)
639  env_vars, env_vals = unzip2(env.items())
640  const_vars, const_vals = unzip2(consts.items())
641  # The env_vars are pre-pended to the invars
642  jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
643  core.skip_checks or core.check_jaxpr(jaxpr)
644  return jaxpr, const_vals, env_vals
647def convert_constvars_jaxpr(jaxpr: Jaxpr):
648  """Moves the constvars to the start of invars."""
649  core.skip_checks or core.check_jaxpr(jaxpr)
650  lifted_jaxpr = Jaxpr(constvars=(),
651                       invars=jaxpr.constvars + jaxpr.invars,
652                       outvars=jaxpr.outvars, eqns=jaxpr.eqns)
653  core.skip_checks or core.check_jaxpr(lifted_jaxpr)
654  return lifted_jaxpr
656def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
657  core.skip_checks or core.check_jaxpr(jaxpr)
658  env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
659  converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
660                          invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
661  core.skip_checks or core.check_jaxpr(converted_jaxpr)
662  return converted_jaxpr
665def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
666  return (abstract_unit, aval) if unknown else (aval, abstract_unit)
668def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
669                       instantiate: Union[bool, Sequence[bool]],
670                       ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
671  """Specializes a Jaxpr given an indication of which inputs are known.
673  Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
675  `out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
676  `jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
677  and performs *all* the computation in `jaxpr` that depends only on the known inputs.
678  Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
679  appended with the known residuals (the intermediate computations in `jaxpr`
680  that depend only on known inputs and that are needed to compute the unknown outputs).
682  `jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
683  computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
684  outputs replaced by `*`.
686  Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
687  unknown inputs into:
689    jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
690                    let _, uout = jaxpr_unknown(ki, ui, kresidual)
691                    in (kout, uout)
693  For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
694                                            in (ki + 3, ui + ka)"
695  then
696    `jaxpr_known` = lambda ki, ui: let ka = ki + 2
697                                    in (ki + 3, *, ka)
698    'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
700  Note that if instantiate is True for a given output, then jaxpr_known always returns a
701  unit in its place. So when instantiate is True, the expectation is the one doesn't
702  run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
703  to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
704  simply get passed as residual constants and returned immediately.
705  """
706  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
708  cell = []
709  def fun(*vals):
710    pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
711            for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
712    jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
713    out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
714    cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
715    return out_consts_2 + consts_2
717  # For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
718  # known inputs.
719  in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
720  jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
721  (out_pvs_2, jaxpr_2, num_res), = cell
722  assert len(jaxpr_2.constvars) == num_res
724  #   jaxpr :: a -> b
725  # jaxpr_1 :: a1 -> [b1, res]
726  # jaxpr_2 :: res | a2 -> b2
727  # jaxpr_2 :: [a2, res] -> b2
728  jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
729  jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
730  for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
731    if not unknown:
732      var.aval = abstract_unit
734  uk_out = [pv is not None for pv in out_pvs_2]
736  in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
737  out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
738  # out_avals_1 and in_avals_2 need the residuals added
739  res_avals = out_avals[len(jaxpr.out_avals):]
740  assert len(res_avals) == num_res
741  out_avals_1 = [*out_avals_1, *res_avals]
742  in_avals_2 = [*in_avals_2, *res_avals]
744  return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
747remat_call_p = core.CallPrimitive('remat_call')
748remat_call = remat_call_p.bind
751def _remat_partial_eval(trace, _, f, tracers, params):
752  concrete = params['concrete']
754  # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
755  # the function being called, not just for the unknown parts. To do that, we
756  # instantiate all the input tracers as constants in the jaxpr being formed.
757  # Those tracers might have concrete avals, and doing abstract interpretation
758  # on concrete avals engenders a tradeoff: it allows data-dependent Python
759  # control flow to work, but it can in some cases lead to redundant FLOPs (done
760  # both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
761  # `concrete` parameter to switch this behavior, and if `concrete` is False
762  # then we raise the avals to the Shaped level.
763  if concrete:
764    instantiated_tracers = map(trace.instantiate_const, tracers)
765  else:
766    instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
768  # Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
769  in_pvals = [t.pval for t in instantiated_tracers]
770  if config.omnistaging_enabled:
771    jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
772      f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
773  else:
774    with core.initial_style_staging():  # type: ignore
775      jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
776        f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
778  # Convert consts to inputs, since they may contain Tracer instances.
779  jaxpr = convert_constvars_jaxpr(jaxpr)
780  const_tracers = map(trace.new_instantiated_const, consts)
782  # Since we traced with everything marked as unknown, but we need to know which
783  # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
784  closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
785  in_unknowns = ([False] * len(consts) +
786                 [not t.is_known() for t in it.chain(env_tracers, tracers)])
787  if config.omnistaging_enabled:
788    jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
789        closed_jaxpr, in_unknowns, instantiate=False)  # type: ignore
790  else:
791    jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
792        closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)  # type: ignore
793  out_knowns = [not b for b in out_unknowns]
794  out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
796  # Next, we need values for the outputs that should be known. Since consts
797  # weren't passed through Python for evaluation, we need to evaluate jaxpr_known,
798  # minus the residual outputs that we don't need. When `concrete=True`, as an
799  # optimization we can avoid redoing *some* redundant FLOPs, namely those that
800  # produced concrete avals at the output, simply by using those as computed
801  # values. For the use case of inverse-mode ad in op-by-op ("eager mode")
802  # evaluation, all the primal outputs should be concrete (thus not recomputed).
803  to_compute = [type(pval[0]) is not ConcreteArray
804                for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
805  num_outputs = len(jaxpr_unknown.out_avals)
806  num_res = len(jaxpr_known.out_avals) - num_outputs
807  jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
808  jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
809  _, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
810  reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
811  out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
813  # Known outputs should keep propagating as constants
814  assert all(pv.is_known() for pv in out_known_pvals)
815  known_output_tracers = [trace.new_const(pval.get_known())
816                          for pval in out_known_pvals]
817  # Unknown outputs get wrapped in tracers with the appropriate recipe
818  unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
819                            for out_pval in out_unknown_pvals]
821  # dce jaxpr outputs
822  new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
823  new_params = dict(params, call_jaxpr=new_jaxpr)
825  # set up eqn for unknown outputs
826  in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
827  eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params,
828                       source_info_util.current())
829  for t in unknown_output_tracers: t.recipe = eqn
830  return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
831call_partial_eval_rules[remat_call_p] = _remat_partial_eval
833def _partition_knowns(pvals, unknowns: Sequence[bool]):
834  return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
835          [e for e, unknown in zip(pvals, unknowns) if unknown])
837def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
838  known_iter, unknown_iter = iter(known_list), iter(unknown_list)
839  return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
842def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
843  new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
844  return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
847def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
848  # This dead-code elimination is pretty rudimentary, and in particular doesn't
849  # nontrivially DCE through scan, call, or other higher-order primitives.
850  # TODO(mattjj): better DCE
851  if drop_outputs:
852    new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
853  else:
854    new_outvars = [var if output else unitvar
855                   for var, output in zip(jaxpr.outvars, outputs)]
857  needed_vars = {v for v in new_outvars if type(v) is not Literal}
858  new_eqns = []
859  for eqn in jaxpr.eqns[::-1]:
860    if set(eqn.outvars) & needed_vars:
861      new_eqns.append(eqn)
862      needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
863  new_eqns = new_eqns[::-1]
864  return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
865                    new_outvars, new_eqns)
868def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]):
869  return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d],
870                    jaxpr.outvars, jaxpr.eqns)
873def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
874  pv1, _ = pval1
875  if pval1.is_known():
876    return pval1
877  else:
878    if type(pv1) is ConcreteArray:
879      return PartialVal.known(pv1.val)  # pytype: disable=attribute-error
880    else:
881      return PartialVal.known(const2)
884def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
885  """Reorder the `invars` to move to front the ones for which `to_move` is True."""
886  assert not closed_jaxpr.jaxpr.constvars
887  assert len(closed_jaxpr.in_avals) == len(to_move)
888  new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
889  new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
890                         closed_jaxpr.jaxpr.eqns)
891  new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
892  return new_closed_jaxpr
894def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
895  return ([elt for elt, move in zip(lst, to_move) if move] +
896          [elt for elt, move in zip(lst, to_move) if not move])
899class DynamicJaxprTracer(core.Tracer):
900  __slots__ = ['aval']
902  def __init__(self, trace, aval, line_info=None):
903    self._trace = trace
904    self._line_info = line_info
905    self.aval = aval
907  def full_lower(self):
908    return self
910  def _contents(self):
911    return ()
913  def _origin_msg(self):
914    invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
915    if invar_pos:
916      origin = (f"While tracing the function {self._trace.main.source_info}, "
917                "this concrete value was not available in Python because it "
918                "depends on the value of the arguments to "
919                f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
920                "and the computation of these values is being staged out "
921                "(that is, delayed rather than executed eagerly).\n\n"
922                "You can use transformation parameters such as `static_argnums` "
923                "for `jit` to avoid tracing particular arguments of transformed "
924                "functions, though at the cost of more recompiles.")
925    elif progenitor_eqns:
926      msts = [f"  operation {core.pp_eqn(eqn, print_shapes=True)}\n"
927              f"    from line {source_info_util.summarize(eqn.source_info)}"
928              for eqn in progenitor_eqns]
929      origin = (f"While tracing the function {self._trace.main.source_info}, "
930                "this value became a tracer due to JAX operations on these lines:"
931                "\n\n" + "\n\n".join(msts))
932    else:
933      origin = ("The error occured while tracing the function "
934                f"{self._trace.main.source_info}.")
935    return origin
937  def _assert_live(self) -> None:
938    if not self._trace.main.jaxpr_stack:  # type: ignore
939      raise core.escaped_tracer_error(self, None)
941class JaxprStackFrame:
942  __slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
943               'tracers', 'eqns', 'invars']
945  def __init__(self):
946    self.newvar = core.gensym()
947    self.tracer_to_var = {}
948    self.constid_to_var = {}
949    self.constvar_to_val = {}
950    self.tracers = []   # circ refs, frame->tracer->trace->main->frame,
951    self.eqns = []      # cleared when we pop frame from main
952    self.invars = []
954  def to_jaxpr(self, in_tracers, out_tracers):
955    invars = [self.tracer_to_var[id(t)] for t in in_tracers]
956    outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
957    constvars, constvals = unzip2(self.constvar_to_val.items())
958    jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
959    jaxpr, constvals = _inline_literals(jaxpr, constvals)
960    out_avals = [t.aval for t in out_tracers]
961    return jaxpr, out_avals, constvals
963  def find_progenitors(self, tracer):
964    var = self.tracer_to_var.get(id(tracer))
965    if not var:
966      return None, None
967    active_vars = {var}
968    for eqn in self.eqns[::-1]:
969      produced = set(eqn.outvars) & active_vars
970      if produced:
971        active_vars.difference_update(produced)
972        active_vars.update(eqn.invars)
973    invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
974    constvars = active_vars & set(self.constvar_to_val)
975    const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
976    return invar_positions, const_eqns
978def _inline_literals(jaxpr, constvals):
979  consts = dict(zip(jaxpr.constvars, constvals))
980  newvar = core.gensym()
981  class var(dict):
982    def __missing__(self, v):
983      new_v = self[v] = newvar(v.aval)
984      return new_v
985  var = var()
987  def lit(var: core.Var) -> Optional[Any]:
988    val = consts.get(var)
989    if type(val) in core.literalable_types and not np.shape(val):
990      return Literal(val)
991    else:
992      return None
994  used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
995  new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
996  new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
997  new_invars = [var[v] for v in jaxpr.invars]
998  new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
999                            [var[v] if v in used else dropvar for v in eqn.outvars],
1000                            eqn.primitive, eqn.params, eqn.source_info)
1001              for eqn in jaxpr.eqns]
1002  new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
1003  new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
1004  return new_jaxpr, new_constvals
1006class DynamicJaxprTrace(core.Trace):
1007  __slots__ = []  # type: ignore
1009  @property
1010  def frame(self):
1011    return self.main.jaxpr_stack[-1]  # pytype: disable=attribute-error
1013  def new_arg(self, aval):
1014    tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
1015    self.frame.tracers.append(tracer)
1016    self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
1017    self.frame.invars.append(var)
1018    return tracer
1020  def new_const(self, val):
1021    aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
1022    tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
1023    self.frame.tracers.append(tracer)
1024    var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
1025    self.frame.constvar_to_val[var] = val
1026    return tracer
1028  pure = lift = sublift = new_const
1030  def getvar(self, tracer):
1031    var = self.frame.tracer_to_var.get(id(tracer))
1032    if var is None:
1033      raise core.escaped_tracer_error(tracer)
1034    return var
1036  def makevar(self, tracer):
1037    var = self.frame.tracer_to_var.get(id(tracer))
1038    assert var is None, "a jaxpr variable must be created only once per tracer"
1039    self.frame.tracers.append(tracer)
1040    var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
1041    return var
1043  def getconstvar(self, c):
1044    var = self.frame.constid_to_var.get(id(c))
1045    if var is None:
1046      var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
1047    return var
1049  def instantiate_const(self, val):
1050    if (isinstance(val, Tracer) and val._trace.main is self.main
1051        and val._trace.sublevel == self.sublevel):
1052      return val
1053    else:
1054      return self.new_const(val)
1056  def process_primitive(self, primitive, tracers, params):
1057    avals = [t.aval for t in tracers]
1058    out_avals = primitive.abstract_eval(*avals, **params)
1059    out_avals = [out_avals] if not primitive.multiple_results else out_avals
1060    source_info = source_info_util.current()
1061    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1062    invars = map(self.getvar, tracers)
1063    outvars = map(self.makevar, out_tracers)
1064    eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
1065    self.frame.eqns.append(eqn)
1066    return out_tracers if primitive.multiple_results else out_tracers.pop()
1068  def process_call(self, call_primitive, f, tracers, params):
1069    in_avals = [t.aval for t in tracers]
1070    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
1071    if not jaxpr.eqns:
1072      return core.eval_jaxpr(jaxpr, consts, *tracers)
1073    source_info = source_info_util.current()
1074    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1075    invars = map(self.getvar, tracers)
1076    constvars = map(self.getvar, map(self.instantiate_const, consts))
1077    outvars = map(self.makevar, out_tracers)
1078    new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
1079    update_params = call_param_updaters.get(call_primitive)
1080    if update_params:
1081      new_params = update_params(new_params, [True] * len(tracers))
1082    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
1083                        new_params, source_info)
1084    self.frame.eqns.append(eqn)
1085    return out_tracers
1087  def post_process_call(self, call_primitive, out_tracers, params):
1088    assert False  # unreachable
1090  def process_map(self, map_primitive, f, tracers, params):
1091    in_avals = [t.aval for t in tracers]
1092    axis_name, axis_size = params['axis_name'], params['axis_size']
1093    reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
1094                        if in_axis is not None else a
1095                        for a, in_axis in zip(in_avals, params['in_axes'])]
1096    with core.extend_axis_env(axis_name, axis_size, None):  # type: ignore
1097      jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
1098          f, self.main, reduced_in_avals)
1099    out_axes = params['out_axes_thunk']()
1100    out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a)
1101                 if out_axis is not None else a
1102                 for a, out_axis in zip(reduced_out_avals, out_axes)]
1103    source_info = source_info_util.current()
1104    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1105    invars = map(self.getvar, tracers)
1106    constvars = map(self.getvar, map(self.instantiate_const, consts))
1107    outvars = map(self.makevar, out_tracers)
1108    new_in_axes = (None,) * len(consts) + params['in_axes']
1109    new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
1110                      call_jaxpr=convert_constvars_jaxpr(jaxpr))
1111    del new_params['out_axes_thunk']
1112    update_params = call_param_updaters.get(map_primitive)
1113    if update_params:
1114      new_params = update_params(new_params, [True] * len(tracers))
1115    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
1116                        new_params, source_info)
1117    self.frame.eqns.append(eqn)
1118    return out_tracers
1120  def post_process_map(self, map_primitive, out_tracers, params):
1121    assert False  # unreachable
1123  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
1124    in_avals = [t.aval for t in tracers]
1125    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
1126    closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
1127    jvp_jaxpr_thunk = _memoize(
1128        lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
1129    out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
1130    invars = map(self.getvar, tracers)
1131    constvars = map(self.getvar, map(self.instantiate_const, consts))
1132    outvars = map(self.makevar, out_tracers)
1133    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
1134                        dict(fun_jaxpr=closed_fun_jaxpr,
1135                             jvp_jaxpr_thunk=jvp_jaxpr_thunk,
1136                             num_consts=len(consts)),
1137                        source_info_util.current())
1138    self.frame.eqns.append(eqn)
1139    return out_tracers
1141  def post_process_custom_jvp_call(self, out_tracers, params):
1142    assert False  # unreachable
1144  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
1145    in_avals = [t.aval for t in tracers]
1146    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
1147    closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
1148    fwd_jaxpr_thunk = _memoize(
1149        lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
1150    out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
1151    invars = map(self.getvar, tracers)
1152    constvars = map(self.getvar, map(self.instantiate_const, consts))
1153    outvars = map(self.makevar, out_tracers)
1154    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
1155                        dict(fun_jaxpr=closed_fun_jaxpr,
1156                             fwd_jaxpr_thunk=fwd_jaxpr_thunk,
1157                             num_consts=len(consts),
1158                             bwd=bwd, out_trees=out_trees),
1159                        source_info_util.current())
1160    self.frame.eqns.append(eqn)
1161    return out_tracers
1163  def post_process_custom_vjp_call(self, out_tracers, params):
1164    assert False  # unreachable
1166def _memoize(thunk):
1167  cell = []
1168  saved_state = core.thread_local_state.trace_state.copy()
1169  def memoized():
1170    if not cell:
1171      prev_state = core.thread_local_state.trace_state
1172      core.thread_local_state.trace_state = saved_state
1173      try:
1174        cell.append(thunk())
1175      finally:
1176        core.thread_local_state.trace_state = prev_state
1177    return cell[0]
1178  return memoized
1181def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
1182  assert config.omnistaging_enabled
1183  with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
1184    main.source_info = fun_sourceinfo(fun.f)  # type: ignore
1185    main.jaxpr_stack = ()  # type: ignore
1186    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1187    del main, fun
1188  return jaxpr, out_avals, consts
1190def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
1191                              in_avals: Sequence[AbstractValue]):
1192  frame = JaxprStackFrame()
1193  with extend_jaxpr_stack(main, frame):
1194    trace = DynamicJaxprTrace(main, core.cur_sublevel())
1195    in_tracers = map(trace.new_arg, in_avals)
1196    ans = fun.call_wrapped(*in_tracers)
1197    out_tracers = map(trace.full_raise, ans)
1198    jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
1199    del fun, main, trace, frame, in_tracers, out_tracers, ans
1200  return jaxpr, out_avals, consts
1203def extend_jaxpr_stack(main, frame):
1204  main.jaxpr_stack = main.jaxpr_stack + (frame,)
1205  try:
1206    yield
1207  finally:
1208    assert frame is main.jaxpr_stack[-1]
1209    main.jaxpr_stack = main.jaxpr_stack[:-1]
1211def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
1212  assert config.omnistaging_enabled
1213  with core.new_base_main(DynamicJaxprTrace) as main:  # type: ignore
1214    main.source_info = fun_sourceinfo(fun.f)  # type: ignore
1215    main.jaxpr_stack = ()  # type: ignore
1216    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1217    del fun, main
1218  return jaxpr, out_avals, consts
1220def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
1221  # This function provides a partial evaluation behavior used by Flax. We can't
1222  # use trace_to_jaxpr directly because of an interaction with the curent
1223  # custom_derivatives.py, which we work around by adding the EvalTrace.
1224  # TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
1225  assert config.omnistaging_enabled
1226  with core.new_main(core.EvalTrace, dynamic=True) as _:  # type: ignore
1227    return trace_to_jaxpr(fun, in_pvals)
1229def fun_sourceinfo(fun):
1230  if isinstance(fun, functools.partial):
1231    fun = fun.func
1232  try:
1233    filename = fun.__code__.co_filename
1234    lineno = fun.__code__.co_firstlineno
1235    return f"{fun.__name__} at {filename}:{lineno}"
1236  except AttributeError:
1237    return "<unknown>"
1242def omnistaging_disabler() -> None:
1243  global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
1245  def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
1246                     instantiate: Union[bool, Sequence[bool]] = False,
1247                     stage_out=False, bottom=False,
1248                     trace_type: Optional[Type[Trace]] = None,
1249                     ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
1250    """Traces a function into a Jaxpr, given PartialVals for inputs.
1252    Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
1253    computation that depends on unknown inputs. The `out_pvals` are the PartialVal
1254    for the outputs. The intermediate values that depend only on known inputs and
1255    are needed to compute the output of `jaxpr` are in `consts` and are passed in
1256    as the constvars of the `jaxpr`. The handling of the known outputs depends on
1257    `instantiate`.
1259    For example, given `fun` defined as follows::
1261      def fun(ki, ui):  # ki will be a known input in this example
1262        ka = ki + 2
1263        kb = ka + 3
1264        return (kb, ui + ka)
1266    with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
1267    computation that depends on unknown inputs is `ui + ka` and will be the only
1268    computation in the body of the `jaxpr`. This computation depends on the known
1269    intermediate value `ka`, which will be computed statically. Currently, such
1270    constants are either embedded in the Jaxpr if they are scalars, or passed as a
1271    constvar to `jaxpr`, and then the value of the actual constant will be in
1272    `consts`:
1274    When `instantiate=False` we get::
1276      jaxpr =
1277        { lambda ka ; ki ui.
1278          let c = add ui ka
1279          in (*, c) }   # known outputs are `*`
1280      out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
1281      consts = [3]  # the constant for `ka`
1283    When `instantiate=True` we get::
1285      jaxpr =
1286        { lambda ka kb ; ki ui.
1287          let c = add ui ka
1288          in (kb, c) }   # known output are explicit
1289      out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
1290      consts = [3, 6]  # values for `ka` and `kb` constvars
1291    """
1292    trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
1293    with core.new_main(trace_type, bottom=bottom) as main:
1294      fun = trace_to_subjaxpr(fun, main, instantiate)
1295      jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
1296      assert not env
1297      del main
1299    return jaxpr, out_pvals, consts
1301  def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
1302                        instantiate: Union[bool, Sequence[bool]],
1303                        trace_type: Optional[Type[core.Trace]]
1304                        ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
1305    f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
1307    cell = []
1308    def fun(*vals):
1309      pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
1310              for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
1311      jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
1312                                                      trace_type=trace_type)
1313      out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
1314      cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
1315      return out_consts_2 + consts_2
1317    # The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
1318    # the avals, and it will never actually reach any primitives, because the `fun` above will
1319    # execute the jaxpr with the right avals (it reconstructs `pvals` inside).
1320    pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
1321            for aval, uk in zip(jaxpr.in_avals, unknowns)]
1322    jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
1323    (out_pvs_2, jaxpr_2, num_res), = cell
1324    assert len(jaxpr_2.constvars) == num_res
1326    #   jaxpr :: a -> b
1327    # jaxpr_1 :: a1 -> [b1, res]
1328    # jaxpr_2 :: res | a2 -> b2
1329    # jaxpr_2 :: [a2, res] -> b2
1330    jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
1331    jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
1332    for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
1333      if not unknown:
1334        var.aval = abstract_unit
1336    uk_out = [pv is not None for pv in out_pvs_2]
1338    return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
1340  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
1341    # See comment at top of `JaxprTrace`. This method should be reachable
1342    # only when we stage out, and in that case we drop the custom differentiation
1343    # rules, because we do not need them.
1344    if not config.omnistaging_enabled:
1345      assert self.main.trace_type is StagingJaxprTrace
1346    return fun.call_wrapped(*tracers)
1347  JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
1349  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
1350    # See comment in the above process_custom_jvp_call method.
1351    if not config.omnistaging_enabled:
1352      assert self.main.trace_type is StagingJaxprTrace
1353    return fun.call_wrapped(*tracers)
1354  JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
1356  staged_out_calls = set()
1358  class StagingJaxprTrace(JaxprTrace): pass