Home
last modified time | relevance | path

Searched refs:fwd_jaxpr_thunk (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py605 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
611 fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
628 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
643 fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
656 fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
733 fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
736 fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py361 def fwd_jaxpr_thunk(): function
372 fwd_jaxpr_thunk=fwd_jaxpr_thunk,
1148 fwd_jaxpr_thunk = _memoize(
1156 fwd_jaxpr_thunk=fwd_jaxpr_thunk,
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1288 fwd_jaxpr_thunk=unreachable_thunk,