1 /****************************************************************************
2 
3 PGFFT: Pretty Good FFT (v1.8)
4 
5 Copyright (C) 2019, victor Shoup
6 
7 See below for more details.
8 
9 ****************************************************************************/
10 
11 
12 
13 #pragma GCC diagnostic push
14 #pragma GCC diagnostic ignored "-Wunused-variable"
15 #pragma GCC diagnostic ignored "-Wunused-function"
16 #ifdef __GNUC__
17 #ifndef __clang__
18 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
19 #endif
20 #endif
21 
22 #define PGFFT_USE_TRUNCATED_BLUE (1)
23 // set to 0 to disable the truncated Bluestein
24 
25 #define PGFFT_USE_EXPLICIT_MUL (1)
26 // Set to 0 to disable explict complex multiplication.
27 // The built-in complex multiplication routines are
28 // incredibly slow, because the standard requires special handling
29 // of non-finite complex values.
30 // To fix the problem, by default, PGFFT will override these routines
31 // with explicitly defined multiplication functions.
32 // Another way to solve this problem is to compile with the -ffast-math
33 // option (at least, on gcc, that's the right flag).
34 
35 
36 //============================================
37 
38 #ifndef PGFFT_DISABLE_SIMD
39 
40 #ifdef __AVX__
41 #define HAVE_AVX
42 //#warning "HAVE_AVX"
43 #endif
44 
45 #ifdef __AVX2__
46 #define HAVE_AVX2
47 //#warning "HAVE_AVX2"
48 #endif
49 
50 #if defined(HAVE_AVX) || defined(HAVE_AVX2)
51 #define USE_PD4
52 #endif
53 
54 #endif
55 
56 
57 
58 
59 #include <helib/PGFFT.h>
60 #include <cassert>
61 #include <cstdlib>
62 #include <limits>
63 
64 #ifdef USE_PD4
65 #include <immintrin.h>
66 #endif
67 
68 namespace helib {
69 
70 using std::vector;
71 using std::complex;
72 
73 template<class T>
74 using aligned_vector = PGFFT::aligned_vector<T>;
75 
76 typedef complex<double> cmplx_t;
77 typedef long double ldbl;
78 //typedef double ldbl;
79 
80 #ifdef USE_PD4
simd_enabled()81 bool PGFFT::simd_enabled() { return true; }
82 #else
simd_enabled()83 bool PGFFT::simd_enabled() { return false; }
84 #endif
85 
86 
87 #if (PGFFT_USE_EXPLICIT_MUL)
88 
89 static inline cmplx_t
MUL(cmplx_t a,cmplx_t b)90 MUL(cmplx_t a, cmplx_t b)
91 {
92    double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
93    return cmplx_t(x*u-y*v, x*v+y*u);
94 }
95 
96 static inline cmplx_t
CMUL(cmplx_t a,cmplx_t b)97 CMUL(cmplx_t a, cmplx_t b)
98 {
99    double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
100    return cmplx_t(x*u+y*v, y*u-x*v);
101 }
102 
103 #else
104 
105 #define MUL(a, b) ((a) * (b))
106 #define CMUL(a, b) ((a) * std::conj(b))
107 
108 #endif
109 
110 
111 #if (defined(__GNUC__) && (__GNUC__ >= 4))
112 
113 // on relative modern versions of gcc, we can
114 // decalare "restricted" pointers in C++
115 
116 #define RESTRICT __restrict
117 
118 #else
119 
120 #define RESTRICT
121 
122 #endif
123 
124 /**************************************************************
125 
126    Aligned allocation
127 
128 **************************************************************/
129 
130 #ifdef USE_PD4
131 
132 #define PGFFT_ALIGN (64)
133 
134 void *
aligned_allocate(std::size_t n,std::size_t nelts)135 PGFFT::aligned_allocate(std::size_t n, std::size_t nelts)
136 {
137    if (n > std::numeric_limits<std::size_t>::max() / nelts) return 0;
138    std::size_t sz = n * nelts;
139    std::size_t alignment = PGFFT_ALIGN;
140    if (sz > std::numeric_limits<std::size_t>::max() - alignment) return 0;
141 
142    sz += alignment;
143    char* buf = (char*) std::malloc(sz);
144 
145    if (!buf) return 0;
146 
147    int remainder = ((unsigned long long)buf) % alignment;
148    int offset = alignment - remainder;
149    char* ret = buf + offset;
150 
151    ret[-1] = offset;
152 
153    return ret;
154 }
155 
156 
157 void
aligned_deallocate(void * p)158 PGFFT::aligned_deallocate(void *p)
159 {
160    if (!p) return;
161    char *cp = (char *) p;
162    int offset = cp[-1];
163    std::free(cp - offset);
164 }
165 
166 #else
167 
168 void *
aligned_allocate(std::size_t n,std::size_t sz)169 PGFFT::aligned_allocate(std::size_t n, std::size_t sz)
170 {
171    if (n > std::numeric_limits<std::size_t>::max() / sz) return 0;
172    std::size_t size = n * sz;
173    return std::malloc(size);
174 }
175 
176 void
aligned_deallocate(void * p)177 PGFFT::aligned_deallocate(void *p)
178 {
179    if (!p) return;
180    std::free(p);
181 }
182 
183 #endif
184 
185 /**************************************************************
186 
187    Packed Double abstraction layer
188 
189 **************************************************************/
190 
191 
192 
193 namespace {
194 
195 
196 
197 //=================== PD4 implementation ===============
198 
199 #if defined(USE_PD4)
200 
201 struct PD4 {
202    __m256d data;
203 
204 
205    PD4() = default;
PD4helib::__anon00fb31d40111::PD4206    PD4(double x) : data(_mm256_set1_pd(x)) { }
PD4helib::__anon00fb31d40111::PD4207    PD4(__m256d _data) : data(_data) { }
PD4helib::__anon00fb31d40111::PD4208    PD4(double d0, double d1, double d2, double d3)
209       : data(_mm256_set_pd(d3, d2, d1, d0)) { }
210 
loadhelib::__anon00fb31d40111::PD4211    static PD4 load(const double *p) { return _mm256_load_pd(p); }
212 
213    // load from unaligned address
loaduhelib::__anon00fb31d40111::PD4214    static PD4 loadu(const double *p) { return _mm256_loadu_pd(p); }
215 };
216 
217 inline void
load(PD4 & x,const double * p)218 load(PD4& x, const double *p)
219 { x = PD4::load(p); }
220 
221 // load from unaligned address
222 inline void
loadu(PD4 & x,const double * p)223 loadu(PD4& x, const double *p)
224 { x = PD4::loadu(p); }
225 
226 inline void
store(double * p,PD4 a)227 store(double *p, PD4 a)
228 { _mm256_store_pd(p, a.data); }
229 
230 // store to unaligned address
231 inline void
storeu(double * p,PD4 a)232 storeu(double *p, PD4 a)
233 { _mm256_storeu_pd(p, a.data); }
234 
235 
236 // swap even/odd slots
237 // e.g., 0123 -> 1032
238 inline PD4
swap2(PD4 a)239 swap2(PD4 a)
240 { return _mm256_permute_pd(a.data, 0x5); }
241 
242 // 0123 -> 0022
243 inline PD4
dup2even(PD4 a)244 dup2even(PD4 a)
245 { return _mm256_permute_pd(a.data, 0);   }
246 
247 // 0123 -> 1133
248 inline PD4
dup2odd(PD4 a)249 dup2odd(PD4 a)
250 { return _mm256_permute_pd(a.data, 0xf);   }
251 
252 // blend even/odd slots
253 // 0123, 4567 -> 0527
254 inline PD4
blend2(PD4 a,PD4 b)255 blend2(PD4 a, PD4 b)
256 { return _mm256_blend_pd(a.data, b.data, 0xa); }
257 
258 // 0123, 4567 -> 0426
259 inline PD4
blend_even(PD4 a,PD4 b)260 blend_even(PD4 a, PD4 b)
261 { return _mm256_unpacklo_pd(a.data, b.data); }
262 
263 
264 // 0123, 4567 -> 1537
265 inline PD4
blend_odd(PD4 a,PD4 b)266 blend_odd(PD4 a, PD4 b)
267 { return _mm256_unpackhi_pd(a.data, b.data); }
268 
269 
270 inline void
clear(PD4 & x)271 clear(PD4& x)
272 { x.data = _mm256_setzero_pd(); }
273 
274 inline PD4
operator +(PD4 a,PD4 b)275 operator+(PD4 a, PD4 b)
276 { return _mm256_add_pd(a.data, b.data); }
277 
278 inline PD4
operator -(PD4 a,PD4 b)279 operator-(PD4 a, PD4 b)
280 { return _mm256_sub_pd(a.data, b.data); }
281 
282 inline PD4
operator *(PD4 a,PD4 b)283 operator*(PD4 a, PD4 b)
284 { return _mm256_mul_pd(a.data, b.data); }
285 
286 inline PD4
operator /(PD4 a,PD4 b)287 operator/(PD4 a, PD4 b)
288 { return _mm256_div_pd(a.data, b.data); }
289 
290 inline PD4&
operator +=(PD4 & a,PD4 b)291 operator+=(PD4& a, PD4 b)
292 { a = a + b; return a; }
293 
294 inline PD4&
operator -=(PD4 & a,PD4 b)295 operator-=(PD4& a, PD4 b)
296 { a = a - b; return a; }
297 
298 inline PD4&
operator *=(PD4 & a,PD4 b)299 operator*=(PD4& a, PD4 b)
300 { a = a * b; return a; }
301 
302 inline PD4&
operator /=(PD4 & a,PD4 b)303 operator/=(PD4& a, PD4 b)
304 { a = a / b; return a; }
305 
306 #ifdef HAVE_AVX2
307 
308 // a*b+c (fused)
309 inline PD4
fused_muladd(PD4 a,PD4 b,PD4 c)310 fused_muladd(PD4 a, PD4 b, PD4 c)
311 { return _mm256_fmadd_pd(a.data, b.data, c.data); }
312 // NEEDS: FMA
313 
314 // a*b-c (fused)
315 inline PD4
fused_mulsub(PD4 a,PD4 b,PD4 c)316 fused_mulsub(PD4 a, PD4 b, PD4 c)
317 { return _mm256_fmsub_pd(a.data, b.data, c.data); }
318 // NEEDS: FMA
319 
320 // -a*b+c (fused)
321 inline PD4
fused_negmuladd(PD4 a,PD4 b,PD4 c)322 fused_negmuladd(PD4 a, PD4 b, PD4 c)
323 { return _mm256_fnmadd_pd(a.data, b.data, c.data); }
324 // NEEDS: FMA
325 
326 // (a0,a1,a2,a3), (b0,b1,b2,b3), (c0,c1,c2,c3) ->
327 // (a0*b0-c0, a1*b1+c1, a2*b2-c2, a3*b3+c3)
328 inline PD4
fmaddsub(PD4 a,PD4 b,PD4 c)329 fmaddsub(PD4 a, PD4 b, PD4 c)
330 { return _mm256_fmaddsub_pd(a.data, b.data, c.data); }
331 // NEEDS: FMA
332 // (plain addsub only needs AVX)
333 
334 // (a0,a1,a2,a3), (b0,b1,b2,b3), (c0,c1,c2,c3) ->
335 // (a0*b0+c0, a1*b1-c1, a2*b2+c2, a3*b3-c3)
336 inline PD4
fmsubadd(PD4 a,PD4 b,PD4 c)337 fmsubadd(PD4 a, PD4 b, PD4 c)
338 { return _mm256_fmsubadd_pd(a.data, b.data, c.data); }
339 // NEEDS: FMA
340 // (there is no plain subadd)
341 #endif
342 
343 
344 #endif
345 
346 }
347 
348 
349 /***************************************************************
350 
351 
352 TRUNCATED FFT
353 
354 This code is derived from code originally developed
355 by David Harvey.  I include his original documentation,
356 annotated appropriately to highlight differences in
357 the implemebtation (see NOTEs).
358 
359 The DFT is defined as follows.
360 
361 Let the input sequence be a_0, ..., a_{N-1}.
362 
363 Let w = standard primitive N-th root of 1, i.e. w = g^(2^FFT62_MAX_LGN / N),
364 where g = some fixed element of Z/pZ of order 2^FFT62_MAX_LGN.
365 
366 Let Z = an element of (Z/pZ)^* (twisting parameter).
367 
368 Then the output sequence is
369   b_j = \sum_{0 <= i < N} Z^i a_i w^(ij'), for 0 <= j < N,
370 where j' is the length-lgN bit-reversal of j.
371 
372 Some of the FFT routines can operate on truncated sequences of certain
373 "admissible" sizes. A size parameter n is admissible if 1 <= n <= N, and n is
374 divisible by a certain power of 2. The precise power depends on the recursive
375 array decomposition of the FFT. The smallest admissible n' >= n can be
376 obtained via fft62_next_size().
377 
378 NOTE: the twising parameter is not implemented.
379 NOTE: the next admissible size function is called FFTRoundUp,
380 
381 
382 Truncated FFT interface is as follows:
383 
384 xn and yn must be admissible sizes for N.
385 
386 Input in xp[] is a_0, a_1, ..., a_{xn-1}. Assumes a_i = 0 for xn <= i < N.
387 
388 Output in yp[] is b_0, ..., b_{yn-1}, i.e. only first yn outputs are computed.
389 
390 Twisting parameter Z is described by z and lgH. If z == 0, then Z = basic
391 2^lgH-th root of 1, and must have lgH >= lgN + 1. If z != 0, then Z = z
392 (and lgH is ignored).
393 
394 The buffers {xp,xn} and {yp,yn} may overlap, but only if xp == yp.
395 
396 Inputs are in [0, 2p), outputs are in [0, 2p).
397 
398 threads = number of OpenMP threads to use.
399 
400 
401 
402 Inverse truncated FFT interface is as follows.
403 
404 xn and yn must be admissible sizes for N, with yn <= xn.
405 
406 Input in xp[] is b_0, b_1, ..., b_{yn-1}, N*a_{yn}, ..., N*a_{xn-1}.
407 
408 Assumes a_i = 0 for xn <= i < N.
409 
410 Output in yp[] is N*a_0, ..., N*a_{yn-1}.
411 
412 Twisting parameter Z is described by z and lgH. If z == 0, then Z = basic
413 2^lgH-th root of 1, and must have lgH >= lgN + 1. If z != 0, then Z = z^(-1)
414 (and lgH is ignored).
415 
416 The buffers {xp,xn} and {yp,yn} may overlap, but only if xp == yp.
417 
418 Inputs are in [0, 4p), outputs are in [0, 4p).
419 
420 threads = number of OpenMP threads to use.
421 
422 (note: no function actually implements this interface in full generality!
423 This is because it is tricky (and not that useful) to implement the twisting
424 parameter when xn != yn.)
425 
426 NOTE: threads and twisting parameter are not used here.
427 NOTE: the code has been re-written and simplified so that
428   everything is done in place, so xp == yp.
429 
430 
431 ***************************************************************/
432 
433 
434 
435 #define PGFFT_FFT_RDUP (4)
436 // Currently, this should be at least 2 to support
437 // loop unrolling in the FFT implementation
438 
439 
440 static inline long
FFTRoundUp(long xn,long k)441 FFTRoundUp(long xn, long k)
442 {
443    long n = 1L << k;
444    if (xn <= 0) return n;
445    // default truncation value of 0 gets converted to n
446 
447    xn = ((xn+((1L << PGFFT_FFT_RDUP)-1)) >> PGFFT_FFT_RDUP) << PGFFT_FFT_RDUP;
448 
449    if (k >= 10) {
450       if (xn > n - (n >> 4)) xn = n;
451    }
452    else {
453       if (xn > n - (n >> 3)) xn = n;
454    }
455    // truncation just a bit below n does not really help
456    // at all, and can sometimes slow things down slightly, so round up
457    // to n.  This also takes care of cases where xn > n.
458    // Actually, for smallish n, we should round up sooner,
459    // at n-n/8, and for larger n, we should round up later,
460    // at n-m/16.  At least, experimentally, this is what I see.
461 
462    return xn;
463 }
464 
465 
466 
467 
468 #define fwd_butterfly(xx0, xx1, w)  \
469 do \
470 { \
471    cmplx_t x0_ = xx0; \
472    cmplx_t x1_ = xx1; \
473    cmplx_t t_  = x0_ -  x1_; \
474    xx0 = x0_ + x1_; \
475    xx1 = MUL(t_, w); \
476 }  \
477 while (0)
478 
479 
480 
481 #define fwd_butterfly0(xx0, xx1) \
482 do   \
483 {  \
484    cmplx_t x0_ = xx0;  \
485    cmplx_t x1_ = xx1;  \
486    xx0 = x0_ + x1_; \
487    xx1 = x0_ - x1_; \
488 }  \
489 while (0)
490 
491 
492 #define inv_butterfly0(xx0, xx1)  \
493 do   \
494 {  \
495    cmplx_t x0_ = xx0;  \
496    cmplx_t x1_ = xx1;  \
497    xx0 = x0_ + x1_;  \
498    xx1 = x0_ - x1_;  \
499 } while (0)
500 
501 
502 #define inv_butterfly(xx0, xx1, w)  \
503 do  \
504 {  \
505    cmplx_t x0_ = xx0;  \
506    cmplx_t x1_ = xx1;  \
507    cmplx_t t_ = CMUL(x1_, w);  \
508    xx0 = x0_ + t_;  \
509    xx1 = x0_ - t_;  \
510 } while (0)
511 
512 
513 
514 #ifdef USE_PD4
515 
516 #ifdef HAVE_AVX2
517 
518 static inline PD4
complex_mul(PD4 ab,PD4 cd)519 complex_mul(PD4 ab, PD4 cd)
520 {
521    PD4 cc = dup2even(cd);
522    PD4 dd = dup2odd(cd);
523    PD4 ba = swap2(ab);
524    return fmaddsub(ab, cc, ba*dd);
525 }
526 
527 static inline PD4
complex_conj_mul(PD4 ab,PD4 cd)528 complex_conj_mul(PD4 ab, PD4 cd)
529 // (ac+bd,bc-ad)
530 {
531    PD4 cc = dup2even(cd);
532    PD4 dd = dup2odd(cd);
533    PD4 ba = swap2(ab);
534    return fmsubadd(ab, cc, ba*dd);
535 }
536 
537 
538 #define MUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
539 do { \
540    x_0 = complex_mul(a_0, b_0); \
541    x_1 = complex_mul(a_1, b_1); \
542 } while (0)
543 
544 #define CMUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
545 do { \
546    x_0 = complex_conj_mul(a_0, b_0); \
547    x_1 = complex_conj_mul(a_1, b_1); \
548 } while (0)
549 
550 #else
551 // This code sequence works without FMA
552 #define MUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
553 do { \
554     PD4 a_re_ = blend_even(a_0, a_1); \
555     PD4 a_im_ = blend_odd(a_0, a_1); \
556  \
557     PD4 b_re_ = blend_even(b_0, b_1); \
558     PD4 b_im_ = blend_odd(b_0, b_1); \
559  \
560     PD4 x_re_ = a_re_*b_re_ - a_im_*b_im_; \
561     PD4 x_im_ = a_re_*b_im_ + a_im_*b_re_; \
562  \
563     x_0 = blend_even(x_re_, x_im_); \
564     x_1 = blend_odd(x_re_, x_im_); \
565 } while (0)
566 
567 #define CMUL2(x_0, x_1, a_0, a_1, b_0, b_1) \
568 do { \
569     PD4 a_re_ = blend_even(a_0, a_1); \
570     PD4 a_im_ = blend_odd(a_0, a_1); \
571  \
572     PD4 b_re_ = blend_even(b_0, b_1); \
573     PD4 b_im_ = blend_odd(b_0, b_1); \
574  \
575     PD4 x_re_ = a_re_*b_re_ + a_im_*b_im_; \
576     PD4 x_im_ = a_im_*b_re_ - a_re_*b_im_; \
577  \
578     x_0 = blend_even(x_re_, x_im_); \
579     x_1 = blend_odd(x_re_, x_im_); \
580 } while (0)
581 
582 #endif
583 
584 
585 
586 static inline void
fwd_butterfly_loop_simd(long size,double * RESTRICT xp0,double * RESTRICT xp1,const double * RESTRICT wtab)587 fwd_butterfly_loop_simd(
588    long size,
589    double * RESTRICT xp0,
590    double * RESTRICT xp1,
591    const double * RESTRICT wtab)
592 {
593   for (long j = 0; j < size; j += 4) {
594     PD4 x0_0 = PD4::load(xp0+2*(j+0));
595     PD4 x0_1 = PD4::load(xp0+2*(j+2));
596     PD4 x1_0 = PD4::load(xp1+2*(j+0));
597     PD4 x1_1 = PD4::load(xp1+2*(j+2));
598     PD4 w_0  = PD4::load(wtab+2*(j+0));
599     PD4 w_1  = PD4::load(wtab+2*(j+2));
600 
601     PD4 xx0_0 = x0_0 + x1_0;
602     PD4 xx0_1 = x0_1 + x1_1;
603 
604     PD4 diff_0 = x0_0 - x1_0;
605     PD4 diff_1 = x0_1 - x1_1;
606 
607     PD4 xx1_0, xx1_1;
608     MUL2(xx1_0, xx1_1, diff_0, diff_1, w_0, w_1);
609 
610     store(xp0+2*(j+0), xx0_0);
611     store(xp0+2*(j+2), xx0_1);
612     store(xp1+2*(j+0), xx1_0);
613     store(xp1+2*(j+2), xx1_1);
614   }
615 }
616 
617 static inline void
fwd_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)618 fwd_butterfly_loop(
619    long size,
620    cmplx_t * RESTRICT xp0,
621    cmplx_t * RESTRICT xp1,
622    const cmplx_t * RESTRICT wtab)
623 {
624    // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
625    fwd_butterfly_loop_simd(
626       size,
627       reinterpret_cast<double*>(xp0),
628       reinterpret_cast<double*>(xp1),
629       reinterpret_cast<const double*>(wtab));
630 }
631 
632 static inline void
inv_butterfly_loop_simd(long size,double * RESTRICT xp0,double * RESTRICT xp1,const double * RESTRICT wtab)633 inv_butterfly_loop_simd(
634    long size,
635    double * RESTRICT xp0,
636    double * RESTRICT xp1,
637    const double * RESTRICT wtab)
638 {
639   for (long j = 0; j < size; j += 4) {
640     PD4 x0_0 = PD4::load(xp0+2*(j+0));
641     PD4 x0_1 = PD4::load(xp0+2*(j+2));
642     PD4 x1_0 = PD4::load(xp1+2*(j+0));
643     PD4 x1_1 = PD4::load(xp1+2*(j+2));
644     PD4 w_0  = PD4::load(wtab+2*(j+0));
645     PD4 w_1  = PD4::load(wtab+2*(j+2));
646 
647     PD4 t_0, t_1;
648     CMUL2(t_0, t_1, x1_0, x1_1, w_0, w_1);
649 
650     PD4 xx0_0 = x0_0 + t_0;
651     PD4 xx0_1 = x0_1 + t_1;
652 
653     PD4 xx1_0 = x0_0 - t_0;
654     PD4 xx1_1 = x0_1 - t_1;
655 
656     store(xp0+2*(j+0), xx0_0);
657     store(xp0+2*(j+2), xx0_1);
658     store(xp1+2*(j+0), xx1_0);
659     store(xp1+2*(j+2), xx1_1);
660   }
661 }
662 
663 static inline void
inv_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)664 inv_butterfly_loop(
665    long size,
666    cmplx_t * RESTRICT xp0,
667    cmplx_t * RESTRICT xp1,
668    const cmplx_t * RESTRICT wtab)
669 {
670    // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
671    inv_butterfly_loop_simd(
672       size,
673       reinterpret_cast<double*>(xp0),
674       reinterpret_cast<double*>(xp1),
675       reinterpret_cast<const double*>(wtab));
676 }
677 
678 #else
679 
680 static inline void
fwd_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)681 fwd_butterfly_loop(
682    long size,
683    cmplx_t * RESTRICT xp0,
684    cmplx_t * RESTRICT xp1,
685    const cmplx_t * RESTRICT wtab)
686 {
687    fwd_butterfly0(xp0[0+0], xp1[0+0]);
688    fwd_butterfly(xp0[0+1], xp1[0+1], wtab[0+1]);
689    fwd_butterfly(xp0[0+2], xp1[0+2], wtab[0+2]);
690    fwd_butterfly(xp0[0+3], xp1[0+3], wtab[0+3]);
691    for (long j = 4; j < size; j += 4) {
692      fwd_butterfly(xp0[j+0], xp1[j+0], wtab[j+0]);
693      fwd_butterfly(xp0[j+1], xp1[j+1], wtab[j+1]);
694      fwd_butterfly(xp0[j+2], xp1[j+2], wtab[j+2]);
695      fwd_butterfly(xp0[j+3], xp1[j+3], wtab[j+3]);
696    }
697 }
698 
699 static inline void
inv_butterfly_loop(long size,cmplx_t * RESTRICT xp0,cmplx_t * RESTRICT xp1,const cmplx_t * RESTRICT wtab)700 inv_butterfly_loop(
701    long size,
702    cmplx_t * RESTRICT xp0,
703    cmplx_t * RESTRICT xp1,
704    const cmplx_t * RESTRICT wtab)
705 {
706    inv_butterfly0(xp0[0+0], xp1[0+0]);
707    inv_butterfly(xp0[0+1], xp1[0+1], wtab[0+1]);
708    inv_butterfly(xp0[0+2], xp1[0+2], wtab[0+2]);
709    inv_butterfly(xp0[0+3], xp1[0+3], wtab[0+3]);
710    for (long j = 4; j < size; j += 4) {
711      inv_butterfly(xp0[j+0], xp1[j+0], wtab[j+0]);
712      inv_butterfly(xp0[j+1], xp1[j+1], wtab[j+1]);
713      inv_butterfly(xp0[j+2], xp1[j+2], wtab[j+2]);
714      inv_butterfly(xp0[j+3], xp1[j+3], wtab[j+3]);
715    }
716 }
717 
718 #endif
719 
720 
721 #if (defined(USE_PD4))
722 
723 
724 static inline void
mul_loop_simd(long size,double * RESTRICT xp,const double * yp)725 mul_loop_simd(
726    long size,
727    double * RESTRICT xp,
728    const double * yp)
729 {
730   long j;
731   for (j = 0; j < size; j += 4) {
732     PD4 x_0 = PD4::load(xp+2*(j+0));
733     PD4 x_1 = PD4::load(xp+2*(j+2));
734     PD4 y_0 = PD4::load(yp+2*(j+0));
735     PD4 y_1 = PD4::load(yp+2*(j+2));
736 
737     PD4 z_0, z_1;
738     MUL2(z_0, z_1, x_0, x_1, y_0, y_1);
739 
740     store(xp+2*(j+0), z_0);
741     store(xp+2*(j+2), z_1);
742   }
743 }
744 
745 
746 static inline void
mul_loop(long size,cmplx_t * xp,const cmplx_t * yp)747 mul_loop(
748    long size,
749    cmplx_t * xp,
750    const cmplx_t * yp)
751 {
752    // NOTE: C++11 guarantees that these reinterpret_cast's work as expected
753    mul_loop_simd(
754       size,
755       reinterpret_cast<double*>(xp),
756       reinterpret_cast<const double*>(yp));
757 }
758 
759 
760 #else
761 
762 
763 static inline void
mul_loop(long size,cmplx_t * xp,const cmplx_t * yp)764 mul_loop(
765    long size,
766    cmplx_t * xp,
767    const cmplx_t * yp)
768 {
769   for (long j = 0; j < size; j++)
770     xp[j] = MUL(xp[j], yp[j]);
771 }
772 
773 #endif
774 
775 
776 // requires size divisible by 8
777 static void
new_fft_layer(cmplx_t * xp,long blocks,long size,const cmplx_t * RESTRICT wtab)778 new_fft_layer(cmplx_t* xp, long blocks, long size,
779               const cmplx_t* RESTRICT wtab)
780 {
781   size /= 2;
782 
783   do
784     {
785       cmplx_t* RESTRICT xp0 = xp;
786       cmplx_t* RESTRICT xp1 = xp + size;
787 
788       fwd_butterfly_loop(size, xp0, xp1, wtab);
789 
790       xp += 2 * size;
791     }
792   while (--blocks != 0);
793 }
794 
795 
796 
797 static void
new_fft_last_two_layers(cmplx_t * xp,long blocks,const cmplx_t * wtab)798 new_fft_last_two_layers(cmplx_t* xp, long blocks, const cmplx_t* wtab)
799 {
800   // 4th root of unity
801   cmplx_t w = wtab[1];
802 
803   do
804     {
805       cmplx_t u0 = xp[0];
806       cmplx_t u1 = xp[1];
807       cmplx_t u2 = xp[2];
808       cmplx_t u3 = xp[3];
809 
810       cmplx_t v0 = u0 + u2;
811       cmplx_t v2 = u0 - u2;
812       cmplx_t v1 = u1 + u3;
813       cmplx_t t  = u1 - u3;
814 
815       //cmplx_t v3 = MUL(t, w);
816       // DIRT: relies on w == (0,-1)
817       cmplx_t v3(t.imag(), -t.real());
818 
819 
820       xp[0] = v0 + v1;
821       xp[1] = v0 - v1;
822       xp[2] = v2 + v3;
823       xp[3] = v2 - v3;
824 
825       xp += 4;
826     }
827   while (--blocks != 0);
828 }
829 
830 
831 static void
new_fft_base(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)832 new_fft_base(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
833 {
834   if (lgN == 0) return;
835 
836   if (lgN == 1)
837     {
838       cmplx_t x0 = xp[0];
839       cmplx_t x1 = xp[1];
840       xp[0] = x0 + x1;
841       xp[1] = x0 - x1;
842       return;
843     }
844 
845 
846   long N = 1L << lgN;
847 
848   for (long j = lgN, size = N, blocks = 1;
849        j > 2; j--, blocks <<= 1, size >>= 1)
850     new_fft_layer(xp, blocks, size, &tab[j][0]);
851 
852   new_fft_last_two_layers(xp, N/4, &tab[2][0]);
853 }
854 
855 
856 // Implements the truncated FFT interface, described above.
857 // All computations done in place, and xp should point to
858 // an array of size N, all of which may be overwitten
859 // during the computation.
860 
861 #define PGFFT_NEW_FFT_THRESH (10)
862 
863 static
new_fft_short(cmplx_t * xp,long yn,long xn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)864 void new_fft_short(cmplx_t* xp, long yn, long xn, long lgN,
865                    const vector<aligned_vector<cmplx_t>>& tab)
866 {
867   long N = 1L << lgN;
868 
869   if (yn == N)
870     {
871       if (xn == N && lgN <= PGFFT_NEW_FFT_THRESH)
872 	{
873 	  // no truncation
874 	  new_fft_base(xp, lgN, tab);
875 	  return;
876 	}
877     }
878 
879   // divide-and-conquer algorithm
880 
881   long half = N >> 1;
882 
883   if (yn <= half)
884     {
885       if (xn <= half)
886 	{
887 	  new_fft_short(xp, yn, xn, lgN - 1, tab);
888 	}
889       else
890 	{
891 	  xn -= half;
892 
893 	  // (X, Y) -> X + Y
894 	  for (long j = 0; j < xn; j++)
895 	    xp[j] = xp[j] + xp[j + half];
896 
897 	  new_fft_short(xp, yn, half, lgN - 1, tab);
898 	}
899     }
900   else
901     {
902       yn -= half;
903 
904       cmplx_t* RESTRICT xp0 = xp;
905       cmplx_t* RESTRICT xp1 = xp + half;
906       const cmplx_t* RESTRICT wtab = &tab[lgN][0];
907 
908       if (xn <= half)
909 	{
910 	  // X -> (X, w*X)
911 	  for (long j = 0; j < xn; j++)
912 	    xp1[j] = MUL(xp0[j], wtab[j]);
913 
914 	  new_fft_short(xp0, half, xn, lgN - 1, tab);
915 	  new_fft_short(xp1, yn, xn, lgN - 1, tab);
916 	}
917       else
918 	{
919 	  xn -= half;
920 
921 	  // (X, Y) -> (X + Y, w*(X - Y))
922           // DIRT: assumes xn is a multiple of 4
923           fwd_butterfly_loop(xn, xp0, xp1, wtab);
924 
925 	  // X -> (X, w*X)
926 	  for (long j = xn; j < half; j++)
927 	    xp1[j] = MUL(xp0[j], wtab[j]);
928 
929 	  new_fft_short(xp0, half, half, lgN - 1, tab);
930 	  new_fft_short(xp1, yn, half, lgN - 1, tab);
931 	}
932     }
933 }
934 
new_fft(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)935 static void new_fft(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
936 {
937    long N = 1L << lgN;
938    new_fft_short(xp, N, N, lgN, tab);
939 }
940 
941 
942 
943 
944 
945 // requires size divisible by 8
946 static void
new_ifft_layer(cmplx_t * xp,long blocks,long size,const cmplx_t * RESTRICT wtab)947 new_ifft_layer(cmplx_t* xp, long blocks, long size,
948                const cmplx_t* RESTRICT wtab)
949 {
950 
951   size /= 2;
952 
953   do
954     {
955 
956       cmplx_t* RESTRICT xp0 = xp;
957       cmplx_t* RESTRICT xp1 = xp + size;
958 
959 
960       inv_butterfly_loop(size, xp0, xp1, wtab);
961 
962       xp += 2 * size;
963     }
964   while (--blocks != 0);
965 }
966 
967 static void
new_ifft_first_two_layers(cmplx_t * xp,long blocks,const cmplx_t * wtab)968 new_ifft_first_two_layers(cmplx_t* xp, long blocks, const cmplx_t* wtab)
969 {
970   // 4th root of unity
971   cmplx_t w = wtab[1];
972 
973   do
974     {
975       cmplx_t u0 = xp[0];
976       cmplx_t u1 = xp[1];
977       cmplx_t u2 = xp[2];
978       cmplx_t u3 = xp[3];
979 
980       cmplx_t v0 = u0 + u1;
981       cmplx_t v1 = u0 - u1;
982       cmplx_t v2 = u2 + u3;
983       cmplx_t t  = u2 - u3;
984 
985       //cmplx_t v3 = CMUL(t, w);
986       // DIRT: relies on w == (0,1)
987       cmplx_t v3(-t.imag(), t.real());
988 
989       xp[0] = v0 + v2;
990       xp[2] = v0 - v2;
991       xp[1] = v1 + v3;
992       xp[3] = v1 - v3;
993 
994       xp += 4;
995     }
996   while (--blocks != 0);
997 }
998 
999 
1000 static void
new_ifft_base(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1001 new_ifft_base(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1002 {
1003   if (lgN == 0) return;
1004 
1005 
1006   if (lgN == 1)
1007     {
1008       cmplx_t x0 = xp[0];
1009       cmplx_t x1 = xp[1];
1010       xp[0] = x0 + x1;
1011       xp[1] = x0 - x1;
1012       return;
1013     }
1014 
1015 
1016   long blocks = 1L << (lgN - 2);
1017   new_ifft_first_two_layers(xp, blocks, &tab[2][0]);
1018   blocks >>= 1;
1019 
1020   long size = 8;
1021   for (long j = 3; j <= lgN; j++, blocks >>= 1, size <<= 1)
1022     new_ifft_layer(xp, blocks, size, &tab[j][0]);
1023 }
1024 
1025 static
1026 void new_ifft_short2(cmplx_t* yp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab);
1027 
1028 
1029 
1030 static
new_ifft_short1(cmplx_t * xp,long yn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1031 void new_ifft_short1(cmplx_t* xp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1032 
1033 // Implements truncated inverse FFT interface, but with xn==yn.
1034 // All computations are done in place.
1035 
1036 {
1037   long N = 1L << lgN;
1038 
1039   if (yn == N && lgN <= PGFFT_NEW_FFT_THRESH)
1040     {
1041       // no truncation
1042       new_ifft_base(xp, lgN, tab);
1043       return;
1044     }
1045 
1046   // divide-and-conquer algorithm
1047 
1048   long half = N >> 1;
1049 
1050   if (yn <= half)
1051     {
1052       // X -> 2X
1053       for (long j = 0; j < yn; j++)
1054       	xp[j] = 2.0 * xp[j];
1055 
1056       new_ifft_short1(xp, yn, lgN - 1, tab);
1057     }
1058   else
1059     {
1060       cmplx_t* RESTRICT xp0 = xp;
1061       cmplx_t* RESTRICT xp1 = xp + half;
1062       const cmplx_t* RESTRICT wtab = &tab[lgN][0];
1063 
1064       new_ifft_short1(xp0, half, lgN - 1, tab);
1065 
1066       yn -= half;
1067 
1068       // X -> (2X, w*X)
1069       for (long j = yn; j < half; j++)
1070 	{
1071 	  cmplx_t x0 = xp0[j];
1072 	  xp0[j] = 2.0 * x0;
1073 	  xp1[j] = MUL(x0, wtab[j]);
1074 	}
1075 
1076       new_ifft_short2(xp1, yn, lgN - 1, tab);
1077 
1078       // (X, Y) -> (X + Y/w, X - Y/w)
1079       {
1080         inv_butterfly_loop(yn, xp0, xp1, wtab);
1081       }
1082     }
1083 }
1084 
1085 
1086 
1087 static
new_ifft_short2(cmplx_t * xp,long yn,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1088 void new_ifft_short2(cmplx_t* xp, long yn, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1089 
1090 // Implements truncated inverse FFT interface, but with xn==N.
1091 // All computations are done in place.
1092 
1093 {
1094   long N = 1L << lgN;
1095 
1096   if (yn == N && lgN <= PGFFT_NEW_FFT_THRESH)
1097     {
1098       // no truncation
1099       new_ifft_base(xp, lgN, tab);
1100       return;
1101     }
1102 
1103   // divide-and-conquer algorithm
1104 
1105   long half = N >> 1;
1106 
1107   if (yn <= half)
1108     {
1109       // X -> 2X
1110       for (long j = 0; j < yn; j++)
1111      	xp[j] = 2.0 * xp[j];
1112       // (X, Y) -> X + Y
1113       for (long j = yn; j < half; j++)
1114 	xp[j] = xp[j] + xp[j + half];
1115 
1116       new_ifft_short2(xp, yn, lgN - 1, tab);
1117 
1118       // (X, Y) -> X - Y
1119       for (long j = 0; j < yn; j++)
1120 	xp[j] = xp[j] - xp[j + half];
1121     }
1122   else
1123     {
1124       cmplx_t* RESTRICT xp0 = xp;
1125       cmplx_t* RESTRICT xp1 = xp + half;
1126       const cmplx_t* RESTRICT wtab = &tab[lgN][0];
1127 
1128       new_ifft_short1(xp0, half, lgN - 1, tab);
1129 
1130       yn -= half;
1131 
1132 
1133       // (X, Y) -> (2X - Y, w*(X - Y))
1134       for (long j = yn; j < half; j++)
1135 	{
1136 	  cmplx_t x0 = xp0[j];
1137 	  cmplx_t x1 = xp1[j];
1138 	  cmplx_t u = x0 - x1;
1139 	  xp0[j] = x0 + u;
1140 	  xp1[j] = MUL(u, wtab[j]);
1141 	}
1142 
1143       new_ifft_short2(xp1, yn, lgN - 1, tab);
1144 
1145       // (X, Y) -> (X + Y/w, X - Y/w)
1146       {
1147         inv_butterfly_loop(yn, xp0, xp1, wtab);
1148       }
1149     }
1150 }
1151 
1152 
1153 
1154 static void
new_ifft(cmplx_t * xp,long lgN,const vector<aligned_vector<cmplx_t>> & tab)1155 new_ifft(cmplx_t* xp, long lgN, const vector<aligned_vector<cmplx_t>>& tab)
1156 {
1157    long N = 1L << lgN;
1158    new_ifft_short1(xp, N, lgN, tab);
1159 }
1160 
1161 
1162 static void
compute_table(vector<aligned_vector<cmplx_t>> & tab,long k)1163 compute_table(vector<aligned_vector<cmplx_t>>& tab, long k)
1164 {
1165   if (k < 2) return;
1166 
1167   const ldbl pi = std::atan(ldbl(1)) * 4.0;
1168 
1169   tab.resize(k+1);
1170   for (long s = 2; s <= k; s++) {
1171     long m = 1L << s;
1172     tab[s].resize(m/2);
1173     for (long j = 0; j < m/2; j++) {
1174       ldbl angle = -((2 * pi) * (ldbl(j)/ldbl(m)));
1175       tab[s][j] = cmplx_t(std::cos(angle), std::sin(angle));
1176     }
1177   }
1178 }
1179 
1180 static long
RevInc(long a,long k)1181 RevInc(long a, long k)
1182 {
1183    long j, m;
1184 
1185    j = k;
1186    m = 1L << (k-1);
1187 
1188    while (j && (m & a)) {
1189       a ^= m;
1190       m >>= 1;
1191       j--;
1192    }
1193    if (j) a ^= m;
1194    return a;
1195 }
1196 
1197 static void
BRC_init(long k,vector<long> & rev)1198 BRC_init(long k, vector<long>& rev)
1199 {
1200    long n = (1L << k);
1201    rev.resize(n);
1202    long i, j;
1203    for (i = 0, j = 0; i < n; i++, j = RevInc(j, k))
1204       rev[i] = j;
1205 }
1206 
1207 
1208 #define PGFFT_BRC_THRESH (11)
1209 #define PGFFT_BRC_Q (5)
1210 // Must have PGFFT_BRC_THRESH >= 2*PGFFT_BRC_Q
1211 // Should also have (1L << (2*PGFFT_BRC_Q)) small enough
1212 // so that we can fit that many cmplx_t's into the cache
1213 
1214 
1215 static
BasicBitReverseCopy(cmplx_t * B,const cmplx_t * A,long k,const vector<long> & rev)1216 void BasicBitReverseCopy(cmplx_t *B,
1217                          const cmplx_t *A, long k, const vector<long>& rev)
1218 {
1219    long n = 1L << k;
1220    long i, j;
1221 
1222    for (i = 0; i < n; i++)
1223       B[rev[i]] = A[i];
1224 }
1225 
1226 static void
COBRA(cmplx_t * RESTRICT B,const cmplx_t * RESTRICT A,long k,const vector<long> & rev,const vector<long> rev1)1227 COBRA(cmplx_t * RESTRICT B, const cmplx_t * RESTRICT A, long k,
1228       const vector<long>& rev, const vector<long> rev1)
1229 {
1230    constexpr long q = PGFFT_BRC_Q;
1231    long k1 = k - 2*q;
1232 
1233    aligned_vector<cmplx_t> BRC_temp(1L << (2*q));
1234 
1235    cmplx_t * RESTRICT T = &BRC_temp[0];
1236    const long * RESTRICT rev_k1 = &rev[0];
1237    const long * RESTRICT rev_q = &rev1[0];
1238 
1239 
1240    for (long b = 0; b < (1L << k1); b++) {
1241       long b1 = rev_k1[b];
1242       for (long a = 0; a < (1L << q); a++) {
1243          long a1 = rev_q[a];
1244          cmplx_t *T_p = &T[a1 << q];
1245          const cmplx_t *A_p = &A[(a << (k1+q)) + (b << q)];
1246 #ifdef USE_PD4
1247          for (long c = 0; c < (1 << q); c += 4) {
1248             PD4 x0 = PD4::load(reinterpret_cast<const double*>(&A_p[c+0]));
1249             PD4 x1 = PD4::load(reinterpret_cast<const double*>(&A_p[c+2]));
1250             store(reinterpret_cast<double*>(&T_p[c+0]), x0);
1251             store(reinterpret_cast<double*>(&T_p[c+2]), x1);
1252          }
1253 #else
1254          for (long c = 0; c < (1L << q); c++) T_p[c] = A_p[c];
1255 #endif
1256       }
1257 
1258       for (long c = 0; c < (1L << q); c++) {
1259          long c1 = rev_q[c];
1260          cmplx_t *B_p = &B[(c1 << (k1+q)) + (b1 << q)];
1261          cmplx_t *T_p = &T[c];
1262          for (long a1 = 0; a1 < (1l << q); a1++)
1263             B_p[a1] = T_p[a1 << q];
1264       }
1265    }
1266 }
1267 
1268 
1269 static long
pow2_precomp(long n,vector<long> & rev,vector<long> & rev1,vector<aligned_vector<cmplx_t>> & tab)1270 pow2_precomp(long n, vector<long>& rev, vector<long>& rev1, vector<aligned_vector<cmplx_t>>& tab)
1271 {
1272    // k = least k such that 2^k >= n
1273    long k = 0;
1274    while ((1L << k) < n) k++;
1275 
1276    compute_table(tab, k);
1277 
1278 
1279    if (k <= PGFFT_BRC_THRESH) {
1280       BRC_init(k, rev);
1281    }
1282    else {
1283       long q = PGFFT_BRC_Q;
1284       long k1 = k - 2*q;
1285       BRC_init(k1, rev);
1286       BRC_init(q, rev1);
1287    }
1288 
1289 
1290    return k;
1291 }
1292 
1293 static void
pow2_comp(const cmplx_t * src,cmplx_t * dst,long n,long k,const vector<long> & rev,const vector<long> & rev1,const vector<aligned_vector<cmplx_t>> & tab)1294 pow2_comp(const cmplx_t* src, cmplx_t* dst,
1295                   long n, long k, const vector<long>& rev, const vector<long>& rev1,
1296                   const vector<aligned_vector<cmplx_t>>& tab)
1297 {
1298    aligned_vector<cmplx_t> x;
1299    x.assign(src, src+n);
1300 
1301    new_fft(&x[0], k, tab);
1302 #if 0
1303    for (long i = 0; i < n; i++) dst[i] = x[i];
1304 #else
1305    if (k <= PGFFT_BRC_THRESH)
1306       BasicBitReverseCopy(&dst[0], &x[0], k, rev);
1307    else
1308       COBRA(&dst[0], &x[0], k, rev, rev1);
1309 #endif
1310 }
1311 
1312 static long
bluestein_precomp(long n,aligned_vector<cmplx_t> & powers,aligned_vector<cmplx_t> & Rb,vector<aligned_vector<cmplx_t>> & tab)1313 bluestein_precomp(long n, aligned_vector<cmplx_t>& powers,
1314                   aligned_vector<cmplx_t>& Rb,
1315                   vector<aligned_vector<cmplx_t>>& tab)
1316 {
1317    // k = least k such that 2^k >= 2*n-1
1318    long k = 0;
1319    while ((1L << k) < 2*n-1) k++;
1320 
1321    compute_table(tab, k);
1322 
1323    const ldbl pi = std::atan(ldbl(1)) * 4.0;
1324 
1325    powers.resize(n);
1326    powers[0] = 1;
1327    long i_sqr = 0;
1328    for (long i = 1; i < n; i++) {
1329       // i^2 = (i-1)^2 + 2*i-1
1330       i_sqr = (i_sqr + 2*i - 1) % (2*n);
1331       ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(2*n)));
1332       powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1333    }
1334 
1335    long N = 1L << k;
1336    Rb.resize(N);
1337    for (long i = 0; i < N; i++) Rb[i] = 0;
1338 
1339    Rb[n-1] = 1;
1340    i_sqr = 0;
1341    for (long i = 1; i < n; i++) {
1342       // i^2 = (i-1)^2 + 2*i-1
1343       i_sqr = (i_sqr + 2*i - 1) % (2*n);
1344       ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(2*n));
1345       Rb[n-1+i] = Rb[n-1-i] = cmplx_t(std::cos(angle), std::sin(angle));
1346    }
1347 
1348    new_fft(&Rb[0], k, tab);
1349 
1350    double Ninv = 1/double(N);
1351    for (long i = 0; i < N; i++)
1352       Rb[i] *= Ninv;
1353 
1354    return k;
1355 
1356 }
1357 
1358 
1359 static void
bluestein_comp(const cmplx_t * src,cmplx_t * dst,long n,long k,const aligned_vector<cmplx_t> & powers,const aligned_vector<cmplx_t> & Rb,const vector<aligned_vector<cmplx_t>> & tab)1360 bluestein_comp(const cmplx_t* src, cmplx_t* dst,
1361                   long n, long k, const aligned_vector<cmplx_t>& powers,
1362                   const aligned_vector<cmplx_t>& Rb,
1363                   const vector<aligned_vector<cmplx_t>>& tab)
1364 {
1365    long N = 1L << k;
1366 
1367    aligned_vector<cmplx_t> x(N);
1368 
1369    for (long i = 0; i < n; i++)
1370       x[i] = MUL(src[i], powers[i]);
1371 
1372    for (long i = n; i < N; i++)
1373       x[i] = 0;
1374 
1375    new_fft(&x[0], k, tab);
1376 
1377    // for (long i = 0; i < N; i++) x[i] = MUL(x[i], Rb[i]);
1378    mul_loop(N, &x[0], &Rb[0]);
1379 
1380    new_ifft(&x[0], k, tab);
1381 
1382    double Ninv = 1/double(N);
1383 
1384    for (long i = 0; i < n; i++)
1385       dst[i] = MUL(x[n-1+i], powers[i]);
1386 
1387 }
1388 
1389 
1390 
1391 
1392 static long
bluestein_precomp1(long n,aligned_vector<cmplx_t> & powers,aligned_vector<cmplx_t> & Rb,vector<aligned_vector<cmplx_t>> & tab)1393 bluestein_precomp1(long n, aligned_vector<cmplx_t>& powers,
1394                   aligned_vector<cmplx_t>& Rb,
1395                   vector<aligned_vector<cmplx_t>>& tab)
1396 {
1397    // k = least k such that 2^k >= 2*n-1
1398    long k = 0;
1399    while ((1L << k) < 2*n-1) k++;
1400 
1401    compute_table(tab, k);
1402 
1403    const ldbl pi = std::atan(ldbl(1)) * 4.0;
1404 
1405    powers.resize(n);
1406    powers[0] = 1;
1407    long i_sqr = 0;
1408 
1409    if (n % 2 == 0) {
1410       for (long i = 1; i < n; i++) {
1411 	 // i^2 = (i-1)^2 + 2*i-1
1412 	 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1413 	 ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(2*n)));
1414 	 powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1415       }
1416    }
1417    else {
1418       for (long i = 1; i < n; i++) {
1419 	 // i^2*((n+1)/2) = (i-1)^2*((n+1)/2) + i + ((n-1)/2) (mod n)
1420 	 i_sqr = (i_sqr + i + (n-1)/2) % n;
1421 	 ldbl angle = -((2 * pi) * (ldbl(i_sqr)/ldbl(n)));
1422 	 powers[i] = cmplx_t(std::cos(angle), std::sin(angle));
1423       }
1424    }
1425 
1426    long N = 1L << k;
1427    Rb.resize(N);
1428    for (long i = 0; i < N; i++) Rb[i] = 0;
1429 
1430    Rb[0] = 1;
1431    i_sqr = 0;
1432 
1433    if (n % 2 == 0) {
1434       for (long i = 1; i < n; i++) {
1435 	 // i^2 = (i-1)^2 + 2*i-1
1436 	 i_sqr = (i_sqr + 2*i - 1) % (2*n);
1437 	 ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(2*n));
1438 	 Rb[i] = cmplx_t(std::cos(angle), std::sin(angle));
1439       }
1440    }
1441    else {
1442       for (long i = 1; i < n; i++) {
1443 	 // i^2*((n+1)/2) = (i-1)^2*((n+1)/2) + i + ((n-1)/2) (mod n)
1444 	 i_sqr = (i_sqr + i + (n-1)/2) % n;
1445 	 ldbl angle = (2 * pi) * (ldbl(i_sqr)/ldbl(n));
1446 	 Rb[i] = cmplx_t(std::cos(angle), std::sin(angle));
1447       }
1448    }
1449 
1450    new_fft(&Rb[0], k, tab);
1451 
1452    double Ninv = 1/double(N);
1453    for (long i = 0; i < N; i++)
1454       Rb[i] *= Ninv;
1455 
1456    return k;
1457 
1458 }
1459 
1460 
1461 static void
bluestein_comp1(const cmplx_t * src,cmplx_t * dst,long n,long k,const aligned_vector<cmplx_t> & powers,const aligned_vector<cmplx_t> & Rb,const vector<aligned_vector<cmplx_t>> & tab)1462 bluestein_comp1(const cmplx_t* src, cmplx_t* dst,
1463                   long n, long k, const aligned_vector<cmplx_t>& powers,
1464                   const aligned_vector<cmplx_t>& Rb,
1465                   const vector<aligned_vector<cmplx_t>>& tab)
1466 {
1467    long N = 1L << k;
1468 
1469    aligned_vector<cmplx_t> x(N);
1470 
1471    for (long i = 0; i < n; i++)
1472       x[i] = MUL(src[i], powers[i]);
1473 
1474    long len = FFTRoundUp(2*n-1, k);
1475    long ilen = FFTRoundUp(n, k);
1476 
1477    for (long i = n; i < ilen; i++)
1478       x[i] = 0;
1479 
1480    new_fft_short(&x[0], len, ilen, k, tab);
1481 
1482    // for (long i = 0; i < len; i++) x[i] = MUL(x[i], Rb[i]);
1483    mul_loop(len, &x[0], &Rb[0]);
1484 
1485    new_ifft_short1(&x[0], len, k, tab);
1486 
1487    double Ninv = 1/double(N);
1488 
1489    for (long i = 0; i < n-1; i++)
1490       dst[i] = MUL(x[i] + x[n+i], powers[i]);
1491 
1492    dst[n-1] = MUL(x[n-1], powers[n-1]);
1493 
1494 }
1495 
1496 #define PGFFT_STRATEGY_NULL  (0)
1497 #define PGFFT_STRATEGY_POW2  (1)
1498 #define PGFFT_STRATEGY_BLUE  (2)
1499 #define PGFFT_STRATEGY_TBLUE (3)
1500 
choose_strategy(long n)1501 static long choose_strategy(long n)
1502 {
1503    if (n == 1) return PGFFT_STRATEGY_NULL;
1504 
1505    if ((n & (n - 1)) == 0) return PGFFT_STRATEGY_POW2;
1506 
1507    if (!PGFFT_USE_TRUNCATED_BLUE) return PGFFT_STRATEGY_BLUE;
1508 
1509    // choose between Bluestein and truncated Bluestein
1510 
1511    // k = least k such that 2^k >= 2*n-1
1512    long k = 0;
1513    while ((1L << k) < 2*n-1) k++;
1514 
1515    long rdup = FFTRoundUp(2*n-1, k);
1516    if (rdup == (1L << k)) return PGFFT_STRATEGY_BLUE;
1517 
1518    return PGFFT_STRATEGY_TBLUE;
1519 }
1520 
PGFFT(long n_)1521 PGFFT::PGFFT(long n_)
1522 {
1523    assert(n_ > 0);
1524    n = n_;
1525 
1526    strategy = choose_strategy(n);
1527 
1528    //std::cout << strategy << "\n";
1529 
1530    switch (strategy) {
1531 
1532    case PGFFT_STRATEGY_NULL:
1533       break;
1534 
1535    case PGFFT_STRATEGY_POW2:
1536       k = pow2_precomp(n, rev, rev1, tab);
1537       break;
1538 
1539    case PGFFT_STRATEGY_BLUE:
1540       k = bluestein_precomp(n, powers, Rb, tab);
1541       break;
1542 
1543    case PGFFT_STRATEGY_TBLUE:
1544       k = bluestein_precomp1(n, powers, Rb, tab);
1545       break;
1546 
1547    default: ;
1548 
1549    }
1550 }
1551 
apply(const cmplx_t * src,cmplx_t * dst) const1552 void PGFFT::apply(const cmplx_t* src, cmplx_t* dst) const
1553 {
1554    switch (strategy) {
1555 
1556    case PGFFT_STRATEGY_NULL:
1557       break;
1558 
1559    case PGFFT_STRATEGY_POW2:
1560       pow2_comp(src, dst, n, k, rev, rev1, tab);
1561       break;
1562 
1563    case PGFFT_STRATEGY_BLUE:
1564       bluestein_comp(src, dst, n, k, powers, Rb, tab);
1565       break;
1566 
1567    case PGFFT_STRATEGY_TBLUE:
1568       bluestein_comp1(src, dst, n, k, powers, Rb, tab);
1569       break;
1570 
1571    default: ;
1572 
1573    }
1574 }
1575 
1576 }
1577 
1578 
1579 /****************************************************************************
1580 
1581 PGFFT: Pretty Good FFT (v1.8)
1582 
1583 Copyright (C) 2019, victor Shoup
1584 
1585 All rights reserved.
1586 
1587 Redistribution and use in source and binary forms, with or without
1588 modification, are permitted provided that the following conditions are met:
1589 
1590 * Redistributions of source code must retain the above copyright notice, this
1591   list of conditions and the following disclaimer.
1592 * Redistributions in binary form must reproduce the above copyright notice,
1593   this list of conditions and the following disclaimer in the documentation
1594   and/or other materials provided with the distribution.
1595 
1596 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
1597 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
1598 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
1599 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
1600 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
1601 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
1602 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
1603 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
1604 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1605 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1606 
1607 ****************************************************************************
1608 
1609 The logic of this code is derived from code originally developed by David Harvey,
1610 even though the code itself has been essentially rewritten from scratch.
1611 Here is David Harvey's original copyright notice.
1612 
1613 fft62: a library for number-theoretic transforms
1614 
1615 Copyright (C) 2013, David Harvey
1616 
1617 All rights reserved.
1618 
1619 Redistribution and use in source and binary forms, with or without
1620 modification, are permitted provided that the following conditions are met:
1621 
1622 * Redistributions of source code must retain the above copyright notice, this
1623   list of conditions and the following disclaimer.
1624 * Redistributions in binary form must reproduce the above copyright notice,
1625   this list of conditions and the following disclaimer in the documentation
1626   and/or other materials provided with the distribution.
1627 
1628 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
1629 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
1630 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
1631 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
1632 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
1633 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
1634 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
1635 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
1636 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1637 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1638 
1639 ****************************************************************************/
1640 
1641 #pragma GCC diagnostic pop
1642