Searched refs:pmax_p (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 144 out_flat = pmax_p.bind(*leaves, axis_name=axis_name, 624 pmax_p = core.Primitive('pmax') variable 625 pmax_p.multiple_results = True 626 pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) 627 xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p) 628 pxla.multi_host_supported_collectives.add(pmax_p) 629 batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p) 630 batching.collective_rules[pmax_p] = \ 631 partial(_batched_reduction_collective, pmax_p,
|
/dports/math/py-jax/jax-0.2.9/jax/lax/ |
H A D | __init__.py | 335 pmax_p,
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 831 lax.infeed_p, lax.outfeed_p, lax_parallel.pmax_p,
|