Home
last modified time | relevance | path

Searched refs:Jaxpr (Results 1 – 10 of 10) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Djaxpr_util.py28 def all_eqns(jaxpr: core.Jaxpr): argument
34 def collect_eqns(jaxpr: core.Jaxpr, key: Callable): argument
40 def histogram(jaxpr: core.Jaxpr, key: Callable, argument
45 def primitives(jaxpr: core.Jaxpr): argument
48 def primitives_by_source(jaxpr: core.Jaxpr): argument
54 def primitives_by_shape(jaxpr: core.Jaxpr): argument
61 def source_locations(jaxpr: core.Jaxpr): argument
68 def var_defs_and_refs(jaxpr: core.Jaxpr): argument
104 def vars_by_fanout(jaxpr: core.Jaxpr): argument
H A Dcore.py77 class Jaxpr: class
104 def jaxprs_in_params(params) -> Iterator[Jaxpr]:
108 if isinstance(v, Jaxpr):
114 def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]: argument
124 jaxpr: Jaxpr
127 def __init__(self, jaxpr: Jaxpr, consts: Sequence): argument
202 def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None, argument
331 if type(param) in (Jaxpr, ClosedJaxpr))
334 def eval_jaxpr(jaxpr: Jaxpr, consts, *args): argument
1405 def check_jaxpr(jaxpr: Jaxpr): argument
[all …]
H A Dcustom_derivatives.py314 jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument
333 jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument
605 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument
628 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py31 from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
584 ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
647 def convert_constvars_jaxpr(jaxpr: Jaxpr): argument
650 lifted_jaxpr = Jaxpr(constvars=(),
656 def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int): argument
659 converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
847 def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr: argument
864 return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
868 def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]): argument
889 new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
[all …]
H A Dinvertible_ad.py24 from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
156 def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): argument
284 eqn_jaxpr = Jaxpr([], variable_invars, eqn.outvars, [eqn])
H A Dxla.py194 outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
195 def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: argument
202 def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: argument
208 if type(param) is core.Jaxpr:
532 def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: argument
H A Dad.py163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument
659 new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1092 def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, argument
1093 has_output_token: bool) -> core.Jaxpr:
1128 new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
1219 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
1232 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
1248 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
1297 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
1340 core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
1387 core.Jaxpr([], (new_body_invars_cond_constvars +
H A Dmaps.py461 assert type(call_jaxpr) is core.Jaxpr
641 return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py99 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
439 cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
986 jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
2654 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,