xref: /dragonfly/contrib/gmp/mpn/generic/mul_fft.c (revision 86d7f5d3)
1 /* Schoenhage's fast multiplication modulo 2^N+1.
2 
3    Contributed by Paul Zimmermann.
4 
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8 
9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
10 2009, 2010 Free Software Foundation, Inc.
11 
12 This file is part of the GNU MP Library.
13 
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of the GNU Lesser General Public License as published by
16 the Free Software Foundation; either version 3 of the License, or (at your
17 option) any later version.
18 
19 The GNU MP Library is distributed in the hope that it will be useful, but
20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
22 License for more details.
23 
24 You should have received a copy of the GNU Lesser General Public License
25 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
26 
27 
28 /* References:
29 
30    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31    Strassen, Computing 7, p. 281-292, 1971.
32 
33    Asymptotically fast algorithms for the numerical multiplication and division
34    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36 
37    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39 
40    TODO:
41 
42    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43    Zimmermann.
44 
45    It might be possible to avoid a small number of MPN_COPYs by using a
46    rotating temporary or two.
47 
48    Cleanup and simplify the code!
49 */
50 
51 #ifdef TRACE
52 #undef TRACE
53 #define TRACE(x) x
54 #include <stdio.h>
55 #else
56 #define TRACE(x)
57 #endif
58 
59 #include "gmp.h"
60 #include "gmp-impl.h"
61 
62 #ifdef WANT_ADDSUB
63 #include "generic/add_n_sub_n.c"
64 #define HAVE_NATIVE_mpn_add_n_sub_n 1
65 #endif
66 
67 static mp_limb_t mpn_mul_fft_internal
68 __GMP_PROTO ((mp_ptr, mp_size_t, int, mp_ptr *, mp_ptr *,
69 	      mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, int));
70 static void mpn_mul_fft_decompose
71 __GMP_PROTO ((mp_ptr, mp_ptr *, int, int, mp_srcptr, mp_size_t, int, int, mp_ptr));
72 
73 
74 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
75    We have sqr=0 if for a multiply, sqr=1 for a square.
76    There are three generations of this code; we keep the old ones as long as
77    some gmp-mparam.h is not updated.  */
78 
79 
80 /*****************************************************************************/
81 
82 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
83 
84 #ifndef FFT_TABLE3_SIZE		/* When tuning, this is define in gmp-impl.h */
85 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
86 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
87 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
88 #else
89 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
90 #endif
91 #endif
92 #endif
93 
94 #ifndef FFT_TABLE3_SIZE
95 #define FFT_TABLE3_SIZE 200
96 #endif
97 
98 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
99 {
100   MUL_FFT_TABLE3,
101   SQR_FFT_TABLE3
102 };
103 
104 int
mpn_fft_best_k(mp_size_t n,int sqr)105 mpn_fft_best_k (mp_size_t n, int sqr)
106 {
107   FFT_TABLE_ATTRS struct fft_table_nk *fft_tab, *tab;
108   mp_size_t tab_n, thres;
109   int last_k;
110 
111   fft_tab = mpn_fft_table3[sqr];
112   last_k = fft_tab->k;
113   for (tab = fft_tab + 1; ; tab++)
114     {
115       tab_n = tab->n;
116       thres = tab_n << last_k;
117       if (n <= thres)
118 	break;
119       last_k = tab->k;
120     }
121   return last_k;
122 }
123 
124 #define MPN_FFT_BEST_READY 1
125 #endif
126 
127 /*****************************************************************************/
128 
129 #if ! defined (MPN_FFT_BEST_READY)
130 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
131 {
132   MUL_FFT_TABLE,
133   SQR_FFT_TABLE
134 };
135 
136 int
mpn_fft_best_k(mp_size_t n,int sqr)137 mpn_fft_best_k (mp_size_t n, int sqr)
138 {
139   int i;
140 
141   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
142     if (n < mpn_fft_table[sqr][i])
143       return i + FFT_FIRST_K;
144 
145   /* treat 4*last as one further entry */
146   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
147     return i + FFT_FIRST_K;
148   else
149     return i + FFT_FIRST_K + 1;
150 }
151 #endif
152 
153 /*****************************************************************************/
154 
155 
156 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
157    i.e. smallest multiple of 2^k >= pl.
158 
159    Don't declare static: needed by tuneup.
160 */
161 
162 mp_size_t
mpn_fft_next_size(mp_size_t pl,int k)163 mpn_fft_next_size (mp_size_t pl, int k)
164 {
165   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
166   return pl << k;
167 }
168 
169 
170 /* Initialize l[i][j] with bitrev(j) */
171 static void
mpn_fft_initl(int ** l,int k)172 mpn_fft_initl (int **l, int k)
173 {
174   int i, j, K;
175   int *li;
176 
177   l[0][0] = 0;
178   for (i = 1, K = 1; i <= k; i++, K *= 2)
179     {
180       li = l[i];
181       for (j = 0; j < K; j++)
182 	{
183 	  li[j] = 2 * l[i - 1][j];
184 	  li[K + j] = 1 + li[j];
185 	}
186     }
187 }
188 
189 
190 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
191    Assumes a is semi-normalized, i.e. a[n] <= 1.
192    r and a must have n+1 limbs, and not overlap.
193 */
194 static void
mpn_fft_mul_2exp_modF(mp_ptr r,mp_srcptr a,unsigned int d,mp_size_t n)195 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
196 {
197   int sh;
198   mp_limb_t cc, rd;
199 
200   sh = d % GMP_NUMB_BITS;
201   d /= GMP_NUMB_BITS;
202 
203   if (d >= n)			/* negate */
204     {
205       /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
206 	 r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
207 
208       d -= n;
209       if (sh != 0)
210 	{
211 	  /* no out shift below since a[n] <= 1 */
212 	  mpn_lshift (r, a + n - d, d + 1, sh);
213 	  rd = r[d];
214 	  cc = mpn_lshiftc (r + d, a, n - d, sh);
215 	}
216       else
217 	{
218 	  MPN_COPY (r, a + n - d, d);
219 	  rd = a[n];
220 	  mpn_com (r + d, a, n - d);
221 	  cc = 0;
222 	}
223 
224       /* add cc to r[0], and add rd to r[d] */
225 
226       /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
227 
228       r[n] = 0;
229       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
230       cc++;
231       mpn_incr_u (r, cc);
232 
233       rd++;
234       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
235       cc = (rd == 0) ? 1 : rd;
236       r = r + d + (rd == 0);
237       mpn_incr_u (r, cc);
238     }
239   else
240     {
241       /* r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
242 	 r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)  */
243       if (sh != 0)
244 	{
245 	  /* no out bits below since a[n] <= 1 */
246 	  mpn_lshiftc (r, a + n - d, d + 1, sh);
247 	  rd = ~r[d];
248 	  /* {r, d+1} = {a+n-d, d+1} << sh */
249 	  cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
250 	}
251       else
252 	{
253 	  /* r[d] is not used below, but we save a test for d=0 */
254 	  mpn_com (r, a + n - d, d + 1);
255 	  rd = a[n];
256 	  MPN_COPY (r + d, a, n - d);
257 	  cc = 0;
258 	}
259 
260       /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
261 
262       /* if d=0 we just have r[0]=a[n] << sh */
263       if (d != 0)
264 	{
265 	  /* now add 1 in r[0], subtract 1 in r[d] */
266 	  if (cc-- == 0) /* then add 1 to r[0] */
267 	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
268 	  cc = mpn_sub_1 (r, r, d, cc) + 1;
269 	  /* add 1 to cc instead of rd since rd might overflow */
270 	}
271 
272       /* now subtract cc and rd from r[d..n] */
273 
274       r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
275       r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
276       if (r[n] & GMP_LIMB_HIGHBIT)
277 	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
278     }
279 }
280 
281 
282 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
283    Assumes a and b are semi-normalized.
284 */
285 static inline void
mpn_fft_add_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,int n)286 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
287 {
288   mp_limb_t c, x;
289 
290   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
291   /* 0 <= c <= 3 */
292 
293 #if 1
294   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
295      result is slower code, of course.  But the following outsmarts GCC.  */
296   x = (c - 1) & -(c != 0);
297   r[n] = c - x;
298   MPN_DECR_U (r, n + 1, x);
299 #endif
300 #if 0
301   if (c > 1)
302     {
303       r[n] = 1;                       /* r[n] - c = 1 */
304       MPN_DECR_U (r, n + 1, c - 1);
305     }
306   else
307     {
308       r[n] = c;
309     }
310 #endif
311 }
312 
313 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
314    Assumes a and b are semi-normalized.
315 */
316 static inline void
mpn_fft_sub_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,int n)317 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
318 {
319   mp_limb_t c, x;
320 
321   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
322   /* -2 <= c <= 1 */
323 
324 #if 1
325   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
326      result is slower code, of course.  But the following outsmarts GCC.  */
327   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
328   r[n] = x + c;
329   MPN_INCR_U (r, n + 1, x);
330 #endif
331 #if 0
332   if ((c & GMP_LIMB_HIGHBIT) != 0)
333     {
334       r[n] = 0;
335       MPN_INCR_U (r, n + 1, -c);
336     }
337   else
338     {
339       r[n] = c;
340     }
341 #endif
342 }
343 
344 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
345 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
346    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
347 
348 static void
mpn_fft_fft(mp_ptr * Ap,mp_size_t K,int ** ll,mp_size_t omega,mp_size_t n,mp_size_t inc,mp_ptr tp)349 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
350 	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
351 {
352   if (K == 2)
353     {
354       mp_limb_t cy;
355 #if HAVE_NATIVE_mpn_add_n_sub_n
356       cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
357 #else
358       MPN_COPY (tp, Ap[0], n + 1);
359       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
360       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
361 #endif
362       if (Ap[0][n] > 1) /* can be 2 or 3 */
363 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
364       if (cy) /* Ap[inc][n] can be -1 or -2 */
365 	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
366     }
367   else
368     {
369       int j;
370       int *lk = *ll;
371 
372       mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
373       mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
374       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
375 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
376       for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
377 	{
378 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
379 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
380 	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
381 	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
382 	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
383 	}
384     }
385 }
386 
387 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
388 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
389    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
390    tp must have space for 2*(n+1) limbs.
391 */
392 
393 
394 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
395    by subtracting that modulus if necessary.
396 
397    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
398    borrow and the limbs must be zeroed out again.  This will occur very
399    infrequently.  */
400 
401 static inline void
mpn_fft_normalize(mp_ptr ap,mp_size_t n)402 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
403 {
404   if (ap[n] != 0)
405     {
406       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
407       if (ap[n] == 0)
408 	{
409 	  /* This happens with very low probability; we have yet to trigger it,
410 	     and thereby make sure this code is correct.  */
411 	  MPN_ZERO (ap, n);
412 	  ap[n] = 1;
413 	}
414       else
415 	ap[n] = 0;
416     }
417 }
418 
419 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
420 static void
mpn_fft_mul_modF_K(mp_ptr * ap,mp_ptr * bp,mp_size_t n,int K)421 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
422 {
423   int i;
424   int sqr = (ap == bp);
425   TMP_DECL;
426 
427   TMP_MARK;
428 
429   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
430     {
431       int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
432       int **fft_l;
433       mp_ptr *Ap, *Bp, A, B, T;
434 
435       k = mpn_fft_best_k (n, sqr);
436       K2 = 1 << k;
437       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
438       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
439       M2 = n * GMP_NUMB_BITS >> k;
440       l = n >> k;
441       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
442       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
443       nprime2 = Nprime2 / GMP_NUMB_BITS;
444 
445       /* we should ensure that nprime2 is a multiple of the next K */
446       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
447 	{
448 	  unsigned long K3;
449 	  for (;;)
450 	    {
451 	      K3 = 1L << mpn_fft_best_k (nprime2, sqr);
452 	      if ((nprime2 & (K3 - 1)) == 0)
453 		break;
454 	      nprime2 = (nprime2 + K3 - 1) & -K3;
455 	      Nprime2 = nprime2 * GMP_LIMB_BITS;
456 	      /* warning: since nprime2 changed, K3 may change too! */
457 	    }
458 	}
459       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
460 
461       Mp2 = Nprime2 >> k;
462 
463       Ap = TMP_ALLOC_MP_PTRS (K2);
464       Bp = TMP_ALLOC_MP_PTRS (K2);
465       A = TMP_ALLOC_LIMBS (2 * (nprime2 + 1) << k);
466       T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
467       B = A + ((nprime2 + 1) << k);
468       fft_l = TMP_ALLOC_TYPE (k + 1, int *);
469       for (i = 0; i <= k; i++)
470 	fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
471       mpn_fft_initl (fft_l, k);
472 
473       TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
474 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
475       for (i = 0; i < K; i++, ap++, bp++)
476 	{
477 	  mp_limb_t cy;
478 	  mpn_fft_normalize (*ap, n);
479 	  if (!sqr)
480 	    mpn_fft_normalize (*bp, n);
481 
482 	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
483 	  if (!sqr)
484 	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
485 
486 	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
487 				     l, Mp2, fft_l, T, sqr);
488 	  (*ap)[n] = cy;
489 	}
490     }
491   else
492     {
493       mp_ptr a, b, tp, tpn;
494       mp_limb_t cc;
495       int n2 = 2 * n;
496       tp = TMP_ALLOC_LIMBS (n2);
497       tpn = tp + n;
498       TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
499       for (i = 0; i < K; i++)
500 	{
501 	  a = *ap++;
502 	  b = *bp++;
503 	  if (sqr)
504 	    mpn_sqr (tp, a, n);
505 	  else
506 	    mpn_mul_n (tp, b, a, n);
507 	  if (a[n] != 0)
508 	    cc = mpn_add_n (tpn, tpn, b, n);
509 	  else
510 	    cc = 0;
511 	  if (b[n] != 0)
512 	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
513 	  if (cc != 0)
514 	    {
515 	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
516 	      cc = mpn_add_1 (tp, tp, n2, cc);
517 	      ASSERT (cc == 0);
518 	    }
519 	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
520 	}
521     }
522   TMP_FREE;
523 }
524 
525 
526 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
527    output: K*A[0] K*A[K-1] ... K*A[1].
528    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
529    This condition is also fulfilled at exit.
530 */
531 static void
mpn_fft_fftinv(mp_ptr * Ap,int K,mp_size_t omega,mp_size_t n,mp_ptr tp)532 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
533 {
534   if (K == 2)
535     {
536       mp_limb_t cy;
537 #if HAVE_NATIVE_mpn_add_n_sub_n
538       cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
539 #else
540       MPN_COPY (tp, Ap[0], n + 1);
541       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
542       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
543 #endif
544       if (Ap[0][n] > 1) /* can be 2 or 3 */
545 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
546       if (cy) /* Ap[1][n] can be -1 or -2 */
547 	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
548     }
549   else
550     {
551       int j, K2 = K >> 1;
552 
553       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
554       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
555       /* A[j]     <- A[j] + omega^j A[j+K/2]
556 	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
557       for (j = 0; j < K2; j++, Ap++)
558 	{
559 	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
560 	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
561 	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
562 	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
563 	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
564 	}
565     }
566 }
567 
568 
569 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
570 static void
mpn_fft_div_2exp_modF(mp_ptr r,mp_srcptr a,int k,mp_size_t n)571 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
572 {
573   int i;
574 
575   ASSERT (r != a);
576   i = 2 * n * GMP_NUMB_BITS - k;
577   mpn_fft_mul_2exp_modF (r, a, i, n);
578   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
579   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
580   mpn_fft_normalize (r, n);
581 }
582 
583 
584 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
585    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
586    then {rp,n}=0.
587 */
588 static int
mpn_fft_norm_modF(mp_ptr rp,mp_size_t n,mp_ptr ap,mp_size_t an)589 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
590 {
591   mp_size_t l;
592   long int m;
593   mp_limb_t cc;
594   int rpn;
595 
596   ASSERT ((n <= an) && (an <= 3 * n));
597   m = an - 2 * n;
598   if (m > 0)
599     {
600       l = n;
601       /* add {ap, m} and {ap+2n, m} in {rp, m} */
602       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
603       /* copy {ap+m, n-m} to {rp+m, n-m} */
604       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
605     }
606   else
607     {
608       l = an - n; /* l <= n */
609       MPN_COPY (rp, ap, n);
610       rpn = 0;
611     }
612 
613   /* remains to subtract {ap+n, l} from {rp, n+1} */
614   cc = mpn_sub_n (rp, rp, ap + n, l);
615   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
616   if (rpn < 0) /* necessarily rpn = -1 */
617     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
618   return rpn;
619 }
620 
621 /* store in A[0..nprime] the first M bits from {n, nl},
622    in A[nprime+1..] the following M bits, ...
623    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
624    T must have space for at least (nprime + 1) limbs.
625    We must have nl <= 2*K*l.
626 */
627 static void
mpn_mul_fft_decompose(mp_ptr A,mp_ptr * Ap,int K,int nprime,mp_srcptr n,mp_size_t nl,int l,int Mp,mp_ptr T)628 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
629 		       mp_size_t nl, int l, int Mp, mp_ptr T)
630 {
631   int i, j;
632   mp_ptr tmp;
633   mp_size_t Kl = K * l;
634   TMP_DECL;
635   TMP_MARK;
636 
637   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
638     {
639       mp_size_t dif = nl - Kl;
640       mp_limb_signed_t cy;
641 
642       tmp = TMP_ALLOC_LIMBS(Kl + 1);
643 
644       if (dif > Kl)
645 	{
646 	  int subp = 0;
647 
648 	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
649 	  n += 2 * Kl;
650 	  dif -= Kl;
651 
652 	  /* now dif > 0 */
653 	  while (dif > Kl)
654 	    {
655 	      if (subp)
656 		cy += mpn_sub_n (tmp, tmp, n, Kl);
657 	      else
658 		cy -= mpn_add_n (tmp, tmp, n, Kl);
659 	      subp ^= 1;
660 	      n += Kl;
661 	      dif -= Kl;
662 	    }
663 	  /* now dif <= Kl */
664 	  if (subp)
665 	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
666 	  else
667 	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
668 	  if (cy >= 0)
669 	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
670 	  else
671 	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
672 	}
673       else /* dif <= Kl, i.e. nl <= 2 * Kl */
674 	{
675 	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
676 	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
677 	}
678       tmp[Kl] = cy;
679       nl = Kl + 1;
680       n = tmp;
681     }
682   for (i = 0; i < K; i++)
683     {
684       Ap[i] = A;
685       /* store the next M bits of n into A[0..nprime] */
686       if (nl > 0) /* nl is the number of remaining limbs */
687 	{
688 	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
689 	  nl -= j;
690 	  MPN_COPY (T, n, j);
691 	  MPN_ZERO (T + j, nprime + 1 - j);
692 	  n += l;
693 	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
694 	}
695       else
696 	MPN_ZERO (A, nprime + 1);
697       A += nprime + 1;
698     }
699   ASSERT_ALWAYS (nl == 0);
700   TMP_FREE;
701 }
702 
703 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
704    op is pl limbs, its high bit is returned.
705    One must have pl = mpn_fft_next_size (pl, k).
706    T must have space for 2 * (nprime + 1) limbs.
707 */
708 
709 static mp_limb_t
mpn_mul_fft_internal(mp_ptr op,mp_size_t pl,int k,mp_ptr * Ap,mp_ptr * Bp,mp_ptr A,mp_ptr B,mp_size_t nprime,mp_size_t l,mp_size_t Mp,int ** fft_l,mp_ptr T,int sqr)710 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
711 		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
712 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
713 		      int **fft_l, mp_ptr T, int sqr)
714 {
715   int K, i, pla, lo, sh, j;
716   mp_ptr p;
717   mp_limb_t cc;
718 
719   K = 1 << k;
720 
721   /* direct fft's */
722   mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
723   if (!sqr)
724     mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
725 
726   /* term to term multiplications */
727   mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
728 
729   /* inverse fft's */
730   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
731 
732   /* division of terms after inverse fft */
733   Bp[0] = T + nprime + 1;
734   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
735   for (i = 1; i < K; i++)
736     {
737       Bp[i] = Ap[i - 1];
738       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
739     }
740 
741   /* addition of terms in result p */
742   MPN_ZERO (T, nprime + 1);
743   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
744   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
745   MPN_ZERO (p, pla);
746   cc = 0; /* will accumulate the (signed) carry at p[pla] */
747   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
748     {
749       mp_ptr n = p + sh;
750 
751       j = (K - i) & (K - 1);
752 
753       if (mpn_add_n (n, n, Bp[j], nprime + 1))
754 	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
755 			  pla - sh - nprime - 1, CNST_LIMB(1));
756       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
757       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
758 	{ /* subtract 2^N'+1 */
759 	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
760 	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
761 	}
762     }
763   if (cc == -CNST_LIMB(1))
764     {
765       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
766 	{
767 	  /* p[pla-pl]...p[pla-1] are all zero */
768 	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
769 	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
770 	}
771     }
772   else if (cc == 1)
773     {
774       if (pla >= 2 * pl)
775 	{
776 	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
777 	    ;
778 	}
779       else
780 	{
781 	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
782 	  ASSERT (cc == 0);
783 	}
784     }
785   else
786     ASSERT (cc == 0);
787 
788   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
789      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
790      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
791   return mpn_fft_norm_modF (op, pl, p, pla);
792 }
793 
794 /* return the lcm of a and 2^k */
795 static unsigned long int
mpn_mul_fft_lcm(unsigned long int a,unsigned int k)796 mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
797 {
798   unsigned long int l = k;
799 
800   while (a % 2 == 0 && k > 0)
801     {
802       a >>= 1;
803       k --;
804     }
805   return a << l;
806 }
807 
808 
809 mp_limb_t
mpn_mul_fft(mp_ptr op,mp_size_t pl,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml,int k)810 mpn_mul_fft (mp_ptr op, mp_size_t pl,
811 	     mp_srcptr n, mp_size_t nl,
812 	     mp_srcptr m, mp_size_t ml,
813 	     int k)
814 {
815   int K, maxLK, i;
816   mp_size_t N, Nprime, nprime, M, Mp, l;
817   mp_ptr *Ap, *Bp, A, T, B;
818   int **fft_l;
819   int sqr = (n == m && nl == ml);
820   mp_limb_t h;
821   TMP_DECL;
822 
823   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
824   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
825 
826   TMP_MARK;
827   N = pl * GMP_NUMB_BITS;
828   fft_l = TMP_ALLOC_TYPE (k + 1, int *);
829   for (i = 0; i <= k; i++)
830     fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
831   mpn_fft_initl (fft_l, k);
832   K = 1 << k;
833   M = N >> k;	/* N = 2^k M */
834   l = 1 + (M - 1) / GMP_NUMB_BITS;
835   maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
836 
837   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
838   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
839   nprime = Nprime / GMP_NUMB_BITS;
840   TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
841 		 N, K, M, l, maxLK, Nprime, nprime));
842   /* we should ensure that recursively, nprime is a multiple of the next K */
843   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
844     {
845       unsigned long K2;
846       for (;;)
847 	{
848 	  K2 = 1L << mpn_fft_best_k (nprime, sqr);
849 	  if ((nprime & (K2 - 1)) == 0)
850 	    break;
851 	  nprime = (nprime + K2 - 1) & -K2;
852 	  Nprime = nprime * GMP_LIMB_BITS;
853 	  /* warning: since nprime changed, K2 may change too! */
854 	}
855       TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
856     }
857   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
858 
859   T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
860   Mp = Nprime >> k;
861 
862   TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
863 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
864 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
865 
866   A = TMP_ALLOC_LIMBS (K * (nprime + 1));
867   Ap = TMP_ALLOC_MP_PTRS (K);
868   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
869   if (sqr)
870     {
871       mp_size_t pla;
872       pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
873       B = TMP_ALLOC_LIMBS (pla);
874       Bp = TMP_ALLOC_MP_PTRS (K);
875     }
876   else
877     {
878       B = TMP_ALLOC_LIMBS (K * (nprime + 1));
879       Bp = TMP_ALLOC_MP_PTRS (K);
880       mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
881     }
882   h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
883 
884   TMP_FREE;
885   return h;
886 }
887 
888 #if WANT_OLD_FFT_FULL
889 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
890 void
mpn_mul_fft_full(mp_ptr op,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml)891 mpn_mul_fft_full (mp_ptr op,
892 		  mp_srcptr n, mp_size_t nl,
893 		  mp_srcptr m, mp_size_t ml)
894 {
895   mp_ptr pad_op;
896   mp_size_t pl, pl2, pl3, l;
897   int k2, k3;
898   int sqr = (n == m && nl == ml);
899   int cc, c2, oldcc;
900 
901   pl = nl + ml; /* total number of limbs of the result */
902 
903   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
904      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
905      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
906      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
907      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
908      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
909      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
910 
911   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
912 
913   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
914   do
915     {
916       pl2++;
917       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
918       pl2 = mpn_fft_next_size (pl2, k2);
919       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
920 			    thus pl2 / 2 is exact */
921       k3 = mpn_fft_best_k (pl3, sqr);
922     }
923   while (mpn_fft_next_size (pl3, k3) != pl3);
924 
925   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
926 		 nl, ml, pl2, pl3, k2));
927 
928   ASSERT_ALWAYS(pl3 <= pl);
929   cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
930   ASSERT(cc == 0);
931   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
932   cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
933   cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
934   /* 0 <= cc <= 1 */
935   ASSERT(0 <= cc && cc <= 1);
936   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
937   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
938   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
939   ASSERT(-1 <= cc && cc <= 1);
940   if (cc < 0)
941     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
942   ASSERT(0 <= cc && cc <= 1);
943   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
944   oldcc = cc;
945 #if HAVE_NATIVE_mpn_add_n_sub_n
946   c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
947   /* c2 & 1 is the borrow, c2 & 2 is the carry */
948   cc += c2 >> 1; /* carry out from high <- low + high */
949   c2 = c2 & 1; /* borrow out from low <- low - high */
950 #else
951   {
952     mp_ptr tmp;
953     TMP_DECL;
954 
955     TMP_MARK;
956     tmp = TMP_ALLOC_LIMBS (l);
957     MPN_COPY (tmp, pad_op, l);
958     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
959     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
960     TMP_FREE;
961   }
962 #endif
963   c2 += oldcc;
964   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
965      at pad_op + l, cc is the carry at pad_op + pl2 */
966   /* 0 <= cc <= 2 */
967   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
968   /* -1 <= cc <= 2 */
969   if (cc > 0)
970     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
971   /* now -1 <= cc <= 0 */
972   if (cc < 0)
973     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
974   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
975   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
976     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
977   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
978      out below */
979   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
980   if (cc) /* then cc=1 */
981     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
982   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
983      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
984   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
985   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
986   MPN_COPY (op + pl3, pad_op, pl - pl3);
987   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
988   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
989   /* since the final result has at most pl limbs, no carry out below */
990   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
991 }
992 #endif
993