Home
last modified time | relevance | path

Searched refs:trace_to_jaxpr (Results 1 – 13 of 13) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dinvertible_ad.py62 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals)
219 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
222 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore
323 jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
H A Dpartial_eval.py409 _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
459 def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], function
712 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
1227 return trace_to_jaxpr(fun, in_pvals)
1243 global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
1245 def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], function
1311 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
1322 jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
H A Dad.py101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
720 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
723 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
791 jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
H A Dsharded_jit.py102 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
H A Dxla.py661 jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
961 jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
988 jaxpr, _, consts = pe.trace_to_jaxpr(
H A Dbatching.py484 jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
H A Dpxla.py731 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
/dports/math/py-flax/flax-0.3.3/flax/
H A Djax_utils.py118 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
121 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)
/dports/math/py-flax/flax-0.3.3/flax/core/
H A Daxes_scan.py133 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py648 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
1982 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
2088 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
2518 jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
2521 jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
H A Dcustom_derivatives.py676 jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
813 jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True)
919 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py1576 jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
1735 jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
2618 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
H A Dlax.py1151 jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
1161 jaxpr, _, consts = pe.trace_to_jaxpr(flat_comp, tuple(pvals),