Home
last modified time | relevance | path

Searched refs:_all_gather_via_psum (Results 1 – 1 of 1) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py851 return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name,
860 def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): function
887 lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)