Home
last modified time | relevance | path

Searched refs:primals_in (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Djet.py273 x, = primals_in
309 x, = primals_in
362 x, = primals_in
393 x, = primals_in
408 x, = primals_in
421 x, = primals_in
431 x, = primals_in
443 x, y = primals_in
458 x, y = primals_in
472 x, = primals_in
[all …]
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dinvertible_ad.py183 map(write_primal, jaxpr.invars, primals_in)
188 primals_in = map(read_primal, eqn.invars)
213 tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
216 in_avals = map(abstract, primals_in + primals_out + primals_out)
231 unknowns = (map(ad.is_undefined_primal, primals_in) +
251 rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
257 in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
268 def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in): argument
272 rec_primals_in = get_primitive_inverse(eqn.primitive)(primals_out, *primals_in)
278 primals_in = map(lambda p, rp, unknown: rp if unknown else p,
[all …]
H A Dad.py163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument
200 map(write_primal, jaxpr.invars, primals_in)
285 primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
346 primals_in = map(core.full_lower, primals_in)
350 tangents_in = map(replace_float0s, primals_in, tangents_in)
351 outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
562 unknowns = map(is_undefined_primal, primals_in)
571 def do_transpose(primals_in, cotangents_in): argument
579 return cotangents_out[:len(primals_in)]
581 flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
[all …]
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py235 primals_in, tangents_in = split_list(args, [len(args) // 2])
236 py_primals = tree_unflatten(in_tree, primals_in)