1 // clang-format off 2 /* -*- c++ -*- ---------------------------------------------------------- 3 LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator 4 https://www.lammps.org/, Sandia National Laboratories 5 Steve Plimpton, sjplimp@sandia.gov 6 7 Copyright (2003) Sandia Corporation. Under the terms of Contract 8 DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains 9 certain rights in this software. This software is distributed under 10 the GNU General Public License. 11 12 See the README file in the top-level LAMMPS directory. 13 ------------------------------------------------------------------------- */ 14 15 /* 16 we use a stripped down KISS FFT as default FFT for LAMMPS 17 this code is adapted from kiss_fft_v1_2_9 18 homepage: http://kissfft.sf.net/ 19 20 changes 2008-2011 by Axel Kohlmeyer <akohlmey@gmail.com> 21 22 KISS FFT ported to Kokkos by Stan Moore (SNL) 23 */ 24 25 /* 26 Copyright (c) 2003-2010, Mark Borgerding 27 28 All rights reserved. 29 30 Redistribution and use in source and binary forms, with or without 31 modification, are permitted provided that the following conditions are 32 met: 33 34 * Redistributions of source code must retain the above copyright 35 notice, this list of conditions and the following disclaimer. 36 37 * Redistributions in binary form must reproduce the above copyright 38 notice, this list of conditions and the following disclaimer in 39 the documentation and/or other materials provided with the 40 distribution. 41 42 * Neither the author nor the names of any contributors may be used 43 to endorse or promote products derived from this software without 44 specific prior written permission. 45 46 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 47 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 48 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 49 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 50 OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 51 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 52 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 53 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 54 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 55 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 56 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 57 */ 58 59 60 #ifndef LMP_KISSFFT_KOKKOS_H 61 #define LMP_KISSFFT_KOKKOS_H 62 63 #include <stdlib.h> 64 #include <string.h> 65 #include <math.h> 66 #include "fftdata_kokkos.h" 67 68 #ifndef M_PI 69 #define M_PI 3.141592653589793238462643383279502884197169399375105820974944 70 #endif 71 72 /* 73 Explanation of macros dealing with complex math: 74 75 C_MUL(m,a,b) : m = a*b 76 C_FIXDIV( c , div ) : if a fixed point impl., c /= div. noop otherwise 77 C_SUB( res, a,b) : res = a - b 78 C_SUBFROM( res , a) : res -= a 79 C_ADDTO( res , a) : res += a 80 C_EQ( res , a) : res = a 81 */ 82 83 #define S_MUL(a,b) ( (a)*(b) ) 84 85 #define C_MUL(m,a,a_index,b,b_index) \ 86 do{ (m)[0] = (a)(a_index).re*(b)(b_index).re - (a)(a_index).im*(b)(b_index).im;\ 87 (m)[1] = (a)(a_index).re*(b)(b_index).im + (a)(a_index).im*(b)(b_index).re; }while(0) 88 89 /* 90 #define C_FIXDIV(c,div) // NOOP 91 92 #define C_MULBYSCALAR( c, s ) \ 93 do{ (c)[0] *= (s);\ 94 (c)[1] *= (s); }while(0) 95 96 #define C_ADD( res, a,b)\ 97 do { \ 98 (res)[0]=(a)[0]+(b)[0]; (res)[1]=(a)[1]+(b)[1]; \ 99 }while(0) 100 101 #define C_SUB( res, a,b)\ 102 do { \ 103 (res)[0]=(a)[0]-(b)[0]; (res)[1]=(a)[1]-(b)[1]; \ 104 }while(0) 105 106 #define C_ADDTO( res , a)\ 107 do { \ 108 (res)[0] += (a)[0]; (res)[1] += (a)[1];\ 109 }while(0) 110 111 #define C_SUBFROM( res , a)\ 112 do {\ 113 (res)[0] -= (a)[0]; (res)[1] -= (a)[1]; \ 114 }while(0) 115 116 #define C_EQ(res, a)\ 117 do {\ 118 (res)[0] = (a)[0]; (res)[1] = (a)[1]; \ 119 }while(0) 120 */ 121 122 #define KISS_FFT_COS(phase) (FFT_SCALAR) cos(phase) 123 #define KISS_FFT_SIN(phase) (FFT_SCALAR) sin(phase) 124 #define HALF_OF(x) ((x)*.5) 125 126 #define kf_cexp(x,x_index,phase) \ 127 do{ \ 128 (x)(x_index).re = KISS_FFT_COS(phase);\ 129 (x)(x_index).im = KISS_FFT_SIN(phase);\ 130 }while(0) 131 132 133 namespace LAMMPS_NS { 134 135 #define MAXFACTORS 32 136 /* e.g. an fft of length 128 has 4 factors 137 as far as kissfft is concerned: 4*4*4*2 */ 138 template<class DeviceType> 139 struct kiss_fft_state_kokkos { 140 typedef DeviceType device_type; 141 typedef FFTArrayTypes<DeviceType> FFT_AT; 142 int nfft; 143 int inverse; 144 typename FFT_AT::t_int_64 d_factors; 145 typename FFT_AT::t_FFT_DATA_1d d_twiddles; 146 typename FFT_AT::t_FFT_DATA_1d d_scratch; 147 }; 148 149 template<class DeviceType> 150 class KissFFTKokkos { 151 public: 152 typedef DeviceType device_type; 153 typedef FFTArrayTypes<DeviceType> FFT_AT; 154 155 KOKKOS_INLINE_FUNCTION kf_bfly2(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int Fout_count)156 static void kf_bfly2(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride, 157 const kiss_fft_state_kokkos<DeviceType> &st, int m, int Fout_count) 158 { 159 typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles; 160 FFT_SCALAR t[2]; 161 int Fout2_count; 162 int tw1_count = 0; 163 164 Fout2_count = Fout_count + m; 165 do { 166 //C_FIXDIV(d_Fout[Fout_count],2); C_FIXDIV(d_Fout[Fout2_count],2); 167 168 C_MUL(t,d_Fout,Fout2_count,d_twiddles,tw1_count); 169 tw1_count += fstride; 170 //C_SUB(*Fout2,*Fout,t); 171 d_Fout(Fout2_count).re = d_Fout(Fout_count).re - t[0]; 172 d_Fout(Fout2_count).im = d_Fout(Fout_count).im - t[1]; 173 //C_ADDTO(d_Fout[Fout_count],t); 174 d_Fout(Fout_count).re += t[0]; 175 d_Fout(Fout_count).im += t[1]; 176 ++Fout2_count; 177 ++Fout_count; 178 } while(--m); 179 } 180 181 KOKKOS_INLINE_FUNCTION kf_bfly4(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,const size_t m,int Fout_count)182 static void kf_bfly4(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride, 183 const kiss_fft_state_kokkos<DeviceType> &st, const size_t m, int Fout_count) 184 { 185 typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles; 186 FFT_SCALAR scratch[6][2]; 187 size_t k=m; 188 const size_t m2=2*m; 189 const size_t m3=3*m; 190 191 int tw3_count,tw2_count,tw1_count; 192 tw3_count = tw2_count = tw1_count = 0; 193 194 do { 195 //C_FIXDIV(d_Fout[Fout_count],4); C_FIXDIV(d_Fout[m],4); C_FIXDIV(d_Fout[m2],4); C_FIXDIV(d_Fout[m3],4); 196 197 C_MUL(scratch[0],d_Fout,Fout_count + m,d_twiddles,tw1_count); 198 C_MUL(scratch[1],d_Fout,Fout_count + m2,d_twiddles,tw2_count); 199 C_MUL(scratch[2],d_Fout,Fout_count + m3,d_twiddles,tw3_count); 200 201 //C_SUB(scratch[5],d_Fout[Fout_count],scratch[1] ); 202 scratch[5][0] = d_Fout(Fout_count).re - scratch[1][0]; 203 scratch[5][1] = d_Fout(Fout_count).im - scratch[1][1]; 204 //C_ADDTO(d_Fout[Fout_count], scratch[1]); 205 d_Fout(Fout_count).re += scratch[1][0]; 206 d_Fout(Fout_count).im += scratch[1][1]; 207 //C_ADD(scratch[3],scratch[0],scratch[2]); 208 scratch[3][0] = scratch[0][0] + scratch[2][0]; 209 scratch[3][1] = scratch[0][1] + scratch[2][1]; 210 //C_SUB( scratch[4] , scratch[0] , scratch[2] ); 211 scratch[4][0] = scratch[0][0] - scratch[2][0]; 212 scratch[4][1] = scratch[0][1] - scratch[2][1]; 213 //C_SUB(d_Fout[m2],d_Fout[Fout_count],scratch[3]); 214 d_Fout(Fout_count + m2).re = d_Fout(Fout_count).re - scratch[3][0]; 215 d_Fout(Fout_count + m2).im = d_Fout(Fout_count).im - scratch[3][1]; 216 217 tw1_count += fstride; 218 tw2_count += fstride*2; 219 tw3_count += fstride*3; 220 //C_ADDTO(d_Fout[Fout_count],scratch[3]); 221 d_Fout(Fout_count).re += scratch[3][0]; 222 d_Fout(Fout_count).im += scratch[3][1]; 223 224 if (st.inverse) { 225 d_Fout(Fout_count + m).re = scratch[5][0] - scratch[4][1]; 226 d_Fout(Fout_count + m).im = scratch[5][1] + scratch[4][0]; 227 d_Fout(Fout_count + m3).re = scratch[5][0] + scratch[4][1]; 228 d_Fout(Fout_count + m3).im = scratch[5][1] - scratch[4][0]; 229 } else{ 230 d_Fout(Fout_count + m).re = scratch[5][0] + scratch[4][1]; 231 d_Fout(Fout_count + m).im = scratch[5][1] - scratch[4][0]; 232 d_Fout(Fout_count + m3).re = scratch[5][0] - scratch[4][1]; 233 d_Fout(Fout_count + m3).im = scratch[5][1] + scratch[4][0]; 234 } 235 ++Fout_count; 236 } while(--k); 237 } 238 239 KOKKOS_INLINE_FUNCTION kf_bfly3(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,size_t m,int Fout_count)240 static void kf_bfly3(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride, 241 const kiss_fft_state_kokkos<DeviceType> &st, size_t m, int Fout_count) 242 { 243 size_t k=m; 244 const size_t m2 = 2*m; 245 typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles; 246 FFT_SCALAR scratch[5][2]; 247 FFT_SCALAR epi3[2]; 248 //C_EQ(epi3,d_twiddles[fstride*m]); 249 epi3[0] = d_twiddles(fstride*m).re; 250 epi3[1] = d_twiddles(fstride*m).im; 251 252 int tw1_count,tw2_count; 253 tw1_count = tw2_count = 0; 254 255 do { 256 //C_FIXDIV(d_Fout[Fout_count],3); C_FIXDIV(d_Fout[m],3); C_FIXDIV(d_Fout[m2],3); 257 258 C_MUL(scratch[1],d_Fout,Fout_count + m,d_twiddles,tw1_count); 259 C_MUL(scratch[2],d_Fout,Fout_count + m2,d_twiddles,tw2_count); 260 261 //C_ADD(scratch[3],scratch[1],scratch[2]); 262 scratch[3][0] = scratch[1][0] + scratch[2][0]; 263 scratch[3][1] = scratch[1][1] + scratch[2][1]; 264 //C_SUB(scratch[0],scratch[1],scratch[2]); 265 scratch[0][0] = scratch[1][0] - scratch[2][0]; 266 scratch[0][1] = scratch[1][1] - scratch[2][1]; 267 tw1_count += fstride; 268 tw2_count += fstride*2; 269 270 d_Fout(Fout_count + m).re = d_Fout(Fout_count).re - HALF_OF(scratch[3][0]); 271 d_Fout(Fout_count + m).im = d_Fout(Fout_count).im - HALF_OF(scratch[3][1]); 272 273 //C_MULBYSCALAR(scratch[0],epi3[1]); 274 scratch[0][0] *= epi3[1]; 275 scratch[0][1] *= epi3[1]; 276 277 //C_ADDTO(d_Fout[Fout_count],scratch[3]); 278 d_Fout(Fout_count).re += scratch[3][0]; 279 d_Fout(Fout_count).im += scratch[3][1]; 280 281 d_Fout(Fout_count + m2).re = d_Fout(Fout_count + m).re + scratch[0][1]; 282 d_Fout(Fout_count + m2).im = d_Fout(Fout_count + m).im - scratch[0][0]; 283 284 d_Fout(Fout_count + m).re -= scratch[0][1]; 285 d_Fout(Fout_count + m).im += scratch[0][0]; 286 287 ++Fout_count; 288 } while(--k); 289 } 290 291 KOKKOS_INLINE_FUNCTION kf_bfly5(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int Fout_count)292 static void kf_bfly5(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride, 293 const kiss_fft_state_kokkos<DeviceType> &st, int m, int Fout_count) 294 { 295 int u; 296 FFT_SCALAR scratch[13][2]; 297 typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles; 298 FFT_SCALAR ya[2],yb[2]; 299 //C_EQ(ya,d_twiddles[fstride*m]); 300 ya[1] = d_twiddles(fstride*m).im; 301 ya[0] = d_twiddles(fstride*m).re; 302 //C_EQ(yb,d_twiddles[fstride*2*m]); 303 yb[1] = d_twiddles(fstride*2*m).im; 304 yb[0] = d_twiddles(fstride*2*m).re; 305 306 int Fout0_count=Fout_count; 307 int Fout1_count=Fout0_count+m; 308 int Fout2_count=Fout0_count+2*m; 309 int Fout3_count=Fout0_count+3*m; 310 int Fout4_count=Fout0_count+4*m; 311 312 for ( u=0; u<m; ++u ) { 313 //C_FIXDIV( d_Fout[Fout0_count],5); C_FIXDIV( d_Fout[Fout1_count],5); C_FIXDIV( d_Fout[Fout2_count],5); 314 //C_FIXDIV( d_Fout[Fout3_count],5); C_FIXDIV( d_Fout[Fout4_count],5); 315 //C_EQ(scratch[0],d_Fout[Fout0_count]); 316 scratch[0][0] = d_Fout(Fout0_count).re; 317 scratch[0][1] = d_Fout(Fout0_count).im; 318 319 C_MUL(scratch[1],d_Fout,Fout1_count,d_twiddles,u*fstride ); 320 C_MUL(scratch[2],d_Fout,Fout2_count,d_twiddles,2*u*fstride); 321 C_MUL(scratch[3],d_Fout,Fout3_count,d_twiddles,3*u*fstride); 322 C_MUL(scratch[4],d_Fout,Fout4_count,d_twiddles,4*u*fstride); 323 324 //C_ADD(scratch[7],scratch[1],scratch[4]); 325 scratch[7][0] = scratch[1][0] + scratch[4][0]; 326 scratch[7][1] = scratch[1][1] + scratch[4][1]; 327 //C_SUB(scratch[10],scratch[1],scratch[4]); 328 scratch[10][0] = scratch[1][0] - scratch[4][0]; 329 scratch[10][1] = scratch[1][1] - scratch[4][1]; 330 //C_ADD(scratch[8],scratch[2],scratch[3]); 331 scratch[8][0] = scratch[2][0] + scratch[3][0]; 332 scratch[8][1] = scratch[2][1] + scratch[3][1]; 333 //C_SUB(scratch[9],scratch[2],scratch[3]); 334 scratch[9][0] = scratch[2][0] - scratch[3][0]; 335 scratch[9][1] = scratch[2][1] - scratch[3][1]; 336 337 d_Fout(Fout0_count).re += scratch[7][0] + scratch[8][0]; 338 d_Fout(Fout0_count).im += scratch[7][1] + scratch[8][1]; 339 340 scratch[5][0] = scratch[0][0] + S_MUL(scratch[7][0],ya[0]) + S_MUL(scratch[8][0],yb[0]); 341 scratch[5][1] = scratch[0][1] + S_MUL(scratch[7][1],ya[0]) + S_MUL(scratch[8][1],yb[0]); 342 343 scratch[6][0] = S_MUL(scratch[10][1],ya[1]) + S_MUL(scratch[9][1],yb[1]); 344 scratch[6][1] = -S_MUL(scratch[10][0],ya[1]) - S_MUL(scratch[9][0],yb[1]); 345 346 //C_SUB(d_Fout[Fout1_count],scratch[5],scratch[6]); 347 d_Fout(Fout1_count).re = scratch[5][0] - scratch[6][0]; 348 d_Fout(Fout1_count).im = scratch[5][1] - scratch[6][1]; 349 //C_ADD(d_Fout[Fout4_count],scratch[5],scratch[6]); 350 d_Fout(Fout4_count).re = scratch[5][0] + scratch[6][0]; 351 d_Fout(Fout4_count).im = scratch[5][1] + scratch[6][1]; 352 353 scratch[11][0] = scratch[0][0] + S_MUL(scratch[7][0],yb[0]) + S_MUL(scratch[8][0],ya[0]); 354 scratch[11][1] = scratch[0][1] + S_MUL(scratch[7][1],yb[0]) + S_MUL(scratch[8][1],ya[0]); 355 scratch[12][0] = - S_MUL(scratch[10][1],yb[1]) + S_MUL(scratch[9][1],ya[1]); 356 scratch[12][1] = S_MUL(scratch[10][0],yb[1]) - S_MUL(scratch[9][0],ya[1]); 357 358 //C_ADD(d_Fout[Fout2_count],scratch[11],scratch[12]); 359 d_Fout(Fout2_count).re = scratch[11][0] + scratch[12][0]; 360 d_Fout(Fout2_count).im = scratch[11][1] + scratch[12][1]; 361 //C_SUB(d_Fout3[Fout3_count],scratch[11],scratch[12]); 362 d_Fout(Fout3_count).re = scratch[11][0] - scratch[12][0]; 363 d_Fout(Fout3_count).im = scratch[11][1] - scratch[12][1]; 364 365 ++Fout0_count;++Fout1_count;++Fout2_count;++Fout3_count;++Fout4_count; 366 } 367 } 368 369 /* perform the butterfly for one stage of a mixed radix FFT */ 370 371 KOKKOS_INLINE_FUNCTION kf_bfly_generic(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int p,int Fout_count)372 static void kf_bfly_generic(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride, 373 const kiss_fft_state_kokkos<DeviceType> &st, int m, int p, int Fout_count) 374 { 375 int u,k,q1,q; 376 typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles; 377 FFT_SCALAR t[2]; 378 int Norig = st.nfft; 379 380 typename FFT_AT::t_FFT_DATA_1d_um d_scratch = st.d_scratch; 381 for ( u=0; u<m; ++u ) { 382 k=u; 383 for ( q1=0 ; q1<p ; ++q1 ) { 384 //C_EQ(d_scratch[q1],d_Fout[k]); 385 d_scratch(q1).re = d_Fout(Fout_count + k).re; 386 d_scratch(q1).im = d_Fout(Fout_count + k).im; 387 //C_FIXDIV(d_scratch[q1],p); 388 k += m; 389 } 390 391 k=u; 392 for ( q1=0 ; q1<p ; ++q1 ) { 393 int twidx=0; 394 //C_EQ(d_Fout[k],d_scratch[0]); 395 d_Fout(Fout_count + k).re = d_scratch(0).re; 396 d_Fout(Fout_count + k).im = d_scratch(0).im; 397 for (q=1;q<p;++q ) { 398 twidx += fstride * k; 399 if (twidx>=Norig) twidx-=Norig; 400 C_MUL(t,d_scratch,q,d_twiddles,twidx); 401 //C_ADDTO(d_Fout[k],t); 402 d_Fout(Fout_count + k).re += t[0]; 403 d_Fout(Fout_count + k).im += t[1]; 404 } 405 k += m; 406 } 407 } 408 } 409 410 KOKKOS_INLINE_FUNCTION kf_work(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const typename FFT_AT::t_FFT_DATA_1d_um & d_f,const size_t fstride,int in_stride,const typename FFT_AT::t_int_64_um & d_factors,const kiss_fft_state_kokkos<DeviceType> & st,int Fout_count,int f_count,int factors_count)411 static void kf_work(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const typename FFT_AT::t_FFT_DATA_1d_um &d_f, 412 const size_t fstride, int in_stride, 413 const typename FFT_AT::t_int_64_um &d_factors, const kiss_fft_state_kokkos<DeviceType> &st, int Fout_count, int f_count, int factors_count) 414 { 415 const int beg = Fout_count; 416 const int p = d_factors[factors_count++]; /* the radix */ 417 const int m = d_factors[factors_count++]; /* stage's fft length/p */ 418 const int end = Fout_count + p*m; 419 420 if (m == 1) { 421 do { 422 //C_EQ(d_Fout[Fout_count],d_f[f_count]); 423 d_Fout(Fout_count).re = d_f(f_count).re; 424 d_Fout(Fout_count).im = d_f(f_count).im; 425 f_count += fstride*in_stride; 426 } while (++Fout_count != end); 427 } else { 428 do { 429 /* recursive call: 430 DFT of size m*p performed by doing 431 p instances of smaller DFTs of size m, 432 each one takes a decimated version of the input */ 433 kf_work(d_Fout, d_f, fstride*p, in_stride, d_factors, st, Fout_count, f_count, factors_count); 434 f_count += fstride*in_stride; 435 } while( (Fout_count += m) != end); 436 } 437 438 Fout_count=beg; 439 440 /* recombine the p smaller DFTs */ 441 switch (p) { 442 case 2: kf_bfly2(d_Fout,fstride,st,m,Fout_count); break; 443 case 3: kf_bfly3(d_Fout,fstride,st,m,Fout_count); break; 444 case 4: kf_bfly4(d_Fout,fstride,st,m,Fout_count); break; 445 case 5: kf_bfly5(d_Fout,fstride,st,m,Fout_count); break; 446 default: kf_bfly_generic(d_Fout,fstride,st,m,p,Fout_count); break; 447 } 448 } 449 450 /* facbuf is populated by p1,m1,p2,m2, ... 451 where 452 p[i] * m[i] = m[i-1] 453 m0 = n */ 454 kf_factor(int n,FFT_HAT::t_int_64 h_facbuf)455 static int kf_factor(int n, FFT_HAT::t_int_64 h_facbuf) 456 { 457 int p=4, nf=0; 458 double floor_sqrt; 459 floor_sqrt = floor( sqrt((double)n) ); 460 int facbuf_count = 0; 461 int p_max = 0; 462 463 /* factor out the remaining powers of 4, powers of 2, 464 and then any other remaining primes */ 465 do { 466 if (nf == MAXFACTORS) p = n; /* make certain that we don't run out of space */ 467 while (n % p) { 468 switch (p) { 469 case 4: p = 2; break; 470 case 2: p = 3; break; 471 default: p += 2; break; 472 } 473 if (p > floor_sqrt) 474 p = n; /* no more factors, skip to end */ 475 } 476 n /= p; 477 h_facbuf[facbuf_count++] = p; 478 h_facbuf[facbuf_count++] = n; 479 p_max = MAX(p,p_max); 480 ++nf; 481 } while (n > 1); 482 return p_max; 483 } 484 485 /* 486 * User-callable function to allocate all necessary storage space for the fft. 487 * 488 * The return value is a contiguous block of memory, allocated with malloc. As such, 489 * It can be freed with free(), rather than a kiss_fft-specific function. 490 */ 491 kiss_fft_alloc_kokkos(int nfft,int inverse_fft,void * mem,size_t * lenmem)492 static kiss_fft_state_kokkos<DeviceType> kiss_fft_alloc_kokkos(int nfft, int inverse_fft, void *mem, size_t *lenmem) 493 { 494 kiss_fft_state_kokkos<DeviceType> st; 495 int i; 496 st.nfft = nfft; 497 st.inverse = inverse_fft; 498 499 typename FFT_AT::tdual_int_64 k_factors = typename FFT_AT::tdual_int_64(); 500 typename FFT_AT::tdual_FFT_DATA_1d k_twiddles = typename FFT_AT::tdual_FFT_DATA_1d(); 501 502 if (nfft > 0) { 503 k_factors = typename FFT_AT::tdual_int_64("kissfft:factors",MAXFACTORS*2); 504 k_twiddles = typename FFT_AT::tdual_FFT_DATA_1d("kissfft:twiddles",nfft); 505 506 for (i=0;i<nfft;++i) { 507 const double phase = (st.inverse ? 2.0*M_PI:-2.0*M_PI)*i / nfft; 508 kf_cexp(k_twiddles.h_view,i,phase ); 509 } 510 511 int p_max = kf_factor(nfft,k_factors.h_view); 512 st.d_scratch = typename FFT_AT::t_FFT_DATA_1d("kissfft:scratch",p_max); 513 } 514 515 k_factors.template modify<LMPHostType>(); 516 k_factors.template sync<LMPDeviceType>(); 517 st.d_factors = k_factors.template view<DeviceType>(); 518 519 k_twiddles.template modify<LMPHostType>(); 520 k_twiddles.template sync<LMPDeviceType>(); 521 st.d_twiddles = k_twiddles.template view<DeviceType>(); 522 523 return st; 524 } 525 526 KOKKOS_INLINE_FUNCTION kiss_fft_stride(const kiss_fft_state_kokkos<DeviceType> & st,const typename FFT_AT::t_FFT_DATA_1d_um & d_fin,typename FFT_AT::t_FFT_DATA_1d_um & d_fout,int in_stride,int offset)527 static void kiss_fft_stride(const kiss_fft_state_kokkos<DeviceType> &st, const typename FFT_AT::t_FFT_DATA_1d_um &d_fin, typename FFT_AT::t_FFT_DATA_1d_um &d_fout, int in_stride, int offset) 528 { 529 //if (d_fin.data() == d_fout.data()) { 530 // // NOTE: this is not really an in-place FFT algorithm. 531 // // It just performs an out-of-place FFT into a temp buffer 532 // typename FFT_AT::t_FFT_DATA_1d_um d_tmpbuf = typename FFT_AT::t_FFT_DATA_1d("kissfft:tmpbuf",d_fin.extent(1)); 533 // kf_work(d_tmpbuf,d_fin,1,in_stride,st.d_factors,st,offset,offset).re; 534 // Kokkos::deep_copy(d_fout,d_tmpbuf); 535 //} else { 536 kf_work(d_fout,d_fin,1,in_stride,st.d_factors,st,offset,offset,0); 537 //} 538 } 539 540 KOKKOS_INLINE_FUNCTION kiss_fft_kokkos(const kiss_fft_state_kokkos<DeviceType> & cfg,const typename FFT_AT::t_FFT_DATA_1d_um d_fin,typename FFT_AT::t_FFT_DATA_1d_um d_fout,int offset)541 static void kiss_fft_kokkos(const kiss_fft_state_kokkos<DeviceType> &cfg, const typename FFT_AT::t_FFT_DATA_1d_um d_fin, typename FFT_AT::t_FFT_DATA_1d_um d_fout, int offset) 542 { 543 kiss_fft_stride(cfg,d_fin,d_fout,1,offset); 544 } 545 546 }; 547 548 } 549 #endif 550