Home
last modified time | relevance | path

Searched defs:primals_in (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Djet.py192 def zero_prop(prim, primals_in, series_in, **params): argument
272 def deriv_prop(prim, deriv, primals_in, series_in): argument
308 def _erf_inv_rule(primals_in, series_in): argument
361 def _exp_taylor(primals_in, series_in): argument
372 def _pow_taylor(primals_in, series_in): argument
386 def _integer_pow_taylor(primals_in, series_in, *, y): argument
407 def _expit_taylor(primals_in, series_in): argument
420 def _tanh_taylor(primals_in, series_in): argument
430 def _log_taylor(primals_in, series_in): argument
442 def _atan2_taylor(primals_in, series_in): argument
[all …]
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dinvertible_ad.py156 def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): argument
268 def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in): argument
H A Dad.py163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument
557 def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): argument
571 def do_transpose(primals_in, cotangents_in): argument
656 def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): argument