Home
last modified time | relevance | path

Searched refs:AxisEnv (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py281 built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend,
346 axis_env = AxisEnv(1, (), ())
492 class AxisEnv(NamedTuple): class
498 def extend_axis_env(env: AxisEnv, name, size: int): argument
499 return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
507 def axis_groups(axis_env: AxisEnv, name): argument
709 c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
951 axis_env = AxisEnv(1, (), ())
H A Dpxla.py829 axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,))
1467 axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped
1472 axis_env = xla.AxisEnv(nreps=mesh.size,
1637 axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,))
H A Dsharded_jit.py163 axis_env = xla.AxisEnv(nrep, (), ())
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py606 return xla.AxisEnv(nreps, (), ())
610 return xla.AxisEnv(nreps, names, sizes)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py4940 axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions