Searched defs:post_process_call (Results 1 – 6 of 6) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 245 def post_process_call(self, primitive, out_tracers, params): member in JaxprTrace 1087 def post_process_call(self, call_primitive, out_tracers, params): member in DynamicJaxprTrace
|
H A D | batching.py | 168 def post_process_call(self, call_primitive, out_tracers, params): member in BatchTrace
|
H A D | masking.py | 504 def post_process_call(self, call_primitive, out_tracers, params): member in MaskTrace
|
H A D | ad.py | 322 def post_process_call(self, call_primitive, out_tracers, params): member in JVPTrace
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | jet.py | 145 def post_process_call(self, call_primitive, out_tracers, params): member in JetTrace
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 755 def post_process_call(self, call_primitive: core.Primitive, member in TensorFlowTrace
|