Searched refs:f_jvp (Results 1 – 6 of 6) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 297 f_jvp = jvp_subtrace(f, self.main) 302 f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) 315 f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) 318 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) 344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): argument 351 outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) 638 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) 641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) 787 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) 791 jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | control_flow_ops_test.py | 69 def f_jvp(primals, tangents): function 180 def f_jvp(primals, tangents): function 232 def f_jvp(primals, tangents): function
|
H A D | jax2tf_test.py | 192 def f_jvp(primals, tangents): function 281 def f_jvp(primals, tangents): function
|
/dports/math/py-autograd/autograd-1.3/autograd/ |
H A D | differential_operators.py | 126 f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros()) 127 def ggnvp(v): return f_vjp(g_hvp(f_jvp(v)))
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | test_util.py | 232 def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): argument 237 v_out, t_out = f_jvp(args, tangent)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 2050 unchecked_zeros, f_jvp = jax.linearize(f, x) 2051 return tangent_solve(f_jvp, b)
|