1 #include <helib/PGFFT.h>
2 
3 #include <iostream>
4 #include <cstdlib>
5 #include <ctime>
6 
7 // these are just for the Fft stuff
8 #include <algorithm>
9 #include <cmath>
10 #include <cstddef>
11 #include <cstdint>
12 #include <stdexcept>
13 #include <limits>
14 #include <gtest/gtest.h>
15 
16 namespace {
17 
18 // RandomBnd(n) returns a random number in [0..n).
19 // Assumes n > 0.
20 // FIXME: uses brain-dead rand() function
RandomBnd(long n)21 static long RandomBnd(long n)
22 {
23   const int BPL = std::numeric_limits<unsigned long>::digits;
24   const int ROTAMT = 7;
25   unsigned long x = 0;
26   for (long i = 0; i < 12; i++) {
27     unsigned long x1 = std::rand();
28     // rotate x ROTAMT bits
29     x = (x << ROTAMT) | (x >> (BPL - ROTAMT));
30     x = x ^ x1;
31   }
32 
33   return long(x % ((unsigned long)n));
34 }
35 
SetSeed()36 static void SetSeed() { srand(time(0)); }
37 
38 //================== Fft ====================
39 
40 // I've modified this a bit from code I got here:
41 // https://www.nayuki.io/page/free-small-fft-in-multiple-languages
42 
43 // Specifically, I modified it to use long doubles instead of doubles
44 
45 // Here is the original copyright notice:
46 
47 /*
48  * Free FFT and convolution (C++)
49  *
50  * Copyright (c) 2017 Project Nayuki. (MIT License)
51  * https://www.nayuki.io/page/free-small-fft-in-multiple-languages
52  *
53  * Permission is hereby granted, free of charge, to any person obtaining a copy
54  * of this software and associated documentation files (the "Software"), to deal
55  * in the Software without restriction, including without limitation the rights
56  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57  * copies of the Software, and to permit persons to whom the Software is
58  * furnished to do so, subject to the following conditions:
59  * - The above copyright notice and this permission notice shall be included in
60  *   all copies or substantial portions of the Software.
61  * - The Software is provided "as is", without warranty of any kind, express or
62  *   implied, including but not limited to the warranties of merchantability,
63  *   fitness for a particular purpose and noninfringement. In no event shall the
64  *   authors or copyright holders be liable for any claim, damages or other
65  *   liability, whether in an action of contract, tort or otherwise, arising
66  * from, out of or in connection with the Software or the use or other dealings
67  * in the Software.
68  */
69 
70 typedef long double ldbl;
71 typedef std::complex<ldbl> lcx;
72 
73 namespace Fft {
74 
75 /*
76  * Computes the discrete Fourier transform (DFT) of the given complex vector,
77  * storing the result back into the vector. The vector can have any length. This
78  * is a wrapper function.
79  */
80 void transform(std::vector<lcx>& vec);
81 
82 /*
83  * Computes the inverse discrete Fourier transform (IDFT) of the given complex
84  * vector, storing the result back into the vector. The vector can have any
85  * length. This is a wrapper function. This transform does not perform scaling,
86  * so the inverse is not a true inverse.
87  */
88 void inverseTransform(std::vector<lcx>& vec);
89 
90 /*
91  * Computes the discrete Fourier transform (DFT) of the given complex vector,
92  * storing the result back into the vector. The vector's length must be a power
93  * of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm.
94  */
95 void transformRadix2(std::vector<lcx>& vec);
96 
97 /*
98  * Computes the discrete Fourier transform (DFT) of the given complex vector,
99  * storing the result back into the vector. The vector can have any length. This
100  * requires the convolution function, which in turn requires the radix-2 FFT
101  * function. Uses Bluestein's chirp z-transform algorithm.
102  */
103 void transformBluestein(std::vector<lcx>& vec);
104 
105 /*
106  * Computes the circular convolution of the given complex vectors. Each vector's
107  * length must be the same.
108  */
109 void convolve(const std::vector<lcx>& vecx,
110               const std::vector<lcx>& vecy,
111               std::vector<lcx>& vecout);
112 
113 } // namespace Fft
114 
115 using std::complex;
116 using std::size_t;
117 using std::vector;
118 
119 // Private function prototypes
120 static size_t reverseBits(size_t x, int n);
121 
transform(vector<lcx> & vec)122 void Fft::transform(vector<lcx>& vec)
123 {
124   size_t n = vec.size();
125   if (n == 0)
126     return;
127   else if ((n & (n - 1)) == 0) // Is power of 2
128     transformRadix2(vec);
129   else // More complicated algorithm for arbitrary sizes
130     transformBluestein(vec);
131 }
132 
inverseTransform(vector<lcx> & vec)133 void Fft::inverseTransform(vector<lcx>& vec)
134 {
135   std::transform(vec.cbegin(),
136                  vec.cend(),
137                  vec.begin(),
138                  static_cast<lcx (*)(const lcx&)>(std::conj));
139   transform(vec);
140   std::transform(vec.cbegin(),
141                  vec.cend(),
142                  vec.begin(),
143                  static_cast<lcx (*)(const lcx&)>(std::conj));
144 }
145 
transformRadix2(vector<lcx> & vec)146 void Fft::transformRadix2(vector<lcx>& vec)
147 {
148   // Length variables
149   size_t n = vec.size();
150   int levels = 0; // Compute levels = floor(log2(n))
151   for (size_t temp = n; temp > 1U; temp >>= 1)
152     levels++;
153   if (static_cast<size_t>(1U) << levels != n)
154     throw std::domain_error("Length is not a power of 2");
155 
156   const ldbl pi = atan(ldbl(1)) * 4.0;
157 
158   // Trignometric table
159   vector<lcx> expTable(n / 2);
160   for (size_t i = 0; i < n / 2; i++) {
161     // expTable[i] = std::exp(lcx(0, -2 * M_PI * i / n));
162     ldbl angle = -2 * pi * i / n;
163     expTable[i] = lcx(std::cos(angle), std::sin(angle));
164   }
165 
166   // Bit-reversed addressing permutation
167   for (size_t i = 0; i < n; i++) {
168     size_t j = reverseBits(i, levels);
169     if (j > i)
170       std::swap(vec[i], vec[j]);
171   }
172 
173   // Cooley-Tukey decimation-in-time radix-2 FFT
174   for (size_t size = 2; size <= n; size *= 2) {
175     size_t halfsize = size / 2;
176     size_t tablestep = n / size;
177     for (size_t i = 0; i < n; i += size) {
178       for (size_t j = i, k = 0; j < i + halfsize; j++, k += tablestep) {
179         lcx temp = vec[j + halfsize] * expTable[k];
180         vec[j + halfsize] = vec[j] - temp;
181         vec[j] += temp;
182       }
183     }
184     if (size == n) // Prevent overflow in 'size *= 2'
185       break;
186   }
187 }
188 
transformBluestein(vector<lcx> & vec)189 void Fft::transformBluestein(vector<lcx>& vec)
190 {
191   // Find a power-of-2 convolution length m such that m >= n * 2 + 1
192   size_t n = vec.size();
193   size_t m = 1;
194   while (m / 2 <= n) {
195     if (m > SIZE_MAX / 2)
196       throw std::length_error("Vector too large");
197     m *= 2;
198   }
199 
200   const ldbl pi = atan(ldbl(1)) * 4.0;
201 
202   // Trignometric table
203   vector<lcx> expTable(n);
204   for (size_t i = 0; i < n; i++) {
205     unsigned long long temp = static_cast<unsigned long long>(i) * i;
206     temp %= static_cast<unsigned long long>(n) * 2;
207     ldbl angle = pi * temp / n;
208     // Less accurate alternative if long long is unavailable: double angle =
209     // M_PI * i * i / n;
210     expTable[i] = lcx(std::cos(-angle), std::sin(-angle));
211   }
212 
213   // Temporary vectors and preprocessing
214   vector<lcx> av(m);
215   for (size_t i = 0; i < n; i++)
216     av[i] = vec[i] * expTable[i];
217   vector<lcx> bv(m);
218   bv[0] = expTable[0];
219   for (size_t i = 1; i < n; i++)
220     bv[i] = bv[m - i] = std::conj(expTable[i]);
221 
222   // Convolution
223   vector<lcx> cv(m);
224   convolve(av, bv, cv);
225 
226   // Postprocessing
227   for (size_t i = 0; i < n; i++)
228     vec[i] = cv[i] * expTable[i];
229 }
230 
convolve(const vector<lcx> & xvec,const vector<lcx> & yvec,vector<lcx> & outvec)231 void Fft::convolve(const vector<lcx>& xvec,
232                    const vector<lcx>& yvec,
233                    vector<lcx>& outvec)
234 {
235 
236   size_t n = xvec.size();
237   if (n != yvec.size() || n != outvec.size())
238     throw std::domain_error("Mismatched lengths");
239   vector<lcx> xv = xvec;
240   vector<lcx> yv = yvec;
241   transform(xv);
242   transform(yv);
243   for (size_t i = 0; i < n; i++)
244     xv[i] *= yv[i];
245   inverseTransform(xv);
246   for (size_t i = 0; i < n;
247        i++) // Scaling (because this FFT implementation omits it)
248     outvec[i] = xv[i] / static_cast<ldbl>(n);
249 }
250 
reverseBits(size_t x,int n)251 static size_t reverseBits(size_t x, int n)
252 {
253   size_t result = 0;
254   for (int i = 0; i < n; i++, x >>= 1)
255     result = (result << 1) | (x & 1U);
256   return result;
257 }
258 
259 //===========================================
260 
261 typedef complex<double> cmplx_t;
262 
TestIt(long n)263 static void TestIt(long n)
264 {
265   helib::PGFFT pgfft(n);
266 
267   for (long j = 0; j < 10; j++) {
268 
269     vector<cmplx_t> v(n);
270     for (int i = 0; i < n; i++) {
271       v[i] = RandomBnd(20) - 10;
272     }
273 
274     vector<cmplx_t> v0(v);
275 
276     pgfft.apply(v.data());
277 
278     vector<lcx> vv(n);
279     for (int i = 0; i < n; i++)
280       vv[i] = v0[i];
281 
282     Fft::transform(vv);
283 
284     ldbl vv_norm = 0;
285     for (int i = 0; i < n; i++) {
286       vv_norm += std::norm(vv[i]);
287     }
288 
289     vv_norm = sqrt(vv_norm);
290 
291     ldbl diff_norm = 0;
292     for (int i = 0; i < n; i++) {
293       lcx val = v[i];
294       lcx diff = val - vv[i];
295       diff_norm += std::norm(diff);
296     }
297     diff_norm = sqrt(diff_norm);
298 
299     if (vv_norm == 0) {
300       // vv has norm = 0. Cheching if the fft is correct looking only the
301       // enumerator.
302       EXPECT_DOUBLE_EQ(diff_norm, 0);
303     } else {
304       // Check if the fft relative error is smaller than the treshold.
305       ldbl rel_err = diff_norm / vv_norm;
306       EXPECT_LE(rel_err, 1e-9);
307     }
308   }
309 }
310 
TEST(GTestPGFFT,PGFFTWorksInRange1to100Points)311 TEST(GTestPGFFT, PGFFTWorksInRange1to100Points)
312 {
313   SetSeed();
314 
315   for (long n = 1; n <= 100; n++)
316     TestIt(n);
317 }
318 
TEST(GTestPGFFT,PGFFTWorksInRange256to32768PowerOfTwoPoints)319 TEST(GTestPGFFT, PGFFTWorksInRange256to32768PowerOfTwoPoints)
320 {
321   SetSeed();
322 
323   for (long n = 256; n <= 32 * 1024; n *= 2)
324     TestIt(n);
325 }
326 
TEST(GTestPGFFT,PGFFTWorksInRange10000to20000Points)327 TEST(GTestPGFFT, PGFFTWorksInRange10000to20000Points)
328 {
329   SetSeed();
330 
331   for (long i = 0; i < 100; i++) {
332     long n = RandomBnd(10000) + 10000;
333     TestIt(n);
334   }
335 }
336 } // namespace
337