1 // Fast integer multiplication using FFT in a modular ring.
2 // Bruno Haible 5.5.1996, 30.6.1996, 20.8.1996
3
4 // FFT in the complex domain has the drawback that it needs careful round-off
5 // error analysis. So here we choose another field of characteristic 0: Q_p.
6 // Since Q_p contains exactly the (p-1)th roots of unity, we choose
7 // p == 1 mod N and have the Nth roots of unity (N = 2^n) in Q_p and
8 // even in Z_p. Actually, we compute in Z/(p^m Z).
9
10 // All operations the FFT algorithm needs is addition, subtraction,
11 // multiplication, multiplication by the Nth root of unity and division
12 // by N. Hence we can use the domain Z/(p^m Z) even if p is not a prime!
13
14 // We want to compute the convolution of N 32-bit words. The resulting
15 // words are < (2^32)^2 * N. To avoid computing with numbers greater than
16 // 32 bits, we compute in Z/pZ for three different primes p in parallel,
17 // i.e. we compute in the ring (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z). We choose
18 // p1 = 3*2^30+1, p2 = 15*2^27+1, p3 = 7*2^26+1.
19 // Because of p1*p2*p3 >= 2^91 >= (2^32)^2 * N, the chinese remainder theorem
20 // will faithfully combine 3 32-bit words to a word < (2^32)^2 * N.
21
22 // Furthermore we use Montgomery's modular multiplication trick
23 // [Peter L. Montgomery: Modular multiplication without trial division,
24 // Mathematics of Computation 44 (1985), 519-521.]
25 //
26 // Assume we want to compute modulo M, M odd. V and N will be chosen
27 // so that V*N==1 mod M and that (a,b) --> a*b*V mod M can be more easily
28 // computed than (a,b) --> a*b mod M. Then, we have a ring isomorphism
29 // (Z/MZ, +, * mod M) \isomorph (Z/MZ, +, (a,b) --> a*b*V mod M)
30 // x mod M --------> x*N mod M
31 // It is thus preferrable to use x*N mod M as a "representation" of x mod M,
32 // especially for computations which involve at least several multiplications.
33 //
34 // The precise algorithm to compute a*b*V mod M, given a and b, and the choice
35 // of N and V depend on M and on the hardware. The general idea is this:
36 // Choose N = 2^n, so that division by N is easy. Recall that V == N^-1 mod M.
37 // 1. Given a and b as m-bit numbers (M <= 2^m), compute a*b in full
38 // precision.
39 // 2. Write a*b = c*N+d, i.e. split it into components c and d.
40 // 3. Now a*b*V = c*N*V+d*V == c+d*V mod M.
41 // 4. Instead of computing d*V mod M
42 // a. by full multiplication and then division mod M, or
43 // b. by left shifts: repeated application of
44 // x := 2*x+(0 or 1); if (x >= M) { x := x-M; }
45 // we compute
46 // c. by right shifts (recall that d*V == d*2^-n mod M): repeated application
47 // of if (x odd) { x := (x+M)/2; } else { x := x/2; }
48 // Usually one will choose N = 2^m, so that c and d have both m bits.
49 // Several variations are possible: In step 4 one can implement the right
50 // shifts in hardware. Or (for example when N = 2^160 and working on a
51 // 32-bit machine) one can do 32 shift steps at the same time:
52 // Choose M' == M^-1 mod 2^32 and compute n/32 times
53 // x := (x - ((x mod 2^32) * M' mod 2^32) * M) / 2^32.
54 //
55 // Here, we deal with moduli M = p_i = j*2^k+1. These form of primes comes
56 // in because we need 2^n-th roots of unity mod M. But is also comes handy
57 // for Montgomery multiplication: Instead of choosing N = 2^32 (which makes
58 // up for very easy splitting in step 1) and V = j^2*2^(2*k-32), we better
59 // choose N = 2^k and V = -j. The algorithm now goes like this (recall that
60 // M is an m-bit number and j is an (m-k)-bit number):
61 // 1. Compute a*b in full precision, as a 2*m <= 64 bit number.
62 // 2. Split a*b = c*N+d, with c an (2m-k)-bit number and d an k-bit number.
63 // 3. a*b*V == c+d*V mod M.
64 // 4. Compute c mod M by splitting off the leading (m-k+1) bits of c and
65 // using table lookup; the remainder (c mod 2^(m-1)) is already reduced
66 // mod M.
67 // Compute d*|V| the standard way; |V| has only few bits. d*|V| is
68 // already reduced mod M, because d*|V| < j*2^k < M.
69
70 // In order to get best performance, we carefully choose the primes so that
71 // a. the table of size 2^(m-k+1) doesn't get too large,
72 // b. multiplication by V is easy.
73 // Here is a list of the interesting primes < 2^32:
74 //
75 // U*M+V*N = 1
76 // prime bits N=2^n V U 2m<=n+32 ?
77 // M m n
78 //
79 // 3*2^30+1 32 30 -3 1 n (*)
80 //
81 // 13*2^28+1 32 28 -13 1 n
82 //
83 // 15*2^27+1 31 27 -15 1 n (*)
84 // 17*2^27+1 32 27 -17 1 n
85 // 29*2^27+1 32 27 -29 1 n
86 //
87 // 7*2^26+1 29 26 -7 1 y (*)
88 // 27*2^26+1 31 26 -27 1 n
89 // 37*2^26+1 32 26 -37 1 n
90 // 43*2^26+1 32 26 -43 1 n
91 //
92 // 5*2^25+1 28 25 -5 1 y
93 // 33*2^25+1 31 25 -33 1 n
94 // 51*2^25+1 31 25 -51 1 n
95 // 63*2^25+1 31 25 -63 1 n
96 // 81*2^25+1 32 25 -81 1 n
97 // 125*2^25+1 32 25 -125 1 n
98 //
99 // 45*2^24+1 30 24 -45 1 n
100 // 73*2^24+1 31 24 -73 1 n
101 // 127*2^24+1 31 24 -127 1 n
102 // 151*2^24+1 32 24 -151 1 n
103 // 157*2^24+1 32 24 -157 1 n
104 // 171*2^24+1 32 24 -171 1 n
105 // 193*2^24+1 32 24 -193 1 n
106 // 235*2^24+1 32 24 -235 1 n
107 // 243*2^24+1 32 24 -243 1 n
108 //
109 // 45*2^23+1 29 23 -45 1 n
110 // ...
111 //
112 // The inequality 2m<=n+32 would mean that c fits in a 32-bit word, but that's
113 // actually irrelevant because we can fetch the most significant bits of c
114 // before actually computing c.
115 // We choose the primes marked with an asterisk.
116
117
118 #if !(intDsize==32)
119 #error "fft mod p implemented only for intDsize==32"
120 #endif
121
122 // Avoid clash with fftp3
123 #define p1 fftp3m_p1
124 #define p2 fftp3m_p2
125 #define p3 fftp3m_p3
126 #define n1 fftp3m_n1
127 #define n2 fftp3m_n2
128 #define n3 fftp3m_n3
129
130 static const uint32 p1 = 1+(3<<30); // = 3221225473
131 static const uint32 p2 = 1+(15<<27); // = 2013265921
132 static const uint32 p3 = 1+(7<<26); // = 469762049
133 static const uint32 n1 = 30; // Montgomery: represent x mod p1 as x*2^n1 mod p1
134 static const uint32 n2 = 27; // Montgomery: represent x mod p2 as x*2^n2 mod p2
135 static const uint32 n3 = 26; // Montgomery: represent x mod p3 as x*2^n3 mod p3
136
137 typedef struct {
138 uint32 w1; // remainder mod p1
139 uint32 w2; // remainder mod p2
140 uint32 w3; // remainder mod p3
141 } fftp3m_word;
142
143 static const fftp3m_word fftp3m_roots_of_1 [26+1] =
144 // roots_of_1[n] is a (2^n)th root of unity in our ring.
145 // (Also roots_of_1[n-1] = roots_of_1[n]^2, but we don't need this.)
146 {
147 #if 0 // in standard representation
148 { 1, 1, 1 },
149 { 3221225472, 2013265920, 469762048 },
150 { 1013946479, 284861408, 19610091 },
151 { 1031213943, 211723194, 26623616 },
152 { 694614138, 78945800, 111570435 },
153 { 347220834, 772607190, 135956445 },
154 { 680684264, 288289890, 181505383 },
155 { 1109768284, 112574482, 145518049 },
156 { 602134989, 928726468, 109721424 },
157 { 1080308101, 875419223, 2847903 },
158 { 381653707, 510575142, 110273149 },
159 { 902453688, 193023072, 65701394 },
160 { 1559299664, 313561437, 181642641 },
161 { 254499731, 121307056, 82315502 },
162 { 1376063215, 20899142, 142137197 },
163 { 1284040478, 956809618, 207661045 },
164 { 336664489, 317295870, 194405005 },
165 { 894491787, 785393806, 2821902 },
166 { 795860341, 738526384, 230963948 },
167 { 23880336, 956561758, 59211404 },
168 { 790585193, 352904935, 95374542 },
169 { 877386874, 836313293, 153165757 },
170 { 1510644826, 971592443, 74027009 },
171 { 353060343, 692611595, 24417505 },
172 { 716717815, 791167605, 26032760 },
173 { 1020271667, 751686895, 150976424 },
174 { 139914905, 477826617, 71902965 }
175 #else // in Montgomery representation
176 { 1073741824, 134217728, 67108864 },
177 { 2147483649, 1879048193, 402653185 },
178 { 1809501489, 1054751064, 265634015 },
179 { 2877487492, 1193844673, 331740947 },
180 { 2989687427, 665825587, 252496823 },
181 { 3105485195, 1961758775, 114795379 },
182 { 1920588894, 1994046595, 175397252 },
183 { 703819063, 932019131, 314756028 },
184 { 3020513810, 1682915367, 51434375 },
185 { 713639124, 1015380543, 133810885 },
186 { 946523922, 1576574394, 454008742 },
187 { 2920407577, 1597744532, 191940679 },
188 { 1627717094, 1589708641, 309595372 },
189 { 2062650405, 126130591, 189567235 },
190 { 615054086, 267042180, 382347871 },
191 { 1719470156, 1681043157, 238769593 },
192 { 961520328, 1992112863, 240663313 },
193 { 2923061544, 81858141, 402250056 },
194 { 808455044, 487635820, 302549471 },
195 { 3213265361, 1681059681, 461303277 },
196 { 1883955251, 1318650285, 254810522 },
197 { 781279533, 1017987605, 179445770 },
198 { 570193549, 1008968995, 459186762 },
199 { 3103538692, 624914534, 466273834 },
200 { 834835886, 1960521414, 331825355 },
201 { 1807393093, 1292064821, 246867396 },
202 { 2100845347, 1578757629, 56837012 }
203 #endif
204 };
205
206 // Define this for (cheap) consistency checks.
207 //#define DEBUG_FFTP3M
208
209 // Define this for extensive consistency checks.
210 //#define DEBUG_FFTP3M_OPERATIONS
211
212 // Define the algorithm of the backward FFT:
213 // Either FORWARD (a normal FFT followed by a permutation)
214 // or RECIPROOT (an FFT with reciprocal root of unity)
215 // or CLEVER (an FFT with reciprocal root of unity but clever computation
216 // of the reciprocals).
217 // Drawback of FORWARD: the permutation pass.
218 // Drawback of RECIPROOT: need all the powers of the root, not only half of them.
219 #define FORWARD 42
220 #define RECIPROOT 43
221 #define CLEVER 44
222 #define FFTP3M_BACKWARD CLEVER
223
224 #ifdef DEBUG_FFTP3M_OPERATIONS
225 #define check_fftp3m_word(x) if ((x.w1 >= p1) || (x.w2 >= p2) || (x.w3 >= p3)) throw runtime_exception()
226 #else
227 #define check_fftp3m_word(x)
228 #endif
229
230 // r := 0 mod p
zerop3m(fftp3m_word & r)231 static inline void zerop3m (fftp3m_word& r)
232 {
233 r.w1 = 0;
234 r.w2 = 0;
235 r.w3 = 0;
236 }
237
238 // r := x mod p
setp3m(uint32 x,fftp3m_word & r)239 static inline void setp3m (uint32 x, fftp3m_word& r)
240 {
241 var uint32 hi;
242 var uint32 lo;
243 hi = x >> (32-n1); lo = x << n1; divu_6432_3232(hi,lo,p1, ,r.w1=);
244 hi = x >> (32-n2); lo = x << n2; divu_6432_3232(hi,lo,p2, ,r.w2=);
245 hi = x >> (32-n3); lo = x << n3; divu_6432_3232(hi,lo,p3, ,r.w3=);
246 }
247
248 // Chinese remainder theorem:
249 // (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z) == Z / p1*p2*p3 Z = Z / P Z.
250 // Return r as an integer >= 0, < p1*p2*p3, as 3-digit-sequence res.
251 // This routine also does the "de-Montgomerizing".
combinep3m(const fftp3m_word & r,uintD * resLSDptr)252 static void combinep3m (const fftp3m_word& r, uintD* resLSDptr)
253 {
254 check_fftp3m_word(r);
255 // Compute e1 * v1 * r.w1 + e2 * v2 * r.w2 + e3 * v3 * r.w3 where
256 // vi == 2^-ni mod pi, and the idempotents ei are found as:
257 // xgcd(pi,p/pi) = 1 = ui*pi + vi*P/pi, ei = 1 - ui*pi.
258 // e1 = 1709008312966733882383995583
259 // e2 = 2781580629833601225216537109
260 // e3 = 1602397205945693664242711343
261 // e1*v1 = 965961209845827124691257285
262 // e2*v2 = 927193593718183024654651603
263 // e3*v3 = 969191855872201893987508667
264 // We will have 0 <= e1*v1 * r.w1 + e2*v2 * r.w2 + e3*v3 * r.w3 <
265 // < e1*v1 * p1 + e2*v2 * p2 + e3*v3 * p3 < 3 * 2^32 * p1*p2*p3 < 2^128.
266 // The sum of the products fits in 4 digits, we divide by p1*p2*p3
267 // as a 3-digit sequence, thus getting the remainder.
268 #if 0
269 #if CL_DS_BIG_ENDIAN_P
270 var const uintD p123 [3] = { 0x09D80000, 0x7C200001, 0x54000001 };
271 var const uintD e1v1 [3] = { 0x031F063E, 0x1CD1F37E, 0x20E0C7C5 };
272 var const uintD e2v2 [3] = { 0x02FEF4E1, 0x6E62C875, 0x788590D3 };
273 var const uintD e3v3 [3] = { 0x0321B25B, 0xC8DB371B, 0xF0E861BB };
274 #else
275 var const uintD p123 [3] = { 0x54000001, 0x7C200001, 0x09D80000 };
276 var const uintD e1v1 [3] = { 0x20E0C7C5, 0x1CD1F37E, 0x031F063E };
277 var const uintD e2v2 [3] = { 0x788590D3, 0x6E62C875, 0x02FEF4E1 };
278 var const uintD e3v3 [3] = { 0xF0E861BB, 0xC8DB371B, 0x0321B25B };
279 #endif
280 #else
281 // The final division step requires a shift left by 4 bits in order
282 // to normalize p1*p2*p3. We combine this shift left with the
283 // multiplications. Note that since e1v1 + e2v2 + e3v3 < p1*p2*p3,
284 // there is no risk of overflow.
285 #if CL_DS_BIG_ENDIAN_P
286 var const uintD p123 [3] = { 0x9D800007, 0xC2000015, 0x40000010 };
287 var const uintD e1v1 [3] = { 0x31F063E1, 0xCD1F37E2, 0x0E0C7C50 };
288 var const uintD e2v2 [3] = { 0x2FEF4E16, 0xE62C8757, 0x88590D30 };
289 var const uintD e3v3 [3] = { 0x321B25BC, 0x8DB371BF, 0x0E861BB0 };
290 #else
291 var const uintD p123 [3] = { 0x40000010, 0xC2000015, 0x9D800007 };
292 var const uintD e1v1 [3] = { 0x0E0C7C50, 0xCD1F37E2, 0x31F063E1 };
293 var const uintD e2v2 [3] = { 0x88590D30, 0xE62C8757, 0x2FEF4E16 };
294 var const uintD e3v3 [3] = { 0x0E861BB0, 0x8DB371BF, 0x321B25BC };
295 #endif
296 #endif
297 var uintD sum [4];
298 var uintD* const sumLSDptr = arrayLSDptr(sum,4);
299 mulu_loop_lsp(r.w1,arrayLSDptr(e1v1,3), sumLSDptr,3);
300 lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w2,arrayLSDptr(e2v2,3), sumLSDptr,3);
301 lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w3,arrayLSDptr(e3v3,3), sumLSDptr,3);
302 #if 0
303 {CL_ALLOCA_STACK;
304 var DS q;
305 var DS r;
306 UDS_divide(arrayMSDptr(sum,4),4,arrayLSDptr(sum,4),
307 arrayMSDptr(p123,3),3,arrayLSDptr(p123,3),
308 &q,&r
309 );
310 ASSERT(q.len <= 1)
311 ASSERT(r.len <= 3)
312 copy_loop_lsp(r.LSDptr,arrayLSDptr(sum,4),r.len);
313 DS_clear_loop(arrayMSDptr(sum,4) mspop 1,3-r.len,arrayLSDptr(sum,4) lspop r.len);
314 }
315 #else
316 // Division wie UDS_divide mit a_len=4, b_len=3.
317 {
318 var uintD q_stern;
319 var uintD c1;
320 #if HAVE_DD
321 divuD(highlowDD(lspref(sumLSDptr,3),lspref(sumLSDptr,2)),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
322 { var uintDD c2 = highlowDD(c1,lspref(sumLSDptr,1));
323 var uintDD c3 = muluD(lspref(arrayLSDptr(p123,3),1),q_stern);
324 if (c3 > c2)
325 { q_stern = q_stern-1;
326 if (c3-c2 > highlowDD(lspref(arrayLSDptr(p123,3),2),lspref(arrayLSDptr(p123,3),1)))
327 { q_stern = q_stern-1; }
328 } }
329 #else
330 divuD(lspref(sumLSDptr,3),lspref(sumLSDptr,2),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
331 { var uintD c2lo = lspref(sumLSDptr,1);
332 var uintD c3hi;
333 var uintD c3lo;
334 muluD(lspref(arrayLSDptr(p123,3),1),q_stern, c3hi=,c3lo=);
335 if ((c3hi > c1) || ((c3hi == c1) && (c3lo > c2lo)))
336 { q_stern = q_stern-1;
337 c3hi -= c1; if (c3lo < c2lo) { c3hi--; }; c3lo -= c2lo;
338 if ((c3hi > lspref(arrayLSDptr(p123,3),2)) || ((c3hi == lspref(arrayLSDptr(p123,3),2)) && (c3lo > lspref(arrayLSDptr(p123,3),1))))
339 { q_stern = q_stern-1; }
340 } }
341 #endif
342 if (!(q_stern==0))
343 { var uintD carry = mulusub_loop_lsp(q_stern,arrayLSDptr(p123,3),sumLSDptr,3);
344 if (carry > lspref(sumLSDptr,3))
345 { q_stern = q_stern-1;
346 addto_loop_lsp(arrayLSDptr(p123,3),sumLSDptr,3);
347 } }
348 }
349 #endif
350 #ifdef DEBUG_FFTP3M_OPERATIONS
351 if (compare_loop_msp(sumLSDptr lspop 3,arrayMSDptr(p123,3),3) >= 0)
352 throw runtime_exception();
353 #endif
354 // Renormalize the division's remainder: shift right by 4 bits.
355 shiftrightcopy_loop_msp(sumLSDptr lspop 3,resLSDptr lspop 3,3,4,0);
356 }
357
358 // r := (a + b) mod p
addp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)359 static inline void addp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
360 {
361 var uint32 x;
362
363 check_fftp3m_word(a); check_fftp3m_word(b);
364 // Add single 32-bit words mod pi.
365 if (((x = (a.w1 + b.w1)) < b.w1) || (x >= p1))
366 x -= p1;
367 r.w1 = x;
368 if ((x = (a.w2 + b.w2)) >= p2) // x doesn't overflow since p2 <= 2^31
369 x -= p2;
370 r.w2 = x;
371 if ((x = (a.w3 + b.w3)) >= p3) // x doesn't overflow since p3 <= 2^31
372 x -= p3;
373 r.w3 = x;
374 check_fftp3m_word(r);
375 }
376
377 // r := (a - b) mod p
subp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)378 static inline void subp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
379 {
380 check_fftp3m_word(a); check_fftp3m_word(b);
381 // Subtract single 32-bit words mod pi.
382 r.w1 = (a.w1 < b.w1 ? a.w1-b.w1+p1 : a.w1-b.w1);
383 r.w2 = (a.w2 < b.w2 ? a.w2-b.w2+p2 : a.w2-b.w2);
384 r.w3 = (a.w3 < b.w3 ? a.w3-b.w3+p3 : a.w3-b.w3);
385 check_fftp3m_word(r);
386 }
387
388 // r := (a * b) mod p
mulp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & res)389 static void mulp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& res)
390 {
391 check_fftp3m_word(a); check_fftp3m_word(b);
392 // Multiplication à la Montgomery:
393 #define mul_mod_p(aw,bw,result_zuweisung,p,m,n,j,js,table) \
394 { /* table[i] == i*2^(m-1) mod p for 0 <= i < 2^(m-n+1) */\
395 var uint32 hi; \
396 var uint32 lo; \
397 mulu32(aw,bw, hi=,lo=); \
398 /* hi has 2m-32 bits */ \
399 var const int l = (m-1)-(32-n); \
400 var uint32 r = table[hi>>l]; \
401 hi = ((hi << (32-l)) >> (n-l)) | (lo >> n); \
402 /* hi = c mod 2^(m-1), has m-1 bits */ \
403 lo = lo & (bit(n)-1); \
404 /* lo = d, has n bits */ \
405 lo = (lo << js) - lo; \
406 /* lo = d*|V|, has m bits */ \
407 /* Finally compute (r + hi - lo) mod p. */ \
408 if (m < 32) { \
409 r += hi; \
410 if (r >= p) \
411 { r = r - p; } \
412 } else { \
413 if (((r += hi) < hi) || (r >= p)) \
414 { r = r - p; } \
415 } \
416 r = (r < lo ? r-lo+p : r-lo); \
417 /* ifdef DEBUG_FFTP3M_OPERATIONS * \
418 var uint32 tmp; \
419 mulu32(aw,bw, hi=,lo=); \
420 divu_6432_3232(hi,lo,p, ,tmp=); \
421 mulu32(tmp,j, hi=, lo=); \
422 divu_6432_3232(hi,lo,p, ,tmp=); \
423 if (tmp != 0) { tmp = p-tmp; } \
424 if (tmp != r) \
425 throw runtime_exception(); \
426 * endif DEBUG_FFTP3M_OPERATIONS */ \
427 result_zuweisung r; \
428 }
429 // p1 = 3*2^30+1, n1 = 30, j1 = 3 = 2^2-1
430 static uint32 table1 [8] =
431 { 0, 2147483648, 1073741823, 3221225471,
432 2147483646, 1073741821, 3221225469, 2147483644
433 };
434 mul_mod_p(a.w1,b.w1,res.w1=,p1,32,30,3,2,table1);
435 // p2 = 15*2^27+1, n2 = 27, j2 = 15 = 2^4-1
436 static uint32 table2 [32] =
437 { 0, 1073741824, 134217727, 1207959551,
438 268435454, 1342177278, 402653181, 1476395005,
439 536870908, 1610612732, 671088635, 1744830459,
440 805306362, 1879048186, 939524089, 2013265913,
441 1073741816, 134217719, 1207959543, 268435446,
442 1342177270, 402653173, 1476394997, 536870900,
443 1610612724, 671088627, 1744830451, 805306354,
444 1879048178, 939524081, 2013265905, 1073741808
445 };
446 mul_mod_p(a.w2,b.w2,res.w2=,p2,31,27,15,4,table2);
447 // p3 = 7*2^26+1, n3 = 26, j3 = 7 = 2^3-1
448 static uint32 table3 [16] =
449 { 0, 268435456, 67108863, 335544319,
450 134217726, 402653182, 201326589, 469762045,
451 268435452, 67108859, 335544315, 134217722,
452 402653178, 201326585, 469762041, 268435448
453 };
454 mul_mod_p(a.w3,b.w3,res.w3=,p3,29,26,7,3,table3);
455 #undef mul_mod_p
456 check_fftp3m_word(res);
457 }
458 #ifdef DEBUG_FFTP3M_OPERATIONS
mulp3m_doublecheck(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)459 static void mulp3m_doublecheck (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
460 {
461 fftp3m_word zero, ma, mb, or;
462 zerop3m(zero);
463 subp3m(zero,a, ma);
464 subp3m(zero,b, mb);
465 mulp3m(ma,mb, or);
466 mulp3m(a,b, r);
467 if (!((r.w1 == or.w1) && (r.w2 == or.w2) && (r.w3 == or.w3)))
468 throw runtime_exception();
469 }
470 #define mulp3m mulp3m_doublecheck
471 #endif /* DEBUG_FFTP3M_OPERATIONS */
472
473 // b := (a / 2) mod p
shiftp3m(const fftp3m_word & a,fftp3m_word & b)474 static inline void shiftp3m (const fftp3m_word& a, fftp3m_word& b)
475 {
476 check_fftp3m_word(a);
477 b.w1 = (a.w1 & 1 ? (a.w1 >> 1) + (p1 >> 1) + 1 : (a.w1 >> 1));
478 b.w2 = (a.w2 & 1 ? (a.w2 >> 1) + (p2 >> 1) + 1 : (a.w2 >> 1));
479 b.w3 = (a.w3 & 1 ? (a.w3 >> 1) + (p3 >> 1) + 1 : (a.w3 >> 1));
480 check_fftp3m_word(b);
481 }
482
483 #ifndef _BIT_REVERSE
484 #define _BIT_REVERSE
485 // Reverse an n-bit number x. n>0.
bit_reverse(uintL n,uintC x)486 static uintC bit_reverse (uintL n, uintC x)
487 {
488 var uintC y = 0;
489 do {
490 y <<= 1;
491 y |= (x & 1);
492 x >>= 1;
493 } while (!(--n == 0));
494 return y;
495 }
496 #endif
497
498 // Compute an convolution mod p using FFT: z[0..N-1] := x[0..N-1] * y[0..N-1].
fftp3m_convolution(const uintL n,const uintC N,fftp3m_word * x,fftp3m_word * y,fftp3m_word * z)499 static void fftp3m_convolution (const uintL n, const uintC N, // N = 2^n
500 fftp3m_word * x, // N words
501 fftp3m_word * y, // N words
502 fftp3m_word * z // N words result
503 )
504 {
505 CL_ALLOCA_STACK;
506 #if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
507 var fftp3m_word* const w = cl_alloc_array(fftp3m_word,N);
508 #else
509 var fftp3m_word* const w = cl_alloc_array(fftp3m_word,(N>>1)+1);
510 #endif
511 var uintC i;
512 // Initialize w[i] to w^i, w a primitive N-th root of unity.
513 w[0] = fftp3m_roots_of_1[0];
514 w[1] = fftp3m_roots_of_1[n];
515 #if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
516 for (i = 2; i < N; i++)
517 mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
518 #else // need only half of the roots
519 for (i = 2; i < N>>1; i++)
520 mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
521 #endif
522 #ifdef DEBUG_FFTP3M
523 // Check that w is really a primitive N-th root of unity.
524 {
525 var fftp3m_word w_N;
526 mulp3m(w[N-1],fftp3m_roots_of_1[n], w_N);
527 if (!( w_N.w1 == (uint32)1<<n1
528 && w_N.w2 == (uint32)1<<n2
529 && w_N.w3 == (uint32)1<<n3))
530 throw runtime_exception();
531 w_N = w[N>>1];
532 if (!( w_N.w1 == p1-((uint32)1<<n1)
533 && w_N.w2 == p2-((uint32)1<<n2)
534 && w_N.w3 == p3-((uint32)1<<n3)))
535 throw runtime_exception();
536 }
537 #endif
538 var bool squaring = (x == y);
539 // Do an FFT of length N on x.
540 {
541 var sintL l;
542 /* l = n-1 */ {
543 var const uintC tmax = N>>1; // tmax = 2^(n-1)
544 for (var uintC t = 0; t < tmax; t++) {
545 var uintC i1 = t;
546 var uintC i2 = i1 + tmax;
547 // Butterfly: replace (x(i1),x(i2)) by
548 // (x(i1) + x(i2), x(i1) - x(i2)).
549 var fftp3m_word tmp;
550 tmp = x[i2];
551 subp3m(x[i1],tmp, x[i2]);
552 addp3m(x[i1],tmp, x[i1]);
553 }
554 }
555 for (l = n-2; l>=0; l--) {
556 var const uintC smax = (uintC)1 << (n-1-l);
557 var const uintC tmax = (uintC)1 << l;
558 for (var uintC s = 0; s < smax; s++) {
559 var uintC exp = bit_reverse(n-1-l,s) << l;
560 for (var uintC t = 0; t < tmax; t++) {
561 var uintC i1 = (s << (l+1)) + t;
562 var uintC i2 = i1 + tmax;
563 // Butterfly: replace (x(i1),x(i2)) by
564 // (x(i1) + w^exp*x(i2), x(i1) - w^exp*x(i2)).
565 var fftp3m_word tmp;
566 mulp3m(x[i2],w[exp], tmp);
567 subp3m(x[i1],tmp, x[i2]);
568 addp3m(x[i1],tmp, x[i1]);
569 }
570 }
571 }
572 }
573 // Do an FFT of length N on y.
574 if (!squaring) {
575 var sintL l;
576 /* l = n-1 */ {
577 var uintC const tmax = N>>1; // tmax = 2^(n-1)
578 for (var uintC t = 0; t < tmax; t++) {
579 var uintC i1 = t;
580 var uintC i2 = i1 + tmax;
581 // Butterfly: replace (y(i1),y(i2)) by
582 // (y(i1) + y(i2), y(i1) - y(i2)).
583 var fftp3m_word tmp;
584 tmp = y[i2];
585 subp3m(y[i1],tmp, y[i2]);
586 addp3m(y[i1],tmp, y[i1]);
587 }
588 }
589 for (l = n-2; l>=0; l--) {
590 var const uintC smax = (uintC)1 << (n-1-l);
591 var const uintC tmax = (uintC)1 << l;
592 for (var uintC s = 0; s < smax; s++) {
593 var uintC exp = bit_reverse(n-1-l,s) << l;
594 for (var uintC t = 0; t < tmax; t++) {
595 var uintC i1 = (s << (l+1)) + t;
596 var uintC i2 = i1 + tmax;
597 // Butterfly: replace (y(i1),y(i2)) by
598 // (y(i1) + w^exp*y(i2), y(i1) - w^exp*y(i2)).
599 var fftp3m_word tmp;
600 mulp3m(y[i2],w[exp], tmp);
601 subp3m(y[i1],tmp, y[i2]);
602 addp3m(y[i1],tmp, y[i1]);
603 }
604 }
605 }
606 }
607 // Multiply the transformed vectors into z.
608 for (i = 0; i < N; i++)
609 mulp3m(x[i],y[i], z[i]);
610 // Undo an FFT of length N on z.
611 {
612 var uintL l;
613 for (l = 0; l < n-1; l++) {
614 var const uintC smax = (uintC)1 << (n-1-l);
615 var const uintC tmax = (uintC)1 << l;
616 #if FFTP3M_BACKWARD != CLEVER
617 for (var uintC s = 0; s < smax; s++) {
618 var uintC exp = bit_reverse(n-1-l,s) << l;
619 #if FFTP3M_BACKWARD == RECIPROOT
620 if (exp > 0)
621 exp = N - exp; // negate exp (use w^-1 instead of w)
622 #endif
623 for (var uintC t = 0; t < tmax; t++) {
624 var uintC i1 = (s << (l+1)) + t;
625 var uintC i2 = i1 + tmax;
626 // Inverse Butterfly: replace (z(i1),z(i2)) by
627 // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)).
628 var fftp3m_word sum;
629 var fftp3m_word diff;
630 addp3m(z[i1],z[i2], sum);
631 subp3m(z[i1],z[i2], diff);
632 shiftp3m(sum, z[i1]);
633 mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
634 }
635 }
636 #else // FFTP3M_BACKWARD == CLEVER: clever handling of negative exponents
637 /* s = 0, exp = 0 */ {
638 for (var uintC t = 0; t < tmax; t++) {
639 var uintC i1 = t;
640 var uintC i2 = i1 + tmax;
641 // Inverse Butterfly: replace (z(i1),z(i2)) by
642 // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
643 // with exp <-- 0.
644 var fftp3m_word sum;
645 var fftp3m_word diff;
646 addp3m(z[i1],z[i2], sum);
647 subp3m(z[i1],z[i2], diff);
648 shiftp3m(sum, z[i1]);
649 shiftp3m(diff, z[i2]);
650 }
651 }
652 for (var uintC s = 1; s < smax; s++) {
653 var uintC exp = bit_reverse(n-1-l,s) << l;
654 exp = (N>>1) - exp; // negate exp (use w^-1 instead of w)
655 for (var uintC t = 0; t < tmax; t++) {
656 var uintC i1 = (s << (l+1)) + t;
657 var uintC i2 = i1 + tmax;
658 // Inverse Butterfly: replace (z(i1),z(i2)) by
659 // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
660 // with exp <-- (N/2 - exp).
661 var fftp3m_word sum;
662 var fftp3m_word diff;
663 addp3m(z[i1],z[i2], sum);
664 subp3m(z[i2],z[i1], diff); // note that w^(N/2) = -1
665 shiftp3m(sum, z[i1]);
666 mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
667 }
668 }
669 #endif
670 }
671 /* l = n-1 */ {
672 var const uintC tmax = N>>1; // tmax = 2^(n-1)
673 for (var uintC t = 0; t < tmax; t++) {
674 var uintC i1 = t;
675 var uintC i2 = i1 + tmax;
676 // Inverse Butterfly: replace (z(i1),z(i2)) by
677 // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/2).
678 var fftp3m_word sum;
679 var fftp3m_word diff;
680 addp3m(z[i1],z[i2], sum);
681 subp3m(z[i1],z[i2], diff);
682 shiftp3m(sum, z[i1]);
683 shiftp3m(diff, z[i2]);
684 }
685 }
686 }
687 #if FFTP3M_BACKWARD == FORWARD
688 // Swap z[i] and z[N-i] for 0 < i < N/2.
689 for (i = (N>>1)-1; i > 0; i--) {
690 var fftp3m_word tmp = z[i];
691 z[i] = z[N-i];
692 z[N-i] = tmp;
693 }
694 #endif
695 }
696
mulu_fft_modp3m(const uintD * sourceptr1,uintC len1,const uintD * sourceptr2,uintC len2,uintD * destptr)697 static void mulu_fft_modp3m (const uintD* sourceptr1, uintC len1,
698 const uintD* sourceptr2, uintC len2,
699 uintD* destptr)
700 // Es ist 2 <= len1 <= len2.
701 {
702 // Methode:
703 // source1 ist ein Stück der Länge N1, source2 ein oder mehrere Stücke
704 // der Länge N2, mit N1+N2 <= N, wobei N Zweierpotenz ist.
705 // sum(i=0..N-1, x_i b^i) * sum(i=0..N-1, y_i b^i) wird errechnet,
706 // indem man die beiden Polynome
707 // sum(i=0..N-1, x_i T^i), sum(i=0..N-1, y_i T^i)
708 // multipliziert, und zwar durch Fourier-Transformation (s.o.).
709 var uint32 n;
710 integerlengthC(len1-1, n=); // 2^(n-1) < len1 <= 2^n
711 var uintC len = (uintC)1 << n; // kleinste Zweierpotenz >= len1
712 // Wählt man N = len, so hat man ceiling(len2/(len-len1+1)) * FFT(len).
713 // Wählt man N = 2*len, so hat man ceiling(len2/(2*len-len1+1)) * FFT(2*len).
714 // Wir wählen das billigere von beiden:
715 // Bei ceiling(len2/(len-len1+1)) <= 2 * ceiling(len2/(2*len-len1+1))
716 // nimmt man N = len, bei ....... > ........ dagegen N = 2*len.
717 // (Wahl von N = 4*len oder mehr bringt nur in Extremfällen etwas.)
718 if (len2 > 2 * (len-len1+1) * (len2 <= (2*len-len1+1) ? 1 : ceiling(len2,(2*len-len1+1)))) {
719 n = n+1;
720 len = len << 1;
721 }
722 var const uintC N = len; // N = 2^n
723 CL_ALLOCA_STACK;
724 var fftp3m_word* const x = cl_alloc_array(fftp3m_word,N);
725 var fftp3m_word* const y = cl_alloc_array(fftp3m_word,N);
726 #ifdef DEBUG_FFTP3M
727 var fftp3m_word* const z = cl_alloc_array(fftp3m_word,N);
728 #else
729 var fftp3m_word* const z = x; // put z in place of x - saves memory
730 #endif
731 var uintD* const tmpprod = cl_alloc_array(uintD,len1+1);
732 var uintP i;
733 var uintC destlen = len1+len2;
734 clear_loop_lsp(destptr,destlen);
735 do {
736 var uintC len2p; // length of a piece of source2
737 len2p = N - len1 + 1;
738 if (len2p > len2)
739 len2p = len2;
740 // len2p = min(N-len1+1,len2).
741 if (len2p == 1) {
742 // cheap case
743 var uintD* tmpptr = arrayLSDptr(tmpprod,len1+1);
744 mulu_loop_lsp(lspref(sourceptr2,0),sourceptr1,tmpptr,len1);
745 if (addto_loop_lsp(tmpptr,destptr,len1+1))
746 if (inc_loop_lsp(destptr lspop (len1+1),destlen-(len1+1)))
747 throw runtime_exception();
748 } else {
749 var uintC destlenp = len1 + len2p - 1;
750 // destlenp = min(N,destlen-1).
751 var bool squaring = ((sourceptr1 == sourceptr2) && (len1 == len2p));
752 // Fill factor x.
753 {
754 for (i = 0; i < len1; i++)
755 setp3m(lspref(sourceptr1,i), x[i]);
756 for (i = len1; i < N; i++)
757 zerop3m(x[i]);
758 }
759 // Fill factor y.
760 if (!squaring) {
761 for (i = 0; i < len2p; i++)
762 setp3m(lspref(sourceptr2,i), y[i]);
763 for (i = len2p; i < N; i++)
764 zerop3m(y[i]);
765 }
766 // Multiply.
767 if (!squaring)
768 fftp3m_convolution(n,N, &x[0], &y[0], &z[0]);
769 else
770 fftp3m_convolution(n,N, &x[0], &x[0], &z[0]);
771 // Add result to destptr[-destlen..-1]:
772 {
773 var uintD* ptr = destptr;
774 // ac2|ac1|ac0 are an accumulator.
775 var uint32 ac0 = 0;
776 var uint32 ac1 = 0;
777 var uint32 ac2 = 0;
778 var uint32 tmp;
779 for (i = 0; i < destlenp; i++) {
780 // Convert z[i] to a 3-digit number.
781 var uintD z_i[3];
782 combinep3m(z[i],arrayLSDptr(z_i,3));
783 #ifdef DEBUG_FFTP3M
784 if (!(arrayLSref(z_i,3,2) < N))
785 throw runtime_exception();
786 #endif
787 // Add z[i] to the accumulator.
788 tmp = arrayLSref(z_i,3,0);
789 if ((ac0 += tmp) < tmp) {
790 if (++ac1 == 0)
791 ++ac2;
792 }
793 tmp = arrayLSref(z_i,3,1);
794 if ((ac1 += tmp) < tmp)
795 ++ac2;
796 tmp = arrayLSref(z_i,3,2);
797 ac2 += tmp;
798 // Add the accumulator's least significant word to destptr:
799 tmp = lspref(ptr,0);
800 if ((ac0 += tmp) < tmp) {
801 if (++ac1 == 0)
802 ++ac2;
803 }
804 lspref(ptr,0) = ac0;
805 lsshrink(ptr);
806 ac0 = ac1;
807 ac1 = ac2;
808 ac2 = 0;
809 }
810 // ac2 = 0.
811 if (ac1 > 0) {
812 if (!((i += 2) <= destlen))
813 throw runtime_exception();
814 tmp = lspref(ptr,0);
815 if ((ac0 += tmp) < tmp)
816 ++ac1;
817 lspref(ptr,0) = ac0;
818 lsshrink(ptr);
819 tmp = lspref(ptr,0);
820 ac1 += tmp;
821 lspref(ptr,0) = ac1;
822 lsshrink(ptr);
823 if (ac1 < tmp)
824 if (inc_loop_lsp(ptr,destlen-i))
825 throw runtime_exception();
826 } else if (ac0 > 0) {
827 if (!((i += 1) <= destlen))
828 throw runtime_exception();
829 tmp = lspref(ptr,0);
830 ac0 += tmp;
831 lspref(ptr,0) = ac0;
832 lsshrink(ptr);
833 if (ac0 < tmp)
834 if (inc_loop_lsp(ptr,destlen-i))
835 throw runtime_exception();
836 }
837 }
838 #ifdef DEBUG_FFTP3M
839 // If destlenp < N, check that the remaining z[i] are 0.
840 for (i = destlenp; i < N; i++)
841 if (z[i].w1 > 0 || z[i].w2 > 0 || z[i].w3 > 0)
842 throw runtime_exception();
843 #endif
844 }
845 // Decrement len2.
846 destptr = destptr lspop len2p;
847 destlen -= len2p;
848 sourceptr2 = sourceptr2 lspop len2p;
849 len2 -= len2p;
850 } while (len2 > 0);
851 }
852
853 #undef n3
854 #undef n2
855 #undef n1
856 #undef p3
857 #undef p2
858 #undef p1
859