Searched refs:primals_in (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | jet.py | 273 x, = primals_in 309 x, = primals_in 362 x, = primals_in 393 x, = primals_in 408 x, = primals_in 421 x, = primals_in 431 x, = primals_in 443 x, y = primals_in 458 x, y = primals_in 472 x, = primals_in [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | invertible_ad.py | 183 map(write_primal, jaxpr.invars, primals_in) 188 primals_in = map(read_primal, eqn.invars) 213 tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) 216 in_avals = map(abstract, primals_in + primals_out + primals_out) 231 unknowns = (map(ad.is_undefined_primal, primals_in) + 251 rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), 257 in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] 268 def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in): argument 272 rec_primals_in = get_primitive_inverse(eqn.primitive)(primals_out, *primals_in) 278 primals_in = map(lambda p, rp, unknown: rp if unknown else p, [all …]
|
H A D | ad.py | 163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument 200 map(write_primal, jaxpr.invars, primals_in) 285 primal_out, tangent_out = jvp(primals_in, tangents_in, **params) 346 primals_in = map(core.full_lower, primals_in) 350 tangents_in = map(replace_float0s, primals_in, tangents_in) 351 outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) 562 unknowns = map(is_undefined_primal, primals_in) 571 def do_transpose(primals_in, cotangents_in): argument 579 return cotangents_out[:len(primals_in)] 581 flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 235 primals_in, tangents_in = split_list(args, [len(args) // 2]) 236 py_primals = tree_unflatten(in_tree, primals_in)
|