/dports/math/py-jax/jax-0.2.9/jax/_src/ |
H A D | random.py | 363 dtype = dtypes.canonicalize_dtype(dtype) 419 dtype = dtypes.canonicalize_dtype(dtype) 615 dtype = dtypes.canonicalize_dtype(dtype) 673 dtype = dtypes.canonicalize_dtype(dtype) 737 dtype = dtypes.canonicalize_dtype(dtype) 828 dtype = dtypes.canonicalize_dtype(dtype) 865 dtype = dtypes.canonicalize_dtype(dtype) 900 dtype = dtypes.canonicalize_dtype(dtype) 937 dtype = dtypes.canonicalize_dtype(dtype) 1069 dtype = dtypes.canonicalize_dtype(dtype) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | abstract_arrays.py | 38 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) 42 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
H A D | dtypes.py | 69 def canonicalize_dtype(dtype): function 328 return canonicalize_dtype(dtype(args[0])) 330 return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args}))
|
H A D | test_util.py | 81 return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits 140 dtype = _dtypes.canonicalize_dtype(np.dtype(dtype)) 178 assert (_dtypes.canonicalize_dtype(_dtype(x)) == 179 _dtypes.canonicalize_dtype(_dtype(y))) 821 self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)), 822 _dtypes.canonicalize_dtype(_dtype(y)))
|
H A D | api.py | 1070 return dtypes.canonicalize_dtype(dtypes.result_type(x))
|
H A D | core.py | 976 self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | jax2tf_test.py | 54 v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_)) 147 default_float_type = dtypes.canonicalize_dtype(jnp.float_) 170 default_float_dtype = dtypes.canonicalize_dtype(jnp.float_) 207 x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_)) 241 x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
|
H A D | tf_test_util.py | 102 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))), 103 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
|
H A D | primitive_harness.py | 797 index_dtype = dtypes.canonicalize_dtype(index_dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/lib/ |
H A D | xla_bridge.py | 286 return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype)) 291 return {dtypes.canonicalize_dtype(dtype) 299 return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
/dports/math/py-jax/jax-0.2.9/jax/_src/nn/ |
H A D | functions.py | 294 dtype = dtypes.canonicalize_dtype(dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 128 x = canonicalize_dtype(x) 150 def canonicalize_dtype(x): function 160 return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) 164 x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ])) 413 return [xb.constant(c, canonicalize_dtype(v.val))] 857 outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
|
H A D | pxla.py | 368 arg = xla.canonicalize_dtype(arg)
|
/dports/math/py-flax/flax-0.3.3/flax/optim/ |
H A D | adafactor.py | 105 self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | lax_numpy.py | 770 dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) 1119 dtype = dtypes.canonicalize_dtype(float_) 1843 info = finfo(dtypes.canonicalize_dtype(dtype)) 1904 a_dtype = dtypes.canonicalize_dtype(_dtype(a)) 1977 dtype = dtypes.canonicalize_dtype(dtype) 2001 out_dtype = dtypes.canonicalize_dtype(out_dtype) 2073 dtype = a_dtype = dtypes.canonicalize_dtype(float_) 2784 dtype = dtype and dtypes.canonicalize_dtype(dtype) 3257 default_int = dtypes.canonicalize_dtype(np.int_) 4212 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_)) [all …]
|
H A D | linalg.py | 41 dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 423 new_dtype = dtypes.canonicalize_dtype(new_dtype) 457 new_dtype = dtypes.canonicalize_dtype(new_dtype) 1471 dtype = dtypes.canonicalize_dtype(dtype) 1477 dtype = dtypes.canonicalize_dtype(dtype) 1484 dtype = dtypes.canonicalize_dtype(dtype) 1494 dtype = dtypes.canonicalize_dtype(dtype) 1508 dtype = dtypes.canonicalize_dtype(dtype) 1525 dtype = dtypes.canonicalize_dtype(dtype) 4115 return dtypes.canonicalize_dtype(operand.dtype) 4362 return dtypes.canonicalize_dtype(operand.dtype) [all …]
|
H A D | control_flow.py | 179 lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) 180 upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) 756 mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices))) 2558 return dtypes.canonicalize_dtype(operand.dtype)
|
H A D | linalg.py | 373 dtype = dtypes.canonicalize_dtype(dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 876 canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results)) 904 canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 541 return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
|