Searched refs:avals_out (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 365 avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] 368 avals_out=avals_out) 641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) 676 custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out) 684 def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out): argument 686 cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
|
H A D | partial_eval.py | 406 _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals) 411 avals_out, _ = unzip2(pvals_out) 412 for aval_out in avals_out: 414 return avals_out
|
H A D | batching.py | 485 avals_out, _ = unzip2(pvals_out)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 619 avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] 621 *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
|