1 //! Implements portable horizontal integer vector arithmetic reductions.
2 
3 macro_rules! impl_reduction_integer_arithmetic {
4     ([$elem_ty:ident; $elem_count:expr]: $id:ident | $ielem_ty:ident
5      | $test_tt:tt) => {
6         impl $id {
7             /// Horizontal wrapping sum of the vector elements.
8             ///
9             /// The intrinsic performs a tree-reduction of the vector elements.
10             /// That is, for an 8 element vector:
11             ///
12             /// > ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + (x6 + x7))
13             ///
14             /// If an operation overflows it returns the mathematical result
15             /// modulo `2^n` where `n` is the number of times it overflows.
16             #[inline]
17             pub fn wrapping_sum(self) -> $elem_ty {
18                 #[cfg(not(target_arch = "aarch64"))]
19                 {
20                     use crate::llvm::simd_reduce_add_ordered;
21                     let v: $ielem_ty = unsafe {
22                         simd_reduce_add_ordered(self.0, 0 as $ielem_ty)
23                     };
24                     v as $elem_ty
25                 }
26                 #[cfg(target_arch = "aarch64")]
27                 {
28                     // FIXME: broken on AArch64
29                     // https://github.com/rust-lang-nursery/packed_simd/issues/15
30                     let mut x = self.extract(0) as $elem_ty;
31                     for i in 1..$id::lanes() {
32                         x = x.wrapping_add(self.extract(i) as $elem_ty);
33                     }
34                     x
35                 }
36             }
37 
38             /// Horizontal wrapping product of the vector elements.
39             ///
40             /// The intrinsic performs a tree-reduction of the vector elements.
41             /// That is, for an 8 element vector:
42             ///
43             /// > ((x0 * x1) * (x2 * x3)) * ((x4 * x5) * (x6 * x7))
44             ///
45             /// If an operation overflows it returns the mathematical result
46             /// modulo `2^n` where `n` is the number of times it overflows.
47             #[inline]
48             pub fn wrapping_product(self) -> $elem_ty {
49                 #[cfg(not(target_arch = "aarch64"))]
50                 {
51                     use crate::llvm::simd_reduce_mul_ordered;
52                     let v: $ielem_ty = unsafe {
53                         simd_reduce_mul_ordered(self.0, 1 as $ielem_ty)
54                     };
55                     v as $elem_ty
56                 }
57                 #[cfg(target_arch = "aarch64")]
58                 {
59                     // FIXME: broken on AArch64
60                     // https://github.com/rust-lang-nursery/packed_simd/issues/15
61                     let mut x = self.extract(0) as $elem_ty;
62                     for i in 1..$id::lanes() {
63                         x = x.wrapping_mul(self.extract(i) as $elem_ty);
64                     }
65                     x
66                 }
67             }
68         }
69 
70         impl crate::iter::Sum for $id {
71             #[inline]
72             fn sum<I: Iterator<Item = $id>>(iter: I) -> $id {
73                 iter.fold($id::splat(0), crate::ops::Add::add)
74             }
75         }
76 
77         impl crate::iter::Product for $id {
78             #[inline]
79             fn product<I: Iterator<Item = $id>>(iter: I) -> $id {
80                 iter.fold($id::splat(1), crate::ops::Mul::mul)
81             }
82         }
83 
84         impl<'a> crate::iter::Sum<&'a $id> for $id {
85             #[inline]
86             fn sum<I: Iterator<Item = &'a $id>>(iter: I) -> $id {
87                 iter.fold($id::splat(0), |a, b| crate::ops::Add::add(a, *b))
88             }
89         }
90 
91         impl<'a> crate::iter::Product<&'a $id> for $id {
92             #[inline]
93             fn product<I: Iterator<Item = &'a $id>>(iter: I) -> $id {
94                 iter.fold($id::splat(1), |a, b| crate::ops::Mul::mul(a, *b))
95             }
96         }
97 
98         test_if! {
99             $test_tt:
100             paste::item! {
101                 pub mod [<$id _reduction_int_arith>] {
102                     use super::*;
103 
104                     fn alternating(x: usize) -> $id {
105                         let mut v = $id::splat(1 as $elem_ty);
106                         for i in 0..$id::lanes() {
107                             if i % x == 0 {
108                                 v = v.replace(i, 2 as $elem_ty);
109                             }
110                         }
111                         v
112                     }
113 
114                     #[cfg_attr(not(target_arch = "wasm32"), test)]
115                     #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
116                     fn wrapping_sum() {
117                         let v = $id::splat(0 as $elem_ty);
118                         assert_eq!(v.wrapping_sum(), 0 as $elem_ty);
119                         let v = $id::splat(1 as $elem_ty);
120                         assert_eq!(v.wrapping_sum(), $id::lanes() as $elem_ty);
121                         let v = alternating(2);
122                         if $id::lanes() > 1 {
123                             assert_eq!(
124                                 v.wrapping_sum(),
125                                 ($id::lanes() / 2 + $id::lanes()) as $elem_ty
126                             );
127                         } else {
128                             assert_eq!(
129                                 v.wrapping_sum(),
130                                 2 as $elem_ty
131                             );
132                         }
133                     }
134                     #[cfg_attr(not(target_arch = "wasm32"), test)]
135                     #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
136                     fn wrapping_sum_overflow() {
137                         let start = $elem_ty::max_value()
138                             - ($id::lanes() as $elem_ty / 2);
139 
140                         let v = $id::splat(start as $elem_ty);
141                         let vwrapping_sum = v.wrapping_sum();
142 
143                         let mut wrapping_sum = start;
144                         for _ in 1..$id::lanes() {
145                             wrapping_sum = wrapping_sum.wrapping_add(start);
146                         }
147                         assert_eq!(wrapping_sum, vwrapping_sum, "v = {:?}", v);
148                     }
149 
150                     #[cfg_attr(not(target_arch = "wasm32"), test)]
151                     #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
152                     fn wrapping_product() {
153                         let v = $id::splat(0 as $elem_ty);
154                         assert_eq!(v.wrapping_product(), 0 as $elem_ty);
155                         let v = $id::splat(1 as $elem_ty);
156                         assert_eq!(v.wrapping_product(), 1 as $elem_ty);
157                         let f = match $id::lanes() {
158                             64 => 16,
159                             32 => 8,
160                             16 => 4,
161                             _ => 2,
162                         };
163                         let v = alternating(f);
164                         if $id::lanes() > 1 {
165                             assert_eq!(
166                                 v.wrapping_product(),
167                                 (2_usize.pow(($id::lanes() / f) as u32)
168                                  as $elem_ty)
169                             );
170                         } else {
171                             assert_eq!(
172                                 v.wrapping_product(),
173                                 2 as $elem_ty
174                             );
175                         }
176                     }
177 
178                     #[cfg_attr(not(target_arch = "wasm32"), test)]
179                     #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
180                     fn wrapping_product_overflow() {
181                         let start = $elem_ty::max_value()
182                             - ($id::lanes() as $elem_ty / 2);
183 
184                         let v = $id::splat(start as $elem_ty);
185                         let vmul = v.wrapping_product();
186 
187                         let mut mul = start;
188                         for _ in 1..$id::lanes() {
189                             mul = mul.wrapping_mul(start);
190                         }
191                         assert_eq!(mul, vmul, "v = {:?}", v);
192                     }
193                 }
194             }
195         }
196     };
197 }
198