Searched refs:default_process_primitive (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 970 return trace.default_process_primitive(outside_call_p, args, params) 981 return trace.default_process_primitive(outside_call_p, args, params) 992 outs_known = trace.default_process_primitive( 998 outs_all_unknown = trace.default_process_primitive(outside_call_p, args, params)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 504 return trace.default_process_primitive(while_p, tracers, params) 526 out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params) 840 return trace.default_process_primitive(cond_p, tracers, params) 1538 return trace.default_process_primitive(scan_p, tracers, params)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 140 return self.default_process_primitive(primitive, tracers, params) 142 def default_process_primitive(self, primitive, tracers, params): member in JaxprTrace
|