Home
last modified time | relevance | path

Searched refs:default_process_primitive (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py970 return trace.default_process_primitive(outside_call_p, args, params)
981 return trace.default_process_primitive(outside_call_p, args, params)
992 outs_known = trace.default_process_primitive(
998 outs_all_unknown = trace.default_process_primitive(outside_call_p, args, params)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py504 return trace.default_process_primitive(while_p, tracers, params)
526 out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
840 return trace.default_process_primitive(cond_p, tracers, params)
1538 return trace.default_process_primitive(scan_p, tracers, params)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py140 return self.default_process_primitive(primitive, tracers, params)
142 def default_process_primitive(self, primitive, tracers, params): member in JaxprTrace