Searched refs:trace_to_subjaxpr (Results 1 – 1 of 1) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 298 f = trace_to_subjaxpr(f, self.main, instantiate) 309 fun = trace_to_subjaxpr(fun, self.main, True) 323 jvp_ = trace_to_subjaxpr(jvp, self.main, True) 348 fun = trace_to_subjaxpr(fun, self.main, True) 362 fwd_ = trace_to_subjaxpr(fwd, self.main, True) 505 fun = trace_to_subjaxpr(fun, main, instantiate) 514 def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], function 1294 fun = trace_to_subjaxpr(fun, main, instantiate)
|