/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 245 def post_process_call(self, primitive, out_tracers, params): argument 339 def post_process_custom_jvp_call(self, out_tracers, params): argument 379 def post_process_custom_vjp_call(self, out_tracers, params): argument 954 def to_jaxpr(self, in_tracers, out_tracers): argument 1087 def post_process_call(self, call_primitive, out_tracers, params): argument 1120 def post_process_map(self, map_primitive, out_tracers, params): argument 1141 def post_process_custom_jvp_call(self, out_tracers, params): argument 1163 def post_process_custom_vjp_call(self, out_tracers, params): argument
|
H A D | batching.py | 168 def post_process_call(self, call_primitive, out_tracers, params): argument 216 def post_process_map(self, call_primitive, out_tracers, params): argument 243 def post_process_custom_jvp_call(self, out_tracers, params): argument
|
H A D | ad.py | 322 def post_process_call(self, call_primitive, out_tracers, params): argument 356 def post_process_custom_jvp_call(self, out_tracers, params): argument 372 def post_process_custom_vjp_call(self, out_tracers, params): argument
|
H A D | masking.py | 504 def post_process_call(self, call_primitive, out_tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 286 def post_process(self, trace, out_tracers, params): argument 587 def post_process(self, trace, out_tracers, params): argument 700 def jvp_post_process(self, trace, out_tracers, params): argument 716 def vjp_post_process(self, trace, out_tracers, params): argument
|
H A D | core.py | 1280 def post_process(self, trace, out_tracers, params): argument 1307 def post_process(self, trace, out_tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | loops.py | 457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument
|
H A D | jet.py | 145 def post_process_call(self, call_primitive, out_tracers, params): argument
|
H A D | maps.py | 363 def post_process(self, trace, out_tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 771 def post_process_map(self, map_primitive, out_tracers, params): argument 781 def post_process_custom_jvp_call(self, out_tracers, params): argument 791 def post_process_custom_vjp_call(self, out_tracers, params): argument
|