Home
last modified time | relevance | path

Searched refs:lax_control_flow (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dloops.py116 from jax._src.lax import control_flow as lax_control_flow unknown
514 return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals,
541 lax_control_flow._initial_style_jaxpr(
547 return lax_control_flow.cond_p.bind(
583 lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
593 return lax_control_flow.while_p.bind(*itertools.chain(cond_consts,
H A Djet.py32 from jax._src.lax import control_flow as lax_control_flow unknown
252 return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis,
256 deflinear(lax_control_flow.cumsum_p)
257 jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
259 jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
261 jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py35 from jax._src.lax import control_flow as lax_control_flow unknown
1740 tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
1741 functools.partial(lax_control_flow._cumred_tpu_translation_rule,
1743 tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
1744 functools.partial(lax_control_flow._cumred_tpu_translation_rule,
1751 tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
1754 tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
1906 tf_impl[lax_control_flow.cond_p] = _cond
1974 tf_impl[lax_control_flow.while_p] = _while
1977 tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dprimitive_harness.py54 from jax._src.lax import control_flow as lax_control_flow unknown
1324 f_jax=lax_control_flow.cummin,
1353 lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumsum,
1354 lax_control_flow.cumprod