xref: /dragonfly/contrib/gmp/mpn/generic/mul_fft.c (revision 73e0051e)
1 /* Schoenhage's fast multiplication modulo 2^N+1.
2 
3    Contributed by Paul Zimmermann.
4 
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8 
9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008
10 Free Software Foundation, Inc.
11 
12 This file is part of the GNU MP Library.
13 
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of the GNU Lesser General Public License as published by
16 the Free Software Foundation; either version 3 of the License, or (at your
17 option) any later version.
18 
19 The GNU MP Library is distributed in the hope that it will be useful, but
20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
22 License for more details.
23 
24 You should have received a copy of the GNU Lesser General Public License
25 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
26 
27 
28 /* References:
29 
30    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31    Strassen, Computing 7, p. 281-292, 1971.
32 
33    Asymptotically fast algorithms for the numerical multiplication and division
34    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36 
37    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39 
40    TODO:
41 
42    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43    Zimmermann.
44 
45    It might be possible to avoid a small number of MPN_COPYs by using a
46    rotating temporary or two.
47 
48    Cleanup and simplify the code!
49 */
50 
51 #ifdef TRACE
52 #undef TRACE
53 #define TRACE(x) x
54 #include <stdio.h>
55 #else
56 #define TRACE(x)
57 #endif
58 
59 #include "gmp.h"
60 #include "gmp-impl.h"
61 
62 #ifdef WANT_ADDSUB
63 #include "generic/addsub_n.c"
64 #define HAVE_NATIVE_mpn_addsub_n 1
65 #endif
66 
67 static void mpn_mul_fft_internal
68 __GMP_PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *,
69 	      mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr,
70 	      int));
71 
72 
73 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
74    sqr==0 if for a multiply, sqr==1 for a square.
75    Don't declare it static since it is needed by tuneup.
76 */
77 #ifdef MUL_FFT_TABLE2
78 #define MPN_FFT_TABLE2_SIZE 512
79 
80 /* FIXME: The format of this should change to need less space.
81    Perhaps put n and k in the same 32-bit word, with n shifted-down
82    (k-2) steps, and k using the 4-5 lowest bits.  That's possible since
83    n-1 is highly divisible.
84    Alternatively, separate n and k out into separate arrays.  */
85 struct nk {
86   mp_size_t n;
87   unsigned char k;
88 };
89 
90 static struct nk mpn_fft_table2[2][MPN_FFT_TABLE2_SIZE] =
91 {
92   MUL_FFT_TABLE2,
93   SQR_FFT_TABLE2
94 };
95 
96 int
97 mpn_fft_best_k (mp_size_t n, int sqr)
98 {
99   struct nk *tab;
100   int last_k;
101 
102   last_k = 4;
103   for (tab = mpn_fft_table2[sqr] + 1; ; tab++)
104     {
105       if (n < tab->n)
106 	break;
107       last_k = tab->k;
108     }
109   return last_k;
110 }
111 #endif
112 
113 #if !defined (MUL_FFT_TABLE2) || TUNE_PROGRAM_BUILD
114 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
115 {
116   MUL_FFT_TABLE,
117   SQR_FFT_TABLE
118 };
119 #endif
120 
121 #if !defined (MUL_FFT_TABLE2)
122 int
123 mpn_fft_best_k (mp_size_t n, int sqr)
124 {
125   int i;
126 
127   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
128     if (n < mpn_fft_table[sqr][i])
129       return i + FFT_FIRST_K;
130 
131   /* treat 4*last as one further entry */
132   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
133     return i + FFT_FIRST_K;
134   else
135     return i + FFT_FIRST_K + 1;
136 }
137 #endif
138 
139 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
140    i.e. smallest multiple of 2^k >= pl.
141 
142    Don't declare static: needed by tuneup.
143 */
144 
145 mp_size_t
146 mpn_fft_next_size (mp_size_t pl, int k)
147 {
148   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
149   return pl << k;
150 }
151 
152 
153 /* Initialize l[i][j] with bitrev(j) */
154 static void
155 mpn_fft_initl (int **l, int k)
156 {
157   int i, j, K;
158   int *li;
159 
160   l[0][0] = 0;
161   for (i = 1, K = 1; i <= k; i++, K *= 2)
162     {
163       li = l[i];
164       for (j = 0; j < K; j++)
165 	{
166 	  li[j] = 2 * l[i - 1][j];
167 	  li[K + j] = 1 + li[j];
168 	}
169     }
170 }
171 
172 /* Shift {up, n} of cnt bits to the left, store the complemented result
173    in {rp, n}, and output the shifted bits (not complemented).
174    Same as:
175      cc = mpn_lshift (rp, up, n, cnt);
176      mpn_com_n (rp, rp, n);
177      return cc;
178 
179    Assumes n >= 1, 1 < cnt < GMP_NUMB_BITS, rp >= up.
180 */
181 #ifndef HAVE_NATIVE_mpn_lshiftc
182 #undef mpn_lshiftc
183 static mp_limb_t
184 mpn_lshiftc (mp_ptr rp, mp_srcptr up, mp_size_t n, unsigned int cnt)
185 {
186   mp_limb_t high_limb, low_limb;
187   unsigned int tnc;
188   mp_size_t i;
189   mp_limb_t retval;
190 
191   up += n;
192   rp += n;
193 
194   tnc = GMP_NUMB_BITS - cnt;
195   low_limb = *--up;
196   retval = low_limb >> tnc;
197   high_limb = (low_limb << cnt);
198 
199   for (i = n - 1; i != 0; i--)
200     {
201       low_limb = *--up;
202       *--rp = (~(high_limb | (low_limb >> tnc))) & GMP_NUMB_MASK;
203       high_limb = low_limb << cnt;
204     }
205   *--rp = (~high_limb) & GMP_NUMB_MASK;
206 
207   return retval;
208 }
209 #endif
210 
211 /* r <- a*2^e mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
212    Assumes a is semi-normalized, i.e. a[n] <= 1.
213    r and a must have n+1 limbs, and not overlap.
214 */
215 static void
216 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
217 {
218   int sh, negate;
219   mp_limb_t cc, rd;
220 
221   sh = d % GMP_NUMB_BITS;
222   d /= GMP_NUMB_BITS;
223   negate = d >= n;
224   if (negate)
225     d -= n;
226 
227   if (negate)
228     {
229       /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
230 	 r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
231       if (sh != 0)
232 	{
233 	  /* no out shift below since a[n] <= 1 */
234 	  mpn_lshift (r, a + n - d, d + 1, sh);
235 	  rd = r[d];
236 	  cc = mpn_lshiftc (r + d, a, n - d, sh);
237 	}
238       else
239 	{
240 	  MPN_COPY (r, a + n - d, d);
241 	  rd = a[n];
242 	  mpn_com_n (r + d, a, n - d);
243 	  cc = 0;
244 	}
245 
246       /* add cc to r[0], and add rd to r[d] */
247 
248       /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
249 
250       r[n] = 0;
251       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
252       cc++;
253       mpn_incr_u (r, cc);
254 
255       rd ++;
256       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
257       cc = (rd == 0) ? 1 : rd;
258       r = r + d + (rd == 0);
259       mpn_incr_u (r, cc);
260 
261       return;
262     }
263 
264   /* if negate=0,
265 	r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
266 	r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)
267   */
268   if (sh != 0)
269     {
270       /* no out bits below since a[n] <= 1 */
271       mpn_lshiftc (r, a + n - d, d + 1, sh);
272       rd = ~r[d];
273       /* {r, d+1} = {a+n-d, d+1} << sh */
274       cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
275     }
276   else
277     {
278       /* r[d] is not used below, but we save a test for d=0 */
279       mpn_com_n (r, a + n - d, d + 1);
280       rd = a[n];
281       MPN_COPY (r + d, a, n - d);
282       cc = 0;
283     }
284 
285   /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
286 
287   /* if d=0 we just have r[0]=a[n] << sh */
288   if (d != 0)
289     {
290       /* now add 1 in r[0], subtract 1 in r[d] */
291       if (cc-- == 0) /* then add 1 to r[0] */
292 	cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
293       cc = mpn_sub_1 (r, r, d, cc) + 1;
294       /* add 1 to cc instead of rd since rd might overflow */
295     }
296 
297   /* now subtract cc and rd from r[d..n] */
298 
299   r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
300   r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
301   if (r[n] & GMP_LIMB_HIGHBIT)
302     r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
303 }
304 
305 
306 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
307    Assumes a and b are semi-normalized.
308 */
309 static inline void
310 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
311 {
312   mp_limb_t c, x;
313 
314   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
315   /* 0 <= c <= 3 */
316 
317 #if 1
318   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
319      result is slower code, of course.  But the following outsmarts GCC.  */
320   x = (c - 1) & -(c != 0);
321   r[n] = c - x;
322   MPN_DECR_U (r, n + 1, x);
323 #endif
324 #if 0
325   if (c > 1)
326     {
327       r[n] = 1;                       /* r[n] - c = 1 */
328       MPN_DECR_U (r, n + 1, c - 1);
329     }
330   else
331     {
332       r[n] = c;
333     }
334 #endif
335 }
336 
337 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
338    Assumes a and b are semi-normalized.
339 */
340 static inline void
341 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
342 {
343   mp_limb_t c, x;
344 
345   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
346   /* -2 <= c <= 1 */
347 
348 #if 1
349   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
350      result is slower code, of course.  But the following outsmarts GCC.  */
351   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
352   r[n] = x + c;
353   MPN_INCR_U (r, n + 1, x);
354 #endif
355 #if 0
356   if ((c & GMP_LIMB_HIGHBIT) != 0)
357     {
358       r[n] = 0;
359       MPN_INCR_U (r, n + 1, -c);
360     }
361   else
362     {
363       r[n] = c;
364     }
365 #endif
366 }
367 
368 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
369 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
370    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
371 
372 static void
373 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
374 	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
375 {
376   if (K == 2)
377     {
378       mp_limb_t cy;
379 #if HAVE_NATIVE_mpn_addsub_n
380       cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
381 #else
382       MPN_COPY (tp, Ap[0], n + 1);
383       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
384       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
385 #endif
386       if (Ap[0][n] > 1) /* can be 2 or 3 */
387 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
388       if (cy) /* Ap[inc][n] can be -1 or -2 */
389 	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
390     }
391   else
392     {
393       int j;
394       int *lk = *ll;
395 
396       mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
397       mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
398       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
399 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
400       for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
401 	{
402 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
403 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
404 	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
405 	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
406 	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
407 	}
408     }
409 }
410 
411 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
412 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
413    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
414    tp must have space for 2*(n+1) limbs.
415 */
416 
417 
418 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
419    by subtracting that modulus if necessary.
420 
421    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
422    borrow and the limbs must be zeroed out again.  This will occur very
423    infrequently.  */
424 
425 static inline void
426 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
427 {
428   if (ap[n] != 0)
429     {
430       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
431       if (ap[n] == 0)
432 	{
433 	  /* This happens with very low probability; we have yet to trigger it,
434 	     and thereby make sure this code is correct.  */
435 	  MPN_ZERO (ap, n);
436 	  ap[n] = 1;
437 	}
438       else
439 	ap[n] = 0;
440     }
441 }
442 
443 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
444 static void
445 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
446 {
447   int i;
448   int sqr = (ap == bp);
449   TMP_DECL;
450 
451   TMP_MARK;
452 
453   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
454     {
455       int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
456       int **_fft_l;
457       mp_ptr *Ap, *Bp, A, B, T;
458 
459       k = mpn_fft_best_k (n, sqr);
460       K2 = 1 << k;
461       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
462       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
463       M2 = n * GMP_NUMB_BITS >> k;
464       l = n >> k;
465       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
466       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
467       nprime2 = Nprime2 / GMP_NUMB_BITS;
468 
469       /* we should ensure that nprime2 is a multiple of the next K */
470       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
471 	{
472 	  unsigned long K3;
473 	  for (;;)
474 	    {
475 	      K3 = 1L << mpn_fft_best_k (nprime2, sqr);
476 	      if ((nprime2 & (K3 - 1)) == 0)
477 		break;
478 	      nprime2 = (nprime2 + K3 - 1) & -K3;
479 	      Nprime2 = nprime2 * GMP_LIMB_BITS;
480 	      /* warning: since nprime2 changed, K3 may change too! */
481 	    }
482 	}
483       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
484 
485       Mp2 = Nprime2 >> k;
486 
487       Ap = TMP_ALLOC_MP_PTRS (K2);
488       Bp = TMP_ALLOC_MP_PTRS (K2);
489       A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1));
490       T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
491       B = A + K2 * (nprime2 + 1);
492       _fft_l = TMP_ALLOC_TYPE (k + 1, int *);
493       for (i = 0; i <= k; i++)
494 	_fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
495       mpn_fft_initl (_fft_l, k);
496 
497       TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
498 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
499       for (i = 0; i < K; i++, ap++, bp++)
500 	{
501 	  mpn_fft_normalize (*ap, n);
502 	  if (!sqr)
503 	    mpn_fft_normalize (*bp, n);
504 	  mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2,
505 				l, Mp2, _fft_l, T, 1);
506 	}
507     }
508   else
509     {
510       mp_ptr a, b, tp, tpn;
511       mp_limb_t cc;
512       int n2 = 2 * n;
513       tp = TMP_ALLOC_LIMBS (n2);
514       tpn = tp + n;
515       TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
516       for (i = 0; i < K; i++)
517 	{
518 	  a = *ap++;
519 	  b = *bp++;
520 	  if (sqr)
521 	    mpn_sqr_n (tp, a, n);
522 	  else
523 	    mpn_mul_n (tp, b, a, n);
524 	  if (a[n] != 0)
525 	    cc = mpn_add_n (tpn, tpn, b, n);
526 	  else
527 	    cc = 0;
528 	  if (b[n] != 0)
529 	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
530 	  if (cc != 0)
531 	    {
532 	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
533 	      cc = mpn_add_1 (tp, tp, n2, cc);
534 	      ASSERT (cc == 0);
535 	    }
536 	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
537 	}
538     }
539   TMP_FREE;
540 }
541 
542 
543 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
544    output: K*A[0] K*A[K-1] ... K*A[1].
545    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
546    This condition is also fulfilled at exit.
547 */
548 static void
549 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
550 {
551   if (K == 2)
552     {
553       mp_limb_t cy;
554 #if HAVE_NATIVE_mpn_addsub_n
555       cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
556 #else
557       MPN_COPY (tp, Ap[0], n + 1);
558       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
559       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
560 #endif
561       if (Ap[0][n] > 1) /* can be 2 or 3 */
562 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
563       if (cy) /* Ap[1][n] can be -1 or -2 */
564 	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
565     }
566   else
567     {
568       int j, K2 = K >> 1;
569 
570       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
571       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
572       /* A[j]     <- A[j] + omega^j A[j+K/2]
573 	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
574       for (j = 0; j < K2; j++, Ap++)
575 	{
576 	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
577 	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
578 	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
579 	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
580 	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
581 	}
582     }
583 }
584 
585 
586 /* A <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
587 static void
588 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
589 {
590   int i;
591 
592   ASSERT (r != a);
593   i = 2 * n * GMP_NUMB_BITS;
594   i = (i - k) % i;		/* FIXME: This % looks superfluous */
595   mpn_fft_mul_2exp_modF (r, a, i, n);
596   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
597   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
598   mpn_fft_normalize (r, n);
599 }
600 
601 
602 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
603    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
604    then {rp,n}=0.
605 */
606 static int
607 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
608 {
609   mp_size_t l;
610   long int m;
611   mp_limb_t cc;
612   int rpn;
613 
614   ASSERT ((n <= an) && (an <= 3 * n));
615   m = an - 2 * n;
616   if (m > 0)
617     {
618       l = n;
619       /* add {ap, m} and {ap+2n, m} in {rp, m} */
620       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
621       /* copy {ap+m, n-m} to {rp+m, n-m} */
622       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
623     }
624   else
625     {
626       l = an - n; /* l <= n */
627       MPN_COPY (rp, ap, n);
628       rpn = 0;
629     }
630 
631   /* remains to subtract {ap+n, l} from {rp, n+1} */
632   cc = mpn_sub_n (rp, rp, ap + n, l);
633   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
634   if (rpn < 0) /* necessarily rpn = -1 */
635     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
636   return rpn;
637 }
638 
639 /* store in A[0..nprime] the first M bits from {n, nl},
640    in A[nprime+1..] the following M bits, ...
641    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
642    T must have space for at least (nprime + 1) limbs.
643    We must have nl <= 2*K*l.
644 */
645 static void
646 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
647 		       mp_size_t nl, int l, int Mp, mp_ptr T)
648 {
649   int i, j;
650   mp_ptr tmp;
651   mp_size_t Kl = K * l;
652   TMP_DECL;
653   TMP_MARK;
654 
655   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
656     {
657       mp_size_t dif = nl - Kl;
658       mp_limb_signed_t cy;
659 
660       tmp = TMP_ALLOC_LIMBS(Kl + 1);
661 
662       if (dif > Kl)
663 	{
664 	  int subp = 0;
665 
666 	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
667 	  n += 2 * Kl;
668 	  dif -= Kl;
669 
670 	  /* now dif > 0 */
671 	  while (dif > Kl)
672 	    {
673 	      if (subp)
674 		cy += mpn_sub_n (tmp, tmp, n, Kl);
675 	      else
676 		cy -= mpn_add_n (tmp, tmp, n, Kl);
677 	      subp ^= 1;
678 	      n += Kl;
679 	      dif -= Kl;
680 	    }
681 	  /* now dif <= Kl */
682 	  if (subp)
683 	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
684 	  else
685 	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
686 	  if (cy >= 0)
687 	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
688 	  else
689 	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
690 	}
691       else /* dif <= Kl, i.e. nl <= 2 * Kl */
692 	{
693 	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
694 	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
695 	}
696       tmp[Kl] = cy;
697       nl = Kl + 1;
698       n = tmp;
699     }
700   for (i = 0; i < K; i++)
701     {
702       Ap[i] = A;
703       /* store the next M bits of n into A[0..nprime] */
704       if (nl > 0) /* nl is the number of remaining limbs */
705 	{
706 	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
707 	  nl -= j;
708 	  MPN_COPY (T, n, j);
709 	  MPN_ZERO (T + j, nprime + 1 - j);
710 	  n += l;
711 	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
712 	}
713       else
714 	MPN_ZERO (A, nprime + 1);
715       A += nprime + 1;
716     }
717   ASSERT_ALWAYS (nl == 0);
718   TMP_FREE;
719 }
720 
721 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
722    n and m have respectively nl and ml limbs
723    op must have space for pl+1 limbs if rec=1 (and pl limbs if rec=0).
724    One must have pl = mpn_fft_next_size (pl, k).
725    T must have space for 2 * (nprime + 1) limbs.
726 
727    If rec=0, then store only the pl low bits of the result, and return
728    the out carry.
729 */
730 
731 static void
732 mpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl,
733 		      int k, int K,
734 		      mp_ptr *Ap, mp_ptr *Bp,
735 		      mp_ptr A, mp_ptr B,
736 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
737 		      int **_fft_l,
738 		      mp_ptr T, int rec)
739 {
740   int i, sqr, pla, lo, sh, j;
741   mp_ptr p;
742   mp_limb_t cc;
743 
744   sqr = n == m;
745 
746   TRACE (printf ("pl=%ld k=%d K=%d np=%ld l=%ld Mp=%ld rec=%d sqr=%d\n",
747 		 pl,k,K,nprime,l,Mp,rec,sqr));
748 
749   /* decomposition of inputs into arrays Ap[i] and Bp[i] */
750   if (rec)
751     {
752       mpn_mul_fft_decompose (A, Ap, K, nprime, n, K * l + 1, l, Mp, T);
753       if (!sqr)
754 	mpn_mul_fft_decompose (B, Bp, K, nprime, m, K * l + 1, l, Mp, T);
755     }
756 
757   /* direct fft's */
758   mpn_fft_fft (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T);
759   if (!sqr)
760     mpn_fft_fft (Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T);
761 
762   /* term to term multiplications */
763   mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K);
764 
765   /* inverse fft's */
766   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
767 
768   /* division of terms after inverse fft */
769   Bp[0] = T + nprime + 1;
770   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
771   for (i = 1; i < K; i++)
772     {
773       Bp[i] = Ap[i - 1];
774       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
775     }
776 
777   /* addition of terms in result p */
778   MPN_ZERO (T, nprime + 1);
779   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
780   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
781   MPN_ZERO (p, pla);
782   cc = 0; /* will accumulate the (signed) carry at p[pla] */
783   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
784     {
785       mp_ptr n = p + sh;
786 
787       j = (K - i) & (K - 1);
788 
789       if (mpn_add_n (n, n, Bp[j], nprime + 1))
790 	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
791 			  pla - sh - nprime - 1, CNST_LIMB(1));
792       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
793       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
794 	{ /* subtract 2^N'+1 */
795 	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
796 	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
797 	}
798     }
799   if (cc == -CNST_LIMB(1))
800     {
801       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
802 	{
803 	  /* p[pla-pl]...p[pla-1] are all zero */
804 	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
805 	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
806 	}
807     }
808   else if (cc == 1)
809     {
810       if (pla >= 2 * pl)
811 	{
812 	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
813 	    ;
814 	}
815       else
816 	{
817 	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
818 	  ASSERT (cc == 0);
819 	}
820     }
821   else
822     ASSERT (cc == 0);
823 
824   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
825      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
826      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
827   i = mpn_fft_norm_modF (op, pl, p, pla);
828   if (rec) /* store the carry out */
829     op[pl] = i;
830 }
831 
832 /* return the lcm of a and 2^k */
833 static unsigned long int
834 mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
835 {
836   unsigned long int l = k;
837 
838   while (a % 2 == 0 && k > 0)
839     {
840       a >>= 1;
841       k --;
842     }
843   return a << l;
844 }
845 
846 void
847 mpn_mul_fft (mp_ptr op, mp_size_t pl,
848 	     mp_srcptr n, mp_size_t nl,
849 	     mp_srcptr m, mp_size_t ml,
850 	     int k)
851 {
852   int K, maxLK, i;
853   mp_size_t N, Nprime, nprime, M, Mp, l;
854   mp_ptr *Ap, *Bp, A, T, B;
855   int **_fft_l;
856   int sqr = (n == m && nl == ml);
857   TMP_DECL;
858 
859   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
860   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
861 
862   TMP_MARK;
863   N = pl * GMP_NUMB_BITS;
864   _fft_l = TMP_ALLOC_TYPE (k + 1, int *);
865   for (i = 0; i <= k; i++)
866     _fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
867   mpn_fft_initl (_fft_l, k);
868   K = 1 << k;
869   M = N >> k;	/* N = 2^k M */
870   l = 1 + (M - 1) / GMP_NUMB_BITS;
871   maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
872 
873   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
874   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
875   nprime = Nprime / GMP_NUMB_BITS;
876   TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
877 		 N, K, M, l, maxLK, Nprime, nprime));
878   /* we should ensure that recursively, nprime is a multiple of the next K */
879   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
880     {
881       unsigned long K2;
882       for (;;)
883 	{
884 	  K2 = 1L << mpn_fft_best_k (nprime, sqr);
885 	  if ((nprime & (K2 - 1)) == 0)
886 	    break;
887 	  nprime = (nprime + K2 - 1) & -K2;
888 	  Nprime = nprime * GMP_LIMB_BITS;
889 	  /* warning: since nprime changed, K2 may change too! */
890 	}
891       TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
892     }
893   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
894 
895   T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
896   Mp = Nprime >> k;
897 
898   TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
899 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
900 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
901 
902   A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1));
903   B = A + K * (nprime + 1);
904   Ap = TMP_ALLOC_MP_PTRS (K);
905   Bp = TMP_ALLOC_MP_PTRS (K);
906 
907   /* special decomposition for main call */
908   /* nl is the number of significant limbs in n */
909   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
910   if (n != m)
911     mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
912 
913   mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0);
914 
915   TMP_FREE;
916   __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1));
917 }
918 
919 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
920 void
921 mpn_mul_fft_full (mp_ptr op,
922 		  mp_srcptr n, mp_size_t nl,
923 		  mp_srcptr m, mp_size_t ml)
924 {
925   mp_ptr pad_op;
926   mp_size_t pl, pl2, pl3, l;
927   int k2, k3;
928   int sqr = (n == m && nl == ml);
929   int cc, c2, oldcc;
930 
931   pl = nl + ml; /* total number of limbs of the result */
932 
933   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
934      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
935      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
936      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
937      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
938      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
939      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
940 
941   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
942 
943   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
944   do
945     {
946       pl2 ++;
947       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
948       pl2 = mpn_fft_next_size (pl2, k2);
949       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
950 			    thus pl2 / 2 is exact */
951       k3 = mpn_fft_best_k (pl3, sqr);
952     }
953   while (mpn_fft_next_size (pl3, k3) != pl3);
954 
955   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
956 		 nl, ml, pl2, pl3, k2));
957 
958   ASSERT_ALWAYS(pl3 <= pl);
959   mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
960   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
961   mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
962   cc = mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
963   /* 0 <= cc <= 1 */
964   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
965   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
966   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; /* -1 <= cc <= 1 */
967   if (cc < 0)
968     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
969   /* 0 <= cc <= 1 */
970   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
971   oldcc = cc;
972 #if HAVE_NATIVE_mpn_addsub_n
973   c2 = mpn_addsub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
974   /* c2 & 1 is the borrow, c2 & 2 is the carry */
975   cc += c2 >> 1; /* carry out from high <- low + high */
976   c2 = c2 & 1; /* borrow out from low <- low - high */
977 #else
978   {
979     mp_ptr tmp;
980     TMP_DECL;
981 
982     TMP_MARK;
983     tmp = TMP_ALLOC_LIMBS (l);
984     MPN_COPY (tmp, pad_op, l);
985     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
986     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
987     TMP_FREE;
988   }
989 #endif
990   c2 += oldcc;
991   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
992      at pad_op + l, cc is the carry at pad_op + pl2 */
993   /* 0 <= cc <= 2 */
994   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
995   /* -1 <= cc <= 2 */
996   if (cc > 0)
997     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
998   /* now -1 <= cc <= 0 */
999   if (cc < 0)
1000     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1001   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1002   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1003     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1004   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1005      out below */
1006   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1007   if (cc) /* then cc=1 */
1008     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1009   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1010      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1011   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1012   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1013   MPN_COPY (op + pl3, pad_op, pl - pl3);
1014   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1015   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1016   /* since the final result has at most pl limbs, no carry out below */
1017   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1018 }
1019