Home
last modified time | relevance | path

Searched defs:out_tracers (Results 1 – 10 of 10) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py245 def post_process_call(self, primitive, out_tracers, params): argument
339 def post_process_custom_jvp_call(self, out_tracers, params): argument
379 def post_process_custom_vjp_call(self, out_tracers, params): argument
954 def to_jaxpr(self, in_tracers, out_tracers): argument
1087 def post_process_call(self, call_primitive, out_tracers, params): argument
1120 def post_process_map(self, map_primitive, out_tracers, params): argument
1141 def post_process_custom_jvp_call(self, out_tracers, params): argument
1163 def post_process_custom_vjp_call(self, out_tracers, params): argument
H A Dbatching.py168 def post_process_call(self, call_primitive, out_tracers, params): argument
216 def post_process_map(self, call_primitive, out_tracers, params): argument
243 def post_process_custom_jvp_call(self, out_tracers, params): argument
H A Dad.py322 def post_process_call(self, call_primitive, out_tracers, params): argument
356 def post_process_custom_jvp_call(self, out_tracers, params): argument
372 def post_process_custom_vjp_call(self, out_tracers, params): argument
H A Dmasking.py504 def post_process_call(self, call_primitive, out_tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py286 def post_process(self, trace, out_tracers, params): argument
587 def post_process(self, trace, out_tracers, params): argument
700 def jvp_post_process(self, trace, out_tracers, params): argument
716 def vjp_post_process(self, trace, out_tracers, params): argument
H A Dcore.py1280 def post_process(self, trace, out_tracers, params): argument
1307 def post_process(self, trace, out_tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dloops.py457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument
H A Djet.py145 def post_process_call(self, call_primitive, out_tracers, params): argument
H A Dmaps.py363 def post_process(self, trace, out_tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py771 def post_process_map(self, map_primitive, out_tracers, params): argument
781 def post_process_custom_jvp_call(self, out_tracers, params): argument
791 def post_process_custom_vjp_call(self, out_tracers, params): argument