Home
last modified time | relevance | path

Searched refs:post_process_call (Results 1 – 7 of 7) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py245 def post_process_call(self, primitive, out_tracers, params): member in JaxprTrace
291 post_process_map = post_process_call
1087 def post_process_call(self, call_primitive, out_tracers, params): member in DynamicJaxprTrace
H A Dad.py322 def post_process_call(self, call_primitive, out_tracers, params): member in JVPTrace
342 post_process_map = post_process_call
H A Dmasking.py504 def post_process_call(self, call_primitive, out_tracers, params): member in MaskTrace
H A Dbatching.py168 def post_process_call(self, call_primitive, out_tracers, params): member in BatchTrace
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Djet.py145 def post_process_call(self, call_primitive, out_tracers, params): member in JetTrace
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcore.py1281 return trace.post_process_call(self, out_tracers, params)
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py755 def post_process_call(self, call_primitive: core.Primitive, member in TensorFlowTrace