Searched defs:all_gather_dimension (Results 1 – 1 of 1) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 860 def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument 869 def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument 876 def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_… argument 891 def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument 897 def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_… argument 909 def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, ax… argument 923 def _all_gather_batched_collective(frame, vals_in, dims_in, all_gather_dimension, axis_name, axis_i… argument
|