Searched refs:lax_control_flow (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | loops.py | 116 from jax._src.lax import control_flow as lax_control_flow unknown 514 return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals, 541 lax_control_flow._initial_style_jaxpr( 547 return lax_control_flow.cond_p.bind( 583 lax_control_flow._initial_style_jaxpr(cond_func_wrapped, 593 return lax_control_flow.while_p.bind(*itertools.chain(cond_consts,
|
H A D | jet.py | 32 from jax._src.lax import control_flow as lax_control_flow unknown 252 return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis, 256 deflinear(lax_control_flow.cumsum_p) 257 jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule, 259 jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule, 261 jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 35 from jax._src.lax import control_flow as lax_control_flow unknown 1740 tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl( 1741 functools.partial(lax_control_flow._cumred_tpu_translation_rule, 1743 tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl( 1744 functools.partial(lax_control_flow._cumred_tpu_translation_rule, 1751 tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl( 1754 tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl( 1906 tf_impl[lax_control_flow.cond_p] = _cond 1974 tf_impl[lax_control_flow.while_p] = _while 1977 tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | primitive_harness.py | 54 from jax._src.lax import control_flow as lax_control_flow unknown 1324 f_jax=lax_control_flow.cummin, 1353 lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumsum, 1354 lax_control_flow.cumprod
|