Home
last modified time | relevance | path

Searched refs:pmax_p (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py144 out_flat = pmax_p.bind(*leaves, axis_name=axis_name,
624 pmax_p = core.Primitive('pmax') variable
625 pmax_p.multiple_results = True
626 pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
627 xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p)
628 pxla.multi_host_supported_collectives.add(pmax_p)
629 batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
630 batching.collective_rules[pmax_p] = \
631 partial(_batched_reduction_collective, pmax_p,
/dports/math/py-jax/jax-0.2.9/jax/lax/
H A D__init__.py335 pmax_p,
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py831 lax.infeed_p, lax.outfeed_p, lax_parallel.pmax_p,