Home
last modified time | relevance | path

Searched refs:ShapedArray (Results 1 – 23 of 23) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Dabstract_arrays.py29 ShapedArray = core.ShapedArray variable
39 return ShapedArray(np.shape(x), dtype)
43 return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype))
58 assert isinstance(aval, ShapedArray)
63 ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
H A Dcore.py1030 class ShapedArray(UnshapedArray): class
1035 super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
1066 return ShapedArray(self.shape, self.dtype, weak_type=False)
1092 class ConcreteArray(ShapedArray):
1122 return ShapedArray(self.shape, self.dtype,
1173 ShapedArray: lambda aval, weak_type: ShapedArray(aval.shape, aval.dtype, weak_type=weak_type)
1200 if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
1368 elif isinstance(aval, ShapedArray):
1371 return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype)
1378 elif isinstance(aval, ShapedArray):
[all …]
H A D__init__.py70 ShapedArray,
H A Dapi.py64 from .core import ConcreteArray, ShapedArray, raise_to_shaped
613 return ShapedArray(np.shape(x), dtypes.result_type(x))
682 if not isinstance(out_aval, xla.ShapedArray):
830 if isinstance(aval, ShapedArray):
1661 avals = map(partial(ShapedArray, dtype=np.float32), in_shapes)
1972 return core.ShapedArray(np.shape(x), dtypes.result_type(x))
2183 stacked_aval = ShapedArray((len(devices),) + avals[0].shape, avals[0].dtype)
2225 assert isinstance(aval, core.ShapedArray) and aval._num_buffers == 1
2356 return ShapedArray(np.shape(x), dtypes.result_type(x))
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpxla.py505 aval: ShapedArray, argument
530 assert type(aval) is ShapedArray
590 aval = ShapedArray(
1049 return ShapedArray(global_shape, local_aval.dtype)
1254 elif isinstance(aval, ShapedArray):
1273 elif isinstance(aval, ShapedArray):
1387 assert isinstance(aval, ShapedArray)
1392 return ShapedArray(tuple(shape), aval.dtype)
1397 assert isinstance(aval, ShapedArray)
1401 return ShapedArray(tuple(shape), aval.dtype)
[all …]
H A Dxla.py34 from ..core import (ConcreteArray, ShapedArray, AbstractToken,
104 ShapedArray: _make_array_shape,
114 def array_result_handler(device: Optional[Device], aval: core.ShapedArray): argument
123 ShapedArray: array_result_handler,
182 return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
978 return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
1029 aval: core.ShapedArray, argument
1071 def __init__(self, aval: core.ShapedArray, device: Optional[Device], argument
1092 assert type(aval) is ShapedArray
1349 def _lazy_force_computation(aval: core.ShapedArray, argument
H A Dbatching.py21 from ..core import ShapedArray, raise_to_shaped, Trace, Tracer
114 elif type(aval) is ShapedArray:
117 return ShapedArray(new_shape, aval.dtype)
H A Dmasking.py27 from ..core import ShapedArray, Trace, Tracer
441 return ShapedArray(self.polymorphic_shape, self.dtype)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlinalg.py29 from jax.core import Primitive, ShapedArray
365 if isinstance(operand, ShapedArray):
374 vl = vr = ShapedArray(batch_dims + (n, n), dtype)
375 w = ShapedArray(batch_dims + (n,), dtype)
465 if isinstance(operand, ShapedArray):
474 w = ShapedArray(batch_dims + (n,),
812 if isinstance(operand, ShapedArray):
820 perm = ShapedArray(batch_dims + (m,), jnp.int32)
1028 if isinstance(operand, ShapedArray):
1150 if isinstance(operand, ShapedArray):
[all …]
H A Dfft.py21 from jax.core import Primitive, ShapedArray
76 return ShapedArray(shape, dtype)
H A Dparallel.py29 from jax.core import ShapedArray, raise_to_shaped
553 scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
573 scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
779 return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
895 return ShapedArray(new_shape, x_aval.dtype)
958 lambda *args, **params: ShapedArray((), np.int32))
1074 out_aval = ShapedArray((), np.int32)
1091 lambda *args, **params: ShapedArray((), np.int32))
H A Dlax.py1479 aval = ShapedArray(shape, dtype)
1501 aval = ShapedArray((N, M), dtype)
1518 aval = ShapedArray(shape, dtype)
2666 lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding, argument
4983 scalar = ShapedArray((), dtype)
5015 scalar = ShapedArray((), dtype)
5067 scalar = ShapedArray((), dtype)
5258 scalar = ShapedArray((), dtype)
5312 scalar = ShapedArray((), dtype)
5427 scalar = ShapedArray((), dtype)
[all …]
H A Dcontrol_flow.py38 from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
271 if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
322 scalar = ShapedArray((), np.bool_)
1240 x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
1461 elif isinstance(aval, ShapedArray):
1462 return ShapedArray((sz, *aval.shape), aval.dtype)
1469 ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
1663 return ShapedArray((sz,) + aval.shape, aval.dtype)
1814 aval = ShapedArray((), dtypes.int_)
/dports/math/py-flax/flax-0.3.3/flax/core/
H A Daxes_scan.py125 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x), jnp.result_type(x))),
128 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))),
/dports/math/py-flax/flax-0.3.3/flax/
H A Djax_utils.py60 aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
114 in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Ddlpack.py68 aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
H A Drandom.py112 if all(isinstance(arg, core.ShapedArray) for arg in args):
114 aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py572 assert isinstance(aval, core.ShapedArray)
576 return core.ShapedArray(tuple(shape), aval.dtype)
579 assert isinstance(aval, core.ShapedArray)
583 return core.ShapedArray(tuple(shape), aval.dtype)
H A Dloops.py589 if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]):
H A Dhost_callback.py575 flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.result_type(r))
644 def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlax_numpy.py5193 setattr(ShapedArray, "reshape", core.aval_method(_reshape))
5194 setattr(ShapedArray, "flatten", core.aval_method(ravel))
5195 setattr(ShapedArray, "T", core.aval_property(transpose))
5196 setattr(ShapedArray, "real", core.aval_property(real))
5197 setattr(ShapedArray, "imag", core.aval_property(imag))
5198 setattr(ShapedArray, "astype", core.aval_method(_astype))
5199 setattr(ShapedArray, "view", core.aval_method(_view))
5200 setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))
5237 setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
5239 setattr(ShapedArray, "split", core.aval_method(split))
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py398 return core.ShapedArray(raw_shape, dtype)
420 return core.ShapedArray(shape, dtype)
541 return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
696 return TensorFlowTracer(self, val, core.ShapedArray(shape, dtype))
1076 return core.ShapedArray(aval.shape, np.int8)
2036 o_aval = core.ShapedArray((), to_jax_dtype(op.dtype))
2048 _out_aval=core.ShapedArray((), np.bool_),
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dshape_poly_test.py70 return core.ShapedArray(masking.parse_spec(shape), np.float32)
72 return core.ShapedArray(shape, np.float32)