1 /*
2  *  Copyright (c) 2003-2010, Mark Borgerding. All rights reserved.
3  *  This file is part of KISS FFT - https://github.com/mborgerding/kissfft
4  *
5  *  SPDX-License-Identifier: BSD-3-Clause
6  *  See COPYING file for more information.
7  */
8 
9 #ifndef KISSFFT_CLASS_HH
10 #define KISSFFT_CLASS_HH
11 #include <complex>
12 #include <utility>
13 #include <vector>
14 
15 
16 template <typename scalar_t>
17 class kissfft
18 {
19     public:
20 
21         using cpx_t = std::complex<scalar_t>;
22 
kissfft(const std::size_t nfft,const bool inverse)23         kissfft( const std::size_t nfft,
24                  const bool inverse )
25             :_nfft(nfft)
26             ,_inverse(inverse)
27         {
28             // fill twiddle factors
29             _twiddles.resize(_nfft);
30             const scalar_t phinc =  (_inverse?2:-2)* acos( (scalar_t) -1)  / _nfft;
31             for (std::size_t i=0;i<_nfft;++i)
32                 _twiddles[i] = exp( cpx_t(0,i*phinc) );
33 
34             //factorize
35             //start factoring out 4's, then 2's, then 3,5,7,9,...
36             std::size_t n= _nfft;
37             std::size_t p=4;
38             do {
39                 while (n % p) {
40                     switch (p) {
41                         case 4: p = 2; break;
42                         case 2: p = 3; break;
43                         default: p += 2; break;
44                     }
45                     if (p*p>n)
46                         p = n;// no more factors
47                 }
48                 n /= p;
49                 _stageRadix.push_back(p);
50                 _stageRemainder.push_back(n);
51             }while(n>1);
52         }
53 
54 
55         /// Changes the FFT-length and/or the transform direction.
56         ///
57         /// @post The @c kissfft object will be in the same state as if it
58         /// had been newly constructed with the passed arguments.
59         /// However, the implementation may be faster than constructing a
60         /// new fft object.
assign(const std::size_t nfft,const bool inverse)61         void assign( const std::size_t nfft,
62                      const bool inverse )
63         {
64             if ( nfft != _nfft )
65             {
66                 kissfft tmp( nfft, inverse ); // O(n) time.
67                 std::swap( tmp, *this ); // this is O(1) in C++11, O(n) otherwise.
68             }
69             else if ( inverse != _inverse )
70             {
71                 // conjugate the twiddle factors.
72                 for ( typename std::vector<cpx_t>::iterator it = _twiddles.begin();
73                       it != _twiddles.end(); ++it )
74                     it->imag( -it->imag() );
75             }
76         }
77 
78         /// Calculates the complex Discrete Fourier Transform.
79         ///
80         /// The size of the passed arrays must be passed in the constructor.
81         /// The sum of the squares of the absolute values in the @c dst
82         /// array will be @c N times the sum of the squares of the absolute
83         /// values in the @c src array, where @c N is the size of the array.
84         /// In other words, the l_2 norm of the resulting array will be
85         /// @c sqrt(N) times as big as the l_2 norm of the input array.
86         /// This is also the case when the inverse flag is set in the
87         /// constructor. Hence when applying the same transform twice, but with
88         /// the inverse flag changed the second time, then the result will
89         /// be equal to the original input times @c N.
transform(const cpx_t * fft_in,cpx_t * fft_out,const std::size_t stage=0,const std::size_t fstride=1,const std::size_t in_stride=1) const90         void transform(const cpx_t * fft_in, cpx_t * fft_out, const std::size_t stage = 0, const std::size_t fstride = 1, const std::size_t in_stride = 1) const
91         {
92             const std::size_t p = _stageRadix[stage];
93             const std::size_t m = _stageRemainder[stage];
94             cpx_t * const Fout_beg = fft_out;
95             cpx_t * const Fout_end = fft_out + p*m;
96 
97             if (m==1) {
98                 do{
99                     *fft_out = *fft_in;
100                     fft_in += fstride*in_stride;
101                 }while(++fft_out != Fout_end );
102             }else{
103                 do{
104                     // recursive call:
105                     // DFT of size m*p performed by doing
106                     // p instances of smaller DFTs of size m,
107                     // each one takes a decimated version of the input
108                     transform(fft_in, fft_out, stage+1, fstride*p,in_stride);
109                     fft_in += fstride*in_stride;
110                 }while( (fft_out += m) != Fout_end );
111             }
112 
113             fft_out=Fout_beg;
114 
115             // recombine the p smaller DFTs
116             switch (p) {
117                 case 2: kf_bfly2(fft_out,fstride,m); break;
118                 case 3: kf_bfly3(fft_out,fstride,m); break;
119                 case 4: kf_bfly4(fft_out,fstride,m); break;
120                 case 5: kf_bfly5(fft_out,fstride,m); break;
121                 default: kf_bfly_generic(fft_out,fstride,m,p); break;
122             }
123         }
124 
125         /// Calculates the Discrete Fourier Transform (DFT) of a real input
126         /// of size @c 2*N.
127         ///
128         /// The 0-th and N-th value of the DFT are real numbers. These are
129         /// stored in @c dst[0].real() and @c dst[1].imag() respectively.
130         /// The remaining DFT values up to the index N-1 are stored in
131         /// @c dst[1] to @c dst[N-1].
132         /// The other half of the DFT values can be calculated from the
133         /// symmetry relation
134         ///     @code
135         ///         DFT(src)[2*N-k] == conj( DFT(src)[k] );
136         ///     @endcode
137         /// The same scaling factors as in @c transform() apply.
138         ///
139         /// @note For this to work, the types @c scalar_t and @c cpx_t
140         /// must fulfill the following requirements:
141         ///
142         /// For any object @c z of type @c cpx_t,
143         /// @c reinterpret_cast<scalar_t(&)[2]>(z)[0] is the real part of @c z and
144         /// @c reinterpret_cast<scalar_t(&)[2]>(z)[1] is the imaginary part of @c z.
145         /// For any pointer to an element of an array of @c cpx_t named @c p
146         /// and any valid array index @c i, @c reinterpret_cast<T*>(p)[2*i]
147         /// is the real part of the complex number @c p[i], and
148         /// @c reinterpret_cast<T*>(p)[2*i+1] is the imaginary part of the
149         /// complex number @c p[i].
150         ///
151         /// Since C++11, these requirements are guaranteed to be satisfied for
152         /// @c scalar_ts being @c float, @c double or @c long @c double
153         /// together with @c cpx_t being @c std::complex<scalar_t>.
transform_real(const scalar_t * const src,cpx_t * const dst) const154         void transform_real( const scalar_t * const src,
155                              cpx_t * const dst ) const
156         {
157             const std::size_t N = _nfft;
158             if ( N == 0 )
159                 return;
160 
161             // perform complex FFT
162             transform( reinterpret_cast<const cpx_t*>(src), dst );
163 
164             // post processing for k = 0 and k = N
165             dst[0] = cpx_t( dst[0].real() + dst[0].imag(),
166                                dst[0].real() - dst[0].imag() );
167 
168             // post processing for all the other k = 1, 2, ..., N-1
169             const scalar_t pi = acos( (scalar_t) -1);
170             const scalar_t half_phi_inc = ( _inverse ? pi : -pi ) / N;
171             const cpx_t twiddle_mul = exp( cpx_t(0, half_phi_inc) );
172             for ( std::size_t k = 1; 2*k < N; ++k )
173             {
174                 const cpx_t w = (scalar_t)0.5 * cpx_t(
175                      dst[k].real() + dst[N-k].real(),
176                      dst[k].imag() - dst[N-k].imag() );
177                 const cpx_t z = (scalar_t)0.5 * cpx_t(
178                      dst[k].imag() + dst[N-k].imag(),
179                     -dst[k].real() + dst[N-k].real() );
180                 const cpx_t twiddle =
181                     k % 2 == 0 ?
182                     _twiddles[k/2] :
183                     _twiddles[k/2] * twiddle_mul;
184                 dst[  k] =       w + twiddle * z;
185                 dst[N-k] = conj( w - twiddle * z );
186             }
187             if ( N % 2 == 0 )
188                 dst[N/2] = conj( dst[N/2] );
189         }
190 
191     private:
192 
kf_bfly2(cpx_t * Fout,const size_t fstride,const std::size_t m) const193         void kf_bfly2( cpx_t * Fout, const size_t fstride, const std::size_t m) const
194         {
195             for (std::size_t k=0;k<m;++k) {
196                 const cpx_t t = Fout[m+k] * _twiddles[k*fstride];
197                 Fout[m+k] = Fout[k] - t;
198                 Fout[k] += t;
199             }
200         }
201 
kf_bfly3(cpx_t * Fout,const std::size_t fstride,const std::size_t m) const202         void kf_bfly3( cpx_t * Fout, const std::size_t fstride, const std::size_t m) const
203         {
204             std::size_t k=m;
205             const std::size_t m2 = 2*m;
206             const cpx_t *tw1,*tw2;
207             cpx_t scratch[5];
208             const cpx_t epi3 = _twiddles[fstride*m];
209 
210             tw1=tw2=&_twiddles[0];
211 
212             do{
213                 scratch[1] = Fout[m]  * *tw1;
214                 scratch[2] = Fout[m2] * *tw2;
215 
216                 scratch[3] = scratch[1] + scratch[2];
217                 scratch[0] = scratch[1] - scratch[2];
218                 tw1 += fstride;
219                 tw2 += fstride*2;
220 
221                 Fout[m] = Fout[0] - scratch[3]*scalar_t(0.5);
222                 scratch[0] *= epi3.imag();
223 
224                 Fout[0] += scratch[3];
225 
226                 Fout[m2] = cpx_t(  Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
227 
228                 Fout[m] += cpx_t( -scratch[0].imag(),scratch[0].real() );
229                 ++Fout;
230             }while(--k);
231         }
232 
kf_bfly4(cpx_t * const Fout,const std::size_t fstride,const std::size_t m) const233         void kf_bfly4( cpx_t * const Fout, const std::size_t fstride, const std::size_t m) const
234         {
235             cpx_t scratch[7];
236             const scalar_t negative_if_inverse = _inverse ? -1 : +1;
237             for (std::size_t k=0;k<m;++k) {
238                 scratch[0] = Fout[k+  m] * _twiddles[k*fstride  ];
239                 scratch[1] = Fout[k+2*m] * _twiddles[k*fstride*2];
240                 scratch[2] = Fout[k+3*m] * _twiddles[k*fstride*3];
241                 scratch[5] = Fout[k] - scratch[1];
242 
243                 Fout[k] += scratch[1];
244                 scratch[3] = scratch[0] + scratch[2];
245                 scratch[4] = scratch[0] - scratch[2];
246                 scratch[4] = cpx_t( scratch[4].imag()*negative_if_inverse ,
247                                       -scratch[4].real()*negative_if_inverse );
248 
249                 Fout[k+2*m]  = Fout[k] - scratch[3];
250                 Fout[k    ]+= scratch[3];
251                 Fout[k+  m] = scratch[5] + scratch[4];
252                 Fout[k+3*m] = scratch[5] - scratch[4];
253             }
254         }
255 
kf_bfly5(cpx_t * const Fout,const std::size_t fstride,const std::size_t m) const256         void kf_bfly5( cpx_t * const Fout, const std::size_t fstride, const std::size_t m) const
257         {
258             cpx_t *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
259             cpx_t scratch[13];
260             const cpx_t ya = _twiddles[fstride*m];
261             const cpx_t yb = _twiddles[fstride*2*m];
262 
263             Fout0=Fout;
264             Fout1=Fout0+m;
265             Fout2=Fout0+2*m;
266             Fout3=Fout0+3*m;
267             Fout4=Fout0+4*m;
268 
269             for ( std::size_t u=0; u<m; ++u ) {
270                 scratch[0] = *Fout0;
271 
272                 scratch[1] = *Fout1 * _twiddles[  u*fstride];
273                 scratch[2] = *Fout2 * _twiddles[2*u*fstride];
274                 scratch[3] = *Fout3 * _twiddles[3*u*fstride];
275                 scratch[4] = *Fout4 * _twiddles[4*u*fstride];
276 
277                 scratch[7] = scratch[1] + scratch[4];
278                 scratch[10]= scratch[1] - scratch[4];
279                 scratch[8] = scratch[2] + scratch[3];
280                 scratch[9] = scratch[2] - scratch[3];
281 
282                 *Fout0 += scratch[7];
283                 *Fout0 += scratch[8];
284 
285                 scratch[5] = scratch[0] + cpx_t(
286                         scratch[7].real()*ya.real() + scratch[8].real()*yb.real(),
287                         scratch[7].imag()*ya.real() + scratch[8].imag()*yb.real()
288                         );
289 
290                 scratch[6] =  cpx_t(
291                          scratch[10].imag()*ya.imag() + scratch[9].imag()*yb.imag(),
292                         -scratch[10].real()*ya.imag() - scratch[9].real()*yb.imag()
293                         );
294 
295                 *Fout1 = scratch[5] - scratch[6];
296                 *Fout4 = scratch[5] + scratch[6];
297 
298                 scratch[11] = scratch[0] +
299                     cpx_t(
300                             scratch[7].real()*yb.real() + scratch[8].real()*ya.real(),
301                             scratch[7].imag()*yb.real() + scratch[8].imag()*ya.real()
302                             );
303 
304                 scratch[12] = cpx_t(
305                         -scratch[10].imag()*yb.imag() + scratch[9].imag()*ya.imag(),
306                          scratch[10].real()*yb.imag() - scratch[9].real()*ya.imag()
307                         );
308 
309                 *Fout2 = scratch[11] + scratch[12];
310                 *Fout3 = scratch[11] - scratch[12];
311 
312                 ++Fout0;
313                 ++Fout1;
314                 ++Fout2;
315                 ++Fout3;
316                 ++Fout4;
317             }
318         }
319 
320         /* perform the butterfly for one stage of a mixed radix FFT */
kf_bfly_generic(cpx_t * const Fout,const size_t fstride,const std::size_t m,const std::size_t p) const321         void kf_bfly_generic(
322                 cpx_t * const Fout,
323                 const size_t fstride,
324                 const std::size_t m,
325                 const std::size_t p
326                 ) const
327         {
328             const cpx_t * twiddles = &_twiddles[0];
329             cpx_t scratchbuf[p];
330 
331             for ( std::size_t u=0; u<m; ++u ) {
332                 std::size_t k = u;
333                 for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
334                     scratchbuf[q1] = Fout[ k  ];
335                     k += m;
336                 }
337 
338                 k=u;
339                 for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
340                     std::size_t twidx=0;
341                     Fout[ k ] = scratchbuf[0];
342                     for ( std::size_t q=1;q<p;++q ) {
343                         twidx += fstride * k;
344                         if (twidx>=_nfft)
345                           twidx-=_nfft;
346                         Fout[ k ] += scratchbuf[q] * twiddles[twidx];
347                     }
348                     k += m;
349                 }
350             }
351         }
352 
353         std::size_t _nfft;
354         bool _inverse;
355         std::vector<cpx_t> _twiddles;
356         std::vector<std::size_t> _stageRadix;
357         std::vector<std::size_t> _stageRemainder;
358 };
359 #endif
360