1 /* 2 * Copyright (c) 2003-2010, Mark Borgerding. All rights reserved. 3 * This file is part of KISS FFT - https://github.com/mborgerding/kissfft 4 * 5 * SPDX-License-Identifier: BSD-3-Clause 6 * See COPYING file for more information. 7 */ 8 9 #ifndef KISSFFT_CLASS_HH 10 #define KISSFFT_CLASS_HH 11 #include <complex> 12 #include <utility> 13 #include <vector> 14 15 16 template <typename scalar_t> 17 class kissfft 18 { 19 public: 20 21 using cpx_t = std::complex<scalar_t>; 22 kissfft(const std::size_t nfft,const bool inverse)23 kissfft( const std::size_t nfft, 24 const bool inverse ) 25 :_nfft(nfft) 26 ,_inverse(inverse) 27 { 28 // fill twiddle factors 29 _twiddles.resize(_nfft); 30 const scalar_t phinc = (_inverse?2:-2)* acos( (scalar_t) -1) / _nfft; 31 for (std::size_t i=0;i<_nfft;++i) 32 _twiddles[i] = exp( cpx_t(0,i*phinc) ); 33 34 //factorize 35 //start factoring out 4's, then 2's, then 3,5,7,9,... 36 std::size_t n= _nfft; 37 std::size_t p=4; 38 do { 39 while (n % p) { 40 switch (p) { 41 case 4: p = 2; break; 42 case 2: p = 3; break; 43 default: p += 2; break; 44 } 45 if (p*p>n) 46 p = n;// no more factors 47 } 48 n /= p; 49 _stageRadix.push_back(p); 50 _stageRemainder.push_back(n); 51 }while(n>1); 52 } 53 54 55 /// Changes the FFT-length and/or the transform direction. 56 /// 57 /// @post The @c kissfft object will be in the same state as if it 58 /// had been newly constructed with the passed arguments. 59 /// However, the implementation may be faster than constructing a 60 /// new fft object. assign(const std::size_t nfft,const bool inverse)61 void assign( const std::size_t nfft, 62 const bool inverse ) 63 { 64 if ( nfft != _nfft ) 65 { 66 kissfft tmp( nfft, inverse ); // O(n) time. 67 std::swap( tmp, *this ); // this is O(1) in C++11, O(n) otherwise. 68 } 69 else if ( inverse != _inverse ) 70 { 71 // conjugate the twiddle factors. 72 for ( typename std::vector<cpx_t>::iterator it = _twiddles.begin(); 73 it != _twiddles.end(); ++it ) 74 it->imag( -it->imag() ); 75 } 76 } 77 78 /// Calculates the complex Discrete Fourier Transform. 79 /// 80 /// The size of the passed arrays must be passed in the constructor. 81 /// The sum of the squares of the absolute values in the @c dst 82 /// array will be @c N times the sum of the squares of the absolute 83 /// values in the @c src array, where @c N is the size of the array. 84 /// In other words, the l_2 norm of the resulting array will be 85 /// @c sqrt(N) times as big as the l_2 norm of the input array. 86 /// This is also the case when the inverse flag is set in the 87 /// constructor. Hence when applying the same transform twice, but with 88 /// the inverse flag changed the second time, then the result will 89 /// be equal to the original input times @c N. transform(const cpx_t * fft_in,cpx_t * fft_out,const std::size_t stage=0,const std::size_t fstride=1,const std::size_t in_stride=1) const90 void transform(const cpx_t * fft_in, cpx_t * fft_out, const std::size_t stage = 0, const std::size_t fstride = 1, const std::size_t in_stride = 1) const 91 { 92 const std::size_t p = _stageRadix[stage]; 93 const std::size_t m = _stageRemainder[stage]; 94 cpx_t * const Fout_beg = fft_out; 95 cpx_t * const Fout_end = fft_out + p*m; 96 97 if (m==1) { 98 do{ 99 *fft_out = *fft_in; 100 fft_in += fstride*in_stride; 101 }while(++fft_out != Fout_end ); 102 }else{ 103 do{ 104 // recursive call: 105 // DFT of size m*p performed by doing 106 // p instances of smaller DFTs of size m, 107 // each one takes a decimated version of the input 108 transform(fft_in, fft_out, stage+1, fstride*p,in_stride); 109 fft_in += fstride*in_stride; 110 }while( (fft_out += m) != Fout_end ); 111 } 112 113 fft_out=Fout_beg; 114 115 // recombine the p smaller DFTs 116 switch (p) { 117 case 2: kf_bfly2(fft_out,fstride,m); break; 118 case 3: kf_bfly3(fft_out,fstride,m); break; 119 case 4: kf_bfly4(fft_out,fstride,m); break; 120 case 5: kf_bfly5(fft_out,fstride,m); break; 121 default: kf_bfly_generic(fft_out,fstride,m,p); break; 122 } 123 } 124 125 /// Calculates the Discrete Fourier Transform (DFT) of a real input 126 /// of size @c 2*N. 127 /// 128 /// The 0-th and N-th value of the DFT are real numbers. These are 129 /// stored in @c dst[0].real() and @c dst[1].imag() respectively. 130 /// The remaining DFT values up to the index N-1 are stored in 131 /// @c dst[1] to @c dst[N-1]. 132 /// The other half of the DFT values can be calculated from the 133 /// symmetry relation 134 /// @code 135 /// DFT(src)[2*N-k] == conj( DFT(src)[k] ); 136 /// @endcode 137 /// The same scaling factors as in @c transform() apply. 138 /// 139 /// @note For this to work, the types @c scalar_t and @c cpx_t 140 /// must fulfill the following requirements: 141 /// 142 /// For any object @c z of type @c cpx_t, 143 /// @c reinterpret_cast<scalar_t(&)[2]>(z)[0] is the real part of @c z and 144 /// @c reinterpret_cast<scalar_t(&)[2]>(z)[1] is the imaginary part of @c z. 145 /// For any pointer to an element of an array of @c cpx_t named @c p 146 /// and any valid array index @c i, @c reinterpret_cast<T*>(p)[2*i] 147 /// is the real part of the complex number @c p[i], and 148 /// @c reinterpret_cast<T*>(p)[2*i+1] is the imaginary part of the 149 /// complex number @c p[i]. 150 /// 151 /// Since C++11, these requirements are guaranteed to be satisfied for 152 /// @c scalar_ts being @c float, @c double or @c long @c double 153 /// together with @c cpx_t being @c std::complex<scalar_t>. transform_real(const scalar_t * const src,cpx_t * const dst) const154 void transform_real( const scalar_t * const src, 155 cpx_t * const dst ) const 156 { 157 const std::size_t N = _nfft; 158 if ( N == 0 ) 159 return; 160 161 // perform complex FFT 162 transform( reinterpret_cast<const cpx_t*>(src), dst ); 163 164 // post processing for k = 0 and k = N 165 dst[0] = cpx_t( dst[0].real() + dst[0].imag(), 166 dst[0].real() - dst[0].imag() ); 167 168 // post processing for all the other k = 1, 2, ..., N-1 169 const scalar_t pi = acos( (scalar_t) -1); 170 const scalar_t half_phi_inc = ( _inverse ? pi : -pi ) / N; 171 const cpx_t twiddle_mul = exp( cpx_t(0, half_phi_inc) ); 172 for ( std::size_t k = 1; 2*k < N; ++k ) 173 { 174 const cpx_t w = (scalar_t)0.5 * cpx_t( 175 dst[k].real() + dst[N-k].real(), 176 dst[k].imag() - dst[N-k].imag() ); 177 const cpx_t z = (scalar_t)0.5 * cpx_t( 178 dst[k].imag() + dst[N-k].imag(), 179 -dst[k].real() + dst[N-k].real() ); 180 const cpx_t twiddle = 181 k % 2 == 0 ? 182 _twiddles[k/2] : 183 _twiddles[k/2] * twiddle_mul; 184 dst[ k] = w + twiddle * z; 185 dst[N-k] = conj( w - twiddle * z ); 186 } 187 if ( N % 2 == 0 ) 188 dst[N/2] = conj( dst[N/2] ); 189 } 190 191 private: 192 kf_bfly2(cpx_t * Fout,const size_t fstride,const std::size_t m) const193 void kf_bfly2( cpx_t * Fout, const size_t fstride, const std::size_t m) const 194 { 195 for (std::size_t k=0;k<m;++k) { 196 const cpx_t t = Fout[m+k] * _twiddles[k*fstride]; 197 Fout[m+k] = Fout[k] - t; 198 Fout[k] += t; 199 } 200 } 201 kf_bfly3(cpx_t * Fout,const std::size_t fstride,const std::size_t m) const202 void kf_bfly3( cpx_t * Fout, const std::size_t fstride, const std::size_t m) const 203 { 204 std::size_t k=m; 205 const std::size_t m2 = 2*m; 206 const cpx_t *tw1,*tw2; 207 cpx_t scratch[5]; 208 const cpx_t epi3 = _twiddles[fstride*m]; 209 210 tw1=tw2=&_twiddles[0]; 211 212 do{ 213 scratch[1] = Fout[m] * *tw1; 214 scratch[2] = Fout[m2] * *tw2; 215 216 scratch[3] = scratch[1] + scratch[2]; 217 scratch[0] = scratch[1] - scratch[2]; 218 tw1 += fstride; 219 tw2 += fstride*2; 220 221 Fout[m] = Fout[0] - scratch[3]*scalar_t(0.5); 222 scratch[0] *= epi3.imag(); 223 224 Fout[0] += scratch[3]; 225 226 Fout[m2] = cpx_t( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() ); 227 228 Fout[m] += cpx_t( -scratch[0].imag(),scratch[0].real() ); 229 ++Fout; 230 }while(--k); 231 } 232 kf_bfly4(cpx_t * const Fout,const std::size_t fstride,const std::size_t m) const233 void kf_bfly4( cpx_t * const Fout, const std::size_t fstride, const std::size_t m) const 234 { 235 cpx_t scratch[7]; 236 const scalar_t negative_if_inverse = _inverse ? -1 : +1; 237 for (std::size_t k=0;k<m;++k) { 238 scratch[0] = Fout[k+ m] * _twiddles[k*fstride ]; 239 scratch[1] = Fout[k+2*m] * _twiddles[k*fstride*2]; 240 scratch[2] = Fout[k+3*m] * _twiddles[k*fstride*3]; 241 scratch[5] = Fout[k] - scratch[1]; 242 243 Fout[k] += scratch[1]; 244 scratch[3] = scratch[0] + scratch[2]; 245 scratch[4] = scratch[0] - scratch[2]; 246 scratch[4] = cpx_t( scratch[4].imag()*negative_if_inverse , 247 -scratch[4].real()*negative_if_inverse ); 248 249 Fout[k+2*m] = Fout[k] - scratch[3]; 250 Fout[k ]+= scratch[3]; 251 Fout[k+ m] = scratch[5] + scratch[4]; 252 Fout[k+3*m] = scratch[5] - scratch[4]; 253 } 254 } 255 kf_bfly5(cpx_t * const Fout,const std::size_t fstride,const std::size_t m) const256 void kf_bfly5( cpx_t * const Fout, const std::size_t fstride, const std::size_t m) const 257 { 258 cpx_t *Fout0,*Fout1,*Fout2,*Fout3,*Fout4; 259 cpx_t scratch[13]; 260 const cpx_t ya = _twiddles[fstride*m]; 261 const cpx_t yb = _twiddles[fstride*2*m]; 262 263 Fout0=Fout; 264 Fout1=Fout0+m; 265 Fout2=Fout0+2*m; 266 Fout3=Fout0+3*m; 267 Fout4=Fout0+4*m; 268 269 for ( std::size_t u=0; u<m; ++u ) { 270 scratch[0] = *Fout0; 271 272 scratch[1] = *Fout1 * _twiddles[ u*fstride]; 273 scratch[2] = *Fout2 * _twiddles[2*u*fstride]; 274 scratch[3] = *Fout3 * _twiddles[3*u*fstride]; 275 scratch[4] = *Fout4 * _twiddles[4*u*fstride]; 276 277 scratch[7] = scratch[1] + scratch[4]; 278 scratch[10]= scratch[1] - scratch[4]; 279 scratch[8] = scratch[2] + scratch[3]; 280 scratch[9] = scratch[2] - scratch[3]; 281 282 *Fout0 += scratch[7]; 283 *Fout0 += scratch[8]; 284 285 scratch[5] = scratch[0] + cpx_t( 286 scratch[7].real()*ya.real() + scratch[8].real()*yb.real(), 287 scratch[7].imag()*ya.real() + scratch[8].imag()*yb.real() 288 ); 289 290 scratch[6] = cpx_t( 291 scratch[10].imag()*ya.imag() + scratch[9].imag()*yb.imag(), 292 -scratch[10].real()*ya.imag() - scratch[9].real()*yb.imag() 293 ); 294 295 *Fout1 = scratch[5] - scratch[6]; 296 *Fout4 = scratch[5] + scratch[6]; 297 298 scratch[11] = scratch[0] + 299 cpx_t( 300 scratch[7].real()*yb.real() + scratch[8].real()*ya.real(), 301 scratch[7].imag()*yb.real() + scratch[8].imag()*ya.real() 302 ); 303 304 scratch[12] = cpx_t( 305 -scratch[10].imag()*yb.imag() + scratch[9].imag()*ya.imag(), 306 scratch[10].real()*yb.imag() - scratch[9].real()*ya.imag() 307 ); 308 309 *Fout2 = scratch[11] + scratch[12]; 310 *Fout3 = scratch[11] - scratch[12]; 311 312 ++Fout0; 313 ++Fout1; 314 ++Fout2; 315 ++Fout3; 316 ++Fout4; 317 } 318 } 319 320 /* perform the butterfly for one stage of a mixed radix FFT */ kf_bfly_generic(cpx_t * const Fout,const size_t fstride,const std::size_t m,const std::size_t p) const321 void kf_bfly_generic( 322 cpx_t * const Fout, 323 const size_t fstride, 324 const std::size_t m, 325 const std::size_t p 326 ) const 327 { 328 const cpx_t * twiddles = &_twiddles[0]; 329 cpx_t scratchbuf[p]; 330 331 for ( std::size_t u=0; u<m; ++u ) { 332 std::size_t k = u; 333 for ( std::size_t q1=0 ; q1<p ; ++q1 ) { 334 scratchbuf[q1] = Fout[ k ]; 335 k += m; 336 } 337 338 k=u; 339 for ( std::size_t q1=0 ; q1<p ; ++q1 ) { 340 std::size_t twidx=0; 341 Fout[ k ] = scratchbuf[0]; 342 for ( std::size_t q=1;q<p;++q ) { 343 twidx += fstride * k; 344 if (twidx>=_nfft) 345 twidx-=_nfft; 346 Fout[ k ] += scratchbuf[q] * twiddles[twidx]; 347 } 348 k += m; 349 } 350 } 351 } 352 353 std::size_t _nfft; 354 bool _inverse; 355 std::vector<cpx_t> _twiddles; 356 std::vector<std::size_t> _stageRadix; 357 std::vector<std::size_t> _stageRemainder; 358 }; 359 #endif 360