Home
last modified time | relevance | path

Searched refs:trace_to_subjaxpr (Results 1 – 1 of 1) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py298 f = trace_to_subjaxpr(f, self.main, instantiate)
309 fun = trace_to_subjaxpr(fun, self.main, True)
323 jvp_ = trace_to_subjaxpr(jvp, self.main, True)
348 fun = trace_to_subjaxpr(fun, self.main, True)
362 fwd_ = trace_to_subjaxpr(fwd, self.main, True)
505 fun = trace_to_subjaxpr(fun, main, instantiate)
514 def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], function
1294 fun = trace_to_subjaxpr(fun, main, instantiate)