Home
last modified time | relevance | path

Searched refs:all_gather_dimension (Results 1 – 1 of 1) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py872 out_shape.insert(all_gather_dimension, axis_size)
878 if (platform == 'gpu') and (all_gather_dimension == 0):
880 new_shape.insert(all_gather_dimension, 1)
884 return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size,
888 return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
894 new_shape.insert(all_gather_dimension, axis_size)
904 split_axis=all_gather_dimension,
911 if d <= all_gather_dimension:
912 all_gather_dimension += 1
917 all_gather_dimension=all_gather_dimension,
[all …]