Searched refs:post_process_call (Results 1 – 7 of 7) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 245 def post_process_call(self, primitive, out_tracers, params): member in JaxprTrace 291 post_process_map = post_process_call 1087 def post_process_call(self, call_primitive, out_tracers, params): member in DynamicJaxprTrace
|
H A D | ad.py | 322 def post_process_call(self, call_primitive, out_tracers, params): member in JVPTrace 342 post_process_map = post_process_call
|
H A D | masking.py | 504 def post_process_call(self, call_primitive, out_tracers, params): member in MaskTrace
|
H A D | batching.py | 168 def post_process_call(self, call_primitive, out_tracers, params): member in BatchTrace
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | jet.py | 145 def post_process_call(self, call_primitive, out_tracers, params): member in JetTrace
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | core.py | 1281 return trace.post_process_call(self, out_tracers, params)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 755 def post_process_call(self, call_primitive: core.Primitive, member in TensorFlowTrace
|