1 /* Schoenhage's fast multiplication modulo 2^N+1.
2 
3    Contributed by Paul Zimmermann.
4 
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8 
9 Copyright 1998-2010, 2012, 2013 Free Software Foundation, Inc.
10 
11 This file is part of the GNU MP Library.
12 
13 The GNU MP Library is free software; you can redistribute it and/or modify
14 it under the terms of either:
15 
16   * the GNU Lesser General Public License as published by the Free
17     Software Foundation; either version 3 of the License, or (at your
18     option) any later version.
19 
20 or
21 
22   * the GNU General Public License as published by the Free Software
23     Foundation; either version 2 of the License, or (at your option) any
24     later version.
25 
26 or both in parallel, as here.
27 
28 The GNU MP Library is distributed in the hope that it will be useful, but
29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
30 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
31 for more details.
32 
33 You should have received copies of the GNU General Public License and the
34 GNU Lesser General Public License along with the GNU MP Library.  If not,
35 see https://www.gnu.org/licenses/.  */
36 
37 
38 /* References:
39 
40    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
41    Strassen, Computing 7, p. 281-292, 1971.
42 
43    Asymptotically fast algorithms for the numerical multiplication and division
44    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
45    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
46 
47    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
48    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
49 
50    TODO:
51 
52    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
53    Zimmermann.
54 
55    It might be possible to avoid a small number of MPN_COPYs by using a
56    rotating temporary or two.
57 
58    Cleanup and simplify the code!
59 */
60 
61 #ifdef TRACE
62 #undef TRACE
63 #define TRACE(x) x
64 #include <stdio.h>
65 #else
66 #define TRACE(x)
67 #endif
68 
69 #include "gmp.h"
70 #include "gmp-impl.h"
71 
72 #ifdef WANT_ADDSUB
73 #include "generic/add_n_sub_n.c"
74 #define HAVE_NATIVE_mpn_add_n_sub_n 1
75 #endif
76 
77 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
78 				       mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
79 				       mp_size_t, mp_size_t, int **, mp_ptr, int);
80 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
81 				   mp_size_t, mp_size_t, mp_size_t, mp_ptr);
82 
83 
84 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
85    We have sqr=0 if for a multiply, sqr=1 for a square.
86    There are three generations of this code; we keep the old ones as long as
87    some gmp-mparam.h is not updated.  */
88 
89 
90 /*****************************************************************************/
91 
92 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
93 
94 #ifndef FFT_TABLE3_SIZE		/* When tuning this is defined in gmp-impl.h */
95 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
96 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
97 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
98 #else
99 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
100 #endif
101 #endif
102 #endif
103 
104 #ifndef FFT_TABLE3_SIZE
105 #define FFT_TABLE3_SIZE 200
106 #endif
107 
108 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
109 {
110   MUL_FFT_TABLE3,
111   SQR_FFT_TABLE3
112 };
113 
114 int
mpn_fft_best_k(mp_size_t n,int sqr)115 mpn_fft_best_k (mp_size_t n, int sqr)
116 {
117   const struct fft_table_nk *fft_tab, *tab;
118   mp_size_t tab_n, thres;
119   int last_k;
120 
121   fft_tab = mpn_fft_table3[sqr];
122   last_k = fft_tab->k;
123   for (tab = fft_tab + 1; ; tab++)
124     {
125       tab_n = tab->n;
126       thres = tab_n << last_k;
127       if (n <= thres)
128 	break;
129       last_k = tab->k;
130     }
131   return last_k;
132 }
133 
134 #define MPN_FFT_BEST_READY 1
135 #endif
136 
137 /*****************************************************************************/
138 
139 #if ! defined (MPN_FFT_BEST_READY)
140 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
141 {
142   MUL_FFT_TABLE,
143   SQR_FFT_TABLE
144 };
145 
146 int
mpn_fft_best_k(mp_size_t n,int sqr)147 mpn_fft_best_k (mp_size_t n, int sqr)
148 {
149   int i;
150 
151   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
152     if (n < mpn_fft_table[sqr][i])
153       return i + FFT_FIRST_K;
154 
155   /* treat 4*last as one further entry */
156   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
157     return i + FFT_FIRST_K;
158   else
159     return i + FFT_FIRST_K + 1;
160 }
161 #endif
162 
163 /*****************************************************************************/
164 
165 
166 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
167    i.e. smallest multiple of 2^k >= pl.
168 
169    Don't declare static: needed by tuneup.
170 */
171 
172 mp_size_t
mpn_fft_next_size(mp_size_t pl,int k)173 mpn_fft_next_size (mp_size_t pl, int k)
174 {
175   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
176   return pl << k;
177 }
178 
179 
180 /* Initialize l[i][j] with bitrev(j) */
181 static void
mpn_fft_initl(int ** l,int k)182 mpn_fft_initl (int **l, int k)
183 {
184   int i, j, K;
185   int *li;
186 
187   l[0][0] = 0;
188   for (i = 1, K = 1; i <= k; i++, K *= 2)
189     {
190       li = l[i];
191       for (j = 0; j < K; j++)
192 	{
193 	  li[j] = 2 * l[i - 1][j];
194 	  li[K + j] = 1 + li[j];
195 	}
196     }
197 }
198 
199 
200 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
201    Assumes a is semi-normalized, i.e. a[n] <= 1.
202    r and a must have n+1 limbs, and not overlap.
203 */
204 static void
mpn_fft_mul_2exp_modF(mp_ptr r,mp_srcptr a,mp_bitcnt_t d,mp_size_t n)205 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
206 {
207   unsigned int sh;
208   mp_size_t m;
209   mp_limb_t cc, rd;
210 
211   sh = d % GMP_NUMB_BITS;
212   m = d / GMP_NUMB_BITS;
213 
214   if (m >= n)			/* negate */
215     {
216       /* r[0..m-1]  <-- lshift(a[n-m]..a[n-1], sh)
217 	 r[m..n-1]  <-- -lshift(a[0]..a[n-m-1],  sh) */
218 
219       m -= n;
220       if (sh != 0)
221 	{
222 	  /* no out shift below since a[n] <= 1 */
223 	  mpn_lshift (r, a + n - m, m + 1, sh);
224 	  rd = r[m];
225 	  cc = mpn_lshiftc (r + m, a, n - m, sh);
226 	}
227       else
228 	{
229 	  MPN_COPY (r, a + n - m, m);
230 	  rd = a[n];
231 	  mpn_com (r + m, a, n - m);
232 	  cc = 0;
233 	}
234 
235       /* add cc to r[0], and add rd to r[m] */
236 
237       /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
238 
239       r[n] = 0;
240       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
241       cc++;
242       mpn_incr_u (r, cc);
243 
244       rd++;
245       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
246       cc = (rd == 0) ? 1 : rd;
247       r = r + m + (rd == 0);
248       mpn_incr_u (r, cc);
249     }
250   else
251     {
252       /* r[0..m-1]  <-- -lshift(a[n-m]..a[n-1], sh)
253 	 r[m..n-1]  <-- lshift(a[0]..a[n-m-1],  sh)  */
254       if (sh != 0)
255 	{
256 	  /* no out bits below since a[n] <= 1 */
257 	  mpn_lshiftc (r, a + n - m, m + 1, sh);
258 	  rd = ~r[m];
259 	  /* {r, m+1} = {a+n-m, m+1} << sh */
260 	  cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
261 	}
262       else
263 	{
264 	  /* r[m] is not used below, but we save a test for m=0 */
265 	  mpn_com (r, a + n - m, m + 1);
266 	  rd = a[n];
267 	  MPN_COPY (r + m, a, n - m);
268 	  cc = 0;
269 	}
270 
271       /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
272 
273       /* if m=0 we just have r[0]=a[n] << sh */
274       if (m != 0)
275 	{
276 	  /* now add 1 in r[0], subtract 1 in r[m] */
277 	  if (cc-- == 0) /* then add 1 to r[0] */
278 	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
279 	  cc = mpn_sub_1 (r, r, m, cc) + 1;
280 	  /* add 1 to cc instead of rd since rd might overflow */
281 	}
282 
283       /* now subtract cc and rd from r[m..n] */
284 
285       r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
286       r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
287       if (r[n] & GMP_LIMB_HIGHBIT)
288 	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
289     }
290 }
291 
292 
293 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
294    Assumes a and b are semi-normalized.
295 */
296 static inline void
mpn_fft_add_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,mp_size_t n)297 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
298 {
299   mp_limb_t c, x;
300 
301   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
302   /* 0 <= c <= 3 */
303 
304 #if 1
305   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
306      result is slower code, of course.  But the following outsmarts GCC.  */
307   x = (c - 1) & -(c != 0);
308   r[n] = c - x;
309   MPN_DECR_U (r, n + 1, x);
310 #endif
311 #if 0
312   if (c > 1)
313     {
314       r[n] = 1;                       /* r[n] - c = 1 */
315       MPN_DECR_U (r, n + 1, c - 1);
316     }
317   else
318     {
319       r[n] = c;
320     }
321 #endif
322 }
323 
324 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
325    Assumes a and b are semi-normalized.
326 */
327 static inline void
mpn_fft_sub_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,mp_size_t n)328 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
329 {
330   mp_limb_t c, x;
331 
332   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
333   /* -2 <= c <= 1 */
334 
335 #if 1
336   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
337      result is slower code, of course.  But the following outsmarts GCC.  */
338   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
339   r[n] = x + c;
340   MPN_INCR_U (r, n + 1, x);
341 #endif
342 #if 0
343   if ((c & GMP_LIMB_HIGHBIT) != 0)
344     {
345       r[n] = 0;
346       MPN_INCR_U (r, n + 1, -c);
347     }
348   else
349     {
350       r[n] = c;
351     }
352 #endif
353 }
354 
355 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
356 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
357    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
358 
359 static void
mpn_fft_fft(mp_ptr * Ap,mp_size_t K,int ** ll,mp_size_t omega,mp_size_t n,mp_size_t inc,mp_ptr tp)360 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
361 	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
362 {
363   if (K == 2)
364     {
365       mp_limb_t cy;
366 #if HAVE_NATIVE_mpn_add_n_sub_n
367       cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
368 #else
369       MPN_COPY (tp, Ap[0], n + 1);
370       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
371       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
372 #endif
373       if (Ap[0][n] > 1) /* can be 2 or 3 */
374 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
375       if (cy) /* Ap[inc][n] can be -1 or -2 */
376 	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
377     }
378   else
379     {
380       mp_size_t j, K2 = K >> 1;
381       int *lk = *ll;
382 
383       mpn_fft_fft (Ap,     K2, ll-1, 2 * omega, n, inc * 2, tp);
384       mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
385       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
386 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
387       for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
388 	{
389 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
390 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
391 	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
392 	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
393 	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
394 	}
395     }
396 }
397 
398 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
399 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
400    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
401    tp must have space for 2*(n+1) limbs.
402 */
403 
404 
405 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
406    by subtracting that modulus if necessary.
407 
408    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
409    borrow and the limbs must be zeroed out again.  This will occur very
410    infrequently.  */
411 
412 static inline void
mpn_fft_normalize(mp_ptr ap,mp_size_t n)413 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
414 {
415   if (ap[n] != 0)
416     {
417       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
418       if (ap[n] == 0)
419 	{
420 	  /* This happens with very low probability; we have yet to trigger it,
421 	     and thereby make sure this code is correct.  */
422 	  MPN_ZERO (ap, n);
423 	  ap[n] = 1;
424 	}
425       else
426 	ap[n] = 0;
427     }
428 }
429 
430 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
431 static void
mpn_fft_mul_modF_K(mp_ptr * ap,mp_ptr * bp,mp_size_t n,mp_size_t K)432 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
433 {
434   int i;
435   int sqr = (ap == bp);
436   TMP_DECL;
437 
438   TMP_MARK;
439 
440   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
441     {
442       mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
443       int k;
444       int **fft_l, *tmp;
445       mp_ptr *Ap, *Bp, A, B, T;
446 
447       k = mpn_fft_best_k (n, sqr);
448       K2 = (mp_size_t) 1 << k;
449       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
450       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
451       M2 = n * GMP_NUMB_BITS >> k;
452       l = n >> k;
453       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
454       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
455       nprime2 = Nprime2 / GMP_NUMB_BITS;
456 
457       /* we should ensure that nprime2 is a multiple of the next K */
458       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
459 	{
460 	  mp_size_t K3;
461 	  for (;;)
462 	    {
463 	      K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
464 	      if ((nprime2 & (K3 - 1)) == 0)
465 		break;
466 	      nprime2 = (nprime2 + K3 - 1) & -K3;
467 	      Nprime2 = nprime2 * GMP_LIMB_BITS;
468 	      /* warning: since nprime2 changed, K3 may change too! */
469 	    }
470 	}
471       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
472 
473       Mp2 = Nprime2 >> k;
474 
475       Ap = TMP_BALLOC_MP_PTRS (K2);
476       Bp = TMP_BALLOC_MP_PTRS (K2);
477       A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
478       T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
479       B = A + ((nprime2 + 1) << k);
480       fft_l = TMP_BALLOC_TYPE (k + 1, int *);
481       tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
482       for (i = 0; i <= k; i++)
483 	{
484 	  fft_l[i] = tmp;
485 	  tmp += (mp_size_t) 1 << i;
486 	}
487 
488       mpn_fft_initl (fft_l, k);
489 
490       TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
491 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
492       for (i = 0; i < K; i++, ap++, bp++)
493 	{
494 	  mp_limb_t cy;
495 	  mpn_fft_normalize (*ap, n);
496 	  if (!sqr)
497 	    mpn_fft_normalize (*bp, n);
498 
499 	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
500 	  if (!sqr)
501 	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
502 
503 	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
504 				     l, Mp2, fft_l, T, sqr);
505 	  (*ap)[n] = cy;
506 	}
507     }
508   else
509     {
510       mp_ptr a, b, tp, tpn;
511       mp_limb_t cc;
512       mp_size_t n2 = 2 * n;
513       tp = TMP_BALLOC_LIMBS (n2);
514       tpn = tp + n;
515       TRACE (printf ("  mpn_mul_n %ld of %ld limbs\n", K, n));
516       for (i = 0; i < K; i++)
517 	{
518 	  a = *ap++;
519 	  b = *bp++;
520 	  if (sqr)
521 	    mpn_sqr (tp, a, n);
522 	  else
523 	    mpn_mul_n (tp, b, a, n);
524 	  if (a[n] != 0)
525 	    cc = mpn_add_n (tpn, tpn, b, n);
526 	  else
527 	    cc = 0;
528 	  if (b[n] != 0)
529 	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
530 	  if (cc != 0)
531 	    {
532 	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
533 	      cc = mpn_add_1 (tp, tp, n2, cc);
534 	      ASSERT (cc == 0);
535 	    }
536 	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
537 	}
538     }
539   TMP_FREE;
540 }
541 
542 
543 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
544    output: K*A[0] K*A[K-1] ... K*A[1].
545    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
546    This condition is also fulfilled at exit.
547 */
548 static void
mpn_fft_fftinv(mp_ptr * Ap,mp_size_t K,mp_size_t omega,mp_size_t n,mp_ptr tp)549 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
550 {
551   if (K == 2)
552     {
553       mp_limb_t cy;
554 #if HAVE_NATIVE_mpn_add_n_sub_n
555       cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
556 #else
557       MPN_COPY (tp, Ap[0], n + 1);
558       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
559       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
560 #endif
561       if (Ap[0][n] > 1) /* can be 2 or 3 */
562 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
563       if (cy) /* Ap[1][n] can be -1 or -2 */
564 	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
565     }
566   else
567     {
568       mp_size_t j, K2 = K >> 1;
569 
570       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
571       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
572       /* A[j]     <- A[j] + omega^j A[j+K/2]
573 	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
574       for (j = 0; j < K2; j++, Ap++)
575 	{
576 	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
577 	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
578 	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
579 	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
580 	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
581 	}
582     }
583 }
584 
585 
586 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
587 static void
mpn_fft_div_2exp_modF(mp_ptr r,mp_srcptr a,mp_bitcnt_t k,mp_size_t n)588 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
589 {
590   mp_bitcnt_t i;
591 
592   ASSERT (r != a);
593   i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
594   mpn_fft_mul_2exp_modF (r, a, i, n);
595   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
596   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
597   mpn_fft_normalize (r, n);
598 }
599 
600 
601 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
602    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
603    then {rp,n}=0.
604 */
605 static mp_size_t
mpn_fft_norm_modF(mp_ptr rp,mp_size_t n,mp_ptr ap,mp_size_t an)606 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
607 {
608   mp_size_t l, m, rpn;
609   mp_limb_t cc;
610 
611   ASSERT ((n <= an) && (an <= 3 * n));
612   m = an - 2 * n;
613   if (m > 0)
614     {
615       l = n;
616       /* add {ap, m} and {ap+2n, m} in {rp, m} */
617       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
618       /* copy {ap+m, n-m} to {rp+m, n-m} */
619       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
620     }
621   else
622     {
623       l = an - n; /* l <= n */
624       MPN_COPY (rp, ap, n);
625       rpn = 0;
626     }
627 
628   /* remains to subtract {ap+n, l} from {rp, n+1} */
629   cc = mpn_sub_n (rp, rp, ap + n, l);
630   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
631   if (rpn < 0) /* necessarily rpn = -1 */
632     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
633   return rpn;
634 }
635 
636 /* store in A[0..nprime] the first M bits from {n, nl},
637    in A[nprime+1..] the following M bits, ...
638    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
639    T must have space for at least (nprime + 1) limbs.
640    We must have nl <= 2*K*l.
641 */
642 static void
mpn_mul_fft_decompose(mp_ptr A,mp_ptr * Ap,mp_size_t K,mp_size_t nprime,mp_srcptr n,mp_size_t nl,mp_size_t l,mp_size_t Mp,mp_ptr T)643 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
644 		       mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
645 		       mp_ptr T)
646 {
647   mp_size_t i, j;
648   mp_ptr tmp;
649   mp_size_t Kl = K * l;
650   TMP_DECL;
651   TMP_MARK;
652 
653   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
654     {
655       mp_size_t dif = nl - Kl;
656       mp_limb_signed_t cy;
657 
658       tmp = TMP_BALLOC_LIMBS(Kl + 1);
659 
660       if (dif > Kl)
661 	{
662 	  int subp = 0;
663 
664 	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
665 	  n += 2 * Kl;
666 	  dif -= Kl;
667 
668 	  /* now dif > 0 */
669 	  while (dif > Kl)
670 	    {
671 	      if (subp)
672 		cy += mpn_sub_n (tmp, tmp, n, Kl);
673 	      else
674 		cy -= mpn_add_n (tmp, tmp, n, Kl);
675 	      subp ^= 1;
676 	      n += Kl;
677 	      dif -= Kl;
678 	    }
679 	  /* now dif <= Kl */
680 	  if (subp)
681 	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
682 	  else
683 	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
684 	  if (cy >= 0)
685 	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
686 	  else
687 	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
688 	}
689       else /* dif <= Kl, i.e. nl <= 2 * Kl */
690 	{
691 	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
692 	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
693 	}
694       tmp[Kl] = cy;
695       nl = Kl + 1;
696       n = tmp;
697     }
698   for (i = 0; i < K; i++)
699     {
700       Ap[i] = A;
701       /* store the next M bits of n into A[0..nprime] */
702       if (nl > 0) /* nl is the number of remaining limbs */
703 	{
704 	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
705 	  nl -= j;
706 	  MPN_COPY (T, n, j);
707 	  MPN_ZERO (T + j, nprime + 1 - j);
708 	  n += l;
709 	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
710 	}
711       else
712 	MPN_ZERO (A, nprime + 1);
713       A += nprime + 1;
714     }
715   ASSERT_ALWAYS (nl == 0);
716   TMP_FREE;
717 }
718 
719 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
720    op is pl limbs, its high bit is returned.
721    One must have pl = mpn_fft_next_size (pl, k).
722    T must have space for 2 * (nprime + 1) limbs.
723 */
724 
725 static mp_limb_t
mpn_mul_fft_internal(mp_ptr op,mp_size_t pl,int k,mp_ptr * Ap,mp_ptr * Bp,mp_ptr A,mp_ptr B,mp_size_t nprime,mp_size_t l,mp_size_t Mp,int ** fft_l,mp_ptr T,int sqr)726 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
727 		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
728 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
729 		      int **fft_l, mp_ptr T, int sqr)
730 {
731   mp_size_t K, i, pla, lo, sh, j;
732   mp_ptr p;
733   mp_limb_t cc;
734 
735   K = (mp_size_t) 1 << k;
736 
737   /* direct fft's */
738   mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
739   if (!sqr)
740     mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
741 
742   /* term to term multiplications */
743   mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
744 
745   /* inverse fft's */
746   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
747 
748   /* division of terms after inverse fft */
749   Bp[0] = T + nprime + 1;
750   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
751   for (i = 1; i < K; i++)
752     {
753       Bp[i] = Ap[i - 1];
754       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
755     }
756 
757   /* addition of terms in result p */
758   MPN_ZERO (T, nprime + 1);
759   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
760   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
761   MPN_ZERO (p, pla);
762   cc = 0; /* will accumulate the (signed) carry at p[pla] */
763   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
764     {
765       mp_ptr n = p + sh;
766 
767       j = (K - i) & (K - 1);
768 
769       if (mpn_add_n (n, n, Bp[j], nprime + 1))
770 	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
771 			  pla - sh - nprime - 1, CNST_LIMB(1));
772       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
773       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
774 	{ /* subtract 2^N'+1 */
775 	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
776 	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
777 	}
778     }
779   if (cc == -CNST_LIMB(1))
780     {
781       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
782 	{
783 	  /* p[pla-pl]...p[pla-1] are all zero */
784 	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
785 	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
786 	}
787     }
788   else if (cc == 1)
789     {
790       if (pla >= 2 * pl)
791 	{
792 	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
793 	    ;
794 	}
795       else
796 	{
797 	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
798 	  ASSERT (cc == 0);
799 	}
800     }
801   else
802     ASSERT (cc == 0);
803 
804   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
805      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
806      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
807   return mpn_fft_norm_modF (op, pl, p, pla);
808 }
809 
810 /* return the lcm of a and 2^k */
811 static mp_bitcnt_t
mpn_mul_fft_lcm(mp_bitcnt_t a,int k)812 mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
813 {
814   mp_bitcnt_t l = k;
815 
816   while (a % 2 == 0 && k > 0)
817     {
818       a >>= 1;
819       k --;
820     }
821   return a << l;
822 }
823 
824 
825 mp_limb_t
mpn_mul_fft(mp_ptr op,mp_size_t pl,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml,int k)826 mpn_mul_fft (mp_ptr op, mp_size_t pl,
827 	     mp_srcptr n, mp_size_t nl,
828 	     mp_srcptr m, mp_size_t ml,
829 	     int k)
830 {
831   int i;
832   mp_size_t K, maxLK;
833   mp_size_t N, Nprime, nprime, M, Mp, l;
834   mp_ptr *Ap, *Bp, A, T, B;
835   int **fft_l, *tmp;
836   int sqr = (n == m && nl == ml);
837   mp_limb_t h;
838   TMP_DECL;
839 
840   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
841   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
842 
843   TMP_MARK;
844   N = pl * GMP_NUMB_BITS;
845   fft_l = TMP_BALLOC_TYPE (k + 1, int *);
846   tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
847   for (i = 0; i <= k; i++)
848     {
849       fft_l[i] = tmp;
850       tmp += (mp_size_t) 1 << i;
851     }
852 
853   mpn_fft_initl (fft_l, k);
854   K = (mp_size_t) 1 << k;
855   M = N >> k;	/* N = 2^k M */
856   l = 1 + (M - 1) / GMP_NUMB_BITS;
857   maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
858 
859   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
860   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
861   nprime = Nprime / GMP_NUMB_BITS;
862   TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
863 		 N, K, M, l, maxLK, Nprime, nprime));
864   /* we should ensure that recursively, nprime is a multiple of the next K */
865   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
866     {
867       mp_size_t K2;
868       for (;;)
869 	{
870 	  K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
871 	  if ((nprime & (K2 - 1)) == 0)
872 	    break;
873 	  nprime = (nprime + K2 - 1) & -K2;
874 	  Nprime = nprime * GMP_LIMB_BITS;
875 	  /* warning: since nprime changed, K2 may change too! */
876 	}
877       TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
878     }
879   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
880 
881   T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
882   Mp = Nprime >> k;
883 
884   TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
885 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
886 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
887 
888   A = TMP_BALLOC_LIMBS (K * (nprime + 1));
889   Ap = TMP_BALLOC_MP_PTRS (K);
890   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
891   if (sqr)
892     {
893       mp_size_t pla;
894       pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
895       B = TMP_BALLOC_LIMBS (pla);
896       Bp = TMP_BALLOC_MP_PTRS (K);
897     }
898   else
899     {
900       B = TMP_BALLOC_LIMBS (K * (nprime + 1));
901       Bp = TMP_BALLOC_MP_PTRS (K);
902       mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
903     }
904   h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
905 
906   TMP_FREE;
907   return h;
908 }
909 
910 #if WANT_OLD_FFT_FULL
911 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
912 void
mpn_mul_fft_full(mp_ptr op,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml)913 mpn_mul_fft_full (mp_ptr op,
914 		  mp_srcptr n, mp_size_t nl,
915 		  mp_srcptr m, mp_size_t ml)
916 {
917   mp_ptr pad_op;
918   mp_size_t pl, pl2, pl3, l;
919   mp_size_t cc, c2, oldcc;
920   int k2, k3;
921   int sqr = (n == m && nl == ml);
922 
923   pl = nl + ml; /* total number of limbs of the result */
924 
925   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
926      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
927      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
928      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
929      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
930      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
931      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
932 
933   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
934 
935   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
936   do
937     {
938       pl2++;
939       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
940       pl2 = mpn_fft_next_size (pl2, k2);
941       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
942 			    thus pl2 / 2 is exact */
943       k3 = mpn_fft_best_k (pl3, sqr);
944     }
945   while (mpn_fft_next_size (pl3, k3) != pl3);
946 
947   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
948 		 nl, ml, pl2, pl3, k2));
949 
950   ASSERT_ALWAYS(pl3 <= pl);
951   cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
952   ASSERT(cc == 0);
953   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
954   cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
955   cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
956   /* 0 <= cc <= 1 */
957   ASSERT(0 <= cc && cc <= 1);
958   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
959   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
960   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
961   ASSERT(-1 <= cc && cc <= 1);
962   if (cc < 0)
963     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
964   ASSERT(0 <= cc && cc <= 1);
965   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
966   oldcc = cc;
967 #if HAVE_NATIVE_mpn_add_n_sub_n
968   c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
969   /* c2 & 1 is the borrow, c2 & 2 is the carry */
970   cc += c2 >> 1; /* carry out from high <- low + high */
971   c2 = c2 & 1; /* borrow out from low <- low - high */
972 #else
973   {
974     mp_ptr tmp;
975     TMP_DECL;
976 
977     TMP_MARK;
978     tmp = TMP_BALLOC_LIMBS (l);
979     MPN_COPY (tmp, pad_op, l);
980     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
981     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
982     TMP_FREE;
983   }
984 #endif
985   c2 += oldcc;
986   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
987      at pad_op + l, cc is the carry at pad_op + pl2 */
988   /* 0 <= cc <= 2 */
989   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
990   /* -1 <= cc <= 2 */
991   if (cc > 0)
992     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
993   /* now -1 <= cc <= 0 */
994   if (cc < 0)
995     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
996   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
997   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
998     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
999   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1000      out below */
1001   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1002   if (cc) /* then cc=1 */
1003     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1004   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1005      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1006   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1007   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1008   MPN_COPY (op + pl3, pad_op, pl - pl3);
1009   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1010   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1011   /* since the final result has at most pl limbs, no carry out below */
1012   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1013 }
1014 #endif
1015