Home
last modified time | relevance | path

Searched refs:parallel_translations (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py598 xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
627 xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p)
638 xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p)
678 xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
783 xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
934 xla.parallel_translations[all_gather_p] = _all_gather_translation_rule
955 xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
1035 out_tup = xla.parallel_translations[psum_p](
1042 xla.parallel_translations[pdot_p] = _pdot_translation_rule
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dprimitives_test.py123 | set(xla.parallel_translations))
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py454 elif eqn.primitive in parallel_translations:
455 rule = parallel_translations[eqn.primitive]
571 if eqn.primitive in parallel_translations:
912 parallel_translations: Dict[core.Primitive, Callable] = {} variable
H A Dpxla.py1692 if eqn.primitive in xla.parallel_translations: