Searched refs:trace_to_jaxpr_dynamic (Results 1 – 9 of 9) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 55 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) 810 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals) 914 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
H A D | api.py | 645 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) 2082 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | invertible_ad.py | 35 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
H A D | partial_eval.py | 406 _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals) 720 jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals) 1181 def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): function
|
H A D | batching.py | 423 jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
H A D | xla.py | 956 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) 983 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
H A D | ad.py | 641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
H A D | pxla.py | 1672 return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 67 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) 1732 jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|