1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 //! Math helper functions
10 
11 #[cfg(feature = "simd_support")] use packed_simd::*;
12 
13 
14 pub(crate) trait WideningMultiply<RHS = Self> {
15     type Output;
16 
wmul(self, x: RHS) -> Self::Output17     fn wmul(self, x: RHS) -> Self::Output;
18 }
19 
20 macro_rules! wmul_impl {
21     ($ty:ty, $wide:ty, $shift:expr) => {
22         impl WideningMultiply for $ty {
23             type Output = ($ty, $ty);
24 
25             #[inline(always)]
26             fn wmul(self, x: $ty) -> Self::Output {
27                 let tmp = (self as $wide) * (x as $wide);
28                 ((tmp >> $shift) as $ty, tmp as $ty)
29             }
30         }
31     };
32 
33     // simd bulk implementation
34     ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
35         $(
36             impl WideningMultiply for $ty {
37                 type Output = ($ty, $ty);
38 
39                 #[inline(always)]
40                 fn wmul(self, x: $ty) -> Self::Output {
41                     // For supported vectors, this should compile to a couple
42                     // supported multiply & swizzle instructions (no actual
43                     // casting).
44                     // TODO: optimize
45                     let y: $wide = self.cast();
46                     let x: $wide = x.cast();
47                     let tmp = y * x;
48                     let hi: $ty = (tmp >> $shift).cast();
49                     let lo: $ty = tmp.cast();
50                     (hi, lo)
51                 }
52             }
53         )+
54     };
55 }
56 wmul_impl! { u8, u16, 8 }
57 wmul_impl! { u16, u32, 16 }
58 wmul_impl! { u32, u64, 32 }
59 #[cfg(not(target_os = "emscripten"))]
60 wmul_impl! { u64, u128, 64 }
61 
62 // This code is a translation of the __mulddi3 function in LLVM's
63 // compiler-rt. It is an optimised variant of the common method
64 // `(a + b) * (c + d) = ac + ad + bc + bd`.
65 //
66 // For some reason LLVM can optimise the C version very well, but
67 // keeps shuffling registers in this Rust translation.
68 macro_rules! wmul_impl_large {
69     ($ty:ty, $half:expr) => {
70         impl WideningMultiply for $ty {
71             type Output = ($ty, $ty);
72 
73             #[inline(always)]
74             fn wmul(self, b: $ty) -> Self::Output {
75                 const LOWER_MASK: $ty = !0 >> $half;
76                 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
77                 let mut t = low >> $half;
78                 low &= LOWER_MASK;
79                 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
80                 low += (t & LOWER_MASK) << $half;
81                 let mut high = t >> $half;
82                 t = low >> $half;
83                 low &= LOWER_MASK;
84                 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
85                 low += (t & LOWER_MASK) << $half;
86                 high += t >> $half;
87                 high += (self >> $half).wrapping_mul(b >> $half);
88 
89                 (high, low)
90             }
91         }
92     };
93 
94     // simd bulk implementation
95     (($($ty:ty,)+) $scalar:ty, $half:expr) => {
96         $(
97             impl WideningMultiply for $ty {
98                 type Output = ($ty, $ty);
99 
100                 #[inline(always)]
101                 fn wmul(self, b: $ty) -> Self::Output {
102                     // needs wrapping multiplication
103                     const LOWER_MASK: $scalar = !0 >> $half;
104                     let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
105                     let mut t = low >> $half;
106                     low &= LOWER_MASK;
107                     t += (self >> $half) * (b & LOWER_MASK);
108                     low += (t & LOWER_MASK) << $half;
109                     let mut high = t >> $half;
110                     t = low >> $half;
111                     low &= LOWER_MASK;
112                     t += (b >> $half) * (self & LOWER_MASK);
113                     low += (t & LOWER_MASK) << $half;
114                     high += t >> $half;
115                     high += (self >> $half) * (b >> $half);
116 
117                     (high, low)
118                 }
119             }
120         )+
121     };
122 }
123 #[cfg(target_os = "emscripten")]
124 wmul_impl_large! { u64, 32 }
125 #[cfg(not(target_os = "emscripten"))]
126 wmul_impl_large! { u128, 64 }
127 
128 macro_rules! wmul_impl_usize {
129     ($ty:ty) => {
130         impl WideningMultiply for usize {
131             type Output = (usize, usize);
132 
133             #[inline(always)]
134             fn wmul(self, x: usize) -> Self::Output {
135                 let (high, low) = (self as $ty).wmul(x as $ty);
136                 (high as usize, low as usize)
137             }
138         }
139     };
140 }
141 #[cfg(target_pointer_width = "32")]
142 wmul_impl_usize! { u32 }
143 #[cfg(target_pointer_width = "64")]
144 wmul_impl_usize! { u64 }
145 
146 #[cfg(feature = "simd_support")]
147 mod simd_wmul {
148     use super::*;
149     #[cfg(target_arch = "x86")] use core::arch::x86::*;
150     #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
151 
152     wmul_impl! {
153         (u8x2, u16x2),
154         (u8x4, u16x4),
155         (u8x8, u16x8),
156         (u8x16, u16x16),
157         (u8x32, u16x32),,
158         8
159     }
160 
161     wmul_impl! { (u16x2, u32x2),, 16 }
162     wmul_impl! { (u16x4, u32x4),, 16 }
163     #[cfg(not(target_feature = "sse2"))]
164     wmul_impl! { (u16x8, u32x8),, 16 }
165     #[cfg(not(target_feature = "avx2"))]
166     wmul_impl! { (u16x16, u32x16),, 16 }
167 
168     // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
169     // means `wmul` can be implemented with only two instructions.
170     #[allow(unused_macros)]
171     macro_rules! wmul_impl_16 {
172         ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
173             impl WideningMultiply for $ty {
174                 type Output = ($ty, $ty);
175 
176                 #[inline(always)]
177                 fn wmul(self, x: $ty) -> Self::Output {
178                     let b = $intrinsic::from_bits(x);
179                     let a = $intrinsic::from_bits(self);
180                     let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
181                     let lo = $ty::from_bits(unsafe { $mullo(a, b) });
182                     (hi, lo)
183                 }
184             }
185         };
186     }
187 
188     #[cfg(target_feature = "sse2")]
189     wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
190     #[cfg(target_feature = "avx2")]
191     wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
192     // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
193     // cannot use the same implementation.
194 
195     wmul_impl! {
196         (u32x2, u64x2),
197         (u32x4, u64x4),
198         (u32x8, u64x8),,
199         32
200     }
201 
202     // TODO: optimize, this seems to seriously slow things down
203     wmul_impl_large! { (u8x64,) u8, 4 }
204     wmul_impl_large! { (u16x32,) u16, 8 }
205     wmul_impl_large! { (u32x16,) u32, 16 }
206     wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
207 }
208 
209 /// Helper trait when dealing with scalar and SIMD floating point types.
210 pub(crate) trait FloatSIMDUtils {
211     // `PartialOrd` for vectors compares lexicographically. We want to compare all
212     // the individual SIMD lanes instead, and get the combined result over all
213     // lanes. This is possible using something like `a.lt(b).all()`, but we
214     // implement it as a trait so we can write the same code for `f32` and `f64`.
215     // Only the comparison functions we need are implemented.
all_lt(self, other: Self) -> bool216     fn all_lt(self, other: Self) -> bool;
all_le(self, other: Self) -> bool217     fn all_le(self, other: Self) -> bool;
all_finite(self) -> bool218     fn all_finite(self) -> bool;
219 
220     type Mask;
finite_mask(self) -> Self::Mask221     fn finite_mask(self) -> Self::Mask;
gt_mask(self, other: Self) -> Self::Mask222     fn gt_mask(self, other: Self) -> Self::Mask;
ge_mask(self, other: Self) -> Self::Mask223     fn ge_mask(self, other: Self) -> Self::Mask;
224 
225     // Decrease all lanes where the mask is `true` to the next lower value
226     // representable by the floating-point type. At least one of the lanes
227     // must be set.
decrease_masked(self, mask: Self::Mask) -> Self228     fn decrease_masked(self, mask: Self::Mask) -> Self;
229 
230     // Convert from int value. Conversion is done while retaining the numerical
231     // value, not by retaining the binary representation.
232     type UInt;
cast_from_int(i: Self::UInt) -> Self233     fn cast_from_int(i: Self::UInt) -> Self;
234 }
235 
236 /// Implement functions available in std builds but missing from core primitives
237 #[cfg(not(std))]
238 // False positive: We are following `std` here.
239 #[allow(clippy::wrong_self_convention)]
240 pub(crate) trait Float: Sized {
is_nan(self) -> bool241     fn is_nan(self) -> bool;
is_infinite(self) -> bool242     fn is_infinite(self) -> bool;
is_finite(self) -> bool243     fn is_finite(self) -> bool;
244 }
245 
246 /// Implement functions on f32/f64 to give them APIs similar to SIMD types
247 pub(crate) trait FloatAsSIMD: Sized {
248     #[inline(always)]
lanes() -> usize249     fn lanes() -> usize {
250         1
251     }
252     #[inline(always)]
splat(scalar: Self) -> Self253     fn splat(scalar: Self) -> Self {
254         scalar
255     }
256     #[inline(always)]
extract(self, index: usize) -> Self257     fn extract(self, index: usize) -> Self {
258         debug_assert_eq!(index, 0);
259         self
260     }
261     #[inline(always)]
replace(self, index: usize, new_value: Self) -> Self262     fn replace(self, index: usize, new_value: Self) -> Self {
263         debug_assert_eq!(index, 0);
264         new_value
265     }
266 }
267 
268 pub(crate) trait BoolAsSIMD: Sized {
any(self) -> bool269     fn any(self) -> bool;
all(self) -> bool270     fn all(self) -> bool;
none(self) -> bool271     fn none(self) -> bool;
272 }
273 
274 impl BoolAsSIMD for bool {
275     #[inline(always)]
any(self) -> bool276     fn any(self) -> bool {
277         self
278     }
279 
280     #[inline(always)]
all(self) -> bool281     fn all(self) -> bool {
282         self
283     }
284 
285     #[inline(always)]
none(self) -> bool286     fn none(self) -> bool {
287         !self
288     }
289 }
290 
291 macro_rules! scalar_float_impl {
292     ($ty:ident, $uty:ident) => {
293         #[cfg(not(std))]
294         impl Float for $ty {
295             #[inline]
296             fn is_nan(self) -> bool {
297                 self != self
298             }
299 
300             #[inline]
301             fn is_infinite(self) -> bool {
302                 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
303             }
304 
305             #[inline]
306             fn is_finite(self) -> bool {
307                 !(self.is_nan() || self.is_infinite())
308             }
309         }
310 
311         impl FloatSIMDUtils for $ty {
312             type Mask = bool;
313             type UInt = $uty;
314 
315             #[inline(always)]
316             fn all_lt(self, other: Self) -> bool {
317                 self < other
318             }
319 
320             #[inline(always)]
321             fn all_le(self, other: Self) -> bool {
322                 self <= other
323             }
324 
325             #[inline(always)]
326             fn all_finite(self) -> bool {
327                 self.is_finite()
328             }
329 
330             #[inline(always)]
331             fn finite_mask(self) -> Self::Mask {
332                 self.is_finite()
333             }
334 
335             #[inline(always)]
336             fn gt_mask(self, other: Self) -> Self::Mask {
337                 self > other
338             }
339 
340             #[inline(always)]
341             fn ge_mask(self, other: Self) -> Self::Mask {
342                 self >= other
343             }
344 
345             #[inline(always)]
346             fn decrease_masked(self, mask: Self::Mask) -> Self {
347                 debug_assert!(mask, "At least one lane must be set");
348                 <$ty>::from_bits(self.to_bits() - 1)
349             }
350 
351             #[inline]
352             fn cast_from_int(i: Self::UInt) -> Self {
353                 i as $ty
354             }
355         }
356 
357         impl FloatAsSIMD for $ty {}
358     };
359 }
360 
361 scalar_float_impl!(f32, u32);
362 scalar_float_impl!(f64, u64);
363 
364 
365 #[cfg(feature = "simd_support")]
366 macro_rules! simd_impl {
367     ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
368         impl FloatSIMDUtils for $ty {
369             type Mask = $mty;
370             type UInt = $uty;
371 
372             #[inline(always)]
373             fn all_lt(self, other: Self) -> bool {
374                 self.lt(other).all()
375             }
376 
377             #[inline(always)]
378             fn all_le(self, other: Self) -> bool {
379                 self.le(other).all()
380             }
381 
382             #[inline(always)]
383             fn all_finite(self) -> bool {
384                 self.finite_mask().all()
385             }
386 
387             #[inline(always)]
388             fn finite_mask(self) -> Self::Mask {
389                 // This can possibly be done faster by checking bit patterns
390                 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
391                 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
392                 self.gt(neg_inf) & self.lt(pos_inf)
393             }
394 
395             #[inline(always)]
396             fn gt_mask(self, other: Self) -> Self::Mask {
397                 self.gt(other)
398             }
399 
400             #[inline(always)]
401             fn ge_mask(self, other: Self) -> Self::Mask {
402                 self.ge(other)
403             }
404 
405             #[inline(always)]
406             fn decrease_masked(self, mask: Self::Mask) -> Self {
407                 // Casting a mask into ints will produce all bits set for
408                 // true, and 0 for false. Adding that to the binary
409                 // representation of a float means subtracting one from
410                 // the binary representation, resulting in the next lower
411                 // value representable by $ty. This works even when the
412                 // current value is infinity.
413                 debug_assert!(mask.any(), "At least one lane must be set");
414                 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
415             }
416 
417             #[inline]
418             fn cast_from_int(i: Self::UInt) -> Self {
419                 i.cast()
420             }
421         }
422     };
423 }
424 
425 #[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
426 #[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
427 #[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
428 #[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
429 #[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
430 #[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
431 #[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
432