Home
last modified time | relevance | path

Searched refs:process_custom_jvp_call (Results 1 – 8 of 8) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dcallback.py160 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in CallbackTrace
H A Djet.py156 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in JetTrace
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py306 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in JaxprTrace
1123 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in DynamicJaxprTrace
1340 def process_custom_jvp_call(self, prim, fun, jvp, tracers): function
1347 JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
H A Dbatching.py232 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in BatchTrace
H A Dad.py344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): member in JVPTrace
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py279 outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
693 outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
H A Dcore.py438 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in Trace
634 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in EvalTrace
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py774 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in TensorFlowTrace