1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import itertools as it
16from collections import namedtuple
17import contextlib
18import functools
19from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
20                    List, Union, cast, Type, no_type_check)
21from weakref import ref
22
23import numpy as np
24
25from .. import core
26from .. import dtypes
27from .. import linear_util as lu
28from ..ad_util import Zero
29from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
30                         split_list, cache, as_hashable_function)
31from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
32                    unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
33                    dropvar, ConcreteArray, raise_to_shaped)
34from jax._src import source_info_util
35from ..config import config
36
37map = safe_map
38zip = safe_zip
39def identity(x): return x
40
41class PartialVal(tuple):
42  """Partial value: either a known value or an unknown (abstract) value.
43
44  Represented as a pair `(aval_opt, const)` of one of two kinds:
45  * `(None, <Constant>)` indicates a known value, either a Python regular
46    value, or a Tracer.
47  * `(<AbstractValue>, *)` indicates an unknown value characterized by an
48    abstract value.
49  """
50  def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
51    pv, const = xs
52    if not core.skip_checks:
53      # type checks
54      assert isinstance(pv, (AbstractValue, type(None))), xs
55      assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
56      # invariant checks
57      if isinstance(pv, AbstractValue):
58        assert get_aval(const) == core.abstract_unit, xs
59    return tuple.__new__(cls, xs)
60
61  @classmethod
62  def known(cls, const: core.Value) -> 'PartialVal':
63    return PartialVal((None, const))
64
65  @classmethod
66  def unknown(cls, aval: AbstractValue) -> 'PartialVal':
67    return PartialVal((aval, core.unit))
68
69  def is_known(self) -> bool:
70    return self[0] is None
71
72  def get_known(self) -> Optional[core.Value]:
73    """Get the known value, if known, else None."""
74    return self[1] if self[0] is None else None
75
76  def get_aval(self) -> AbstractValue:
77    """Get AbstractValue directly (if unknown) or from the constant (known)."""
78    known = self.get_known()
79    if known is not None:
80      return get_aval(known)
81    else:
82      return self[0]
83
84  def merge_with_known(self, val: core.Value) -> core.Value:
85    """Either the stored known value, or the given 'val'."""
86    known = self.get_known()
87    return known if known is not None else val
88
89
90class JaxprTrace(Trace):
91  def pure(self, val) -> 'JaxprTracer':
92    return self.new_const(val)
93
94  def lift(self, val) -> 'JaxprTracer':
95    return self.new_const(val)
96
97  def sublift(self, val) -> 'JaxprTracer':
98    return JaxprTracer(self, val.pval, FreeVar(val))
99
100  def new_const(self, val) -> 'JaxprTracer':
101    if isinstance(val, Tracer) and val._trace.level == self.level:
102      raise Exception
103    return JaxprTracer(self, PartialVal.known(val), unit)
104
105  def new_instantiated_literal(self, val) -> 'JaxprTracer':
106    return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
107
108  def new_instantiated_const(self, val) -> 'JaxprTracer':
109    return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
110
111  def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
112    const = pval.get_known()
113    if const is None:
114      return JaxprTracer(self, pval, LambdaBinding())
115    else:
116      return self.new_const(const)
117
118  def instantiate_const(self, tracer) -> Tracer:
119    const = tracer.pval.get_known()
120    if const is None:
121      return tracer
122    else:
123      if type(const) in core.literalable_types and np.shape(const) == ():
124        return self.new_instantiated_literal(const)
125      else:
126        return self.new_instantiated_const(const)
127
128  def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
129    const = tracer.pval.get_known()
130    if const is None:
131      return tracer
132    else:
133      aval = raise_to_shaped(get_aval(const), np.isscalar(const))
134      return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
135
136  def process_primitive(self, primitive, tracers, params):
137    if primitive in custom_partial_eval_rules:
138      return custom_partial_eval_rules[primitive](self, *tracers, **params)
139    else:
140      return self.default_process_primitive(primitive, tracers, params)
141
142  def default_process_primitive(self, primitive, tracers, params):
143    """By default, if all the input tracers are known, then execute the primitive
144    and all the ouputs are known. Otherwise, all the outputs are unknown."""
145    consts = [t.pval.get_known() for t in tracers]
146    if all(c is not None for c in consts):
147      return primitive.bind(*consts, **params)
148    tracers = map(self.instantiate_const, tracers)
149    avals = [t.aval for t in tracers]
150    out_aval = primitive.abstract_eval(*avals, **params)
151    source = source_info_util.current()
152    if primitive.multiple_results:
153      out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
154                     for aval in out_aval]
155      eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
156      for t in out_tracers: t.recipe = eqn
157      return out_tracers
158    else:
159      out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
160      out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
161                                         params, source)
162      return out_tracer
163
164  # We use process_call to handle both call and map primitives.
165  def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
166    if not config.omnistaging_enabled:
167      if (self.main.trace_type is StagingJaxprTrace  # type: ignore
168          and primitive in staged_out_calls):        # type: ignore
169        tracers = map(self.instantiate_const_abstracted, tracers)
170
171    if primitive in call_partial_eval_rules:
172      return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
173
174    in_pvals = [t.pval for t in tracers]
175    if primitive.map_primitive:
176      mapped_aval = partial(core.mapped_aval, params['axis_size'])
177      in_pvals = [pval if pval.is_known() or in_axis is None
178                  else PartialVal.unknown(mapped_aval(in_axis, pval[0]))
179                  for pval, in_axis in zip(in_pvals, params['in_axes'])]
180
181      def app(f, *args):
182        f, num_outputs = count_outputs(f)
183        out_axes_thunk = params['out_axes_thunk']
184        @as_hashable_function(closure=out_axes_thunk)
185        def new_out_axes_thunk():
186          out_axes = out_axes_thunk()
187          return out_axes + (0,) * (num_outputs() - len(out_axes))
188        pe_params = dict(params, out_axes_thunk=new_out_axes_thunk)
189        return primitive.bind(f, *args, **pe_params)
190    else:
191      app = partial(primitive.bind, **params)
192    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
193        f, in_pvals, app, instantiate=False)
194    if primitive.map_primitive:
195      unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
196      out_axes = params['out_axes_thunk']()
197      out_pvals = [pval if pval.is_known() else
198                   PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else
199                   PartialVal.unknown(pval[0])
200                   for pval, out_axis in zip(out_pvals, out_axes)]
201
202    # Skip known invars and outvars, and lift constants as regular invars
203    in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
204    out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
205    jaxpr = _drop_invars(jaxpr, in_knowns)
206    jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
207
208    # Known tracers get propagated as if they were constants
209    known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
210                         if pval.is_known()]
211
212    # Unknown tracers need to have the jaxpr set up as their recipe
213    unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
214                           if not pval.is_known()]
215    unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
216    const_tracers = map(self.new_instantiated_const, consts)
217    in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
218
219    # Set up new params
220    new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
221    if primitive.map_primitive:
222      in_axes = params['in_axes']
223      # NOTE: const_tracers are added as map outputs, and we always map them
224      #       along axis 0 (see `new_out_axes_thunk` above).
225      new_in_axes = ((0,) * len(const_tracers) +
226                     (None,) * len(env_tracers) +
227                     tuple(axis for axis, t in zip(in_axes, tracers)
228                           if not t.pval.is_known()))
229      new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
230                           if not pval.is_known())
231      new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes)
232      del new_params['out_axes_thunk']
233    update_params = call_param_updaters.get(primitive)
234    if update_params:
235      new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
236
237    eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
238                         source_info_util.current())
239    for t in unknown_tracers_out: t.recipe = eqn
240    return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
241
242  process_map = process_call
243
244  # We use post_process_call to handle both call and map primitives.
245  def post_process_call(self, primitive, out_tracers, params):
246    jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
247    out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
248    out = out_pv_consts + consts
249    nconsts = len(consts)
250    del consts, out_pv_consts
251    main = self.main
252
253    if primitive.map_primitive:
254      out_axes = params['out_axes_thunk']()
255      sz = params['axis_size']
256      out_pvs = [None if pv is None else core.unmapped_aval(sz, ax, pv)
257                 for pv, ax in zip(out_pvs, out_axes)]
258
259    def todo(x):
260      n = len(jaxpr.outvars)
261      out_pv_consts, consts = x[:n], x[n:]
262      trace = JaxprTrace(main, core.cur_sublevel())
263      const_tracers = map(trace.new_instantiated_const, consts)
264      out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
265                     for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
266      in_tracers = (*const_tracers, *map(trace.full_raise, env))
267
268      new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
269      if primitive.map_primitive:
270        # NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform.
271        new_in_axes = (0,) * len(const_tracers) + (None,) * len(env)
272        new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
273        del new_params['out_axes_thunk']
274      update_params = call_param_updaters.get(primitive)
275      if update_params:
276        new_params = update_params(new_params, [])
277
278      eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
279                           source_info_util.current())
280      for t in out_tracers:
281        t.recipe = eqn
282      return out_tracers
283
284    if primitive.map_primitive:
285      def out_axes_transform(out_axes):
286        return out_axes + (0,) * nconsts
287      todo = (todo, out_axes_transform)
288
289    return out, todo
290
291  post_process_map = post_process_call
292
293  def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
294                   app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
295                   instantiate: bool):
296    """Partially evaluate f on a sequence of PartialVals."""
297    in_avals, in_consts = unzip2(pvals)
298    f = trace_to_subjaxpr(f, self.main, instantiate)
299    f, aux = partial_eval_wrapper(f, tuple(in_avals))
300    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
301    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
302    out_pvs = map(PartialVal, zip(out_avals, out_consts))
303    env_tracers = map(self.full_raise, env)
304    return jaxpr, out_pvs, consts, env_tracers
305
306  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
307    tracers = map(self.instantiate_const_abstracted, tracers)
308    in_avals, in_consts = unzip2(t.pval for t in tracers)  # in_consts are units
309    fun = trace_to_subjaxpr(fun, self.main, True)
310    fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
311    out_flat = prim.bind(fun, jvp, *in_consts)
312    out_avals, jaxpr, env = aux()
313    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
314    out_pvals = map(PartialVal, zip(out_avals, out_consts))  # out_consts are units
315    env_tracers = map(self.full_raise, env)
316    out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
317    const_tracers = map(self.new_instantiated_const, consts)
318    in_tracers = (*const_tracers, *env_tracers, *tracers)
319    closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
320
321    @_memoize
322    def jvp_jaxpr_thunk():
323      jvp_ = trace_to_subjaxpr(jvp, self.main, True)
324      jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
325      out_flat = jvp_.call_wrapped(*(in_consts * 2))  # in_consts are units
326      out_avals, jaxpr, env = aux()
327      _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
328      converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
329      return converted_jaxpr, (*consts, *env)
330
331    eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
332                         dict(fun_jaxpr=closed_jaxpr,
333                              jvp_jaxpr_thunk=jvp_jaxpr_thunk,
334                              num_consts=len(consts) + len(env)),
335                         source_info_util.current())
336    for t in out_tracers: t.recipe = eqn
337    return out_tracers
338
339  def post_process_custom_jvp_call(self, out_tracers, params):
340    # This path should only be reachable if we expose a partial eval API
341    # unrelated to autodiff, since we raise an error when differentiation with
342    # respect to values over which a custom_jvp function closes is detected.
343    raise NotImplementedError  # TODO(mattjj)
344
345  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
346    tracers = map(self.instantiate_const_abstracted, tracers)
347    in_avals, in_consts = unzip2(t.pval for t in tracers)  # in_consts are units
348    fun = trace_to_subjaxpr(fun, self.main, True)
349    fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
350    out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
351    out_avals, jaxpr, env = aux()
352    out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
353    out_pvals = map(PartialVal, zip(out_avals, out_consts))  # out_consts are units
354    env_tracers = map(self.full_raise, env)
355    out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
356    const_tracers = map(self.new_instantiated_const, consts)
357    in_tracers = (*const_tracers, *env_tracers, *tracers)
358    closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
359
360    @_memoize
361    def fwd_jaxpr_thunk():
362      fwd_ = trace_to_subjaxpr(fwd, self.main, True)
363      fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
364      out_flat = fwd_.call_wrapped(*in_consts)  # in_consts are units
365      out_avals, jaxpr, env = aux()
366      _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
367      converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
368      return converted_jaxpr, (*consts, *env)
369
370    eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
371                         dict(fun_jaxpr=closed_jaxpr,
372                              fwd_jaxpr_thunk=fwd_jaxpr_thunk,
373                              num_consts=len(consts) + len(env),
374                              bwd=bwd, out_trees=out_trees),
375                         source_info_util.current())
376    for t in out_tracers: t.recipe = eqn
377    return out_tracers
378
379  def post_process_custom_vjp_call(self, out_tracers, params):
380    # This path should only be reachable if we expose a partial eval API
381    # unrelated to autodiff, since we raise an error when differentiation with
382    # respect to values over which a custom_vjp function closes is detected.
383    raise NotImplementedError  # TODO(mattjj)
384
385
386@lu.transformation_with_aux
387def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
388  py_args = map(PartialVal, zip(pvs, consts))
389  jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
390  out_pvs, out_consts = unzip2(out_pvals)
391  out = tuple(out_consts) + tuple(consts)
392  yield out, (out_pvs, jaxpr, env)
393
394@lu.transformation_with_aux
395def count_outputs(*args, **kwargs):
396  ans = yield args, kwargs
397  yield ans, len(ans)
398
399custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
400call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
401call_param_updaters: Dict[core.Primitive, Callable] = {}
402
403
404def abstract_eval_fun(fun, *avals, **params):
405  if config.omnistaging_enabled:
406    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
407  else:
408    pvals_in = [PartialVal.unknown(a) for a in avals]
409    _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
410                                    instantiate=True, stage_out=True)  # type: ignore
411    avals_out, _ = unzip2(pvals_out)
412  for aval_out in avals_out:
413    assert isinstance(aval_out, AbstractValue)  # instantiate=True
414  return avals_out
415
416
417JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
418                          'ConstVar', Literal, core.Unit]
419
420class JaxprTracer(Tracer):
421  __slots__ = ['pval', 'recipe']
422
423  def __init__(self, trace: JaxprTrace, pval: PartialVal,
424               recipe: Optional[JaxprTracerRecipe]):
425    assert isinstance(pval, PartialVal)
426    pv, const = pval
427    if isinstance(const, Tracer) and const._trace.level >= trace.level:
428      raise core.escaped_tracer_error(
429          const, "Tracer from a higher level: {} in trace {}".format(const, trace))
430    self._trace = trace
431    self.pval = pval
432    self.recipe = recipe
433
434  def __repr__(self):
435    return 'Traced<{}:{}>'.format(self.aval, self._trace)
436
437  @property
438  def aval(self) -> AbstractValue:
439    return self.pval.get_aval()
440
441  @property
442  def parents(self) -> Sequence['JaxprTracer']:
443    if isinstance(self.recipe, JaxprEqnRecipe):
444      return self.recipe.invars
445    else:
446      return []
447
448  def full_lower(self):
449    known = self.pval.get_known()
450    if known is not None:
451      return core.full_lower(known)
452    else:
453      return self
454
455  def is_known(self):
456      return self.pval.is_known()
457
458# TODO(necula): this could return a ClosedJaxpr with out_pvals
459def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
460                   instantiate: Union[bool, Sequence[bool]] = False,
461                   ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
462  """Traces a function into a Jaxpr, given PartialVals for inputs.
463
464  Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
465  computation that depends on unknown inputs. The `out_pvals` are the PartialVal
466  for the outputs. The intermediate values that depend only on known inputs and
467  are needed to compute the output of `jaxpr` are in `consts` and are passed in
468  as the constvars of the `jaxpr`. The handling of the known outputs depends on
469  `instantiate`.
470
471  For example, given `fun` defined as follows::
472
473    def fun(ki, ui):  # ki will be a known input in this example
474      ka = ki + 2
475      kb = ka + 3
476      return (kb, ui + ka)
477
478  with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
479  computation that depends on unknown inputs is `ui + ka` and will be the only
480  computation in the body of the `jaxpr`. This computation depends on the known
481  intermediate value `ka`, which will be computed statically. Currently, such
482  constants are either embedded in the Jaxpr if they are scalars, or passed as a
483  constvar to `jaxpr`, and then the value of the actual constant will be in
484  `consts`:
485
486  When `instantiate=False` we get::
487
488    jaxpr =
489      { lambda ka ; ki ui.
490        let c = add ui ka
491        in (*, c) }   # known outputs are `*`
492    out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
493    consts = [3]  # the constant for `ka`
494
495  When `instantiate=True` we get::
496
497    jaxpr =
498      { lambda ka kb ; ki ui.
499        let c = add ui ka
500        in (kb, c) }   # known output are explicit
501    out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
502    consts = [3, 6]  # values for `ka` and `kb` constvars
503  """
504  with core.new_main(JaxprTrace) as main:
505    fun = trace_to_subjaxpr(fun, main, instantiate)
506    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
507    assert not env
508    del main, fun, env
509
510  return jaxpr, out_pvals, consts
511
512
513@lu.transformation
514def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
515                      pvals: Sequence[PartialVal]):
516  assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
517  trace = JaxprTrace(main, core.cur_sublevel())
518  in_tracers = map(trace.new_arg, pvals)
519  ans = yield in_tracers, {}
520  instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
521  out_tracers = map(trace.full_raise, map(core.full_lower, ans))
522  out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
523  jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
524  out_pvals = [t.pval for t in out_tracers]
525  del trace, in_tracers, out_tracers
526  yield jaxpr, (out_pvals, consts, env)
527
528def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
529  if instantiate:
530    return trace.instantiate_const(trace.full_raise(tracer))
531  else:
532    return tracer
533
534
535FreeVar = namedtuple('FreeVar', ['val'])
536ConstVar = namedtuple('ConstVar', ['val'])
537LambdaBinding = namedtuple('LambdaBinding', [])
538class JaxprEqnRecipe(NamedTuple):
539  eqn_id: object
540  invars: Sequence[JaxprTracer]
541  outvars: 'Sequence[ref[JaxprTracer]]'
542  primitive: core.Primitive
543  params: Dict[str, Any]
544  source_info: Optional[source_info_util.Traceback]
545
546def new_eqn_recipe(invars: Sequence[JaxprTracer],
547                   outvars: Sequence[JaxprTracer],
548                   primitive: core.Primitive,
549                   params: Dict[str, Any],
550                   source_info: Optional[source_info_util.Traceback]
551                  ) -> JaxprEqnRecipe:
552  """Constructs a new JaxEqnRecipe.
553
554  Params:
555    invars: the tracers for the primitive inputs.
556    outvars: the tracers for the primitive outputs.
557    primitive: the primitive.
558    params: the primitive params
559  """
560  # TODO(necula): move these checks to core.check_jaxpr, and call in more places
561  if primitive.call_primitive or primitive.map_primitive:
562    assert "call_jaxpr" in params
563  if primitive.map_primitive:
564    assert ("in_axes" in params and
565            len(params["in_axes"]) == len(params["call_jaxpr"].invars))
566    assert ("donated_invars" in params and
567            len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
568  return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
569                        params, source_info)
570
571
572def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
573                  recipe: JaxprEqnRecipe) -> core.JaxprEqn:
574  _, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
575  out_tracers = [t_ref() for t_ref in out_tracer_refs]
576  invars  = [getvar(t) for t in in_tracers]
577  outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
578             for t in out_tracers]
579  return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
580
581def tracers_to_jaxpr(
582  in_tracers: Sequence[JaxprTracer],
583  out_tracers: Sequence[JaxprTracer]
584  ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
585  """Constructs Jaxpr given tracers for inputs and outputs.
586
587  Params:
588    in_tracers: the tracers that were created for the function inputs
589    out_tracers: the tracers that were output by the function.
590
591  Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
592    the `constvars` in the returned Jaxps, and a list of environment values.
593    The vars for the environment values have been prepended to the Jaxpr's
594    `invars`.
595  """
596  newvar = core.gensym()
597  t_to_var: Dict[int, core.Atom] = {}
598  def getvar(t: JaxprTracer) -> core.Atom:
599    var = t_to_var.get(id(t))
600    if var is None:
601      aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
602      var = t_to_var[id(t)] = newvar(aval)
603    return var
604  sorted_tracers = toposort(out_tracers)
605  invars = map(getvar, in_tracers)
606  eqns: List[core.JaxprEqn] = []
607  env: Dict[core.Var, Any] = {}
608  consts: Dict[core.Var, Any] = {}
609  const_to_var: Dict[int, core.Var] = {}
610  def getconstvar(c):
611    var = const_to_var.get(id(c))
612    if var is None:
613      var = const_to_var[id(c)] = newvar(get_aval(c))
614    return var
615  processed_eqn_ids = set()
616  for t in sorted_tracers:
617    recipe = t.recipe
618    if isinstance(recipe, JaxprEqnRecipe):
619      if recipe.eqn_id not in processed_eqn_ids:
620        eqns.append(recipe_to_eqn(getvar, recipe))
621        processed_eqn_ids.add(recipe.eqn_id)
622    elif isinstance(recipe, LambdaBinding):
623      if not any(t is in_tracer for in_tracer in in_tracers):
624        raise core.escaped_tracer_error(
625            t, "Tracer not among input tracers {}".format(t))
626      assert in_tracers, "Lambda binding with no args"
627    elif isinstance(recipe, FreeVar):
628      env[cast(core.Var, getvar(t))] = recipe.val
629    elif isinstance(recipe, ConstVar):
630      v = t_to_var[id(t)] = getconstvar(recipe.val)
631      consts[v] = recipe.val
632    elif isinstance(recipe, Literal):
633      t_to_var[id(t)] = recipe
634    elif recipe is unit:
635      t_to_var[id(t)] = unitvar
636    else:
637      raise TypeError(recipe)
638
639  env_vars, env_vals = unzip2(env.items())
640  const_vars, const_vals = unzip2(consts.items())
641  # The env_vars are pre-pended to the invars
642  jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
643  core.skip_checks or core.check_jaxpr(jaxpr)
644  return jaxpr, const_vals, env_vals
645
646@cache()
647def convert_constvars_jaxpr(jaxpr: Jaxpr):
648  """Moves the constvars to the start of invars."""
649  core.skip_checks or core.check_jaxpr(jaxpr)
650  lifted_jaxpr = Jaxpr(constvars=(),
651                       invars=jaxpr.constvars + jaxpr.invars,
652                       outvars=jaxpr.outvars, eqns=jaxpr.eqns)
653  core.skip_checks or core.check_jaxpr(lifted_jaxpr)
654  return lifted_jaxpr
655
656def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
657  core.skip_checks or core.check_jaxpr(jaxpr)
658  env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
659  converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
660                          invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
661  core.skip_checks or core.check_jaxpr(converted_jaxpr)
662  return converted_jaxpr
663
664
665def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
666  return (abstract_unit, aval) if unknown else (aval, abstract_unit)
667
668def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
669                       instantiate: Union[bool, Sequence[bool]],
670                       ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
671  """Specializes a Jaxpr given an indication of which inputs are known.
672
673  Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
674
675  `out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
676  `jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
677  and performs *all* the computation in `jaxpr` that depends only on the known inputs.
678  Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
679  appended with the known residuals (the intermediate computations in `jaxpr`
680  that depend only on known inputs and that are needed to compute the unknown outputs).
681
682  `jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
683  computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
684  outputs replaced by `*`.
685
686  Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
687  unknown inputs into:
688
689    jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
690                    let _, uout = jaxpr_unknown(ki, ui, kresidual)
691                    in (kout, uout)
692
693  For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
694                                            in (ki + 3, ui + ka)"
695  then
696    `jaxpr_known` = lambda ki, ui: let ka = ki + 2
697                                    in (ki + 3, *, ka)
698    'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
699
700  Note that if instantiate is True for a given output, then jaxpr_known always returns a
701  unit in its place. So when instantiate is True, the expectation is the one doesn't
702  run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
703  to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
704  simply get passed as residual constants and returned immediately.
705  """
706  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
707
708  cell = []
709  def fun(*vals):
710    pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
711            for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
712    jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
713    out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
714    cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
715    return out_consts_2 + consts_2
716
717  # For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
718  # known inputs.
719  in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
720  jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
721  (out_pvs_2, jaxpr_2, num_res), = cell
722  assert len(jaxpr_2.constvars) == num_res
723
724  #   jaxpr :: a -> b
725  # jaxpr_1 :: a1 -> [b1, res]
726  # jaxpr_2 :: res | a2 -> b2
727  # jaxpr_2 :: [a2, res] -> b2
728  jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
729  jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
730  for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
731    if not unknown:
732      var.aval = abstract_unit
733
734  uk_out = [pv is not None for pv in out_pvs_2]
735
736  in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
737  out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
738  # out_avals_1 and in_avals_2 need the residuals added
739  res_avals = out_avals[len(jaxpr.out_avals):]
740  assert len(res_avals) == num_res
741  out_avals_1 = [*out_avals_1, *res_avals]
742  in_avals_2 = [*in_avals_2, *res_avals]
743
744  return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
745
746
747remat_call_p = core.CallPrimitive('remat_call')
748remat_call = remat_call_p.bind
749remat_call_p.def_impl(core.call_impl)
750
751def _remat_partial_eval(trace, _, f, tracers, params):
752  concrete = params['concrete']
753
754  # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
755  # the function being called, not just for the unknown parts. To do that, we
756  # instantiate all the input tracers as constants in the jaxpr being formed.
757  # Those tracers might have concrete avals, and doing abstract interpretation
758  # on concrete avals engenders a tradeoff: it allows data-dependent Python
759  # control flow to work, but it can in some cases lead to redundant FLOPs (done
760  # both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
761  # `concrete` parameter to switch this behavior, and if `concrete` is False
762  # then we raise the avals to the Shaped level.
763  if concrete:
764    instantiated_tracers = map(trace.instantiate_const, tracers)
765  else:
766    instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
767
768  # Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
769  in_pvals = [t.pval for t in instantiated_tracers]
770  if config.omnistaging_enabled:
771    jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
772      f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
773  else:
774    with core.initial_style_staging():  # type: ignore
775      jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
776        f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
777
778  # Convert consts to inputs, since they may contain Tracer instances.
779  jaxpr = convert_constvars_jaxpr(jaxpr)
780  const_tracers = map(trace.new_instantiated_const, consts)
781
782  # Since we traced with everything marked as unknown, but we need to know which
783  # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
784  closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
785  in_unknowns = ([False] * len(consts) +
786                 [not t.is_known() for t in it.chain(env_tracers, tracers)])
787  if config.omnistaging_enabled:
788    jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
789        closed_jaxpr, in_unknowns, instantiate=False)  # type: ignore
790  else:
791    jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
792        closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)  # type: ignore
793  out_knowns = [not b for b in out_unknowns]
794  out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
795
796  # Next, we need values for the outputs that should be known. Since consts
797  # weren't passed through Python for evaluation, we need to evaluate jaxpr_known,
798  # minus the residual outputs that we don't need. When `concrete=True`, as an
799  # optimization we can avoid redoing *some* redundant FLOPs, namely those that
800  # produced concrete avals at the output, simply by using those as computed
801  # values. For the use case of inverse-mode ad in op-by-op ("eager mode")
802  # evaluation, all the primal outputs should be concrete (thus not recomputed).
803  to_compute = [type(pval[0]) is not ConcreteArray
804                for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
805  num_outputs = len(jaxpr_unknown.out_avals)
806  num_res = len(jaxpr_known.out_avals) - num_outputs
807  jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
808  jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
809  _, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
810  reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
811  out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
812
813  # Known outputs should keep propagating as constants
814  assert all(pv.is_known() for pv in out_known_pvals)
815  known_output_tracers = [trace.new_const(pval.get_known())
816                          for pval in out_known_pvals]
817  # Unknown outputs get wrapped in tracers with the appropriate recipe
818  unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
819                            for out_pval in out_unknown_pvals]
820
821  # dce jaxpr outputs
822  new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
823  new_params = dict(params, call_jaxpr=new_jaxpr)
824
825  # set up eqn for unknown outputs
826  in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
827  eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params,
828                       source_info_util.current())
829  for t in unknown_output_tracers: t.recipe = eqn
830  return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
831call_partial_eval_rules[remat_call_p] = _remat_partial_eval
832
833def _partition_knowns(pvals, unknowns: Sequence[bool]):
834  return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
835          [e for e, unknown in zip(pvals, unknowns) if unknown])
836
837def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
838  known_iter, unknown_iter = iter(known_list), iter(unknown_list)
839  return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
840
841
842def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
843  new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
844  return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
845
846@cache()
847def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
848  # This dead-code elimination is pretty rudimentary, and in particular doesn't
849  # nontrivially DCE through scan, call, or other higher-order primitives.
850  # TODO(mattjj): better DCE
851  if drop_outputs:
852    new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
853  else:
854    new_outvars = [var if output else unitvar
855                   for var, output in zip(jaxpr.outvars, outputs)]
856
857  needed_vars = {v for v in new_outvars if type(v) is not Literal}
858  new_eqns = []
859  for eqn in jaxpr.eqns[::-1]:
860    if set(eqn.outvars) & needed_vars:
861      new_eqns.append(eqn)
862      needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
863  new_eqns = new_eqns[::-1]
864  return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
865                    new_outvars, new_eqns)
866
867@cache()
868def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]):
869  return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d],
870                    jaxpr.outvars, jaxpr.eqns)
871
872
873def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
874  pv1, _ = pval1
875  if pval1.is_known():
876    return pval1
877  else:
878    if type(pv1) is ConcreteArray:
879      return PartialVal.known(pv1.val)  # pytype: disable=attribute-error
880    else:
881      return PartialVal.known(const2)
882
883
884def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
885  """Reorder the `invars` to move to front the ones for which `to_move` is True."""
886  assert not closed_jaxpr.jaxpr.constvars
887  assert len(closed_jaxpr.in_avals) == len(to_move)
888  new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
889  new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
890                         closed_jaxpr.jaxpr.eqns)
891  new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
892  return new_closed_jaxpr
893
894def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
895  return ([elt for elt, move in zip(lst, to_move) if move] +
896          [elt for elt, move in zip(lst, to_move) if not move])
897
898
899class DynamicJaxprTracer(core.Tracer):
900  __slots__ = ['aval']
901
902  def __init__(self, trace, aval, line_info=None):
903    self._trace = trace
904    self._line_info = line_info
905    self.aval = aval
906
907  def full_lower(self):
908    return self
909
910  def _contents(self):
911    return ()
912
913  def _origin_msg(self):
914    invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
915    if invar_pos:
916      origin = (f"While tracing the function {self._trace.main.source_info}, "
917                "this concrete value was not available in Python because it "
918                "depends on the value of the arguments to "
919                f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
920                "and the computation of these values is being staged out "
921                "(that is, delayed rather than executed eagerly).\n\n"
922                "You can use transformation parameters such as `static_argnums` "
923                "for `jit` to avoid tracing particular arguments of transformed "
924                "functions, though at the cost of more recompiles.")
925    elif progenitor_eqns:
926      msts = [f"  operation {core.pp_eqn(eqn, print_shapes=True)}\n"
927              f"    from line {source_info_util.summarize(eqn.source_info)}"
928              for eqn in progenitor_eqns]
929      origin = (f"While tracing the function {self._trace.main.source_info}, "
930                "this value became a tracer due to JAX operations on these lines:"
931                "\n\n" + "\n\n".join(msts))
932    else:
933      origin = ("The error occured while tracing the function "
934                f"{self._trace.main.source_info}.")
935    return origin
936
937  def _assert_live(self) -> None:
938    if not self._trace.main.jaxpr_stack:  # type: ignore
939      raise core.escaped_tracer_error(self, None)
940
941class JaxprStackFrame:
942  __slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
943               'tracers', 'eqns', 'invars']
944
945  def __init__(self):
946    self.newvar = core.gensym()
947    self.tracer_to_var = {}
948    self.constid_to_var = {}
949    self.constvar_to_val = {}
950    self.tracers = []   # circ refs, frame->tracer->trace->main->frame,
951    self.eqns = []      # cleared when we pop frame from main
952    self.invars = []
953
954  def to_jaxpr(self, in_tracers, out_tracers):
955    invars = [self.tracer_to_var[id(t)] for t in in_tracers]
956    outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
957    constvars, constvals = unzip2(self.constvar_to_val.items())
958    jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
959    jaxpr, constvals = _inline_literals(jaxpr, constvals)
960    out_avals = [t.aval for t in out_tracers]
961    return jaxpr, out_avals, constvals
962
963  def find_progenitors(self, tracer):
964    var = self.tracer_to_var.get(id(tracer))
965    if not var:
966      return None, None
967    active_vars = {var}
968    for eqn in self.eqns[::-1]:
969      produced = set(eqn.outvars) & active_vars
970      if produced:
971        active_vars.difference_update(produced)
972        active_vars.update(eqn.invars)
973    invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
974    constvars = active_vars & set(self.constvar_to_val)
975    const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
976    return invar_positions, const_eqns
977
978def _inline_literals(jaxpr, constvals):
979  consts = dict(zip(jaxpr.constvars, constvals))
980  newvar = core.gensym()
981  class var(dict):
982    def __missing__(self, v):
983      new_v = self[v] = newvar(v.aval)
984      return new_v
985  var = var()
986
987  def lit(var: core.Var) -> Optional[Any]:
988    val = consts.get(var)
989    if type(val) in core.literalable_types and not np.shape(val):
990      return Literal(val)
991    else:
992      return None
993
994  used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
995  new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
996  new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
997  new_invars = [var[v] for v in jaxpr.invars]
998  new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
999                            [var[v] if v in used else dropvar for v in eqn.outvars],
1000                            eqn.primitive, eqn.params, eqn.source_info)
1001              for eqn in jaxpr.eqns]
1002  new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
1003  new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
1004  return new_jaxpr, new_constvals
1005
1006class DynamicJaxprTrace(core.Trace):
1007  __slots__ = []  # type: ignore
1008
1009  @property
1010  def frame(self):
1011    return self.main.jaxpr_stack[-1]  # pytype: disable=attribute-error
1012
1013  def new_arg(self, aval):
1014    tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
1015    self.frame.tracers.append(tracer)
1016    self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
1017    self.frame.invars.append(var)
1018    return tracer
1019
1020  def new_const(self, val):
1021    aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
1022    tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
1023    self.frame.tracers.append(tracer)
1024    var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
1025    self.frame.constvar_to_val[var] = val
1026    return tracer
1027
1028  pure = lift = sublift = new_const
1029
1030  def getvar(self, tracer):
1031    var = self.frame.tracer_to_var.get(id(tracer))
1032    if var is None:
1033      raise core.escaped_tracer_error(tracer)
1034    return var
1035
1036  def makevar(self, tracer):
1037    var = self.frame.tracer_to_var.get(id(tracer))
1038    assert var is None, "a jaxpr variable must be created only once per tracer"
1039    self.frame.tracers.append(tracer)
1040    var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
1041    return var
1042
1043  def getconstvar(self, c):
1044    var = self.frame.constid_to_var.get(id(c))
1045    if var is None:
1046      var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
1047    return var
1048
1049  def instantiate_const(self, val):
1050    if (isinstance(val, Tracer) and val._trace.main is self.main
1051        and val._trace.sublevel == self.sublevel):
1052      return val
1053    else:
1054      return self.new_const(val)
1055
1056  def process_primitive(self, primitive, tracers, params):
1057    avals = [t.aval for t in tracers]
1058    out_avals = primitive.abstract_eval(*avals, **params)
1059    out_avals = [out_avals] if not primitive.multiple_results else out_avals
1060    source_info = source_info_util.current()
1061    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1062    invars = map(self.getvar, tracers)
1063    outvars = map(self.makevar, out_tracers)
1064    eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
1065    self.frame.eqns.append(eqn)
1066    return out_tracers if primitive.multiple_results else out_tracers.pop()
1067
1068  def process_call(self, call_primitive, f, tracers, params):
1069    in_avals = [t.aval for t in tracers]
1070    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
1071    if not jaxpr.eqns:
1072      return core.eval_jaxpr(jaxpr, consts, *tracers)
1073    source_info = source_info_util.current()
1074    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1075    invars = map(self.getvar, tracers)
1076    constvars = map(self.getvar, map(self.instantiate_const, consts))
1077    outvars = map(self.makevar, out_tracers)
1078    new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
1079    update_params = call_param_updaters.get(call_primitive)
1080    if update_params:
1081      new_params = update_params(new_params, [True] * len(tracers))
1082    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
1083                        new_params, source_info)
1084    self.frame.eqns.append(eqn)
1085    return out_tracers
1086
1087  def post_process_call(self, call_primitive, out_tracers, params):
1088    assert False  # unreachable
1089
1090  def process_map(self, map_primitive, f, tracers, params):
1091    in_avals = [t.aval for t in tracers]
1092    axis_name, axis_size = params['axis_name'], params['axis_size']
1093    reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
1094                        if in_axis is not None else a
1095                        for a, in_axis in zip(in_avals, params['in_axes'])]
1096    with core.extend_axis_env(axis_name, axis_size, None):  # type: ignore
1097      jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
1098          f, self.main, reduced_in_avals)
1099    out_axes = params['out_axes_thunk']()
1100    out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a)
1101                 if out_axis is not None else a
1102                 for a, out_axis in zip(reduced_out_avals, out_axes)]
1103    source_info = source_info_util.current()
1104    out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
1105    invars = map(self.getvar, tracers)
1106    constvars = map(self.getvar, map(self.instantiate_const, consts))
1107    outvars = map(self.makevar, out_tracers)
1108    new_in_axes = (None,) * len(consts) + params['in_axes']
1109    new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
1110                      call_jaxpr=convert_constvars_jaxpr(jaxpr))
1111    del new_params['out_axes_thunk']
1112    update_params = call_param_updaters.get(map_primitive)
1113    if update_params:
1114      new_params = update_params(new_params, [True] * len(tracers))
1115    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
1116                        new_params, source_info)
1117    self.frame.eqns.append(eqn)
1118    return out_tracers
1119
1120  def post_process_map(self, map_primitive, out_tracers, params):
1121    assert False  # unreachable
1122
1123  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
1124    in_avals = [t.aval for t in tracers]
1125    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
1126    closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
1127    jvp_jaxpr_thunk = _memoize(
1128        lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
1129    out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
1130    invars = map(self.getvar, tracers)
1131    constvars = map(self.getvar, map(self.instantiate_const, consts))
1132    outvars = map(self.makevar, out_tracers)
1133    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
1134                        dict(fun_jaxpr=closed_fun_jaxpr,
1135                             jvp_jaxpr_thunk=jvp_jaxpr_thunk,
1136                             num_consts=len(consts)),
1137                        source_info_util.current())
1138    self.frame.eqns.append(eqn)
1139    return out_tracers
1140
1141  def post_process_custom_jvp_call(self, out_tracers, params):
1142    assert False  # unreachable
1143
1144  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
1145    in_avals = [t.aval for t in tracers]
1146    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
1147    closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
1148    fwd_jaxpr_thunk = _memoize(
1149        lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
1150    out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
1151    invars = map(self.getvar, tracers)
1152    constvars = map(self.getvar, map(self.instantiate_const, consts))
1153    outvars = map(self.makevar, out_tracers)
1154    eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
1155                        dict(fun_jaxpr=closed_fun_jaxpr,
1156                             fwd_jaxpr_thunk=fwd_jaxpr_thunk,
1157                             num_consts=len(consts),
1158                             bwd=bwd, out_trees=out_trees),
1159                        source_info_util.current())
1160    self.frame.eqns.append(eqn)
1161    return out_tracers
1162
1163  def post_process_custom_vjp_call(self, out_tracers, params):
1164    assert False  # unreachable
1165
1166def _memoize(thunk):
1167  cell = []
1168  saved_state = core.thread_local_state.trace_state.copy()
1169  def memoized():
1170    if not cell:
1171      prev_state = core.thread_local_state.trace_state
1172      core.thread_local_state.trace_state = saved_state
1173      try:
1174        cell.append(thunk())
1175      finally:
1176        core.thread_local_state.trace_state = prev_state
1177    return cell[0]
1178  return memoized
1179
1180
1181def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
1182  assert config.omnistaging_enabled
1183  with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
1184    main.source_info = fun_sourceinfo(fun.f)  # type: ignore
1185    main.jaxpr_stack = ()  # type: ignore
1186    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1187    del main, fun
1188  return jaxpr, out_avals, consts
1189
1190def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
1191                              in_avals: Sequence[AbstractValue]):
1192  frame = JaxprStackFrame()
1193  with extend_jaxpr_stack(main, frame):
1194    trace = DynamicJaxprTrace(main, core.cur_sublevel())
1195    in_tracers = map(trace.new_arg, in_avals)
1196    ans = fun.call_wrapped(*in_tracers)
1197    out_tracers = map(trace.full_raise, ans)
1198    jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
1199    del fun, main, trace, frame, in_tracers, out_tracers, ans
1200  return jaxpr, out_avals, consts
1201
1202@contextlib.contextmanager
1203def extend_jaxpr_stack(main, frame):
1204  main.jaxpr_stack = main.jaxpr_stack + (frame,)
1205  try:
1206    yield
1207  finally:
1208    assert frame is main.jaxpr_stack[-1]
1209    main.jaxpr_stack = main.jaxpr_stack[:-1]
1210
1211def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
1212  assert config.omnistaging_enabled
1213  with core.new_base_main(DynamicJaxprTrace) as main:  # type: ignore
1214    main.source_info = fun_sourceinfo(fun.f)  # type: ignore
1215    main.jaxpr_stack = ()  # type: ignore
1216    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1217    del fun, main
1218  return jaxpr, out_avals, consts
1219
1220def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
1221  # This function provides a partial evaluation behavior used by Flax. We can't
1222  # use trace_to_jaxpr directly because of an interaction with the curent
1223  # custom_derivatives.py, which we work around by adding the EvalTrace.
1224  # TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
1225  assert config.omnistaging_enabled
1226  with core.new_main(core.EvalTrace, dynamic=True) as _:  # type: ignore
1227    return trace_to_jaxpr(fun, in_pvals)
1228
1229def fun_sourceinfo(fun):
1230  if isinstance(fun, functools.partial):
1231    fun = fun.func
1232  try:
1233    filename = fun.__code__.co_filename
1234    lineno = fun.__code__.co_firstlineno
1235    return f"{fun.__name__} at {filename}:{lineno}"
1236  except AttributeError:
1237    return "<unknown>"
1238
1239
1240@config.register_omnistaging_disabler
1241@no_type_check
1242def omnistaging_disabler() -> None:
1243  global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
1244
1245  def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
1246                     instantiate: Union[bool, Sequence[bool]] = False,
1247                     stage_out=False, bottom=False,
1248                     trace_type: Optional[Type[Trace]] = None,
1249                     ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
1250    """Traces a function into a Jaxpr, given PartialVals for inputs.
1251
1252    Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
1253    computation that depends on unknown inputs. The `out_pvals` are the PartialVal
1254    for the outputs. The intermediate values that depend only on known inputs and
1255    are needed to compute the output of `jaxpr` are in `consts` and are passed in
1256    as the constvars of the `jaxpr`. The handling of the known outputs depends on
1257    `instantiate`.
1258
1259    For example, given `fun` defined as follows::
1260
1261      def fun(ki, ui):  # ki will be a known input in this example
1262        ka = ki + 2
1263        kb = ka + 3
1264        return (kb, ui + ka)
1265
1266    with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
1267    computation that depends on unknown inputs is `ui + ka` and will be the only
1268    computation in the body of the `jaxpr`. This computation depends on the known
1269    intermediate value `ka`, which will be computed statically. Currently, such
1270    constants are either embedded in the Jaxpr if they are scalars, or passed as a
1271    constvar to `jaxpr`, and then the value of the actual constant will be in
1272    `consts`:
1273
1274    When `instantiate=False` we get::
1275
1276      jaxpr =
1277        { lambda ka ; ki ui.
1278          let c = add ui ka
1279          in (*, c) }   # known outputs are `*`
1280      out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
1281      consts = [3]  # the constant for `ka`
1282
1283    When `instantiate=True` we get::
1284
1285      jaxpr =
1286        { lambda ka kb ; ki ui.
1287          let c = add ui ka
1288          in (kb, c) }   # known output are explicit
1289      out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
1290      consts = [3, 6]  # values for `ka` and `kb` constvars
1291    """
1292    trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
1293    with core.new_main(trace_type, bottom=bottom) as main:
1294      fun = trace_to_subjaxpr(fun, main, instantiate)
1295      jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
1296      assert not env
1297      del main
1298
1299    return jaxpr, out_pvals, consts
1300
1301  def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
1302                        instantiate: Union[bool, Sequence[bool]],
1303                        trace_type: Optional[Type[core.Trace]]
1304                        ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
1305    f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
1306
1307    cell = []
1308    def fun(*vals):
1309      pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
1310              for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
1311      jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
1312                                                      trace_type=trace_type)
1313      out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
1314      cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
1315      return out_consts_2 + consts_2
1316
1317    # The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
1318    # the avals, and it will never actually reach any primitives, because the `fun` above will
1319    # execute the jaxpr with the right avals (it reconstructs `pvals` inside).
1320    pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
1321            for aval, uk in zip(jaxpr.in_avals, unknowns)]
1322    jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
1323    (out_pvs_2, jaxpr_2, num_res), = cell
1324    assert len(jaxpr_2.constvars) == num_res
1325
1326    #   jaxpr :: a -> b
1327    # jaxpr_1 :: a1 -> [b1, res]
1328    # jaxpr_2 :: res | a2 -> b2
1329    # jaxpr_2 :: [a2, res] -> b2
1330    jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
1331    jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
1332    for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
1333      if not unknown:
1334        var.aval = abstract_unit
1335
1336    uk_out = [pv is not None for pv in out_pvs_2]
1337
1338    return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
1339
1340  def process_custom_jvp_call(self, prim, fun, jvp, tracers):
1341    # See comment at top of `JaxprTrace`. This method should be reachable
1342    # only when we stage out, and in that case we drop the custom differentiation
1343    # rules, because we do not need them.
1344    if not config.omnistaging_enabled:
1345      assert self.main.trace_type is StagingJaxprTrace
1346    return fun.call_wrapped(*tracers)
1347  JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
1348
1349  def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
1350    # See comment in the above process_custom_jvp_call method.
1351    if not config.omnistaging_enabled:
1352      assert self.main.trace_type is StagingJaxprTrace
1353    return fun.call_wrapped(*tracers)
1354  JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
1355
1356  staged_out_calls = set()
1357
1358  class StagingJaxprTrace(JaxprTrace): pass
1359