Home
last modified time | relevance | path

Searched refs:custom_partial_eval_rules (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py137 if primitive in custom_partial_eval_rules:
138 return custom_partial_eval_rules[primitive](self, *tracers, **params)
399 custom_partial_eval_rules: Dict[core.Primitive, Callable] = {} variable
H A Dad.py728 pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py544 pe.custom_partial_eval_rules[while_p] = _while_partial_eval
1109 pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
1877 pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1009 pe.custom_partial_eval_rules[outside_call_p] = _outside_call_partial_eval_rule