/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | callback.py | 160 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in CallbackTrace
|
H A D | jet.py | 156 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in JetTrace
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 306 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in JaxprTrace 1123 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in DynamicJaxprTrace 1340 def process_custom_jvp_call(self, prim, fun, jvp, tracers): function 1347 JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
|
H A D | batching.py | 232 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in BatchTrace
|
H A D | ad.py | 344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): member in JVPTrace
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 279 outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore 693 outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
|
H A D | core.py | 438 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in Trace 634 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): member in EvalTrace
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 774 def process_custom_jvp_call(self, prim, fun, jvp, tracers): member in TensorFlowTrace
|