Searched refs:xla_shape (Results 1 – 6 of 6) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/ |
H A D | dlpack.py | 66 xla_shape = getattr(buf, "xla_shape", buf.shape)() 67 assert not xla_shape.is_tuple() 68 aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | lazy.py | 208 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) 210 xops.Add(xops.Iota(c, xla_shape, 0), 212 xops.Iota(c, xla_shape, 1)) 217 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) 219 xops.Add(xops.Iota(c, xla_shape, 0), 221 xops.Iota(c, xla_shape, 1))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 372 _check_special(prim.name, buf.xla_shape(), buf) 374 def _check_special(name, xla_shape, buf): argument 375 assert not xla_shape.is_tuple() 376 if dtypes.issubdtype(xla_shape.element_type(), np.inexact): 802 for xla_shape in aval_to_xla_shapes(a)] 973 def _array_aval_from_xla_shape(xla_shape): argument 977 assert not xla_shape.is_tuple() 978 return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) 1429 def zeros(xla_shape): argument 1430 if xla_shape.is_array(): [all …]
|
H A D | pxla.py | 1280 xla_shape = c.get_shape(x) 1281 dims = list(xla_shape.dimensions()) 1282 padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 531 xla_shape = c.get_shape(x) 532 x_dtype = xla_shape.numpy_dtype() 539 tile_shape = list(xla_shape.dimensions())
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 6010 xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape) 6011 return xops.RngUniform(a, b, xla_shape) 6033 xla_shape = xc.Shape.array_shape(etype, shape) 6034 return xops.Iota(c, xla_shape, dimension)
|