/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | invertible_ad.py | 62 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals) 219 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( 222 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore 323 jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
H A D | partial_eval.py | 409 _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in, 459 def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], function 712 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate) 1227 return trace_to_jaxpr(fun, in_pvals) 1243 global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace 1245 def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], function 1311 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate, 1322 jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
H A D | ad.py | 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) 720 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True) 723 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, 791 jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
H A D | sharded_jit.py | 102 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
|
H A D | xla.py | 661 jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore 961 jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True, 988 jaxpr, _, consts = pe.trace_to_jaxpr(
|
H A D | batching.py | 484 jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
H A D | pxla.py | 731 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
/dports/math/py-flax/flax-0.3.3/flax/ |
H A D | jax_utils.py | 118 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) 121 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)
|
/dports/math/py-flax/flax-0.3.3/flax/core/ |
H A D | axes_scan.py | 133 _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 648 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( 1982 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, 2088 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( 2518 jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True) 2521 jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
H A D | custom_derivatives.py | 676 jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, 813 jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True) 919 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 1576 jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr( 1735 jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) 2618 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
H A D | lax.py | 1151 jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False) 1161 jaxpr, _, consts = pe.trace_to_jaxpr(flat_comp, tuple(pvals),
|