1 use std::any::TypeId;
2
3 // Calculate the sum of an expression consisting of just plus and minus, like `value = a + b - c + d`.
4 // The expression is rewritten to `value = a + (b - (c - d))` (note the flipped sign on d).
5 // After this the `$add` and `$sub` functions are used to make the calculation.
6 // For f32 using `_mm_add_ps` and `_mm_sub_ps`, the expression `value = a + b - c + d` becomes:
7 // ```let value = _mm_add_ps(a, _mm_sub_ps(b, _mm_sub_ps(c, d)));```
8 // Only plus and minus are supported, and all the terms must be plain scalar variables.
9 // Using array indices, like `value = temp[0] + temp[1]` is not supported.
10 macro_rules! calc_sum {
11 ($add:ident, $sub:ident, + $acc:tt + $($rest:tt)*)=> {
12 $add($acc, calc_sum!($add, $sub, + $($rest)*))
13 };
14 ($add:ident, $sub:ident, + $acc:tt - $($rest:tt)*)=> {
15 $sub($acc, calc_sum!($add, $sub, - $($rest)*))
16 };
17 ($add:ident, $sub:ident, - $acc:tt + $($rest:tt)*)=> {
18 $sub($acc, calc_sum!($add, $sub, + $($rest)*))
19 };
20 ($add:ident, $sub:ident, - $acc:tt - $($rest:tt)*)=> {
21 $add($acc, calc_sum!($add, $sub, - $($rest)*))
22 };
23 ($add:ident, $sub:ident, $acc:tt + $($rest:tt)*)=> {
24 $add($acc, calc_sum!($add, $sub, + $($rest)*))
25 };
26 ($add:ident, $sub:ident, $acc:tt - $($rest:tt)*)=> {
27 $sub($acc, calc_sum!($add, $sub, - $($rest)*))
28 };
29 ($add:ident, $sub:ident, + $val:tt) => {$val};
30 ($add:ident, $sub:ident, - $val:tt) => {$val};
31 }
32
33 // Calculate the sum of an expression consisting of just plus and minus, like a + b - c + d
34 macro_rules! calc_f32 {
35 ($($tokens:tt)*) => { calc_sum!(_mm_add_ps, _mm_sub_ps, $($tokens)*)};
36 }
37
38 // Calculate the sum of an expression consisting of just plus and minus, like a + b - c + d
39 macro_rules! calc_f64 {
40 ($($tokens:tt)*) => { calc_sum!(_mm_add_pd, _mm_sub_pd, $($tokens)*)};
41 }
42
43 // Helper function to assert we have the right float type
assert_f32<T: 'static>()44 pub fn assert_f32<T: 'static>() {
45 let id_f32 = TypeId::of::<f32>();
46 let id_t = TypeId::of::<T>();
47 assert!(id_t == id_f32, "Wrong float type, must be f32");
48 }
49
50 // Helper function to assert we have the right float type
assert_f64<T: 'static>()51 pub fn assert_f64<T: 'static>() {
52 let id_f64 = TypeId::of::<f64>();
53 let id_t = TypeId::of::<T>();
54 assert!(id_t == id_f64, "Wrong float type, must be f64");
55 }
56
57 // Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
58 macro_rules! interleave_complex_f32 {
59 ($input:ident, $offset:literal, { $($idx:literal),* }) => {
60 [
61 $(
62 extract_lo_lo_f32($input[$idx], $input[$idx+$offset]),
63 extract_hi_hi_f32($input[$idx], $input[$idx+$offset]),
64 )*
65 ]
66 }
67 }
68
69 // Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
70 // This statement:
71 // ```
72 // let values = separate_interleaved_complex_f32!(input, {0, 2, 4});
73 // ```
74 // is equivalent to:
75 // ```
76 // let values = [
77 // extract_lo_lo_f32(input[0], input[1]),
78 // extract_lo_lo_f32(input[2], input[3]),
79 // extract_lo_lo_f32(input[4], input[5]),
80 // extract_hi_hi_f32(input[0], input[1]),
81 // extract_hi_hi_f32(input[2], input[3]),
82 // extract_hi_hi_f32(input[4], input[5]),
83 // ];
84 macro_rules! separate_interleaved_complex_f32 {
85 ($input:ident, { $($idx:literal),* }) => {
86 [
87 $(
88 extract_lo_lo_f32($input[$idx], $input[$idx+1]),
89 )*
90 $(
91 extract_hi_hi_f32($input[$idx], $input[$idx+1]),
92 )*
93 ]
94 }
95 }
96
97 macro_rules! boilerplate_fft_sse_oop {
98 ($struct_name:ident, $len_fn:expr) => {
99 impl<T: FftNum> Fft<T> for $struct_name<T> {
100 fn process_outofplace_with_scratch(
101 &self,
102 input: &mut [Complex<T>],
103 output: &mut [Complex<T>],
104 _scratch: &mut [Complex<T>],
105 ) {
106 if self.len() == 0 {
107 return;
108 }
109
110 if input.len() < self.len() || output.len() != input.len() {
111 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
112 fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
113 return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
114 }
115
116 let result = unsafe {
117 array_utils::iter_chunks_zipped(
118 input,
119 output,
120 self.len(),
121 |in_chunk, out_chunk| {
122 self.perform_fft_out_of_place(in_chunk, out_chunk, &mut [])
123 },
124 )
125 };
126
127 if result.is_err() {
128 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
129 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
130 fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
131 }
132 }
133 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
134 if self.len() == 0 {
135 return;
136 }
137
138 let required_scratch = self.get_inplace_scratch_len();
139 if scratch.len() < required_scratch || buffer.len() < self.len() {
140 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
141 fft_error_inplace(
142 self.len(),
143 buffer.len(),
144 self.get_inplace_scratch_len(),
145 scratch.len(),
146 );
147 return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
148 }
149
150 let scratch = &mut scratch[..required_scratch];
151 let result = unsafe {
152 array_utils::iter_chunks(buffer, self.len(), |chunk| {
153 self.perform_fft_out_of_place(chunk, scratch, &mut []);
154 chunk.copy_from_slice(scratch);
155 })
156 };
157 if result.is_err() {
158 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
159 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
160 fft_error_inplace(
161 self.len(),
162 buffer.len(),
163 self.get_inplace_scratch_len(),
164 scratch.len(),
165 );
166 }
167 }
168 #[inline(always)]
169 fn get_inplace_scratch_len(&self) -> usize {
170 self.len()
171 }
172 #[inline(always)]
173 fn get_outofplace_scratch_len(&self) -> usize {
174 0
175 }
176 }
177 impl<T> Length for $struct_name<T> {
178 #[inline(always)]
179 fn len(&self) -> usize {
180 $len_fn(self)
181 }
182 }
183 impl<T> Direction for $struct_name<T> {
184 #[inline(always)]
185 fn fft_direction(&self) -> FftDirection {
186 self.direction
187 }
188 }
189 };
190 }
191
192 /* Not used now, but maybe later for the mixed radixes etc
193 macro_rules! boilerplate_sse_fft {
194 ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => {
195 impl<T: FftNum> Fft<T> for $struct_name<T> {
196 fn process_outofplace_with_scratch(
197 &self,
198 input: &mut [Complex<T>],
199 output: &mut [Complex<T>],
200 scratch: &mut [Complex<T>],
201 ) {
202 if self.len() == 0 {
203 return;
204 }
205
206 let required_scratch = self.get_outofplace_scratch_len();
207 if scratch.len() < required_scratch
208 || input.len() < self.len()
209 || output.len() != input.len()
210 {
211 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
212 fft_error_outofplace(
213 self.len(),
214 input.len(),
215 output.len(),
216 self.get_outofplace_scratch_len(),
217 scratch.len(),
218 );
219 return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
220 }
221
222 let scratch = &mut scratch[..required_scratch];
223 let result = array_utils::iter_chunks_zipped(
224 input,
225 output,
226 self.len(),
227 |in_chunk, out_chunk| {
228 self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
229 },
230 );
231
232 if result.is_err() {
233 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
234 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
235 fft_error_outofplace(
236 self.len(),
237 input.len(),
238 output.len(),
239 self.get_outofplace_scratch_len(),
240 scratch.len(),
241 );
242 }
243 }
244 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
245 if self.len() == 0 {
246 return;
247 }
248
249 let required_scratch = self.get_inplace_scratch_len();
250 if scratch.len() < required_scratch || buffer.len() < self.len() {
251 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
252 fft_error_inplace(
253 self.len(),
254 buffer.len(),
255 self.get_inplace_scratch_len(),
256 scratch.len(),
257 );
258 return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
259 }
260
261 let scratch = &mut scratch[..required_scratch];
262 let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
263 self.perform_fft_inplace(chunk, scratch)
264 });
265
266 if result.is_err() {
267 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
268 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
269 fft_error_inplace(
270 self.len(),
271 buffer.len(),
272 self.get_inplace_scratch_len(),
273 scratch.len(),
274 );
275 }
276 }
277 #[inline(always)]
278 fn get_inplace_scratch_len(&self) -> usize {
279 $inplace_scratch_len_fn(self)
280 }
281 #[inline(always)]
282 fn get_outofplace_scratch_len(&self) -> usize {
283 $out_of_place_scratch_len_fn(self)
284 }
285 }
286 impl<T: FftNum> Length for $struct_name<T> {
287 #[inline(always)]
288 fn len(&self) -> usize {
289 $len_fn(self)
290 }
291 }
292 impl<T: FftNum> Direction for $struct_name<T> {
293 #[inline(always)]
294 fn fft_direction(&self) -> FftDirection {
295 self.direction
296 }
297 }
298 };
299 }
300 */
301
302 #[cfg(test)]
303 mod unit_tests {
304 use core::arch::x86_64::*;
305
306 #[test]
test_calc_f32()307 fn test_calc_f32() {
308 unsafe {
309 let a = _mm_set_ps(1.0, 1.0, 1.0, 1.0);
310 let b = _mm_set_ps(2.0, 2.0, 2.0, 2.0);
311 let c = _mm_set_ps(3.0, 3.0, 3.0, 3.0);
312 let d = _mm_set_ps(4.0, 4.0, 4.0, 4.0);
313 let e = _mm_set_ps(5.0, 5.0, 5.0, 5.0);
314 let f = _mm_set_ps(6.0, 6.0, 6.0, 6.0);
315 let g = _mm_set_ps(7.0, 7.0, 7.0, 7.0);
316 let h = _mm_set_ps(8.0, 8.0, 8.0, 8.0);
317 let i = _mm_set_ps(9.0, 9.0, 9.0, 9.0);
318 let expected: f32 = 1.0 + 2.0 - 3.0 + 4.0 - 5.0 + 6.0 - 7.0 - 8.0 + 9.0;
319 let res = calc_f32!(a + b - c + d - e + f - g - h + i);
320 let sum = std::mem::transmute::<__m128, [f32; 4]>(res);
321 assert_eq!(sum[0], expected);
322 assert_eq!(sum[1], expected);
323 assert_eq!(sum[2], expected);
324 assert_eq!(sum[3], expected);
325 }
326 }
327 #[test]
test_calc_f64()328 fn test_calc_f64() {
329 unsafe {
330 let a = _mm_set_pd(1.0, 1.0);
331 let b = _mm_set_pd(2.0, 2.0);
332 let c = _mm_set_pd(3.0, 3.0);
333 let d = _mm_set_pd(4.0, 4.0);
334 let e = _mm_set_pd(5.0, 5.0);
335 let f = _mm_set_pd(6.0, 6.0);
336 let g = _mm_set_pd(7.0, 7.0);
337 let h = _mm_set_pd(8.0, 8.0);
338 let i = _mm_set_pd(9.0, 9.0);
339 let expected: f64 = 1.0 + 2.0 - 3.0 + 4.0 - 5.0 + 6.0 - 7.0 - 8.0 + 9.0;
340 let res = calc_f64!(a + b - c + d - e + f - g - h + i);
341 let sum = std::mem::transmute::<__m128d, [f64; 2]>(res);
342 assert_eq!(sum[0], expected);
343 assert_eq!(sum[1], expected);
344 }
345 }
346 }
347