1 use std::any::TypeId;
2 
3 // Calculate the sum of an expression consisting of just plus and minus, like `value = a + b - c + d`.
4 // The expression is rewritten to `value = a + (b - (c - d))` (note the flipped sign on d).
5 // After this the `$add` and `$sub` functions are used to make the calculation.
6 // For f32 using `_mm_add_ps` and `_mm_sub_ps`, the expression `value = a + b - c + d` becomes:
7 // ```let value = _mm_add_ps(a, _mm_sub_ps(b, _mm_sub_ps(c, d)));```
8 // Only plus and minus are supported, and all the terms must be plain scalar variables.
9 // Using array indices, like `value = temp[0] + temp[1]` is not supported.
10 macro_rules! calc_sum {
11     ($add:ident, $sub:ident, + $acc:tt + $($rest:tt)*)=> {
12         $add($acc, calc_sum!($add, $sub, + $($rest)*))
13     };
14     ($add:ident, $sub:ident, + $acc:tt - $($rest:tt)*)=> {
15         $sub($acc, calc_sum!($add, $sub, - $($rest)*))
16     };
17     ($add:ident, $sub:ident, - $acc:tt + $($rest:tt)*)=> {
18         $sub($acc, calc_sum!($add, $sub, + $($rest)*))
19     };
20     ($add:ident, $sub:ident, - $acc:tt - $($rest:tt)*)=> {
21         $add($acc, calc_sum!($add, $sub, - $($rest)*))
22     };
23     ($add:ident, $sub:ident, $acc:tt + $($rest:tt)*)=> {
24         $add($acc, calc_sum!($add, $sub, + $($rest)*))
25     };
26     ($add:ident, $sub:ident, $acc:tt - $($rest:tt)*)=> {
27         $sub($acc, calc_sum!($add, $sub, - $($rest)*))
28     };
29     ($add:ident, $sub:ident, + $val:tt) => {$val};
30     ($add:ident, $sub:ident, - $val:tt) => {$val};
31 }
32 
33 // Calculate the sum of an expression consisting of just plus and minus, like a + b - c + d
34 macro_rules! calc_f32 {
35     ($($tokens:tt)*) => { calc_sum!(_mm_add_ps, _mm_sub_ps, $($tokens)*)};
36 }
37 
38 // Calculate the sum of an expression consisting of just plus and minus, like a + b - c + d
39 macro_rules! calc_f64 {
40     ($($tokens:tt)*) => { calc_sum!(_mm_add_pd, _mm_sub_pd, $($tokens)*)};
41 }
42 
43 // Helper function to assert we have the right float type
assert_f32<T: 'static>()44 pub fn assert_f32<T: 'static>() {
45     let id_f32 = TypeId::of::<f32>();
46     let id_t = TypeId::of::<T>();
47     assert!(id_t == id_f32, "Wrong float type, must be f32");
48 }
49 
50 // Helper function to assert we have the right float type
assert_f64<T: 'static>()51 pub fn assert_f64<T: 'static>() {
52     let id_f64 = TypeId::of::<f64>();
53     let id_t = TypeId::of::<T>();
54     assert!(id_t == id_f64, "Wrong float type, must be f64");
55 }
56 
57 // Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
58 macro_rules! interleave_complex_f32 {
59     ($input:ident, $offset:literal, { $($idx:literal),* }) => {
60         [
61         $(
62             extract_lo_lo_f32($input[$idx], $input[$idx+$offset]),
63             extract_hi_hi_f32($input[$idx], $input[$idx+$offset]),
64         )*
65         ]
66     }
67 }
68 
69 // Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
70 // This statement:
71 // ```
72 // let values = separate_interleaved_complex_f32!(input, {0, 2, 4});
73 // ```
74 // is equivalent to:
75 // ```
76 // let values = [
77 //    extract_lo_lo_f32(input[0], input[1]),
78 //    extract_lo_lo_f32(input[2], input[3]),
79 //    extract_lo_lo_f32(input[4], input[5]),
80 //    extract_hi_hi_f32(input[0], input[1]),
81 //    extract_hi_hi_f32(input[2], input[3]),
82 //    extract_hi_hi_f32(input[4], input[5]),
83 // ];
84 macro_rules! separate_interleaved_complex_f32 {
85     ($input:ident, { $($idx:literal),* }) => {
86         [
87         $(
88             extract_lo_lo_f32($input[$idx], $input[$idx+1]),
89         )*
90         $(
91             extract_hi_hi_f32($input[$idx], $input[$idx+1]),
92         )*
93         ]
94     }
95 }
96 
97 macro_rules! boilerplate_fft_sse_oop {
98     ($struct_name:ident, $len_fn:expr) => {
99         impl<T: FftNum> Fft<T> for $struct_name<T> {
100             fn process_outofplace_with_scratch(
101                 &self,
102                 input: &mut [Complex<T>],
103                 output: &mut [Complex<T>],
104                 _scratch: &mut [Complex<T>],
105             ) {
106                 if self.len() == 0 {
107                     return;
108                 }
109 
110                 if input.len() < self.len() || output.len() != input.len() {
111                     // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
112                     fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
113                     return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
114                 }
115 
116                 let result = unsafe {
117                     array_utils::iter_chunks_zipped(
118                         input,
119                         output,
120                         self.len(),
121                         |in_chunk, out_chunk| {
122                             self.perform_fft_out_of_place(in_chunk, out_chunk, &mut [])
123                         },
124                     )
125                 };
126 
127                 if result.is_err() {
128                     // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
129                     // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
130                     fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
131                 }
132             }
133             fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
134                 if self.len() == 0 {
135                     return;
136                 }
137 
138                 let required_scratch = self.get_inplace_scratch_len();
139                 if scratch.len() < required_scratch || buffer.len() < self.len() {
140                     // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
141                     fft_error_inplace(
142                         self.len(),
143                         buffer.len(),
144                         self.get_inplace_scratch_len(),
145                         scratch.len(),
146                     );
147                     return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
148                 }
149 
150                 let scratch = &mut scratch[..required_scratch];
151                 let result = unsafe {
152                     array_utils::iter_chunks(buffer, self.len(), |chunk| {
153                         self.perform_fft_out_of_place(chunk, scratch, &mut []);
154                         chunk.copy_from_slice(scratch);
155                     })
156                 };
157                 if result.is_err() {
158                     // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
159                     // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
160                     fft_error_inplace(
161                         self.len(),
162                         buffer.len(),
163                         self.get_inplace_scratch_len(),
164                         scratch.len(),
165                     );
166                 }
167             }
168             #[inline(always)]
169             fn get_inplace_scratch_len(&self) -> usize {
170                 self.len()
171             }
172             #[inline(always)]
173             fn get_outofplace_scratch_len(&self) -> usize {
174                 0
175             }
176         }
177         impl<T> Length for $struct_name<T> {
178             #[inline(always)]
179             fn len(&self) -> usize {
180                 $len_fn(self)
181             }
182         }
183         impl<T> Direction for $struct_name<T> {
184             #[inline(always)]
185             fn fft_direction(&self) -> FftDirection {
186                 self.direction
187             }
188         }
189     };
190 }
191 
192 /* Not used now, but maybe later for the mixed radixes etc
193 macro_rules! boilerplate_sse_fft {
194     ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => {
195         impl<T: FftNum> Fft<T> for $struct_name<T> {
196             fn process_outofplace_with_scratch(
197                 &self,
198                 input: &mut [Complex<T>],
199                 output: &mut [Complex<T>],
200                 scratch: &mut [Complex<T>],
201             ) {
202                 if self.len() == 0 {
203                     return;
204                 }
205 
206                 let required_scratch = self.get_outofplace_scratch_len();
207                 if scratch.len() < required_scratch
208                     || input.len() < self.len()
209                     || output.len() != input.len()
210                 {
211                     // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
212                     fft_error_outofplace(
213                         self.len(),
214                         input.len(),
215                         output.len(),
216                         self.get_outofplace_scratch_len(),
217                         scratch.len(),
218                     );
219                     return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
220                 }
221 
222                 let scratch = &mut scratch[..required_scratch];
223                 let result = array_utils::iter_chunks_zipped(
224                     input,
225                     output,
226                     self.len(),
227                     |in_chunk, out_chunk| {
228                         self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
229                     },
230                 );
231 
232                 if result.is_err() {
233                     // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
234                     // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
235                     fft_error_outofplace(
236                         self.len(),
237                         input.len(),
238                         output.len(),
239                         self.get_outofplace_scratch_len(),
240                         scratch.len(),
241                     );
242                 }
243             }
244             fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
245                 if self.len() == 0 {
246                     return;
247                 }
248 
249                 let required_scratch = self.get_inplace_scratch_len();
250                 if scratch.len() < required_scratch || buffer.len() < self.len() {
251                     // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
252                     fft_error_inplace(
253                         self.len(),
254                         buffer.len(),
255                         self.get_inplace_scratch_len(),
256                         scratch.len(),
257                     );
258                     return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
259                 }
260 
261                 let scratch = &mut scratch[..required_scratch];
262                 let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
263                     self.perform_fft_inplace(chunk, scratch)
264                 });
265 
266                 if result.is_err() {
267                     // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
268                     // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
269                     fft_error_inplace(
270                         self.len(),
271                         buffer.len(),
272                         self.get_inplace_scratch_len(),
273                         scratch.len(),
274                     );
275                 }
276             }
277             #[inline(always)]
278             fn get_inplace_scratch_len(&self) -> usize {
279                 $inplace_scratch_len_fn(self)
280             }
281             #[inline(always)]
282             fn get_outofplace_scratch_len(&self) -> usize {
283                 $out_of_place_scratch_len_fn(self)
284             }
285         }
286         impl<T: FftNum> Length for $struct_name<T> {
287             #[inline(always)]
288             fn len(&self) -> usize {
289                 $len_fn(self)
290             }
291         }
292         impl<T: FftNum> Direction for $struct_name<T> {
293             #[inline(always)]
294             fn fft_direction(&self) -> FftDirection {
295                 self.direction
296             }
297         }
298     };
299 }
300 */
301 
302 #[cfg(test)]
303 mod unit_tests {
304     use core::arch::x86_64::*;
305 
306     #[test]
test_calc_f32()307     fn test_calc_f32() {
308         unsafe {
309             let a = _mm_set_ps(1.0, 1.0, 1.0, 1.0);
310             let b = _mm_set_ps(2.0, 2.0, 2.0, 2.0);
311             let c = _mm_set_ps(3.0, 3.0, 3.0, 3.0);
312             let d = _mm_set_ps(4.0, 4.0, 4.0, 4.0);
313             let e = _mm_set_ps(5.0, 5.0, 5.0, 5.0);
314             let f = _mm_set_ps(6.0, 6.0, 6.0, 6.0);
315             let g = _mm_set_ps(7.0, 7.0, 7.0, 7.0);
316             let h = _mm_set_ps(8.0, 8.0, 8.0, 8.0);
317             let i = _mm_set_ps(9.0, 9.0, 9.0, 9.0);
318             let expected: f32 = 1.0 + 2.0 - 3.0 + 4.0 - 5.0 + 6.0 - 7.0 - 8.0 + 9.0;
319             let res = calc_f32!(a + b - c + d - e + f - g - h + i);
320             let sum = std::mem::transmute::<__m128, [f32; 4]>(res);
321             assert_eq!(sum[0], expected);
322             assert_eq!(sum[1], expected);
323             assert_eq!(sum[2], expected);
324             assert_eq!(sum[3], expected);
325         }
326     }
327     #[test]
test_calc_f64()328     fn test_calc_f64() {
329         unsafe {
330             let a = _mm_set_pd(1.0, 1.0);
331             let b = _mm_set_pd(2.0, 2.0);
332             let c = _mm_set_pd(3.0, 3.0);
333             let d = _mm_set_pd(4.0, 4.0);
334             let e = _mm_set_pd(5.0, 5.0);
335             let f = _mm_set_pd(6.0, 6.0);
336             let g = _mm_set_pd(7.0, 7.0);
337             let h = _mm_set_pd(8.0, 8.0);
338             let i = _mm_set_pd(9.0, 9.0);
339             let expected: f64 = 1.0 + 2.0 - 3.0 + 4.0 - 5.0 + 6.0 - 7.0 - 8.0 + 9.0;
340             let res = calc_f64!(a + b - c + d - e + f - g - h + i);
341             let sum = std::mem::transmute::<__m128d, [f64; 2]>(res);
342             assert_eq!(sum[0], expected);
343             assert_eq!(sum[1], expected);
344         }
345     }
346 }
347