1 #include <helib/PGFFT.h>
2
3 #include <iostream>
4 #include <cstdlib>
5 #include <ctime>
6
7 // these are just for the Fft stuff
8 #include <algorithm>
9 #include <cmath>
10 #include <cstddef>
11 #include <cstdint>
12 #include <stdexcept>
13 #include <limits>
14 #include <gtest/gtest.h>
15
16 namespace {
17
18 // RandomBnd(n) returns a random number in [0..n).
19 // Assumes n > 0.
20 // FIXME: uses brain-dead rand() function
RandomBnd(long n)21 static long RandomBnd(long n)
22 {
23 const int BPL = std::numeric_limits<unsigned long>::digits;
24 const int ROTAMT = 7;
25 unsigned long x = 0;
26 for (long i = 0; i < 12; i++) {
27 unsigned long x1 = std::rand();
28 // rotate x ROTAMT bits
29 x = (x << ROTAMT) | (x >> (BPL - ROTAMT));
30 x = x ^ x1;
31 }
32
33 return long(x % ((unsigned long)n));
34 }
35
SetSeed()36 static void SetSeed() { srand(time(0)); }
37
38 //================== Fft ====================
39
40 // I've modified this a bit from code I got here:
41 // https://www.nayuki.io/page/free-small-fft-in-multiple-languages
42
43 // Specifically, I modified it to use long doubles instead of doubles
44
45 // Here is the original copyright notice:
46
47 /*
48 * Free FFT and convolution (C++)
49 *
50 * Copyright (c) 2017 Project Nayuki. (MIT License)
51 * https://www.nayuki.io/page/free-small-fft-in-multiple-languages
52 *
53 * Permission is hereby granted, free of charge, to any person obtaining a copy
54 * of this software and associated documentation files (the "Software"), to deal
55 * in the Software without restriction, including without limitation the rights
56 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57 * copies of the Software, and to permit persons to whom the Software is
58 * furnished to do so, subject to the following conditions:
59 * - The above copyright notice and this permission notice shall be included in
60 * all copies or substantial portions of the Software.
61 * - The Software is provided "as is", without warranty of any kind, express or
62 * implied, including but not limited to the warranties of merchantability,
63 * fitness for a particular purpose and noninfringement. In no event shall the
64 * authors or copyright holders be liable for any claim, damages or other
65 * liability, whether in an action of contract, tort or otherwise, arising
66 * from, out of or in connection with the Software or the use or other dealings
67 * in the Software.
68 */
69
70 typedef long double ldbl;
71 typedef std::complex<ldbl> lcx;
72
73 namespace Fft {
74
75 /*
76 * Computes the discrete Fourier transform (DFT) of the given complex vector,
77 * storing the result back into the vector. The vector can have any length. This
78 * is a wrapper function.
79 */
80 void transform(std::vector<lcx>& vec);
81
82 /*
83 * Computes the inverse discrete Fourier transform (IDFT) of the given complex
84 * vector, storing the result back into the vector. The vector can have any
85 * length. This is a wrapper function. This transform does not perform scaling,
86 * so the inverse is not a true inverse.
87 */
88 void inverseTransform(std::vector<lcx>& vec);
89
90 /*
91 * Computes the discrete Fourier transform (DFT) of the given complex vector,
92 * storing the result back into the vector. The vector's length must be a power
93 * of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm.
94 */
95 void transformRadix2(std::vector<lcx>& vec);
96
97 /*
98 * Computes the discrete Fourier transform (DFT) of the given complex vector,
99 * storing the result back into the vector. The vector can have any length. This
100 * requires the convolution function, which in turn requires the radix-2 FFT
101 * function. Uses Bluestein's chirp z-transform algorithm.
102 */
103 void transformBluestein(std::vector<lcx>& vec);
104
105 /*
106 * Computes the circular convolution of the given complex vectors. Each vector's
107 * length must be the same.
108 */
109 void convolve(const std::vector<lcx>& vecx,
110 const std::vector<lcx>& vecy,
111 std::vector<lcx>& vecout);
112
113 } // namespace Fft
114
115 using std::complex;
116 using std::size_t;
117 using std::vector;
118
119 // Private function prototypes
120 static size_t reverseBits(size_t x, int n);
121
transform(vector<lcx> & vec)122 void Fft::transform(vector<lcx>& vec)
123 {
124 size_t n = vec.size();
125 if (n == 0)
126 return;
127 else if ((n & (n - 1)) == 0) // Is power of 2
128 transformRadix2(vec);
129 else // More complicated algorithm for arbitrary sizes
130 transformBluestein(vec);
131 }
132
inverseTransform(vector<lcx> & vec)133 void Fft::inverseTransform(vector<lcx>& vec)
134 {
135 std::transform(vec.cbegin(),
136 vec.cend(),
137 vec.begin(),
138 static_cast<lcx (*)(const lcx&)>(std::conj));
139 transform(vec);
140 std::transform(vec.cbegin(),
141 vec.cend(),
142 vec.begin(),
143 static_cast<lcx (*)(const lcx&)>(std::conj));
144 }
145
transformRadix2(vector<lcx> & vec)146 void Fft::transformRadix2(vector<lcx>& vec)
147 {
148 // Length variables
149 size_t n = vec.size();
150 int levels = 0; // Compute levels = floor(log2(n))
151 for (size_t temp = n; temp > 1U; temp >>= 1)
152 levels++;
153 if (static_cast<size_t>(1U) << levels != n)
154 throw std::domain_error("Length is not a power of 2");
155
156 const ldbl pi = atan(ldbl(1)) * 4.0;
157
158 // Trignometric table
159 vector<lcx> expTable(n / 2);
160 for (size_t i = 0; i < n / 2; i++) {
161 // expTable[i] = std::exp(lcx(0, -2 * M_PI * i / n));
162 ldbl angle = -2 * pi * i / n;
163 expTable[i] = lcx(std::cos(angle), std::sin(angle));
164 }
165
166 // Bit-reversed addressing permutation
167 for (size_t i = 0; i < n; i++) {
168 size_t j = reverseBits(i, levels);
169 if (j > i)
170 std::swap(vec[i], vec[j]);
171 }
172
173 // Cooley-Tukey decimation-in-time radix-2 FFT
174 for (size_t size = 2; size <= n; size *= 2) {
175 size_t halfsize = size / 2;
176 size_t tablestep = n / size;
177 for (size_t i = 0; i < n; i += size) {
178 for (size_t j = i, k = 0; j < i + halfsize; j++, k += tablestep) {
179 lcx temp = vec[j + halfsize] * expTable[k];
180 vec[j + halfsize] = vec[j] - temp;
181 vec[j] += temp;
182 }
183 }
184 if (size == n) // Prevent overflow in 'size *= 2'
185 break;
186 }
187 }
188
transformBluestein(vector<lcx> & vec)189 void Fft::transformBluestein(vector<lcx>& vec)
190 {
191 // Find a power-of-2 convolution length m such that m >= n * 2 + 1
192 size_t n = vec.size();
193 size_t m = 1;
194 while (m / 2 <= n) {
195 if (m > SIZE_MAX / 2)
196 throw std::length_error("Vector too large");
197 m *= 2;
198 }
199
200 const ldbl pi = atan(ldbl(1)) * 4.0;
201
202 // Trignometric table
203 vector<lcx> expTable(n);
204 for (size_t i = 0; i < n; i++) {
205 unsigned long long temp = static_cast<unsigned long long>(i) * i;
206 temp %= static_cast<unsigned long long>(n) * 2;
207 ldbl angle = pi * temp / n;
208 // Less accurate alternative if long long is unavailable: double angle =
209 // M_PI * i * i / n;
210 expTable[i] = lcx(std::cos(-angle), std::sin(-angle));
211 }
212
213 // Temporary vectors and preprocessing
214 vector<lcx> av(m);
215 for (size_t i = 0; i < n; i++)
216 av[i] = vec[i] * expTable[i];
217 vector<lcx> bv(m);
218 bv[0] = expTable[0];
219 for (size_t i = 1; i < n; i++)
220 bv[i] = bv[m - i] = std::conj(expTable[i]);
221
222 // Convolution
223 vector<lcx> cv(m);
224 convolve(av, bv, cv);
225
226 // Postprocessing
227 for (size_t i = 0; i < n; i++)
228 vec[i] = cv[i] * expTable[i];
229 }
230
convolve(const vector<lcx> & xvec,const vector<lcx> & yvec,vector<lcx> & outvec)231 void Fft::convolve(const vector<lcx>& xvec,
232 const vector<lcx>& yvec,
233 vector<lcx>& outvec)
234 {
235
236 size_t n = xvec.size();
237 if (n != yvec.size() || n != outvec.size())
238 throw std::domain_error("Mismatched lengths");
239 vector<lcx> xv = xvec;
240 vector<lcx> yv = yvec;
241 transform(xv);
242 transform(yv);
243 for (size_t i = 0; i < n; i++)
244 xv[i] *= yv[i];
245 inverseTransform(xv);
246 for (size_t i = 0; i < n;
247 i++) // Scaling (because this FFT implementation omits it)
248 outvec[i] = xv[i] / static_cast<ldbl>(n);
249 }
250
reverseBits(size_t x,int n)251 static size_t reverseBits(size_t x, int n)
252 {
253 size_t result = 0;
254 for (int i = 0; i < n; i++, x >>= 1)
255 result = (result << 1) | (x & 1U);
256 return result;
257 }
258
259 //===========================================
260
261 typedef complex<double> cmplx_t;
262
TestIt(long n)263 static void TestIt(long n)
264 {
265 helib::PGFFT pgfft(n);
266
267 for (long j = 0; j < 10; j++) {
268
269 vector<cmplx_t> v(n);
270 for (int i = 0; i < n; i++) {
271 v[i] = RandomBnd(20) - 10;
272 }
273
274 vector<cmplx_t> v0(v);
275
276 pgfft.apply(v.data());
277
278 vector<lcx> vv(n);
279 for (int i = 0; i < n; i++)
280 vv[i] = v0[i];
281
282 Fft::transform(vv);
283
284 ldbl vv_norm = 0;
285 for (int i = 0; i < n; i++) {
286 vv_norm += std::norm(vv[i]);
287 }
288
289 vv_norm = sqrt(vv_norm);
290
291 ldbl diff_norm = 0;
292 for (int i = 0; i < n; i++) {
293 lcx val = v[i];
294 lcx diff = val - vv[i];
295 diff_norm += std::norm(diff);
296 }
297 diff_norm = sqrt(diff_norm);
298
299 if (vv_norm == 0) {
300 // vv has norm = 0. Cheching if the fft is correct looking only the
301 // enumerator.
302 EXPECT_DOUBLE_EQ(diff_norm, 0);
303 } else {
304 // Check if the fft relative error is smaller than the treshold.
305 ldbl rel_err = diff_norm / vv_norm;
306 EXPECT_LE(rel_err, 1e-9);
307 }
308 }
309 }
310
TEST(GTestPGFFT,PGFFTWorksInRange1to100Points)311 TEST(GTestPGFFT, PGFFTWorksInRange1to100Points)
312 {
313 SetSeed();
314
315 for (long n = 1; n <= 100; n++)
316 TestIt(n);
317 }
318
TEST(GTestPGFFT,PGFFTWorksInRange256to32768PowerOfTwoPoints)319 TEST(GTestPGFFT, PGFFTWorksInRange256to32768PowerOfTwoPoints)
320 {
321 SetSeed();
322
323 for (long n = 256; n <= 32 * 1024; n *= 2)
324 TestIt(n);
325 }
326
TEST(GTestPGFFT,PGFFTWorksInRange10000to20000Points)327 TEST(GTestPGFFT, PGFFTWorksInRange10000to20000Points)
328 {
329 SetSeed();
330
331 for (long i = 0; i < 100; i++) {
332 long n = RandomBnd(10000) + 10000;
333 TestIt(n);
334 }
335 }
336 } // namespace
337