/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 517 def _replica_groups(axis_env, axis_name, axis_index_groups): argument 526 axis_env, platform): argument 568 def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env, argument 646 def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform): argument 698 axis_index_groups, axis_env, platform): argument 876 …n_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, axis_env, platform): argument 940 def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): argument 1030 axis_env, platform): argument
|
H A D | control_flow.py | 301 def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args, argument 731 def _cond_translation_rule(c, axis_env, name_stack, avals, backend, argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 450 def _xmap_translation_rule_replica(c, axis_env, argument 500 def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes): argument 517 def _xla_tile(c, axis_env, x, in_axes, axis_sizes): argument 530 def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): argument 557 def _xmap_translation_rule_spmd(c, axis_env, argument
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 312 def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params): argument 405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): argument 501 def axis_read(axis_env, axis_name): argument 895 def _xla_call_translation_rule(c, axis_env, argument 981 def f(c, axis_env, name_stack, avals, backend, *xla_args, **params): argument 1400 def _remat_translation_rule(c, axis_env, in_nodes, argument 1448 def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *, argument 1459 def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend, argument
|
H A D | pxla.py | 1218 def _pmap_translation_rule(c, axis_env, argument 1249 def _xla_shard(c, aval, axis_env, x, in_axis): argument 1268 def _xla_unshard(c, aval, axis_env, out_axis, x, backend): argument 1304 def _unravel_index(c, axis_env): argument
|
H A D | sharded_jit.py | 199 def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, argument
|