Home
last modified time | relevance | path

Searched refs:bitcast_convert_type (Results 1 – 6 of 6) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/lax/
H A D__init__.py60 bitcast_convert_type,
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlax_numpy.py693 x = lax.bitcast_convert_type(x, int_type)
764 return lax.bitcast_convert_type(x1, int_type), x2
798 x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
826 x1 = lax.bitcast_convert_type(x1, dtype)
5056 return lax.bitcast_convert_type(arr, uint8).astype(dtype)
5057 return lax.bitcast_convert_type(arr, dtype)
5068 arr_bytes = lax.bitcast_convert_type(arr, dt_in)
5078 return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
5079 return lax.bitcast_convert_type(arr_bytes, dtype)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dlax_reference.py162 def bitcast_convert_type(operand, dtype): function
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Djax2tf_limitations.py267 def bitcast_convert_type(cls, harness: primitive_harness.Harness): member in Jax2TfLimitation
/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Drandom.py393 floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py441 def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array: function
5708 signed = bitcast_convert_type(x, signed_dtype)
5709 unsigned = bitcast_convert_type(x, unsigned_dtype)
5710 flipped = bitcast_convert_type(