xref: /dragonfly/contrib/gmp/mpn/generic/mul_fft.c (revision 92fc8b5c)
1 /* Schoenhage's fast multiplication modulo 2^N+1.
2 
3    Contributed by Paul Zimmermann.
4 
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8 
9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008
10 Free Software Foundation, Inc.
11 
12 This file is part of the GNU MP Library.
13 
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of the GNU Lesser General Public License as published by
16 the Free Software Foundation; either version 3 of the License, or (at your
17 option) any later version.
18 
19 The GNU MP Library is distributed in the hope that it will be useful, but
20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
22 License for more details.
23 
24 You should have received a copy of the GNU Lesser General Public License
25 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
26 
27 
28 /* References:
29 
30    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31    Strassen, Computing 7, p. 281-292, 1971.
32 
33    Asymptotically fast algorithms for the numerical multiplication and division
34    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36 
37    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39 
40    TODO:
41 
42    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43    Zimmermann.
44 
45    It might be possible to avoid a small number of MPN_COPYs by using a
46    rotating temporary or two.
47 
48    Cleanup and simplify the code!
49 */
50 
51 #ifdef TRACE
52 #undef TRACE
53 #define TRACE(x) x
54 #include <stdio.h>
55 #else
56 #define TRACE(x)
57 #endif
58 
59 #include "gmp.h"
60 #include "gmp-impl.h"
61 
62 #ifdef WANT_ADDSUB
63 #include "generic/addsub_n.c"
64 #define HAVE_NATIVE_mpn_addsub_n 1
65 #endif
66 
67 static mp_limb_t mpn_mul_fft_internal
68 __GMP_PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *,
69 	      mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr,
70 	      int));
71 
72 
73 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
74    sqr==0 if for a multiply, sqr==1 for a square.
75    Don't declare it static since it is needed by tuneup.
76 */
77 #ifdef MUL_FFT_TABLE2
78 
79 #if defined (MUL_FFT_TABLE2_SIZE) && defined (SQR_FFT_TABLE2_SIZE)
80 #if MUL_FFT_TABLE2_SIZE > SQR_FFT_TABLE2_SIZE
81 #define FFT_TABLE2_SIZE MUL_FFT_TABLE2_SIZE
82 #else
83 #define FFT_TABLE2_SIZE SQR_FFT_TABLE2_SIZE
84 #endif
85 #endif
86 
87 #ifndef FFT_TABLE2_SIZE
88 #define FFT_TABLE2_SIZE 200
89 #endif
90 
91 /* FIXME: The format of this should change to need less space.
92    Perhaps put n and k in the same 32-bit word, with n shifted-down
93    (k-2) steps, and k using the 4-5 lowest bits.  That's possible since
94    n-1 is highly divisible.
95    Alternatively, separate n and k out into separate arrays.  */
96 struct nk {
97   unsigned int n:27;
98   unsigned int k:5;
99 };
100 
101 static struct nk mpn_fft_table2[2][FFT_TABLE2_SIZE] =
102 {
103   MUL_FFT_TABLE2,
104   SQR_FFT_TABLE2
105 };
106 
107 int
108 mpn_fft_best_k (mp_size_t n, int sqr)
109 {
110   struct nk *tab;
111   int last_k;
112 
113   last_k = 4;
114   for (tab = mpn_fft_table2[sqr] + 1; ; tab++)
115     {
116       if (n < tab->n)
117 	break;
118       last_k = tab->k;
119     }
120   return last_k;
121 }
122 #endif
123 
124 #if !defined (MUL_FFT_TABLE2) || TUNE_PROGRAM_BUILD
125 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
126 {
127   MUL_FFT_TABLE,
128   SQR_FFT_TABLE
129 };
130 #endif
131 
132 #if !defined (MUL_FFT_TABLE2)
133 int
134 mpn_fft_best_k (mp_size_t n, int sqr)
135 {
136   int i;
137 
138   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
139     if (n < mpn_fft_table[sqr][i])
140       return i + FFT_FIRST_K;
141 
142   /* treat 4*last as one further entry */
143   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
144     return i + FFT_FIRST_K;
145   else
146     return i + FFT_FIRST_K + 1;
147 }
148 #endif
149 
150 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
151    i.e. smallest multiple of 2^k >= pl.
152 
153    Don't declare static: needed by tuneup.
154 */
155 
156 mp_size_t
157 mpn_fft_next_size (mp_size_t pl, int k)
158 {
159   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
160   return pl << k;
161 }
162 
163 
164 /* Initialize l[i][j] with bitrev(j) */
165 static void
166 mpn_fft_initl (int **l, int k)
167 {
168   int i, j, K;
169   int *li;
170 
171   l[0][0] = 0;
172   for (i = 1, K = 1; i <= k; i++, K *= 2)
173     {
174       li = l[i];
175       for (j = 0; j < K; j++)
176 	{
177 	  li[j] = 2 * l[i - 1][j];
178 	  li[K + j] = 1 + li[j];
179 	}
180     }
181 }
182 
183 /* Shift {up, n} of cnt bits to the left, store the complemented result
184    in {rp, n}, and output the shifted bits (not complemented).
185    Same as:
186      cc = mpn_lshift (rp, up, n, cnt);
187      mpn_com_n (rp, rp, n);
188      return cc;
189 
190    Assumes n >= 1, 1 < cnt < GMP_NUMB_BITS, rp >= up.
191 */
192 #ifndef HAVE_NATIVE_mpn_lshiftc
193 #undef mpn_lshiftc
194 static mp_limb_t
195 mpn_lshiftc (mp_ptr rp, mp_srcptr up, mp_size_t n, unsigned int cnt)
196 {
197   mp_limb_t high_limb, low_limb;
198   unsigned int tnc;
199   mp_size_t i;
200   mp_limb_t retval;
201 
202   up += n;
203   rp += n;
204 
205   tnc = GMP_NUMB_BITS - cnt;
206   low_limb = *--up;
207   retval = low_limb >> tnc;
208   high_limb = (low_limb << cnt);
209 
210   for (i = n - 1; i != 0; i--)
211     {
212       low_limb = *--up;
213       *--rp = (~(high_limb | (low_limb >> tnc))) & GMP_NUMB_MASK;
214       high_limb = low_limb << cnt;
215     }
216   *--rp = (~high_limb) & GMP_NUMB_MASK;
217 
218   return retval;
219 }
220 #endif
221 
222 /* r <- a*2^e mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
223    Assumes a is semi-normalized, i.e. a[n] <= 1.
224    r and a must have n+1 limbs, and not overlap.
225 */
226 static void
227 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
228 {
229   int sh, negate;
230   mp_limb_t cc, rd;
231 
232   sh = d % GMP_NUMB_BITS;
233   d /= GMP_NUMB_BITS;
234   negate = d >= n;
235   if (negate)
236     d -= n;
237 
238   if (negate)
239     {
240       /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
241 	 r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
242       if (sh != 0)
243 	{
244 	  /* no out shift below since a[n] <= 1 */
245 	  mpn_lshift (r, a + n - d, d + 1, sh);
246 	  rd = r[d];
247 	  cc = mpn_lshiftc (r + d, a, n - d, sh);
248 	}
249       else
250 	{
251 	  MPN_COPY (r, a + n - d, d);
252 	  rd = a[n];
253 	  mpn_com_n (r + d, a, n - d);
254 	  cc = 0;
255 	}
256 
257       /* add cc to r[0], and add rd to r[d] */
258 
259       /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
260 
261       r[n] = 0;
262       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
263       cc++;
264       mpn_incr_u (r, cc);
265 
266       rd ++;
267       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
268       cc = (rd == 0) ? 1 : rd;
269       r = r + d + (rd == 0);
270       mpn_incr_u (r, cc);
271 
272       return;
273     }
274 
275   /* if negate=0,
276 	r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
277 	r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)
278   */
279   if (sh != 0)
280     {
281       /* no out bits below since a[n] <= 1 */
282       mpn_lshiftc (r, a + n - d, d + 1, sh);
283       rd = ~r[d];
284       /* {r, d+1} = {a+n-d, d+1} << sh */
285       cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
286     }
287   else
288     {
289       /* r[d] is not used below, but we save a test for d=0 */
290       mpn_com_n (r, a + n - d, d + 1);
291       rd = a[n];
292       MPN_COPY (r + d, a, n - d);
293       cc = 0;
294     }
295 
296   /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
297 
298   /* if d=0 we just have r[0]=a[n] << sh */
299   if (d != 0)
300     {
301       /* now add 1 in r[0], subtract 1 in r[d] */
302       if (cc-- == 0) /* then add 1 to r[0] */
303 	cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
304       cc = mpn_sub_1 (r, r, d, cc) + 1;
305       /* add 1 to cc instead of rd since rd might overflow */
306     }
307 
308   /* now subtract cc and rd from r[d..n] */
309 
310   r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
311   r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
312   if (r[n] & GMP_LIMB_HIGHBIT)
313     r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
314 }
315 
316 
317 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
318    Assumes a and b are semi-normalized.
319 */
320 static inline void
321 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
322 {
323   mp_limb_t c, x;
324 
325   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
326   /* 0 <= c <= 3 */
327 
328 #if 1
329   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
330      result is slower code, of course.  But the following outsmarts GCC.  */
331   x = (c - 1) & -(c != 0);
332   r[n] = c - x;
333   MPN_DECR_U (r, n + 1, x);
334 #endif
335 #if 0
336   if (c > 1)
337     {
338       r[n] = 1;                       /* r[n] - c = 1 */
339       MPN_DECR_U (r, n + 1, c - 1);
340     }
341   else
342     {
343       r[n] = c;
344     }
345 #endif
346 }
347 
348 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
349    Assumes a and b are semi-normalized.
350 */
351 static inline void
352 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
353 {
354   mp_limb_t c, x;
355 
356   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
357   /* -2 <= c <= 1 */
358 
359 #if 1
360   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
361      result is slower code, of course.  But the following outsmarts GCC.  */
362   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
363   r[n] = x + c;
364   MPN_INCR_U (r, n + 1, x);
365 #endif
366 #if 0
367   if ((c & GMP_LIMB_HIGHBIT) != 0)
368     {
369       r[n] = 0;
370       MPN_INCR_U (r, n + 1, -c);
371     }
372   else
373     {
374       r[n] = c;
375     }
376 #endif
377 }
378 
379 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
380 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
381    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
382 
383 static void
384 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
385 	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
386 {
387   if (K == 2)
388     {
389       mp_limb_t cy;
390 #if HAVE_NATIVE_mpn_addsub_n
391       cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
392 #else
393       MPN_COPY (tp, Ap[0], n + 1);
394       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
395       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
396 #endif
397       if (Ap[0][n] > 1) /* can be 2 or 3 */
398 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
399       if (cy) /* Ap[inc][n] can be -1 or -2 */
400 	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
401     }
402   else
403     {
404       int j;
405       int *lk = *ll;
406 
407       mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
408       mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
409       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
410 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
411       for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
412 	{
413 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
414 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
415 	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
416 	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
417 	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
418 	}
419     }
420 }
421 
422 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
423 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
424    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
425    tp must have space for 2*(n+1) limbs.
426 */
427 
428 
429 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
430    by subtracting that modulus if necessary.
431 
432    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
433    borrow and the limbs must be zeroed out again.  This will occur very
434    infrequently.  */
435 
436 static inline void
437 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
438 {
439   if (ap[n] != 0)
440     {
441       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
442       if (ap[n] == 0)
443 	{
444 	  /* This happens with very low probability; we have yet to trigger it,
445 	     and thereby make sure this code is correct.  */
446 	  MPN_ZERO (ap, n);
447 	  ap[n] = 1;
448 	}
449       else
450 	ap[n] = 0;
451     }
452 }
453 
454 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
455 static void
456 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
457 {
458   int i;
459   int sqr = (ap == bp);
460   TMP_DECL;
461 
462   TMP_MARK;
463 
464   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
465     {
466       int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
467       int **_fft_l;
468       mp_ptr *Ap, *Bp, A, B, T;
469 
470       k = mpn_fft_best_k (n, sqr);
471       K2 = 1 << k;
472       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
473       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
474       M2 = n * GMP_NUMB_BITS >> k;
475       l = n >> k;
476       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
477       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
478       nprime2 = Nprime2 / GMP_NUMB_BITS;
479 
480       /* we should ensure that nprime2 is a multiple of the next K */
481       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
482 	{
483 	  unsigned long K3;
484 	  for (;;)
485 	    {
486 	      K3 = 1L << mpn_fft_best_k (nprime2, sqr);
487 	      if ((nprime2 & (K3 - 1)) == 0)
488 		break;
489 	      nprime2 = (nprime2 + K3 - 1) & -K3;
490 	      Nprime2 = nprime2 * GMP_LIMB_BITS;
491 	      /* warning: since nprime2 changed, K3 may change too! */
492 	    }
493 	}
494       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
495 
496       Mp2 = Nprime2 >> k;
497 
498       Ap = TMP_ALLOC_MP_PTRS (K2);
499       Bp = TMP_ALLOC_MP_PTRS (K2);
500       A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1));
501       T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
502       B = A + K2 * (nprime2 + 1);
503       _fft_l = TMP_ALLOC_TYPE (k + 1, int *);
504       for (i = 0; i <= k; i++)
505 	_fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
506       mpn_fft_initl (_fft_l, k);
507 
508       TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
509 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
510       for (i = 0; i < K; i++, ap++, bp++)
511 	{
512 	  mpn_fft_normalize (*ap, n);
513 	  if (!sqr)
514 	    mpn_fft_normalize (*bp, n);
515 	  mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2,
516 				l, Mp2, _fft_l, T, 1);
517 	}
518     }
519   else
520     {
521       mp_ptr a, b, tp, tpn;
522       mp_limb_t cc;
523       int n2 = 2 * n;
524       tp = TMP_ALLOC_LIMBS (n2);
525       tpn = tp + n;
526       TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
527       for (i = 0; i < K; i++)
528 	{
529 	  a = *ap++;
530 	  b = *bp++;
531 	  if (sqr)
532 	    mpn_sqr_n (tp, a, n);
533 	  else
534 	    mpn_mul_n (tp, b, a, n);
535 	  if (a[n] != 0)
536 	    cc = mpn_add_n (tpn, tpn, b, n);
537 	  else
538 	    cc = 0;
539 	  if (b[n] != 0)
540 	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
541 	  if (cc != 0)
542 	    {
543 	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
544 	      cc = mpn_add_1 (tp, tp, n2, cc);
545 	      ASSERT (cc == 0);
546 	    }
547 	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
548 	}
549     }
550   TMP_FREE;
551 }
552 
553 
554 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
555    output: K*A[0] K*A[K-1] ... K*A[1].
556    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
557    This condition is also fulfilled at exit.
558 */
559 static void
560 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
561 {
562   if (K == 2)
563     {
564       mp_limb_t cy;
565 #if HAVE_NATIVE_mpn_addsub_n
566       cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
567 #else
568       MPN_COPY (tp, Ap[0], n + 1);
569       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
570       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
571 #endif
572       if (Ap[0][n] > 1) /* can be 2 or 3 */
573 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
574       if (cy) /* Ap[1][n] can be -1 or -2 */
575 	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
576     }
577   else
578     {
579       int j, K2 = K >> 1;
580 
581       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
582       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
583       /* A[j]     <- A[j] + omega^j A[j+K/2]
584 	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
585       for (j = 0; j < K2; j++, Ap++)
586 	{
587 	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
588 	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
589 	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
590 	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
591 	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
592 	}
593     }
594 }
595 
596 
597 /* A <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
598 static void
599 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
600 {
601   int i;
602 
603   ASSERT (r != a);
604   i = 2 * n * GMP_NUMB_BITS;
605   i = (i - k) % i;		/* FIXME: This % looks superfluous */
606   mpn_fft_mul_2exp_modF (r, a, i, n);
607   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
608   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
609   mpn_fft_normalize (r, n);
610 }
611 
612 
613 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
614    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
615    then {rp,n}=0.
616 */
617 static int
618 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
619 {
620   mp_size_t l;
621   long int m;
622   mp_limb_t cc;
623   int rpn;
624 
625   ASSERT ((n <= an) && (an <= 3 * n));
626   m = an - 2 * n;
627   if (m > 0)
628     {
629       l = n;
630       /* add {ap, m} and {ap+2n, m} in {rp, m} */
631       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
632       /* copy {ap+m, n-m} to {rp+m, n-m} */
633       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
634     }
635   else
636     {
637       l = an - n; /* l <= n */
638       MPN_COPY (rp, ap, n);
639       rpn = 0;
640     }
641 
642   /* remains to subtract {ap+n, l} from {rp, n+1} */
643   cc = mpn_sub_n (rp, rp, ap + n, l);
644   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
645   if (rpn < 0) /* necessarily rpn = -1 */
646     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
647   return rpn;
648 }
649 
650 /* store in A[0..nprime] the first M bits from {n, nl},
651    in A[nprime+1..] the following M bits, ...
652    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
653    T must have space for at least (nprime + 1) limbs.
654    We must have nl <= 2*K*l.
655 */
656 static void
657 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
658 		       mp_size_t nl, int l, int Mp, mp_ptr T)
659 {
660   int i, j;
661   mp_ptr tmp;
662   mp_size_t Kl = K * l;
663   TMP_DECL;
664   TMP_MARK;
665 
666   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
667     {
668       mp_size_t dif = nl - Kl;
669       mp_limb_signed_t cy;
670 
671       tmp = TMP_ALLOC_LIMBS(Kl + 1);
672 
673       if (dif > Kl)
674 	{
675 	  int subp = 0;
676 
677 	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
678 	  n += 2 * Kl;
679 	  dif -= Kl;
680 
681 	  /* now dif > 0 */
682 	  while (dif > Kl)
683 	    {
684 	      if (subp)
685 		cy += mpn_sub_n (tmp, tmp, n, Kl);
686 	      else
687 		cy -= mpn_add_n (tmp, tmp, n, Kl);
688 	      subp ^= 1;
689 	      n += Kl;
690 	      dif -= Kl;
691 	    }
692 	  /* now dif <= Kl */
693 	  if (subp)
694 	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
695 	  else
696 	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
697 	  if (cy >= 0)
698 	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
699 	  else
700 	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
701 	}
702       else /* dif <= Kl, i.e. nl <= 2 * Kl */
703 	{
704 	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
705 	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
706 	}
707       tmp[Kl] = cy;
708       nl = Kl + 1;
709       n = tmp;
710     }
711   for (i = 0; i < K; i++)
712     {
713       Ap[i] = A;
714       /* store the next M bits of n into A[0..nprime] */
715       if (nl > 0) /* nl is the number of remaining limbs */
716 	{
717 	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
718 	  nl -= j;
719 	  MPN_COPY (T, n, j);
720 	  MPN_ZERO (T + j, nprime + 1 - j);
721 	  n += l;
722 	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
723 	}
724       else
725 	MPN_ZERO (A, nprime + 1);
726       A += nprime + 1;
727     }
728   ASSERT_ALWAYS (nl == 0);
729   TMP_FREE;
730 }
731 
732 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
733    n and m have respectively nl and ml limbs
734    op must have space for pl+1 limbs if rec=1 (and pl limbs if rec=0).
735    One must have pl = mpn_fft_next_size (pl, k).
736    T must have space for 2 * (nprime + 1) limbs.
737 
738    If rec=0, then store only the pl low bits of the result, and return
739    the out carry.
740 */
741 
742 static mp_limb_t
743 mpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl,
744 		      int k, int K,
745 		      mp_ptr *Ap, mp_ptr *Bp,
746 		      mp_ptr A, mp_ptr B,
747 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
748 		      int **_fft_l,
749 		      mp_ptr T, int rec)
750 {
751   int i, sqr, pla, lo, sh, j;
752   mp_ptr p;
753   mp_limb_t cc;
754 
755   sqr = n == m;
756 
757   TRACE (printf ("pl=%ld k=%d K=%d np=%ld l=%ld Mp=%ld rec=%d sqr=%d\n",
758 		 pl,k,K,nprime,l,Mp,rec,sqr));
759 
760   /* decomposition of inputs into arrays Ap[i] and Bp[i] */
761   if (rec)
762     {
763       mpn_mul_fft_decompose (A, Ap, K, nprime, n, K * l + 1, l, Mp, T);
764       if (!sqr)
765 	mpn_mul_fft_decompose (B, Bp, K, nprime, m, K * l + 1, l, Mp, T);
766     }
767 
768   /* direct fft's */
769   mpn_fft_fft (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T);
770   if (!sqr)
771     mpn_fft_fft (Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T);
772 
773   /* term to term multiplications */
774   mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K);
775 
776   /* inverse fft's */
777   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
778 
779   /* division of terms after inverse fft */
780   Bp[0] = T + nprime + 1;
781   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
782   for (i = 1; i < K; i++)
783     {
784       Bp[i] = Ap[i - 1];
785       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
786     }
787 
788   /* addition of terms in result p */
789   MPN_ZERO (T, nprime + 1);
790   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
791   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
792   MPN_ZERO (p, pla);
793   cc = 0; /* will accumulate the (signed) carry at p[pla] */
794   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
795     {
796       mp_ptr n = p + sh;
797 
798       j = (K - i) & (K - 1);
799 
800       if (mpn_add_n (n, n, Bp[j], nprime + 1))
801 	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
802 			  pla - sh - nprime - 1, CNST_LIMB(1));
803       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
804       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
805 	{ /* subtract 2^N'+1 */
806 	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
807 	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
808 	}
809     }
810   if (cc == -CNST_LIMB(1))
811     {
812       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
813 	{
814 	  /* p[pla-pl]...p[pla-1] are all zero */
815 	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
816 	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
817 	}
818     }
819   else if (cc == 1)
820     {
821       if (pla >= 2 * pl)
822 	{
823 	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
824 	    ;
825 	}
826       else
827 	{
828 	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
829 	  ASSERT (cc == 0);
830 	}
831     }
832   else
833     ASSERT (cc == 0);
834 
835   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
836      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
837      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
838   i = mpn_fft_norm_modF (op, pl, p, pla);
839   if (rec) /* store the carry out */
840     op[pl] = i;
841 
842   return i;
843 }
844 
845 /* return the lcm of a and 2^k */
846 static unsigned long int
847 mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
848 {
849   unsigned long int l = k;
850 
851   while (a % 2 == 0 && k > 0)
852     {
853       a >>= 1;
854       k --;
855     }
856   return a << l;
857 }
858 
859 
860 mp_limb_t
861 mpn_mul_fft (mp_ptr op, mp_size_t pl,
862 	     mp_srcptr n, mp_size_t nl,
863 	     mp_srcptr m, mp_size_t ml,
864 	     int k)
865 {
866   int K, maxLK, i;
867   mp_size_t N, Nprime, nprime, M, Mp, l;
868   mp_ptr *Ap, *Bp, A, T, B;
869   int **_fft_l;
870   int sqr = (n == m && nl == ml);
871   mp_limb_t h;
872   TMP_DECL;
873 
874   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
875   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
876 
877   TMP_MARK;
878   N = pl * GMP_NUMB_BITS;
879   _fft_l = TMP_ALLOC_TYPE (k + 1, int *);
880   for (i = 0; i <= k; i++)
881     _fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
882   mpn_fft_initl (_fft_l, k);
883   K = 1 << k;
884   M = N >> k;	/* N = 2^k M */
885   l = 1 + (M - 1) / GMP_NUMB_BITS;
886   maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
887 
888   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
889   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
890   nprime = Nprime / GMP_NUMB_BITS;
891   TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
892 		 N, K, M, l, maxLK, Nprime, nprime));
893   /* we should ensure that recursively, nprime is a multiple of the next K */
894   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
895     {
896       unsigned long K2;
897       for (;;)
898 	{
899 	  K2 = 1L << mpn_fft_best_k (nprime, sqr);
900 	  if ((nprime & (K2 - 1)) == 0)
901 	    break;
902 	  nprime = (nprime + K2 - 1) & -K2;
903 	  Nprime = nprime * GMP_LIMB_BITS;
904 	  /* warning: since nprime changed, K2 may change too! */
905 	}
906       TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
907     }
908   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
909 
910   T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
911   Mp = Nprime >> k;
912 
913   TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
914 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
915 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
916 
917   A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1));
918   B = A + K * (nprime + 1);
919   Ap = TMP_ALLOC_MP_PTRS (K);
920   Bp = TMP_ALLOC_MP_PTRS (K);
921 
922   /* special decomposition for main call */
923   /* nl is the number of significant limbs in n */
924   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
925   if (n != m)
926     mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
927 
928   h = mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0);
929 
930   TMP_FREE;
931   __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1));
932 
933   return h;
934 }
935 
936 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
937 void
938 mpn_mul_fft_full (mp_ptr op,
939 		  mp_srcptr n, mp_size_t nl,
940 		  mp_srcptr m, mp_size_t ml)
941 {
942   mp_ptr pad_op;
943   mp_size_t pl, pl2, pl3, l;
944   int k2, k3;
945   int sqr = (n == m && nl == ml);
946   int cc, c2, oldcc;
947 
948   pl = nl + ml; /* total number of limbs of the result */
949 
950   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
951      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
952      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
953      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
954      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
955      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
956      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
957 
958   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
959 
960   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
961   do
962     {
963       pl2 ++;
964       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
965       pl2 = mpn_fft_next_size (pl2, k2);
966       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
967 			    thus pl2 / 2 is exact */
968       k3 = mpn_fft_best_k (pl3, sqr);
969     }
970   while (mpn_fft_next_size (pl3, k3) != pl3);
971 
972   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
973 		 nl, ml, pl2, pl3, k2));
974 
975   ASSERT_ALWAYS(pl3 <= pl);
976   cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
977   ASSERT_ALWAYS(cc == 0);
978   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
979   cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
980   cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
981   /* 0 <= cc <= 1 */
982   ASSERT_ALWAYS(0 <= cc && cc <= 1);
983   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
984   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
985   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
986   ASSERT_ALWAYS(-1 <= cc && cc <= 1);
987   if (cc < 0)
988     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
989   ASSERT_ALWAYS(0 <= cc && cc <= 1);
990   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
991   oldcc = cc;
992 #if HAVE_NATIVE_mpn_addsub_n
993   c2 = mpn_addsub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
994   /* c2 & 1 is the borrow, c2 & 2 is the carry */
995   cc += c2 >> 1; /* carry out from high <- low + high */
996   c2 = c2 & 1; /* borrow out from low <- low - high */
997 #else
998   {
999     mp_ptr tmp;
1000     TMP_DECL;
1001 
1002     TMP_MARK;
1003     tmp = TMP_ALLOC_LIMBS (l);
1004     MPN_COPY (tmp, pad_op, l);
1005     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
1006     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
1007     TMP_FREE;
1008   }
1009 #endif
1010   c2 += oldcc;
1011   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
1012      at pad_op + l, cc is the carry at pad_op + pl2 */
1013   /* 0 <= cc <= 2 */
1014   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
1015   /* -1 <= cc <= 2 */
1016   if (cc > 0)
1017     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
1018   /* now -1 <= cc <= 0 */
1019   if (cc < 0)
1020     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1021   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1022   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1023     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1024   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1025      out below */
1026   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1027   if (cc) /* then cc=1 */
1028     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1029   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1030      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1031   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1032   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1033   MPN_COPY (op + pl3, pad_op, pl - pl3);
1034   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1035   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1036   /* since the final result has at most pl limbs, no carry out below */
1037   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1038 }
1039