Searched refs:all_gather_dimension (Results 1 – 1 of 1) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 872 out_shape.insert(all_gather_dimension, axis_size) 878 if (platform == 'gpu') and (all_gather_dimension == 0): 880 new_shape.insert(all_gather_dimension, 1) 884 return xops.AllGather(x, all_gather_dimension=all_gather_dimension, shard_count=axis_size, 888 return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, 894 new_shape.insert(all_gather_dimension, axis_size) 904 split_axis=all_gather_dimension, 911 if d <= all_gather_dimension: 912 all_gather_dimension += 1 917 all_gather_dimension=all_gather_dimension, [all …]
|