Searched refs:_all_gather_via_psum (Results 1 – 1 of 1) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 851 return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name, 860 def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): function 887 lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
|