Home
last modified time | relevance | path

Searched refs:xla_shape (Results 1 – 6 of 6) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Ddlpack.py66 xla_shape = getattr(buf, "xla_shape", buf.shape)()
67 assert not xla_shape.is_tuple()
68 aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
/dports/math/py-jax/jax-0.2.9/jax/
H A Dlazy.py208 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M))
210 xops.Add(xops.Iota(c, xla_shape, 0),
212 xops.Iota(c, xla_shape, 1))
217 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M))
219 xops.Add(xops.Iota(c, xla_shape, 0),
221 xops.Iota(c, xla_shape, 1))
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py372 _check_special(prim.name, buf.xla_shape(), buf)
374 def _check_special(name, xla_shape, buf): argument
375 assert not xla_shape.is_tuple()
376 if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
802 for xla_shape in aval_to_xla_shapes(a)]
973 def _array_aval_from_xla_shape(xla_shape): argument
977 assert not xla_shape.is_tuple()
978 return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
1429 def zeros(xla_shape): argument
1430 if xla_shape.is_array():
[all …]
H A Dpxla.py1280 xla_shape = c.get_shape(x)
1281 dims = list(xla_shape.dimensions())
1282 padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py531 xla_shape = c.get_shape(x)
532 x_dtype = xla_shape.numpy_dtype()
539 tile_shape = list(xla_shape.dimensions())
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py6010 xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape)
6011 return xops.RngUniform(a, b, xla_shape)
6033 xla_shape = xc.Shape.array_shape(etype, shape)
6034 return xops.Iota(c, xla_shape, dimension)