1 /****************************************************************************
2
3 PGFFT: Pretty Good FFT (v1.8)
4
5 Copyright (C) 2019, victor Shoup
6
7 See below for more details.
8
9 ****************************************************************************/
10
11
12
13 #pragma GCC diagnostic push
14 #pragma GCC diagnostic ignored "-Wunused-variable"
15 #pragma GCC diagnostic ignored "-Wunused-function"
16 #ifdef __GNUC__
17 #ifndef __clang__
18 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
19 #endif
20 #endif
21
22 #define PGFFT_USE_TRUNCATED_BLUE (1)
23 // set to 0 to disable the truncated Bluestein
24
25 #define PGFFT_USE_EXPLICIT_MUL (1)
26 // Set to 0 to disable explict complex multiplication.
27 // The built-in complex multiplication routines are
28 // incredibly slow, because the standard requires special handling
29 // of non-finite complex values.
30 // To fix the problem, by default, PGFFT will override these routines
31 // with explicitly defined multiplication functions.
32 // Another way to solve this problem is to compile with the -ffast-math
33 // option (at least, on gcc, that's the right flag).
34
35
36 //============================================
37
38 #ifndef PGFFT_DISABLE_SIMD
39
40 #ifdef __AVX__
41 #define HAVE_AVX
42 //#warning "HAVE_AVX"
43 #endif
44
45 #ifdef __AVX2__
46 #define HAVE_AVX2
47 //#warning "HAVE_AVX2"
48 #endif
49
50 #if defined(HAVE_AVX) || defined(HAVE_AVX2)
51 #define USE_PD4
52 #endif
53
54 #endif
55
56
57
58
59 #include <helib/PGFFT.h>
60 #include <cassert>
61 #include <cstdlib>
62 #include <limits>
63
64 #ifdef USE_PD4
65 #include <immintrin.h>
66 #endif
67
68 namespace helib {
69
70 using std::vector;
71 using std::complex;
72
73 template<class T>
74 using aligned_vector = PGFFT::aligned_vector<T>;
75
76 typedef complex<double> cmplx_t;
77 typedef long double ldbl;
78 //typedef double ldbl;
79
80 #ifdef USE_PD4
simd_enabled()81 bool PGFFT::simd_enabled() { return true; }
82 #else
simd_enabled()83 bool PGFFT::simd_enabled() { return false; }
84 #endif
85
86
87 #if (PGFFT_USE_EXPLICIT_MUL)
88
89 static inline cmplx_t
MUL(cmplx_t a,cmplx_t b)90 MUL(cmplx_t a, cmplx_t b)
91 {
92 double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
93 return cmplx_t(x*u-y*v, x*v+y*u);
94 }
95
96 static inline cmplx_t
CMUL(cmplx_t a,cmplx_t b)97 CMUL(cmplx_t a, cmplx_t b)
98 {
99 double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
100 return cmplx_t(x*u+y*v, y*u-x*v);
101 }
102
103 #else
104
105 #define MUL(a, b) ((a) * (b))
106 #define CMUL(a, b) ((a) * std::conj(b))
107
108 #endif
109
110
111 #if (defined(__GNUC__) && (__GNUC__ >= 4))
112
113 // on relative modern versions of gcc, we can
114 // decalare "restricted" pointers in C++
115
116 #define RESTRICT __restrict
117
118 #else
119
120 #define RESTRICT
121
122 #endif
123
124 /**************************************************************
125
126 Aligned allocation
127
128 **************************************************************/
129
130 #ifdef USE_PD4
131
132 #define PGFFT_ALIGN (64)
133
134 void *
aligned_allocate(std::size_t n,std::size_t nelts)135 PGFFT::aligned_allocate(std::size_t n, std::size_t nelts)
136 {
137 if (n > std::numeric_limits<std::size_t>::max() / nelts) return 0;
138 std::size_t sz = n * nelts;
139 std::size_t alignment = PGFFT_ALIGN;
140 if (sz > std::numeric_limits<std::size_t>::max() - alignment) return 0;
141
142 sz += alignment;
143 char* buf = (char*) std::malloc(sz);
144
145 if (!buf) return 0;
146
147 int remainder = ((unsigned long long)buf) % alignment;
148 int offset = alignment - remainder;
149 char* ret = buf + offset;
150
151 ret[-1] = offset;
152
153 return ret;
154 }
155
156
157 void
aligned_deallocate(void * p)158 PGFFT::aligned_deallocate(void *p)
159 {
160 if (!p) return;
161 char *cp = (char *) p;
162 int offset = cp[-1];
163 std::free(cp - offset);
164 }
165
166 #else
167
168 void *
aligned_allocate(std::size_t n,std::size_t sz)169 PGFFT::aligned_allocate(std::size_t n, std::size_t sz)
170 {
171 if (n > std::numeric_limits<std::size_t>::max() / sz) return 0;
172 std::size_t size = n * sz;
173 return std::malloc(size);
174 }
175
176 void
aligned_deallocate(void * p)177 PGFFT::aligned_deallocate(void *p)
178 {
179 if (!p) return;
180 std::free(p);
181 }
182
183 #endif
184
185 /**************************************************************
186
187 Packed Double abstraction layer
188
189 **************************************************************/
190
191
192
193 namespace {
194
195
196
197 //=================== PD4 implementation ===============
198
199 #if defined(USE_PD4)
200
201 struct PD4 {
202 __m256d data;
203
204
205 PD4() = default;
PD4helib::__anon00fb31d40111::PD4206 PD4(double x) : data(_mm256_set1_pd(x)) { }
PD4helib::__anon00fb31d40111::PD4207 PD4(__m256d _data) : data(_data) { }
PD4helib::__anon00fb31d40111::PD4208 PD4(double d0, double d1, double d2, double d3)
209 : data(_mm256_set_pd(d3, d2, d1, d0)) { }
210
loadhelib::__anon00fb31d40111::PD4211 static PD4 load(const double *p) { return _mm256_load_pd(p); }
212
213 // load from unaligned address
loaduhelib::__anon00fb31d40111::PD4214 static PD4 loadu(const double *p) { return _mm256_loadu_pd(p); }
215 };
216
217 inline void
load(PD4 & x,const double * p)218 load(PD4& x, const double *p)
219 { x = PD4::load(p); }
220
221 // load from unaligned address
222 inline void
loadu(PD4 & x,const double * p)223 loadu(PD4& x, const double *p)
224 { x = PD4::loadu(p); }
225
226 inline void
store(double * p,PD4 a)227 store(double *p, PD4 a)
228 { _mm256_store_pd(p, a.data); }
229
230 // store to unaligned address
231 inline void
storeu(double * p,PD4 a)232 storeu(double *p, PD4 a)
233 { _mm256_storeu_pd(p, a.data); }
234
235
236 // swap even/odd slots
237 // e.g., 0123 -> 1032
238 inline PD4
swap2(PD4 a)239 swap2(PD4 a)
240 { return _mm256_permute_pd(a.data, 0x5); }
241
242 // 0123 -> 0022
243 inline PD4
dup2even(PD4 a)244 dup2even(PD4 a)
245 { return _mm256_permute_pd(a.data, 0); }
246
247 // 0123 -> 1133
248 inline PD4
dup2odd(PD4 a)249 dup2odd(PD4 a)
250 { return _mm256_permute_pd(a.data, 0xf); }
251
252 // blend even/odd slots
253 // 0123, 4567 -> 0527
254 inline PD4
blend2(PD4 a,PD4 b)255 blend2(PD4 a, PD4 b)
256 { return _mm256_blend_pd(a.data, b.data, 0xa); }
257
258 // 0123, 4567 -> 0426
259 inline PD4
blend_even(PD4 a,PD4 b)260 blend_even(PD4 a, PD4 b)
261 { return _mm256_unpacklo_pd(a.data, b.data); }
262
263
264 // 0123, 4567 -> 1537
265 inline PD4
blend_odd(PD4 a,PD4 b)266 blend_odd(PD4 a, PD4 b)
267 { return _mm256_unpackhi_pd(a.data, b.data); }
268
269
270 inline void
clear(PD4 & x)271 clear(PD4& x)
272 { x.data = _mm256_setzero_pd(); }
273
274 inline PD4
operator +(PD4 a,PD4 b)275 operator+(PD4 a, PD4 b)
276 { return _mm256_add_pd(a.data, b.data); }
277
278 inline PD4
operator -(PD4 a,PD4 b)279 operator-(PD4 a, PD4 b)
280 { return _mm256_sub_pd(a.data, b.data); }
281
282 inline PD4
operator *(PD4 a,PD4 b)283 operator*(PD4 a, PD4 b)
284 { return _mm256_mul_pd(a.data, b.data); }
285
286 inline PD4
operator /(PD4 a,PD4 b)287 operator/(PD4 a, PD4 b)
288 { return _mm256_div_pd(a.data, b.data); }
289
290 inline PD4&
operator +=(PD4 & a,PD4 b)291 operator+=(PD4& a, PD4 b)
292 { a = a + b; return a; }
293
294 inline PD4&
operator -=(PD4 & a,PD4 b)295 operator-=(PD4& a, PD4 b)
296 { a = a - b; return a; }
297
298 inline PD4&
operator *=(PD4 & a,PD4 b)299 operator*=(PD4& a, PD4 b)
300 { a = a * b; return a; }
301
302 inline PD4&
operator /=(PD4 & a,PD4 b)303 operator/=(PD4& a, PD4 b)
304 { a = a / b; return a; }
305
306 #ifdef HAVE_AVX2
307
308 // a*b+c (fused)
309 inline PD4
fused_muladd(PD4 a,PD4 b,PD4 c)310 fused_muladd(PD4 a, PD4 b, PD4 c)
311 { return _mm256_fmadd_pd(a.data, b.data, c.data); }
312 // NEEDS: FMA
313
314 // a*b-c (fused)
315 inline PD4
fused_mulsub(PD4 a,PD4 b,PD4 c)316 fused_mulsub(PD4 a, PD4 b, PD4 c)
317 { return _mm256_fmsub_pd(a.data, b.data, c.data); }
318 // NEEDS: FMA
319
320 // -a*b+c (fused)
321 inline PD4
fused_negmuladd(PD4 a,PD4 b,PD4 c)322 fused_negmuladd(PD4 a, PD4 b, PD4 c)
323 { return _mm256_fnmadd_pd(a.data, b.data, c.data); }
324 // NEEDS: FMA
325
326 // (a0,a1,a2,a3), (b0,b1,b2,b3), (c0,c1,c2,c3) ->
327 // (a0*b0-c0, a1*b1+c1, a2*b2-c2, a3*b3+c3)
328 inline PD4
fmaddsub(PD4 a,PD4 b,PD4 c)329 fmaddsub(PD4 a, PD4 b, PD4 c)
330 { return _mm256_fmaddsub_pd(a.data, b.data, c.data); }
331 // NEEDS: FMA
332 // (plain addsub only needs AVX)
333
334 // (a0,a1,a2,a3), (b0,b1,b2,b3), (c0,c1,c2,c3) ->
335 // (a0*b0+c0, a1*b1-c1, a2*b2+c2, a3*b3-c3)
336 inline PD4
fmsubadd(PD4 a,PD4 b,PD4 c)337 fmsubadd(PD4 a, PD4 b, PD4 c)
338 { return _mm256_fmsubadd_pd(a.data, b.data, c.data); }
339 // NEEDS: FMA
340 // (there is no plain subadd)
341 #endif
342
343
344 #endif
345
346 }
347
348
349 /***************************************************************
350
351
352 TRUNCATED FFT
353
354 This code is derived from code originally developed
355 by David Harvey. I include his original documentation,
356 annotated appropriately to highlight differences in
357 the implemebtation (see NOTEs).
358
359 The DFT is defined as follows.
360
361 Let the input sequence be a_0, ..., a_{N-1}.
362
363 Let w = standard primitive N-th root of 1, i.e. w = g^(2^FFT62_MAX_LGN / N),
364 where g = some fixed element of Z/pZ of order 2^FFT62_MAX_LGN.
365
366 Let Z = an element of (Z/pZ)^* (twisting parameter).
367
368 Then the output sequence is
369 b_j = \sum_{0 <= i < N} Z^i a_i w^(ij'), for 0 <= j < N,
370 where j' is the length-lgN bit-reversal of j.
371
372 Some of the FFT routines can operate on truncated sequences of certain
373 "admissible" sizes. A size parameter n is admissible if 1 <= n <= N, and n is
374 divisible by a certain power of 2. The precise power depends on the recursive
375 array decomposition of the FFT. The smallest admissible n' >= n can be
376 obtained via fft62_next_size().
377
378 NOTE: the twising parameter is not implemented.
379 NOTE: the next admissible size function is called FFTRoundUp,
380
381
382 Truncated FFT interface is as follows:
383
384 xn and yn must be admissible sizes for N.
385
386 Input in xp[] is a_0, a_1, ..., a_{xn-1}. Assumes a_i = 0 for xn <= i < N.
387
388 Output in yp[] is b_0, ..., b_{yn-1}, i.e. only first yn outputs are computed.
389
390 Twisting parameter Z is described by z and lgH. If z == 0, then Z = basic
391 2^lgH-th root of 1, and must have lgH >= lgN + 1. If z != 0, then Z = z
392 (and lgH is ignored).
393
394 The buffers {xp,xn} and {yp,yn} may overlap, but only if xp == yp.
395
396 Inputs are in [0, 2p), outputs are in [0, 2p).
397
398 threads = number of OpenMP threads to use.
399
400
401
402 Inverse truncated FFT interface is as follows.
403
404 xn and yn must be admissible sizes for N, with yn <= xn.
405
406 Input in xp[] is b_0, b_1, ..., b_{yn-1}, N*a_{yn}, ..., N*a_{xn-1}.
407
408 Assumes a_i = 0 for xn <= i < N.
409
410 Output in yp[] is N*a_0, ..., N*a_{yn-1}.
411
412 Twisting parameter Z is described by z and lgH. If z == 0, then Z = basic
413 2^lgH-th root of 1, and must have lgH >= lgN + 1. If z != 0, then Z = z^(-1)
414 (and lgH is ignored).
415
416 The buffers {xp,xn} and {yp,yn} may overlap, but only if xp == yp.
417
418 Inputs are in [0, 4p), outputs are in [0, 4p).
419
420 threads = number of OpenMP threads to use.
421
422 (note: no function actually implements this interface in full generality!
423 This is because it is tricky (and not that useful) to implement the twisting
424 parameter when xn != yn.)
425
426 NOTE: threads and twisting parameter are not used here.
427 NOTE: the code has been re-written and simplified so that
428 everything is done in place, so xp == yp.
429
430
431 ***************************************************************/
432
433
434
435 #define PGFFT_FFT_RDUP (4)
436 // Currently, this should be at least 2 to support
437 // loop unrolling in the FFT implementation
438
439
440 static inline long
FFTRoundUp(long xn,long k)441 FFTRoundUp(long xn, long k)
442 {
443 long n = 1L << k;
444 if (xn <= 0) return n;
445 // default truncation value of 0 gets converted to n
446
447 xn = ((xn+((1L << PGFFT_FFT_RDUP)-1)) >> PGFFT_FFT_RDUP) << PGFFT_FFT_RDUP;
448
449 if (k >= 10) {
450 if (xn > n - (n >> 4)) xn = n;
451 }
452 else {
453 if (xn > n - (n >> 3)) xn = n;
454 }
455 // truncation just a bit below n does not really help
456 // at all, and can sometimes slow things down slightly, so round up
457 // to n. This also takes care of cases where xn > n.
458 // Actually, for smallish n, we should round up sooner,
459 // at n-n/8, and for larger n, we should round up later,
460 // at n-m/16. At least, experimentally, this is what I see.
461
462 return xn;
463 }
464
465
466
467
468 #define fwd_butterfly(xx0, xx1, w) \
469 do \
470 { \
471 cmplx_t x0_ = xx0; \
472 cmplx_t x1_ = xx1; \
473 cmplx_t t_ = x0_ - x1_; \
474 xx0 = x0_ + x1_; \
475 xx1 = MUL(t_, w); \
476 } \
477 while (0)
478
479
480
481 #define fwd_butterfly0(xx0, xx1) \
482 do \
483 { \
484 cmplx_t x0_ = xx0; \
485 cmplx_t x1_ = xx1; \
486 xx0 = x0_ + x1_; \
487 xx1 = x0_ - x1_; \
488 } \
489 while (0)
490
491
492 #define inv_butterfly0(xx0, xx1) \
493 do \
494 { \
495 cmplx_t x0_ = xx0; \
496 cmplx_t x1_ = xx1; \
497 xx0 = x0_ + x1_; \
498 xx1 = x0_ - x1_; \
499 } while (0)
500
501
502 #define inv_butterfly(xx0, xx1, w) \
503 do \
504 { \
505 cmplx_t x0_ = xx0; \
506 cmplx_t x1_ = xx1; \
507 cmplx_t t_ = CMUL(x1_, w); \
508 xx0 = x0_ + t_; \
509 xx1 = x0_ - t_; \
510 } while (0)
511
512
513
514 #ifdef USE_PD4
515
516 #ifdef HAVE_AVX2
517
518 static inline PD4
complex_mul(PD4 ab,PD4 cd)519 complex_mul(PD4 ab, PD4 cd)
520 {
521 PD4 cc = dup2even(cd);
522 PD4 dd = dup2odd(cd);
523 PD4 ba = swap2(ab);
524 return fmaddsub(ab, cc, ba*dd);
525 }
526
527 static inline PD4
complex_conj_mul(PD4 ab,PD4 cd)528 complex_conj_mul(PD4 ab, PD4 cd)
529 // (ac+bd,bc-ad)
530 {
531 PD4 cc = dup2even(cd);
532 PD4 dd = dup2odd(cd);
533 PD4 ba = swap2(ab);
534 return fmsubadd(ab, cc, ba*dd);
535 }
536
537
538 #define MUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
539 do { \
540 x_0 = complex_mul(a_0, b_0); \
541 x_1 = complex_mul(a_1, b_1); \
542 } while (0)
543
544 #define CMUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
545 do { \
546 x_0 = complex_conj_mul(a_0, b_0); \
547 x_1 = complex_conj_mul(a_1, b_1); \
548 } while (0)
549
550 #else
551 // This code sequence works without FMA
552 #define MUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
553 do { \
554 PD4 a_re_ = blend_even(a_0, a_1); \
555 PD4 a_im_ = blend_odd(a_0, a_1); \
556 \
557 PD4 b_re_ = blend_even(b_0, b_1); \
558 PD4 b_im_ = blend_odd(b_0, b_1); \
559 \
560 PD4 x_re_ = a_re_*b_re_ - a_im_*b_im_; \
561 PD4 x_im_ = a_re_*b_im_ + a_im_*b_re_; \
562 \
563 x_0 = blend_even(x_re_, x_im_); \
564 x_1 = blend_odd(x_re_, x_im_); \
565 } while (0)
566
567 #define CMUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
568 do { \
569 PD4 a_re_ = blend_even(a_0, a_1); \
570 PD4 a_im_ = blend_odd(a_0, a_1); \
571 \
572 PD4 b_re_ = blend_even(b_0, b_1); \
573 PD4 b_im_ = blend_odd(b_0, b_1); \
574 \
575 PD4 x_re_ = a_re_*b_re_ + a_im_*b_im_; \
576 PD4 x_im_ = a_im_*b_re_ - a_re_*b_im_; \
577 \
578 x_0 = blend_even(x_re_, x_im_); \
579 x_1 = blend_odd(x_re_, x_im_); \
580 } while (0)
581
582 #endif
583
584
585
586 static inline void
fwd_butterfly_loop_simd(long size,double * RESTRICT xp0,double * RESTRICT xp1,const double * RESTRICT wtab)587 fwd_butterfly_loop_simd(
588 long size,
589 double * RESTRICT xp0,
590 double * RESTRICT xp1,
591 const double * RESTRICT wtab)
592 {
593 for (long j = 0; j < size; j += 4) {
594 PD4 x0_0 = PD4::load(xp0+2*(j+0));
595 PD4 x0_1 = PD4::load(xp0+2*(j+2));
596 PD4 x1_0 = PD4::load(xp1+2*(j+0));
597 PD4 x1_1 = PD4::load(xp1+2*(j+2));
598 PD4 w_0 = PD4::load(wtab+2*(j+0));
599 PD4 w_1 = PD4::load(wtab+2*(j+2));
600
601 PD4 xx0_0 = x0_0 + x1_0;
602 PD4 xx0_1 = x0_1 + x1_1;
603
604 PD4 diff_0 = x0_0 - x1_0;
605 PD4 diff_1 = x0_1 - x1_1;
606
607 PD4 xx1_0, xx1_1;
608 MUL2(xx1_0, xx1_1, diff_0, diff_1, w_0, w_1);
609
610 store(xp0+2*(j+0), xx0_0);
611 store(xp0+2*(j+2), xx0_1);
612 store(xp1+2*(j+0), xx1_0);
613 store(xp1+2*(j+2), xx1_1);
614 }
615 }
616
617 static inline void
fwd_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)618 fwd_butterfly_loop(
619 long size,
620 cmplx_t * RESTRICT xp0,
621 cmplx_t * RESTRICT xp1,
622 const cmplx_t * RESTRICT wtab)
623 {
624 // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
625 fwd_butterfly_loop_simd(
626 size,
627 reinterpret_cast<double*>(xp0),
628 reinterpret_cast<double*>(xp1),
629 reinterpret_cast<const double*>(wtab));
630 }
631
632 static inline void
inv_butterfly_loop_simd(long size,double * RESTRICT xp0,double * RESTRICT xp1,const double * RESTRICT wtab)633 inv_butterfly_loop_simd(
634 long size,
635 double * RESTRICT xp0,
636 double * RESTRICT xp1,
637 const double * RESTRICT wtab)
638 {
639 for (long j = 0; j < size; j += 4) {
640 PD4 x0_0 = PD4::load(xp0+2*(j+0));
641 PD4 x0_1 = PD4::load(xp0+2*(j+2));
642 PD4 x1_0 = PD4::load(xp1+2*(j+0));
643 PD4 x1_1 = PD4::load(xp1+2*(j+2));
644 PD4 w_0 = PD4::load(wtab+2*(j+0));
645 PD4 w_1 = PD4::load(wtab+2*(j+2));
646
647 PD4 t_0, t_1;
648 CMUL2(t_0, t_1, x1_0, x1_1, w_0, w_1);
649
650 PD4 xx0_0 = x0_0 + t_0;
651 PD4 xx0_1 = x0_1 + t_1;
652
653 PD4 xx1_0 = x0_0 - t_0;
654 PD4 xx1_1 = x0_1 - t_1;
655
656 store(xp0+2*(j+0), xx0_0);
657 store(xp0+2*(j+2), xx0_1);
658 store(xp1+2*(j+0), xx1_0);
659 store(xp1+2*(j+2), xx1_1);
660 }
661 }
662
663 static inline void
inv_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)664 inv_butterfly_loop(
665 long size,
666 cmplx_t * RESTRICT xp0,
667 cmplx_t * RESTRICT xp1,
668 const cmplx_t * RESTRICT wtab)
669 {
670 // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
671 inv_butterfly_loop_simd(
672 size,
673 reinterpret_cast<double*>(xp0),
674 reinterpret_cast<double*>(xp1),
675 reinterpret_cast<const double*>(wtab));
676 }
677
678 #else
679
680 static inline void
fwd_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)681 fwd_butterfly_loop(
682 long size,
683 cmplx_t * RESTRICT xp0,
684 cmplx_t * RESTRICT xp1,
685 const cmplx_t * RESTRICT wtab)
686 {
687 fwd_butterfly0(xp0[0+0], xp1[0+0]);
688 fwd_butterfly(xp0[0+1], xp1[0+1], wtab[0+1]);
689 fwd_butterfly(xp0[0+2], xp1[0+2], wtab[0+2]);
690 fwd_butterfly(xp0[0+3], xp1[0+3], wtab[0+3]);
691 for (long j = 4; j < size; j += 4) {
692 fwd_butterfly(xp0[j+0], xp1[j+0], wtab[j+0]);
693 fwd_butterfly(xp0[j+1], xp1[j+1], wtab[j+1]);
694 fwd_butterfly(xp0[j+2], xp1[j+2], wtab[j+2]);
695 fwd_butterfly(xp0[j+3], xp1[j+3], wtab[j+3]);
696 }
697 }
698
699 static inline void
inv_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)700 inv_butterfly_loop(
701 long size,
702 cmplx_t * RESTRICT xp0,
703 cmplx_t * RESTRICT xp1,
704 const cmplx_t * RESTRICT wtab)
705 {
706 inv_butterfly0(xp0[0+0], xp1[0+0]);
707 inv_butterfly(xp0[0+1], xp1[0+1], wtab[0+1]);
708 inv_butterfly(xp0[0+2], xp1[0+2], wtab[0+2]);
709 inv_butterfly(xp0[0+3], xp1[0+3], wtab[0+3]);
710 for (long j = 4; j < size; j += 4) {
711 inv_butterfly(xp0[j+0], xp1[j+0], wtab[j+0]);
712 inv_butterfly(xp0[j+1], xp1[j+1], wtab[j+1]);
713 inv_butterfly(xp0[j+2], xp1[j+2], wtab[j+2]);
714 inv_butterfly(xp0[j+3], xp1[j+3], wtab[j+3]);
715 }
716 }
717
718 #endif
719
720
721 #if (defined(USE_PD4))
722
723
724 static inline void
mul_loop_simd(long size,double * RESTRICT xp,const double * yp)725 mul_loop_simd(
726 long size,
727 double * RESTRICT xp,
728 const double * yp)
729 {
730 long j;
731 for (j = 0; j < size; j += 4) {
732 PD4 x_0 = PD4::load(xp+2*(j+0));
733 PD4 x_1 = PD4::load(xp+2*(j+2));
734 PD4 y_0 = PD4::load(yp+2*(j+0));
735 PD4 y_1 = PD4::load(yp+2*(j+2));
736
737 PD4 z_0, z_1;
738 MUL2(z_0, z_1, x_0, x_1, y_0, y_1);
739
740 store(xp+2*(j+0), z_0);
741 store(xp+2*(j+2), z_1);
742 }
743 }
744
745
746 static inline void
mul_loop(long size,cmplx_t * xp,const cmplx_t * yp)747 mul_loop(
748 long size,
749 cmplx_t * xp,
750 const cmplx_t * yp)
751 {
752 // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
753 mul_loop_simd(
754 size,
755 reinterpret_cast<double*>(xp),
756 reinterpret_cast<const double*>(yp));
757 }
758
759
760 #else
761
762
763 static inline void
mul_loop(long size,cmplx_t * xp,const cmplx_t * yp)764 mul_loop(
765 long size,
766 cmplx_t * xp,
767 const cmplx_t * yp)
768 {
769 for (long j = 0; j < size; j++)
770 xp[j] = MUL(xp[j], yp[j]);
771 }
772
773 #endif
774
775
776 // requires size divisible by 8
777 static void
new_fft_layer(cmplx_t * xp,long blocks,long size,const cmplx_t * RESTRICT wtab)778 new_fft_layer(cmplx_t* xp, long blocks, long size,
779 const cmplx_t* RESTRICT wtab)
780 {
781 size /= 2;
782
783 do
784 {
785 cmplx_t* RESTRICT xp0 = xp;
786 cmplx_t* RESTRICT xp1 = xp + size;
787
788 fwd_butterfly_loop(size, xp0, xp1, wtab);
789
790 xp += 2 * size;
791 }
792 while (--blocks != 0);
793 }
794
795
796
797 static void
new_fft_last_two_layers(cmplx_t * xp,long blocks,const cmplx_t * wtab)798 new_fft_last_two_layers(cmplx_t* xp, long blocks, const cmplx_t* wtab)
799 {
800 // 4th root of unity
801 cmplx_t w = wtab[1];
802
803 do
804 {
805 cmplx_t u0 = xp[0];
806 cmplx_t u1 = xp[1];
807 cmplx_t u2 = xp[2];
808 cmplx_t u3 = xp[3];
809
810 cmplx_t v0 = u0 + u2;
811 cmplx_t v2 = u0 - u2;
812 cmplx_t v1 = u1 + u3;
813 cmplx_t t = u1 - u3;
814
815 //cmplx_t v3 = MUL(t, w);
816 // DIRT: relies on w == (0,-1)
817 cmplx_t v3(t.imag(), -t.real());
818
819
820 xp[0] = v0 + v1;
821 xp[1] = v0 - v1;
822 xp[2] = v2 + v3;
823 xp[3] = v2 - v3;
824
825 xp += 4;
826 }
827 while (--blocks != 0);
828 }
829
830
831 static void
new_fft_base(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)832 new_fft_base(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
833 {
834 if (lgN == 0) return;
835
836 if (lgN == 1)
837 {
838 cmplx_t x0 = xp[0];
839 cmplx_t x1 = xp[1];
840 xp[0] = x0 + x1;
841 xp[1] = x0 - x1;
842 return;
843 }
844
845
846 long N = 1L << lgN;
847
848 for (long j = lgN, size = N, blocks = 1;
849 j > 2; j--, blocks <<= 1, size >>= 1)
850 new_fft_layer(xp, blocks, size, &tab[j][0]);
851
852 new_fft_last_two_layers(xp, N/4, &tab[2][0]);
853 }
854
855
856 // Implements the truncated FFT interface, described above.
857 // All computations done in place, and xp should point to
858 // an array of size N, all of which may be overwitten
859 // during the computation.
860
861 #define PGFFT_NEW_FFT_THRESH (10)
862
863 static
new_fft_short(cmplx_t * xp,long yn,long xn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)864 void new_fft_short(cmplx_t* xp, long yn, long xn, long lgN,
865 const vector<aligned_vector<cmplx_t>>& tab)
866 {
867 long N = 1L << lgN;
868
869 if (yn == N)
870 {
871 if (xn == N && lgN <= PGFFT_NEW_FFT_THRESH)
872 {
873 // no truncation
874 new_fft_base(xp, lgN, tab);
875 return;
876 }
877 }
878
879 // divide-and-conquer algorithm
880
881 long half = N >> 1;
882
883 if (yn <= half)
884 {
885 if (xn <= half)
886 {
887 new_fft_short(xp, yn, xn, lgN - 1, tab);
888 }
889 else
890 {
891 xn -= half;
892
893 // (X, Y) -> X + Y
894 for (long j = 0; j < xn; j++)
895 xp[j] = xp[j] + xp[j + half];
896
897 new_fft_short(xp, yn, half, lgN - 1, tab);
898 }
899 }
900 else
901 {
902 yn -= half;
903
904 cmplx_t* RESTRICT xp0 = xp;
905 cmplx_t* RESTRICT xp1 = xp + half;
906 const cmplx_t* RESTRICT wtab = &tab[lgN][0];
907
908 if (xn <= half)
909 {
910 // X -> (X, w*X)
911 for (long j = 0; j < xn; j++)
912 xp1[j] = MUL(xp0[j], wtab[j]);
913
914 new_fft_short(xp0, half, xn, lgN - 1, tab);
915 new_fft_short(xp1, yn, xn, lgN - 1, tab);
916 }
917 else
918 {
919 xn -= half;
920
921 // (X, Y) -> (X + Y, w*(X - Y))
922 // DIRT: assumes xn is a multiple of 4
923 fwd_butterfly_loop(xn, xp0, xp1, wtab);
924
925 // X -> (X, w*X)
926 for (long j = xn; j < half; j++)
927 xp1[j] = MUL(xp0[j], wtab[j]);
928
929 new_fft_short(xp0, half, half, lgN - 1, tab);
930 new_fft_short(xp1, yn, half, lgN - 1, tab);
931 }
932 }
933 }
934
new_fft(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)935 static void new_fft(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
936 {
937 long N = 1L << lgN;
938 new_fft_short(xp, N, N, lgN, tab);
939 }
940
941
942
943
944
945 // requires size divisible by 8
946 static void
new_ifft_layer(cmplx_t * xp,long blocks,long size,const cmplx_t * RESTRICT wtab)947 new_ifft_layer(cmplx_t* xp, long blocks, long size,
948 const cmplx_t* RESTRICT wtab)
949 {
950
951 size /= 2;
952
953 do
954 {
955
956 cmplx_t* RESTRICT xp0 = xp;
957 cmplx_t* RESTRICT xp1 = xp + size;
958
959
960 inv_butterfly_loop(size, xp0, xp1, wtab);
961
962 xp += 2 * size;
963 }
964 while (--blocks != 0);
965 }
966
967 static void
new_ifft_first_two_layers(cmplx_t * xp,long blocks,const cmplx_t * wtab)968 new_ifft_first_two_layers(cmplx_t* xp, long blocks, const cmplx_t* wtab)
969 {
970 // 4th root of unity
971 cmplx_t w = wtab[1];
972
973 do
974 {
975 cmplx_t u0 = xp[0];
976 cmplx_t u1 = xp[1];
977 cmplx_t u2 = xp[2];
978 cmplx_t u3 = xp[3];
979
980 cmplx_t v0 = u0 + u1;
981 cmplx_t v1 = u0 - u1;
982 cmplx_t v2 = u2 + u3;
983 cmplx_t t = u2 - u3;
984
985 //cmplx_t v3 = CMUL(t, w);
986 // DIRT: relies on w == (0,1)
987 cmplx_t v3(-t.imag(), t.real());
988
989 xp[0] = v0 + v2;
990 xp[2] = v0 - v2;
991 xp[1] = v1 + v3;
992 xp[3] = v1 - v3;
993
994 xp += 4;
995 }
996 while (--blocks != 0);
997 }
998
999
1000 static void
new_ifft_base(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1001 new_ifft_base(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1002 {
1003 if (lgN == 0) return;
1004
1005
1006 if (lgN == 1)
1007 {
1008 cmplx_t x0 = xp[0];
1009 cmplx_t x1 = xp[1];
1010 xp[0] = x0 + x1;
1011 xp[1] = x0 - x1;
1012 return;
1013 }
1014
1015
1016 long blocks = 1L << (lgN - 2);
1017 new_ifft_first_two_layers(xp, blocks, &tab[2][0]);
1018 blocks >>= 1;
1019
1020 long size = 8;
1021 for (long j = 3; j <= lgN; j++, blocks >>= 1, size <<= 1)
1022 new_ifft_layer(xp, blocks, size, &tab[j][0]);
1023 }
1024
1025 static
1026 void new_ifft_short2(cmplx_t* yp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab);
1027
1028
1029
1030 static
new_ifft_short1(cmplx_t * xp,long yn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1031 void new_ifft_short1(cmplx_t* xp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1032
1033 // Implements truncated inverse FFT interface, but with xn==yn.
1034 // All computations are done in place.
1035
1036 {
1037 long N = 1L << lgN;
1038
1039 if (yn == N && lgN <= PGFFT_NEW_FFT_THRESH)
1040 {
1041 // no truncation
1042 new_ifft_base(xp, lgN, tab);
1043 return;
1044 }
1045
1046 // divide-and-conquer algorithm
1047
1048 long half = N >> 1;
1049
1050 if (yn <= half)
1051 {
1052 // X -> 2X
1053 for (long j = 0; j < yn; j++)
1054 xp[j] = 2.0 * xp[j];
1055
1056 new_ifft_short1(xp, yn, lgN - 1, tab);
1057 }
1058 else
1059 {
1060 cmplx_t* RESTRICT xp0 = xp;
1061 cmplx_t* RESTRICT xp1 = xp + half;
1062 const cmplx_t* RESTRICT wtab = &tab[lgN][0];
1063
1064 new_ifft_short1(xp0, half, lgN - 1, tab);
1065
1066 yn -= half;
1067
1068 // X -> (2X, w*X)
1069 for (long j = yn; j < half; j++)
1070 {
1071 cmplx_t x0 = xp0[j];
1072 xp0[j] = 2.0 * x0;
1073 xp1[j] = MUL(x0, wtab[j]);
1074 }
1075
1076 new_ifft_short2(xp1, yn, lgN - 1, tab);
1077
1078 // (X, Y) -> (X + Y/w, X - Y/w)
1079 {
1080 inv_butterfly_loop(yn, xp0, xp1, wtab);
1081 }
1082 }
1083 }
1084
1085
1086
1087 static
new_ifft_short2(cmplx_t * xp,long yn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1088 void new_ifft_short2(cmplx_t* xp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1089
1090 // Implements truncated inverse FFT interface, but with xn==N.
1091 // All computations are done in place.
1092
1093 {
1094 long N = 1L << lgN;
1095
1096 if (yn == N && lgN <= PGFFT_NEW_FFT_THRESH)
1097 {
1098 // no truncation
1099 new_ifft_base(xp, lgN, tab);
1100 return;
1101 }
1102
1103 // divide-and-conquer algorithm
1104
1105 long half = N >> 1;
1106
1107 if (yn <= half)
1108 {
1109 // X -> 2X
1110 for (long j = 0; j < yn; j++)
1111 xp[j] = 2.0 * xp[j];
1112 // (X, Y) -> X + Y
1113 for (long j = yn; j < half; j++)
1114 xp[j] = xp[j] + xp[j + half];
1115
1116 new_ifft_short2(xp, yn, lgN - 1, tab);
1117
1118 // (X, Y) -> X - Y
1119 for (long j = 0; j < yn; j++)
1120 xp[j] = xp[j] - xp[j + half];
1121 }
1122 else
1123 {
1124 cmplx_t* RESTRICT xp0 = xp;
1125 cmplx_t* RESTRICT xp1 = xp + half;
1126 const cmplx_t* RESTRICT wtab = &tab[lgN][0];
1127
1128 new_ifft_short1(xp0, half, lgN - 1, tab);
1129
1130 yn -= half;
1131
1132
1133 // (X, Y) -> (2X - Y, w*(X - Y))
1134 for (long j = yn; j < half; j++)
1135 {
1136 cmplx_t x0 = xp0[j];
1137 cmplx_t x1 = xp1[j];
1138 cmplx_t u = x0 - x1;
1139 xp0[j] = x0 + u;
1140 xp1[j] = MUL(u, wtab[j]);
1141 }
1142
1143 new_ifft_short2(xp1, yn, lgN - 1, tab);
1144
1145 // (X, Y) -> (X + Y/w, X - Y/w)
1146 {
1147 inv_butterfly_loop(yn, xp0, xp1, wtab);
1148 }
1149 }
1150 }
1151
1152
1153
1154 static void
new_ifft(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1155 new_ifft(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1156 {
1157 long N = 1L << lgN;
1158 new_ifft_short1(xp, N, lgN, tab);
1159 }
1160
1161
1162 static void
compute_table(vector<aligned_vector<cmplx_t>> & tab,long k)1163 compute_table(vector<aligned_vector<cmplx_t>>& tab, long k)
1164 {
1165 if (k < 2) return;
1166
1167 const ldbl pi = std::atan(ldbl(1)) * 4.0;
1168
1169 tab.resize(k+1);
1170 for (long s = 2; s <= k; s++) {
1171 long m = 1L << s;
1172 tab[s].resize(m/2);
1173 for (long j = 0; j < m/2; j++) {
1174 ldbl angle = -((2 * pi) * (ldbl(j)/ldbl(m)));
1175 tab[s][j] = cmplx_t(std::cos(angle), std::sin(angle));
1176 }
1177 }
1178 }
1179
1180 static long
RevInc(long a,long k)1181 RevInc(long a, long k)
1182 {
1183 long j, m;
1184
1185 j = k;
1186 m = 1L << (k-1);
1187
1188 while (j && (m & a)) {
1189 a ^= m;
1190 m >>= 1;
1191 j--;
1192 }
1193 if (j) a ^= m;
1194 return a;
1195 }
1196
1197 static void
BRC_init(long k,vector<long> & rev)1198 BRC_init(long k, vector<long>& rev)
1199 {
1200 long n = (1L << k);
1201 rev.resize(n);
1202 long i, j;
1203 for (i = 0, j = 0; i < n; i++, j = RevInc(j, k))
1204 rev[i] = j;
1205 }
1206
1207
1208 #define PGFFT_BRC_THRESH (11)
1209 #define PGFFT_BRC_Q (5)
1210 // Must have PGFFT_BRC_THRESH >= 2*PGFFT_BRC_Q
1211 // Should also have (1L << (2*PGFFT_BRC_Q)) small enough
1212 // so that we can fit that many cmplx_t's into the cache
1213
1214
1215 static
BasicBitReverseCopy(cmplx_t * B,const cmplx_t * A,long k,const vector<long> & rev)1216 void BasicBitReverseCopy(cmplx_t *B,
1217 const cmplx_t *A, long k, const vector<long>& rev)
1218 {
1219 long n = 1L << k;
1220 long i, j;
1221
1222 for (i = 0; i < n; i++)
1223 B[rev[i]] = A[i];
1224 }
1225
1226 static void
COBRA(cmplx_t * RESTRICT B,const cmplx_t * RESTRICT A,long k,const vector<long> & rev,const vector<long> rev1)1227 COBRA(cmplx_t * RESTRICT B, const cmplx_t * RESTRICT A, long k,
1228 const vector<long>& rev, const vector<long> rev1)
1229 {
1230 constexpr long q = PGFFT_BRC_Q;
1231 long k1 = k - 2*q;
1232
1233 aligned_vector<cmplx_t> BRC_temp(1L << (2*q));
1234
1235 cmplx_t * RESTRICT T = &BRC_temp[0];
1236 const long * RESTRICT rev_k1 = &rev[0];
1237 const long * RESTRICT rev_q = &rev1[0];
1238
1239
1240 for (long b = 0; b < (1L << k1); b++) {
1241 long b1 = rev_k1[b];
1242 for (long a = 0; a < (1L << q); a++) {
1243 long a1 = rev_q[a];
1244 cmplx_t *T_p = &T[a1 << q];
1245 const cmplx_t *A_p = &A[(a << (k1+q)) + (b << q)];
1246 #ifdef USE_PD4
1247 for (long c = 0; c < (1 << q); c += 4) {
1248 PD4 x0 = PD4::load(reinterpret_cast<const double*>(&A_p[c+0]));
1249 PD4 x1 = PD4::load(reinterpret_cast<const double*>(&A_p[c+2]));
1250 store(reinterpret_cast<double*>(&T_p[c+0]), x0);
1251 store(reinterpret_cast<double*>(&T_p[c+2]), x1);
1252 }
1253 #else
1254 for (long c = 0; c < (1L << q); c++) T_p[c] = A_p[c];
1255 #endif
1256 }
1257
1258 for (long c = 0; c < (1L << q); c++) {
1259 long c1 = rev_q[c];
1260 cmplx_t *B_p = &B[(c1 << (k1+q)) + (b1 << q)];
1261 cmplx_t *T_p = &T[c];
1262 for (long a1 = 0; a1 < (1l << q); a1++)
1263 B_p[a1] = T_p[a1 << q];
1264 }
1265 }
1266 }
1267
1268
1269 static long
pow2_precomp(long n,vector<long> & rev,vector<long> & rev1,vector<aligned_vector<cmplx_t>> & tab)1270 pow2_precomp(long n, vector<long>& rev, vector<long>& rev1, vector<aligned_vector<cmplx_t>>& tab)
1271 {
1272 // k = least k such that 2^k >= n
1273 long k = 0;
1274 while ((1L << k) < n) k++;
1275
1276 compute_table(tab, k);
1277
1278
1279 if (k <= PGFFT_BRC_THRESH) {
1280 BRC_init(k, rev);
1281 }
1282 else {
1283 long q = PGFFT_BRC_Q;
1284 long k1 = k - 2*q;
1285 BRC_init(k1, rev);
1286 BRC_init(q, rev1);
1287 }
1288
1289
1290 return k;
1291 }
1292
1293 static void
pow2_comp(const cmplx_t * src,cmplx_t * dst,long n,long k,const vector<long> & rev,const vector<long> & rev1,const vector<aligned_vector<cmplx_t>> & tab)1294 pow2_comp(const cmplx_t* src, cmplx_t* dst,
1295 long n, long k, const vector<long>& rev, const vector<long>& rev1,
1296 const vector<aligned_vector<cmplx_t>>& tab)
1297 {
1298 aligned_vector<cmplx_t> x;
1299 x.assign(src, src+n);
1300
1301 new_fft(&x[0], k, tab);
1302 #if 0
1303 for (long i = 0; i < n; i++) dst[i] = x[i];
1304 #else
1305 if (k <= PGFFT_BRC_THRESH)
1306 BasicBitReverseCopy(&dst[0], &x[0], k, rev);
1307 else
1308 COBRA(&dst[0], &x[0], k, rev, rev1);
1309 #endif
1310 }
1311
1312 static long
bluestein_precomp(long n,aligned_vector<cmplx_t> & powers,aligned_vector<cmplx_t> & Rb,vector<aligned_vector<cmplx_t>> & tab)1313 bluestein_precomp(long n, aligned_vector<cmplx_t>& powers,
1314 aligned_vector<cmplx_t>& Rb,
1315 vector<aligned_vector<cmplx_t>>& tab)
1316 {
1317 // k = least k such that 2^k >= 2*n-1
1318 long k = 0;
1319 while ((1L << k) < 2*n-1) k++;
1320
1321 compute_table(tab, k);
1322
1323 const ldbl pi = std::atan(ldbl(1)) * 4.0;
1324
1325 powers.resize(n);
1326 powers[0] = 1;
1327 long i_sqr = 0;
1328 for (long i = 1; i < n; i++) {
1329 // i^2 = (i-1)^2 + 2*i-1
1330 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1331 ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(2*n)));
1332 powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1333 }
1334
1335 long N = 1L << k;
1336 Rb.resize(N);
1337 for (long i = 0; i < N; i++) Rb[i] = 0;
1338
1339 Rb[n-1] = 1;
1340 i_sqr = 0;
1341 for (long i = 1; i < n; i++) {
1342 // i^2 = (i-1)^2 + 2*i-1
1343 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1344 ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(2*n));
1345 Rb[n-1+i] = Rb[n-1-i] = cmplx_t(std::cos(angle), std::sin(angle));
1346 }
1347
1348 new_fft(&Rb[0], k, tab);
1349
1350 double Ninv = 1/double(N);
1351 for (long i = 0; i < N; i++)
1352 Rb[i] *= Ninv;
1353
1354 return k;
1355
1356 }
1357
1358
1359 static void
bluestein_comp(const cmplx_t * src,cmplx_t * dst,long n,long k,const aligned_vector<cmplx_t> & powers,const aligned_vector<cmplx_t> & Rb,const vector<aligned_vector<cmplx_t>> & tab)1360 bluestein_comp(const cmplx_t* src, cmplx_t* dst,
1361 long n, long k, const aligned_vector<cmplx_t>& powers,
1362 const aligned_vector<cmplx_t>& Rb,
1363 const vector<aligned_vector<cmplx_t>>& tab)
1364 {
1365 long N = 1L << k;
1366
1367 aligned_vector<cmplx_t> x(N);
1368
1369 for (long i = 0; i < n; i++)
1370 x[i] = MUL(src[i], powers[i]);
1371
1372 for (long i = n; i < N; i++)
1373 x[i] = 0;
1374
1375 new_fft(&x[0], k, tab);
1376
1377 // for (long i = 0; i < N; i++) x[i] = MUL(x[i], Rb[i]);
1378 mul_loop(N, &x[0], &Rb[0]);
1379
1380 new_ifft(&x[0], k, tab);
1381
1382 double Ninv = 1/double(N);
1383
1384 for (long i = 0; i < n; i++)
1385 dst[i] = MUL(x[n-1+i], powers[i]);
1386
1387 }
1388
1389
1390
1391
1392 static long
bluestein_precomp1(long n,aligned_vector<cmplx_t> & powers,aligned_vector<cmplx_t> & Rb,vector<aligned_vector<cmplx_t>> & tab)1393 bluestein_precomp1(long n, aligned_vector<cmplx_t>& powers,
1394 aligned_vector<cmplx_t>& Rb,
1395 vector<aligned_vector<cmplx_t>>& tab)
1396 {
1397 // k = least k such that 2^k >= 2*n-1
1398 long k = 0;
1399 while ((1L << k) < 2*n-1) k++;
1400
1401 compute_table(tab, k);
1402
1403 const ldbl pi = std::atan(ldbl(1)) * 4.0;
1404
1405 powers.resize(n);
1406 powers[0] = 1;
1407 long i_sqr = 0;
1408
1409 if (n % 2 == 0) {
1410 for (long i = 1; i < n; i++) {
1411 // i^2 = (i-1)^2 + 2*i-1
1412 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1413 ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(2*n)));
1414 powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1415 }
1416 }
1417 else {
1418 for (long i = 1; i < n; i++) {
1419 // i^2*((n+1)/2) = (i-1)^2*((n+1)/2) + i + ((n-1)/2) (mod n)
1420 i_sqr = (i_sqr + i + (n-1)/2) % n;
1421 ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(n)));
1422 powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1423 }
1424 }
1425
1426 long N = 1L << k;
1427 Rb.resize(N);
1428 for (long i = 0; i < N; i++) Rb[i] = 0;
1429
1430 Rb[0] = 1;
1431 i_sqr = 0;
1432
1433 if (n % 2 == 0) {
1434 for (long i = 1; i < n; i++) {
1435 // i^2 = (i-1)^2 + 2*i-1
1436 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1437 ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(2*n));
1438 Rb[i] = cmplx_t(std::cos(angle), std::sin(angle));
1439 }
1440 }
1441 else {
1442 for (long i = 1; i < n; i++) {
1443 // i^2*((n+1)/2) = (i-1)^2*((n+1)/2) + i + ((n-1)/2) (mod n)
1444 i_sqr = (i_sqr + i + (n-1)/2) % n;
1445 ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(n));
1446 Rb[i] = cmplx_t(std::cos(angle), std::sin(angle));
1447 }
1448 }
1449
1450 new_fft(&Rb[0], k, tab);
1451
1452 double Ninv = 1/double(N);
1453 for (long i = 0; i < N; i++)
1454 Rb[i] *= Ninv;
1455
1456 return k;
1457
1458 }
1459
1460
1461 static void
bluestein_comp1(const cmplx_t * src,cmplx_t * dst,long n,long k,const aligned_vector<cmplx_t> & powers,const aligned_vector<cmplx_t> & Rb,const vector<aligned_vector<cmplx_t>> & tab)1462 bluestein_comp1(const cmplx_t* src, cmplx_t* dst,
1463 long n, long k, const aligned_vector<cmplx_t>& powers,
1464 const aligned_vector<cmplx_t>& Rb,
1465 const vector<aligned_vector<cmplx_t>>& tab)
1466 {
1467 long N = 1L << k;
1468
1469 aligned_vector<cmplx_t> x(N);
1470
1471 for (long i = 0; i < n; i++)
1472 x[i] = MUL(src[i], powers[i]);
1473
1474 long len = FFTRoundUp(2*n-1, k);
1475 long ilen = FFTRoundUp(n, k);
1476
1477 for (long i = n; i < ilen; i++)
1478 x[i] = 0;
1479
1480 new_fft_short(&x[0], len, ilen, k, tab);
1481
1482 // for (long i = 0; i < len; i++) x[i] = MUL(x[i], Rb[i]);
1483 mul_loop(len, &x[0], &Rb[0]);
1484
1485 new_ifft_short1(&x[0], len, k, tab);
1486
1487 double Ninv = 1/double(N);
1488
1489 for (long i = 0; i < n-1; i++)
1490 dst[i] = MUL(x[i] + x[n+i], powers[i]);
1491
1492 dst[n-1] = MUL(x[n-1], powers[n-1]);
1493
1494 }
1495
1496 #define PGFFT_STRATEGY_NULL (0)
1497 #define PGFFT_STRATEGY_POW2 (1)
1498 #define PGFFT_STRATEGY_BLUE (2)
1499 #define PGFFT_STRATEGY_TBLUE (3)
1500
choose_strategy(long n)1501 static long choose_strategy(long n)
1502 {
1503 if (n == 1) return PGFFT_STRATEGY_NULL;
1504
1505 if ((n & (n - 1)) == 0) return PGFFT_STRATEGY_POW2;
1506
1507 if (!PGFFT_USE_TRUNCATED_BLUE) return PGFFT_STRATEGY_BLUE;
1508
1509 // choose between Bluestein and truncated Bluestein
1510
1511 // k = least k such that 2^k >= 2*n-1
1512 long k = 0;
1513 while ((1L << k) < 2*n-1) k++;
1514
1515 long rdup = FFTRoundUp(2*n-1, k);
1516 if (rdup == (1L << k)) return PGFFT_STRATEGY_BLUE;
1517
1518 return PGFFT_STRATEGY_TBLUE;
1519 }
1520
PGFFT(long n_)1521 PGFFT::PGFFT(long n_)
1522 {
1523 assert(n_ > 0);
1524 n = n_;
1525
1526 strategy = choose_strategy(n);
1527
1528 //std::cout << strategy << "\n";
1529
1530 switch (strategy) {
1531
1532 case PGFFT_STRATEGY_NULL:
1533 break;
1534
1535 case PGFFT_STRATEGY_POW2:
1536 k = pow2_precomp(n, rev, rev1, tab);
1537 break;
1538
1539 case PGFFT_STRATEGY_BLUE:
1540 k = bluestein_precomp(n, powers, Rb, tab);
1541 break;
1542
1543 case PGFFT_STRATEGY_TBLUE:
1544 k = bluestein_precomp1(n, powers, Rb, tab);
1545 break;
1546
1547 default: ;
1548
1549 }
1550 }
1551
apply(const cmplx_t * src,cmplx_t * dst) const1552 void PGFFT::apply(const cmplx_t* src, cmplx_t* dst) const
1553 {
1554 switch (strategy) {
1555
1556 case PGFFT_STRATEGY_NULL:
1557 break;
1558
1559 case PGFFT_STRATEGY_POW2:
1560 pow2_comp(src, dst, n, k, rev, rev1, tab);
1561 break;
1562
1563 case PGFFT_STRATEGY_BLUE:
1564 bluestein_comp(src, dst, n, k, powers, Rb, tab);
1565 break;
1566
1567 case PGFFT_STRATEGY_TBLUE:
1568 bluestein_comp1(src, dst, n, k, powers, Rb, tab);
1569 break;
1570
1571 default: ;
1572
1573 }
1574 }
1575
1576 }
1577
1578
1579 /****************************************************************************
1580
1581 PGFFT: Pretty Good FFT (v1.8)
1582
1583 Copyright (C) 2019, victor Shoup
1584
1585 All rights reserved.
1586
1587 Redistribution and use in source and binary forms, with or without
1588 modification, are permitted provided that the following conditions are met:
1589
1590 * Redistributions of source code must retain the above copyright notice, this
1591 list of conditions and the following disclaimer.
1592 * Redistributions in binary form must reproduce the above copyright notice,
1593 this list of conditions and the following disclaimer in the documentation
1594 and/or other materials provided with the distribution.
1595
1596 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
1597 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
1598 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
1599 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
1600 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
1601 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
1602 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
1603 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
1604 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1605 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1606
1607 ****************************************************************************
1608
1609 The logic of this code is derived from code originally developed by David Harvey,
1610 even though the code itself has been essentially rewritten from scratch.
1611 Here is David Harvey's original copyright notice.
1612
1613 fft62: a library for number-theoretic transforms
1614
1615 Copyright (C) 2013, David Harvey
1616
1617 All rights reserved.
1618
1619 Redistribution and use in source and binary forms, with or without
1620 modification, are permitted provided that the following conditions are met:
1621
1622 * Redistributions of source code must retain the above copyright notice, this
1623 list of conditions and the following disclaimer.
1624 * Redistributions in binary form must reproduce the above copyright notice,
1625 this list of conditions and the following disclaimer in the documentation
1626 and/or other materials provided with the distribution.
1627
1628 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
1629 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
1630 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
1631 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
1632 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
1633 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
1634 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
1635 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
1636 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1637 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1638
1639 ****************************************************************************/
1640
1641 #pragma GCC diagnostic pop
1642