Home
last modified time | relevance | path

Searched defs:all_gather_dimension (Results 1 – 1 of 1) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py860 def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument
869 def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument
876 def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_… argument
891 def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): argument
897 def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_… argument
909 def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, ax… argument
923 def _all_gather_batched_collective(frame, vals_in, dims_in, all_gather_dimension, axis_name, axis_i… argument