Searched refs:trace_to_jaxpr_final (Results 1 – 5 of 5) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 286 jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals) 309 final_jaxpr, out_avals, final_consts = pe.trace_to_jaxpr_final(f, in_avals) 476 vectorized_jaxpr, _, consts = pe.trace_to_jaxpr_final(f, local_avals)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | sharded_jit.py | 98 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
H A D | pxla.py | 718 jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals) 1437 jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals) 1614 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
|
H A D | xla.py | 656 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
|
H A D | partial_eval.py | 1211 def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): function
|