1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 //! Math helper functions
10 
11 #[cfg(feature="simd_support")]
12 use packed_simd::*;
13 #[cfg(feature="std")]
14 use distributions::ziggurat_tables;
15 #[cfg(feature="std")]
16 use Rng;
17 
18 
19 pub trait WideningMultiply<RHS = Self> {
20     type Output;
21 
wmul(self, x: RHS) -> Self::Output22     fn wmul(self, x: RHS) -> Self::Output;
23 }
24 
25 macro_rules! wmul_impl {
26     ($ty:ty, $wide:ty, $shift:expr) => {
27         impl WideningMultiply for $ty {
28             type Output = ($ty, $ty);
29 
30             #[inline(always)]
31             fn wmul(self, x: $ty) -> Self::Output {
32                 let tmp = (self as $wide) * (x as $wide);
33                 ((tmp >> $shift) as $ty, tmp as $ty)
34             }
35         }
36     };
37 
38     // simd bulk implementation
39     ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
40         $(
41             impl WideningMultiply for $ty {
42                 type Output = ($ty, $ty);
43 
44                 #[inline(always)]
45                 fn wmul(self, x: $ty) -> Self::Output {
46                     // For supported vectors, this should compile to a couple
47                     // supported multiply & swizzle instructions (no actual
48                     // casting).
49                     // TODO: optimize
50                     let y: $wide = self.cast();
51                     let x: $wide = x.cast();
52                     let tmp = y * x;
53                     let hi: $ty = (tmp >> $shift).cast();
54                     let lo: $ty = tmp.cast();
55                     (hi, lo)
56                 }
57             }
58         )+
59     };
60 }
61 wmul_impl! { u8, u16, 8 }
62 wmul_impl! { u16, u32, 16 }
63 wmul_impl! { u32, u64, 32 }
64 #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
65 wmul_impl! { u64, u128, 64 }
66 
67 // This code is a translation of the __mulddi3 function in LLVM's
68 // compiler-rt. It is an optimised variant of the common method
69 // `(a + b) * (c + d) = ac + ad + bc + bd`.
70 //
71 // For some reason LLVM can optimise the C version very well, but
72 // keeps shuffling registers in this Rust translation.
73 macro_rules! wmul_impl_large {
74     ($ty:ty, $half:expr) => {
75         impl WideningMultiply for $ty {
76             type Output = ($ty, $ty);
77 
78             #[inline(always)]
79             fn wmul(self, b: $ty) -> Self::Output {
80                 const LOWER_MASK: $ty = !0 >> $half;
81                 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
82                 let mut t = low >> $half;
83                 low &= LOWER_MASK;
84                 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
85                 low += (t & LOWER_MASK) << $half;
86                 let mut high = t >> $half;
87                 t = low >> $half;
88                 low &= LOWER_MASK;
89                 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
90                 low += (t & LOWER_MASK) << $half;
91                 high += t >> $half;
92                 high += (self >> $half).wrapping_mul(b >> $half);
93 
94                 (high, low)
95             }
96         }
97     };
98 
99     // simd bulk implementation
100     (($($ty:ty,)+) $scalar:ty, $half:expr) => {
101         $(
102             impl WideningMultiply for $ty {
103                 type Output = ($ty, $ty);
104 
105                 #[inline(always)]
106                 fn wmul(self, b: $ty) -> Self::Output {
107                     // needs wrapping multiplication
108                     const LOWER_MASK: $scalar = !0 >> $half;
109                     let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
110                     let mut t = low >> $half;
111                     low &= LOWER_MASK;
112                     t += (self >> $half) * (b & LOWER_MASK);
113                     low += (t & LOWER_MASK) << $half;
114                     let mut high = t >> $half;
115                     t = low >> $half;
116                     low &= LOWER_MASK;
117                     t += (b >> $half) * (self & LOWER_MASK);
118                     low += (t & LOWER_MASK) << $half;
119                     high += t >> $half;
120                     high += (self >> $half) * (b >> $half);
121 
122                     (high, low)
123                 }
124             }
125         )+
126     };
127 }
128 #[cfg(not(all(rustc_1_26, not(target_os = "emscripten"))))]
129 wmul_impl_large! { u64, 32 }
130 #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
131 wmul_impl_large! { u128, 64 }
132 
133 macro_rules! wmul_impl_usize {
134     ($ty:ty) => {
135         impl WideningMultiply for usize {
136             type Output = (usize, usize);
137 
138             #[inline(always)]
139             fn wmul(self, x: usize) -> Self::Output {
140                 let (high, low) = (self as $ty).wmul(x as $ty);
141                 (high as usize, low as usize)
142             }
143         }
144     }
145 }
146 #[cfg(target_pointer_width = "32")]
147 wmul_impl_usize! { u32 }
148 #[cfg(target_pointer_width = "64")]
149 wmul_impl_usize! { u64 }
150 
151 #[cfg(all(feature = "simd_support", feature = "nightly"))]
152 mod simd_wmul {
153     #[cfg(target_arch = "x86")]
154     use core::arch::x86::*;
155     #[cfg(target_arch = "x86_64")]
156     use core::arch::x86_64::*;
157     use super::*;
158 
159     wmul_impl! {
160         (u8x2, u16x2),
161         (u8x4, u16x4),
162         (u8x8, u16x8),
163         (u8x16, u16x16),
164         (u8x32, u16x32),,
165         8
166     }
167 
168     wmul_impl! { (u16x2, u32x2),, 16 }
169     #[cfg(not(target_feature = "sse2"))]
170     wmul_impl! { (u16x4, u32x4),, 16 }
171     #[cfg(not(target_feature = "sse4.2"))]
172     wmul_impl! { (u16x8, u32x8),, 16 }
173     #[cfg(not(target_feature = "avx2"))]
174     wmul_impl! { (u16x16, u32x16),, 16 }
175 
176     // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
177     // means `wmul` can be implemented with only two instructions.
178     #[allow(unused_macros)]
179     macro_rules! wmul_impl_16 {
180         ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
181             impl WideningMultiply for $ty {
182                 type Output = ($ty, $ty);
183 
184                 #[inline(always)]
185                 fn wmul(self, x: $ty) -> Self::Output {
186                     let b = $intrinsic::from_bits(x);
187                     let a = $intrinsic::from_bits(self);
188                     let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
189                     let lo = $ty::from_bits(unsafe { $mullo(a, b) });
190                     (hi, lo)
191                 }
192             }
193         };
194     }
195 
196     #[cfg(target_feature = "sse2")]
197     wmul_impl_16! { u16x4, __m64, _mm_mulhi_pu16, _mm_mullo_pi16 }
198     #[cfg(target_feature = "sse4.2")]
199     wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
200     #[cfg(target_feature = "avx2")]
201     wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
202     // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
203     // cannot use the same implementation.
204 
205     wmul_impl! {
206         (u32x2, u64x2),
207         (u32x4, u64x4),
208         (u32x8, u64x8),,
209         32
210     }
211 
212     // TODO: optimize, this seems to seriously slow things down
213     wmul_impl_large! { (u8x64,) u8, 4 }
214     wmul_impl_large! { (u16x32,) u16, 8 }
215     wmul_impl_large! { (u32x16,) u32, 16 }
216     wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
217 }
218 #[cfg(all(feature = "simd_support", feature = "nightly"))]
219 pub use self::simd_wmul::*;
220 
221 
222 /// Helper trait when dealing with scalar and SIMD floating point types.
223 pub(crate) trait FloatSIMDUtils {
224     // `PartialOrd` for vectors compares lexicographically. We want to compare all
225     // the individual SIMD lanes instead, and get the combined result over all
226     // lanes. This is possible using something like `a.lt(b).all()`, but we
227     // implement it as a trait so we can write the same code for `f32` and `f64`.
228     // Only the comparison functions we need are implemented.
all_lt(self, other: Self) -> bool229     fn all_lt(self, other: Self) -> bool;
all_le(self, other: Self) -> bool230     fn all_le(self, other: Self) -> bool;
all_finite(self) -> bool231     fn all_finite(self) -> bool;
232 
233     type Mask;
finite_mask(self) -> Self::Mask234     fn finite_mask(self) -> Self::Mask;
gt_mask(self, other: Self) -> Self::Mask235     fn gt_mask(self, other: Self) -> Self::Mask;
ge_mask(self, other: Self) -> Self::Mask236     fn ge_mask(self, other: Self) -> Self::Mask;
237 
238     // Decrease all lanes where the mask is `true` to the next lower value
239     // representable by the floating-point type. At least one of the lanes
240     // must be set.
decrease_masked(self, mask: Self::Mask) -> Self241     fn decrease_masked(self, mask: Self::Mask) -> Self;
242 
243     // Convert from int value. Conversion is done while retaining the numerical
244     // value, not by retaining the binary representation.
245     type UInt;
cast_from_int(i: Self::UInt) -> Self246     fn cast_from_int(i: Self::UInt) -> Self;
247 }
248 
249 /// Implement functions available in std builds but missing from core primitives
250 #[cfg(not(std))]
251 pub(crate) trait Float : Sized {
252     type Bits;
253 
is_nan(self) -> bool254     fn is_nan(self) -> bool;
is_infinite(self) -> bool255     fn is_infinite(self) -> bool;
is_finite(self) -> bool256     fn is_finite(self) -> bool;
to_bits(self) -> Self::Bits257     fn to_bits(self) -> Self::Bits;
from_bits(v: Self::Bits) -> Self258     fn from_bits(v: Self::Bits) -> Self;
259 }
260 
261 /// Implement functions on f32/f64 to give them APIs similar to SIMD types
262 pub(crate) trait FloatAsSIMD : Sized {
263     #[inline(always)]
lanes() -> usize264     fn lanes() -> usize { 1 }
265     #[inline(always)]
splat(scalar: Self) -> Self266     fn splat(scalar: Self) -> Self { scalar }
267     #[inline(always)]
extract(self, index: usize) -> Self268     fn extract(self, index: usize) -> Self { debug_assert_eq!(index, 0); self }
269     #[inline(always)]
replace(self, index: usize, new_value: Self) -> Self270     fn replace(self, index: usize, new_value: Self) -> Self { debug_assert_eq!(index, 0); new_value }
271 }
272 
273 pub(crate) trait BoolAsSIMD : Sized {
any(self) -> bool274     fn any(self) -> bool;
all(self) -> bool275     fn all(self) -> bool;
none(self) -> bool276     fn none(self) -> bool;
277 }
278 
279 impl BoolAsSIMD for bool {
280     #[inline(always)]
any(self) -> bool281     fn any(self) -> bool { self }
282     #[inline(always)]
all(self) -> bool283     fn all(self) -> bool { self }
284     #[inline(always)]
none(self) -> bool285     fn none(self) -> bool { !self }
286 }
287 
288 macro_rules! scalar_float_impl {
289     ($ty:ident, $uty:ident) => {
290         #[cfg(not(std))]
291         impl Float for $ty {
292             type Bits = $uty;
293 
294             #[inline]
295             fn is_nan(self) -> bool {
296                 self != self
297             }
298 
299             #[inline]
300             fn is_infinite(self) -> bool {
301                 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
302             }
303 
304             #[inline]
305             fn is_finite(self) -> bool {
306                 !(self.is_nan() || self.is_infinite())
307             }
308 
309             #[inline]
310             fn to_bits(self) -> Self::Bits {
311                 unsafe { ::core::mem::transmute(self) }
312             }
313 
314             #[inline]
315             fn from_bits(v: Self::Bits) -> Self {
316                 // It turns out the safety issues with sNaN were overblown! Hooray!
317                 unsafe { ::core::mem::transmute(v) }
318             }
319         }
320 
321         impl FloatSIMDUtils for $ty {
322             type Mask = bool;
323             #[inline(always)]
324             fn all_lt(self, other: Self) -> bool { self < other }
325             #[inline(always)]
326             fn all_le(self, other: Self) -> bool { self <= other }
327             #[inline(always)]
328             fn all_finite(self) -> bool { self.is_finite() }
329             #[inline(always)]
330             fn finite_mask(self) -> Self::Mask { self.is_finite() }
331             #[inline(always)]
332             fn gt_mask(self, other: Self) -> Self::Mask { self > other }
333             #[inline(always)]
334             fn ge_mask(self, other: Self) -> Self::Mask { self >= other }
335             #[inline(always)]
336             fn decrease_masked(self, mask: Self::Mask) -> Self {
337                 debug_assert!(mask, "At least one lane must be set");
338                 <$ty>::from_bits(self.to_bits() - 1)
339             }
340             type UInt = $uty;
341             fn cast_from_int(i: Self::UInt) -> Self { i as $ty }
342         }
343 
344         impl FloatAsSIMD for $ty {}
345     }
346 }
347 
348 scalar_float_impl!(f32, u32);
349 scalar_float_impl!(f64, u64);
350 
351 
352 #[cfg(feature="simd_support")]
353 macro_rules! simd_impl {
354     ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
355         impl FloatSIMDUtils for $ty {
356             type Mask = $mty;
357             #[inline(always)]
358             fn all_lt(self, other: Self) -> bool { self.lt(other).all() }
359             #[inline(always)]
360             fn all_le(self, other: Self) -> bool { self.le(other).all() }
361             #[inline(always)]
362             fn all_finite(self) -> bool { self.finite_mask().all() }
363             #[inline(always)]
364             fn finite_mask(self) -> Self::Mask {
365                 // This can possibly be done faster by checking bit patterns
366                 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
367                 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
368                 self.gt(neg_inf) & self.lt(pos_inf)
369             }
370             #[inline(always)]
371             fn gt_mask(self, other: Self) -> Self::Mask { self.gt(other) }
372             #[inline(always)]
373             fn ge_mask(self, other: Self) -> Self::Mask { self.ge(other) }
374             #[inline(always)]
375             fn decrease_masked(self, mask: Self::Mask) -> Self {
376                 // Casting a mask into ints will produce all bits set for
377                 // true, and 0 for false. Adding that to the binary
378                 // representation of a float means subtracting one from
379                 // the binary representation, resulting in the next lower
380                 // value representable by $ty. This works even when the
381                 // current value is infinity.
382                 debug_assert!(mask.any(), "At least one lane must be set");
383                 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
384             }
385             type UInt = $uty;
386             fn cast_from_int(i: Self::UInt) -> Self { i.cast() }
387         }
388     }
389 }
390 
391 #[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
392 #[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
393 #[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
394 #[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
395 #[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
396 #[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
397 #[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
398 
399 /// Calculates ln(gamma(x)) (natural logarithm of the gamma
400 /// function) using the Lanczos approximation.
401 ///
402 /// The approximation expresses the gamma function as:
403 /// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
404 /// `g` is an arbitrary constant; we use the approximation with `g=5`.
405 ///
406 /// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
407 /// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
408 ///
409 /// `Ag(z)` is an infinite series with coefficients that can be calculated
410 /// ahead of time - we use just the first 6 terms, which is good enough
411 /// for most purposes.
412 #[cfg(feature="std")]
log_gamma(x: f64) -> f64413 pub fn log_gamma(x: f64) -> f64 {
414     // precalculated 6 coefficients for the first 6 terms of the series
415     let coefficients: [f64; 6] = [
416         76.18009172947146,
417         -86.50532032941677,
418         24.01409824083091,
419         -1.231739572450155,
420         0.1208650973866179e-2,
421         -0.5395239384953e-5,
422     ];
423 
424     // (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
425     let tmp = x + 5.5;
426     let log = (x + 0.5) * tmp.ln() - tmp;
427 
428     // the first few terms of the series for Ag(x)
429     let mut a = 1.000000000190015;
430     let mut denom = x;
431     for coeff in &coefficients {
432         denom += 1.0;
433         a += coeff / denom;
434     }
435 
436     // get everything together
437     // a is Ag(x)
438     // 2.5066... is sqrt(2pi)
439     log + (2.5066282746310005 * a / x).ln()
440 }
441 
442 /// Sample a random number using the Ziggurat method (specifically the
443 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
444 /// directly from the paper:
445 ///
446 /// * `rng`: source of randomness
447 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
448 /// * `X`: the $x_i$ abscissae.
449 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
450 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
451 /// * `pdf`: the probability density function
452 /// * `zero_case`: manual sampling from the tail when we chose the
453 ///    bottom box (i.e. i == 0)
454 
455 // the perf improvement (25-50%) is definitely worth the extra code
456 // size from force-inlining.
457 #[cfg(feature="std")]
458 #[inline(always)]
ziggurat<R: Rng + ?Sized, P, Z>( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable, mut pdf: P, mut zero_case: Z) -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64459 pub fn ziggurat<R: Rng + ?Sized, P, Z>(
460             rng: &mut R,
461             symmetric: bool,
462             x_tab: ziggurat_tables::ZigTable,
463             f_tab: ziggurat_tables::ZigTable,
464             mut pdf: P,
465             mut zero_case: Z)
466             -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 {
467     use distributions::float::IntoFloat;
468     loop {
469         // As an optimisation we re-implement the conversion to a f64.
470         // From the remaining 12 most significant bits we use 8 to construct `i`.
471         // This saves us generating a whole extra random number, while the added
472         // precision of using 64 bits for f64 does not buy us much.
473         let bits = rng.next_u64();
474         let i = bits as usize & 0xff;
475 
476         let u = if symmetric {
477             // Convert to a value in the range [2,4) and substract to get [-1,1)
478             // We can't convert to an open range directly, that would require
479             // substracting `3.0 - EPSILON`, which is not representable.
480             // It is possible with an extra step, but an open range does not
481             // seem neccesary for the ziggurat algorithm anyway.
482             (bits >> 12).into_float_with_exponent(1) - 3.0
483         } else {
484             // Convert to a value in the range [1,2) and substract to get (0,1)
485             (bits >> 12).into_float_with_exponent(0)
486             - (1.0 - ::core::f64::EPSILON / 2.0)
487         };
488         let x = u * x_tab[i];
489 
490         let test_x = if symmetric { x.abs() } else {x};
491 
492         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
493         if test_x < x_tab[i + 1] {
494             return x;
495         }
496         if i == 0 {
497             return zero_case(rng, u);
498         }
499         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
500         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
501             return x;
502         }
503     }
504 }
505