Searched defs:process_custom_jvp_call (Results 1 – 7 of 7) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | callback.py | 160 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in CallbackTrace
|
H A D | jet.py | 156 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in JetTrace
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 306 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in JaxprTrace 1123 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in DynamicJaxprTrace 1340 def process_custom_jvp_call(self, prim, fun, jvp, tracers): function
|
H A D | batching.py | 232 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in BatchTrace
|
H A D | ad.py | 344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): member in JVPTrace
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | core.py | 438 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in Trace 634 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in EvalTrace
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 774 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in TensorFlowTrace
|