Home
last modified time | relevance | path

Searched refs:avals_out (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py365 avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
368 avals_out=avals_out)
641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
676 custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
684 def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out): argument
686 cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
H A Dpartial_eval.py406 _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
411 avals_out, _ = unzip2(pvals_out)
412 for aval_out in avals_out:
414 return avals_out
H A Dbatching.py485 avals_out, _ = unzip2(pvals_out)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py619 avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
621 *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)