/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | jaxpr_util.py | 28 def all_eqns(jaxpr: core.Jaxpr): argument 34 def collect_eqns(jaxpr: core.Jaxpr, key: Callable): argument 40 def histogram(jaxpr: core.Jaxpr, key: Callable, argument 45 def primitives(jaxpr: core.Jaxpr): argument 48 def primitives_by_source(jaxpr: core.Jaxpr): argument 54 def primitives_by_shape(jaxpr: core.Jaxpr): argument 61 def source_locations(jaxpr: core.Jaxpr): argument 68 def var_defs_and_refs(jaxpr: core.Jaxpr): argument 104 def vars_by_fanout(jaxpr: core.Jaxpr): argument
|
H A D | core.py | 77 class Jaxpr: class 104 def jaxprs_in_params(params) -> Iterator[Jaxpr]: 108 if isinstance(v, Jaxpr): 114 def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]: argument 124 jaxpr: Jaxpr 127 def __init__(self, jaxpr: Jaxpr, consts: Sequence): argument 202 def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None, argument 331 if type(param) in (Jaxpr, ClosedJaxpr)) 334 def eval_jaxpr(jaxpr: Jaxpr, consts, *args): argument 1405 def check_jaxpr(jaxpr: Jaxpr): argument [all …]
|
H A D | custom_derivatives.py | 314 jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument 333 jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument 605 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument 628 fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], argument
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 31 from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue, 584 ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]: 647 def convert_constvars_jaxpr(jaxpr: Jaxpr): argument 650 lifted_jaxpr = Jaxpr(constvars=(), 656 def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int): argument 659 converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, 847 def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr: argument 864 return core.Jaxpr(jaxpr.constvars, jaxpr.invars, 868 def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]): argument 889 new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars, [all …]
|
H A D | invertible_ad.py | 24 from ..core import raise_to_shaped, get_aval, Literal, Jaxpr 156 def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): argument 284 eqn_jaxpr = Jaxpr([], variable_invars, eqn.outvars, [eqn])
|
H A D | xla.py | 194 outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None 195 def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: argument 202 def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: argument 208 if type(param) is core.Jaxpr: 532 def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: argument
|
H A D | ad.py | 163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): argument 659 new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1092 def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, argument 1093 has_output_token: bool) -> core.Jaxpr: 1128 new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) 1219 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) 1232 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) 1248 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) 1297 call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) 1340 core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), []) 1387 core.Jaxpr([], (new_body_invars_cond_constvars +
|
H A D | maps.py | 461 assert type(call_jaxpr) is core.Jaxpr 641 return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 99 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars, 439 cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, 986 jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, 2654 return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|