1 #[cfg(feature = "bytemuck")] 2 use bytemuck::{Pod, Zeroable}; 3 use core::{ 4 cmp::Ordering, 5 fmt::{ 6 Binary, Debug, Display, Error, Formatter, LowerExp, LowerHex, Octal, UpperExp, UpperHex, 7 }, 8 iter::{Product, Sum}, 9 num::{FpCategory, ParseFloatError}, 10 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign}, 11 str::FromStr, 12 }; 13 #[cfg(feature = "serde")] 14 use serde::{Deserialize, Serialize}; 15 #[cfg(feature = "zerocopy")] 16 use zerocopy::{AsBytes, FromBytes}; 17 18 pub(crate) mod convert; 19 20 /// A 16-bit floating point type implementing the [`bfloat16`] format. 21 /// 22 /// The [`bfloat16`] floating point format is a truncated 16-bit version of the IEEE 754 standard 23 /// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] by 24 /// having a lower precision than [`f16`][crate::f16]. While [`f16`][crate::f16] has a precision of 25 /// 11 bits, [`bf16`] has a precision of only 8 bits. 26 /// 27 /// Like [`f16`][crate::f16], [`bf16`] does not offer arithmetic operations as it is intended for 28 /// compact storage rather than calculations. Operations should be performed with [`f32`] or 29 /// higher-precision types and converted to/from [`bf16`] as necessary. 30 /// 31 /// [`bfloat16`]: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 32 #[allow(non_camel_case_types)] 33 #[derive(Clone, Copy, Default)] 34 #[repr(transparent)] 35 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 36 #[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))] 37 #[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))] 38 pub struct bf16(u16); 39 40 impl bf16 { 41 /// Constructs a [`bf16`] value from the raw bits. 42 #[inline] from_bits(bits: u16) -> bf1643 pub const fn from_bits(bits: u16) -> bf16 { 44 bf16(bits) 45 } 46 47 /// Constructs a [`bf16`] value from a 32-bit floating point value. 48 /// 49 /// If the 32-bit value is too large to fit, ±∞ will result. NaN values are preserved. 50 /// Subnormal values that are too tiny to be represented will result in ±0. All other values 51 /// are truncated and rounded to the nearest representable value. 52 #[inline] from_f32(value: f32) -> bf1653 pub fn from_f32(value: f32) -> bf16 { 54 bf16(convert::f32_to_bf16(value)) 55 } 56 57 /// Constructs a [`bf16`] value from a 64-bit floating point value. 58 /// 59 /// If the 64-bit value is to large to fit, ±∞ will result. NaN values are preserved. 60 /// 64-bit subnormal values are too tiny to be represented and result in ±0. Exponents that 61 /// underflow the minimum exponent will result in subnormals or ±0. All other values are 62 /// truncated and rounded to the nearest representable value. 63 #[inline] from_f64(value: f64) -> bf1664 pub fn from_f64(value: f64) -> bf16 { 65 bf16(convert::f64_to_bf16(value)) 66 } 67 68 /// Converts a [`bf16`] into the underlying bit representation. 69 #[inline] to_bits(self) -> u1670 pub const fn to_bits(self) -> u16 { 71 self.0 72 } 73 74 /// Returns the memory representation of the underlying bit representation as a byte array in 75 /// little-endian byte order. 76 /// 77 /// # Examples 78 /// 79 /// ```rust 80 /// # use half::prelude::*; 81 /// let bytes = bf16::from_f32(12.5).to_le_bytes(); 82 /// assert_eq!(bytes, [0x48, 0x41]); 83 /// ``` 84 #[inline] to_le_bytes(self) -> [u8; 2]85 pub const fn to_le_bytes(self) -> [u8; 2] { 86 self.0.to_le_bytes() 87 } 88 89 /// Returns the memory representation of the underlying bit representation as a byte array in 90 /// big-endian (network) byte order. 91 /// 92 /// # Examples 93 /// 94 /// ```rust 95 /// # use half::prelude::*; 96 /// let bytes = bf16::from_f32(12.5).to_be_bytes(); 97 /// assert_eq!(bytes, [0x41, 0x48]); 98 /// ``` 99 #[inline] to_be_bytes(self) -> [u8; 2]100 pub const fn to_be_bytes(self) -> [u8; 2] { 101 self.0.to_be_bytes() 102 } 103 104 /// Returns the memory representation of the underlying bit representation as a byte array in 105 /// native byte order. 106 /// 107 /// As the target platform's native endianness is used, portable code should use 108 /// [`to_be_bytes`][bf16::to_be_bytes] or [`to_le_bytes`][bf16::to_le_bytes], as appropriate, 109 /// instead. 110 /// 111 /// # Examples 112 /// 113 /// ```rust 114 /// # use half::prelude::*; 115 /// let bytes = bf16::from_f32(12.5).to_ne_bytes(); 116 /// assert_eq!(bytes, if cfg!(target_endian = "big") { 117 /// [0x41, 0x48] 118 /// } else { 119 /// [0x48, 0x41] 120 /// }); 121 /// ``` 122 #[inline] to_ne_bytes(self) -> [u8; 2]123 pub const fn to_ne_bytes(self) -> [u8; 2] { 124 self.0.to_ne_bytes() 125 } 126 127 /// Creates a floating point value from its representation as a byte array in little endian. 128 /// 129 /// # Examples 130 /// 131 /// ```rust 132 /// # use half::prelude::*; 133 /// let value = bf16::from_le_bytes([0x48, 0x41]); 134 /// assert_eq!(value, bf16::from_f32(12.5)); 135 /// ``` 136 #[inline] from_le_bytes(bytes: [u8; 2]) -> bf16137 pub const fn from_le_bytes(bytes: [u8; 2]) -> bf16 { 138 bf16::from_bits(u16::from_le_bytes(bytes)) 139 } 140 141 /// Creates a floating point value from its representation as a byte array in big endian. 142 /// 143 /// # Examples 144 /// 145 /// ```rust 146 /// # use half::prelude::*; 147 /// let value = bf16::from_be_bytes([0x41, 0x48]); 148 /// assert_eq!(value, bf16::from_f32(12.5)); 149 /// ``` 150 #[inline] from_be_bytes(bytes: [u8; 2]) -> bf16151 pub const fn from_be_bytes(bytes: [u8; 2]) -> bf16 { 152 bf16::from_bits(u16::from_be_bytes(bytes)) 153 } 154 155 /// Creates a floating point value from its representation as a byte array in native endian. 156 /// 157 /// As the target platform's native endianness is used, portable code likely wants to use 158 /// [`from_be_bytes`][bf16::from_be_bytes] or [`from_le_bytes`][bf16::from_le_bytes], as 159 /// appropriate instead. 160 /// 161 /// # Examples 162 /// 163 /// ```rust 164 /// # use half::prelude::*; 165 /// let value = bf16::from_ne_bytes(if cfg!(target_endian = "big") { 166 /// [0x41, 0x48] 167 /// } else { 168 /// [0x48, 0x41] 169 /// }); 170 /// assert_eq!(value, bf16::from_f32(12.5)); 171 /// ``` 172 #[inline] from_ne_bytes(bytes: [u8; 2]) -> bf16173 pub const fn from_ne_bytes(bytes: [u8; 2]) -> bf16 { 174 bf16::from_bits(u16::from_ne_bytes(bytes)) 175 } 176 177 /// Converts a [`bf16`] value into an [`f32`] value. 178 /// 179 /// This conversion is lossless as all values can be represented exactly in [`f32`]. 180 #[inline] to_f32(self) -> f32181 pub fn to_f32(self) -> f32 { 182 convert::bf16_to_f32(self.0) 183 } 184 185 /// Converts a [`bf16`] value into an [`f64`] value. 186 /// 187 /// This conversion is lossless as all values can be represented exactly in [`f64`]. 188 #[inline] to_f64(self) -> f64189 pub fn to_f64(self) -> f64 { 190 convert::bf16_to_f64(self.0) 191 } 192 193 /// Returns `true` if this value is NaN and `false` otherwise. 194 /// 195 /// # Examples 196 /// 197 /// ```rust 198 /// # use half::prelude::*; 199 /// 200 /// let nan = bf16::NAN; 201 /// let f = bf16::from_f32(7.0_f32); 202 /// 203 /// assert!(nan.is_nan()); 204 /// assert!(!f.is_nan()); 205 /// ``` 206 #[inline] is_nan(self) -> bool207 pub const fn is_nan(self) -> bool { 208 self.0 & 0x7FFFu16 > 0x7F80u16 209 } 210 211 /// Returns `true` if this value is ±∞ and `false` otherwise. 212 /// 213 /// # Examples 214 /// 215 /// ```rust 216 /// # use half::prelude::*; 217 /// 218 /// let f = bf16::from_f32(7.0f32); 219 /// let inf = bf16::INFINITY; 220 /// let neg_inf = bf16::NEG_INFINITY; 221 /// let nan = bf16::NAN; 222 /// 223 /// assert!(!f.is_infinite()); 224 /// assert!(!nan.is_infinite()); 225 /// 226 /// assert!(inf.is_infinite()); 227 /// assert!(neg_inf.is_infinite()); 228 /// ``` 229 #[inline] is_infinite(self) -> bool230 pub const fn is_infinite(self) -> bool { 231 self.0 & 0x7FFFu16 == 0x7F80u16 232 } 233 234 /// Returns `true` if this number is neither infinite nor NaN. 235 /// 236 /// # Examples 237 /// 238 /// ```rust 239 /// # use half::prelude::*; 240 /// 241 /// let f = bf16::from_f32(7.0f32); 242 /// let inf = bf16::INFINITY; 243 /// let neg_inf = bf16::NEG_INFINITY; 244 /// let nan = bf16::NAN; 245 /// 246 /// assert!(f.is_finite()); 247 /// 248 /// assert!(!nan.is_finite()); 249 /// assert!(!inf.is_finite()); 250 /// assert!(!neg_inf.is_finite()); 251 /// ``` 252 #[inline] is_finite(self) -> bool253 pub const fn is_finite(self) -> bool { 254 self.0 & 0x7F80u16 != 0x7F80u16 255 } 256 257 /// Returns `true` if the number is neither zero, infinite, subnormal, or NaN. 258 /// 259 /// # Examples 260 /// 261 /// ```rust 262 /// # use half::prelude::*; 263 /// 264 /// let min = bf16::MIN_POSITIVE; 265 /// let max = bf16::MAX; 266 /// let lower_than_min = bf16::from_f32(1.0e-39_f32); 267 /// let zero = bf16::from_f32(0.0_f32); 268 /// 269 /// assert!(min.is_normal()); 270 /// assert!(max.is_normal()); 271 /// 272 /// assert!(!zero.is_normal()); 273 /// assert!(!bf16::NAN.is_normal()); 274 /// assert!(!bf16::INFINITY.is_normal()); 275 /// // Values between 0 and `min` are subnormal. 276 /// assert!(!lower_than_min.is_normal()); 277 /// ``` 278 #[inline] is_normal(self) -> bool279 pub const fn is_normal(self) -> bool { 280 let exp = self.0 & 0x7F80u16; 281 exp != 0x7F80u16 && exp != 0 282 } 283 284 /// Returns the floating point category of the number. 285 /// 286 /// If only one property is going to be tested, it is generally faster to use the specific 287 /// predicate instead. 288 /// 289 /// # Examples 290 /// 291 /// ```rust 292 /// use std::num::FpCategory; 293 /// # use half::prelude::*; 294 /// 295 /// let num = bf16::from_f32(12.4_f32); 296 /// let inf = bf16::INFINITY; 297 /// 298 /// assert_eq!(num.classify(), FpCategory::Normal); 299 /// assert_eq!(inf.classify(), FpCategory::Infinite); 300 /// ``` classify(self) -> FpCategory301 pub const fn classify(self) -> FpCategory { 302 let exp = self.0 & 0x7F80u16; 303 let man = self.0 & 0x007Fu16; 304 match (exp, man) { 305 (0, 0) => FpCategory::Zero, 306 (0, _) => FpCategory::Subnormal, 307 (0x7F80u16, 0) => FpCategory::Infinite, 308 (0x7F80u16, _) => FpCategory::Nan, 309 _ => FpCategory::Normal, 310 } 311 } 312 313 /// Returns a number that represents the sign of `self`. 314 /// 315 /// * 1.0 if the number is positive, +0.0 or [`INFINITY`][bf16::INFINITY] 316 /// * −1.0 if the number is negative, −0.0` or [`NEG_INFINITY`][bf16::NEG_INFINITY] 317 /// * [`NAN`][bf16::NAN] if the number is NaN 318 /// 319 /// # Examples 320 /// 321 /// ```rust 322 /// # use half::prelude::*; 323 /// 324 /// let f = bf16::from_f32(3.5_f32); 325 /// 326 /// assert_eq!(f.signum(), bf16::from_f32(1.0)); 327 /// assert_eq!(bf16::NEG_INFINITY.signum(), bf16::from_f32(-1.0)); 328 /// 329 /// assert!(bf16::NAN.signum().is_nan()); 330 /// ``` signum(self) -> bf16331 pub const fn signum(self) -> bf16 { 332 if self.is_nan() { 333 self 334 } else if self.0 & 0x8000u16 != 0 { 335 Self::NEG_ONE 336 } else { 337 Self::ONE 338 } 339 } 340 341 /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a 342 /// positive sign bit and +∞. 343 /// 344 /// # Examples 345 /// 346 /// ```rust 347 /// # use half::prelude::*; 348 /// 349 /// let nan = bf16::NAN; 350 /// let f = bf16::from_f32(7.0_f32); 351 /// let g = bf16::from_f32(-7.0_f32); 352 /// 353 /// assert!(f.is_sign_positive()); 354 /// assert!(!g.is_sign_positive()); 355 /// // NaN can be either positive or negative 356 /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); 357 /// ``` 358 #[inline] is_sign_positive(self) -> bool359 pub const fn is_sign_positive(self) -> bool { 360 self.0 & 0x8000u16 == 0 361 } 362 363 /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a 364 /// negative sign bit and −∞. 365 /// 366 /// # Examples 367 /// 368 /// ```rust 369 /// # use half::prelude::*; 370 /// 371 /// let nan = bf16::NAN; 372 /// let f = bf16::from_f32(7.0f32); 373 /// let g = bf16::from_f32(-7.0f32); 374 /// 375 /// assert!(!f.is_sign_negative()); 376 /// assert!(g.is_sign_negative()); 377 /// // NaN can be either positive or negative 378 /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); 379 /// ``` 380 #[inline] is_sign_negative(self) -> bool381 pub const fn is_sign_negative(self) -> bool { 382 self.0 & 0x8000u16 != 0 383 } 384 385 /// Returns a number composed of the magnitude of `self` and the sign of `sign`. 386 /// 387 /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`. 388 /// If `self` is NaN, then NaN with the sign of `sign` is returned. 389 /// 390 /// # Examples 391 /// 392 /// ``` 393 /// # use half::prelude::*; 394 /// let f = bf16::from_f32(3.5); 395 /// 396 /// assert_eq!(f.copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); 397 /// assert_eq!(f.copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); 398 /// assert_eq!((-f).copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); 399 /// assert_eq!((-f).copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); 400 /// 401 /// assert!(bf16::NAN.copysign(bf16::from_f32(1.0)).is_nan()); 402 /// ``` 403 #[inline] copysign(self, sign: bf16) -> bf16404 pub const fn copysign(self, sign: bf16) -> bf16 { 405 bf16((sign.0 & 0x8000u16) | (self.0 & 0x7FFFu16)) 406 } 407 408 /// Returns the maximum of the two numbers. 409 /// 410 /// If one of the arguments is NaN, then the other argument is returned. 411 /// 412 /// # Examples 413 /// 414 /// ``` 415 /// # use half::prelude::*; 416 /// let x = bf16::from_f32(1.0); 417 /// let y = bf16::from_f32(2.0); 418 /// 419 /// assert_eq!(x.max(y), y); 420 /// ``` 421 #[inline] max(self, other: bf16) -> bf16422 pub fn max(self, other: bf16) -> bf16 { 423 if other > self && !other.is_nan() { 424 other 425 } else { 426 self 427 } 428 } 429 430 /// Returns the minimum of the two numbers. 431 /// 432 /// If one of the arguments is NaN, then the other argument is returned. 433 /// 434 /// # Examples 435 /// 436 /// ``` 437 /// # use half::prelude::*; 438 /// let x = bf16::from_f32(1.0); 439 /// let y = bf16::from_f32(2.0); 440 /// 441 /// assert_eq!(x.min(y), x); 442 /// ``` 443 #[inline] min(self, other: bf16) -> bf16444 pub fn min(self, other: bf16) -> bf16 { 445 if other < self && !other.is_nan() { 446 other 447 } else { 448 self 449 } 450 } 451 452 /// Restrict a value to a certain interval unless it is NaN. 453 /// 454 /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`. 455 /// Otherwise this returns `self`. 456 /// 457 /// Note that this function returns NaN if the initial value was NaN as well. 458 /// 459 /// # Panics 460 /// Panics if `min > max`, `min` is NaN, or `max` is NaN. 461 /// 462 /// # Examples 463 /// 464 /// ``` 465 /// # use half::prelude::*; 466 /// assert!(bf16::from_f32(-3.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(-2.0)); 467 /// assert!(bf16::from_f32(0.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(0.0)); 468 /// assert!(bf16::from_f32(2.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(1.0)); 469 /// assert!(bf16::NAN.clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)).is_nan()); 470 /// ``` 471 #[inline] clamp(self, min: bf16, max: bf16) -> bf16472 pub fn clamp(self, min: bf16, max: bf16) -> bf16 { 473 assert!(min <= max); 474 let mut x = self; 475 if x < min { 476 x = min; 477 } 478 if x > max { 479 x = max; 480 } 481 x 482 } 483 484 /// Approximate number of [`bf16`] significant digits in base 10 485 pub const DIGITS: u32 = 2; 486 /// [`bf16`] 487 /// [machine epsilon](https://en.wikipedia.org/wiki/Machine_epsilon) value 488 /// 489 /// This is the difference between 1.0 and the next largest representable number. 490 pub const EPSILON: bf16 = bf16(0x3C00u16); 491 /// [`bf16`] positive Infinity (+∞) 492 pub const INFINITY: bf16 = bf16(0x7F80u16); 493 /// Number of [`bf16`] significant digits in base 2 494 pub const MANTISSA_DIGITS: u32 = 8; 495 /// Largest finite [`bf16`] value 496 pub const MAX: bf16 = bf16(0x7F7F); 497 /// Maximum possible [`bf16`] power of 10 exponent 498 pub const MAX_10_EXP: i32 = 38; 499 /// Maximum possible [`bf16`] power of 2 exponent 500 pub const MAX_EXP: i32 = 128; 501 /// Smallest finite [`bf16`] value 502 pub const MIN: bf16 = bf16(0xFF7F); 503 /// Minimum possible normal [`bf16`] power of 10 exponent 504 pub const MIN_10_EXP: i32 = -37; 505 /// One greater than the minimum possible normal [`bf16`] power of 2 exponent 506 pub const MIN_EXP: i32 = -125; 507 /// Smallest positive normal [`bf16`] value 508 pub const MIN_POSITIVE: bf16 = bf16(0x0080u16); 509 /// [`bf16`] Not a Number (NaN) 510 pub const NAN: bf16 = bf16(0x7FC0u16); 511 /// [`bf16`] negative infinity (-∞). 512 pub const NEG_INFINITY: bf16 = bf16(0xFF80u16); 513 /// The radix or base of the internal representation of [`bf16`] 514 pub const RADIX: u32 = 2; 515 516 /// Minimum positive subnormal [`bf16`] value 517 pub const MIN_POSITIVE_SUBNORMAL: bf16 = bf16(0x0001u16); 518 /// Maximum subnormal [`bf16`] value 519 pub const MAX_SUBNORMAL: bf16 = bf16(0x007Fu16); 520 521 /// [`bf16`] 1 522 pub const ONE: bf16 = bf16(0x3F80u16); 523 /// [`bf16`] 0 524 pub const ZERO: bf16 = bf16(0x0000u16); 525 /// [`bf16`] -0 526 pub const NEG_ZERO: bf16 = bf16(0x8000u16); 527 /// [`bf16`] -1 528 pub const NEG_ONE: bf16 = bf16(0xBF80u16); 529 530 /// [`bf16`] Euler's number (ℯ) 531 pub const E: bf16 = bf16(0x402Eu16); 532 /// [`bf16`] Archimedes' constant (π) 533 pub const PI: bf16 = bf16(0x4049u16); 534 /// [`bf16`] 1/π 535 pub const FRAC_1_PI: bf16 = bf16(0x3EA3u16); 536 /// [`bf16`] 1/√2 537 pub const FRAC_1_SQRT_2: bf16 = bf16(0x3F35u16); 538 /// [`bf16`] 2/π 539 pub const FRAC_2_PI: bf16 = bf16(0x3F23u16); 540 /// [`bf16`] 2/√π 541 pub const FRAC_2_SQRT_PI: bf16 = bf16(0x3F90u16); 542 /// [`bf16`] π/2 543 pub const FRAC_PI_2: bf16 = bf16(0x3FC9u16); 544 /// [`bf16`] π/3 545 pub const FRAC_PI_3: bf16 = bf16(0x3F86u16); 546 /// [`bf16`] π/4 547 pub const FRAC_PI_4: bf16 = bf16(0x3F49u16); 548 /// [`bf16`] π/6 549 pub const FRAC_PI_6: bf16 = bf16(0x3F06u16); 550 /// [`bf16`] π/8 551 pub const FRAC_PI_8: bf16 = bf16(0x3EC9u16); 552 /// [`bf16`] 10 553 pub const LN_10: bf16 = bf16(0x4013u16); 554 /// [`bf16`] 2 555 pub const LN_2: bf16 = bf16(0x3F31u16); 556 /// [`bf16`] ₁₀ℯ 557 pub const LOG10_E: bf16 = bf16(0x3EDEu16); 558 /// [`bf16`] ₁₀2 559 pub const LOG10_2: bf16 = bf16(0x3E9Au16); 560 /// [`bf16`] ₂ℯ 561 pub const LOG2_E: bf16 = bf16(0x3FB9u16); 562 /// [`bf16`] ₂10 563 pub const LOG2_10: bf16 = bf16(0x4055u16); 564 /// [`bf16`] √2 565 pub const SQRT_2: bf16 = bf16(0x3FB5u16); 566 } 567 568 impl From<bf16> for f32 { 569 #[inline] from(x: bf16) -> f32570 fn from(x: bf16) -> f32 { 571 x.to_f32() 572 } 573 } 574 575 impl From<bf16> for f64 { 576 #[inline] from(x: bf16) -> f64577 fn from(x: bf16) -> f64 { 578 x.to_f64() 579 } 580 } 581 582 impl From<i8> for bf16 { 583 #[inline] from(x: i8) -> bf16584 fn from(x: i8) -> bf16 { 585 // Convert to f32, then to bf16 586 bf16::from_f32(f32::from(x)) 587 } 588 } 589 590 impl From<u8> for bf16 { 591 #[inline] from(x: u8) -> bf16592 fn from(x: u8) -> bf16 { 593 // Convert to f32, then to f16 594 bf16::from_f32(f32::from(x)) 595 } 596 } 597 598 impl PartialEq for bf16 { eq(&self, other: &bf16) -> bool599 fn eq(&self, other: &bf16) -> bool { 600 if self.is_nan() || other.is_nan() { 601 false 602 } else { 603 (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) 604 } 605 } 606 } 607 608 impl PartialOrd for bf16 { partial_cmp(&self, other: &bf16) -> Option<Ordering>609 fn partial_cmp(&self, other: &bf16) -> Option<Ordering> { 610 if self.is_nan() || other.is_nan() { 611 None 612 } else { 613 let neg = self.0 & 0x8000u16 != 0; 614 let other_neg = other.0 & 0x8000u16 != 0; 615 match (neg, other_neg) { 616 (false, false) => Some(self.0.cmp(&other.0)), 617 (false, true) => { 618 if (self.0 | other.0) & 0x7FFFu16 == 0 { 619 Some(Ordering::Equal) 620 } else { 621 Some(Ordering::Greater) 622 } 623 } 624 (true, false) => { 625 if (self.0 | other.0) & 0x7FFFu16 == 0 { 626 Some(Ordering::Equal) 627 } else { 628 Some(Ordering::Less) 629 } 630 } 631 (true, true) => Some(other.0.cmp(&self.0)), 632 } 633 } 634 } 635 lt(&self, other: &bf16) -> bool636 fn lt(&self, other: &bf16) -> bool { 637 if self.is_nan() || other.is_nan() { 638 false 639 } else { 640 let neg = self.0 & 0x8000u16 != 0; 641 let other_neg = other.0 & 0x8000u16 != 0; 642 match (neg, other_neg) { 643 (false, false) => self.0 < other.0, 644 (false, true) => false, 645 (true, false) => (self.0 | other.0) & 0x7FFFu16 != 0, 646 (true, true) => self.0 > other.0, 647 } 648 } 649 } 650 le(&self, other: &bf16) -> bool651 fn le(&self, other: &bf16) -> bool { 652 if self.is_nan() || other.is_nan() { 653 false 654 } else { 655 let neg = self.0 & 0x8000u16 != 0; 656 let other_neg = other.0 & 0x8000u16 != 0; 657 match (neg, other_neg) { 658 (false, false) => self.0 <= other.0, 659 (false, true) => (self.0 | other.0) & 0x7FFFu16 == 0, 660 (true, false) => true, 661 (true, true) => self.0 >= other.0, 662 } 663 } 664 } 665 gt(&self, other: &bf16) -> bool666 fn gt(&self, other: &bf16) -> bool { 667 if self.is_nan() || other.is_nan() { 668 false 669 } else { 670 let neg = self.0 & 0x8000u16 != 0; 671 let other_neg = other.0 & 0x8000u16 != 0; 672 match (neg, other_neg) { 673 (false, false) => self.0 > other.0, 674 (false, true) => (self.0 | other.0) & 0x7FFFu16 != 0, 675 (true, false) => false, 676 (true, true) => self.0 < other.0, 677 } 678 } 679 } 680 ge(&self, other: &bf16) -> bool681 fn ge(&self, other: &bf16) -> bool { 682 if self.is_nan() || other.is_nan() { 683 false 684 } else { 685 let neg = self.0 & 0x8000u16 != 0; 686 let other_neg = other.0 & 0x8000u16 != 0; 687 match (neg, other_neg) { 688 (false, false) => self.0 >= other.0, 689 (false, true) => true, 690 (true, false) => (self.0 | other.0) & 0x7FFFu16 == 0, 691 (true, true) => self.0 <= other.0, 692 } 693 } 694 } 695 } 696 697 impl FromStr for bf16 { 698 type Err = ParseFloatError; from_str(src: &str) -> Result<bf16, ParseFloatError>699 fn from_str(src: &str) -> Result<bf16, ParseFloatError> { 700 f32::from_str(src).map(bf16::from_f32) 701 } 702 } 703 704 impl Debug for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>705 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 706 write!(f, "{:?}", self.to_f32()) 707 } 708 } 709 710 impl Display for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>711 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 712 write!(f, "{}", self.to_f32()) 713 } 714 } 715 716 impl LowerExp for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>717 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 718 write!(f, "{:e}", self.to_f32()) 719 } 720 } 721 722 impl UpperExp for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>723 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 724 write!(f, "{:E}", self.to_f32()) 725 } 726 } 727 728 impl Binary for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>729 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 730 write!(f, "{:b}", self.0) 731 } 732 } 733 734 impl Octal for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>735 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 736 write!(f, "{:o}", self.0) 737 } 738 } 739 740 impl LowerHex for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>741 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 742 write!(f, "{:x}", self.0) 743 } 744 } 745 746 impl UpperHex for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>747 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 748 write!(f, "{:X}", self.0) 749 } 750 } 751 752 impl Neg for bf16 { 753 type Output = Self; 754 neg(self) -> Self::Output755 fn neg(self) -> Self::Output { 756 Self(self.0 ^ 0x8000) 757 } 758 } 759 760 impl Add for bf16 { 761 type Output = Self; 762 add(self, rhs: Self) -> Self::Output763 fn add(self, rhs: Self) -> Self::Output { 764 Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs)) 765 } 766 } 767 768 impl Add<&bf16> for bf16 { 769 type Output = <bf16 as Add<bf16>>::Output; 770 771 #[inline] add(self, rhs: &bf16) -> Self::Output772 fn add(self, rhs: &bf16) -> Self::Output { 773 self.add(*rhs) 774 } 775 } 776 777 impl Add<&bf16> for &bf16 { 778 type Output = <bf16 as Add<bf16>>::Output; 779 780 #[inline] add(self, rhs: &bf16) -> Self::Output781 fn add(self, rhs: &bf16) -> Self::Output { 782 (*self).add(*rhs) 783 } 784 } 785 786 impl Add<bf16> for &bf16 { 787 type Output = <bf16 as Add<bf16>>::Output; 788 789 #[inline] add(self, rhs: bf16) -> Self::Output790 fn add(self, rhs: bf16) -> Self::Output { 791 (*self).add(rhs) 792 } 793 } 794 795 impl AddAssign for bf16 { 796 #[inline] add_assign(&mut self, rhs: Self)797 fn add_assign(&mut self, rhs: Self) { 798 *self = (*self).add(rhs); 799 } 800 } 801 802 impl AddAssign<&bf16> for bf16 { 803 #[inline] add_assign(&mut self, rhs: &bf16)804 fn add_assign(&mut self, rhs: &bf16) { 805 *self = (*self).add(rhs); 806 } 807 } 808 809 impl Sub for bf16 { 810 type Output = Self; 811 sub(self, rhs: Self) -> Self::Output812 fn sub(self, rhs: Self) -> Self::Output { 813 Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs)) 814 } 815 } 816 817 impl Sub<&bf16> for bf16 { 818 type Output = <bf16 as Sub<bf16>>::Output; 819 820 #[inline] sub(self, rhs: &bf16) -> Self::Output821 fn sub(self, rhs: &bf16) -> Self::Output { 822 self.sub(*rhs) 823 } 824 } 825 826 impl Sub<&bf16> for &bf16 { 827 type Output = <bf16 as Sub<bf16>>::Output; 828 829 #[inline] sub(self, rhs: &bf16) -> Self::Output830 fn sub(self, rhs: &bf16) -> Self::Output { 831 (*self).sub(*rhs) 832 } 833 } 834 835 impl Sub<bf16> for &bf16 { 836 type Output = <bf16 as Sub<bf16>>::Output; 837 838 #[inline] sub(self, rhs: bf16) -> Self::Output839 fn sub(self, rhs: bf16) -> Self::Output { 840 (*self).sub(rhs) 841 } 842 } 843 844 impl SubAssign for bf16 { 845 #[inline] sub_assign(&mut self, rhs: Self)846 fn sub_assign(&mut self, rhs: Self) { 847 *self = (*self).sub(rhs); 848 } 849 } 850 851 impl SubAssign<&bf16> for bf16 { 852 #[inline] sub_assign(&mut self, rhs: &bf16)853 fn sub_assign(&mut self, rhs: &bf16) { 854 *self = (*self).sub(rhs); 855 } 856 } 857 858 impl Mul for bf16 { 859 type Output = Self; 860 mul(self, rhs: Self) -> Self::Output861 fn mul(self, rhs: Self) -> Self::Output { 862 Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs)) 863 } 864 } 865 866 impl Mul<&bf16> for bf16 { 867 type Output = <bf16 as Mul<bf16>>::Output; 868 869 #[inline] mul(self, rhs: &bf16) -> Self::Output870 fn mul(self, rhs: &bf16) -> Self::Output { 871 self.mul(*rhs) 872 } 873 } 874 875 impl Mul<&bf16> for &bf16 { 876 type Output = <bf16 as Mul<bf16>>::Output; 877 878 #[inline] mul(self, rhs: &bf16) -> Self::Output879 fn mul(self, rhs: &bf16) -> Self::Output { 880 (*self).mul(*rhs) 881 } 882 } 883 884 impl Mul<bf16> for &bf16 { 885 type Output = <bf16 as Mul<bf16>>::Output; 886 887 #[inline] mul(self, rhs: bf16) -> Self::Output888 fn mul(self, rhs: bf16) -> Self::Output { 889 (*self).mul(rhs) 890 } 891 } 892 893 impl MulAssign for bf16 { 894 #[inline] mul_assign(&mut self, rhs: Self)895 fn mul_assign(&mut self, rhs: Self) { 896 *self = (*self).mul(rhs); 897 } 898 } 899 900 impl MulAssign<&bf16> for bf16 { 901 #[inline] mul_assign(&mut self, rhs: &bf16)902 fn mul_assign(&mut self, rhs: &bf16) { 903 *self = (*self).mul(rhs); 904 } 905 } 906 907 impl Div for bf16 { 908 type Output = Self; 909 div(self, rhs: Self) -> Self::Output910 fn div(self, rhs: Self) -> Self::Output { 911 Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs)) 912 } 913 } 914 915 impl Div<&bf16> for bf16 { 916 type Output = <bf16 as Div<bf16>>::Output; 917 918 #[inline] div(self, rhs: &bf16) -> Self::Output919 fn div(self, rhs: &bf16) -> Self::Output { 920 self.div(*rhs) 921 } 922 } 923 924 impl Div<&bf16> for &bf16 { 925 type Output = <bf16 as Div<bf16>>::Output; 926 927 #[inline] div(self, rhs: &bf16) -> Self::Output928 fn div(self, rhs: &bf16) -> Self::Output { 929 (*self).div(*rhs) 930 } 931 } 932 933 impl Div<bf16> for &bf16 { 934 type Output = <bf16 as Div<bf16>>::Output; 935 936 #[inline] div(self, rhs: bf16) -> Self::Output937 fn div(self, rhs: bf16) -> Self::Output { 938 (*self).div(rhs) 939 } 940 } 941 942 impl DivAssign for bf16 { 943 #[inline] div_assign(&mut self, rhs: Self)944 fn div_assign(&mut self, rhs: Self) { 945 *self = (*self).div(rhs); 946 } 947 } 948 949 impl DivAssign<&bf16> for bf16 { 950 #[inline] div_assign(&mut self, rhs: &bf16)951 fn div_assign(&mut self, rhs: &bf16) { 952 *self = (*self).div(rhs); 953 } 954 } 955 956 impl Rem for bf16 { 957 type Output = Self; 958 rem(self, rhs: Self) -> Self::Output959 fn rem(self, rhs: Self) -> Self::Output { 960 Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs)) 961 } 962 } 963 964 impl Rem<&bf16> for bf16 { 965 type Output = <bf16 as Rem<bf16>>::Output; 966 967 #[inline] rem(self, rhs: &bf16) -> Self::Output968 fn rem(self, rhs: &bf16) -> Self::Output { 969 self.rem(*rhs) 970 } 971 } 972 973 impl Rem<&bf16> for &bf16 { 974 type Output = <bf16 as Rem<bf16>>::Output; 975 976 #[inline] rem(self, rhs: &bf16) -> Self::Output977 fn rem(self, rhs: &bf16) -> Self::Output { 978 (*self).rem(*rhs) 979 } 980 } 981 982 impl Rem<bf16> for &bf16 { 983 type Output = <bf16 as Rem<bf16>>::Output; 984 985 #[inline] rem(self, rhs: bf16) -> Self::Output986 fn rem(self, rhs: bf16) -> Self::Output { 987 (*self).rem(rhs) 988 } 989 } 990 991 impl RemAssign for bf16 { 992 #[inline] rem_assign(&mut self, rhs: Self)993 fn rem_assign(&mut self, rhs: Self) { 994 *self = (*self).rem(rhs); 995 } 996 } 997 998 impl RemAssign<&bf16> for bf16 { 999 #[inline] rem_assign(&mut self, rhs: &bf16)1000 fn rem_assign(&mut self, rhs: &bf16) { 1001 *self = (*self).rem(rhs); 1002 } 1003 } 1004 1005 impl Product for bf16 { 1006 #[inline] product<I: Iterator<Item = Self>>(iter: I) -> Self1007 fn product<I: Iterator<Item = Self>>(iter: I) -> Self { 1008 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1009 } 1010 } 1011 1012 impl<'a> Product<&'a bf16> for bf16 { 1013 #[inline] product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self1014 fn product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { 1015 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1016 } 1017 } 1018 1019 impl Sum for bf16 { 1020 #[inline] sum<I: Iterator<Item = Self>>(iter: I) -> Self1021 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self { 1022 bf16::from_f32(iter.map(|f| f.to_f32()).sum()) 1023 } 1024 } 1025 1026 impl<'a> Sum<&'a bf16> for bf16 { 1027 #[inline] sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self1028 fn sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { 1029 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1030 } 1031 } 1032 1033 #[allow( 1034 clippy::cognitive_complexity, 1035 clippy::float_cmp, 1036 clippy::neg_cmp_op_on_partial_ord 1037 )] 1038 #[cfg(test)] 1039 mod test { 1040 use super::*; 1041 use core::cmp::Ordering; 1042 #[cfg(feature = "num-traits")] 1043 use num_traits::{AsPrimitive, FromPrimitive, ToPrimitive}; 1044 use quickcheck_macros::quickcheck; 1045 1046 #[cfg(feature = "num-traits")] 1047 #[test] as_primitive()1048 fn as_primitive() { 1049 let two = bf16::from_f32(2.0); 1050 assert_eq!(<i32 as AsPrimitive<bf16>>::as_(2), two); 1051 assert_eq!(<bf16 as AsPrimitive<i32>>::as_(two), 2); 1052 1053 assert_eq!(<f32 as AsPrimitive<bf16>>::as_(2.0), two); 1054 assert_eq!(<bf16 as AsPrimitive<f32>>::as_(two), 2.0); 1055 1056 assert_eq!(<f64 as AsPrimitive<bf16>>::as_(2.0), two); 1057 assert_eq!(<bf16 as AsPrimitive<f64>>::as_(two), 2.0); 1058 } 1059 1060 #[cfg(feature = "num-traits")] 1061 #[test] to_primitive()1062 fn to_primitive() { 1063 let two = bf16::from_f32(2.0); 1064 assert_eq!(ToPrimitive::to_i32(&two).unwrap(), 2i32); 1065 assert_eq!(ToPrimitive::to_f32(&two).unwrap(), 2.0f32); 1066 assert_eq!(ToPrimitive::to_f64(&two).unwrap(), 2.0f64); 1067 } 1068 1069 #[cfg(feature = "num-traits")] 1070 #[test] from_primitive()1071 fn from_primitive() { 1072 let two = bf16::from_f32(2.0); 1073 assert_eq!(<bf16 as FromPrimitive>::from_i32(2).unwrap(), two); 1074 assert_eq!(<bf16 as FromPrimitive>::from_f32(2.0).unwrap(), two); 1075 assert_eq!(<bf16 as FromPrimitive>::from_f64(2.0).unwrap(), two); 1076 } 1077 1078 #[test] test_bf16_consts_from_f32()1079 fn test_bf16_consts_from_f32() { 1080 let one = bf16::from_f32(1.0); 1081 let zero = bf16::from_f32(0.0); 1082 let neg_zero = bf16::from_f32(-0.0); 1083 let neg_one = bf16::from_f32(-1.0); 1084 let inf = bf16::from_f32(core::f32::INFINITY); 1085 let neg_inf = bf16::from_f32(core::f32::NEG_INFINITY); 1086 let nan = bf16::from_f32(core::f32::NAN); 1087 1088 assert_eq!(bf16::ONE, one); 1089 assert_eq!(bf16::ZERO, zero); 1090 assert!(zero.is_sign_positive()); 1091 assert_eq!(bf16::NEG_ZERO, neg_zero); 1092 assert!(neg_zero.is_sign_negative()); 1093 assert_eq!(bf16::NEG_ONE, neg_one); 1094 assert!(neg_one.is_sign_negative()); 1095 assert_eq!(bf16::INFINITY, inf); 1096 assert_eq!(bf16::NEG_INFINITY, neg_inf); 1097 assert!(nan.is_nan()); 1098 assert!(bf16::NAN.is_nan()); 1099 1100 let e = bf16::from_f32(core::f32::consts::E); 1101 let pi = bf16::from_f32(core::f32::consts::PI); 1102 let frac_1_pi = bf16::from_f32(core::f32::consts::FRAC_1_PI); 1103 let frac_1_sqrt_2 = bf16::from_f32(core::f32::consts::FRAC_1_SQRT_2); 1104 let frac_2_pi = bf16::from_f32(core::f32::consts::FRAC_2_PI); 1105 let frac_2_sqrt_pi = bf16::from_f32(core::f32::consts::FRAC_2_SQRT_PI); 1106 let frac_pi_2 = bf16::from_f32(core::f32::consts::FRAC_PI_2); 1107 let frac_pi_3 = bf16::from_f32(core::f32::consts::FRAC_PI_3); 1108 let frac_pi_4 = bf16::from_f32(core::f32::consts::FRAC_PI_4); 1109 let frac_pi_6 = bf16::from_f32(core::f32::consts::FRAC_PI_6); 1110 let frac_pi_8 = bf16::from_f32(core::f32::consts::FRAC_PI_8); 1111 let ln_10 = bf16::from_f32(core::f32::consts::LN_10); 1112 let ln_2 = bf16::from_f32(core::f32::consts::LN_2); 1113 let log10_e = bf16::from_f32(core::f32::consts::LOG10_E); 1114 // core::f32::consts::LOG10_2 requires rustc 1.43.0 1115 let log10_2 = bf16::from_f32(2f32.log10()); 1116 let log2_e = bf16::from_f32(core::f32::consts::LOG2_E); 1117 // core::f32::consts::LOG2_10 requires rustc 1.43.0 1118 let log2_10 = bf16::from_f32(10f32.log2()); 1119 let sqrt_2 = bf16::from_f32(core::f32::consts::SQRT_2); 1120 1121 assert_eq!(bf16::E, e); 1122 assert_eq!(bf16::PI, pi); 1123 assert_eq!(bf16::FRAC_1_PI, frac_1_pi); 1124 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); 1125 assert_eq!(bf16::FRAC_2_PI, frac_2_pi); 1126 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); 1127 assert_eq!(bf16::FRAC_PI_2, frac_pi_2); 1128 assert_eq!(bf16::FRAC_PI_3, frac_pi_3); 1129 assert_eq!(bf16::FRAC_PI_4, frac_pi_4); 1130 assert_eq!(bf16::FRAC_PI_6, frac_pi_6); 1131 assert_eq!(bf16::FRAC_PI_8, frac_pi_8); 1132 assert_eq!(bf16::LN_10, ln_10); 1133 assert_eq!(bf16::LN_2, ln_2); 1134 assert_eq!(bf16::LOG10_E, log10_e); 1135 assert_eq!(bf16::LOG10_2, log10_2); 1136 assert_eq!(bf16::LOG2_E, log2_e); 1137 assert_eq!(bf16::LOG2_10, log2_10); 1138 assert_eq!(bf16::SQRT_2, sqrt_2); 1139 } 1140 1141 #[test] test_bf16_consts_from_f64()1142 fn test_bf16_consts_from_f64() { 1143 let one = bf16::from_f64(1.0); 1144 let zero = bf16::from_f64(0.0); 1145 let neg_zero = bf16::from_f64(-0.0); 1146 let inf = bf16::from_f64(core::f64::INFINITY); 1147 let neg_inf = bf16::from_f64(core::f64::NEG_INFINITY); 1148 let nan = bf16::from_f64(core::f64::NAN); 1149 1150 assert_eq!(bf16::ONE, one); 1151 assert_eq!(bf16::ZERO, zero); 1152 assert_eq!(bf16::NEG_ZERO, neg_zero); 1153 assert_eq!(bf16::INFINITY, inf); 1154 assert_eq!(bf16::NEG_INFINITY, neg_inf); 1155 assert!(nan.is_nan()); 1156 assert!(bf16::NAN.is_nan()); 1157 1158 let e = bf16::from_f64(core::f64::consts::E); 1159 let pi = bf16::from_f64(core::f64::consts::PI); 1160 let frac_1_pi = bf16::from_f64(core::f64::consts::FRAC_1_PI); 1161 let frac_1_sqrt_2 = bf16::from_f64(core::f64::consts::FRAC_1_SQRT_2); 1162 let frac_2_pi = bf16::from_f64(core::f64::consts::FRAC_2_PI); 1163 let frac_2_sqrt_pi = bf16::from_f64(core::f64::consts::FRAC_2_SQRT_PI); 1164 let frac_pi_2 = bf16::from_f64(core::f64::consts::FRAC_PI_2); 1165 let frac_pi_3 = bf16::from_f64(core::f64::consts::FRAC_PI_3); 1166 let frac_pi_4 = bf16::from_f64(core::f64::consts::FRAC_PI_4); 1167 let frac_pi_6 = bf16::from_f64(core::f64::consts::FRAC_PI_6); 1168 let frac_pi_8 = bf16::from_f64(core::f64::consts::FRAC_PI_8); 1169 let ln_10 = bf16::from_f64(core::f64::consts::LN_10); 1170 let ln_2 = bf16::from_f64(core::f64::consts::LN_2); 1171 let log10_e = bf16::from_f64(core::f64::consts::LOG10_E); 1172 // core::f64::consts::LOG10_2 requires rustc 1.43.0 1173 let log10_2 = bf16::from_f64(2f64.log10()); 1174 let log2_e = bf16::from_f64(core::f64::consts::LOG2_E); 1175 // core::f64::consts::LOG2_10 requires rustc 1.43.0 1176 let log2_10 = bf16::from_f64(10f64.log2()); 1177 let sqrt_2 = bf16::from_f64(core::f64::consts::SQRT_2); 1178 1179 assert_eq!(bf16::E, e); 1180 assert_eq!(bf16::PI, pi); 1181 assert_eq!(bf16::FRAC_1_PI, frac_1_pi); 1182 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); 1183 assert_eq!(bf16::FRAC_2_PI, frac_2_pi); 1184 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); 1185 assert_eq!(bf16::FRAC_PI_2, frac_pi_2); 1186 assert_eq!(bf16::FRAC_PI_3, frac_pi_3); 1187 assert_eq!(bf16::FRAC_PI_4, frac_pi_4); 1188 assert_eq!(bf16::FRAC_PI_6, frac_pi_6); 1189 assert_eq!(bf16::FRAC_PI_8, frac_pi_8); 1190 assert_eq!(bf16::LN_10, ln_10); 1191 assert_eq!(bf16::LN_2, ln_2); 1192 assert_eq!(bf16::LOG10_E, log10_e); 1193 assert_eq!(bf16::LOG10_2, log10_2); 1194 assert_eq!(bf16::LOG2_E, log2_e); 1195 assert_eq!(bf16::LOG2_10, log2_10); 1196 assert_eq!(bf16::SQRT_2, sqrt_2); 1197 } 1198 1199 #[test] test_nan_conversion_to_smaller()1200 fn test_nan_conversion_to_smaller() { 1201 let nan64 = f64::from_bits(0x7FF0_0000_0000_0001u64); 1202 let neg_nan64 = f64::from_bits(0xFFF0_0000_0000_0001u64); 1203 let nan32 = f32::from_bits(0x7F80_0001u32); 1204 let neg_nan32 = f32::from_bits(0xFF80_0001u32); 1205 let nan32_from_64 = nan64 as f32; 1206 let neg_nan32_from_64 = neg_nan64 as f32; 1207 let nan16_from_64 = bf16::from_f64(nan64); 1208 let neg_nan16_from_64 = bf16::from_f64(neg_nan64); 1209 let nan16_from_32 = bf16::from_f32(nan32); 1210 let neg_nan16_from_32 = bf16::from_f32(neg_nan32); 1211 1212 assert!(nan64.is_nan() && nan64.is_sign_positive()); 1213 assert!(neg_nan64.is_nan() && neg_nan64.is_sign_negative()); 1214 assert!(nan32.is_nan() && nan32.is_sign_positive()); 1215 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); 1216 assert!(nan32_from_64.is_nan() && nan32_from_64.is_sign_positive()); 1217 assert!(neg_nan32_from_64.is_nan() && neg_nan32_from_64.is_sign_negative()); 1218 assert!(nan16_from_64.is_nan() && nan16_from_64.is_sign_positive()); 1219 assert!(neg_nan16_from_64.is_nan() && neg_nan16_from_64.is_sign_negative()); 1220 assert!(nan16_from_32.is_nan() && nan16_from_32.is_sign_positive()); 1221 assert!(neg_nan16_from_32.is_nan() && neg_nan16_from_32.is_sign_negative()); 1222 } 1223 1224 #[test] test_nan_conversion_to_larger()1225 fn test_nan_conversion_to_larger() { 1226 let nan16 = bf16::from_bits(0x7F81u16); 1227 let neg_nan16 = bf16::from_bits(0xFF81u16); 1228 let nan32 = f32::from_bits(0x7F80_0001u32); 1229 let neg_nan32 = f32::from_bits(0xFF80_0001u32); 1230 let nan32_from_16 = f32::from(nan16); 1231 let neg_nan32_from_16 = f32::from(neg_nan16); 1232 let nan64_from_16 = f64::from(nan16); 1233 let neg_nan64_from_16 = f64::from(neg_nan16); 1234 let nan64_from_32 = f64::from(nan32); 1235 let neg_nan64_from_32 = f64::from(neg_nan32); 1236 1237 assert!(nan16.is_nan() && nan16.is_sign_positive()); 1238 assert!(neg_nan16.is_nan() && neg_nan16.is_sign_negative()); 1239 assert!(nan32.is_nan() && nan32.is_sign_positive()); 1240 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); 1241 assert!(nan32_from_16.is_nan() && nan32_from_16.is_sign_positive()); 1242 assert!(neg_nan32_from_16.is_nan() && neg_nan32_from_16.is_sign_negative()); 1243 assert!(nan64_from_16.is_nan() && nan64_from_16.is_sign_positive()); 1244 assert!(neg_nan64_from_16.is_nan() && neg_nan64_from_16.is_sign_negative()); 1245 assert!(nan64_from_32.is_nan() && nan64_from_32.is_sign_positive()); 1246 assert!(neg_nan64_from_32.is_nan() && neg_nan64_from_32.is_sign_negative()); 1247 } 1248 1249 #[test] test_bf16_to_f32()1250 fn test_bf16_to_f32() { 1251 let f = bf16::from_f32(7.0); 1252 assert_eq!(f.to_f32(), 7.0f32); 1253 1254 // 7.1 is NOT exactly representable in 16-bit, it's rounded 1255 let f = bf16::from_f32(7.1); 1256 let diff = (f.to_f32() - 7.1f32).abs(); 1257 // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 1258 assert!(diff <= 4.0 * bf16::EPSILON.to_f32()); 1259 1260 let tiny32 = f32::from_bits(0x0001_0000u32); 1261 assert_eq!(bf16::from_bits(0x0001).to_f32(), tiny32); 1262 assert_eq!(bf16::from_bits(0x0005).to_f32(), 5.0 * tiny32); 1263 1264 assert_eq!(bf16::from_bits(0x0001), bf16::from_f32(tiny32)); 1265 assert_eq!(bf16::from_bits(0x0005), bf16::from_f32(5.0 * tiny32)); 1266 } 1267 1268 #[test] test_bf16_to_f64()1269 fn test_bf16_to_f64() { 1270 let f = bf16::from_f64(7.0); 1271 assert_eq!(f.to_f64(), 7.0f64); 1272 1273 // 7.1 is NOT exactly representable in 16-bit, it's rounded 1274 let f = bf16::from_f64(7.1); 1275 let diff = (f.to_f64() - 7.1f64).abs(); 1276 // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 1277 assert!(diff <= 4.0 * bf16::EPSILON.to_f64()); 1278 1279 let tiny64 = 2.0f64.powi(-133); 1280 assert_eq!(bf16::from_bits(0x0001).to_f64(), tiny64); 1281 assert_eq!(bf16::from_bits(0x0005).to_f64(), 5.0 * tiny64); 1282 1283 assert_eq!(bf16::from_bits(0x0001), bf16::from_f64(tiny64)); 1284 assert_eq!(bf16::from_bits(0x0005), bf16::from_f64(5.0 * tiny64)); 1285 } 1286 1287 #[test] test_comparisons()1288 fn test_comparisons() { 1289 let zero = bf16::from_f64(0.0); 1290 let one = bf16::from_f64(1.0); 1291 let neg_zero = bf16::from_f64(-0.0); 1292 let neg_one = bf16::from_f64(-1.0); 1293 1294 assert_eq!(zero.partial_cmp(&neg_zero), Some(Ordering::Equal)); 1295 assert_eq!(neg_zero.partial_cmp(&zero), Some(Ordering::Equal)); 1296 assert!(zero == neg_zero); 1297 assert!(neg_zero == zero); 1298 assert!(!(zero != neg_zero)); 1299 assert!(!(neg_zero != zero)); 1300 assert!(!(zero < neg_zero)); 1301 assert!(!(neg_zero < zero)); 1302 assert!(zero <= neg_zero); 1303 assert!(neg_zero <= zero); 1304 assert!(!(zero > neg_zero)); 1305 assert!(!(neg_zero > zero)); 1306 assert!(zero >= neg_zero); 1307 assert!(neg_zero >= zero); 1308 1309 assert_eq!(one.partial_cmp(&neg_zero), Some(Ordering::Greater)); 1310 assert_eq!(neg_zero.partial_cmp(&one), Some(Ordering::Less)); 1311 assert!(!(one == neg_zero)); 1312 assert!(!(neg_zero == one)); 1313 assert!(one != neg_zero); 1314 assert!(neg_zero != one); 1315 assert!(!(one < neg_zero)); 1316 assert!(neg_zero < one); 1317 assert!(!(one <= neg_zero)); 1318 assert!(neg_zero <= one); 1319 assert!(one > neg_zero); 1320 assert!(!(neg_zero > one)); 1321 assert!(one >= neg_zero); 1322 assert!(!(neg_zero >= one)); 1323 1324 assert_eq!(one.partial_cmp(&neg_one), Some(Ordering::Greater)); 1325 assert_eq!(neg_one.partial_cmp(&one), Some(Ordering::Less)); 1326 assert!(!(one == neg_one)); 1327 assert!(!(neg_one == one)); 1328 assert!(one != neg_one); 1329 assert!(neg_one != one); 1330 assert!(!(one < neg_one)); 1331 assert!(neg_one < one); 1332 assert!(!(one <= neg_one)); 1333 assert!(neg_one <= one); 1334 assert!(one > neg_one); 1335 assert!(!(neg_one > one)); 1336 assert!(one >= neg_one); 1337 assert!(!(neg_one >= one)); 1338 } 1339 1340 #[test] 1341 #[allow(clippy::erasing_op, clippy::identity_op)] round_to_even_f32()1342 fn round_to_even_f32() { 1343 // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 1344 let min_sub = bf16::from_bits(1); 1345 let min_sub_f = (-133f32).exp2(); 1346 assert_eq!(bf16::from_f32(min_sub_f).to_bits(), min_sub.to_bits()); 1347 assert_eq!(f32::from(min_sub).to_bits(), min_sub_f.to_bits()); 1348 1349 // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) 1350 // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) 1351 // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) 1352 assert_eq!( 1353 bf16::from_f32(min_sub_f * 0.49).to_bits(), 1354 min_sub.to_bits() * 0 1355 ); 1356 assert_eq!( 1357 bf16::from_f32(min_sub_f * 0.50).to_bits(), 1358 min_sub.to_bits() * 0 1359 ); 1360 assert_eq!( 1361 bf16::from_f32(min_sub_f * 0.51).to_bits(), 1362 min_sub.to_bits() * 1 1363 ); 1364 1365 // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) 1366 // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) 1367 // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) 1368 assert_eq!( 1369 bf16::from_f32(min_sub_f * 1.49).to_bits(), 1370 min_sub.to_bits() * 1 1371 ); 1372 assert_eq!( 1373 bf16::from_f32(min_sub_f * 1.50).to_bits(), 1374 min_sub.to_bits() * 2 1375 ); 1376 assert_eq!( 1377 bf16::from_f32(min_sub_f * 1.51).to_bits(), 1378 min_sub.to_bits() * 2 1379 ); 1380 1381 // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) 1382 // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) 1383 // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) 1384 assert_eq!( 1385 bf16::from_f32(min_sub_f * 2.49).to_bits(), 1386 min_sub.to_bits() * 2 1387 ); 1388 assert_eq!( 1389 bf16::from_f32(min_sub_f * 2.50).to_bits(), 1390 min_sub.to_bits() * 2 1391 ); 1392 assert_eq!( 1393 bf16::from_f32(min_sub_f * 2.51).to_bits(), 1394 min_sub.to_bits() * 3 1395 ); 1396 1397 assert_eq!( 1398 bf16::from_f32(250.49f32).to_bits(), 1399 bf16::from_f32(250.0).to_bits() 1400 ); 1401 assert_eq!( 1402 bf16::from_f32(250.50f32).to_bits(), 1403 bf16::from_f32(250.0).to_bits() 1404 ); 1405 assert_eq!( 1406 bf16::from_f32(250.51f32).to_bits(), 1407 bf16::from_f32(251.0).to_bits() 1408 ); 1409 assert_eq!( 1410 bf16::from_f32(251.49f32).to_bits(), 1411 bf16::from_f32(251.0).to_bits() 1412 ); 1413 assert_eq!( 1414 bf16::from_f32(251.50f32).to_bits(), 1415 bf16::from_f32(252.0).to_bits() 1416 ); 1417 assert_eq!( 1418 bf16::from_f32(251.51f32).to_bits(), 1419 bf16::from_f32(252.0).to_bits() 1420 ); 1421 assert_eq!( 1422 bf16::from_f32(252.49f32).to_bits(), 1423 bf16::from_f32(252.0).to_bits() 1424 ); 1425 assert_eq!( 1426 bf16::from_f32(252.50f32).to_bits(), 1427 bf16::from_f32(252.0).to_bits() 1428 ); 1429 assert_eq!( 1430 bf16::from_f32(252.51f32).to_bits(), 1431 bf16::from_f32(253.0).to_bits() 1432 ); 1433 } 1434 1435 #[test] 1436 #[allow(clippy::erasing_op, clippy::identity_op)] round_to_even_f64()1437 fn round_to_even_f64() { 1438 // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 1439 let min_sub = bf16::from_bits(1); 1440 let min_sub_f = (-133f64).exp2(); 1441 assert_eq!(bf16::from_f64(min_sub_f).to_bits(), min_sub.to_bits()); 1442 assert_eq!(f64::from(min_sub).to_bits(), min_sub_f.to_bits()); 1443 1444 // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) 1445 // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) 1446 // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) 1447 assert_eq!( 1448 bf16::from_f64(min_sub_f * 0.49).to_bits(), 1449 min_sub.to_bits() * 0 1450 ); 1451 assert_eq!( 1452 bf16::from_f64(min_sub_f * 0.50).to_bits(), 1453 min_sub.to_bits() * 0 1454 ); 1455 assert_eq!( 1456 bf16::from_f64(min_sub_f * 0.51).to_bits(), 1457 min_sub.to_bits() * 1 1458 ); 1459 1460 // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) 1461 // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) 1462 // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) 1463 assert_eq!( 1464 bf16::from_f64(min_sub_f * 1.49).to_bits(), 1465 min_sub.to_bits() * 1 1466 ); 1467 assert_eq!( 1468 bf16::from_f64(min_sub_f * 1.50).to_bits(), 1469 min_sub.to_bits() * 2 1470 ); 1471 assert_eq!( 1472 bf16::from_f64(min_sub_f * 1.51).to_bits(), 1473 min_sub.to_bits() * 2 1474 ); 1475 1476 // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) 1477 // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) 1478 // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) 1479 assert_eq!( 1480 bf16::from_f64(min_sub_f * 2.49).to_bits(), 1481 min_sub.to_bits() * 2 1482 ); 1483 assert_eq!( 1484 bf16::from_f64(min_sub_f * 2.50).to_bits(), 1485 min_sub.to_bits() * 2 1486 ); 1487 assert_eq!( 1488 bf16::from_f64(min_sub_f * 2.51).to_bits(), 1489 min_sub.to_bits() * 3 1490 ); 1491 1492 assert_eq!( 1493 bf16::from_f64(250.49f64).to_bits(), 1494 bf16::from_f64(250.0).to_bits() 1495 ); 1496 assert_eq!( 1497 bf16::from_f64(250.50f64).to_bits(), 1498 bf16::from_f64(250.0).to_bits() 1499 ); 1500 assert_eq!( 1501 bf16::from_f64(250.51f64).to_bits(), 1502 bf16::from_f64(251.0).to_bits() 1503 ); 1504 assert_eq!( 1505 bf16::from_f64(251.49f64).to_bits(), 1506 bf16::from_f64(251.0).to_bits() 1507 ); 1508 assert_eq!( 1509 bf16::from_f64(251.50f64).to_bits(), 1510 bf16::from_f64(252.0).to_bits() 1511 ); 1512 assert_eq!( 1513 bf16::from_f64(251.51f64).to_bits(), 1514 bf16::from_f64(252.0).to_bits() 1515 ); 1516 assert_eq!( 1517 bf16::from_f64(252.49f64).to_bits(), 1518 bf16::from_f64(252.0).to_bits() 1519 ); 1520 assert_eq!( 1521 bf16::from_f64(252.50f64).to_bits(), 1522 bf16::from_f64(252.0).to_bits() 1523 ); 1524 assert_eq!( 1525 bf16::from_f64(252.51f64).to_bits(), 1526 bf16::from_f64(253.0).to_bits() 1527 ); 1528 } 1529 1530 impl quickcheck::Arbitrary for bf16 { arbitrary(g: &mut quickcheck::Gen) -> Self1531 fn arbitrary(g: &mut quickcheck::Gen) -> Self { 1532 bf16(u16::arbitrary(g)) 1533 } 1534 } 1535 1536 #[quickcheck] qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool1537 fn qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool { 1538 let roundtrip = bf16::from_f32(f.to_f32()); 1539 if f.is_nan() { 1540 roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() 1541 } else { 1542 f.0 == roundtrip.0 1543 } 1544 } 1545 1546 #[quickcheck] qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool1547 fn qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool { 1548 let roundtrip = bf16::from_f64(f.to_f64()); 1549 if f.is_nan() { 1550 roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() 1551 } else { 1552 f.0 == roundtrip.0 1553 } 1554 } 1555 } 1556