Searched refs:bitcast_convert_type (Results 1 – 6 of 6) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/lax/ |
H A D | __init__.py | 60 bitcast_convert_type,
|
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | lax_numpy.py | 693 x = lax.bitcast_convert_type(x, int_type) 764 return lax.bitcast_convert_type(x1, int_type), x2 798 x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) 826 x1 = lax.bitcast_convert_type(x1, dtype) 5056 return lax.bitcast_convert_type(arr, uint8).astype(dtype) 5057 return lax.bitcast_convert_type(arr, dtype) 5068 arr_bytes = lax.bitcast_convert_type(arr, dt_in) 5078 return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype) 5079 return lax.bitcast_convert_type(arr_bytes, dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | lax_reference.py | 162 def bitcast_convert_type(operand, dtype): function
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | jax2tf_limitations.py | 267 def bitcast_convert_type(cls, harness: primitive_harness.Harness): member in Jax2TfLimitation
|
/dports/math/py-jax/jax-0.2.9/jax/_src/ |
H A D | random.py | 393 floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 441 def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array: function 5708 signed = bitcast_convert_type(x, signed_dtype) 5709 unsigned = bitcast_convert_type(x, unsigned_dtype) 5710 flipped = bitcast_convert_type(
|