Home
last modified time | relevance | path

Searched defs:axis_env (Results 1 – 6 of 6) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py517 def _replica_groups(axis_env, axis_name, axis_index_groups): argument
526 axis_env, platform): argument
568 def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env, argument
646 def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform): argument
698 axis_index_groups, axis_env, platform): argument
876 …n_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, axis_env, platform): argument
940 def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): argument
1030 axis_env, platform): argument
H A Dcontrol_flow.py301 def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args, argument
731 def _cond_translation_rule(c, axis_env, name_stack, avals, backend, argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py450 def _xmap_translation_rule_replica(c, axis_env, argument
500 def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes): argument
517 def _xla_tile(c, axis_env, x, in_axes, axis_sizes): argument
530 def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): argument
557 def _xmap_translation_rule_spmd(c, axis_env, argument
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py312 def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params): argument
405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): argument
501 def axis_read(axis_env, axis_name): argument
895 def _xla_call_translation_rule(c, axis_env, argument
981 def f(c, axis_env, name_stack, avals, backend, *xla_args, **params): argument
1400 def _remat_translation_rule(c, axis_env, in_nodes, argument
1448 def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *, argument
1459 def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend, argument
H A Dpxla.py1218 def _pmap_translation_rule(c, axis_env, argument
1249 def _xla_shard(c, aval, axis_env, x, in_axis): argument
1268 def _xla_unshard(c, aval, axis_env, out_axis, x, backend): argument
1304 def _unravel_index(c, axis_env): argument
H A Dsharded_jit.py199 def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, argument