Home
last modified time | relevance | path

Searched refs:canonicalize_dtype (Results 1 – 21 of 21) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Drandom.py363 dtype = dtypes.canonicalize_dtype(dtype)
419 dtype = dtypes.canonicalize_dtype(dtype)
615 dtype = dtypes.canonicalize_dtype(dtype)
673 dtype = dtypes.canonicalize_dtype(dtype)
737 dtype = dtypes.canonicalize_dtype(dtype)
828 dtype = dtypes.canonicalize_dtype(dtype)
865 dtype = dtypes.canonicalize_dtype(dtype)
900 dtype = dtypes.canonicalize_dtype(dtype)
937 dtype = dtypes.canonicalize_dtype(dtype)
1069 dtype = dtypes.canonicalize_dtype(dtype)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/
H A Dabstract_arrays.py38 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
42 dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
H A Ddtypes.py69 def canonicalize_dtype(dtype): function
328 return canonicalize_dtype(dtype(args[0]))
330 return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args}))
H A Dtest_util.py81 return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
140 dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
178 assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
179 _dtypes.canonicalize_dtype(_dtype(y)))
821 self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x)),
822 _dtypes.canonicalize_dtype(_dtype(y)))
H A Dapi.py1070 return dtypes.canonicalize_dtype(dtypes.result_type(x))
H A Dcore.py976 self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Djax2tf_test.py54 v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_))
147 default_float_type = dtypes.canonicalize_dtype(jnp.float_)
170 default_float_dtype = dtypes.canonicalize_dtype(jnp.float_)
207 x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
241 x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
H A Dtf_test_util.py102 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(x))),
103 dtypes.canonicalize_dtype(to_numpy_dtype(jtu._dtype(y))))
H A Dprimitive_harness.py797 index_dtype = dtypes.canonicalize_dtype(index_dtype)
/dports/math/py-jax/jax-0.2.9/jax/lib/
H A Dxla_bridge.py286 return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
291 return {dtypes.canonicalize_dtype(dtype)
299 return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
/dports/math/py-jax/jax-0.2.9/jax/_src/nn/
H A Dfunctions.py294 dtype = dtypes.canonicalize_dtype(dtype)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py128 x = canonicalize_dtype(x)
150 def canonicalize_dtype(x): function
160 return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
164 x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
413 return [xb.constant(c, canonicalize_dtype(v.val))]
857 outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
H A Dpxla.py368 arg = xla.canonicalize_dtype(arg)
/dports/math/py-flax/flax-0.3.3/flax/optim/
H A Dadafactor.py105 self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum)
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlax_numpy.py770 dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
1119 dtype = dtypes.canonicalize_dtype(float_)
1843 info = finfo(dtypes.canonicalize_dtype(dtype))
1904 a_dtype = dtypes.canonicalize_dtype(_dtype(a))
1977 dtype = dtypes.canonicalize_dtype(dtype)
2001 out_dtype = dtypes.canonicalize_dtype(out_dtype)
2073 dtype = a_dtype = dtypes.canonicalize_dtype(float_)
2784 dtype = dtype and dtypes.canonicalize_dtype(dtype)
3257 default_int = dtypes.canonicalize_dtype(np.int_)
4212 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
[all …]
H A Dlinalg.py41 dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py423 new_dtype = dtypes.canonicalize_dtype(new_dtype)
457 new_dtype = dtypes.canonicalize_dtype(new_dtype)
1471 dtype = dtypes.canonicalize_dtype(dtype)
1477 dtype = dtypes.canonicalize_dtype(dtype)
1484 dtype = dtypes.canonicalize_dtype(dtype)
1494 dtype = dtypes.canonicalize_dtype(dtype)
1508 dtype = dtypes.canonicalize_dtype(dtype)
1525 dtype = dtypes.canonicalize_dtype(dtype)
4115 return dtypes.canonicalize_dtype(operand.dtype)
4362 return dtypes.canonicalize_dtype(operand.dtype)
[all …]
H A Dcontrol_flow.py179 lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
180 upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
756 mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
2558 return dtypes.canonicalize_dtype(operand.dtype)
H A Dlinalg.py373 dtype = dtypes.canonicalize_dtype(dtype)
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py876 canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
904 canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py541 return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank