/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | abstract_arrays.py | 29 ShapedArray = core.ShapedArray variable 39 return ShapedArray(np.shape(x), dtype) 43 return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype)) 58 assert isinstance(aval, ShapedArray) 63 ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
H A D | core.py | 1030 class ShapedArray(UnshapedArray): class 1035 super(ShapedArray, self).__init__(dtype, weak_type=weak_type) 1066 return ShapedArray(self.shape, self.dtype, weak_type=False) 1092 class ConcreteArray(ShapedArray): 1122 return ShapedArray(self.shape, self.dtype, 1173 ShapedArray: lambda aval, weak_type: ShapedArray(aval.shape, aval.dtype, weak_type=weak_type) 1200 if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) 1368 elif isinstance(aval, ShapedArray): 1371 return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype) 1378 elif isinstance(aval, ShapedArray): [all …]
|
H A D | __init__.py | 70 ShapedArray,
|
H A D | api.py | 64 from .core import ConcreteArray, ShapedArray, raise_to_shaped 613 return ShapedArray(np.shape(x), dtypes.result_type(x)) 682 if not isinstance(out_aval, xla.ShapedArray): 830 if isinstance(aval, ShapedArray): 1661 avals = map(partial(ShapedArray, dtype=np.float32), in_shapes) 1972 return core.ShapedArray(np.shape(x), dtypes.result_type(x)) 2183 stacked_aval = ShapedArray((len(devices),) + avals[0].shape, avals[0].dtype) 2225 assert isinstance(aval, core.ShapedArray) and aval._num_buffers == 1 2356 return ShapedArray(np.shape(x), dtypes.result_type(x))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | pxla.py | 505 aval: ShapedArray, argument 530 assert type(aval) is ShapedArray 590 aval = ShapedArray( 1049 return ShapedArray(global_shape, local_aval.dtype) 1254 elif isinstance(aval, ShapedArray): 1273 elif isinstance(aval, ShapedArray): 1387 assert isinstance(aval, ShapedArray) 1392 return ShapedArray(tuple(shape), aval.dtype) 1397 assert isinstance(aval, ShapedArray) 1401 return ShapedArray(tuple(shape), aval.dtype) [all …]
|
H A D | xla.py | 34 from ..core import (ConcreteArray, ShapedArray, AbstractToken, 104 ShapedArray: _make_array_shape, 114 def array_result_handler(device: Optional[Device], aval: core.ShapedArray): argument 123 ShapedArray: array_result_handler, 182 return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True) 978 return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) 1029 aval: core.ShapedArray, argument 1071 def __init__(self, aval: core.ShapedArray, device: Optional[Device], argument 1092 assert type(aval) is ShapedArray 1349 def _lazy_force_computation(aval: core.ShapedArray, argument
|
H A D | batching.py | 21 from ..core import ShapedArray, raise_to_shaped, Trace, Tracer 114 elif type(aval) is ShapedArray: 117 return ShapedArray(new_shape, aval.dtype)
|
H A D | masking.py | 27 from ..core import ShapedArray, Trace, Tracer 441 return ShapedArray(self.polymorphic_shape, self.dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | linalg.py | 29 from jax.core import Primitive, ShapedArray 365 if isinstance(operand, ShapedArray): 374 vl = vr = ShapedArray(batch_dims + (n, n), dtype) 375 w = ShapedArray(batch_dims + (n,), dtype) 465 if isinstance(operand, ShapedArray): 474 w = ShapedArray(batch_dims + (n,), 812 if isinstance(operand, ShapedArray): 820 perm = ShapedArray(batch_dims + (m,), jnp.int32) 1028 if isinstance(operand, ShapedArray): 1150 if isinstance(operand, ShapedArray): [all …]
|
H A D | fft.py | 21 from jax.core import Primitive, ShapedArray 76 return ShapedArray(shape, dtype)
|
H A D | parallel.py | 29 from jax.core import ShapedArray, raise_to_shaped 553 scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) 573 scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) 779 return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False) 895 return ShapedArray(new_shape, x_aval.dtype) 958 lambda *args, **params: ShapedArray((), np.int32)) 1074 out_aval = ShapedArray((), np.int32) 1091 lambda *args, **params: ShapedArray((), np.int32))
|
H A D | lax.py | 1479 aval = ShapedArray(shape, dtype) 1501 aval = ShapedArray((N, M), dtype) 1518 aval = ShapedArray(shape, dtype) 2666 lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding, argument 4983 scalar = ShapedArray((), dtype) 5015 scalar = ShapedArray((), dtype) 5067 scalar = ShapedArray((), dtype) 5258 scalar = ShapedArray((), dtype) 5312 scalar = ShapedArray((), dtype) 5427 scalar = ShapedArray((), dtype) [all …]
|
H A D | control_flow.py | 38 from jax.core import ConcreteArray, ShapedArray, raise_to_shaped 271 if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_): 322 scalar = ShapedArray((), np.bool_) 1240 x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) 1461 elif isinstance(aval, ShapedArray): 1462 return ShapedArray((sz, *aval.shape), aval.dtype) 1469 ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype) 1663 return ShapedArray((sz,) + aval.shape, aval.dtype) 1814 aval = ShapedArray((), dtypes.int_)
|
/dports/math/py-flax/flax-0.3.3/flax/core/ |
H A D | axes_scan.py | 125 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), 128 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))),
|
/dports/math/py-flax/flax-0.3.3/flax/ |
H A D | jax_utils.py | 60 aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) 114 in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
|
/dports/math/py-jax/jax-0.2.9/jax/_src/ |
H A D | dlpack.py | 68 aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
H A D | random.py | 112 if all(isinstance(arg, core.ShapedArray) for arg in args): 114 aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 572 assert isinstance(aval, core.ShapedArray) 576 return core.ShapedArray(tuple(shape), aval.dtype) 579 assert isinstance(aval, core.ShapedArray) 583 return core.ShapedArray(tuple(shape), aval.dtype)
|
H A D | loops.py | 589 if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]):
|
H A D | host_callback.py | 575 flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.result_type(r)) 644 def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
|
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | lax_numpy.py | 5193 setattr(ShapedArray, "reshape", core.aval_method(_reshape)) 5194 setattr(ShapedArray, "flatten", core.aval_method(ravel)) 5195 setattr(ShapedArray, "T", core.aval_property(transpose)) 5196 setattr(ShapedArray, "real", core.aval_property(real)) 5197 setattr(ShapedArray, "imag", core.aval_property(imag)) 5198 setattr(ShapedArray, "astype", core.aval_method(_astype)) 5199 setattr(ShapedArray, "view", core.aval_method(_view)) 5200 setattr(ShapedArray, "nbytes", core.aval_property(_nbytes)) 5237 setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) 5239 setattr(ShapedArray, "split", core.aval_method(split)) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 398 return core.ShapedArray(raw_shape, dtype) 420 return core.ShapedArray(shape, dtype) 541 return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank 696 return TensorFlowTracer(self, val, core.ShapedArray(shape, dtype)) 1076 return core.ShapedArray(aval.shape, np.int8) 2036 o_aval = core.ShapedArray((), to_jax_dtype(op.dtype)) 2048 _out_aval=core.ShapedArray((), np.bool_),
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | shape_poly_test.py | 70 return core.ShapedArray(masking.parse_spec(shape), np.float32) 72 return core.ShapedArray(shape, np.float32)
|