Searched refs:post_process_custom_jvp_call (Results 1 – 5 of 5) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | batching.py | 243 def post_process_custom_jvp_call(self, out_tracers, params): member in BatchTrace 265 post_process_custom_vjp_call = post_process_custom_jvp_call
|
H A D | partial_eval.py | 339 def post_process_custom_jvp_call(self, out_tracers, params): member in JaxprTrace 1141 def post_process_custom_jvp_call(self, out_tracers, params): member in DynamicJaxprTrace
|
H A D | ad.py | 356 def post_process_custom_jvp_call(self, out_tracers, params): member in JVPTrace
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 287 return trace.post_process_custom_jvp_call(out_tracers, params)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 781 def post_process_custom_jvp_call(self, out_tracers, params): member in TensorFlowTrace
|