Home
last modified time | relevance | path

Searched refs:StagingJaxprTrace (Results 1 – 2 of 2) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py167 if (self.main.trace_type is StagingJaxprTrace # type: ignore
1243 global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
1292 trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
1345 assert self.main.trace_type is StagingJaxprTrace
1352 assert self.main.trace_type is StagingJaxprTrace
1358 class StagingJaxprTrace(JaxprTrace): pass class
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py1534 …if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: igno…