Home
last modified time | relevance | path

Searched refs:trace_to_jaxpr_dynamic (Results 1 – 9 of 9) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py55 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
810 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
914 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
H A Dapi.py645 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
2082 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dinvertible_ad.py35 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
H A Dpartial_eval.py406 _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
720 jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
1181 def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): function
H A Dbatching.py423 jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
H A Dxla.py956 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
983 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
H A Dad.py641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
H A Dpxla.py1672 return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py67 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
1732 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)