Searched refs:AxisEnv (Results 1 – 5 of 5) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 281 built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend, 346 axis_env = AxisEnv(1, (), ()) 492 class AxisEnv(NamedTuple): class 498 def extend_axis_env(env: AxisEnv, name, size: int): argument 499 return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,)) 507 def axis_groups(axis_env: AxisEnv, name): argument 709 c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts, 951 axis_env = AxisEnv(1, (), ())
|
H A D | pxla.py | 829 axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,)) 1467 axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped 1472 axis_env = xla.AxisEnv(nreps=mesh.size, 1637 axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,))
|
H A D | sharded_jit.py | 163 axis_env = xla.AxisEnv(nrep, (), ())
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 606 return xla.AxisEnv(nreps, (), ()) 610 return xla.AxisEnv(nreps, names, sizes)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 4940 axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
|