Home
last modified time | relevance | path

Searched refs:trace_to_jaxpr_final (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py286 jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
309 final_jaxpr, out_avals, final_consts = pe.trace_to_jaxpr_final(f, in_avals)
476 vectorized_jaxpr, _, consts = pe.trace_to_jaxpr_final(f, local_avals)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dsharded_jit.py98 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
H A Dpxla.py718 jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
1437 jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
1614 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
H A Dxla.py656 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
H A Dpartial_eval.py1211 def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): function