1 /*
2 This file is part of pocketfft.
3 
4 Copyright (C) 2010-2019 Max-Planck-Society
5 Copyright (C) 2019 Peter Bell
6 
7 For the odd-sized DCT-IV transforms:
8   Copyright (C) 2003, 2007-14 Matteo Frigo
9   Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology
10 
11 Authors: Martin Reinecke, Peter Bell
12 
13 All rights reserved.
14 
15 Redistribution and use in source and binary forms, with or without modification,
16 are permitted provided that the following conditions are met:
17 
18 * Redistributions of source code must retain the above copyright notice, this
19   list of conditions and the following disclaimer.
20 * Redistributions in binary form must reproduce the above copyright notice, this
21   list of conditions and the following disclaimer in the documentation and/or
22   other materials provided with the distribution.
23 * Neither the name of the copyright holder nor the names of its contributors may
24   be used to endorse or promote products derived from this software without
25   specific prior written permission.
26 
27 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
28 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
29 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
31 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
32 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
33 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
34 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
35 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 */
38 #ifndef PYTHONIC_INCLUDE_NUMPY_FFT_POCKETFFT_HPP
39 #define PYTHONIC_INCLUDE_NUMPY_FFT_POCKETFFT_HPP
40 #ifndef POCKETFFT_HDRONLY_H
41 #define POCKETFFT_HDRONLY_H
42 
43 #ifndef __cplusplus
44 #error This file is C++ and requires a C++ compiler.
45 #endif
46 
47 #if !(__cplusplus >= 201103L || _MSVC_LANG + 0L >= 201103L)
48 #error This file requires at least C++11 support.
49 #endif
50 
51 #ifndef POCKETFFT_CACHE_SIZE
52 #define POCKETFFT_CACHE_SIZE 16
53 #endif
54 
55 #include <cmath>
56 #include <cstring>
57 #include <cstdlib>
58 #include <stdexcept>
59 #include <memory>
60 #include <vector>
61 #include <complex>
62 #if POCKETFFT_CACHE_SIZE != 0
63 #include <array>
64 #include <mutex>
65 #endif
66 
67 #ifndef POCKETFFT_NO_MULTITHREADING
68 #include <mutex>
69 #include <condition_variable>
70 #include <thread>
71 #include <queue>
72 #include <atomic>
73 #include <functional>
74 
75 #ifdef POCKETFFT_PTHREADS
76 #include <pthread.h>
77 #endif
78 #endif
79 
80 #if defined(__GNUC__)
81 #define POCKETFFT_NOINLINE __attribute__((noinline))
82 #define POCKETFFT_RESTRICT __restrict__
83 #elif defined(_MSC_VER)
84 #define POCKETFFT_NOINLINE __declspec(noinline)
85 #define POCKETFFT_RESTRICT __restrict
86 #else
87 #define POCKETFFT_NOINLINE
88 #define POCKETFFT_RESTRICT
89 #endif
90 
91 namespace pocketfft
92 {
93 
94   namespace detail
95   {
96     using std::size_t;
97     using std::ptrdiff_t;
98 
99     // Always use std:: for <cmath> functions
100     template <typename T>
101     T cos(T) = delete;
102     template <typename T>
103     T sin(T) = delete;
104     template <typename T>
105     T sqrt(T) = delete;
106 
107     using shape_t = std::vector<size_t>;
108     using stride_t = std::vector<ptrdiff_t>;
109 
110     constexpr bool FORWARD = true, BACKWARD = false;
111 
112 // only enable vector support for gcc>=5.0 and clang>=5.0
113 #ifndef POCKETFFT_NO_VECTORS
114 #define POCKETFFT_NO_VECTORS
115 #if defined(__INTEL_COMPILER)
116 // do nothing. This is necessary because this compiler also sets __GNUC__.
117 #elif defined(__clang__)
118 // AppleClang has their own version numbering
119 #ifdef __apple_build_version__
120 #if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1)
121 #undef POCKETFFT_NO_VECTORS
122 #endif
123 #elif __clang_major__ >= 5
124 #undef POCKETFFT_NO_VECTORS
125 #endif
126 #elif defined(__GNUC__)
127 #if __GNUC__ >= 5
128 #undef POCKETFFT_NO_VECTORS
129 #endif
130 #endif
131 #endif
132 
133     template <typename T>
134     struct VLEN {
135       static constexpr size_t val = 1;
136     };
137 
138 #ifndef POCKETFFT_NO_VECTORS
139 #if (defined(__AVX512F__))
140     template <>
141     struct VLEN<float> {
142       static constexpr size_t val = 16;
143     };
144     template <>
145     struct VLEN<double> {
146       static constexpr size_t val = 8;
147     };
148 #elif(defined(__AVX__))
149     template <>
150     struct VLEN<float> {
151       static constexpr size_t val = 8;
152     };
153     template <>
154     struct VLEN<double> {
155       static constexpr size_t val = 4;
156     };
157 #elif(defined(__SSE2__))
158     template <>
159     struct VLEN<float> {
160       static constexpr size_t val = 4;
161     };
162     template <>
163     struct VLEN<double> {
164       static constexpr size_t val = 2;
165     };
166 #elif(defined(__VSX__))
167     template <>
168     struct VLEN<float> {
169       static constexpr size_t val = 4;
170     };
171     template <>
172     struct VLEN<double> {
173       static constexpr size_t val = 2;
174     };
175 #else
176 #define POCKETFFT_NO_VECTORS
177 #endif
178 #endif
179 
180     template <typename T>
181     class arr
182     {
183     private:
184       T *p;
185       size_t sz;
186 
187 #if defined(POCKETFFT_NO_VECTORS)
ralloc(size_t num)188       static T *ralloc(size_t num)
189       {
190         if (num == 0)
191           return nullptr;
192         void *res = malloc(num * sizeof(T));
193         if (!res)
194           throw std::bad_alloc();
195         return reinterpret_cast<T *>(res);
196       }
dealloc(T * ptr)197       static void dealloc(T *ptr)
198       {
199         free(ptr);
200       }
201 // C++17 in principle has "aligned_alloc", but unfortunately not everywhere ...
202 #elif(__cplusplus >= 201703L) &&                                               \
203     ((!defined(__MINGW32__)) || defined(_GLIBCXX_HAVE_ALIGNED_ALLOC)) &&       \
204     (!defined(__APPLE__))
ralloc(size_t num)205       static T *ralloc(size_t num)
206       {
207         if (num == 0)
208           return nullptr;
209         void *res = aligned_alloc(64, num * sizeof(T));
210         if (!res)
211           throw std::bad_alloc();
212         return reinterpret_cast<T *>(res);
213       }
dealloc(T * ptr)214       static void dealloc(T *ptr)
215       {
216         free(ptr);
217       }
218 #else // portable emulation
ralloc(size_t num)219       static T *ralloc(size_t num)
220       {
221         if (num == 0)
222           return nullptr;
223         void *ptr = malloc(num * sizeof(T) + 64);
224         if (!ptr)
225           throw std::bad_alloc();
226         T *res = reinterpret_cast<T *>(
227             (reinterpret_cast<size_t>(ptr) & ~(size_t(63))) + 64);
228         (reinterpret_cast<void **>(res))[-1] = ptr;
229         return res;
230       }
dealloc(T * ptr)231       static void dealloc(T *ptr)
232       {
233         if (ptr)
234           free((reinterpret_cast<void **>(ptr))[-1]);
235       }
236 #endif
237 
238     public:
arr()239       arr() : p(0), sz(0)
240       {
241       }
arr(size_t n)242       arr(size_t n) : p(ralloc(n)), sz(n)
243       {
244       }
arr(arr && other)245       arr(arr &&other) : p(other.p), sz(other.sz)
246       {
247         other.p = nullptr;
248         other.sz = 0;
249       }
~arr()250       ~arr()
251       {
252         dealloc(p);
253       }
254 
resize(size_t n)255       void resize(size_t n)
256       {
257         if (n == sz)
258           return;
259         dealloc(p);
260         p = ralloc(n);
261         sz = n;
262       }
263 
operator [](size_t idx)264       T &operator[](size_t idx)
265       {
266         return p[idx];
267       }
operator [](size_t idx) const268       const T &operator[](size_t idx) const
269       {
270         return p[idx];
271       }
272 
data()273       T *data()
274       {
275         return p;
276       }
data() const277       const T *data() const
278       {
279         return p;
280       }
281 
size() const282       size_t size() const
283       {
284         return sz;
285       }
286     };
287 
288     template <typename T>
289     struct cmplx {
290       T r, i;
cmplxpocketfft::detail::cmplx291       cmplx()
292       {
293       }
cmplxpocketfft::detail::cmplx294       cmplx(T r_, T i_) : r(r_), i(i_)
295       {
296       }
Setpocketfft::detail::cmplx297       void Set(T r_, T i_)
298       {
299         r = r_;
300         i = i_;
301       }
Setpocketfft::detail::cmplx302       void Set(T r_)
303       {
304         r = r_;
305         i = T(0);
306       }
operator +=pocketfft::detail::cmplx307       cmplx &operator+=(const cmplx &other)
308       {
309         r += other.r;
310         i += other.i;
311         return *this;
312       }
313       template <typename T2>
operator *=pocketfft::detail::cmplx314       cmplx &operator*=(T2 other)
315       {
316         r *= other;
317         i *= other;
318         return *this;
319       }
320       template <typename T2>
operator *=pocketfft::detail::cmplx321       cmplx &operator*=(const cmplx<T2> &other)
322       {
323         T tmp = r * other.r - i * other.i;
324         i = r * other.i + i * other.r;
325         r = tmp;
326         return *this;
327       }
328       template <typename T2>
operator +=pocketfft::detail::cmplx329       cmplx &operator+=(const cmplx<T2> &other)
330       {
331         r += other.r;
332         i += other.i;
333         return *this;
334       }
335       template <typename T2>
operator -=pocketfft::detail::cmplx336       cmplx &operator-=(const cmplx<T2> &other)
337       {
338         r -= other.r;
339         i -= other.i;
340         return *this;
341       }
342       template <typename T2>
operator *pocketfft::detail::cmplx343       auto operator*(const T2 &other) const -> cmplx<decltype(r *other)>
344       {
345         return {r * other, i * other};
346       }
347       template <typename T2>
operator +pocketfft::detail::cmplx348       auto operator+(const cmplx<T2> &other) const
349           -> cmplx<decltype(r + other.r)>
350       {
351         return {r + other.r, i + other.i};
352       }
353       template <typename T2>
operator -pocketfft::detail::cmplx354       auto operator-(const cmplx<T2> &other) const
355           -> cmplx<decltype(r + other.r)>
356       {
357         return {r - other.r, i - other.i};
358       }
359       template <typename T2>
operator *pocketfft::detail::cmplx360       auto operator*(const cmplx<T2> &other) const
361           -> cmplx<decltype(r + other.r)>
362       {
363         return {r * other.r - i * other.i, r * other.i + i * other.r};
364       }
365       template <bool fwd, typename T2>
special_mulpocketfft::detail::cmplx366       auto special_mul(const cmplx<T2> &other) const
367           -> cmplx<decltype(r + other.r)>
368       {
369         using Tres = cmplx<decltype(r + other.r)>;
370         return fwd ? Tres(r * other.r + i * other.i, i * other.r - r * other.i)
371                    : Tres(r * other.r - i * other.i, r * other.i + i * other.r);
372       }
373     };
374     template <typename T>
PM(T & a,T & b,T c,T d)375     inline void PM(T &a, T &b, T c, T d)
376     {
377       a = c + d;
378       b = c - d;
379     }
380     template <typename T>
PMINPLACE(T & a,T & b)381     inline void PMINPLACE(T &a, T &b)
382     {
383       T t = a;
384       a += b;
385       b = t - b;
386     }
387     template <typename T>
MPINPLACE(T & a,T & b)388     inline void MPINPLACE(T &a, T &b)
389     {
390       T t = a;
391       a -= b;
392       b = t + b;
393     }
394     template <typename T>
conj(const cmplx<T> & a)395     cmplx<T> conj(const cmplx<T> &a)
396     {
397       return {a.r, -a.i};
398     }
399     template <bool fwd, typename T, typename T2>
special_mul(const cmplx<T> & v1,const cmplx<T2> & v2,cmplx<T> & res)400     void special_mul(const cmplx<T> &v1, const cmplx<T2> &v2, cmplx<T> &res)
401     {
402       res =
403           fwd ? cmplx<T>(v1.r * v2.r + v1.i * v2.i, v1.i * v2.r - v1.r * v2.i)
404               : cmplx<T>(v1.r * v2.r - v1.i * v2.i, v1.r * v2.i + v1.i * v2.r);
405     }
406 
407     template <typename T>
ROT90(cmplx<T> & a)408     void ROT90(cmplx<T> &a)
409     {
410       auto tmp_ = a.r;
411       a.r = -a.i;
412       a.i = tmp_;
413     }
414     template <bool fwd, typename T>
ROTX90(cmplx<T> & a)415     void ROTX90(cmplx<T> &a)
416     {
417       auto tmp_ = fwd ? -a.r : a.r;
418       a.r = fwd ? a.i : -a.i;
419       a.i = tmp_;
420     }
421 
422     //
423     // twiddle factor section
424     //
425     template <typename T>
426     class sincos_2pibyn
427     {
428     private:
429       using Thigh = typename std::conditional<(sizeof(T) > sizeof(double)), T,
430                                               double>::type;
431       size_t N, mask, shift;
432       arr<cmplx<Thigh>> v1, v2;
433 
calc(size_t x,size_t n,Thigh ang)434       static cmplx<Thigh> calc(size_t x, size_t n, Thigh ang)
435       {
436         x <<= 3;
437         if (x < 4 * n) // first half
438         {
439           if (x < 2 * n) // first quadrant
440           {
441             if (x < n)
442               return cmplx<Thigh>(std::cos(Thigh(x) * ang),
443                                   std::sin(Thigh(x) * ang));
444             return cmplx<Thigh>(std::sin(Thigh(2 * n - x) * ang),
445                                 std::cos(Thigh(2 * n - x) * ang));
446           } else // second quadrant
447           {
448             x -= 2 * n;
449             if (x < n)
450               return cmplx<Thigh>(-std::sin(Thigh(x) * ang),
451                                   std::cos(Thigh(x) * ang));
452             return cmplx<Thigh>(-std::cos(Thigh(2 * n - x) * ang),
453                                 std::sin(Thigh(2 * n - x) * ang));
454           }
455         } else {
456           x = 8 * n - x;
457           if (x < 2 * n) // third quadrant
458           {
459             if (x < n)
460               return cmplx<Thigh>(std::cos(Thigh(x) * ang),
461                                   -std::sin(Thigh(x) * ang));
462             return cmplx<Thigh>(std::sin(Thigh(2 * n - x) * ang),
463                                 -std::cos(Thigh(2 * n - x) * ang));
464           } else // fourth quadrant
465           {
466             x -= 2 * n;
467             if (x < n)
468               return cmplx<Thigh>(-std::sin(Thigh(x) * ang),
469                                   -std::cos(Thigh(x) * ang));
470             return cmplx<Thigh>(-std::cos(Thigh(2 * n - x) * ang),
471                                 -std::sin(Thigh(2 * n - x) * ang));
472           }
473         }
474       }
475 
476     public:
sincos_2pibyn(size_t n)477       POCKETFFT_NOINLINE sincos_2pibyn(size_t n) : N(n)
478       {
479         constexpr auto pi = 3.141592653589793238462643383279502884197L;
480         Thigh ang = Thigh(0.25L * pi / n);
481         size_t nval = (n + 2) / 2;
482         shift = 1;
483         while ((size_t(1) << shift) * (size_t(1) << shift) < nval)
484           ++shift;
485         mask = (size_t(1) << shift) - 1;
486         v1.resize(mask + 1);
487         v1[0].Set(Thigh(1), Thigh(0));
488         for (size_t i = 1; i < v1.size(); ++i)
489           v1[i] = calc(i, n, ang);
490         v2.resize((nval + mask) / (mask + 1));
491         v2[0].Set(Thigh(1), Thigh(0));
492         for (size_t i = 1; i < v2.size(); ++i)
493           v2[i] = calc(i * (mask + 1), n, ang);
494       }
495 
operator [](size_t idx) const496       cmplx<T> operator[](size_t idx) const
497       {
498         if (2 * idx <= N) {
499           auto x1 = v1[idx & mask], x2 = v2[idx >> shift];
500           return cmplx<T>(T(x1.r * x2.r - x1.i * x2.i),
501                           T(x1.r * x2.i + x1.i * x2.r));
502         }
503         idx = N - idx;
504         auto x1 = v1[idx & mask], x2 = v2[idx >> shift];
505         return cmplx<T>(T(x1.r * x2.r - x1.i * x2.i),
506                         -T(x1.r * x2.i + x1.i * x2.r));
507       }
508     };
509 
510     struct util // hack to avoid duplicate symbols
511         {
largest_prime_factorpocketfft::detail::util512       static POCKETFFT_NOINLINE size_t largest_prime_factor(size_t n)
513       {
514         size_t res = 1;
515         while ((n & 1) == 0) {
516           res = 2;
517           n >>= 1;
518         }
519         for (size_t x = 3; x * x <= n; x += 2)
520           while ((n % x) == 0) {
521             res = x;
522             n /= x;
523           }
524         if (n > 1)
525           res = n;
526         return res;
527       }
528 
cost_guesspocketfft::detail::util529       static POCKETFFT_NOINLINE double cost_guess(size_t n)
530       {
531         constexpr double lfp = 1.1; // penalty for non-hardcoded larger factors
532         size_t ni = n;
533         double result = 0.;
534         while ((n & 1) == 0) {
535           result += 2;
536           n >>= 1;
537         }
538         for (size_t x = 3; x * x <= n; x += 2)
539           while ((n % x) == 0) {
540             result += (x <= 5)
541                           ? double(x)
542                           : lfp * double(x); // penalize larger prime factors
543             n /= x;
544           }
545         if (n > 1)
546           result += (n <= 5) ? double(n) : lfp * double(n);
547         return result * double(ni);
548       }
549 
550       /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */
good_size_cmplxpocketfft::detail::util551       static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n)
552       {
553         if (n <= 12)
554           return n;
555 
556         size_t bestfac = 2 * n;
557         for (size_t f11 = 1; f11 < bestfac; f11 *= 11)
558           for (size_t f117 = f11; f117 < bestfac; f117 *= 7)
559             for (size_t f1175 = f117; f1175 < bestfac; f1175 *= 5) {
560               size_t x = f1175;
561               while (x < n)
562                 x *= 2;
563               for (;;) {
564                 if (x < n)
565                   x *= 3;
566                 else if (x > n) {
567                   if (x < bestfac)
568                     bestfac = x;
569                   if (x & 1)
570                     break;
571                   x >>= 1;
572                 } else
573                   return n;
574               }
575             }
576         return bestfac;
577       }
578 
579       /* returns the smallest composite of 2, 3, 5 which is >= n */
good_size_realpocketfft::detail::util580       static POCKETFFT_NOINLINE size_t good_size_real(size_t n)
581       {
582         if (n <= 6)
583           return n;
584 
585         size_t bestfac = 2 * n;
586         for (size_t f5 = 1; f5 < bestfac; f5 *= 5) {
587           size_t x = f5;
588           while (x < n)
589             x *= 2;
590           for (;;) {
591             if (x < n)
592               x *= 3;
593             else if (x > n) {
594               if (x < bestfac)
595                 bestfac = x;
596               if (x & 1)
597                 break;
598               x >>= 1;
599             } else
600               return n;
601           }
602         }
603         return bestfac;
604       }
605 
prodpocketfft::detail::util606       static size_t prod(const shape_t &shape)
607       {
608         size_t res = 1;
609         for (auto sz : shape)
610           res *= sz;
611         return res;
612       }
613 
sanity_checkpocketfft::detail::util614       static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
615                                                   const stride_t &stride_in,
616                                                   const stride_t &stride_out,
617                                                   bool inplace)
618       {
619         auto ndim = shape.size();
620         if (ndim < 1)
621           throw std::runtime_error("ndim must be >= 1");
622         if ((stride_in.size() != ndim) || (stride_out.size() != ndim))
623           throw std::runtime_error("stride dimension mismatch");
624         if (inplace && (stride_in != stride_out))
625           throw std::runtime_error("stride mismatch");
626       }
627 
sanity_checkpocketfft::detail::util628       static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
629                                                   const stride_t &stride_in,
630                                                   const stride_t &stride_out,
631                                                   bool inplace,
632                                                   const shape_t &axes)
633       {
634         sanity_check(shape, stride_in, stride_out, inplace);
635         auto ndim = shape.size();
636         shape_t tmp(ndim, 0);
637         for (auto ax : axes) {
638           if (ax >= ndim)
639             throw std::invalid_argument("bad axis number");
640           if (++tmp[ax] > 1)
641             throw std::invalid_argument("axis specified repeatedly");
642         }
643       }
644 
sanity_checkpocketfft::detail::util645       static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
646                                                   const stride_t &stride_in,
647                                                   const stride_t &stride_out,
648                                                   bool inplace, size_t axis)
649       {
650         sanity_check(shape, stride_in, stride_out, inplace);
651         if (axis >= shape.size())
652           throw std::invalid_argument("bad axis number");
653       }
654 
655 #ifdef POCKETFFT_NO_MULTITHREADING
thread_countpocketfft::detail::util656       static size_t thread_count(size_t /*nthreads*/, const shape_t & /*shape*/,
657                                  size_t /*axis*/, size_t /*vlen*/)
658       {
659         return 1;
660       }
661 #else
thread_countpocketfft::detail::util662       static size_t thread_count(size_t nthreads, const shape_t &shape,
663                                  size_t axis, size_t vlen)
664       {
665         if (nthreads == 1)
666           return 1;
667         size_t size = prod(shape);
668         size_t parallel = size / (shape[axis] * vlen);
669         if (shape[axis] < 1000)
670           parallel /= 4;
671         size_t max_threads =
672             nthreads == 0 ? std::thread::hardware_concurrency() : nthreads;
673         return std::max(size_t(1), std::min(parallel, max_threads));
674       }
675 #endif
676     };
677 
678     namespace threading
679     {
680 
681 #ifdef POCKETFFT_NO_MULTITHREADING
682 
thread_id()683       constexpr inline size_t thread_id()
684       {
685         return 0;
686       }
num_threads()687       constexpr inline size_t num_threads()
688       {
689         return 1;
690       }
691 
692       template <typename Func>
thread_map(size_t,Func f)693       void thread_map(size_t /* nthreads */, Func f)
694       {
695         f();
696       }
697 
698 #else
699 
700       inline size_t &thread_id()
701       {
702         static thread_local size_t thread_id_ = 0;
703         return thread_id_;
704       }
705       inline size_t &num_threads()
706       {
707         static thread_local size_t num_threads_ = 1;
708         return num_threads_;
709       }
710       static const size_t max_threads =
711           std::max(1u, std::thread::hardware_concurrency());
712 
713       class latch
714       {
715         std::atomic<size_t> num_left_;
716         std::mutex mut_;
717         std::condition_variable completed_;
718         using lock_t = std::unique_lock<std::mutex>;
719 
720       public:
721         latch(size_t n) : num_left_(n)
722         {
723         }
724 
725         void count_down()
726         {
727           lock_t lock(mut_);
728           if (--num_left_)
729             return;
730           completed_.notify_all();
731         }
732 
733         void wait()
734         {
735           lock_t lock(mut_);
736           completed_.wait(lock, [this] { return is_ready(); });
737         }
738         bool is_ready()
739         {
740           return num_left_ == 0;
741         }
742       };
743 
744       template <typename T>
745       class concurrent_queue
746       {
747         std::queue<T> q_;
748         std::mutex mut_;
749         std::condition_variable item_added_;
750         bool shutdown_;
751         using lock_t = std::unique_lock<std::mutex>;
752 
753       public:
754         concurrent_queue() : shutdown_(false)
755         {
756         }
757 
758         void push(T val)
759         {
760           {
761             lock_t lock(mut_);
762             if (shutdown_)
763               throw std::runtime_error("Item added to queue after shutdown");
764             q_.push(move(val));
765           }
766           item_added_.notify_one();
767         }
768 
769         bool pop(T &val)
770         {
771           lock_t lock(mut_);
772           item_added_.wait(lock, [this] { return (!q_.empty() || shutdown_); });
773           if (q_.empty())
774             return false; // We are shutting down
775 
776           val = std::move(q_.front());
777           q_.pop();
778           return true;
779         }
780 
781         void shutdown()
782         {
783           {
784             lock_t lock(mut_);
785             shutdown_ = true;
786           }
787           item_added_.notify_all();
788         }
789 
790         void restart()
791         {
792           shutdown_ = false;
793         }
794       };
795 
796       class thread_pool
797       {
798         concurrent_queue<std::function<void()>> work_queue_;
799         std::vector<std::thread> threads_;
800 
801         void worker_main()
802         {
803           std::function<void()> work;
804           while (work_queue_.pop(work))
805             work();
806         }
807 
808         void create_threads()
809         {
810           size_t nthreads = threads_.size();
811           for (size_t i = 0; i < nthreads; ++i) {
812             try {
813               threads_[i] = std::thread([this] { worker_main(); });
814             } catch (...) {
815               shutdown();
816               throw;
817             }
818           }
819         }
820 
821       public:
822         explicit thread_pool(size_t nthreads) : threads_(nthreads)
823         {
824           create_threads();
825         }
826 
827         thread_pool() : thread_pool(max_threads)
828         {
829         }
830 
831         ~thread_pool()
832         {
833           shutdown();
834         }
835 
836         void submit(std::function<void()> work)
837         {
838           work_queue_.push(move(work));
839         }
840 
841         void shutdown()
842         {
843           work_queue_.shutdown();
844           for (auto &thread : threads_)
845             if (thread.joinable())
846               thread.join();
847         }
848 
849         void restart()
850         {
851           work_queue_.restart();
852           create_threads();
853         }
854       };
855 
856       inline thread_pool &get_pool()
857       {
858         static thread_pool pool;
859 #ifdef POCKETFFT_PTHREADS
860         static std::once_flag f;
861         std::call_once(f, [] {
862           pthread_atfork(+[] { get_pool().shutdown(); }, // prepare
863                          +[] { get_pool().restart(); },  // parent
864                          +[] { get_pool().restart(); }   // child
865                          );
866         });
867 #endif
868 
869         return pool;
870       }
871 
872       /** Map a function f over nthreads */
873       template <typename Func>
874       void thread_map(size_t nthreads, Func f)
875       {
876         if (nthreads == 0)
877           nthreads = max_threads;
878 
879         if (nthreads == 1) {
880           f();
881           return;
882         }
883 
884         auto &pool = get_pool();
885         latch counter(nthreads);
886         std::exception_ptr ex;
887         std::mutex ex_mut;
888         for (size_t i = 0; i < nthreads; ++i) {
889           pool.submit([&f, &counter, &ex, &ex_mut, i, nthreads] {
890             thread_id() = i;
891             num_threads() = nthreads;
892             try {
893               f();
894             } catch (...) {
895               std::lock_guard<std::mutex> lock(ex_mut);
896               ex = std::current_exception();
897             }
898             counter.count_down();
899           });
900         }
901         counter.wait();
902         if (ex)
903           std::rethrow_exception(ex);
904       }
905 
906 #endif
907     }
908 
909     //
910     // complex FFTPACK transforms
911     //
912 
913     template <typename T0>
914     class cfftp
915     {
916     private:
917       struct fctdata {
918         size_t fct;
919         cmplx<T0> *tw, *tws;
920       };
921 
922       size_t length;
923       arr<cmplx<T0>> mem;
924       std::vector<fctdata> fact;
925 
add_factor(size_t factor)926       void add_factor(size_t factor)
927       {
928         fact.push_back({factor, nullptr, nullptr});
929       }
930 
931       template <bool fwd, typename T>
pass2(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const932       void pass2(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
933                  T *POCKETFFT_RESTRICT ch,
934                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
935       {
936         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
937                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
938         auto CC = [cc, ido](size_t a, size_t b, size_t c)
939                       -> const T &{ return cc[a + ido * (b + 2 * c)]; };
940         auto WA =
941             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
942 
943         if (ido == 1)
944           for (size_t k = 0; k < l1; ++k) {
945             CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k);
946             CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k);
947           }
948         else
949           for (size_t k = 0; k < l1; ++k) {
950             CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k);
951             CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k);
952             for (size_t i = 1; i < ido; ++i) {
953               CH(i, k, 0) = CC(i, 0, k) + CC(i, 1, k);
954               special_mul<fwd>(CC(i, 0, k) - CC(i, 1, k), WA(0, i),
955                                CH(i, k, 1));
956             }
957           }
958       }
959 
960 #define POCKETFFT_PREP3(idx)                                                   \
961   T t0 = CC(idx, 0, k), t1, t2;                                                \
962   PM(t1, t2, CC(idx, 1, k), CC(idx, 2, k));                                    \
963   CH(idx, k, 0) = t0 + t1;
964 #define POCKETFFT_PARTSTEP3a(u1, u2, twr, twi)                                 \
965   {                                                                            \
966     T ca = t0 + t1 * twr;                                                      \
967     T cb{-t2.i * twi, t2.r * twi};                                             \
968     PM(CH(0, k, u1), CH(0, k, u2), ca, cb);                                    \
969   }
970 #define POCKETFFT_PARTSTEP3b(u1, u2, twr, twi)                                 \
971   {                                                                            \
972     T ca = t0 + t1 * twr;                                                      \
973     T cb{-t2.i * twi, t2.r * twi};                                             \
974     special_mul<fwd>(ca + cb, WA(u1 - 1, i), CH(i, k, u1));                    \
975     special_mul<fwd>(ca - cb, WA(u2 - 1, i), CH(i, k, u2));                    \
976   }
977       template <bool fwd, typename T>
pass3(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const978       void pass3(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
979                  T *POCKETFFT_RESTRICT ch,
980                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
981       {
982         constexpr T0 tw1r = -0.5,
983                      tw1i = (fwd ? -1 : 1) *
984                             T0(0.8660254037844386467637231707529362L);
985 
986         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
987                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
988         auto CC = [cc, ido](size_t a, size_t b, size_t c)
989                       -> const T &{ return cc[a + ido * (b + 3 * c)]; };
990         auto WA =
991             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
992 
993         if (ido == 1)
994           for (size_t k = 0; k < l1; ++k) {
995             POCKETFFT_PREP3(0)
996             POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i)
997           }
998         else
999           for (size_t k = 0; k < l1; ++k) {
1000             {
1001               POCKETFFT_PREP3(0)
1002               POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i)
1003             }
1004             for (size_t i = 1; i < ido; ++i) {
1005               POCKETFFT_PREP3(i)
1006               POCKETFFT_PARTSTEP3b(1, 2, tw1r, tw1i)
1007             }
1008           }
1009       }
1010 
1011 #undef POCKETFFT_PARTSTEP3b
1012 #undef POCKETFFT_PARTSTEP3a
1013 #undef POCKETFFT_PREP3
1014 
1015       template <bool fwd, typename T>
pass4(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const1016       void pass4(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1017                  T *POCKETFFT_RESTRICT ch,
1018                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
1019       {
1020         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1021                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1022         auto CC = [cc, ido](size_t a, size_t b, size_t c)
1023                       -> const T &{ return cc[a + ido * (b + 4 * c)]; };
1024         auto WA =
1025             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
1026 
1027         if (ido == 1)
1028           for (size_t k = 0; k < l1; ++k) {
1029             T t1, t2, t3, t4;
1030             PM(t2, t1, CC(0, 0, k), CC(0, 2, k));
1031             PM(t3, t4, CC(0, 1, k), CC(0, 3, k));
1032             ROTX90<fwd>(t4);
1033             PM(CH(0, k, 0), CH(0, k, 2), t2, t3);
1034             PM(CH(0, k, 1), CH(0, k, 3), t1, t4);
1035           }
1036         else
1037           for (size_t k = 0; k < l1; ++k) {
1038             {
1039               T t1, t2, t3, t4;
1040               PM(t2, t1, CC(0, 0, k), CC(0, 2, k));
1041               PM(t3, t4, CC(0, 1, k), CC(0, 3, k));
1042               ROTX90<fwd>(t4);
1043               PM(CH(0, k, 0), CH(0, k, 2), t2, t3);
1044               PM(CH(0, k, 1), CH(0, k, 3), t1, t4);
1045             }
1046             for (size_t i = 1; i < ido; ++i) {
1047               T t1, t2, t3, t4;
1048               T cc0 = CC(i, 0, k), cc1 = CC(i, 1, k), cc2 = CC(i, 2, k),
1049                 cc3 = CC(i, 3, k);
1050               PM(t2, t1, cc0, cc2);
1051               PM(t3, t4, cc1, cc3);
1052               ROTX90<fwd>(t4);
1053               CH(i, k, 0) = t2 + t3;
1054               special_mul<fwd>(t1 + t4, WA(0, i), CH(i, k, 1));
1055               special_mul<fwd>(t2 - t3, WA(1, i), CH(i, k, 2));
1056               special_mul<fwd>(t1 - t4, WA(2, i), CH(i, k, 3));
1057             }
1058           }
1059       }
1060 
1061 #define POCKETFFT_PREP5(idx)                                                   \
1062   T t0 = CC(idx, 0, k), t1, t2, t3, t4;                                        \
1063   PM(t1, t4, CC(idx, 1, k), CC(idx, 4, k));                                    \
1064   PM(t2, t3, CC(idx, 2, k), CC(idx, 3, k));                                    \
1065   CH(idx, k, 0).r = t0.r + t1.r + t2.r;                                        \
1066   CH(idx, k, 0).i = t0.i + t1.i + t2.i;
1067 
1068 #define POCKETFFT_PARTSTEP5a(u1, u2, twar, twbr, twai, twbi)                   \
1069   {                                                                            \
1070     T ca, cb;                                                                  \
1071     ca.r = t0.r + twar * t1.r + twbr * t2.r;                                   \
1072     ca.i = t0.i + twar * t1.i + twbr * t2.i;                                   \
1073     cb.i = twai * t4.r twbi * t3.r;                                            \
1074     cb.r = -(twai * t4.i twbi * t3.i);                                         \
1075     PM(CH(0, k, u1), CH(0, k, u2), ca, cb);                                    \
1076   }
1077 
1078 #define POCKETFFT_PARTSTEP5b(u1, u2, twar, twbr, twai, twbi)                   \
1079   {                                                                            \
1080     T ca, cb, da, db;                                                          \
1081     ca.r = t0.r + twar * t1.r + twbr * t2.r;                                   \
1082     ca.i = t0.i + twar * t1.i + twbr * t2.i;                                   \
1083     cb.i = twai * t4.r twbi * t3.r;                                            \
1084     cb.r = -(twai * t4.i twbi * t3.i);                                         \
1085     special_mul<fwd>(ca + cb, WA(u1 - 1, i), CH(i, k, u1));                    \
1086     special_mul<fwd>(ca - cb, WA(u2 - 1, i), CH(i, k, u2));                    \
1087   }
1088       template <bool fwd, typename T>
pass5(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const1089       void pass5(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1090                  T *POCKETFFT_RESTRICT ch,
1091                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
1092       {
1093         constexpr T0 tw1r = T0(0.3090169943749474241022934171828191L),
1094                      tw1i = (fwd ? -1 : 1) *
1095                             T0(0.9510565162951535721164393333793821L),
1096                      tw2r = T0(-0.8090169943749474241022934171828191L),
1097                      tw2i = (fwd ? -1 : 1) *
1098                             T0(0.5877852522924731291687059546390728L);
1099 
1100         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1101                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1102         auto CC = [cc, ido](size_t a, size_t b, size_t c)
1103                       -> const T &{ return cc[a + ido * (b + 5 * c)]; };
1104         auto WA =
1105             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
1106 
1107         if (ido == 1)
1108           for (size_t k = 0; k < l1; ++k) {
1109             POCKETFFT_PREP5(0)
1110             POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i)
1111                 POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i)
1112           }
1113         else
1114           for (size_t k = 0; k < l1; ++k) {
1115             {
1116               POCKETFFT_PREP5(0)
1117               POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i)
1118                   POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i)
1119             }
1120             for (size_t i = 1; i < ido; ++i) {
1121               POCKETFFT_PREP5(i)
1122               POCKETFFT_PARTSTEP5b(1, 4, tw1r, tw2r, +tw1i, +tw2i)
1123                   POCKETFFT_PARTSTEP5b(2, 3, tw2r, tw1r, +tw2i, -tw1i)
1124             }
1125           }
1126       }
1127 
1128 #undef POCKETFFT_PARTSTEP5b
1129 #undef POCKETFFT_PARTSTEP5a
1130 #undef POCKETFFT_PREP5
1131 
1132 #define POCKETFFT_PREP7(idx)                                                   \
1133   T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7;                                \
1134   PM(t2, t7, CC(idx, 1, k), CC(idx, 6, k));                                    \
1135   PM(t3, t6, CC(idx, 2, k), CC(idx, 5, k));                                    \
1136   PM(t4, t5, CC(idx, 3, k), CC(idx, 4, k));                                    \
1137   CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r;                                 \
1138   CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i;
1139 
1140 #define POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, out1, out2)      \
1141   {                                                                            \
1142     T ca, cb;                                                                  \
1143     ca.r = t1.r + x1 * t2.r + x2 * t3.r + x3 * t4.r;                           \
1144     ca.i = t1.i + x1 * t2.i + x2 * t3.i + x3 * t4.i;                           \
1145     cb.i = y1 * t7.r y2 * t6.r y3 * t5.r;                                      \
1146     cb.r = -(y1 * t7.i y2 * t6.i y3 * t5.i);                                   \
1147     PM(out1, out2, ca, cb);                                                    \
1148   }
1149 #define POCKETFFT_PARTSTEP7a(u1, u2, x1, x2, x3, y1, y2, y3)                   \
1150   POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, CH(0, k, u1),          \
1151                         CH(0, k, u2))
1152 #define POCKETFFT_PARTSTEP7(u1, u2, x1, x2, x3, y1, y2, y3)                    \
1153   {                                                                            \
1154     T da, db;                                                                  \
1155     POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, da, db)              \
1156         special_mul<fwd>(da, WA(u1 - 1, i), CH(i, k, u1));                     \
1157     special_mul<fwd>(db, WA(u2 - 1, i), CH(i, k, u2));                         \
1158   }
1159 
1160       template <bool fwd, typename T>
pass7(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const1161       void pass7(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1162                  T *POCKETFFT_RESTRICT ch,
1163                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
1164       {
1165         constexpr T0 tw1r = T0(0.6234898018587335305250048840042398L),
1166                      tw1i = (fwd ? -1 : 1) *
1167                             T0(0.7818314824680298087084445266740578L),
1168                      tw2r = T0(-0.2225209339563144042889025644967948L),
1169                      tw2i = (fwd ? -1 : 1) *
1170                             T0(0.9749279121818236070181316829939312L),
1171                      tw3r = T0(-0.9009688679024191262361023195074451L),
1172                      tw3i = (fwd ? -1 : 1) *
1173                             T0(0.433883739117558120475768332848359L);
1174 
1175         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1176                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1177         auto CC = [cc, ido](size_t a, size_t b, size_t c)
1178                       -> const T &{ return cc[a + ido * (b + 7 * c)]; };
1179         auto WA =
1180             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
1181 
1182         if (ido == 1)
1183           for (size_t k = 0; k < l1; ++k) {
1184             POCKETFFT_PREP7(0)
1185             POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)
1186                 POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i,
1187                                      -tw1i)
1188                     POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i,
1189                                          +tw2i)
1190           }
1191         else
1192           for (size_t k = 0; k < l1; ++k) {
1193             {
1194               POCKETFFT_PREP7(0)
1195               POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)
1196                   POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i,
1197                                        -tw1i)
1198                       POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i,
1199                                            +tw2i)
1200             }
1201             for (size_t i = 1; i < ido; ++i) {
1202               POCKETFFT_PREP7(i)
1203               POCKETFFT_PARTSTEP7(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)
1204               POCKETFFT_PARTSTEP7(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i)
1205               POCKETFFT_PARTSTEP7(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i)
1206             }
1207           }
1208       }
1209 
1210 #undef POCKETFFT_PARTSTEP7
1211 #undef POCKETFFT_PARTSTEP7a0
1212 #undef POCKETFFT_PARTSTEP7a
1213 #undef POCKETFFT_PREP7
1214 
1215       template <bool fwd, typename T>
ROTX45(T & a) const1216       void ROTX45(T &a) const
1217       {
1218         constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);
1219         if (fwd) {
1220           auto tmp_ = a.r;
1221           a.r = hsqt2 * (a.r + a.i);
1222           a.i = hsqt2 * (a.i - tmp_);
1223         } else {
1224           auto tmp_ = a.r;
1225           a.r = hsqt2 * (a.r - a.i);
1226           a.i = hsqt2 * (a.i + tmp_);
1227         }
1228       }
1229       template <bool fwd, typename T>
ROTX135(T & a) const1230       void ROTX135(T &a) const
1231       {
1232         constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);
1233         if (fwd) {
1234           auto tmp_ = a.r;
1235           a.r = hsqt2 * (a.i - a.r);
1236           a.i = hsqt2 * (-tmp_ - a.i);
1237         } else {
1238           auto tmp_ = a.r;
1239           a.r = hsqt2 * (-a.r - a.i);
1240           a.i = hsqt2 * (tmp_ - a.i);
1241         }
1242       }
1243 
1244       template <bool fwd, typename T>
pass8(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const1245       void pass8(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1246                  T *POCKETFFT_RESTRICT ch,
1247                  const cmplx<T0> *POCKETFFT_RESTRICT wa) const
1248       {
1249         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1250                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1251         auto CC = [cc, ido](size_t a, size_t b, size_t c)
1252                       -> const T &{ return cc[a + ido * (b + 8 * c)]; };
1253         auto WA =
1254             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
1255 
1256         if (ido == 1)
1257           for (size_t k = 0; k < l1; ++k) {
1258             T a0, a1, a2, a3, a4, a5, a6, a7;
1259             PM(a1, a5, CC(0, 1, k), CC(0, 5, k));
1260             PM(a3, a7, CC(0, 3, k), CC(0, 7, k));
1261             PMINPLACE(a1, a3);
1262             ROTX90<fwd>(a3);
1263 
1264             ROTX90<fwd>(a7);
1265             PMINPLACE(a5, a7);
1266             ROTX45<fwd>(a5);
1267             ROTX135<fwd>(a7);
1268 
1269             PM(a0, a4, CC(0, 0, k), CC(0, 4, k));
1270             PM(a2, a6, CC(0, 2, k), CC(0, 6, k));
1271             PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1);
1272             PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3);
1273             ROTX90<fwd>(a6);
1274             PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5);
1275             PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7);
1276           }
1277         else
1278           for (size_t k = 0; k < l1; ++k) {
1279             {
1280               T a0, a1, a2, a3, a4, a5, a6, a7;
1281               PM(a1, a5, CC(0, 1, k), CC(0, 5, k));
1282               PM(a3, a7, CC(0, 3, k), CC(0, 7, k));
1283               PMINPLACE(a1, a3);
1284               ROTX90<fwd>(a3);
1285 
1286               ROTX90<fwd>(a7);
1287               PMINPLACE(a5, a7);
1288               ROTX45<fwd>(a5);
1289               ROTX135<fwd>(a7);
1290 
1291               PM(a0, a4, CC(0, 0, k), CC(0, 4, k));
1292               PM(a2, a6, CC(0, 2, k), CC(0, 6, k));
1293               PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1);
1294               PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3);
1295               ROTX90<fwd>(a6);
1296               PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5);
1297               PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7);
1298             }
1299             for (size_t i = 1; i < ido; ++i) {
1300               T a0, a1, a2, a3, a4, a5, a6, a7;
1301               PM(a1, a5, CC(i, 1, k), CC(i, 5, k));
1302               PM(a3, a7, CC(i, 3, k), CC(i, 7, k));
1303               ROTX90<fwd>(a7);
1304               PMINPLACE(a1, a3);
1305               ROTX90<fwd>(a3);
1306               PMINPLACE(a5, a7);
1307               ROTX45<fwd>(a5);
1308               ROTX135<fwd>(a7);
1309               PM(a0, a4, CC(i, 0, k), CC(i, 4, k));
1310               PM(a2, a6, CC(i, 2, k), CC(i, 6, k));
1311               PMINPLACE(a0, a2);
1312               CH(i, k, 0) = a0 + a1;
1313               special_mul<fwd>(a0 - a1, WA(3, i), CH(i, k, 4));
1314               special_mul<fwd>(a2 + a3, WA(1, i), CH(i, k, 2));
1315               special_mul<fwd>(a2 - a3, WA(5, i), CH(i, k, 6));
1316               ROTX90<fwd>(a6);
1317               PMINPLACE(a4, a6);
1318               special_mul<fwd>(a4 + a5, WA(0, i), CH(i, k, 1));
1319               special_mul<fwd>(a4 - a5, WA(4, i), CH(i, k, 5));
1320               special_mul<fwd>(a6 + a7, WA(2, i), CH(i, k, 3));
1321               special_mul<fwd>(a6 - a7, WA(6, i), CH(i, k, 7));
1322             }
1323           }
1324       }
1325 
1326 #define POCKETFFT_PREP11(idx)                                                  \
1327   T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11;              \
1328   PM(t2, t11, CC(idx, 1, k), CC(idx, 10, k));                                  \
1329   PM(t3, t10, CC(idx, 2, k), CC(idx, 9, k));                                   \
1330   PM(t4, t9, CC(idx, 3, k), CC(idx, 8, k));                                    \
1331   PM(t5, t8, CC(idx, 4, k), CC(idx, 7, k));                                    \
1332   PM(t6, t7, CC(idx, 5, k), CC(idx, 6, k));                                    \
1333   CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r + t5.r + t6.r;                   \
1334   CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i + t5.i + t6.i;
1335 
1336 #define POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, \
1337                                out1, out2)                                     \
1338   {                                                                            \
1339     T ca = t1 + t2 * x1 + t3 * x2 + t4 * x3 + t5 * x4 + t6 * x5, cb;           \
1340     cb.i = y1 * t11.r y2 * t10.r y3 * t9.r y4 * t8.r y5 * t7.r;                \
1341     cb.r = -(y1 * t11.i y2 * t10.i y3 * t9.i y4 * t8.i y5 * t7.i);             \
1342     PM(out1, out2, ca, cb);                                                    \
1343   }
1344 #define POCKETFFT_PARTSTEP11a(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5)  \
1345   POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5,       \
1346                          CH(0, k, u1), CH(0, k, u2))
1347 #define POCKETFFT_PARTSTEP11(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5)   \
1348   {                                                                            \
1349     T da, db;                                                                  \
1350     POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, da, \
1351                            db)                                                 \
1352         special_mul<fwd>(da, WA(u1 - 1, i), CH(i, k, u1));                     \
1353     special_mul<fwd>(db, WA(u2 - 1, i), CH(i, k, u2));                         \
1354   }
1355 
1356       template <bool fwd, typename T>
pass11(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa) const1357       void pass11(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1358                   T *POCKETFFT_RESTRICT ch,
1359                   const cmplx<T0> *POCKETFFT_RESTRICT wa) const
1360       {
1361         constexpr T0
1362             tw1r = T0(0.8412535328311811688618116489193677L),
1363             tw1i = (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L),
1364             tw2r = T0(0.4154150130018864255292741492296232L),
1365             tw2i = (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L),
1366             tw3r = T0(-0.1423148382732851404437926686163697L),
1367             tw3i = (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L),
1368             tw4r = T0(-0.6548607339452850640569250724662936L),
1369             tw4i = (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L),
1370             tw5r = T0(-0.9594929736144973898903680570663277L),
1371             tw5i = (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L);
1372 
1373         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1374                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1375         auto CC = [cc, ido](size_t a, size_t b, size_t c)
1376                       -> const T &{ return cc[a + ido * (b + 11 * c)]; };
1377         auto WA =
1378             [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };
1379 
1380         if (ido == 1)
1381           for (size_t k = 0; k < l1; ++k) {
1382             POCKETFFT_PREP11(0)
1383             POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i,
1384                                   +tw2i, +tw3i, +tw4i, +tw5i)
1385                 POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i,
1386                                       +tw4i, -tw5i, -tw3i, -tw1i)
1387                     POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r,
1388                                           +tw3i, -tw5i, -tw2i, +tw1i, +tw4i)
1389                         POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r,
1390                                               tw2r, +tw4i, -tw3i, +tw1i, +tw5i,
1391                                               -tw2i)
1392                             POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r, tw2r,
1393                                                   tw3r, +tw5i, -tw1i, +tw4i,
1394                                                   -tw2i, +tw3i)
1395           }
1396         else
1397           for (size_t k = 0; k < l1; ++k) {
1398             {
1399               POCKETFFT_PREP11(0)
1400               POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i,
1401                                     +tw2i, +tw3i, +tw4i, +tw5i)
1402                   POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r,
1403                                         +tw2i, +tw4i, -tw5i, -tw3i, -tw1i)
1404                       POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r,
1405                                             +tw3i, -tw5i, -tw2i, +tw1i, +tw4i)
1406                           POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r,
1407                                                 tw2r, +tw4i, -tw3i, +tw1i,
1408                                                 +tw5i, -tw2i)
1409                               POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r,
1410                                                     tw2r, tw3r, +tw5i, -tw1i,
1411                                                     +tw4i, -tw2i, +tw3i)
1412             }
1413             for (size_t i = 1; i < ido; ++i) {
1414               POCKETFFT_PREP11(i)
1415               POCKETFFT_PARTSTEP11(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i,
1416                                    +tw2i, +tw3i, +tw4i, +tw5i)
1417               POCKETFFT_PARTSTEP11(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i,
1418                                    +tw4i, -tw5i, -tw3i, -tw1i)
1419               POCKETFFT_PARTSTEP11(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i,
1420                                    -tw5i, -tw2i, +tw1i, +tw4i)
1421               POCKETFFT_PARTSTEP11(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r, +tw4i,
1422                                    -tw3i, +tw1i, +tw5i, -tw2i)
1423               POCKETFFT_PARTSTEP11(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i,
1424                                    -tw1i, +tw4i, -tw2i, +tw3i)
1425             }
1426           }
1427       }
1428 
1429 #undef PARTSTEP11
1430 #undef PARTSTEP11a0
1431 #undef PARTSTEP11a
1432 #undef POCKETFFT_PREP11
1433 
1434       template <bool fwd, typename T>
passg(size_t ido,size_t ip,size_t l1,T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const cmplx<T0> * POCKETFFT_RESTRICT wa,const cmplx<T0> * POCKETFFT_RESTRICT csarr) const1435       void passg(size_t ido, size_t ip, size_t l1, T *POCKETFFT_RESTRICT cc,
1436                  T *POCKETFFT_RESTRICT ch,
1437                  const cmplx<T0> *POCKETFFT_RESTRICT wa,
1438                  const cmplx<T0> *POCKETFFT_RESTRICT csarr) const
1439       {
1440         const size_t cdim = ip;
1441         size_t ipph = (ip + 1) / 2;
1442         size_t idl1 = ido * l1;
1443 
1444         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1445                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
1446         auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c)
1447                       -> const T &{ return cc[a + ido * (b + cdim * c)]; };
1448         auto CX = [cc, ido, l1](size_t a, size_t b, size_t c)
1449                       -> T &{ return cc[a + ido * (b + l1 * c)]; };
1450         auto CX2 =
1451             [cc, idl1](size_t a, size_t b) -> T &{ return cc[a + idl1 * b]; };
1452         auto CH2 = [ch, idl1](size_t a, size_t b)
1453                        -> const T &{ return ch[a + idl1 * b]; };
1454 
1455         arr<cmplx<T0>> wal(ip);
1456         wal[0] = cmplx<T0>(1., 0.);
1457         for (size_t i = 1; i < ip; ++i)
1458           wal[i] = cmplx<T0>(csarr[i].r, fwd ? -csarr[i].i : csarr[i].i);
1459 
1460         for (size_t k = 0; k < l1; ++k)
1461           for (size_t i = 0; i < ido; ++i)
1462             CH(i, k, 0) = CC(i, 0, k);
1463         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)
1464           for (size_t k = 0; k < l1; ++k)
1465             for (size_t i = 0; i < ido; ++i)
1466               PM(CH(i, k, j), CH(i, k, jc), CC(i, j, k), CC(i, jc, k));
1467         for (size_t k = 0; k < l1; ++k)
1468           for (size_t i = 0; i < ido; ++i) {
1469             T tmp = CH(i, k, 0);
1470             for (size_t j = 1; j < ipph; ++j)
1471               tmp += CH(i, k, j);
1472             CX(i, k, 0) = tmp;
1473           }
1474         for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) {
1475           // j=0
1476           for (size_t ik = 0; ik < idl1; ++ik) {
1477             CX2(ik, l).r = CH2(ik, 0).r + wal[l].r * CH2(ik, 1).r +
1478                            wal[2 * l].r * CH2(ik, 2).r;
1479             CX2(ik, l).i = CH2(ik, 0).i + wal[l].r * CH2(ik, 1).i +
1480                            wal[2 * l].r * CH2(ik, 2).i;
1481             CX2(ik, lc).r = -wal[l].i * CH2(ik, ip - 1).i -
1482                             wal[2 * l].i * CH2(ik, ip - 2).i;
1483             CX2(ik, lc).i =
1484                 wal[l].i * CH2(ik, ip - 1).r + wal[2 * l].i * CH2(ik, ip - 2).r;
1485           }
1486 
1487           size_t iwal = 2 * l;
1488           size_t j = 3, jc = ip - 3;
1489           for (; j < ipph - 1; j += 2, jc -= 2) {
1490             iwal += l;
1491             if (iwal > ip)
1492               iwal -= ip;
1493             cmplx<T0> xwal = wal[iwal];
1494             iwal += l;
1495             if (iwal > ip)
1496               iwal -= ip;
1497             cmplx<T0> xwal2 = wal[iwal];
1498             for (size_t ik = 0; ik < idl1; ++ik) {
1499               CX2(ik, l).r +=
1500                   CH2(ik, j).r * xwal.r + CH2(ik, j + 1).r * xwal2.r;
1501               CX2(ik, l).i +=
1502                   CH2(ik, j).i * xwal.r + CH2(ik, j + 1).i * xwal2.r;
1503               CX2(ik, lc).r -=
1504                   CH2(ik, jc).i * xwal.i + CH2(ik, jc - 1).i * xwal2.i;
1505               CX2(ik, lc).i +=
1506                   CH2(ik, jc).r * xwal.i + CH2(ik, jc - 1).r * xwal2.i;
1507             }
1508           }
1509           for (; j < ipph; ++j, --jc) {
1510             iwal += l;
1511             if (iwal > ip)
1512               iwal -= ip;
1513             cmplx<T0> xwal = wal[iwal];
1514             for (size_t ik = 0; ik < idl1; ++ik) {
1515               CX2(ik, l).r += CH2(ik, j).r * xwal.r;
1516               CX2(ik, l).i += CH2(ik, j).i * xwal.r;
1517               CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i;
1518               CX2(ik, lc).i += CH2(ik, jc).r * xwal.i;
1519             }
1520           }
1521         }
1522 
1523         // shuffling and twiddling
1524         if (ido == 1)
1525           for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)
1526             for (size_t ik = 0; ik < idl1; ++ik) {
1527               T t1 = CX2(ik, j), t2 = CX2(ik, jc);
1528               PM(CX2(ik, j), CX2(ik, jc), t1, t2);
1529             }
1530         else {
1531           for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)
1532             for (size_t k = 0; k < l1; ++k) {
1533               T t1 = CX(0, k, j), t2 = CX(0, k, jc);
1534               PM(CX(0, k, j), CX(0, k, jc), t1, t2);
1535               for (size_t i = 1; i < ido; ++i) {
1536                 T x1, x2;
1537                 PM(x1, x2, CX(i, k, j), CX(i, k, jc));
1538                 size_t idij = (j - 1) * (ido - 1) + i - 1;
1539                 special_mul<fwd>(x1, wa[idij], CX(i, k, j));
1540                 idij = (jc - 1) * (ido - 1) + i - 1;
1541                 special_mul<fwd>(x2, wa[idij], CX(i, k, jc));
1542               }
1543             }
1544         }
1545       }
1546 
1547       template <bool fwd, typename T>
pass_all(T c[],T0 fct) const1548       void pass_all(T c[], T0 fct) const
1549       {
1550         if (length == 1) {
1551           c[0] *= fct;
1552           return;
1553         }
1554         size_t l1 = 1;
1555         arr<T> ch(length);
1556         T *p1 = c, *p2 = ch.data();
1557 
1558         for (size_t k1 = 0; k1 < fact.size(); k1++) {
1559           size_t ip = fact[k1].fct;
1560           size_t l2 = ip * l1;
1561           size_t ido = length / l2;
1562           if (ip == 4)
1563             pass4<fwd>(ido, l1, p1, p2, fact[k1].tw);
1564           else if (ip == 8)
1565             pass8<fwd>(ido, l1, p1, p2, fact[k1].tw);
1566           else if (ip == 2)
1567             pass2<fwd>(ido, l1, p1, p2, fact[k1].tw);
1568           else if (ip == 3)
1569             pass3<fwd>(ido, l1, p1, p2, fact[k1].tw);
1570           else if (ip == 5)
1571             pass5<fwd>(ido, l1, p1, p2, fact[k1].tw);
1572           else if (ip == 7)
1573             pass7<fwd>(ido, l1, p1, p2, fact[k1].tw);
1574           else if (ip == 11)
1575             pass11<fwd>(ido, l1, p1, p2, fact[k1].tw);
1576           else {
1577             passg<fwd>(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws);
1578             std::swap(p1, p2);
1579           }
1580           std::swap(p1, p2);
1581           l1 = l2;
1582         }
1583         if (p1 != c) {
1584           if (fct != 1.)
1585             for (size_t i = 0; i < length; ++i)
1586               c[i] = ch[i] * fct;
1587           else
1588             memcpy(c, p1, length * sizeof(T));
1589         } else if (fct != 1.)
1590           for (size_t i = 0; i < length; ++i)
1591             c[i] *= fct;
1592       }
1593 
1594     public:
1595       template <typename T>
exec(T c[],T0 fct,bool fwd) const1596       void exec(T c[], T0 fct, bool fwd) const
1597       {
1598         fwd ? pass_all<true>(c, fct) : pass_all<false>(c, fct);
1599       }
1600 
1601     private:
factorize()1602       POCKETFFT_NOINLINE void factorize()
1603       {
1604         size_t len = length;
1605         while ((len & 7) == 0) {
1606           add_factor(8);
1607           len >>= 3;
1608         }
1609         while ((len & 3) == 0) {
1610           add_factor(4);
1611           len >>= 2;
1612         }
1613         if ((len & 1) == 0) {
1614           len >>= 1;
1615           // factor 2 should be at the front of the factor list
1616           add_factor(2);
1617           std::swap(fact[0].fct, fact.back().fct);
1618         }
1619         for (size_t divisor = 3; divisor * divisor <= len; divisor += 2)
1620           while ((len % divisor) == 0) {
1621             add_factor(divisor);
1622             len /= divisor;
1623           }
1624         if (len > 1)
1625           add_factor(len);
1626       }
1627 
twsize() const1628       size_t twsize() const
1629       {
1630         size_t twsize = 0, l1 = 1;
1631         for (size_t k = 0; k < fact.size(); ++k) {
1632           size_t ip = fact[k].fct, ido = length / (l1 * ip);
1633           twsize += (ip - 1) * (ido - 1);
1634           if (ip > 11)
1635             twsize += ip;
1636           l1 *= ip;
1637         }
1638         return twsize;
1639       }
1640 
comp_twiddle()1641       void comp_twiddle()
1642       {
1643         sincos_2pibyn<T0> twiddle(length);
1644         size_t l1 = 1;
1645         size_t memofs = 0;
1646         for (size_t k = 0; k < fact.size(); ++k) {
1647           size_t ip = fact[k].fct, ido = length / (l1 * ip);
1648           fact[k].tw = mem.data() + memofs;
1649           memofs += (ip - 1) * (ido - 1);
1650           for (size_t j = 1; j < ip; ++j)
1651             for (size_t i = 1; i < ido; ++i)
1652               fact[k].tw[(j - 1) * (ido - 1) + i - 1] = twiddle[j * l1 * i];
1653           if (ip > 11) {
1654             fact[k].tws = mem.data() + memofs;
1655             memofs += ip;
1656             for (size_t j = 0; j < ip; ++j)
1657               fact[k].tws[j] = twiddle[j * l1 * ido];
1658           }
1659           l1 *= ip;
1660         }
1661       }
1662 
1663     public:
cfftp(size_t length_)1664       POCKETFFT_NOINLINE cfftp(size_t length_) : length(length_)
1665       {
1666         if (length == 0)
1667           throw std::runtime_error("zero-length FFT requested");
1668         if (length == 1)
1669           return;
1670         factorize();
1671         mem.resize(twsize());
1672         comp_twiddle();
1673       }
1674     };
1675 
1676     //
1677     // real-valued FFTPACK transforms
1678     //
1679 
1680     template <typename T0>
1681     class rfftp
1682     {
1683     private:
1684       struct fctdata {
1685         size_t fct;
1686         T0 *tw, *tws;
1687       };
1688 
1689       size_t length;
1690       arr<T0> mem;
1691       std::vector<fctdata> fact;
1692 
add_factor(size_t factor)1693       void add_factor(size_t factor)
1694       {
1695         fact.push_back({factor, nullptr, nullptr});
1696       }
1697 
1698       /* (a+ib) = conj(c+id) * (e+if) */
1699       template <typename T1, typename T2, typename T3>
MULPM(T1 & a,T1 & b,T2 c,T2 d,T3 e,T3 f) const1700       inline void MULPM(T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const
1701       {
1702         a = c * e + d * f;
1703         b = c * f - d * e;
1704       }
1705 
1706       template <typename T>
radf2(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const1707       void radf2(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1708                  T *POCKETFFT_RESTRICT ch,
1709                  const T0 *POCKETFFT_RESTRICT wa) const
1710       {
1711         auto WA =
1712             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
1713         auto CC = [cc, ido, l1](size_t a, size_t b, size_t c)
1714                       -> const T &{ return cc[a + ido * (b + l1 * c)]; };
1715         auto CH = [ch, ido](size_t a, size_t b, size_t c)
1716                       -> T &{ return ch[a + ido * (b + 2 * c)]; };
1717 
1718         for (size_t k = 0; k < l1; k++)
1719           PM(CH(0, 0, k), CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 1));
1720         if ((ido & 1) == 0)
1721           for (size_t k = 0; k < l1; k++) {
1722             CH(0, 1, k) = -CC(ido - 1, k, 1);
1723             CH(ido - 1, 0, k) = CC(ido - 1, k, 0);
1724           }
1725         if (ido <= 2)
1726           return;
1727         for (size_t k = 0; k < l1; k++)
1728           for (size_t i = 2; i < ido; i += 2) {
1729             size_t ic = ido - i;
1730             T tr2, ti2;
1731             MULPM(tr2, ti2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1),
1732                   CC(i, k, 1));
1733             PM(CH(i - 1, 0, k), CH(ic - 1, 1, k), CC(i - 1, k, 0), tr2);
1734             PM(CH(i, 0, k), CH(ic, 1, k), ti2, CC(i, k, 0));
1735           }
1736       }
1737 
1738 // a2=a+b; b2=i*(b-a);
1739 #define POCKETFFT_REARRANGE(rx, ix, ry, iy)                                    \
1740   {                                                                            \
1741     auto t1 = rx + ry, t2 = ry - rx, t3 = ix + iy, t4 = ix - iy;               \
1742     rx = t1;                                                                   \
1743     ix = t3;                                                                   \
1744     ry = t4;                                                                   \
1745     iy = t2;                                                                   \
1746   }
1747 
1748       template <typename T>
radf3(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const1749       void radf3(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1750                  T *POCKETFFT_RESTRICT ch,
1751                  const T0 *POCKETFFT_RESTRICT wa) const
1752       {
1753         constexpr T0 taur = -0.5,
1754                      taui = T0(0.8660254037844386467637231707529362L);
1755 
1756         auto WA =
1757             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
1758         auto CC = [cc, ido, l1](size_t a, size_t b, size_t c)
1759                       -> const T &{ return cc[a + ido * (b + l1 * c)]; };
1760         auto CH = [ch, ido](size_t a, size_t b, size_t c)
1761                       -> T &{ return ch[a + ido * (b + 3 * c)]; };
1762 
1763         for (size_t k = 0; k < l1; k++) {
1764           T cr2 = CC(0, k, 1) + CC(0, k, 2);
1765           CH(0, 0, k) = CC(0, k, 0) + cr2;
1766           CH(0, 2, k) = taui * (CC(0, k, 2) - CC(0, k, 1));
1767           CH(ido - 1, 1, k) = CC(0, k, 0) + taur * cr2;
1768         }
1769         if (ido == 1)
1770           return;
1771         for (size_t k = 0; k < l1; k++)
1772           for (size_t i = 2; i < ido; i += 2) {
1773             size_t ic = ido - i;
1774             T di2, di3, dr2, dr3;
1775             MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1),
1776                   CC(i, k, 1)); // d2=conj(WA0)*CC1
1777             MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2),
1778                   CC(i, k, 2)); // d3=conj(WA1)*CC2
1779             POCKETFFT_REARRANGE(dr2, di2, dr3, di3);
1780             CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2; // c add
1781             CH(i, 0, k) = CC(i, k, 0) + di2;
1782             T tr2 = CC(i - 1, k, 0) + taur * dr2; // c add
1783             T ti2 = CC(i, k, 0) + taur * di2;
1784             T tr3 = taui * dr3; // t3 = taui*i*(d3-d2)?
1785             T ti3 = taui * di3;
1786             PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr3); // PM(i) = t2+t3
1787             PM(CH(i, 2, k), CH(ic, 1, k), ti3, ti2); // PM(ic) = conj(t2-t3)
1788           }
1789       }
1790 
1791       template <typename T>
radf4(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const1792       void radf4(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1793                  T *POCKETFFT_RESTRICT ch,
1794                  const T0 *POCKETFFT_RESTRICT wa) const
1795       {
1796         constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);
1797 
1798         auto WA =
1799             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
1800         auto CC = [cc, ido, l1](size_t a, size_t b, size_t c)
1801                       -> const T &{ return cc[a + ido * (b + l1 * c)]; };
1802         auto CH = [ch, ido](size_t a, size_t b, size_t c)
1803                       -> T &{ return ch[a + ido * (b + 4 * c)]; };
1804 
1805         for (size_t k = 0; k < l1; k++) {
1806           T tr1, tr2;
1807           PM(tr1, CH(0, 2, k), CC(0, k, 3), CC(0, k, 1));
1808           PM(tr2, CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 2));
1809           PM(CH(0, 0, k), CH(ido - 1, 3, k), tr2, tr1);
1810         }
1811         if ((ido & 1) == 0)
1812           for (size_t k = 0; k < l1; k++) {
1813             T ti1 = -hsqt2 * (CC(ido - 1, k, 1) + CC(ido - 1, k, 3));
1814             T tr1 = hsqt2 * (CC(ido - 1, k, 1) - CC(ido - 1, k, 3));
1815             PM(CH(ido - 1, 0, k), CH(ido - 1, 2, k), CC(ido - 1, k, 0), tr1);
1816             PM(CH(0, 3, k), CH(0, 1, k), ti1, CC(ido - 1, k, 2));
1817           }
1818         if (ido <= 2)
1819           return;
1820         for (size_t k = 0; k < l1; k++)
1821           for (size_t i = 2; i < ido; i += 2) {
1822             size_t ic = ido - i;
1823             T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3,
1824                 tr4;
1825             MULPM(cr2, ci2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1),
1826                   CC(i, k, 1));
1827             MULPM(cr3, ci3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2),
1828                   CC(i, k, 2));
1829             MULPM(cr4, ci4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3),
1830                   CC(i, k, 3));
1831             PM(tr1, tr4, cr4, cr2);
1832             PM(ti1, ti4, ci2, ci4);
1833             PM(tr2, tr3, CC(i - 1, k, 0), cr3);
1834             PM(ti2, ti3, CC(i, k, 0), ci3);
1835             PM(CH(i - 1, 0, k), CH(ic - 1, 3, k), tr2, tr1);
1836             PM(CH(i, 0, k), CH(ic, 3, k), ti1, ti2);
1837             PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr3, ti4);
1838             PM(CH(i, 2, k), CH(ic, 1, k), tr4, ti3);
1839           }
1840       }
1841 
1842       template <typename T>
radf5(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const1843       void radf5(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
1844                  T *POCKETFFT_RESTRICT ch,
1845                  const T0 *POCKETFFT_RESTRICT wa) const
1846       {
1847         constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L),
1848                      ti11 = T0(0.9510565162951535721164393333793821L),
1849                      tr12 = T0(-0.8090169943749474241022934171828191L),
1850                      ti12 = T0(0.5877852522924731291687059546390728L);
1851 
1852         auto WA =
1853             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
1854         auto CC = [cc, ido, l1](size_t a, size_t b, size_t c)
1855                       -> const T &{ return cc[a + ido * (b + l1 * c)]; };
1856         auto CH = [ch, ido](size_t a, size_t b, size_t c)
1857                       -> T &{ return ch[a + ido * (b + 5 * c)]; };
1858 
1859         for (size_t k = 0; k < l1; k++) {
1860           T cr2, cr3, ci4, ci5;
1861           PM(cr2, ci5, CC(0, k, 4), CC(0, k, 1));
1862           PM(cr3, ci4, CC(0, k, 3), CC(0, k, 2));
1863           CH(0, 0, k) = CC(0, k, 0) + cr2 + cr3;
1864           CH(ido - 1, 1, k) = CC(0, k, 0) + tr11 * cr2 + tr12 * cr3;
1865           CH(0, 2, k) = ti11 * ci5 + ti12 * ci4;
1866           CH(ido - 1, 3, k) = CC(0, k, 0) + tr12 * cr2 + tr11 * cr3;
1867           CH(0, 4, k) = ti12 * ci5 - ti11 * ci4;
1868         }
1869         if (ido == 1)
1870           return;
1871         for (size_t k = 0; k < l1; ++k)
1872           for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {
1873             T di2, di3, di4, di5, dr2, dr3, dr4, dr5;
1874             MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1),
1875                   CC(i, k, 1));
1876             MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2),
1877                   CC(i, k, 2));
1878             MULPM(dr4, di4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3),
1879                   CC(i, k, 3));
1880             MULPM(dr5, di5, WA(3, i - 2), WA(3, i - 1), CC(i - 1, k, 4),
1881                   CC(i, k, 4));
1882             POCKETFFT_REARRANGE(dr2, di2, dr5, di5);
1883             POCKETFFT_REARRANGE(dr3, di3, dr4, di4);
1884             CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2 + dr3;
1885             CH(i, 0, k) = CC(i, k, 0) + di2 + di3;
1886             T tr2 = CC(i - 1, k, 0) + tr11 * dr2 + tr12 * dr3;
1887             T ti2 = CC(i, k, 0) + tr11 * di2 + tr12 * di3;
1888             T tr3 = CC(i - 1, k, 0) + tr12 * dr2 + tr11 * dr3;
1889             T ti3 = CC(i, k, 0) + tr12 * di2 + tr11 * di3;
1890             T tr5 = ti11 * dr5 + ti12 * dr4;
1891             T ti5 = ti11 * di5 + ti12 * di4;
1892             T tr4 = ti12 * dr5 - ti11 * dr4;
1893             T ti4 = ti12 * di5 - ti11 * di4;
1894             PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr5);
1895             PM(CH(i, 2, k), CH(ic, 1, k), ti5, ti2);
1896             PM(CH(i - 1, 4, k), CH(ic - 1, 3, k), tr3, tr4);
1897             PM(CH(i, 4, k), CH(ic, 3, k), ti4, ti3);
1898           }
1899       }
1900 
1901 #undef POCKETFFT_REARRANGE
1902 
1903       template <typename T>
radfg(size_t ido,size_t ip,size_t l1,T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa,const T0 * POCKETFFT_RESTRICT csarr) const1904       void radfg(size_t ido, size_t ip, size_t l1, T *POCKETFFT_RESTRICT cc,
1905                  T *POCKETFFT_RESTRICT ch, const T0 *POCKETFFT_RESTRICT wa,
1906                  const T0 *POCKETFFT_RESTRICT csarr) const
1907       {
1908         const size_t cdim = ip;
1909         size_t ipph = (ip + 1) / 2;
1910         size_t idl1 = ido * l1;
1911 
1912         auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c)
1913                       -> T &{ return cc[a + ido * (b + cdim * c)]; };
1914         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
1915                       -> const T &{ return ch[a + ido * (b + l1 * c)]; };
1916         auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c)
1917                       -> T &{ return cc[a + ido * (b + l1 * c)]; };
1918         auto C2 =
1919             [cc, idl1](size_t a, size_t b) -> T &{ return cc[a + idl1 * b]; };
1920         auto CH2 =
1921             [ch, idl1](size_t a, size_t b) -> T &{ return ch[a + idl1 * b]; };
1922 
1923         if (ido > 1) {
1924           for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 114
1925           {
1926             size_t is = (j - 1) * (ido - 1), is2 = (jc - 1) * (ido - 1);
1927             for (size_t k = 0; k < l1; ++k) // 113
1928             {
1929               size_t idij = is;
1930               size_t idij2 = is2;
1931               for (size_t i = 1; i <= ido - 2; i += 2) // 112
1932               {
1933                 T t1 = C1(i, k, j), t2 = C1(i + 1, k, j), t3 = C1(i, k, jc),
1934                   t4 = C1(i + 1, k, jc);
1935                 T x1 = wa[idij] * t1 + wa[idij + 1] * t2,
1936                   x2 = wa[idij] * t2 - wa[idij + 1] * t1,
1937                   x3 = wa[idij2] * t3 + wa[idij2 + 1] * t4,
1938                   x4 = wa[idij2] * t4 - wa[idij2 + 1] * t3;
1939                 PM(C1(i, k, j), C1(i + 1, k, jc), x3, x1);
1940                 PM(C1(i + 1, k, j), C1(i, k, jc), x2, x4);
1941                 idij += 2;
1942                 idij2 += 2;
1943               }
1944             }
1945           }
1946         }
1947 
1948         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 123
1949           for (size_t k = 0; k < l1; ++k)                    // 122
1950             MPINPLACE(C1(0, k, jc), C1(0, k, j));
1951 
1952         // everything in C
1953         // memset(ch,0,ip*l1*ido*sizeof(double));
1954 
1955         for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) // 127
1956         {
1957           for (size_t ik = 0; ik < idl1; ++ik) // 124
1958           {
1959             CH2(ik, l) =
1960                 C2(ik, 0) + csarr[2 * l] * C2(ik, 1) + csarr[4 * l] * C2(ik, 2);
1961             CH2(ik, lc) = csarr[2 * l + 1] * C2(ik, ip - 1) +
1962                           csarr[4 * l + 1] * C2(ik, ip - 2);
1963           }
1964           size_t iang = 2 * l;
1965           size_t j = 3, jc = ip - 3;
1966           for (; j < ipph - 3; j += 4, jc -= 4) // 126
1967           {
1968             iang += l;
1969             if (iang >= ip)
1970               iang -= ip;
1971             T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];
1972             iang += l;
1973             if (iang >= ip)
1974               iang -= ip;
1975             T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];
1976             iang += l;
1977             if (iang >= ip)
1978               iang -= ip;
1979             T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1];
1980             iang += l;
1981             if (iang >= ip)
1982               iang -= ip;
1983             T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1];
1984             for (size_t ik = 0; ik < idl1; ++ik) // 125
1985             {
1986               CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1) +
1987                             ar3 * C2(ik, j + 2) + ar4 * C2(ik, j + 3);
1988               CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1) +
1989                              ai3 * C2(ik, jc - 2) + ai4 * C2(ik, jc - 3);
1990             }
1991           }
1992           for (; j < ipph - 1; j += 2, jc -= 2) // 126
1993           {
1994             iang += l;
1995             if (iang >= ip)
1996               iang -= ip;
1997             T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];
1998             iang += l;
1999             if (iang >= ip)
2000               iang -= ip;
2001             T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];
2002             for (size_t ik = 0; ik < idl1; ++ik) // 125
2003             {
2004               CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1);
2005               CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1);
2006             }
2007           }
2008           for (; j < ipph; ++j, --jc) // 126
2009           {
2010             iang += l;
2011             if (iang >= ip)
2012               iang -= ip;
2013             T0 ar = csarr[2 * iang], ai = csarr[2 * iang + 1];
2014             for (size_t ik = 0; ik < idl1; ++ik) // 125
2015             {
2016               CH2(ik, l) += ar * C2(ik, j);
2017               CH2(ik, lc) += ai * C2(ik, jc);
2018             }
2019           }
2020         }
2021         for (size_t ik = 0; ik < idl1; ++ik) // 101
2022           CH2(ik, 0) = C2(ik, 0);
2023         for (size_t j = 1; j < ipph; ++j)      // 129
2024           for (size_t ik = 0; ik < idl1; ++ik) // 128
2025             CH2(ik, 0) += C2(ik, j);
2026 
2027         // everything in CH at this point!
2028         // memset(cc,0,ip*l1*ido*sizeof(double));
2029 
2030         for (size_t k = 0; k < l1; ++k)    // 131
2031           for (size_t i = 0; i < ido; ++i) // 130
2032             CC(i, 0, k) = CH(i, k, 0);
2033 
2034         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 137
2035         {
2036           size_t j2 = 2 * j - 1;
2037           for (size_t k = 0; k < l1; ++k) // 136
2038           {
2039             CC(ido - 1, j2, k) = CH(0, k, j);
2040             CC(0, j2 + 1, k) = CH(0, k, jc);
2041           }
2042         }
2043 
2044         if (ido == 1)
2045           return;
2046 
2047         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 140
2048         {
2049           size_t j2 = 2 * j - 1;
2050           for (size_t k = 0; k < l1; ++k) // 139
2051             for (size_t i = 1, ic = ido - i - 2; i <= ido - 2;
2052                  i += 2, ic -= 2) // 138
2053             {
2054               CC(i, j2 + 1, k) = CH(i, k, j) + CH(i, k, jc);
2055               CC(ic, j2, k) = CH(i, k, j) - CH(i, k, jc);
2056               CC(i + 1, j2 + 1, k) = CH(i + 1, k, j) + CH(i + 1, k, jc);
2057               CC(ic + 1, j2, k) = CH(i + 1, k, jc) - CH(i + 1, k, j);
2058             }
2059         }
2060       }
2061 
2062       template <typename T>
radb2(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const2063       void radb2(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
2064                  T *POCKETFFT_RESTRICT ch,
2065                  const T0 *POCKETFFT_RESTRICT wa) const
2066       {
2067         auto WA =
2068             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
2069         auto CC = [cc, ido](size_t a, size_t b, size_t c)
2070                       -> const T &{ return cc[a + ido * (b + 2 * c)]; };
2071         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
2072                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
2073 
2074         for (size_t k = 0; k < l1; k++)
2075           PM(CH(0, k, 0), CH(0, k, 1), CC(0, 0, k), CC(ido - 1, 1, k));
2076         if ((ido & 1) == 0)
2077           for (size_t k = 0; k < l1; k++) {
2078             CH(ido - 1, k, 0) = 2 * CC(ido - 1, 0, k);
2079             CH(ido - 1, k, 1) = -2 * CC(0, 1, k);
2080           }
2081         if (ido <= 2)
2082           return;
2083         for (size_t k = 0; k < l1; ++k)
2084           for (size_t i = 2; i < ido; i += 2) {
2085             size_t ic = ido - i;
2086             T ti2, tr2;
2087             PM(CH(i - 1, k, 0), tr2, CC(i - 1, 0, k), CC(ic - 1, 1, k));
2088             PM(ti2, CH(i, k, 0), CC(i, 0, k), CC(ic, 1, k));
2089             MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ti2,
2090                   tr2);
2091           }
2092       }
2093 
2094       template <typename T>
radb3(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const2095       void radb3(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
2096                  T *POCKETFFT_RESTRICT ch,
2097                  const T0 *POCKETFFT_RESTRICT wa) const
2098       {
2099         constexpr T0 taur = -0.5,
2100                      taui = T0(0.8660254037844386467637231707529362L);
2101 
2102         auto WA =
2103             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
2104         auto CC = [cc, ido](size_t a, size_t b, size_t c)
2105                       -> const T &{ return cc[a + ido * (b + 3 * c)]; };
2106         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
2107                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
2108 
2109         for (size_t k = 0; k < l1; k++) {
2110           T tr2 = 2 * CC(ido - 1, 1, k);
2111           T cr2 = CC(0, 0, k) + taur * tr2;
2112           CH(0, k, 0) = CC(0, 0, k) + tr2;
2113           T ci3 = 2 * taui * CC(0, 2, k);
2114           PM(CH(0, k, 2), CH(0, k, 1), cr2, ci3);
2115         }
2116         if (ido == 1)
2117           return;
2118         for (size_t k = 0; k < l1; k++)
2119           for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {
2120             T tr2 =
2121                 CC(i - 1, 2, k) + CC(ic - 1, 1, k); // t2=CC(I) + conj(CC(ic))
2122             T ti2 = CC(i, 2, k) - CC(ic, 1, k);
2123             T cr2 = CC(i - 1, 0, k) + taur * tr2; // c2=CC +taur*t2
2124             T ci2 = CC(i, 0, k) + taur * ti2;
2125             CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2; // CH=CC+t2
2126             CH(i, k, 0) = CC(i, 0, k) + ti2;
2127             T cr3 = taui * (CC(i - 1, 2, k) -
2128                             CC(ic - 1, 1, k)); // c3=taui*(CC(i)-conj(CC(ic)))
2129             T ci3 = taui * (CC(i, 2, k) + CC(ic, 1, k));
2130             T di2, di3, dr2, dr3;
2131             PM(dr3, dr2, cr2, ci3); // d2= (cr2-ci3, ci2+cr3) = c2+i*c3
2132             PM(di2, di3, ci2, cr3); // d3= (cr2+ci3, ci2-cr3) = c2-i*c3
2133             MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2,
2134                   dr2); // ch = WA*d2
2135             MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3,
2136                   dr3);
2137           }
2138       }
2139 
2140       template <typename T>
radb4(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const2141       void radb4(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
2142                  T *POCKETFFT_RESTRICT ch,
2143                  const T0 *POCKETFFT_RESTRICT wa) const
2144       {
2145         constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);
2146 
2147         auto WA =
2148             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
2149         auto CC = [cc, ido](size_t a, size_t b, size_t c)
2150                       -> const T &{ return cc[a + ido * (b + 4 * c)]; };
2151         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
2152                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
2153 
2154         for (size_t k = 0; k < l1; k++) {
2155           T tr1, tr2;
2156           PM(tr2, tr1, CC(0, 0, k), CC(ido - 1, 3, k));
2157           T tr3 = 2 * CC(ido - 1, 1, k);
2158           T tr4 = 2 * CC(0, 2, k);
2159           PM(CH(0, k, 0), CH(0, k, 2), tr2, tr3);
2160           PM(CH(0, k, 3), CH(0, k, 1), tr1, tr4);
2161         }
2162         if ((ido & 1) == 0)
2163           for (size_t k = 0; k < l1; k++) {
2164             T tr1, tr2, ti1, ti2;
2165             PM(ti1, ti2, CC(0, 3, k), CC(0, 1, k));
2166             PM(tr2, tr1, CC(ido - 1, 0, k), CC(ido - 1, 2, k));
2167             CH(ido - 1, k, 0) = tr2 + tr2;
2168             CH(ido - 1, k, 1) = sqrt2 * (tr1 - ti1);
2169             CH(ido - 1, k, 2) = ti2 + ti2;
2170             CH(ido - 1, k, 3) = -sqrt2 * (tr1 + ti1);
2171           }
2172         if (ido <= 2)
2173           return;
2174         for (size_t k = 0; k < l1; ++k)
2175           for (size_t i = 2; i < ido; i += 2) {
2176             T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3,
2177                 tr4;
2178             size_t ic = ido - i;
2179             PM(tr2, tr1, CC(i - 1, 0, k), CC(ic - 1, 3, k));
2180             PM(ti1, ti2, CC(i, 0, k), CC(ic, 3, k));
2181             PM(tr4, ti3, CC(i, 2, k), CC(ic, 1, k));
2182             PM(tr3, ti4, CC(i - 1, 2, k), CC(ic - 1, 1, k));
2183             PM(CH(i - 1, k, 0), cr3, tr2, tr3);
2184             PM(CH(i, k, 0), ci3, ti2, ti3);
2185             PM(cr4, cr2, tr1, tr4);
2186             PM(ci2, ci4, ti1, ti4);
2187             MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ci2,
2188                   cr2);
2189             MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), ci3,
2190                   cr3);
2191             MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), ci4,
2192                   cr4);
2193           }
2194       }
2195 
2196       template <typename T>
radb5(size_t ido,size_t l1,const T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa) const2197       void radb5(size_t ido, size_t l1, const T *POCKETFFT_RESTRICT cc,
2198                  T *POCKETFFT_RESTRICT ch,
2199                  const T0 *POCKETFFT_RESTRICT wa) const
2200       {
2201         constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L),
2202                      ti11 = T0(0.9510565162951535721164393333793821L),
2203                      tr12 = T0(-0.8090169943749474241022934171828191L),
2204                      ti12 = T0(0.5877852522924731291687059546390728L);
2205 
2206         auto WA =
2207             [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };
2208         auto CC = [cc, ido](size_t a, size_t b, size_t c)
2209                       -> const T &{ return cc[a + ido * (b + 5 * c)]; };
2210         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
2211                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
2212 
2213         for (size_t k = 0; k < l1; k++) {
2214           T ti5 = CC(0, 2, k) + CC(0, 2, k);
2215           T ti4 = CC(0, 4, k) + CC(0, 4, k);
2216           T tr2 = CC(ido - 1, 1, k) + CC(ido - 1, 1, k);
2217           T tr3 = CC(ido - 1, 3, k) + CC(ido - 1, 3, k);
2218           CH(0, k, 0) = CC(0, 0, k) + tr2 + tr3;
2219           T cr2 = CC(0, 0, k) + tr11 * tr2 + tr12 * tr3;
2220           T cr3 = CC(0, 0, k) + tr12 * tr2 + tr11 * tr3;
2221           T ci4, ci5;
2222           MULPM(ci5, ci4, ti5, ti4, ti11, ti12);
2223           PM(CH(0, k, 4), CH(0, k, 1), cr2, ci5);
2224           PM(CH(0, k, 3), CH(0, k, 2), cr3, ci4);
2225         }
2226         if (ido == 1)
2227           return;
2228         for (size_t k = 0; k < l1; ++k)
2229           for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {
2230             T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5;
2231             PM(tr2, tr5, CC(i - 1, 2, k), CC(ic - 1, 1, k));
2232             PM(ti5, ti2, CC(i, 2, k), CC(ic, 1, k));
2233             PM(tr3, tr4, CC(i - 1, 4, k), CC(ic - 1, 3, k));
2234             PM(ti4, ti3, CC(i, 4, k), CC(ic, 3, k));
2235             CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2 + tr3;
2236             CH(i, k, 0) = CC(i, 0, k) + ti2 + ti3;
2237             T cr2 = CC(i - 1, 0, k) + tr11 * tr2 + tr12 * tr3;
2238             T ci2 = CC(i, 0, k) + tr11 * ti2 + tr12 * ti3;
2239             T cr3 = CC(i - 1, 0, k) + tr12 * tr2 + tr11 * tr3;
2240             T ci3 = CC(i, 0, k) + tr12 * ti2 + tr11 * ti3;
2241             T ci4, ci5, cr5, cr4;
2242             MULPM(cr5, cr4, tr5, tr4, ti11, ti12);
2243             MULPM(ci5, ci4, ti5, ti4, ti11, ti12);
2244             T dr2, dr3, dr4, dr5, di2, di3, di4, di5;
2245             PM(dr4, dr3, cr3, ci4);
2246             PM(di3, di4, ci3, cr4);
2247             PM(dr5, dr2, cr2, ci5);
2248             PM(di2, di5, ci2, cr5);
2249             MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2,
2250                   dr2);
2251             MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3,
2252                   dr3);
2253             MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), di4,
2254                   dr4);
2255             MULPM(CH(i, k, 4), CH(i - 1, k, 4), WA(3, i - 2), WA(3, i - 1), di5,
2256                   dr5);
2257           }
2258       }
2259 
2260       template <typename T>
radbg(size_t ido,size_t ip,size_t l1,T * POCKETFFT_RESTRICT cc,T * POCKETFFT_RESTRICT ch,const T0 * POCKETFFT_RESTRICT wa,const T0 * POCKETFFT_RESTRICT csarr) const2261       void radbg(size_t ido, size_t ip, size_t l1, T *POCKETFFT_RESTRICT cc,
2262                  T *POCKETFFT_RESTRICT ch, const T0 *POCKETFFT_RESTRICT wa,
2263                  const T0 *POCKETFFT_RESTRICT csarr) const
2264       {
2265         const size_t cdim = ip;
2266         size_t ipph = (ip + 1) / 2;
2267         size_t idl1 = ido * l1;
2268 
2269         auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c)
2270                       -> const T &{ return cc[a + ido * (b + cdim * c)]; };
2271         auto CH = [ch, ido, l1](size_t a, size_t b, size_t c)
2272                       -> T &{ return ch[a + ido * (b + l1 * c)]; };
2273         auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c)
2274                       -> const T &{ return cc[a + ido * (b + l1 * c)]; };
2275         auto C2 =
2276             [cc, idl1](size_t a, size_t b) -> T &{ return cc[a + idl1 * b]; };
2277         auto CH2 =
2278             [ch, idl1](size_t a, size_t b) -> T &{ return ch[a + idl1 * b]; };
2279 
2280         for (size_t k = 0; k < l1; ++k)    // 102
2281           for (size_t i = 0; i < ido; ++i) // 101
2282             CH(i, k, 0) = CC(i, 0, k);
2283         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 108
2284         {
2285           size_t j2 = 2 * j - 1;
2286           for (size_t k = 0; k < l1; ++k) {
2287             CH(0, k, j) = 2 * CC(ido - 1, j2, k);
2288             CH(0, k, jc) = 2 * CC(0, j2 + 1, k);
2289           }
2290         }
2291 
2292         if (ido != 1) {
2293           for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 111
2294           {
2295             size_t j2 = 2 * j - 1;
2296             for (size_t k = 0; k < l1; ++k)
2297               for (size_t i = 1, ic = ido - i - 2; i <= ido - 2;
2298                    i += 2, ic -= 2) // 109
2299               {
2300                 CH(i, k, j) = CC(i, j2 + 1, k) + CC(ic, j2, k);
2301                 CH(i, k, jc) = CC(i, j2 + 1, k) - CC(ic, j2, k);
2302                 CH(i + 1, k, j) = CC(i + 1, j2 + 1, k) - CC(ic + 1, j2, k);
2303                 CH(i + 1, k, jc) = CC(i + 1, j2 + 1, k) + CC(ic + 1, j2, k);
2304               }
2305           }
2306         }
2307         for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) {
2308           for (size_t ik = 0; ik < idl1; ++ik) {
2309             C2(ik, l) = CH2(ik, 0) + csarr[2 * l] * CH2(ik, 1) +
2310                         csarr[4 * l] * CH2(ik, 2);
2311             C2(ik, lc) = csarr[2 * l + 1] * CH2(ik, ip - 1) +
2312                          csarr[4 * l + 1] * CH2(ik, ip - 2);
2313           }
2314           size_t iang = 2 * l;
2315           size_t j = 3, jc = ip - 3;
2316           for (; j < ipph - 3; j += 4, jc -= 4) {
2317             iang += l;
2318             if (iang > ip)
2319               iang -= ip;
2320             T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];
2321             iang += l;
2322             if (iang > ip)
2323               iang -= ip;
2324             T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];
2325             iang += l;
2326             if (iang > ip)
2327               iang -= ip;
2328             T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1];
2329             iang += l;
2330             if (iang > ip)
2331               iang -= ip;
2332             T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1];
2333             for (size_t ik = 0; ik < idl1; ++ik) {
2334               C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1) +
2335                            ar3 * CH2(ik, j + 2) + ar4 * CH2(ik, j + 3);
2336               C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1) +
2337                             ai3 * CH2(ik, jc - 2) + ai4 * CH2(ik, jc - 3);
2338             }
2339           }
2340           for (; j < ipph - 1; j += 2, jc -= 2) {
2341             iang += l;
2342             if (iang > ip)
2343               iang -= ip;
2344             T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];
2345             iang += l;
2346             if (iang > ip)
2347               iang -= ip;
2348             T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];
2349             for (size_t ik = 0; ik < idl1; ++ik) {
2350               C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1);
2351               C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1);
2352             }
2353           }
2354           for (; j < ipph; ++j, --jc) {
2355             iang += l;
2356             if (iang > ip)
2357               iang -= ip;
2358             T0 war = csarr[2 * iang], wai = csarr[2 * iang + 1];
2359             for (size_t ik = 0; ik < idl1; ++ik) {
2360               C2(ik, l) += war * CH2(ik, j);
2361               C2(ik, lc) += wai * CH2(ik, jc);
2362             }
2363           }
2364         }
2365         for (size_t j = 1; j < ipph; ++j)
2366           for (size_t ik = 0; ik < idl1; ++ik)
2367             CH2(ik, 0) += CH2(ik, j);
2368         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 124
2369           for (size_t k = 0; k < l1; ++k)
2370             PM(CH(0, k, jc), CH(0, k, j), C1(0, k, j), C1(0, k, jc));
2371 
2372         if (ido == 1)
2373           return;
2374 
2375         for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 127
2376           for (size_t k = 0; k < l1; ++k)
2377             for (size_t i = 1; i <= ido - 2; i += 2) {
2378               CH(i, k, j) = C1(i, k, j) - C1(i + 1, k, jc);
2379               CH(i, k, jc) = C1(i, k, j) + C1(i + 1, k, jc);
2380               CH(i + 1, k, j) = C1(i + 1, k, j) + C1(i, k, jc);
2381               CH(i + 1, k, jc) = C1(i + 1, k, j) - C1(i, k, jc);
2382             }
2383 
2384         // All in CH
2385 
2386         for (size_t j = 1; j < ip; ++j) {
2387           size_t is = (j - 1) * (ido - 1);
2388           for (size_t k = 0; k < l1; ++k) {
2389             size_t idij = is;
2390             for (size_t i = 1; i <= ido - 2; i += 2) {
2391               T t1 = CH(i, k, j), t2 = CH(i + 1, k, j);
2392               CH(i, k, j) = wa[idij] * t1 - wa[idij + 1] * t2;
2393               CH(i + 1, k, j) = wa[idij] * t2 + wa[idij + 1] * t1;
2394               idij += 2;
2395             }
2396           }
2397         }
2398       }
2399 
2400       template <typename T>
copy_and_norm(T * c,T * p1,size_t n,T0 fct) const2401       void copy_and_norm(T *c, T *p1, size_t n, T0 fct) const
2402       {
2403         if (p1 != c) {
2404           if (fct != 1.)
2405             for (size_t i = 0; i < n; ++i)
2406               c[i] = fct * p1[i];
2407           else
2408             memcpy(c, p1, n * sizeof(T));
2409         } else if (fct != 1.)
2410           for (size_t i = 0; i < n; ++i)
2411             c[i] *= fct;
2412       }
2413 
2414     public:
2415       template <typename T>
exec(T c[],T0 fct,bool r2hc) const2416       void exec(T c[], T0 fct, bool r2hc) const
2417       {
2418         if (length == 1) {
2419           c[0] *= fct;
2420           return;
2421         }
2422         size_t n = length, nf = fact.size();
2423         arr<T> ch(n);
2424         T *p1 = c, *p2 = ch.data();
2425 
2426         if (r2hc)
2427           for (size_t k1 = 0, l1 = n; k1 < nf; ++k1) {
2428             size_t k = nf - k1 - 1;
2429             size_t ip = fact[k].fct;
2430             size_t ido = n / l1;
2431             l1 /= ip;
2432             if (ip == 4)
2433               radf4(ido, l1, p1, p2, fact[k].tw);
2434             else if (ip == 2)
2435               radf2(ido, l1, p1, p2, fact[k].tw);
2436             else if (ip == 3)
2437               radf3(ido, l1, p1, p2, fact[k].tw);
2438             else if (ip == 5)
2439               radf5(ido, l1, p1, p2, fact[k].tw);
2440             else {
2441               radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);
2442               std::swap(p1, p2);
2443             }
2444             std::swap(p1, p2);
2445           }
2446         else
2447           for (size_t k = 0, l1 = 1; k < nf; k++) {
2448             size_t ip = fact[k].fct, ido = n / (ip * l1);
2449             if (ip == 4)
2450               radb4(ido, l1, p1, p2, fact[k].tw);
2451             else if (ip == 2)
2452               radb2(ido, l1, p1, p2, fact[k].tw);
2453             else if (ip == 3)
2454               radb3(ido, l1, p1, p2, fact[k].tw);
2455             else if (ip == 5)
2456               radb5(ido, l1, p1, p2, fact[k].tw);
2457             else
2458               radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);
2459             std::swap(p1, p2);
2460             l1 *= ip;
2461           }
2462 
2463         copy_and_norm(c, p1, n, fct);
2464       }
2465 
2466     private:
factorize()2467       void factorize()
2468       {
2469         size_t len = length;
2470         while ((len % 4) == 0) {
2471           add_factor(4);
2472           len >>= 2;
2473         }
2474         if ((len % 2) == 0) {
2475           len >>= 1;
2476           // factor 2 should be at the front of the factor list
2477           add_factor(2);
2478           std::swap(fact[0].fct, fact.back().fct);
2479         }
2480         for (size_t divisor = 3; divisor * divisor <= len; divisor += 2)
2481           while ((len % divisor) == 0) {
2482             add_factor(divisor);
2483             len /= divisor;
2484           }
2485         if (len > 1)
2486           add_factor(len);
2487       }
2488 
twsize() const2489       size_t twsize() const
2490       {
2491         size_t twsz = 0, l1 = 1;
2492         for (size_t k = 0; k < fact.size(); ++k) {
2493           size_t ip = fact[k].fct, ido = length / (l1 * ip);
2494           twsz += (ip - 1) * (ido - 1);
2495           if (ip > 5)
2496             twsz += 2 * ip;
2497           l1 *= ip;
2498         }
2499         return twsz;
2500       }
2501 
comp_twiddle()2502       void comp_twiddle()
2503       {
2504         sincos_2pibyn<T0> twid(length);
2505         size_t l1 = 1;
2506         T0 *ptr = mem.data();
2507         for (size_t k = 0; k < fact.size(); ++k) {
2508           size_t ip = fact[k].fct, ido = length / (l1 * ip);
2509           if (k < fact.size() - 1) // last factor doesn't need twiddles
2510           {
2511             fact[k].tw = ptr;
2512             ptr += (ip - 1) * (ido - 1);
2513             for (size_t j = 1; j < ip; ++j)
2514               for (size_t i = 1; i <= (ido - 1) / 2; ++i) {
2515                 fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 2] =
2516                     twid[j * l1 * i].r;
2517                 fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 1] =
2518                     twid[j * l1 * i].i;
2519               }
2520           }
2521           if (ip > 5) // special factors required by *g functions
2522           {
2523             fact[k].tws = ptr;
2524             ptr += 2 * ip;
2525             fact[k].tws[0] = 1.;
2526             fact[k].tws[1] = 0.;
2527             for (size_t i = 2, ic = 2 * ip - 2; i <= ic; i += 2, ic -= 2) {
2528               fact[k].tws[i] = twid[i / 2 * (length / ip)].r;
2529               fact[k].tws[i + 1] = twid[i / 2 * (length / ip)].i;
2530               fact[k].tws[ic] = twid[i / 2 * (length / ip)].r;
2531               fact[k].tws[ic + 1] = -twid[i / 2 * (length / ip)].i;
2532             }
2533           }
2534           l1 *= ip;
2535         }
2536       }
2537 
2538     public:
rfftp(size_t length_)2539       POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_)
2540       {
2541         if (length == 0)
2542           throw std::runtime_error("zero-length FFT requested");
2543         if (length == 1)
2544           return;
2545         factorize();
2546         mem.resize(twsize());
2547         comp_twiddle();
2548       }
2549     };
2550 
2551     //
2552     // complex Bluestein transforms
2553     //
2554 
2555     template <typename T0>
2556     class fftblue
2557     {
2558     private:
2559       size_t n, n2;
2560       cfftp<T0> plan;
2561       arr<cmplx<T0>> mem;
2562       cmplx<T0> *bk, *bkf;
2563 
2564       template <bool fwd, typename T>
fft(cmplx<T> c[],T0 fct) const2565       void fft(cmplx<T> c[], T0 fct) const
2566       {
2567         arr<cmplx<T>> akf(n2);
2568 
2569         /* initialize a_k and FFT it */
2570         for (size_t m = 0; m < n; ++m)
2571           special_mul<fwd>(c[m], bk[m], akf[m]);
2572         auto zero = akf[0] * T0(0);
2573         for (size_t m = n; m < n2; ++m)
2574           akf[m] = zero;
2575 
2576         plan.exec(akf.data(), 1., true);
2577 
2578         /* do the convolution */
2579         akf[0] = akf[0].template special_mul<!fwd>(bkf[0]);
2580         for (size_t m = 1; m < (n2 + 1) / 2; ++m) {
2581           akf[m] = akf[m].template special_mul<!fwd>(bkf[m]);
2582           akf[n2 - m] = akf[n2 - m].template special_mul<!fwd>(bkf[m]);
2583         }
2584         if ((n2 & 1) == 0)
2585           akf[n2 / 2] = akf[n2 / 2].template special_mul<!fwd>(bkf[n2 / 2]);
2586 
2587         /* inverse FFT */
2588         plan.exec(akf.data(), 1., false);
2589 
2590         /* multiply by b_k */
2591         for (size_t m = 0; m < n; ++m)
2592           c[m] = akf[m].template special_mul<fwd>(bk[m]) * fct;
2593       }
2594 
2595     public:
fftblue(size_t length)2596       POCKETFFT_NOINLINE fftblue(size_t length)
2597           : n(length), n2(util::good_size_cmplx(n * 2 - 1)), plan(n2),
2598             mem(n + n2 / 2 + 1), bk(mem.data()), bkf(mem.data() + n)
2599       {
2600         /* initialize b_k */
2601         sincos_2pibyn<T0> tmp(2 * n);
2602         bk[0].Set(1, 0);
2603 
2604         size_t coeff = 0;
2605         for (size_t m = 1; m < n; ++m) {
2606           coeff += 2 * m - 1;
2607           if (coeff >= 2 * n)
2608             coeff -= 2 * n;
2609           bk[m] = tmp[coeff];
2610         }
2611 
2612         /* initialize the zero-padded, Fourier transformed b_k. Add
2613          * normalisation. */
2614         arr<cmplx<T0>> tbkf(n2);
2615         T0 xn2 = T0(1) / T0(n2);
2616         tbkf[0] = bk[0] * xn2;
2617         for (size_t m = 1; m < n; ++m)
2618           tbkf[m] = tbkf[n2 - m] = bk[m] * xn2;
2619         for (size_t m = n; m <= (n2 - n); ++m)
2620           tbkf[m].Set(0., 0.);
2621         plan.exec(tbkf.data(), 1., true);
2622         for (size_t i = 0; i < n2 / 2 + 1; ++i)
2623           bkf[i] = tbkf[i];
2624       }
2625 
2626       template <typename T>
exec(cmplx<T> c[],T0 fct,bool fwd) const2627       void exec(cmplx<T> c[], T0 fct, bool fwd) const
2628       {
2629         fwd ? fft<true>(c, fct) : fft<false>(c, fct);
2630       }
2631 
2632       template <typename T>
exec_r(T c[],T0 fct,bool fwd)2633       void exec_r(T c[], T0 fct, bool fwd)
2634       {
2635         arr<cmplx<T>> tmp(n);
2636         if (fwd) {
2637           auto zero = T0(0) * c[0];
2638           for (size_t m = 0; m < n; ++m)
2639             tmp[m].Set(c[m], zero);
2640           fft<true>(tmp.data(), fct);
2641           c[0] = tmp[0].r;
2642           memcpy(c + 1, tmp.data() + 1, (n - 1) * sizeof(T));
2643         } else {
2644           tmp[0].Set(c[0], c[0] * 0);
2645           memcpy(reinterpret_cast<void *>(tmp.data() + 1),
2646                  reinterpret_cast<void *>(c + 1), (n - 1) * sizeof(T));
2647           if ((n & 1) == 0)
2648             tmp[n / 2].i = T0(0) * c[0];
2649           for (size_t m = 1; 2 * m < n; ++m)
2650             tmp[n - m].Set(tmp[m].r, -tmp[m].i);
2651           fft<false>(tmp.data(), fct);
2652           for (size_t m = 0; m < n; ++m)
2653             c[m] = tmp[m].r;
2654         }
2655       }
2656     };
2657 
2658     //
2659     // flexible (FFTPACK/Bluestein) complex 1D transform
2660     //
2661 
2662     template <typename T0>
2663     class pocketfft_c
2664     {
2665     private:
2666       std::unique_ptr<cfftp<T0>> packplan;
2667       std::unique_ptr<fftblue<T0>> blueplan;
2668       size_t len;
2669 
2670     public:
pocketfft_c(size_t length)2671       POCKETFFT_NOINLINE pocketfft_c(size_t length) : len(length)
2672       {
2673         if (length == 0)
2674           throw std::runtime_error("zero-length FFT requested");
2675         size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length);
2676         if (tmp * tmp <= length) {
2677           packplan = std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));
2678           return;
2679         }
2680         double comp1 = util::cost_guess(length);
2681         double comp2 =
2682             2 * util::cost_guess(util::good_size_cmplx(2 * length - 1));
2683         comp2 *= 1.5;      /* fudge factor that appears to give good overall
2684                               performance */
2685         if (comp2 < comp1) // use Bluestein
2686           blueplan = std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));
2687         else
2688           packplan = std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));
2689       }
2690 
2691       template <typename T>
exec(cmplx<T> c[],T0 fct,bool fwd) const2692       POCKETFFT_NOINLINE void exec(cmplx<T> c[], T0 fct, bool fwd) const
2693       {
2694         packplan ? packplan->exec(c, fct, fwd) : blueplan->exec(c, fct, fwd);
2695       }
2696 
length() const2697       size_t length() const
2698       {
2699         return len;
2700       }
2701     };
2702 
2703     //
2704     // flexible (FFTPACK/Bluestein) real-valued 1D transform
2705     //
2706 
2707     template <typename T0>
2708     class pocketfft_r
2709     {
2710     private:
2711       std::unique_ptr<rfftp<T0>> packplan;
2712       std::unique_ptr<fftblue<T0>> blueplan;
2713       size_t len;
2714 
2715     public:
pocketfft_r(size_t length)2716       POCKETFFT_NOINLINE pocketfft_r(size_t length) : len(length)
2717       {
2718         if (length == 0)
2719           throw std::runtime_error("zero-length FFT requested");
2720         size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length);
2721         if (tmp * tmp <= length) {
2722           packplan = std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
2723           return;
2724         }
2725         double comp1 = 0.5 * util::cost_guess(length);
2726         double comp2 =
2727             2 * util::cost_guess(util::good_size_cmplx(2 * length - 1));
2728         comp2 *= 1.5;      /* fudge factor that appears to give good overall
2729                               performance */
2730         if (comp2 < comp1) // use Bluestein
2731           blueplan = std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));
2732         else
2733           packplan = std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
2734       }
2735 
2736       template <typename T>
exec(T c[],T0 fct,bool fwd) const2737       POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const
2738       {
2739         packplan ? packplan->exec(c, fct, fwd) : blueplan->exec_r(c, fct, fwd);
2740       }
2741 
length() const2742       size_t length() const
2743       {
2744         return len;
2745       }
2746     };
2747 
2748     //
2749     // sine/cosine transforms
2750     //
2751 
2752     template <typename T0>
2753     class T_dct1
2754     {
2755     private:
2756       pocketfft_r<T0> fftplan;
2757 
2758     public:
T_dct1(size_t length)2759       POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2 * (length - 1))
2760       {
2761       }
2762 
2763       template <typename T>
exec(T c[],T0 fct,bool ortho,int,bool) const2764       POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/,
2765                                    bool /*cosine*/) const
2766       {
2767         constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);
2768         size_t N = fftplan.length(), n = N / 2 + 1;
2769         if (ortho) {
2770           c[0] *= sqrt2;
2771           c[n - 1] *= sqrt2;
2772         }
2773         arr<T> tmp(N);
2774         tmp[0] = c[0];
2775         for (size_t i = 1; i < n; ++i)
2776           tmp[i] = tmp[N - i] = c[i];
2777         fftplan.exec(tmp.data(), fct, true);
2778         c[0] = tmp[0];
2779         for (size_t i = 1; i < n; ++i)
2780           c[i] = tmp[2 * i - 1];
2781         if (ortho) {
2782           c[0] *= sqrt2 * T0(0.5);
2783           c[n - 1] *= sqrt2 * T0(0.5);
2784         }
2785       }
2786 
length() const2787       size_t length() const
2788       {
2789         return fftplan.length() / 2 + 1;
2790       }
2791     };
2792 
2793     template <typename T0>
2794     class T_dst1
2795     {
2796     private:
2797       pocketfft_r<T0> fftplan;
2798 
2799     public:
T_dst1(size_t length)2800       POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2 * (length + 1))
2801       {
2802       }
2803 
2804       template <typename T>
exec(T c[],T0 fct,bool,int,bool) const2805       POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/,
2806                                    bool /*cosine*/) const
2807       {
2808         size_t N = fftplan.length(), n = N / 2 - 1;
2809         arr<T> tmp(N);
2810         tmp[0] = tmp[n + 1] = c[0] * 0;
2811         for (size_t i = 0; i < n; ++i) {
2812           tmp[i + 1] = c[i];
2813           tmp[N - 1 - i] = -c[i];
2814         }
2815         fftplan.exec(tmp.data(), fct, true);
2816         for (size_t i = 0; i < n; ++i)
2817           c[i] = -tmp[2 * i + 2];
2818       }
2819 
length() const2820       size_t length() const
2821       {
2822         return fftplan.length() / 2 - 1;
2823       }
2824     };
2825 
2826     template <typename T0>
2827     class T_dcst23
2828     {
2829     private:
2830       pocketfft_r<T0> fftplan;
2831       std::vector<T0> twiddle;
2832 
2833     public:
T_dcst23(size_t length)2834       POCKETFFT_NOINLINE T_dcst23(size_t length)
2835           : fftplan(length), twiddle(length)
2836       {
2837         sincos_2pibyn<T0> tw(4 * length);
2838         for (size_t i = 0; i < length; ++i)
2839           twiddle[i] = tw[i + 1].r;
2840       }
2841 
2842       template <typename T>
exec(T c[],T0 fct,bool ortho,int type,bool cosine) const2843       POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int type,
2844                                    bool cosine) const
2845       {
2846         constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);
2847         size_t N = length();
2848         size_t NS2 = (N + 1) / 2;
2849         if (type == 2) {
2850           if (!cosine)
2851             for (size_t k = 1; k < N; k += 2)
2852               c[k] = -c[k];
2853           c[0] *= 2;
2854           if ((N & 1) == 0)
2855             c[N - 1] *= 2;
2856           for (size_t k = 1; k < N - 1; k += 2)
2857             MPINPLACE(c[k + 1], c[k]);
2858           fftplan.exec(c, fct, false);
2859           for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) {
2860             T t1 = twiddle[k - 1] * c[kc] + twiddle[kc - 1] * c[k];
2861             T t2 = twiddle[k - 1] * c[k] - twiddle[kc - 1] * c[kc];
2862             c[k] = T0(0.5) * (t1 + t2);
2863             c[kc] = T0(0.5) * (t1 - t2);
2864           }
2865           if ((N & 1) == 0)
2866             c[NS2] *= twiddle[NS2 - 1];
2867           if (!cosine)
2868             for (size_t k = 0, kc = N - 1; k < kc; ++k, --kc)
2869               std::swap(c[k], c[kc]);
2870           if (ortho)
2871             c[0] *= sqrt2 * T0(0.5);
2872         } else {
2873           if (ortho)
2874             c[0] *= sqrt2;
2875           if (!cosine)
2876             for (size_t k = 0, kc = N - 1; k < NS2; ++k, --kc)
2877               std::swap(c[k], c[kc]);
2878           for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) {
2879             T t1 = c[k] + c[kc], t2 = c[k] - c[kc];
2880             c[k] = twiddle[k - 1] * t2 + twiddle[kc - 1] * t1;
2881             c[kc] = twiddle[k - 1] * t1 - twiddle[kc - 1] * t2;
2882           }
2883           if ((N & 1) == 0)
2884             c[NS2] *= 2 * twiddle[NS2 - 1];
2885           fftplan.exec(c, fct, true);
2886           for (size_t k = 1; k < N - 1; k += 2)
2887             MPINPLACE(c[k], c[k + 1]);
2888           if (!cosine)
2889             for (size_t k = 1; k < N; k += 2)
2890               c[k] = -c[k];
2891         }
2892       }
2893 
length() const2894       size_t length() const
2895       {
2896         return fftplan.length();
2897       }
2898     };
2899 
2900     template <typename T0>
2901     class T_dcst4
2902     {
2903     private:
2904       size_t N;
2905       std::unique_ptr<pocketfft_c<T0>> fft;
2906       std::unique_ptr<pocketfft_r<T0>> rfft;
2907       arr<cmplx<T0>> C2;
2908 
2909     public:
T_dcst4(size_t length)2910       POCKETFFT_NOINLINE T_dcst4(size_t length)
2911           : N(length), fft((N & 1) ? nullptr : new pocketfft_c<T0>(N / 2)),
2912             rfft((N & 1) ? new pocketfft_r<T0>(N) : nullptr),
2913             C2((N & 1) ? 0 : N / 2)
2914       {
2915         if ((N & 1) == 0) {
2916           sincos_2pibyn<T0> tw(16 * N);
2917           for (size_t i = 0; i < N / 2; ++i)
2918             C2[i] = conj(tw[8 * i + 1]);
2919         }
2920       }
2921 
2922       template <typename T>
exec(T c[],T0 fct,bool,int,bool cosine) const2923       POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/,
2924                                    bool cosine) const
2925       {
2926         size_t n2 = N / 2;
2927         if (!cosine)
2928           for (size_t k = 0, kc = N - 1; k < n2; ++k, --kc)
2929             std::swap(c[k], c[kc]);
2930         if (N & 1) {
2931           // The following code is derived from the FFTW3 function apply_re11()
2932           // and is released under the 3-clause BSD license with friendly
2933           // permission of Matteo Frigo and Steven G. Johnson.
2934 
2935           arr<T> y(N);
2936           {
2937             size_t i = 0, m = n2;
2938             for (; m < N; ++i, m += 4)
2939               y[i] = c[m];
2940             for (; m < 2 * N; ++i, m += 4)
2941               y[i] = -c[2 * N - m - 1];
2942             for (; m < 3 * N; ++i, m += 4)
2943               y[i] = -c[m - 2 * N];
2944             for (; m < 4 * N; ++i, m += 4)
2945               y[i] = c[4 * N - m - 1];
2946             for (; i < N; ++i, m += 4)
2947               y[i] = c[m - 4 * N];
2948           }
2949           rfft->exec(y.data(), fct, true);
2950           {
2951             auto SGN = [](size_t i) {
2952               constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);
2953               return (i & 2) ? -sqrt2 : sqrt2;
2954             };
2955             c[n2] = y[0] * SGN(n2 + 1);
2956             size_t i = 0, i1 = 1, k = 1;
2957             for (; k < n2; ++i, ++i1, k += 2) {
2958               c[i] = y[2 * k - 1] * SGN(i1) + y[2 * k] * SGN(i);
2959               c[N - i1] = y[2 * k - 1] * SGN(N - i) - y[2 * k] * SGN(N - i1);
2960               c[n2 - i1] =
2961                   y[2 * k + 1] * SGN(n2 - i) - y[2 * k + 2] * SGN(n2 - i1);
2962               c[n2 + i1] =
2963                   y[2 * k + 1] * SGN(n2 + i + 2) + y[2 * k + 2] * SGN(n2 + i1);
2964             }
2965             if (k == n2) {
2966               c[i] = y[2 * k - 1] * SGN(i + 1) + y[2 * k] * SGN(i);
2967               c[N - i1] = y[2 * k - 1] * SGN(i + 2) + y[2 * k] * SGN(i1);
2968             }
2969           }
2970 
2971           // FFTW-derived code ends here
2972         } else {
2973           // even length algorithm from
2974           // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
2975           arr<cmplx<T>> y(n2);
2976           for (size_t i = 0; i < n2; ++i) {
2977             y[i].Set(c[2 * i], c[N - 1 - 2 * i]);
2978             y[i] *= C2[i];
2979           }
2980           fft->exec(y.data(), fct, true);
2981           for (size_t i = 0, ic = n2 - 1; i < n2; ++i, --ic) {
2982             c[2 * i] = 2 * (y[i].r * C2[i].r - y[i].i * C2[i].i);
2983             c[2 * i + 1] = -2 * (y[ic].i * C2[ic].r + y[ic].r * C2[ic].i);
2984           }
2985         }
2986         if (!cosine)
2987           for (size_t k = 1; k < N; k += 2)
2988             c[k] = -c[k];
2989       }
2990 
length() const2991       size_t length() const
2992       {
2993         return N;
2994       }
2995     };
2996 
2997     //
2998     // multi-D infrastructure
2999     //
3000 
3001     template <typename T>
get_plan(size_t length)3002     std::shared_ptr<T> get_plan(size_t length)
3003     {
3004 #if POCKETFFT_CACHE_SIZE == 0
3005       return std::make_shared<T>(length);
3006 #else
3007       constexpr size_t nmax = POCKETFFT_CACHE_SIZE;
3008       static std::array<std::shared_ptr<T>, nmax> cache;
3009       static std::array<size_t, nmax> last_access{{0}};
3010       static size_t access_counter = 0;
3011       static std::mutex mut;
3012 
3013       auto find_in_cache = [&]() -> std::shared_ptr<T> {
3014         for (size_t i = 0; i < nmax; ++i)
3015           if (cache[i] && (cache[i]->length() == length)) {
3016             // no need to update if this is already the most recent entry
3017             if (last_access[i] != access_counter) {
3018               last_access[i] = ++access_counter;
3019               // Guard against overflow
3020               if (access_counter == 0)
3021                 last_access.fill(0);
3022             }
3023             return cache[i];
3024           }
3025 
3026         return nullptr;
3027       };
3028 
3029       {
3030         std::lock_guard<std::mutex> lock(mut);
3031         auto p = find_in_cache();
3032         if (p)
3033           return p;
3034       }
3035       auto plan = std::make_shared<T>(length);
3036       {
3037         std::lock_guard<std::mutex> lock(mut);
3038         auto p = find_in_cache();
3039         if (p)
3040           return p;
3041 
3042         size_t lru = 0;
3043         for (size_t i = 1; i < nmax; ++i)
3044           if (last_access[i] < last_access[lru])
3045             lru = i;
3046 
3047         cache[lru] = plan;
3048         last_access[lru] = ++access_counter;
3049       }
3050       return plan;
3051 #endif
3052     }
3053 
3054     class arr_info
3055     {
3056     protected:
3057       shape_t shp;
3058       stride_t str;
3059 
3060     public:
arr_info(const shape_t & shape_,const stride_t & stride_)3061       arr_info(const shape_t &shape_, const stride_t &stride_)
3062           : shp(shape_), str(stride_)
3063       {
3064       }
ndim() const3065       size_t ndim() const
3066       {
3067         return shp.size();
3068       }
size() const3069       size_t size() const
3070       {
3071         return util::prod(shp);
3072       }
shape() const3073       const shape_t &shape() const
3074       {
3075         return shp;
3076       }
shape(size_t i) const3077       size_t shape(size_t i) const
3078       {
3079         return shp[i];
3080       }
stride() const3081       const stride_t &stride() const
3082       {
3083         return str;
3084       }
stride(size_t i) const3085       const ptrdiff_t &stride(size_t i) const
3086       {
3087         return str[i];
3088       }
3089     };
3090 
3091     template <typename T>
3092     class cndarr : public arr_info
3093     {
3094     protected:
3095       const char *d;
3096 
3097     public:
cndarr(const void * data_,const shape_t & shape_,const stride_t & stride_)3098       cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)
3099           : arr_info(shape_, stride_), d(reinterpret_cast<const char *>(data_))
3100       {
3101       }
operator [](ptrdiff_t ofs) const3102       const T &operator[](ptrdiff_t ofs) const
3103       {
3104         return *reinterpret_cast<const T *>(d + ofs);
3105       }
3106     };
3107 
3108     template <typename T>
3109     class ndarr : public cndarr<T>
3110     {
3111     public:
ndarr(void * data_,const shape_t & shape_,const stride_t & stride_)3112       ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)
3113           : cndarr<T>::cndarr(const_cast<const void *>(data_), shape_, stride_)
3114       {
3115       }
operator [](ptrdiff_t ofs)3116       T &operator[](ptrdiff_t ofs)
3117       {
3118         return *reinterpret_cast<T *>(const_cast<char *>(cndarr<T>::d + ofs));
3119       }
3120     };
3121 
3122     template <size_t N>
3123     class multi_iter
3124     {
3125     private:
3126       shape_t pos;
3127       const arr_info &iarr, &oarr;
3128       ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
3129       size_t idim, rem;
3130 
advance_i()3131       void advance_i()
3132       {
3133         for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {
3134           auto i = size_t(i_);
3135           if (i == idim)
3136             continue;
3137           p_ii += iarr.stride(i);
3138           p_oi += oarr.stride(i);
3139           if (++pos[i] < iarr.shape(i))
3140             return;
3141           pos[i] = 0;
3142           p_ii -= ptrdiff_t(iarr.shape(i)) * iarr.stride(i);
3143           p_oi -= ptrdiff_t(oarr.shape(i)) * oarr.stride(i);
3144         }
3145       }
3146 
3147     public:
multi_iter(const arr_info & iarr_,const arr_info & oarr_,size_t idim_)3148       multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
3149           : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
3150             str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
3151             idim(idim_), rem(iarr.size() / iarr.shape(idim))
3152       {
3153         auto nshares = threading::num_threads();
3154         if (nshares == 1)
3155           return;
3156         if (nshares == 0)
3157           throw std::runtime_error("can't run with zero threads");
3158         auto myshare = threading::thread_id();
3159         if (myshare >= nshares)
3160           throw std::runtime_error("impossible share requested");
3161         size_t nbase = rem / nshares;
3162         size_t additional = rem % nshares;
3163         size_t lo =
3164             myshare * nbase + ((myshare < additional) ? myshare : additional);
3165         size_t hi = lo + nbase + (myshare < additional);
3166         size_t todo = hi - lo;
3167 
3168         size_t chunk = rem;
3169         for (size_t i = 0; i < pos.size(); ++i) {
3170           if (i == idim)
3171             continue;
3172           chunk /= iarr.shape(i);
3173           size_t n_advance = lo / chunk;
3174           pos[i] += n_advance;
3175           p_ii += ptrdiff_t(n_advance) * iarr.stride(i);
3176           p_oi += ptrdiff_t(n_advance) * oarr.stride(i);
3177           lo -= n_advance * chunk;
3178         }
3179         rem = todo;
3180       }
advance(size_t n)3181       void advance(size_t n)
3182       {
3183         if (rem < n)
3184           throw std::runtime_error("underrun");
3185         for (size_t i = 0; i < n; ++i) {
3186           p_i[i] = p_ii;
3187           p_o[i] = p_oi;
3188           advance_i();
3189         }
3190         rem -= n;
3191       }
iofs(size_t i) const3192       ptrdiff_t iofs(size_t i) const
3193       {
3194         return p_i[0] + ptrdiff_t(i) * str_i;
3195       }
iofs(size_t j,size_t i) const3196       ptrdiff_t iofs(size_t j, size_t i) const
3197       {
3198         return p_i[j] + ptrdiff_t(i) * str_i;
3199       }
oofs(size_t i) const3200       ptrdiff_t oofs(size_t i) const
3201       {
3202         return p_o[0] + ptrdiff_t(i) * str_o;
3203       }
oofs(size_t j,size_t i) const3204       ptrdiff_t oofs(size_t j, size_t i) const
3205       {
3206         return p_o[j] + ptrdiff_t(i) * str_o;
3207       }
length_in() const3208       size_t length_in() const
3209       {
3210         return iarr.shape(idim);
3211       }
length_out() const3212       size_t length_out() const
3213       {
3214         return oarr.shape(idim);
3215       }
stride_in() const3216       ptrdiff_t stride_in() const
3217       {
3218         return str_i;
3219       }
stride_out() const3220       ptrdiff_t stride_out() const
3221       {
3222         return str_o;
3223       }
remaining() const3224       size_t remaining() const
3225       {
3226         return rem;
3227       }
3228     };
3229 
3230     class simple_iter
3231     {
3232     private:
3233       shape_t pos;
3234       const arr_info &arr;
3235       ptrdiff_t p;
3236       size_t rem;
3237 
3238     public:
simple_iter(const arr_info & arr_)3239       simple_iter(const arr_info &arr_)
3240           : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size())
3241       {
3242       }
advance()3243       void advance()
3244       {
3245         --rem;
3246         for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {
3247           auto i = size_t(i_);
3248           p += arr.stride(i);
3249           if (++pos[i] < arr.shape(i))
3250             return;
3251           pos[i] = 0;
3252           p -= ptrdiff_t(arr.shape(i)) * arr.stride(i);
3253         }
3254       }
ofs() const3255       ptrdiff_t ofs() const
3256       {
3257         return p;
3258       }
remaining() const3259       size_t remaining() const
3260       {
3261         return rem;
3262       }
3263     };
3264 
3265     class rev_iter
3266     {
3267     private:
3268       shape_t pos;
3269       const arr_info &arr;
3270       std::vector<char> rev_axis;
3271       std::vector<char> rev_jump;
3272       size_t last_axis, last_size;
3273       shape_t shp;
3274       ptrdiff_t p, rp;
3275       size_t rem;
3276 
3277     public:
rev_iter(const arr_info & arr_,const shape_t & axes)3278       rev_iter(const arr_info &arr_, const shape_t &axes)
3279           : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
3280             rev_jump(arr_.ndim(), 1), p(0), rp(0)
3281       {
3282         for (auto ax : axes)
3283           rev_axis[ax] = 1;
3284         last_axis = axes.back();
3285         last_size = arr.shape(last_axis) / 2 + 1;
3286         shp = arr.shape();
3287         shp[last_axis] = last_size;
3288         rem = 1;
3289         for (auto i : shp)
3290           rem *= i;
3291       }
advance()3292       void advance()
3293       {
3294         --rem;
3295         for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {
3296           auto i = size_t(i_);
3297           p += arr.stride(i);
3298           if (!rev_axis[i])
3299             rp += arr.stride(i);
3300           else {
3301             rp -= arr.stride(i);
3302             if (rev_jump[i]) {
3303               rp += ptrdiff_t(arr.shape(i)) * arr.stride(i);
3304               rev_jump[i] = 0;
3305             }
3306           }
3307           if (++pos[i] < shp[i])
3308             return;
3309           pos[i] = 0;
3310           p -= ptrdiff_t(shp[i]) * arr.stride(i);
3311           if (rev_axis[i]) {
3312             rp -= ptrdiff_t(arr.shape(i) - shp[i]) * arr.stride(i);
3313             rev_jump[i] = 1;
3314           } else
3315             rp -= ptrdiff_t(shp[i]) * arr.stride(i);
3316         }
3317       }
ofs() const3318       ptrdiff_t ofs() const
3319       {
3320         return p;
3321       }
rev_ofs() const3322       ptrdiff_t rev_ofs() const
3323       {
3324         return rp;
3325       }
remaining() const3326       size_t remaining() const
3327       {
3328         return rem;
3329       }
3330     };
3331 
3332     template <typename T>
3333     struct VTYPE {
3334     };
3335     template <typename T>
3336     using vtype_t = typename VTYPE<T>::type;
3337 
3338 #ifndef POCKETFFT_NO_VECTORS
3339     template <>
3340     struct VTYPE<float> {
3341       using type = float
3342           __attribute__((vector_size(VLEN<float>::val * sizeof(float))));
3343     };
3344     template <>
3345     struct VTYPE<double> {
3346       using type = double
3347           __attribute__((vector_size(VLEN<double>::val * sizeof(double))));
3348     };
3349     template <>
3350     struct VTYPE<long double> {
3351       using type = long double
3352           __attribute__((vector_size(VLEN<long double>::val *
3353                                      sizeof(long double))));
3354     };
3355 #endif
3356 
3357     template <typename T>
alloc_tmp(const shape_t & shape,size_t axsize,size_t elemsize)3358     arr<char> alloc_tmp(const shape_t &shape, size_t axsize, size_t elemsize)
3359     {
3360       auto othersize = util::prod(shape) / axsize;
3361       auto tmpsize = axsize * ((othersize >= VLEN<T>::val) ? VLEN<T>::val : 1);
3362       return arr<char>(tmpsize * elemsize);
3363     }
3364     template <typename T>
alloc_tmp(const shape_t & shape,const shape_t & axes,size_t elemsize)3365     arr<char> alloc_tmp(const shape_t &shape, const shape_t &axes,
3366                         size_t elemsize)
3367     {
3368       size_t fullsize = util::prod(shape);
3369       size_t tmpsize = 0;
3370       for (size_t i = 0; i < axes.size(); ++i) {
3371         auto axsize = shape[axes[i]];
3372         auto othersize = fullsize / axsize;
3373         auto sz = axsize * ((othersize >= VLEN<T>::val) ? VLEN<T>::val : 1);
3374         if (sz > tmpsize)
3375           tmpsize = sz;
3376       }
3377       return arr<char>(tmpsize * elemsize);
3378     }
3379 
3380     template <typename T, size_t vlen>
copy_input(const multi_iter<vlen> & it,const cndarr<cmplx<T>> & src,cmplx<vtype_t<T>> * POCKETFFT_RESTRICT dst)3381     void copy_input(const multi_iter<vlen> &it, const cndarr<cmplx<T>> &src,
3382                     cmplx<vtype_t<T>> *POCKETFFT_RESTRICT dst)
3383     {
3384       for (size_t i = 0; i < it.length_in(); ++i)
3385         for (size_t j = 0; j < vlen; ++j) {
3386           dst[i].r[j] = src[it.iofs(j, i)].r;
3387           dst[i].i[j] = src[it.iofs(j, i)].i;
3388         }
3389     }
3390 
3391     template <typename T, size_t vlen>
copy_input(const multi_iter<vlen> & it,const cndarr<T> & src,vtype_t<T> * POCKETFFT_RESTRICT dst)3392     void copy_input(const multi_iter<vlen> &it, const cndarr<T> &src,
3393                     vtype_t<T> *POCKETFFT_RESTRICT dst)
3394     {
3395       for (size_t i = 0; i < it.length_in(); ++i)
3396         for (size_t j = 0; j < vlen; ++j)
3397           dst[i][j] = src[it.iofs(j, i)];
3398     }
3399 
3400     template <typename T, size_t vlen>
copy_input(const multi_iter<vlen> & it,const cndarr<T> & src,T * POCKETFFT_RESTRICT dst)3401     void copy_input(const multi_iter<vlen> &it, const cndarr<T> &src,
3402                     T *POCKETFFT_RESTRICT dst)
3403     {
3404       if (dst == &src[it.iofs(0)])
3405         return; // in-place
3406       for (size_t i = 0; i < it.length_in(); ++i)
3407         dst[i] = src[it.iofs(i)];
3408     }
3409 
3410     template <typename T, size_t vlen>
copy_output(const multi_iter<vlen> & it,const cmplx<vtype_t<T>> * POCKETFFT_RESTRICT src,ndarr<cmplx<T>> & dst)3411     void copy_output(const multi_iter<vlen> &it,
3412                      const cmplx<vtype_t<T>> *POCKETFFT_RESTRICT src,
3413                      ndarr<cmplx<T>> &dst)
3414     {
3415       for (size_t i = 0; i < it.length_out(); ++i)
3416         for (size_t j = 0; j < vlen; ++j)
3417           dst[it.oofs(j, i)].Set(src[i].r[j], src[i].i[j]);
3418     }
3419 
3420     template <typename T, size_t vlen>
copy_output(const multi_iter<vlen> & it,const vtype_t<T> * POCKETFFT_RESTRICT src,ndarr<T> & dst)3421     void copy_output(const multi_iter<vlen> &it,
3422                      const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3423     {
3424       for (size_t i = 0; i < it.length_out(); ++i)
3425         for (size_t j = 0; j < vlen; ++j)
3426           dst[it.oofs(j, i)] = src[i][j];
3427     }
3428 
3429     template <typename T, size_t vlen>
copy_output(const multi_iter<vlen> & it,const T * POCKETFFT_RESTRICT src,ndarr<T> & dst)3430     void copy_output(const multi_iter<vlen> &it,
3431                      const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3432     {
3433       if (src == &dst[it.oofs(0)])
3434         return; // in-place
3435       for (size_t i = 0; i < it.length_out(); ++i)
3436         dst[it.oofs(i)] = src[i];
3437     }
3438 
3439     template <typename T>
3440     struct add_vec {
3441       using type = vtype_t<T>;
3442     };
3443     template <typename T>
3444     struct add_vec<cmplx<T>> {
3445       using type = cmplx<vtype_t<T>>;
3446     };
3447     template <typename T>
3448     using add_vec_t = typename add_vec<T>::type;
3449 
3450     template <typename Tplan, typename T, typename T0, typename Exec>
general_nd(const cndarr<T> & in,ndarr<T> & out,const shape_t & axes,T0 fct,size_t nthreads,const Exec & exec,const bool allow_inplace=true)3451     POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
3452                                        const shape_t &axes, T0 fct,
3453                                        size_t nthreads, const Exec &exec,
3454                                        const bool allow_inplace = true)
3455     {
3456       std::shared_ptr<Tplan> plan;
3457 
3458       for (size_t iax = 0; iax < axes.size(); ++iax) {
3459         size_t len = in.shape(axes[iax]);
3460         if ((!plan) || (len != plan->length()))
3461           plan = get_plan<Tplan>(len);
3462 
3463         threading::thread_map(
3464             util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
3465             [&] {
3466               constexpr auto vlen = VLEN<T0>::val;
3467               auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
3468               const auto &tin(iax == 0 ? in : out);
3469               multi_iter<vlen> it(tin, out, axes[iax]);
3470 #ifndef POCKETFFT_NO_VECTORS
3471               if (vlen > 1)
3472                 while (it.remaining() >= vlen) {
3473                   it.advance(vlen);
3474                   auto tdatav =
3475                       reinterpret_cast<add_vec_t<T> *>(storage.data());
3476                   exec(it, tin, out, tdatav, *plan, fct);
3477                 }
3478 #endif
3479               while (it.remaining() > 0) {
3480                 it.advance(1);
3481                 auto buf = allow_inplace && it.stride_out() == sizeof(T)
3482                                ? &out[it.oofs(0)]
3483                                : reinterpret_cast<T *>(storage.data());
3484                 exec(it, tin, out, buf, *plan, fct);
3485               }
3486             });      // end of parallel region
3487         fct = T0(1); // factor has been applied, use 1 for remaining axes
3488       }
3489     }
3490 
3491     struct ExecC2C {
3492       bool forward;
3493 
3494       template <typename T0, typename T, size_t vlen>
operator ()pocketfft::detail::ExecC2C3495       void operator()(const multi_iter<vlen> &it, const cndarr<cmplx<T0>> &in,
3496                       ndarr<cmplx<T0>> &out, T *buf,
3497                       const pocketfft_c<T0> &plan, T0 fct) const
3498       {
3499         copy_input(it, in, buf);
3500         plan.exec(buf, fct, forward);
3501         copy_output(it, buf, out);
3502       }
3503     };
3504 
3505     template <typename T, size_t vlen>
copy_hartley(const multi_iter<vlen> & it,const vtype_t<T> * POCKETFFT_RESTRICT src,ndarr<T> & dst)3506     void copy_hartley(const multi_iter<vlen> &it,
3507                       const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3508     {
3509       for (size_t j = 0; j < vlen; ++j)
3510         dst[it.oofs(j, 0)] = src[0][j];
3511       size_t i = 1, i1 = 1, i2 = it.length_out() - 1;
3512       for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2)
3513         for (size_t j = 0; j < vlen; ++j) {
3514           dst[it.oofs(j, i1)] = src[i][j] + src[i + 1][j];
3515           dst[it.oofs(j, i2)] = src[i][j] - src[i + 1][j];
3516         }
3517       if (i < it.length_out())
3518         for (size_t j = 0; j < vlen; ++j)
3519           dst[it.oofs(j, i1)] = src[i][j];
3520     }
3521 
3522     template <typename T, size_t vlen>
copy_hartley(const multi_iter<vlen> & it,const T * POCKETFFT_RESTRICT src,ndarr<T> & dst)3523     void copy_hartley(const multi_iter<vlen> &it,
3524                       const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3525     {
3526       dst[it.oofs(0)] = src[0];
3527       size_t i = 1, i1 = 1, i2 = it.length_out() - 1;
3528       for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) {
3529         dst[it.oofs(i1)] = src[i] + src[i + 1];
3530         dst[it.oofs(i2)] = src[i] - src[i + 1];
3531       }
3532       if (i < it.length_out())
3533         dst[it.oofs(i1)] = src[i];
3534     }
3535 
3536     struct ExecHartley {
3537       template <typename T0, typename T, size_t vlen>
operator ()pocketfft::detail::ExecHartley3538       void operator()(const multi_iter<vlen> &it, const cndarr<T0> &in,
3539                       ndarr<T0> &out, T *buf, const pocketfft_r<T0> &plan,
3540                       T0 fct) const
3541       {
3542         copy_input(it, in, buf);
3543         plan.exec(buf, fct, true);
3544         copy_hartley(it, buf, out);
3545       }
3546     };
3547 
3548     struct ExecDcst {
3549       bool ortho;
3550       int type;
3551       bool cosine;
3552 
3553       template <typename T0, typename T, typename Tplan, size_t vlen>
operator ()pocketfft::detail::ExecDcst3554       void operator()(const multi_iter<vlen> &it, const cndarr<T0> &in,
3555                       ndarr<T0> &out, T *buf, const Tplan &plan, T0 fct) const
3556       {
3557         copy_input(it, in, buf);
3558         plan.exec(buf, fct, ortho, type, cosine);
3559         copy_output(it, buf, out);
3560       }
3561     };
3562 
3563     template <typename T>
general_r2c(const cndarr<T> & in,ndarr<cmplx<T>> & out,size_t axis,bool forward,T fct,size_t nthreads)3564     POCKETFFT_NOINLINE void general_r2c(const cndarr<T> &in,
3565                                         ndarr<cmplx<T>> &out, size_t axis,
3566                                         bool forward, T fct, size_t nthreads)
3567     {
3568       auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
3569       size_t len = in.shape(axis);
3570       threading::thread_map(
3571           util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), [&] {
3572             constexpr auto vlen = VLEN<T>::val;
3573             auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
3574             multi_iter<vlen> it(in, out, axis);
3575 #ifndef POCKETFFT_NO_VECTORS
3576             if (vlen > 1)
3577               while (it.remaining() >= vlen) {
3578                 it.advance(vlen);
3579                 auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
3580                 copy_input(it, in, tdatav);
3581                 plan->exec(tdatav, fct, true);
3582                 for (size_t j = 0; j < vlen; ++j)
3583                   out[it.oofs(j, 0)].Set(tdatav[0][j]);
3584                 size_t i = 1, ii = 1;
3585                 if (forward)
3586                   for (; i < len - 1; i += 2, ++ii)
3587                     for (size_t j = 0; j < vlen; ++j)
3588                       out[it.oofs(j, ii)].Set(tdatav[i][j], tdatav[i + 1][j]);
3589                 else
3590                   for (; i < len - 1; i += 2, ++ii)
3591                     for (size_t j = 0; j < vlen; ++j)
3592                       out[it.oofs(j, ii)].Set(tdatav[i][j], -tdatav[i + 1][j]);
3593                 if (i < len)
3594                   for (size_t j = 0; j < vlen; ++j)
3595                     out[it.oofs(j, ii)].Set(tdatav[i][j]);
3596               }
3597 #endif
3598             while (it.remaining() > 0) {
3599               it.advance(1);
3600               auto tdata = reinterpret_cast<T *>(storage.data());
3601               copy_input(it, in, tdata);
3602               plan->exec(tdata, fct, true);
3603               out[it.oofs(0)].Set(tdata[0]);
3604               size_t i = 1, ii = 1;
3605               if (forward)
3606                 for (; i < len - 1; i += 2, ++ii)
3607                   out[it.oofs(ii)].Set(tdata[i], tdata[i + 1]);
3608               else
3609                 for (; i < len - 1; i += 2, ++ii)
3610                   out[it.oofs(ii)].Set(tdata[i], -tdata[i + 1]);
3611               if (i < len)
3612                 out[it.oofs(ii)].Set(tdata[i]);
3613             }
3614           }); // end of parallel region
3615     }
3616     template <typename T>
general_c2r(const cndarr<cmplx<T>> & in,ndarr<T> & out,size_t axis,bool forward,T fct,size_t nthreads)3617     POCKETFFT_NOINLINE void general_c2r(const cndarr<cmplx<T>> &in,
3618                                         ndarr<T> &out, size_t axis,
3619                                         bool forward, T fct, size_t nthreads)
3620     {
3621       auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
3622       size_t len = out.shape(axis);
3623       threading::thread_map(
3624           util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), [&] {
3625             constexpr auto vlen = VLEN<T>::val;
3626             auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
3627             multi_iter<vlen> it(in, out, axis);
3628 #ifndef POCKETFFT_NO_VECTORS
3629             if (vlen > 1)
3630               while (it.remaining() >= vlen) {
3631                 it.advance(vlen);
3632                 auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
3633                 for (size_t j = 0; j < vlen; ++j)
3634                   tdatav[0][j] = in[it.iofs(j, 0)].r;
3635                 {
3636                   size_t i = 1, ii = 1;
3637                   if (forward)
3638                     for (; i < len - 1; i += 2, ++ii)
3639                       for (size_t j = 0; j < vlen; ++j) {
3640                         tdatav[i][j] = in[it.iofs(j, ii)].r;
3641                         tdatav[i + 1][j] = -in[it.iofs(j, ii)].i;
3642                       }
3643                   else
3644                     for (; i < len - 1; i += 2, ++ii)
3645                       for (size_t j = 0; j < vlen; ++j) {
3646                         tdatav[i][j] = in[it.iofs(j, ii)].r;
3647                         tdatav[i + 1][j] = in[it.iofs(j, ii)].i;
3648                       }
3649                   if (i < len)
3650                     for (size_t j = 0; j < vlen; ++j)
3651                       tdatav[i][j] = in[it.iofs(j, ii)].r;
3652                 }
3653                 plan->exec(tdatav, fct, false);
3654                 copy_output(it, tdatav, out);
3655               }
3656 #endif
3657             while (it.remaining() > 0) {
3658               it.advance(1);
3659               auto tdata = reinterpret_cast<T *>(storage.data());
3660               tdata[0] = in[it.iofs(0)].r;
3661               {
3662                 size_t i = 1, ii = 1;
3663                 if (forward)
3664                   for (; i < len - 1; i += 2, ++ii) {
3665                     tdata[i] = in[it.iofs(ii)].r;
3666                     tdata[i + 1] = -in[it.iofs(ii)].i;
3667                   }
3668                 else
3669                   for (; i < len - 1; i += 2, ++ii) {
3670                     tdata[i] = in[it.iofs(ii)].r;
3671                     tdata[i + 1] = in[it.iofs(ii)].i;
3672                   }
3673                 if (i < len)
3674                   tdata[i] = in[it.iofs(ii)].r;
3675               }
3676               plan->exec(tdata, fct, false);
3677               copy_output(it, tdata, out);
3678             }
3679           }); // end of parallel region
3680     }
3681 
3682     struct ExecR2R {
3683       bool r2c, forward;
3684 
3685       template <typename T0, typename T, size_t vlen>
operator ()pocketfft::detail::ExecR2R3686       void operator()(const multi_iter<vlen> &it, const cndarr<T0> &in,
3687                       ndarr<T0> &out, T *buf, const pocketfft_r<T0> &plan,
3688                       T0 fct) const
3689       {
3690         copy_input(it, in, buf);
3691         if ((!r2c) && forward)
3692           for (size_t i = 2; i < it.length_out(); i += 2)
3693             buf[i] = -buf[i];
3694         plan.exec(buf, fct, forward);
3695         if (r2c && (!forward))
3696           for (size_t i = 2; i < it.length_out(); i += 2)
3697             buf[i] = -buf[i];
3698         copy_output(it, buf, out);
3699       }
3700     };
3701 
3702     template <typename T>
c2c(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,bool forward,const std::complex<T> * data_in,std::complex<T> * data_out,T fct,size_t nthreads=1)3703     void c2c(const shape_t &shape, const stride_t &stride_in,
3704              const stride_t &stride_out, const shape_t &axes, bool forward,
3705              const std::complex<T> *data_in, std::complex<T> *data_out, T fct,
3706              size_t nthreads = 1)
3707     {
3708       if (util::prod(shape) == 0)
3709         return;
3710       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3711                          axes);
3712       cndarr<cmplx<T>> ain(data_in, shape, stride_in);
3713       ndarr<cmplx<T>> aout(data_out, shape, stride_out);
3714       general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads,
3715                                  ExecC2C{forward});
3716     }
3717 
3718     template <typename T>
dct(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,int type,const T * data_in,T * data_out,T fct,bool ortho,size_t nthreads=1)3719     void dct(const shape_t &shape, const stride_t &stride_in,
3720              const stride_t &stride_out, const shape_t &axes, int type,
3721              const T *data_in, T *data_out, T fct, bool ortho,
3722              size_t nthreads = 1)
3723     {
3724       if ((type < 1) || (type > 4))
3725         throw std::invalid_argument("invalid DCT type");
3726       if (util::prod(shape) == 0)
3727         return;
3728       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3729                          axes);
3730       cndarr<T> ain(data_in, shape, stride_in);
3731       ndarr<T> aout(data_out, shape, stride_out);
3732       const ExecDcst exec{ortho, type, true};
3733       if (type == 1)
3734         general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
3735       else if (type == 4)
3736         general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
3737       else
3738         general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
3739     }
3740 
3741     template <typename T>
dst(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,int type,const T * data_in,T * data_out,T fct,bool ortho,size_t nthreads=1)3742     void dst(const shape_t &shape, const stride_t &stride_in,
3743              const stride_t &stride_out, const shape_t &axes, int type,
3744              const T *data_in, T *data_out, T fct, bool ortho,
3745              size_t nthreads = 1)
3746     {
3747       if ((type < 1) || (type > 4))
3748         throw std::invalid_argument("invalid DST type");
3749       if (util::prod(shape) == 0)
3750         return;
3751       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3752                          axes);
3753       cndarr<T> ain(data_in, shape, stride_in);
3754       ndarr<T> aout(data_out, shape, stride_out);
3755       const ExecDcst exec{ortho, type, false};
3756       if (type == 1)
3757         general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
3758       else if (type == 4)
3759         general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
3760       else
3761         general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
3762     }
3763 
3764     template <typename T>
r2c(const shape_t & shape_in,const stride_t & stride_in,const stride_t & stride_out,size_t axis,bool forward,const T * data_in,std::complex<T> * data_out,T fct,size_t nthreads=1)3765     void r2c(const shape_t &shape_in, const stride_t &stride_in,
3766              const stride_t &stride_out, size_t axis, bool forward,
3767              const T *data_in, std::complex<T> *data_out, T fct,
3768              size_t nthreads = 1)
3769     {
3770       if (util::prod(shape_in) == 0)
3771         return;
3772       util::sanity_check(shape_in, stride_in, stride_out, false, axis);
3773       cndarr<T> ain(data_in, shape_in, stride_in);
3774       shape_t shape_out(shape_in);
3775       shape_out[axis] = shape_in[axis] / 2 + 1;
3776       ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);
3777       general_r2c(ain, aout, axis, forward, fct, nthreads);
3778     }
3779 
3780     template <typename T>
r2c(const shape_t & shape_in,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,bool forward,const T * data_in,std::complex<T> * data_out,T fct,size_t nthreads=1)3781     void r2c(const shape_t &shape_in, const stride_t &stride_in,
3782              const stride_t &stride_out, const shape_t &axes, bool forward,
3783              const T *data_in, std::complex<T> *data_out, T fct,
3784              size_t nthreads = 1)
3785     {
3786       if (util::prod(shape_in) == 0)
3787         return;
3788       util::sanity_check(shape_in, stride_in, stride_out, false, axes);
3789       r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in,
3790           data_out, fct, nthreads);
3791       if (axes.size() == 1)
3792         return;
3793 
3794       shape_t shape_out(shape_in);
3795       shape_out[axes.back()] = shape_in[axes.back()] / 2 + 1;
3796       auto newaxes = shape_t{axes.begin(), --axes.end()};
3797       c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out,
3798           data_out, T(1), nthreads);
3799     }
3800 
3801     template <typename T>
c2r(const shape_t & shape_out,const stride_t & stride_in,const stride_t & stride_out,size_t axis,bool forward,const std::complex<T> * data_in,T * data_out,T fct,size_t nthreads=1)3802     void c2r(const shape_t &shape_out, const stride_t &stride_in,
3803              const stride_t &stride_out, size_t axis, bool forward,
3804              const std::complex<T> *data_in, T *data_out, T fct,
3805              size_t nthreads = 1)
3806     {
3807       if (util::prod(shape_out) == 0)
3808         return;
3809       util::sanity_check(shape_out, stride_in, stride_out, false, axis);
3810       shape_t shape_in(shape_out);
3811       shape_in[axis] = shape_out[axis] / 2 + 1;
3812       cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);
3813       ndarr<T> aout(data_out, shape_out, stride_out);
3814       general_c2r(ain, aout, axis, forward, fct, nthreads);
3815     }
3816 
3817     template <typename T>
c2r(const shape_t & shape_out,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,bool forward,const std::complex<T> * data_in,T * data_out,T fct,size_t nthreads=1)3818     void c2r(const shape_t &shape_out, const stride_t &stride_in,
3819              const stride_t &stride_out, const shape_t &axes, bool forward,
3820              const std::complex<T> *data_in, T *data_out, T fct,
3821              size_t nthreads = 1)
3822     {
3823       if (util::prod(shape_out) == 0)
3824         return;
3825       if (axes.size() == 1)
3826         return c2r(shape_out, stride_in, stride_out, axes[0], forward, data_in,
3827                    data_out, fct, nthreads);
3828       util::sanity_check(shape_out, stride_in, stride_out, false, axes);
3829       auto shape_in = shape_out;
3830       shape_in[axes.back()] = shape_out[axes.back()] / 2 + 1;
3831       auto nval = util::prod(shape_in);
3832       stride_t stride_inter(shape_in.size());
3833       stride_inter.back() = sizeof(cmplx<T>);
3834       for (int i = int(shape_in.size()) - 2; i >= 0; --i)
3835         stride_inter[size_t(i)] =
3836             stride_inter[size_t(i + 1)] * ptrdiff_t(shape_in[size_t(i + 1)]);
3837       arr<std::complex<T>> tmp(nval);
3838       auto newaxes = shape_t{axes.begin(), --axes.end()};
3839       c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in,
3840           tmp.data(), T(1), nthreads);
3841       c2r(shape_out, stride_inter, stride_out, axes.back(), forward, tmp.data(),
3842           data_out, fct, nthreads);
3843     }
3844 
3845     template <typename T>
r2r_fftpack(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,bool real2hermitian,bool forward,const T * data_in,T * data_out,T fct,size_t nthreads=1)3846     void r2r_fftpack(const shape_t &shape, const stride_t &stride_in,
3847                      const stride_t &stride_out, const shape_t &axes,
3848                      bool real2hermitian, bool forward, const T *data_in,
3849                      T *data_out, T fct, size_t nthreads = 1)
3850     {
3851       if (util::prod(shape) == 0)
3852         return;
3853       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3854                          axes);
3855       cndarr<T> ain(data_in, shape, stride_in);
3856       ndarr<T> aout(data_out, shape, stride_out);
3857       general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,
3858                                  ExecR2R{real2hermitian, forward});
3859     }
3860 
3861     template <typename T>
r2r_separable_hartley(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,const T * data_in,T * data_out,T fct,size_t nthreads=1)3862     void r2r_separable_hartley(const shape_t &shape, const stride_t &stride_in,
3863                                const stride_t &stride_out, const shape_t &axes,
3864                                const T *data_in, T *data_out, T fct,
3865                                size_t nthreads = 1)
3866     {
3867       if (util::prod(shape) == 0)
3868         return;
3869       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3870                          axes);
3871       cndarr<T> ain(data_in, shape, stride_in);
3872       ndarr<T> aout(data_out, shape, stride_out);
3873       general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
3874                                  false);
3875     }
3876 
3877     template <typename T>
r2r_genuine_hartley(const shape_t & shape,const stride_t & stride_in,const stride_t & stride_out,const shape_t & axes,const T * data_in,T * data_out,T fct,size_t nthreads=1)3878     void r2r_genuine_hartley(const shape_t &shape, const stride_t &stride_in,
3879                              const stride_t &stride_out, const shape_t &axes,
3880                              const T *data_in, T *data_out, T fct,
3881                              size_t nthreads = 1)
3882     {
3883       if (util::prod(shape) == 0)
3884         return;
3885       if (axes.size() == 1)
3886         return r2r_separable_hartley(shape, stride_in, stride_out, axes,
3887                                      data_in, data_out, fct, nthreads);
3888       util::sanity_check(shape, stride_in, stride_out, data_in == data_out,
3889                          axes);
3890       shape_t tshp(shape);
3891       tshp[axes.back()] = tshp[axes.back()] / 2 + 1;
3892       arr<std::complex<T>> tdata(util::prod(tshp));
3893       stride_t tstride(shape.size());
3894       tstride.back() = sizeof(std::complex<T>);
3895       for (size_t i = tstride.size() - 1; i > 0; --i)
3896         tstride[i - 1] = tstride[i] * ptrdiff_t(tshp[i]);
3897       r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct,
3898           nthreads);
3899       cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);
3900       ndarr<T> aout(data_out, shape, stride_out);
3901       simple_iter iin(atmp);
3902       rev_iter iout(aout, axes);
3903       while (iin.remaining() > 0) {
3904         auto v = atmp[iin.ofs()];
3905         aout[iout.ofs()] = v.r + v.i;
3906         aout[iout.rev_ofs()] = v.r - v.i;
3907         iin.advance();
3908         iout.advance();
3909       }
3910     }
3911 
3912   } // namespace detail
3913 
3914   using detail::FORWARD;
3915   using detail::BACKWARD;
3916   using detail::shape_t;
3917   using detail::stride_t;
3918   using detail::c2c;
3919   using detail::c2r;
3920   using detail::r2c;
3921   using detail::r2r_fftpack;
3922   using detail::r2r_separable_hartley;
3923   using detail::r2r_genuine_hartley;
3924   using detail::dct;
3925   using detail::dst;
3926 
3927 } // namespace pocketfft
3928 
3929 #undef POCKETFFT_NOINLINE
3930 #undef POCKETFFT_RESTRICT
3931 
3932 #endif // POCKETFFT_HDRONLY_H
3933 #endif // PYTHONIC_INCLUDE_NUMPY_FFT_POCKETFFT_HPP
3934