1 // Fast integer multiplication using FFT in a modular ring.
2 // Bruno Haible 5.5.1996, 30.6.1996, 20.8.1996
3 
4 // FFT in the complex domain has the drawback that it needs careful round-off
5 // error analysis. So here we choose another field of characteristic 0: Q_p.
6 // Since Q_p contains exactly the (p-1)th roots of unity, we choose
7 // p == 1 mod N and have the Nth roots of unity (N = 2^n) in Q_p and
8 // even in Z_p. Actually, we compute in Z/(p^m Z).
9 
10 // All operations the FFT algorithm needs is addition, subtraction,
11 // multiplication, multiplication by the Nth root of unity and division
12 // by N. Hence we can use the domain Z/(p^m Z) even if p is not a prime!
13 
14 // We want to compute the convolution of N 32-bit words. The resulting
15 // words are < (2^32)^2 * N. To avoid computing with numbers greater than
16 // 32 bits, we compute in Z/pZ for three different primes p in parallel,
17 // i.e. we compute in the ring (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z). We choose
18 // p1 = 3*2^30+1, p2 = 15*2^27+1, p3 = 7*2^26+1.
19 // Because of p1*p2*p3 >= 2^91 >= (2^32)^2 * N, the chinese remainder theorem
20 // will faithfully combine 3 32-bit words to a word < (2^32)^2 * N.
21 
22 // Furthermore we use Montgomery's modular multiplication trick
23 // [Peter L. Montgomery: Modular multiplication without trial division,
24 //  Mathematics of Computation 44 (1985), 519-521.]
25 //
26 // Assume we want to compute modulo M, M odd. V and N will be chosen
27 // so that V*N==1 mod M and that (a,b) --> a*b*V mod M can be more easily
28 // computed than (a,b) --> a*b mod M. Then, we have a ring isomorphism
29 //   (Z/MZ, +, * mod M)  \isomorph  (Z/MZ, +, (a,b) --> a*b*V mod M)
30 //   x mod M             -------->  x*N mod M
31 // It is thus preferrable to use x*N mod M as a "representation" of x mod M,
32 // especially for computations which involve at least several multiplications.
33 //
34 // The precise algorithm to compute a*b*V mod M, given a and b, and the choice
35 // of N and V depend on M and on the hardware. The general idea is this:
36 // Choose N = 2^n, so that division by N is easy. Recall that V == N^-1 mod M.
37 // 1. Given a and b as m-bit numbers (M <= 2^m), compute a*b in full
38 //    precision.
39 // 2. Write a*b = c*N+d, i.e. split it into components c and d.
40 // 3. Now a*b*V = c*N*V+d*V == c+d*V mod M.
41 // 4. Instead of computing d*V mod M
42 //    a. by full multiplication and then division mod M, or
43 //    b. by left shifts: repeated application of
44 //          x := 2*x+(0 or 1); if (x >= M) { x := x-M; }
45 //    we compute
46 //    c. by right shifts (recall that d*V == d*2^-n mod M): repeated application
47 //       of   if (x odd) { x := (x+M)/2; } else { x := x/2; }
48 // Usually one will choose N = 2^m, so that c and d have both m bits.
49 // Several variations are possible: In step 4 one can implement the right
50 // shifts in hardware. Or (for example when N = 2^160 and working on a
51 // 32-bit machine) one can do 32 shift steps at the same time:
52 // Choose M' == M^-1 mod 2^32 and compute n/32 times
53 //       x := (x - ((x mod 2^32) * M' mod 2^32) * M) / 2^32.
54 //
55 // Here, we deal with moduli M = p_i = j*2^k+1. These form of primes comes
56 // in because we need 2^n-th roots of unity mod M. But is also comes handy
57 // for Montgomery multiplication: Instead of choosing N = 2^32 (which makes
58 // up for very easy splitting in step 1) and V = j^2*2^(2*k-32), we better
59 // choose N = 2^k and V = -j. The algorithm now goes like this (recall that
60 // M is an m-bit number and j is an (m-k)-bit number):
61 // 1. Compute a*b in full precision, as a 2*m <= 64 bit number.
62 // 2. Split a*b = c*N+d, with c an (2m-k)-bit number and d an k-bit number.
63 // 3. a*b*V == c+d*V mod M.
64 // 4. Compute c mod M by splitting off the leading (m-k+1) bits of c and
65 //    using table lookup; the remainder (c mod 2^(m-1)) is already reduced
66 //    mod M.
67 //    Compute d*|V| the standard way; |V| has only few bits. d*|V| is
68 //    already reduced mod M, because d*|V| < j*2^k < M.
69 
70 // In order to get best performance, we carefully choose the primes so that
71 // a. the table of size 2^(m-k+1) doesn't get too large,
72 // b. multiplication by V is easy.
73 // Here is a list of the interesting primes < 2^32:
74 //
75 //                       U*M+V*N = 1
76 //   prime       bits    N=2^n  V    U    2m<=n+32 ?
77 //     M           m       n
78 //
79 //   3*2^30+1     32      30    -3   1        n       (*)
80 //
81 //  13*2^28+1     32      28   -13   1        n
82 //
83 //  15*2^27+1     31      27   -15   1        n       (*)
84 //  17*2^27+1     32      27   -17   1        n
85 //  29*2^27+1     32      27   -29   1        n
86 //
87 //   7*2^26+1     29      26    -7   1        y       (*)
88 //  27*2^26+1     31      26   -27   1        n
89 //  37*2^26+1     32      26   -37   1        n
90 //  43*2^26+1     32      26   -43   1        n
91 //
92 //   5*2^25+1     28      25    -5   1        y
93 //  33*2^25+1     31      25   -33   1        n
94 //  51*2^25+1     31      25   -51   1        n
95 //  63*2^25+1     31      25   -63   1        n
96 //  81*2^25+1     32      25   -81   1        n
97 // 125*2^25+1     32      25  -125   1        n
98 //
99 //  45*2^24+1     30      24   -45   1        n
100 //  73*2^24+1     31      24   -73   1        n
101 // 127*2^24+1     31      24  -127   1        n
102 // 151*2^24+1     32      24  -151   1        n
103 // 157*2^24+1     32      24  -157   1        n
104 // 171*2^24+1     32      24  -171   1        n
105 // 193*2^24+1     32      24  -193   1        n
106 // 235*2^24+1     32      24  -235   1        n
107 // 243*2^24+1     32      24  -243   1        n
108 //
109 //  45*2^23+1     29      23   -45   1        n
110 // ...
111 //
112 // The inequality 2m<=n+32 would mean that c fits in a 32-bit word, but that's
113 // actually irrelevant because we can fetch the most significant bits of c
114 // before actually computing c.
115 // We choose the primes marked with an asterisk.
116 
117 
118 #if !(intDsize==32)
119 #error "fft mod p implemented only for intDsize==32"
120 #endif
121 
122 // Avoid clash with fftp3
123 #define p1 fftp3m_p1
124 #define p2 fftp3m_p2
125 #define p3 fftp3m_p3
126 #define n1 fftp3m_n1
127 #define n2 fftp3m_n2
128 #define n3 fftp3m_n3
129 
130 static const uint32 p1 = 1+(3<<30); // = 3221225473
131 static const uint32 p2 = 1+(15<<27); // = 2013265921
132 static const uint32 p3 = 1+(7<<26); // = 469762049
133 static const uint32 n1 = 30; // Montgomery: represent x mod p1 as x*2^n1 mod p1
134 static const uint32 n2 = 27; // Montgomery: represent x mod p2 as x*2^n2 mod p2
135 static const uint32 n3 = 26; // Montgomery: represent x mod p3 as x*2^n3 mod p3
136 
137 typedef struct {
138 	uint32 w1; // remainder mod p1
139 	uint32 w2; // remainder mod p2
140 	uint32 w3; // remainder mod p3
141 } fftp3m_word;
142 
143 static const fftp3m_word fftp3m_roots_of_1 [26+1] =
144   // roots_of_1[n] is a (2^n)th root of unity in our ring.
145   // (Also roots_of_1[n-1] = roots_of_1[n]^2, but we don't need this.)
146   {
147     #if 0 // in standard representation
148     {          1,          1,          1 },
149     { 3221225472, 2013265920,  469762048 },
150     { 1013946479,  284861408,   19610091 },
151     { 1031213943,  211723194,   26623616 },
152     {  694614138,   78945800,  111570435 },
153     {  347220834,  772607190,  135956445 },
154     {  680684264,  288289890,  181505383 },
155     { 1109768284,  112574482,  145518049 },
156     {  602134989,  928726468,  109721424 },
157     { 1080308101,  875419223,    2847903 },
158     {  381653707,  510575142,  110273149 },
159     {  902453688,  193023072,   65701394 },
160     { 1559299664,  313561437,  181642641 },
161     {  254499731,  121307056,   82315502 },
162     { 1376063215,   20899142,  142137197 },
163     { 1284040478,  956809618,  207661045 },
164     {  336664489,  317295870,  194405005 },
165     {  894491787,  785393806,    2821902 },
166     {  795860341,  738526384,  230963948 },
167     {   23880336,  956561758,   59211404 },
168     {  790585193,  352904935,   95374542 },
169     {  877386874,  836313293,  153165757 },
170     { 1510644826,  971592443,   74027009 },
171     {  353060343,  692611595,   24417505 },
172     {  716717815,  791167605,   26032760 },
173     { 1020271667,  751686895,  150976424 },
174     {  139914905,  477826617,   71902965 }
175     #else // in Montgomery representation
176     { 1073741824,  134217728,   67108864 },
177     { 2147483649, 1879048193,  402653185 },
178     { 1809501489, 1054751064,  265634015 },
179     { 2877487492, 1193844673,  331740947 },
180     { 2989687427,  665825587,  252496823 },
181     { 3105485195, 1961758775,  114795379 },
182     { 1920588894, 1994046595,  175397252 },
183     {  703819063,  932019131,  314756028 },
184     { 3020513810, 1682915367,   51434375 },
185     {  713639124, 1015380543,  133810885 },
186     {  946523922, 1576574394,  454008742 },
187     { 2920407577, 1597744532,  191940679 },
188     { 1627717094, 1589708641,  309595372 },
189     { 2062650405,  126130591,  189567235 },
190     {  615054086,  267042180,  382347871 },
191     { 1719470156, 1681043157,  238769593 },
192     {  961520328, 1992112863,  240663313 },
193     { 2923061544,   81858141,  402250056 },
194     {  808455044,  487635820,  302549471 },
195     { 3213265361, 1681059681,  461303277 },
196     { 1883955251, 1318650285,  254810522 },
197     {  781279533, 1017987605,  179445770 },
198     {  570193549, 1008968995,  459186762 },
199     { 3103538692,  624914534,  466273834 },
200     {  834835886, 1960521414,  331825355 },
201     { 1807393093, 1292064821,  246867396 },
202     { 2100845347, 1578757629,   56837012 }
203     #endif
204   };
205 
206 // Define this for (cheap) consistency checks.
207 //#define DEBUG_FFTP3M
208 
209 // Define this for extensive consistency checks.
210 //#define DEBUG_FFTP3M_OPERATIONS
211 
212 // Define the algorithm of the backward FFT:
213 // Either FORWARD (a normal FFT followed by a permutation)
214 // or     RECIPROOT (an FFT with reciprocal root of unity)
215 // or     CLEVER (an FFT with reciprocal root of unity but clever computation
216 //                of the reciprocals).
217 // Drawback of FORWARD: the permutation pass.
218 // Drawback of RECIPROOT: need all the powers of the root, not only half of them.
219 #define FORWARD   42
220 #define RECIPROOT 43
221 #define CLEVER    44
222 #define FFTP3M_BACKWARD CLEVER
223 
224 #ifdef DEBUG_FFTP3M_OPERATIONS
225 #define check_fftp3m_word(x)  if ((x.w1 >= p1) || (x.w2 >= p2) || (x.w3 >= p3)) throw runtime_exception()
226 #else
227 #define check_fftp3m_word(x)
228 #endif
229 
230 // r := 0 mod p
zerop3m(fftp3m_word & r)231 static inline void zerop3m (fftp3m_word& r)
232 {
233 	r.w1 = 0;
234 	r.w2 = 0;
235 	r.w3 = 0;
236 }
237 
238 // r := x mod p
setp3m(uint32 x,fftp3m_word & r)239 static inline void setp3m (uint32 x, fftp3m_word& r)
240 {
241 	var uint32 hi;
242 	var uint32 lo;
243 	hi = x >> (32-n1); lo = x << n1; divu_6432_3232(hi,lo,p1, ,r.w1=);
244 	hi = x >> (32-n2); lo = x << n2; divu_6432_3232(hi,lo,p2, ,r.w2=);
245 	hi = x >> (32-n3); lo = x << n3; divu_6432_3232(hi,lo,p3, ,r.w3=);
246 }
247 
248 // Chinese remainder theorem:
249 // (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z) == Z / p1*p2*p3 Z = Z / P Z.
250 // Return r as an integer >= 0, < p1*p2*p3, as 3-digit-sequence res.
251 // This routine also does the "de-Montgomerizing".
combinep3m(const fftp3m_word & r,uintD * resLSDptr)252 static void combinep3m (const fftp3m_word& r, uintD* resLSDptr)
253 {
254 	check_fftp3m_word(r);
255 	// Compute e1 * v1 * r.w1 + e2 * v2 * r.w2 + e3 * v3 * r.w3 where
256 	// vi == 2^-ni mod pi, and the idempotents ei are found as:
257 	// xgcd(pi,p/pi) = 1 = ui*pi + vi*P/pi, ei = 1 - ui*pi.
258 	// e1 = 1709008312966733882383995583
259 	// e2 = 2781580629833601225216537109
260 	// e3 = 1602397205945693664242711343
261 	// e1*v1 = 965961209845827124691257285
262 	// e2*v2 = 927193593718183024654651603
263 	// e3*v3 = 969191855872201893987508667
264 	// We will have 0 <= e1*v1 * r.w1 + e2*v2 * r.w2 + e3*v3 * r.w3 <
265 	// < e1*v1 * p1 + e2*v2 * p2 + e3*v3 * p3 < 3 * 2^32 * p1*p2*p3 < 2^128.
266 	// The sum of the products fits in 4 digits, we divide by p1*p2*p3
267 	// as a 3-digit sequence, thus getting the remainder.
268 	#if 0
269 	#if CL_DS_BIG_ENDIAN_P
270 	var const uintD p123 [3] = { 0x09D80000, 0x7C200001, 0x54000001 };
271 	var const uintD e1v1 [3] = { 0x031F063E, 0x1CD1F37E, 0x20E0C7C5 };
272 	var const uintD e2v2 [3] = { 0x02FEF4E1, 0x6E62C875, 0x788590D3 };
273 	var const uintD e3v3 [3] = { 0x0321B25B, 0xC8DB371B, 0xF0E861BB };
274 	#else
275 	var const uintD p123 [3] = { 0x54000001, 0x7C200001, 0x09D80000 };
276 	var const uintD e1v1 [3] = { 0x20E0C7C5, 0x1CD1F37E, 0x031F063E };
277 	var const uintD e2v2 [3] = { 0x788590D3, 0x6E62C875, 0x02FEF4E1 };
278 	var const uintD e3v3 [3] = { 0xF0E861BB, 0xC8DB371B, 0x0321B25B };
279 	#endif
280 	#else
281 	// The final division step requires a shift left by 4 bits in order
282 	// to normalize p1*p2*p3. We combine this shift left with the
283 	// multiplications. Note that since e1v1 + e2v2 + e3v3 < p1*p2*p3,
284 	// there is no risk of overflow.
285 	#if CL_DS_BIG_ENDIAN_P
286 	var const uintD p123 [3] = { 0x9D800007, 0xC2000015, 0x40000010 };
287 	var const uintD e1v1 [3] = { 0x31F063E1, 0xCD1F37E2, 0x0E0C7C50 };
288 	var const uintD e2v2 [3] = { 0x2FEF4E16, 0xE62C8757, 0x88590D30 };
289 	var const uintD e3v3 [3] = { 0x321B25BC, 0x8DB371BF, 0x0E861BB0 };
290 	#else
291 	var const uintD p123 [3] = { 0x40000010, 0xC2000015, 0x9D800007 };
292 	var const uintD e1v1 [3] = { 0x0E0C7C50, 0xCD1F37E2, 0x31F063E1 };
293 	var const uintD e2v2 [3] = { 0x88590D30, 0xE62C8757, 0x2FEF4E16 };
294 	var const uintD e3v3 [3] = { 0x0E861BB0, 0x8DB371BF, 0x321B25BC };
295 	#endif
296 	#endif
297 	var uintD sum [4];
298 	var uintD* const sumLSDptr = arrayLSDptr(sum,4);
299 	mulu_loop_lsp(r.w1,arrayLSDptr(e1v1,3), sumLSDptr,3);
300 	lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w2,arrayLSDptr(e2v2,3), sumLSDptr,3);
301 	lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w3,arrayLSDptr(e3v3,3), sumLSDptr,3);
302 	#if 0
303 	{CL_ALLOCA_STACK;
304 	 var DS q;
305 	 var DS r;
306 	 UDS_divide(arrayMSDptr(sum,4),4,arrayLSDptr(sum,4),
307 	            arrayMSDptr(p123,3),3,arrayLSDptr(p123,3),
308 	            &q,&r
309 	           );
310 	 ASSERT(q.len <= 1)
311 	 ASSERT(r.len <= 3)
312 	 copy_loop_lsp(r.LSDptr,arrayLSDptr(sum,4),r.len);
313 	 DS_clear_loop(arrayMSDptr(sum,4) mspop 1,3-r.len,arrayLSDptr(sum,4) lspop r.len);
314 	}
315 	#else
316 	// Division wie UDS_divide mit a_len=4, b_len=3.
317 	{
318 		var uintD q_stern;
319 		var uintD c1;
320 		#if HAVE_DD
321 		  divuD(highlowDD(lspref(sumLSDptr,3),lspref(sumLSDptr,2)),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
322 		  { var uintDD c2 = highlowDD(c1,lspref(sumLSDptr,1));
323 		    var uintDD c3 = muluD(lspref(arrayLSDptr(p123,3),1),q_stern);
324 		    if (c3 > c2)
325 		      { q_stern = q_stern-1;
326 		        if (c3-c2 > highlowDD(lspref(arrayLSDptr(p123,3),2),lspref(arrayLSDptr(p123,3),1)))
327 		          { q_stern = q_stern-1; }
328 		  }   }
329 		#else
330 		  divuD(lspref(sumLSDptr,3),lspref(sumLSDptr,2),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
331 		  { var uintD c2lo = lspref(sumLSDptr,1);
332 		    var uintD c3hi;
333 		    var uintD c3lo;
334 		    muluD(lspref(arrayLSDptr(p123,3),1),q_stern, c3hi=,c3lo=);
335 		    if ((c3hi > c1) || ((c3hi == c1) && (c3lo > c2lo)))
336 		      { q_stern = q_stern-1;
337 		        c3hi -= c1; if (c3lo < c2lo) { c3hi--; }; c3lo -= c2lo;
338 		        if ((c3hi > lspref(arrayLSDptr(p123,3),2)) || ((c3hi == lspref(arrayLSDptr(p123,3),2)) && (c3lo > lspref(arrayLSDptr(p123,3),1))))
339 		          { q_stern = q_stern-1; }
340                    }   }
341 		#endif
342 		if (!(q_stern==0))
343 		  { var uintD carry = mulusub_loop_lsp(q_stern,arrayLSDptr(p123,3),sumLSDptr,3);
344 		    if (carry > lspref(sumLSDptr,3))
345 		      { q_stern = q_stern-1;
346 		        addto_loop_lsp(arrayLSDptr(p123,3),sumLSDptr,3);
347 		  }   }
348 	}
349 	#endif
350 	#ifdef DEBUG_FFTP3M_OPERATIONS
351 	if (compare_loop_msp(sumLSDptr lspop 3,arrayMSDptr(p123,3),3) >= 0)
352 		throw runtime_exception();
353 	#endif
354 	// Renormalize the division's remainder: shift right by 4 bits.
355 	shiftrightcopy_loop_msp(sumLSDptr lspop 3,resLSDptr lspop 3,3,4,0);
356 }
357 
358 // r := (a + b) mod p
addp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)359 static inline void addp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
360 {
361 	var uint32 x;
362 
363 	check_fftp3m_word(a); check_fftp3m_word(b);
364 	// Add single 32-bit words mod pi.
365 	if (((x = (a.w1 + b.w1)) < b.w1) || (x >= p1))
366 		x -= p1;
367 	r.w1 = x;
368 	if ((x = (a.w2 + b.w2)) >= p2) // x doesn't overflow since p2 <= 2^31
369 		x -= p2;
370 	r.w2 = x;
371 	if ((x = (a.w3 + b.w3)) >= p3) // x doesn't overflow since p3 <= 2^31
372 		x -= p3;
373 	r.w3 = x;
374 	check_fftp3m_word(r);
375 }
376 
377 // r := (a - b) mod p
subp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)378 static inline void subp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
379 {
380 	check_fftp3m_word(a); check_fftp3m_word(b);
381 	// Subtract single 32-bit words mod pi.
382 	r.w1 = (a.w1 < b.w1 ? a.w1-b.w1+p1 : a.w1-b.w1);
383 	r.w2 = (a.w2 < b.w2 ? a.w2-b.w2+p2 : a.w2-b.w2);
384 	r.w3 = (a.w3 < b.w3 ? a.w3-b.w3+p3 : a.w3-b.w3);
385 	check_fftp3m_word(r);
386 }
387 
388 // r := (a * b) mod p
mulp3m(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & res)389 static void mulp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& res)
390 {
391 	check_fftp3m_word(a); check_fftp3m_word(b);
392 	// Multiplication à la Montgomery:
393 	#define mul_mod_p(aw,bw,result_zuweisung,p,m,n,j,js,table)  \
394 	{	/* table[i] == i*2^(m-1) mod p for 0 <= i < 2^(m-n+1) */\
395 		var uint32 hi;						\
396 		var uint32 lo;						\
397 		mulu32(aw,bw, hi=,lo=);					\
398 		/* hi has 2m-32 bits */					\
399 		var const int l = (m-1)-(32-n);				\
400 		var uint32 r = table[hi>>l];				\
401 		hi = ((hi << (32-l)) >> (n-l)) | (lo >> n);		\
402 		/* hi = c mod 2^(m-1), has m-1 bits */			\
403 		lo = lo & (bit(n)-1);					\
404 		/* lo = d, has n bits */				\
405 		lo = (lo << js) - lo;					\
406 		/* lo = d*|V|, has m bits */				\
407 		/* Finally compute (r + hi - lo) mod p. */		\
408 		if (m < 32) {						\
409 			r += hi;					\
410 			if (r >= p)					\
411 				{ r = r - p; }				\
412 		} else {						\
413 			if (((r += hi) < hi) || (r >= p))		\
414 				{ r = r - p; }				\
415 		}							\
416 		r = (r < lo ? r-lo+p : r-lo);				\
417 		/* ifdef DEBUG_FFTP3M_OPERATIONS *			\
418 		var uint32 tmp;						\
419 		mulu32(aw,bw, hi=,lo=);					\
420 		divu_6432_3232(hi,lo,p, ,tmp=);				\
421 		mulu32(tmp,j, hi=, lo=);				\
422 		divu_6432_3232(hi,lo,p, ,tmp=);				\
423 		if (tmp != 0) { tmp = p-tmp; }				\
424 		if (tmp != r)						\
425 			throw runtime_exception();					\
426 		 * endif DEBUG_FFTP3M_OPERATIONS */			\
427 		result_zuweisung r;					\
428 	}
429 	// p1 = 3*2^30+1, n1 = 30, j1 = 3 = 2^2-1
430 	static uint32 table1 [8] =
431 	  {          0, 2147483648, 1073741823, 3221225471,
432 	    2147483646, 1073741821, 3221225469, 2147483644
433 	  };
434 	mul_mod_p(a.w1,b.w1,res.w1=,p1,32,30,3,2,table1);
435 	// p2 = 15*2^27+1, n2 = 27, j2 = 15 = 2^4-1
436 	static uint32 table2 [32] =
437 	  {          0, 1073741824,  134217727, 1207959551,
438 	     268435454, 1342177278,  402653181, 1476395005,
439 	     536870908, 1610612732,  671088635, 1744830459,
440 	     805306362, 1879048186,  939524089, 2013265913,
441 	    1073741816,  134217719, 1207959543,  268435446,
442 	    1342177270,  402653173, 1476394997,  536870900,
443 	    1610612724,  671088627, 1744830451,  805306354,
444 	    1879048178,  939524081, 2013265905, 1073741808
445 	  };
446 	mul_mod_p(a.w2,b.w2,res.w2=,p2,31,27,15,4,table2);
447 	// p3 = 7*2^26+1, n3 = 26, j3 = 7 = 2^3-1
448 	static uint32 table3 [16] =
449 	  {          0,  268435456,   67108863,  335544319,
450 	     134217726,  402653182,  201326589,  469762045,
451 	     268435452,   67108859,  335544315,  134217722,
452 	     402653178,  201326585,  469762041,  268435448
453 	  };
454 	mul_mod_p(a.w3,b.w3,res.w3=,p3,29,26,7,3,table3);
455 	#undef mul_mod_p
456 	check_fftp3m_word(res);
457 }
458 #ifdef DEBUG_FFTP3M_OPERATIONS
mulp3m_doublecheck(const fftp3m_word & a,const fftp3m_word & b,fftp3m_word & r)459 static void mulp3m_doublecheck (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
460 {
461 	fftp3m_word zero, ma, mb, or;
462 	zerop3m(zero);
463 	subp3m(zero,a, ma);
464 	subp3m(zero,b, mb);
465 	mulp3m(ma,mb, or);
466 	mulp3m(a,b, r);
467 	if (!((r.w1 == or.w1) && (r.w2 == or.w2) && (r.w3 == or.w3)))
468 		throw runtime_exception();
469 }
470 #define mulp3m mulp3m_doublecheck
471 #endif /* DEBUG_FFTP3M_OPERATIONS */
472 
473 // b := (a / 2) mod p
shiftp3m(const fftp3m_word & a,fftp3m_word & b)474 static inline void shiftp3m (const fftp3m_word& a, fftp3m_word& b)
475 {
476 	check_fftp3m_word(a);
477 	b.w1 = (a.w1 & 1 ? (a.w1 >> 1) + (p1 >> 1) + 1 : (a.w1 >> 1));
478 	b.w2 = (a.w2 & 1 ? (a.w2 >> 1) + (p2 >> 1) + 1 : (a.w2 >> 1));
479 	b.w3 = (a.w3 & 1 ? (a.w3 >> 1) + (p3 >> 1) + 1 : (a.w3 >> 1));
480 	check_fftp3m_word(b);
481 }
482 
483 #ifndef _BIT_REVERSE
484 #define _BIT_REVERSE
485 // Reverse an n-bit number x. n>0.
bit_reverse(uintL n,uintC x)486 static uintC bit_reverse (uintL n, uintC x)
487 {
488 	var uintC y = 0;
489 	do {
490 		y <<= 1;
491 		y |= (x & 1);
492 		x >>= 1;
493 	} while (!(--n == 0));
494 	return y;
495 }
496 #endif
497 
498 // Compute an convolution mod p using FFT: z[0..N-1] := x[0..N-1] * y[0..N-1].
fftp3m_convolution(const uintL n,const uintC N,fftp3m_word * x,fftp3m_word * y,fftp3m_word * z)499 static void fftp3m_convolution (const uintL n, const uintC N, // N = 2^n
500                                 fftp3m_word * x, // N words
501                                 fftp3m_word * y, // N words
502                                 fftp3m_word * z  // N words result
503                                )
504 {
505 	CL_ALLOCA_STACK;
506 	#if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
507 	var fftp3m_word* const w = cl_alloc_array(fftp3m_word,N);
508 	#else
509 	var fftp3m_word* const w = cl_alloc_array(fftp3m_word,(N>>1)+1);
510 	#endif
511 	var uintC i;
512 	// Initialize w[i] to w^i, w a primitive N-th root of unity.
513 	w[0] = fftp3m_roots_of_1[0];
514 	w[1] = fftp3m_roots_of_1[n];
515 	#if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
516 	for (i = 2; i < N; i++)
517 		mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
518 	#else // need only half of the roots
519 	for (i = 2; i < N>>1; i++)
520 		mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
521 	#endif
522 	#ifdef DEBUG_FFTP3M
523 	// Check that w is really a primitive N-th root of unity.
524 	{
525 		var fftp3m_word w_N;
526 		mulp3m(w[N-1],fftp3m_roots_of_1[n], w_N);
527 		if (!(   w_N.w1 == (uint32)1<<n1
528 		      && w_N.w2 == (uint32)1<<n2
529 		      && w_N.w3 == (uint32)1<<n3))
530 			throw runtime_exception();
531 		w_N = w[N>>1];
532 		if (!(   w_N.w1 == p1-((uint32)1<<n1)
533 		      && w_N.w2 == p2-((uint32)1<<n2)
534 		      && w_N.w3 == p3-((uint32)1<<n3)))
535 			throw runtime_exception();
536 	}
537 	#endif
538 	var bool squaring = (x == y);
539 	// Do an FFT of length N on x.
540 	{
541 		var sintL l;
542 		/* l = n-1 */ {
543 			var const uintC tmax = N>>1; // tmax = 2^(n-1)
544 			for (var uintC t = 0; t < tmax; t++) {
545 				var uintC i1 = t;
546 				var uintC i2 = i1 + tmax;
547 				// Butterfly: replace (x(i1),x(i2)) by
548 				// (x(i1) + x(i2), x(i1) - x(i2)).
549 				var fftp3m_word tmp;
550 				tmp = x[i2];
551 				subp3m(x[i1],tmp, x[i2]);
552 				addp3m(x[i1],tmp, x[i1]);
553 			}
554 		}
555 		for (l = n-2; l>=0; l--) {
556 			var const uintC smax = (uintC)1 << (n-1-l);
557 			var const uintC tmax = (uintC)1 << l;
558 			for (var uintC s = 0; s < smax; s++) {
559 				var uintC exp = bit_reverse(n-1-l,s) << l;
560 				for (var uintC t = 0; t < tmax; t++) {
561 					var uintC i1 = (s << (l+1)) + t;
562 					var uintC i2 = i1 + tmax;
563 					// Butterfly: replace (x(i1),x(i2)) by
564 					// (x(i1) + w^exp*x(i2), x(i1) - w^exp*x(i2)).
565 					var fftp3m_word tmp;
566 					mulp3m(x[i2],w[exp], tmp);
567 					subp3m(x[i1],tmp, x[i2]);
568 					addp3m(x[i1],tmp, x[i1]);
569 				}
570 			}
571 		}
572 	}
573 	// Do an FFT of length N on y.
574 	if (!squaring) {
575 		var sintL l;
576 		/* l = n-1 */ {
577 			var uintC const tmax = N>>1; // tmax = 2^(n-1)
578 			for (var uintC t = 0; t < tmax; t++) {
579 				var uintC i1 = t;
580 				var uintC i2 = i1 + tmax;
581 				// Butterfly: replace (y(i1),y(i2)) by
582 				// (y(i1) + y(i2), y(i1) - y(i2)).
583 				var fftp3m_word tmp;
584 				tmp = y[i2];
585 				subp3m(y[i1],tmp, y[i2]);
586 				addp3m(y[i1],tmp, y[i1]);
587 			}
588 		}
589 		for (l = n-2; l>=0; l--) {
590 			var const uintC smax = (uintC)1 << (n-1-l);
591 			var const uintC tmax = (uintC)1 << l;
592 			for (var uintC s = 0; s < smax; s++) {
593 				var uintC exp = bit_reverse(n-1-l,s) << l;
594 				for (var uintC t = 0; t < tmax; t++) {
595 					var uintC i1 = (s << (l+1)) + t;
596 					var uintC i2 = i1 + tmax;
597 					// Butterfly: replace (y(i1),y(i2)) by
598 					// (y(i1) + w^exp*y(i2), y(i1) - w^exp*y(i2)).
599 					var fftp3m_word tmp;
600 					mulp3m(y[i2],w[exp], tmp);
601 					subp3m(y[i1],tmp, y[i2]);
602 					addp3m(y[i1],tmp, y[i1]);
603 				}
604 			}
605 		}
606 	}
607 	// Multiply the transformed vectors into z.
608 	for (i = 0; i < N; i++)
609 		mulp3m(x[i],y[i], z[i]);
610 	// Undo an FFT of length N on z.
611 	{
612 		var uintL l;
613 		for (l = 0; l < n-1; l++) {
614 			var const uintC smax = (uintC)1 << (n-1-l);
615 			var const uintC tmax = (uintC)1 << l;
616 			#if FFTP3M_BACKWARD != CLEVER
617 			for (var uintC s = 0; s < smax; s++) {
618 				var uintC exp = bit_reverse(n-1-l,s) << l;
619 				#if FFTP3M_BACKWARD == RECIPROOT
620 				if (exp > 0)
621 					exp = N - exp; // negate exp (use w^-1 instead of w)
622 				#endif
623 				for (var uintC t = 0; t < tmax; t++) {
624 					var uintC i1 = (s << (l+1)) + t;
625 					var uintC i2 = i1 + tmax;
626 					// Inverse Butterfly: replace (z(i1),z(i2)) by
627 					// ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)).
628 					var fftp3m_word sum;
629 					var fftp3m_word diff;
630 					addp3m(z[i1],z[i2], sum);
631 					subp3m(z[i1],z[i2], diff);
632 					shiftp3m(sum, z[i1]);
633 					mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
634 				}
635 			}
636 			#else // FFTP3M_BACKWARD == CLEVER: clever handling of negative exponents
637 			/* s = 0, exp = 0 */ {
638 				for (var uintC t = 0; t < tmax; t++) {
639 					var uintC i1 = t;
640 					var uintC i2 = i1 + tmax;
641 					// Inverse Butterfly: replace (z(i1),z(i2)) by
642 					// ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
643 					// with exp <-- 0.
644 					var fftp3m_word sum;
645 					var fftp3m_word diff;
646 					addp3m(z[i1],z[i2], sum);
647 					subp3m(z[i1],z[i2], diff);
648 					shiftp3m(sum, z[i1]);
649 					shiftp3m(diff, z[i2]);
650 				}
651 			}
652 			for (var uintC s = 1; s < smax; s++) {
653 				var uintC exp = bit_reverse(n-1-l,s) << l;
654 				exp = (N>>1) - exp; // negate exp (use w^-1 instead of w)
655 				for (var uintC t = 0; t < tmax; t++) {
656 					var uintC i1 = (s << (l+1)) + t;
657 					var uintC i2 = i1 + tmax;
658 					// Inverse Butterfly: replace (z(i1),z(i2)) by
659 					// ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
660 					// with exp <-- (N/2 - exp).
661 					var fftp3m_word sum;
662 					var fftp3m_word diff;
663 					addp3m(z[i1],z[i2], sum);
664 					subp3m(z[i2],z[i1], diff); // note that w^(N/2) = -1
665 					shiftp3m(sum, z[i1]);
666 					mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
667 				}
668 			}
669 			#endif
670 		}
671 		/* l = n-1 */ {
672 			var const uintC tmax = N>>1; // tmax = 2^(n-1)
673 			for (var uintC t = 0; t < tmax; t++) {
674 				var uintC i1 = t;
675 				var uintC i2 = i1 + tmax;
676 				// Inverse Butterfly: replace (z(i1),z(i2)) by
677 				// ((z(i1)+z(i2))/2, (z(i1)-z(i2))/2).
678 				var fftp3m_word sum;
679 				var fftp3m_word diff;
680 				addp3m(z[i1],z[i2], sum);
681 				subp3m(z[i1],z[i2], diff);
682 				shiftp3m(sum, z[i1]);
683 				shiftp3m(diff, z[i2]);
684 			}
685 		}
686 	}
687 	#if FFTP3M_BACKWARD == FORWARD
688 	// Swap z[i] and z[N-i] for 0 < i < N/2.
689 	for (i = (N>>1)-1; i > 0; i--) {
690 		var fftp3m_word tmp = z[i];
691 		z[i] = z[N-i];
692 		z[N-i] = tmp;
693 	}
694 	#endif
695 }
696 
mulu_fft_modp3m(const uintD * sourceptr1,uintC len1,const uintD * sourceptr2,uintC len2,uintD * destptr)697 static void mulu_fft_modp3m (const uintD* sourceptr1, uintC len1,
698                              const uintD* sourceptr2, uintC len2,
699                              uintD* destptr)
700 // Es ist 2 <= len1 <= len2.
701 {
702 	// Methode:
703 	// source1 ist ein Stück der Länge N1, source2 ein oder mehrere Stücke
704 	// der Länge N2, mit N1+N2 <= N, wobei N Zweierpotenz ist.
705 	// sum(i=0..N-1, x_i b^i) * sum(i=0..N-1, y_i b^i) wird errechnet,
706 	// indem man die beiden Polynome
707 	// sum(i=0..N-1, x_i T^i), sum(i=0..N-1, y_i T^i)
708 	// multipliziert, und zwar durch Fourier-Transformation (s.o.).
709 	var uint32 n;
710 	integerlengthC(len1-1, n=); // 2^(n-1) < len1 <= 2^n
711 	var uintC len = (uintC)1 << n; // kleinste Zweierpotenz >= len1
712 	// Wählt man N = len, so hat man ceiling(len2/(len-len1+1)) * FFT(len).
713 	// Wählt man N = 2*len, so hat man ceiling(len2/(2*len-len1+1)) * FFT(2*len).
714 	// Wir wählen das billigere von beiden:
715 	// Bei ceiling(len2/(len-len1+1)) <= 2 * ceiling(len2/(2*len-len1+1))
716 	// nimmt man N = len, bei ....... > ........ dagegen N = 2*len.
717 	// (Wahl von N = 4*len oder mehr bringt nur in Extremfällen etwas.)
718 	if (len2 > 2 * (len-len1+1) * (len2 <= (2*len-len1+1) ? 1 : ceiling(len2,(2*len-len1+1)))) {
719 		n = n+1;
720 		len = len << 1;
721 	}
722 	var const uintC N = len; // N = 2^n
723 	CL_ALLOCA_STACK;
724 	var fftp3m_word* const x = cl_alloc_array(fftp3m_word,N);
725 	var fftp3m_word* const y = cl_alloc_array(fftp3m_word,N);
726 	#ifdef DEBUG_FFTP3M
727 	var fftp3m_word* const z = cl_alloc_array(fftp3m_word,N);
728 	#else
729 	var fftp3m_word* const z = x; // put z in place of x - saves memory
730 	#endif
731 	var uintD* const tmpprod = cl_alloc_array(uintD,len1+1);
732 	var uintP i;
733 	var uintC destlen = len1+len2;
734 	clear_loop_lsp(destptr,destlen);
735 	do {
736 		var uintC len2p; // length of a piece of source2
737 		len2p = N - len1 + 1;
738 		if (len2p > len2)
739 			len2p = len2;
740 		// len2p = min(N-len1+1,len2).
741 		if (len2p == 1) {
742 			// cheap case
743 			var uintD* tmpptr = arrayLSDptr(tmpprod,len1+1);
744 			mulu_loop_lsp(lspref(sourceptr2,0),sourceptr1,tmpptr,len1);
745 			if (addto_loop_lsp(tmpptr,destptr,len1+1))
746 				if (inc_loop_lsp(destptr lspop (len1+1),destlen-(len1+1)))
747 					throw runtime_exception();
748 		} else {
749 			var uintC destlenp = len1 + len2p - 1;
750 			// destlenp = min(N,destlen-1).
751 			var bool squaring = ((sourceptr1 == sourceptr2) && (len1 == len2p));
752 			// Fill factor x.
753 			{
754 				for (i = 0; i < len1; i++)
755 					setp3m(lspref(sourceptr1,i), x[i]);
756 				for (i = len1; i < N; i++)
757 					zerop3m(x[i]);
758 			}
759 			// Fill factor y.
760 			if (!squaring) {
761 				for (i = 0; i < len2p; i++)
762 					setp3m(lspref(sourceptr2,i), y[i]);
763 				for (i = len2p; i < N; i++)
764 					zerop3m(y[i]);
765 			}
766 			// Multiply.
767 			if (!squaring)
768 				fftp3m_convolution(n,N, &x[0], &y[0], &z[0]);
769 			else
770 				fftp3m_convolution(n,N, &x[0], &x[0], &z[0]);
771 			// Add result to destptr[-destlen..-1]:
772 			{
773 				var uintD* ptr = destptr;
774 				// ac2|ac1|ac0 are an accumulator.
775 				var uint32 ac0 = 0;
776 				var uint32 ac1 = 0;
777 				var uint32 ac2 = 0;
778 				var uint32 tmp;
779 				for (i = 0; i < destlenp; i++) {
780 					// Convert z[i] to a 3-digit number.
781 					var uintD z_i[3];
782 					combinep3m(z[i],arrayLSDptr(z_i,3));
783 					#ifdef DEBUG_FFTP3M
784 					if (!(arrayLSref(z_i,3,2) < N))
785 						throw runtime_exception();
786 					#endif
787 					// Add z[i] to the accumulator.
788 					tmp = arrayLSref(z_i,3,0);
789 					if ((ac0 += tmp) < tmp) {
790 						if (++ac1 == 0)
791 							++ac2;
792 					}
793 					tmp = arrayLSref(z_i,3,1);
794 					if ((ac1 += tmp) < tmp)
795 						++ac2;
796 					tmp = arrayLSref(z_i,3,2);
797 					ac2 += tmp;
798 					// Add the accumulator's least significant word to destptr:
799 					tmp = lspref(ptr,0);
800 					if ((ac0 += tmp) < tmp) {
801 						if (++ac1 == 0)
802 							++ac2;
803 					}
804 					lspref(ptr,0) = ac0;
805 					lsshrink(ptr);
806 					ac0 = ac1;
807 					ac1 = ac2;
808 					ac2 = 0;
809 				}
810 				// ac2 = 0.
811 				if (ac1 > 0) {
812 					if (!((i += 2) <= destlen))
813 						throw runtime_exception();
814 					tmp = lspref(ptr,0);
815 					if ((ac0 += tmp) < tmp)
816 						++ac1;
817 					lspref(ptr,0) = ac0;
818 					lsshrink(ptr);
819 					tmp = lspref(ptr,0);
820 					ac1 += tmp;
821 					lspref(ptr,0) = ac1;
822 					lsshrink(ptr);
823 					if (ac1 < tmp)
824 						if (inc_loop_lsp(ptr,destlen-i))
825 							throw runtime_exception();
826 				} else if (ac0 > 0) {
827 					if (!((i += 1) <= destlen))
828 						throw runtime_exception();
829 					tmp = lspref(ptr,0);
830 					ac0 += tmp;
831 					lspref(ptr,0) = ac0;
832 					lsshrink(ptr);
833 					if (ac0 < tmp)
834 						if (inc_loop_lsp(ptr,destlen-i))
835 							throw runtime_exception();
836 				}
837 			}
838 			#ifdef DEBUG_FFTP3M
839 			// If destlenp < N, check that the remaining z[i] are 0.
840 			for (i = destlenp; i < N; i++)
841 				if (z[i].w1 > 0 || z[i].w2 > 0 || z[i].w3 > 0)
842 					throw runtime_exception();
843 			#endif
844 		}
845 		// Decrement len2.
846 		destptr = destptr lspop len2p;
847 		destlen -= len2p;
848 		sourceptr2 = sourceptr2 lspop len2p;
849 		len2 -= len2p;
850 	} while (len2 > 0);
851 }
852 
853 #undef n3
854 #undef n2
855 #undef n1
856 #undef p3
857 #undef p2
858 #undef p1
859