Searched refs:jvpfun (Results 1 – 2 of 2) sorted by relevance
/dports/math/py-autograd/autograd-1.3/autograd/ |
H A D | core.py | 132 jvps_dict = {argnum : translate_jvp(jvpfun, fun, argnum) 133 for argnum, jvpfun in zip(argnums, jvpfuns)} 140 def translate_jvp(jvpfun, fun, argnum): argument 141 if jvpfun is None: 143 elif jvpfun == 'same': 146 elif callable(jvpfun): 147 return jvpfun 149 raise Exception("Bad JVP '{}' for '{}'".format(jvpfun, fun.__name__))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 43 return jvpfun(jvp_subtrace(fun), instantiate) 46 return jvpfun(fun, instantiate), aux 50 def jvpfun(instantiate, primals, tangents): function 92 jvpfun = jvp(traceable) 94 jvpfun, aux = jvp(traceable, has_aux=True) 100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|