Home
last modified time | relevance | path

Searched refs:post_process_custom_jvp_call (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dbatching.py243 def post_process_custom_jvp_call(self, out_tracers, params): member in BatchTrace
265 post_process_custom_vjp_call = post_process_custom_jvp_call
H A Dpartial_eval.py339 def post_process_custom_jvp_call(self, out_tracers, params): member in JaxprTrace
1141 def post_process_custom_jvp_call(self, out_tracers, params): member in DynamicJaxprTrace
H A Dad.py356 def post_process_custom_jvp_call(self, out_tracers, params): member in JVPTrace
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py287 return trace.post_process_custom_jvp_call(out_tracers, params)
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py781 def post_process_custom_jvp_call(self, out_tracers, params): member in TensorFlowTrace