Home
last modified time | relevance | path

Searched refs:f_jvp (Results 1 – 6 of 6) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py297 f_jvp = jvp_subtrace(f, self.main)
302 f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
315 f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
318 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): argument
351 outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
638 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
787 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
791 jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dcontrol_flow_ops_test.py69 def f_jvp(primals, tangents): function
180 def f_jvp(primals, tangents): function
232 def f_jvp(primals, tangents): function
H A Djax2tf_test.py192 def f_jvp(primals, tangents): function
281 def f_jvp(primals, tangents): function
/dports/math/py-autograd/autograd-1.3/autograd/
H A Ddifferential_operators.py126 f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros())
127 def ggnvp(v): return f_vjp(g_hvp(f_jvp(v)))
/dports/math/py-jax/jax-0.2.9/jax/
H A Dtest_util.py232 def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): argument
237 v_out, t_out = f_jvp(args, tangent)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py2050 unchecked_zeros, f_jvp = jax.linearize(f, x)
2051 return tangent_solve(f_jvp, b)