Searched refs:fwd_jaxpr_thunk (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 605 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], 611 fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers! 628 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], 643 fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers 656 fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, 733 fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals)) 736 fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 361 def fwd_jaxpr_thunk(): function 372 fwd_jaxpr_thunk=fwd_jaxpr_thunk, 1148 fwd_jaxpr_thunk = _memoize( 1156 fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1288 fwd_jaxpr_thunk=unreachable_thunk,
|