1 // clang-format off
2 /* -*- c++ -*- ----------------------------------------------------------
3    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
4    https://www.lammps.org/, Sandia National Laboratories
5    Steve Plimpton, sjplimp@sandia.gov
6 
7    Copyright (2003) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level LAMMPS directory.
13 ------------------------------------------------------------------------- */
14 
15 /*
16    we use a stripped down KISS FFT as default FFT for LAMMPS
17    this code is adapted from kiss_fft_v1_2_9
18    homepage: http://kissfft.sf.net/
19 
20    changes 2008-2011 by Axel Kohlmeyer <akohlmey@gmail.com>
21 
22    KISS FFT ported to Kokkos by Stan Moore (SNL)
23 */
24 
25 /*
26   Copyright (c) 2003-2010, Mark Borgerding
27 
28   All rights reserved.
29 
30   Redistribution and use in source and binary forms, with or without
31   modification, are permitted provided that the following conditions are
32   met:
33 
34     * Redistributions of source code must retain the above copyright
35       notice, this list of conditions and the following disclaimer.
36 
37     * Redistributions in binary form must reproduce the above copyright
38       notice, this list of conditions and the following disclaimer in
39       the documentation and/or other materials provided with the
40       distribution.
41 
42     * Neither the author nor the names of any contributors may be used
43       to endorse or promote products derived from this software without
44       specific prior written permission.
45 
46   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
47   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
48   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
49   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
50   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
51   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
52   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
53   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
54   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
55   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
56   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
57 */
58 
59 
60 #ifndef LMP_KISSFFT_KOKKOS_H
61 #define LMP_KISSFFT_KOKKOS_H
62 
63 #include <stdlib.h>
64 #include <string.h>
65 #include <math.h>
66 #include "fftdata_kokkos.h"
67 
68 #ifndef M_PI
69 #define M_PI 3.141592653589793238462643383279502884197169399375105820974944
70 #endif
71 
72 /*
73   Explanation of macros dealing with complex math:
74 
75    C_MUL(m,a,b)         : m = a*b
76    C_FIXDIV( c , div )  : if a fixed point impl., c /= div. noop otherwise
77    C_SUB( res, a,b)     : res = a - b
78    C_SUBFROM( res , a)  : res -= a
79    C_ADDTO( res , a)    : res += a
80    C_EQ( res , a)       : res = a
81 */
82 
83 #define S_MUL(a,b) ( (a)*(b) )
84 
85 #define C_MUL(m,a,a_index,b,b_index) \
86     do{ (m)[0] = (a)(a_index).re*(b)(b_index).re - (a)(a_index).im*(b)(b_index).im;\
87         (m)[1] = (a)(a_index).re*(b)(b_index).im + (a)(a_index).im*(b)(b_index).re; }while(0)
88 
89 /*
90 #define C_FIXDIV(c,div) // NOOP
91 
92 #define C_MULBYSCALAR( c, s ) \
93     do{ (c)[0] *= (s);\
94         (c)[1] *= (s); }while(0)
95 
96 #define  C_ADD( res, a,b)\
97     do { \
98             (res)[0]=(a)[0]+(b)[0];  (res)[1]=(a)[1]+(b)[1]; \
99     }while(0)
100 
101 #define  C_SUB( res, a,b)\
102     do { \
103             (res)[0]=(a)[0]-(b)[0];  (res)[1]=(a)[1]-(b)[1]; \
104     }while(0)
105 
106 #define C_ADDTO( res , a)\
107     do { \
108             (res)[0] += (a)[0];  (res)[1] += (a)[1];\
109     }while(0)
110 
111 #define C_SUBFROM( res , a)\
112     do {\
113             (res)[0] -= (a)[0];  (res)[1] -= (a)[1]; \
114     }while(0)
115 
116 #define C_EQ(res, a)\
117     do {\
118             (res)[0] = (a)[0];  (res)[1] = (a)[1]; \
119     }while(0)
120 */
121 
122 #define KISS_FFT_COS(phase) (FFT_SCALAR) cos(phase)
123 #define KISS_FFT_SIN(phase) (FFT_SCALAR) sin(phase)
124 #define HALF_OF(x) ((x)*.5)
125 
126 #define  kf_cexp(x,x_index,phase) \
127         do{ \
128                 (x)(x_index).re = KISS_FFT_COS(phase);\
129                 (x)(x_index).im = KISS_FFT_SIN(phase);\
130         }while(0)
131 
132 
133 namespace LAMMPS_NS {
134 
135 #define MAXFACTORS 32
136 /* e.g. an fft of length 128 has 4 factors
137  as far as kissfft is concerned: 4*4*4*2  */
138 template<class DeviceType>
139 struct kiss_fft_state_kokkos {
140   typedef DeviceType device_type;
141   typedef FFTArrayTypes<DeviceType> FFT_AT;
142   int nfft;
143   int inverse;
144   typename FFT_AT::t_int_64 d_factors;
145   typename FFT_AT::t_FFT_DATA_1d d_twiddles;
146   typename FFT_AT::t_FFT_DATA_1d d_scratch;
147 };
148 
149 template<class DeviceType>
150 class KissFFTKokkos {
151  public:
152   typedef DeviceType device_type;
153   typedef FFTArrayTypes<DeviceType> FFT_AT;
154 
155   KOKKOS_INLINE_FUNCTION
kf_bfly2(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int Fout_count)156   static void kf_bfly2(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride,
157                        const kiss_fft_state_kokkos<DeviceType> &st, int m, int Fout_count)
158   {
159       typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles;
160       FFT_SCALAR t[2];
161       int Fout2_count;
162       int tw1_count = 0;
163 
164       Fout2_count = Fout_count + m;
165       do {
166           //C_FIXDIV(d_Fout[Fout_count],2); C_FIXDIV(d_Fout[Fout2_count],2);
167 
168           C_MUL(t,d_Fout,Fout2_count,d_twiddles,tw1_count);
169           tw1_count += fstride;
170           //C_SUB(*Fout2,*Fout,t);
171           d_Fout(Fout2_count).re = d_Fout(Fout_count).re - t[0];
172           d_Fout(Fout2_count).im = d_Fout(Fout_count).im - t[1];
173           //C_ADDTO(d_Fout[Fout_count],t);
174           d_Fout(Fout_count).re += t[0];
175           d_Fout(Fout_count).im += t[1];
176           ++Fout2_count;
177           ++Fout_count;
178       } while(--m);
179   }
180 
181   KOKKOS_INLINE_FUNCTION
kf_bfly4(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,const size_t m,int Fout_count)182   static void kf_bfly4(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride,
183                        const kiss_fft_state_kokkos<DeviceType> &st, const size_t m, int Fout_count)
184   {
185       typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles;
186       FFT_SCALAR scratch[6][2];
187       size_t k=m;
188       const size_t m2=2*m;
189       const size_t m3=3*m;
190 
191       int tw3_count,tw2_count,tw1_count;
192       tw3_count = tw2_count = tw1_count = 0;
193 
194       do {
195           //C_FIXDIV(d_Fout[Fout_count],4); C_FIXDIV(d_Fout[m],4); C_FIXDIV(d_Fout[m2],4); C_FIXDIV(d_Fout[m3],4);
196 
197           C_MUL(scratch[0],d_Fout,Fout_count + m,d_twiddles,tw1_count);
198           C_MUL(scratch[1],d_Fout,Fout_count + m2,d_twiddles,tw2_count);
199           C_MUL(scratch[2],d_Fout,Fout_count + m3,d_twiddles,tw3_count);
200 
201           //C_SUB(scratch[5],d_Fout[Fout_count],scratch[1] );
202           scratch[5][0] = d_Fout(Fout_count).re - scratch[1][0];
203           scratch[5][1] = d_Fout(Fout_count).im - scratch[1][1];
204           //C_ADDTO(d_Fout[Fout_count], scratch[1]);
205           d_Fout(Fout_count).re += scratch[1][0];
206           d_Fout(Fout_count).im += scratch[1][1];
207           //C_ADD(scratch[3],scratch[0],scratch[2]);
208           scratch[3][0] = scratch[0][0] + scratch[2][0];
209           scratch[3][1] = scratch[0][1] + scratch[2][1];
210           //C_SUB( scratch[4] , scratch[0] , scratch[2] );
211           scratch[4][0] = scratch[0][0] - scratch[2][0];
212           scratch[4][1] = scratch[0][1] - scratch[2][1];
213           //C_SUB(d_Fout[m2],d_Fout[Fout_count],scratch[3]);
214           d_Fout(Fout_count + m2).re = d_Fout(Fout_count).re - scratch[3][0];
215           d_Fout(Fout_count + m2).im = d_Fout(Fout_count).im - scratch[3][1];
216 
217           tw1_count += fstride;
218           tw2_count += fstride*2;
219           tw3_count += fstride*3;
220           //C_ADDTO(d_Fout[Fout_count],scratch[3]);
221           d_Fout(Fout_count).re += scratch[3][0];
222           d_Fout(Fout_count).im += scratch[3][1];
223 
224           if (st.inverse) {
225               d_Fout(Fout_count + m).re = scratch[5][0] - scratch[4][1];
226               d_Fout(Fout_count + m).im = scratch[5][1] + scratch[4][0];
227               d_Fout(Fout_count + m3).re = scratch[5][0] + scratch[4][1];
228               d_Fout(Fout_count + m3).im = scratch[5][1] - scratch[4][0];
229           } else{
230               d_Fout(Fout_count + m).re = scratch[5][0] + scratch[4][1];
231               d_Fout(Fout_count + m).im = scratch[5][1] - scratch[4][0];
232               d_Fout(Fout_count + m3).re = scratch[5][0] - scratch[4][1];
233               d_Fout(Fout_count + m3).im = scratch[5][1] + scratch[4][0];
234           }
235           ++Fout_count;
236       } while(--k);
237   }
238 
239   KOKKOS_INLINE_FUNCTION
kf_bfly3(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,size_t m,int Fout_count)240   static void kf_bfly3(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride,
241                        const kiss_fft_state_kokkos<DeviceType> &st, size_t m, int Fout_count)
242   {
243       size_t k=m;
244       const size_t m2 = 2*m;
245       typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles;
246       FFT_SCALAR scratch[5][2];
247       FFT_SCALAR epi3[2];
248       //C_EQ(epi3,d_twiddles[fstride*m]);
249       epi3[0] = d_twiddles(fstride*m).re;
250       epi3[1] = d_twiddles(fstride*m).im;
251 
252       int tw1_count,tw2_count;
253       tw1_count = tw2_count = 0;
254 
255       do {
256           //C_FIXDIV(d_Fout[Fout_count],3); C_FIXDIV(d_Fout[m],3); C_FIXDIV(d_Fout[m2],3);
257 
258           C_MUL(scratch[1],d_Fout,Fout_count + m,d_twiddles,tw1_count);
259           C_MUL(scratch[2],d_Fout,Fout_count + m2,d_twiddles,tw2_count);
260 
261           //C_ADD(scratch[3],scratch[1],scratch[2]);
262           scratch[3][0] = scratch[1][0] + scratch[2][0];
263           scratch[3][1] = scratch[1][1] + scratch[2][1];
264           //C_SUB(scratch[0],scratch[1],scratch[2]);
265           scratch[0][0] = scratch[1][0] - scratch[2][0];
266           scratch[0][1] = scratch[1][1] - scratch[2][1];
267           tw1_count += fstride;
268           tw2_count += fstride*2;
269 
270           d_Fout(Fout_count + m).re = d_Fout(Fout_count).re - HALF_OF(scratch[3][0]);
271           d_Fout(Fout_count + m).im = d_Fout(Fout_count).im - HALF_OF(scratch[3][1]);
272 
273           //C_MULBYSCALAR(scratch[0],epi3[1]);
274           scratch[0][0] *= epi3[1];
275           scratch[0][1] *= epi3[1];
276 
277           //C_ADDTO(d_Fout[Fout_count],scratch[3]);
278           d_Fout(Fout_count).re += scratch[3][0];
279           d_Fout(Fout_count).im += scratch[3][1];
280 
281           d_Fout(Fout_count + m2).re = d_Fout(Fout_count + m).re + scratch[0][1];
282           d_Fout(Fout_count + m2).im = d_Fout(Fout_count + m).im - scratch[0][0];
283 
284           d_Fout(Fout_count + m).re -= scratch[0][1];
285           d_Fout(Fout_count + m).im += scratch[0][0];
286 
287           ++Fout_count;
288       } while(--k);
289   }
290 
291   KOKKOS_INLINE_FUNCTION
kf_bfly5(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int Fout_count)292   static void kf_bfly5(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride,
293                        const kiss_fft_state_kokkos<DeviceType> &st, int m, int Fout_count)
294   {
295       int u;
296       FFT_SCALAR scratch[13][2];
297       typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles;
298       FFT_SCALAR ya[2],yb[2];
299       //C_EQ(ya,d_twiddles[fstride*m]);
300       ya[1] = d_twiddles(fstride*m).im;
301       ya[0] = d_twiddles(fstride*m).re;
302       //C_EQ(yb,d_twiddles[fstride*2*m]);
303       yb[1] = d_twiddles(fstride*2*m).im;
304       yb[0] = d_twiddles(fstride*2*m).re;
305 
306       int Fout0_count=Fout_count;
307       int Fout1_count=Fout0_count+m;
308       int Fout2_count=Fout0_count+2*m;
309       int Fout3_count=Fout0_count+3*m;
310       int Fout4_count=Fout0_count+4*m;
311 
312       for ( u=0; u<m; ++u ) {
313           //C_FIXDIV( d_Fout[Fout0_count],5); C_FIXDIV( d_Fout[Fout1_count],5); C_FIXDIV( d_Fout[Fout2_count],5);
314           //C_FIXDIV( d_Fout[Fout3_count],5); C_FIXDIV( d_Fout[Fout4_count],5);
315           //C_EQ(scratch[0],d_Fout[Fout0_count]);
316           scratch[0][0] = d_Fout(Fout0_count).re;
317           scratch[0][1] = d_Fout(Fout0_count).im;
318 
319           C_MUL(scratch[1],d_Fout,Fout1_count,d_twiddles,u*fstride  );
320           C_MUL(scratch[2],d_Fout,Fout2_count,d_twiddles,2*u*fstride);
321           C_MUL(scratch[3],d_Fout,Fout3_count,d_twiddles,3*u*fstride);
322           C_MUL(scratch[4],d_Fout,Fout4_count,d_twiddles,4*u*fstride);
323 
324           //C_ADD(scratch[7],scratch[1],scratch[4]);
325           scratch[7][0] = scratch[1][0] + scratch[4][0];
326           scratch[7][1] = scratch[1][1] + scratch[4][1];
327           //C_SUB(scratch[10],scratch[1],scratch[4]);
328           scratch[10][0] = scratch[1][0] - scratch[4][0];
329           scratch[10][1] = scratch[1][1] - scratch[4][1];
330           //C_ADD(scratch[8],scratch[2],scratch[3]);
331           scratch[8][0] = scratch[2][0] + scratch[3][0];
332           scratch[8][1] = scratch[2][1] + scratch[3][1];
333           //C_SUB(scratch[9],scratch[2],scratch[3]);
334           scratch[9][0] = scratch[2][0] - scratch[3][0];
335           scratch[9][1] = scratch[2][1] - scratch[3][1];
336 
337           d_Fout(Fout0_count).re += scratch[7][0] + scratch[8][0];
338           d_Fout(Fout0_count).im += scratch[7][1] + scratch[8][1];
339 
340           scratch[5][0] = scratch[0][0] + S_MUL(scratch[7][0],ya[0]) + S_MUL(scratch[8][0],yb[0]);
341           scratch[5][1] = scratch[0][1] + S_MUL(scratch[7][1],ya[0]) + S_MUL(scratch[8][1],yb[0]);
342 
343           scratch[6][0] =  S_MUL(scratch[10][1],ya[1]) + S_MUL(scratch[9][1],yb[1]);
344           scratch[6][1] = -S_MUL(scratch[10][0],ya[1]) - S_MUL(scratch[9][0],yb[1]);
345 
346           //C_SUB(d_Fout[Fout1_count],scratch[5],scratch[6]);
347           d_Fout(Fout1_count).re = scratch[5][0] - scratch[6][0];
348           d_Fout(Fout1_count).im = scratch[5][1] - scratch[6][1];
349           //C_ADD(d_Fout[Fout4_count],scratch[5],scratch[6]);
350           d_Fout(Fout4_count).re = scratch[5][0] + scratch[6][0];
351           d_Fout(Fout4_count).im = scratch[5][1] + scratch[6][1];
352 
353           scratch[11][0] = scratch[0][0] + S_MUL(scratch[7][0],yb[0]) + S_MUL(scratch[8][0],ya[0]);
354           scratch[11][1] = scratch[0][1] + S_MUL(scratch[7][1],yb[0]) + S_MUL(scratch[8][1],ya[0]);
355           scratch[12][0] = - S_MUL(scratch[10][1],yb[1]) + S_MUL(scratch[9][1],ya[1]);
356           scratch[12][1] = S_MUL(scratch[10][0],yb[1]) - S_MUL(scratch[9][0],ya[1]);
357 
358           //C_ADD(d_Fout[Fout2_count],scratch[11],scratch[12]);
359           d_Fout(Fout2_count).re = scratch[11][0] + scratch[12][0];
360           d_Fout(Fout2_count).im = scratch[11][1] + scratch[12][1];
361           //C_SUB(d_Fout3[Fout3_count],scratch[11],scratch[12]);
362           d_Fout(Fout3_count).re = scratch[11][0] - scratch[12][0];
363           d_Fout(Fout3_count).im = scratch[11][1] - scratch[12][1];
364 
365           ++Fout0_count;++Fout1_count;++Fout2_count;++Fout3_count;++Fout4_count;
366       }
367   }
368 
369   /* perform the butterfly for one stage of a mixed radix FFT */
370 
371   KOKKOS_INLINE_FUNCTION
kf_bfly_generic(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const size_t fstride,const kiss_fft_state_kokkos<DeviceType> & st,int m,int p,int Fout_count)372   static void kf_bfly_generic(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const size_t fstride,
373                               const kiss_fft_state_kokkos<DeviceType> &st, int m, int p, int Fout_count)
374   {
375       int u,k,q1,q;
376       typename FFT_AT::t_FFT_DATA_1d_um d_twiddles = st.d_twiddles;
377       FFT_SCALAR t[2];
378       int Norig = st.nfft;
379 
380       typename FFT_AT::t_FFT_DATA_1d_um d_scratch = st.d_scratch;
381       for ( u=0; u<m; ++u ) {
382           k=u;
383           for ( q1=0 ; q1<p ; ++q1 ) {
384               //C_EQ(d_scratch[q1],d_Fout[k]);
385               d_scratch(q1).re = d_Fout(Fout_count + k).re;
386               d_scratch(q1).im = d_Fout(Fout_count + k).im;
387               //C_FIXDIV(d_scratch[q1],p);
388               k += m;
389           }
390 
391           k=u;
392           for ( q1=0 ; q1<p ; ++q1 ) {
393               int twidx=0;
394               //C_EQ(d_Fout[k],d_scratch[0]);
395               d_Fout(Fout_count + k).re = d_scratch(0).re;
396               d_Fout(Fout_count + k).im = d_scratch(0).im;
397               for (q=1;q<p;++q ) {
398                   twidx += fstride * k;
399                   if (twidx>=Norig) twidx-=Norig;
400                   C_MUL(t,d_scratch,q,d_twiddles,twidx);
401                   //C_ADDTO(d_Fout[k],t);
402                   d_Fout(Fout_count + k).re += t[0];
403                   d_Fout(Fout_count + k).im += t[1];
404               }
405               k += m;
406           }
407       }
408   }
409 
410   KOKKOS_INLINE_FUNCTION
kf_work(typename FFT_AT::t_FFT_DATA_1d_um & d_Fout,const typename FFT_AT::t_FFT_DATA_1d_um & d_f,const size_t fstride,int in_stride,const typename FFT_AT::t_int_64_um & d_factors,const kiss_fft_state_kokkos<DeviceType> & st,int Fout_count,int f_count,int factors_count)411   static void kf_work(typename FFT_AT::t_FFT_DATA_1d_um &d_Fout, const typename FFT_AT::t_FFT_DATA_1d_um &d_f,
412                       const size_t fstride, int in_stride,
413                       const typename FFT_AT::t_int_64_um &d_factors, const kiss_fft_state_kokkos<DeviceType> &st, int Fout_count, int f_count, int factors_count)
414   {
415       const int beg = Fout_count;
416       const int p = d_factors[factors_count++]; /* the radix  */
417       const int m = d_factors[factors_count++]; /* stage's fft length/p */
418       const int end = Fout_count + p*m;
419 
420       if (m == 1) {
421           do {
422               //C_EQ(d_Fout[Fout_count],d_f[f_count]);
423               d_Fout(Fout_count).re = d_f(f_count).re;
424               d_Fout(Fout_count).im = d_f(f_count).im;
425               f_count += fstride*in_stride;
426           } while (++Fout_count != end);
427       } else {
428           do {
429               /* recursive call:
430                  DFT of size m*p performed by doing
431                  p instances of smaller DFTs of size m,
432                  each one takes a decimated version of the input */
433               kf_work(d_Fout, d_f, fstride*p, in_stride, d_factors, st, Fout_count, f_count, factors_count);
434               f_count += fstride*in_stride;
435           } while( (Fout_count += m) != end);
436       }
437 
438       Fout_count=beg;
439 
440       /* recombine the p smaller DFTs */
441       switch (p) {
442         case 2: kf_bfly2(d_Fout,fstride,st,m,Fout_count); break;
443         case 3: kf_bfly3(d_Fout,fstride,st,m,Fout_count); break;
444         case 4: kf_bfly4(d_Fout,fstride,st,m,Fout_count); break;
445         case 5: kf_bfly5(d_Fout,fstride,st,m,Fout_count); break;
446         default: kf_bfly_generic(d_Fout,fstride,st,m,p,Fout_count); break;
447       }
448   }
449 
450   /*  facbuf is populated by p1,m1,p2,m2, ...
451       where
452       p[i] * m[i] = m[i-1]
453       m0 = n                  */
454 
kf_factor(int n,FFT_HAT::t_int_64 h_facbuf)455   static int kf_factor(int n, FFT_HAT::t_int_64 h_facbuf)
456   {
457       int p=4, nf=0;
458       double floor_sqrt;
459       floor_sqrt = floor( sqrt((double)n) );
460       int facbuf_count = 0;
461       int p_max = 0;
462 
463       /* factor out the remaining powers of 4, powers of 2,
464          and then any other remaining primes */
465       do {
466           if (nf == MAXFACTORS) p = n; /* make certain that we don't run out of space */
467           while (n % p) {
468               switch (p) {
469                 case 4: p = 2; break;
470                 case 2: p = 3; break;
471                 default: p += 2; break;
472               }
473               if (p > floor_sqrt)
474                   p = n;          /* no more factors, skip to end */
475           }
476           n /= p;
477           h_facbuf[facbuf_count++] = p;
478           h_facbuf[facbuf_count++] = n;
479           p_max = MAX(p,p_max);
480           ++nf;
481       } while (n > 1);
482       return p_max;
483   }
484 
485   /*
486    * User-callable function to allocate all necessary storage space for the fft.
487    *
488    * The return value is a contiguous block of memory, allocated with malloc.  As such,
489    * It can be freed with free(), rather than a kiss_fft-specific function.
490    */
491 
kiss_fft_alloc_kokkos(int nfft,int inverse_fft,void * mem,size_t * lenmem)492   static kiss_fft_state_kokkos<DeviceType> kiss_fft_alloc_kokkos(int nfft, int inverse_fft, void *mem, size_t *lenmem)
493   {
494       kiss_fft_state_kokkos<DeviceType> st;
495       int i;
496       st.nfft = nfft;
497       st.inverse = inverse_fft;
498 
499       typename FFT_AT::tdual_int_64 k_factors = typename FFT_AT::tdual_int_64();
500       typename FFT_AT::tdual_FFT_DATA_1d k_twiddles = typename FFT_AT::tdual_FFT_DATA_1d();
501 
502       if (nfft > 0) {
503           k_factors = typename FFT_AT::tdual_int_64("kissfft:factors",MAXFACTORS*2);
504           k_twiddles = typename FFT_AT::tdual_FFT_DATA_1d("kissfft:twiddles",nfft);
505 
506           for (i=0;i<nfft;++i) {
507               const double phase = (st.inverse ? 2.0*M_PI:-2.0*M_PI)*i / nfft;
508               kf_cexp(k_twiddles.h_view,i,phase );
509           }
510 
511           int p_max = kf_factor(nfft,k_factors.h_view);
512           st.d_scratch = typename FFT_AT::t_FFT_DATA_1d("kissfft:scratch",p_max);
513       }
514 
515       k_factors.template modify<LMPHostType>();
516       k_factors.template sync<LMPDeviceType>();
517       st.d_factors = k_factors.template view<DeviceType>();
518 
519       k_twiddles.template modify<LMPHostType>();
520       k_twiddles.template sync<LMPDeviceType>();
521       st.d_twiddles = k_twiddles.template view<DeviceType>();
522 
523       return st;
524   }
525 
526   KOKKOS_INLINE_FUNCTION
kiss_fft_stride(const kiss_fft_state_kokkos<DeviceType> & st,const typename FFT_AT::t_FFT_DATA_1d_um & d_fin,typename FFT_AT::t_FFT_DATA_1d_um & d_fout,int in_stride,int offset)527   static void kiss_fft_stride(const kiss_fft_state_kokkos<DeviceType> &st, const typename FFT_AT::t_FFT_DATA_1d_um &d_fin, typename FFT_AT::t_FFT_DATA_1d_um &d_fout, int in_stride, int offset)
528   {
529       //if (d_fin.data() == d_fout.data()) {
530       //    // NOTE: this is not really an in-place FFT algorithm.
531       //    // It just performs an out-of-place FFT into a temp buffer
532       //    typename FFT_AT::t_FFT_DATA_1d_um d_tmpbuf = typename FFT_AT::t_FFT_DATA_1d("kissfft:tmpbuf",d_fin.extent(1));
533       //    kf_work(d_tmpbuf,d_fin,1,in_stride,st.d_factors,st,offset,offset).re;
534       //    Kokkos::deep_copy(d_fout,d_tmpbuf);
535       //} else {
536         kf_work(d_fout,d_fin,1,in_stride,st.d_factors,st,offset,offset,0);
537       //}
538   }
539 
540   KOKKOS_INLINE_FUNCTION
kiss_fft_kokkos(const kiss_fft_state_kokkos<DeviceType> & cfg,const typename FFT_AT::t_FFT_DATA_1d_um d_fin,typename FFT_AT::t_FFT_DATA_1d_um d_fout,int offset)541   static void kiss_fft_kokkos(const kiss_fft_state_kokkos<DeviceType> &cfg, const typename FFT_AT::t_FFT_DATA_1d_um d_fin, typename FFT_AT::t_FFT_DATA_1d_um d_fout, int offset)
542   {
543       kiss_fft_stride(cfg,d_fin,d_fout,1,offset);
544   }
545 
546 };
547 
548 }
549 #endif
550