1 #ifndef KISSFFT_CLASS_HH
2 #define KISSFFT_CLASS_HH
3 #include <complex>
4 #include <utility>
5 #include <vector>
6 
7 
8 template <typename T_Scalar,
9          typename T_Complex=std::complex<T_Scalar>
10          >
11 class kissfft
12 {
13     public:
14         typedef T_Scalar scalar_type;
15         typedef T_Complex cpx_type;
16 
kissfft(std::size_t nfft,bool inverse)17         kissfft( std::size_t nfft,
18                  bool inverse )
19             :_nfft(nfft)
20             ,_inverse(inverse)
21         {
22             // fill twiddle factors
23             _twiddles.resize(_nfft);
24             const scalar_type phinc =  (_inverse?2:-2)* acos( (scalar_type) -1)  / _nfft;
25             for (std::size_t i=0;i<_nfft;++i)
26                 _twiddles[i] = exp( cpx_type(0,i*phinc) );
27 
28             //factorize
29             //start factoring out 4's, then 2's, then 3,5,7,9,...
30             std::size_t n= _nfft;
31             std::size_t p=4;
32             do {
33                 while (n % p) {
34                     switch (p) {
35                         case 4: p = 2; break;
36                         case 2: p = 3; break;
37                         default: p += 2; break;
38                     }
39                     if (p*p>n)
40                         p = n;// no more factors
41                 }
42                 n /= p;
43                 _stageRadix.push_back(p);
44                 _stageRemainder.push_back(n);
45             }while(n>1);
46         }
47 
48 
49         /// Changes the FFT-length and/or the transform direction.
50         ///
51         /// @post The @c kissfft object will be in the same state as if it
52         /// had been newly constructed with the passed arguments.
53         /// However, the implementation may be faster than constructing a
54         /// new fft object.
assign(std::size_t nfft,bool inverse)55         void assign( std::size_t nfft,
56                      bool inverse )
57         {
58             if ( nfft != _nfft )
59             {
60                 kissfft tmp( nfft, inverse ); // O(n) time.
61                 std::swap( tmp, *this ); // this is O(1) in C++11, O(n) otherwise.
62             }
63             else if ( inverse != _inverse )
64             {
65                 // conjugate the twiddle factors.
66                 for ( typename std::vector<cpx_type>::iterator it = _twiddles.begin();
67                       it != _twiddles.end(); ++it )
68                     it->imag( -it->imag() );
69             }
70         }
71 
72         /// Calculates the complex Discrete Fourier Transform.
73         ///
74         /// The size of the passed arrays must be passed in the constructor.
75         /// The sum of the squares of the absolute values in the @c dst
76         /// array will be @c N times the sum of the squares of the absolute
77         /// values in the @c src array, where @c N is the size of the array.
78         /// In other words, the l_2 norm of the resulting array will be
79         /// @c sqrt(N) times as big as the l_2 norm of the input array.
80         /// This is also the case when the inverse flag is set in the
81         /// constructor. Hence when applying the same transform twice, but with
82         /// the inverse flag changed the second time, then the result will
83         /// be equal to the original input times @c N.
transform(const cpx_type * src,cpx_type * dst) const84         void transform( const cpx_type * src,
85                         cpx_type * dst ) const
86         {
87             kf_work(0, dst, src, 1,1);
88         }
89 
90         /// Calculates the Discrete Fourier Transform (DFT) of a real input
91         /// of size @c 2*N.
92         ///
93         /// The 0-th and N-th value of the DFT are real numbers. These are
94         /// stored in @c dst[0].real() and @c dst[1].imag() respectively.
95         /// The remaining DFT values up to the index N-1 are stored in
96         /// @c dst[1] to @c dst[N-1].
97         /// The other half of the DFT values can be calculated from the
98         /// symmetry relation
99         ///     @code
100         ///         DFT(src)[2*N-k] == conj( DFT(src)[k] );
101         ///     @endcode
102         /// The same scaling factors as in @c transform() apply.
103         ///
104         /// @note For this to work, the types @c scalar_type and @c cpx_type
105         /// must fulfill the following requirements:
106         ///
107         /// For any object @c z of type @c cpx_type,
108         /// @c reinterpret_cast<scalar_type(&)[2]>(z)[0] is the real part of @c z and
109         /// @c reinterpret_cast<scalar_type(&)[2]>(z)[1] is the imaginary part of @c z.
110         /// For any pointer to an element of an array of @c cpx_type named @c p
111         /// and any valid array index @c i, @c reinterpret_cast<T*>(p)[2*i]
112         /// is the real part of the complex number @c p[i], and
113         /// @c reinterpret_cast<T*>(p)[2*i+1] is the imaginary part of the
114         /// complex number @c p[i].
115         ///
116         /// Since C++11, these requirements are guaranteed to be satisfied for
117         /// @c scalar_types being @c float, @c double or @c long @c double
118         /// together with @c cpx_type being @c std::complex<scalar_type>.
transform_real(const scalar_type * src,cpx_type * dst) const119         void transform_real( const scalar_type * src,
120                              cpx_type * dst ) const
121         {
122             const std::size_t N = _nfft;
123             if ( N == 0 )
124                 return;
125 
126             // perform complex FFT
127             transform( reinterpret_cast<const cpx_type*>(src), dst );
128 
129             // post processing for k = 0 and k = N
130             dst[0] = cpx_type( dst[0].real() + dst[0].imag(),
131                                dst[0].real() - dst[0].imag() );
132 
133             // post processing for all the other k = 1, 2, ..., N-1
134             const scalar_type pi = acos( (scalar_type) -1);
135             const scalar_type half_phi_inc = ( _inverse ? pi : -pi ) / N;
136             const cpx_type twiddle_mul = exp( cpx_type(0, half_phi_inc) );
137             for ( std::size_t k = 1; 2*k < N; ++k )
138             {
139                 const cpx_type w = (scalar_type)0.5 * cpx_type(
140                      dst[k].real() + dst[N-k].real(),
141                      dst[k].imag() - dst[N-k].imag() );
142                 const cpx_type z = (scalar_type)0.5 * cpx_type(
143                      dst[k].imag() + dst[N-k].imag(),
144                     -dst[k].real() + dst[N-k].real() );
145                 const cpx_type twiddle =
146                     k % 2 == 0 ?
147                     _twiddles[k/2] :
148                     _twiddles[k/2] * twiddle_mul;
149                 dst[  k] =       w + twiddle * z;
150                 dst[N-k] = conj( w - twiddle * z );
151             }
152             if ( N % 2 == 0 )
153                 dst[N/2] = conj( dst[N/2] );
154         }
155 
156     private:
kf_work(std::size_t stage,cpx_type * Fout,const cpx_type * f,std::size_t fstride,std::size_t in_stride) const157         void kf_work( std::size_t stage,
158                       cpx_type * Fout,
159                       const cpx_type * f,
160                       std::size_t fstride,
161                       std::size_t in_stride) const
162         {
163             const std::size_t p = _stageRadix[stage];
164             const std::size_t m = _stageRemainder[stage];
165             cpx_type * const Fout_beg = Fout;
166             cpx_type * const Fout_end = Fout + p*m;
167 
168             if (m==1) {
169                 do{
170                     *Fout = *f;
171                     f += fstride*in_stride;
172                 }while(++Fout != Fout_end );
173             }else{
174                 do{
175                     // recursive call:
176                     // DFT of size m*p performed by doing
177                     // p instances of smaller DFTs of size m,
178                     // each one takes a decimated version of the input
179                     kf_work(stage+1, Fout , f, fstride*p,in_stride);
180                     f += fstride*in_stride;
181                 }while( (Fout += m) != Fout_end );
182             }
183 
184             Fout=Fout_beg;
185 
186             // recombine the p smaller DFTs
187             switch (p) {
188                 case 2: kf_bfly2(Fout,fstride,m); break;
189                 case 3: kf_bfly3(Fout,fstride,m); break;
190                 case 4: kf_bfly4(Fout,fstride,m); break;
191                 case 5: kf_bfly5(Fout,fstride,m); break;
192                 default: kf_bfly_generic(Fout,fstride,m,p); break;
193             }
194         }
195 
kf_bfly2(cpx_type * Fout,const size_t fstride,std::size_t m) const196         void kf_bfly2( cpx_type * Fout, const size_t fstride, std::size_t m) const
197         {
198             for (std::size_t k=0;k<m;++k) {
199                 const cpx_type t = Fout[m+k] * _twiddles[k*fstride];
200                 Fout[m+k] = Fout[k] - t;
201                 Fout[k] += t;
202             }
203         }
204 
kf_bfly4(cpx_type * Fout,const std::size_t fstride,const std::size_t m) const205         void kf_bfly4( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
206         {
207             cpx_type scratch[7];
208             const scalar_type negative_if_inverse = _inverse ? -1 : +1;
209             for (std::size_t k=0;k<m;++k) {
210                 scratch[0] = Fout[k+  m] * _twiddles[k*fstride  ];
211                 scratch[1] = Fout[k+2*m] * _twiddles[k*fstride*2];
212                 scratch[2] = Fout[k+3*m] * _twiddles[k*fstride*3];
213                 scratch[5] = Fout[k] - scratch[1];
214 
215                 Fout[k] += scratch[1];
216                 scratch[3] = scratch[0] + scratch[2];
217                 scratch[4] = scratch[0] - scratch[2];
218                 scratch[4] = cpx_type( scratch[4].imag()*negative_if_inverse ,
219                                       -scratch[4].real()*negative_if_inverse );
220 
221                 Fout[k+2*m]  = Fout[k] - scratch[3];
222                 Fout[k    ]+= scratch[3];
223                 Fout[k+  m] = scratch[5] + scratch[4];
224                 Fout[k+3*m] = scratch[5] - scratch[4];
225             }
226         }
227 
kf_bfly3(cpx_type * Fout,const std::size_t fstride,const std::size_t m) const228         void kf_bfly3( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
229         {
230             std::size_t k=m;
231             const std::size_t m2 = 2*m;
232             const cpx_type *tw1,*tw2;
233             cpx_type scratch[5];
234             const cpx_type epi3 = _twiddles[fstride*m];
235 
236             tw1=tw2=&_twiddles[0];
237 
238             do{
239                 scratch[1] = Fout[m]  * *tw1;
240                 scratch[2] = Fout[m2] * *tw2;
241 
242                 scratch[3] = scratch[1] + scratch[2];
243                 scratch[0] = scratch[1] - scratch[2];
244                 tw1 += fstride;
245                 tw2 += fstride*2;
246 
247                 Fout[m] = Fout[0] - scratch[3]*scalar_type(0.5);
248                 scratch[0] *= epi3.imag();
249 
250                 Fout[0] += scratch[3];
251 
252                 Fout[m2] = cpx_type(  Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
253 
254                 Fout[m] += cpx_type( -scratch[0].imag(),scratch[0].real() );
255                 ++Fout;
256             }while(--k);
257         }
258 
kf_bfly5(cpx_type * Fout,const std::size_t fstride,const std::size_t m) const259         void kf_bfly5( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
260         {
261             cpx_type *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
262             cpx_type scratch[13];
263             const cpx_type ya = _twiddles[fstride*m];
264             const cpx_type yb = _twiddles[fstride*2*m];
265 
266             Fout0=Fout;
267             Fout1=Fout0+m;
268             Fout2=Fout0+2*m;
269             Fout3=Fout0+3*m;
270             Fout4=Fout0+4*m;
271 
272             for ( std::size_t u=0; u<m; ++u ) {
273                 scratch[0] = *Fout0;
274 
275                 scratch[1] = *Fout1 * _twiddles[  u*fstride];
276                 scratch[2] = *Fout2 * _twiddles[2*u*fstride];
277                 scratch[3] = *Fout3 * _twiddles[3*u*fstride];
278                 scratch[4] = *Fout4 * _twiddles[4*u*fstride];
279 
280                 scratch[7] = scratch[1] + scratch[4];
281                 scratch[10]= scratch[1] - scratch[4];
282                 scratch[8] = scratch[2] + scratch[3];
283                 scratch[9] = scratch[2] - scratch[3];
284 
285                 *Fout0 += scratch[7];
286                 *Fout0 += scratch[8];
287 
288                 scratch[5] = scratch[0] + cpx_type(
289                         scratch[7].real()*ya.real() + scratch[8].real()*yb.real(),
290                         scratch[7].imag()*ya.real() + scratch[8].imag()*yb.real()
291                         );
292 
293                 scratch[6] =  cpx_type(
294                          scratch[10].imag()*ya.imag() + scratch[9].imag()*yb.imag(),
295                         -scratch[10].real()*ya.imag() - scratch[9].real()*yb.imag()
296                         );
297 
298                 *Fout1 = scratch[5] - scratch[6];
299                 *Fout4 = scratch[5] + scratch[6];
300 
301                 scratch[11] = scratch[0] +
302                     cpx_type(
303                             scratch[7].real()*yb.real() + scratch[8].real()*ya.real(),
304                             scratch[7].imag()*yb.real() + scratch[8].imag()*ya.real()
305                             );
306 
307                 scratch[12] = cpx_type(
308                         -scratch[10].imag()*yb.imag() + scratch[9].imag()*ya.imag(),
309                          scratch[10].real()*yb.imag() - scratch[9].real()*ya.imag()
310                         );
311 
312                 *Fout2 = scratch[11] + scratch[12];
313                 *Fout3 = scratch[11] - scratch[12];
314 
315                 ++Fout0;
316                 ++Fout1;
317                 ++Fout2;
318                 ++Fout3;
319                 ++Fout4;
320             }
321         }
322 
323         /* perform the butterfly for one stage of a mixed radix FFT */
kf_bfly_generic(cpx_type * Fout,const size_t fstride,std::size_t m,std::size_t p) const324         void kf_bfly_generic(
325                 cpx_type * Fout,
326                 const size_t fstride,
327                 std::size_t m,
328                 std::size_t p
329                 ) const
330         {
331             const cpx_type * twiddles = &_twiddles[0];
332             cpx_type scratchbuf[p];
333 
334             for ( std::size_t u=0; u<m; ++u ) {
335                 std::size_t k = u;
336                 for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
337                     scratchbuf[q1] = Fout[ k  ];
338                     k += m;
339                 }
340 
341                 k=u;
342                 for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
343                     std::size_t twidx=0;
344                     Fout[ k ] = scratchbuf[0];
345                     for ( std::size_t q=1;q<p;++q ) {
346                         twidx += fstride * k;
347                         if (twidx>=_nfft)
348                           twidx-=_nfft;
349                         Fout[ k ] += scratchbuf[q] * twiddles[twidx];
350                     }
351                     k += m;
352                 }
353             }
354         }
355 
356         std::size_t _nfft;
357         bool _inverse;
358         std::vector<cpx_type> _twiddles;
359         std::vector<std::size_t> _stageRadix;
360         std::vector<std::size_t> _stageRemainder;
361 };
362 #endif
363