Searched defs:primals_in (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | jet.py | 192 def zero_prop(prim, primals_in, series_in, **params): argument 272 def deriv_prop(prim, deriv, primals_in, series_in): argument 308 def _erf_inv_rule(primals_in, series_in): argument 361 def _exp_taylor(primals_in, series_in): argument 372 def _pow_taylor(primals_in, series_in): argument 386 def _integer_pow_taylor(primals_in, series_in, *, y): argument 407 def _expit_taylor(primals_in, series_in): argument 420 def _tanh_taylor(primals_in, series_in): argument 430 def _log_taylor(primals_in, series_in): argument 442 def _atan2_taylor(primals_in, series_in): argument [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | invertible_ad.py | 156 def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): argument 268 def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in): argument
|
H A D | ad.py | 163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument 557 def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): argument 571 def do_transpose(primals_in, cotangents_in): argument 656 def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): argument
|