1 /* Schoenhage's fast multiplication modulo 2^N+1.
2
3 Contributed by Paul Zimmermann.
4
5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY
6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
10 2009, 2010 Free Software Foundation, Inc.
11
12 This file is part of the GNU MP Library.
13
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of the GNU Lesser General Public License as published by
16 the Free Software Foundation; either version 3 of the License, or (at your
17 option) any later version.
18
19 The GNU MP Library is distributed in the hope that it will be useful, but
20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
22 License for more details.
23
24 You should have received a copy of the GNU Lesser General Public License
25 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */
26
27
28 /* References:
29
30 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31 Strassen, Computing 7, p. 281-292, 1971.
32
33 Asymptotically fast algorithms for the numerical multiplication and division
34 of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36
37 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39
40 TODO:
41
42 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43 Zimmermann.
44
45 It might be possible to avoid a small number of MPN_COPYs by using a
46 rotating temporary or two.
47
48 Cleanup and simplify the code!
49 */
50
51 #ifdef TRACE
52 #undef TRACE
53 #define TRACE(x) x
54 #include <stdio.h>
55 #else
56 #define TRACE(x)
57 #endif
58
59 #include "gmp.h"
60 #include "gmp-impl.h"
61
62 #ifdef WANT_ADDSUB
63 #include "generic/add_n_sub_n.c"
64 #define HAVE_NATIVE_mpn_add_n_sub_n 1
65 #endif
66
67 static mp_limb_t mpn_mul_fft_internal
68 __GMP_PROTO ((mp_ptr, mp_size_t, int, mp_ptr *, mp_ptr *,
69 mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, int));
70 static void mpn_mul_fft_decompose
71 __GMP_PROTO ((mp_ptr, mp_ptr *, int, int, mp_srcptr, mp_size_t, int, int, mp_ptr));
72
73
74 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
75 We have sqr=0 if for a multiply, sqr=1 for a square.
76 There are three generations of this code; we keep the old ones as long as
77 some gmp-mparam.h is not updated. */
78
79
80 /*****************************************************************************/
81
82 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
83
84 #ifndef FFT_TABLE3_SIZE /* When tuning, this is define in gmp-impl.h */
85 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
86 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
87 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
88 #else
89 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
90 #endif
91 #endif
92 #endif
93
94 #ifndef FFT_TABLE3_SIZE
95 #define FFT_TABLE3_SIZE 200
96 #endif
97
98 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
99 {
100 MUL_FFT_TABLE3,
101 SQR_FFT_TABLE3
102 };
103
104 int
mpn_fft_best_k(mp_size_t n,int sqr)105 mpn_fft_best_k (mp_size_t n, int sqr)
106 {
107 FFT_TABLE_ATTRS struct fft_table_nk *fft_tab, *tab;
108 mp_size_t tab_n, thres;
109 int last_k;
110
111 fft_tab = mpn_fft_table3[sqr];
112 last_k = fft_tab->k;
113 for (tab = fft_tab + 1; ; tab++)
114 {
115 tab_n = tab->n;
116 thres = tab_n << last_k;
117 if (n <= thres)
118 break;
119 last_k = tab->k;
120 }
121 return last_k;
122 }
123
124 #define MPN_FFT_BEST_READY 1
125 #endif
126
127 /*****************************************************************************/
128
129 #if ! defined (MPN_FFT_BEST_READY)
130 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
131 {
132 MUL_FFT_TABLE,
133 SQR_FFT_TABLE
134 };
135
136 int
mpn_fft_best_k(mp_size_t n,int sqr)137 mpn_fft_best_k (mp_size_t n, int sqr)
138 {
139 int i;
140
141 for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
142 if (n < mpn_fft_table[sqr][i])
143 return i + FFT_FIRST_K;
144
145 /* treat 4*last as one further entry */
146 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
147 return i + FFT_FIRST_K;
148 else
149 return i + FFT_FIRST_K + 1;
150 }
151 #endif
152
153 /*****************************************************************************/
154
155
156 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
157 i.e. smallest multiple of 2^k >= pl.
158
159 Don't declare static: needed by tuneup.
160 */
161
162 mp_size_t
mpn_fft_next_size(mp_size_t pl,int k)163 mpn_fft_next_size (mp_size_t pl, int k)
164 {
165 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
166 return pl << k;
167 }
168
169
170 /* Initialize l[i][j] with bitrev(j) */
171 static void
mpn_fft_initl(int ** l,int k)172 mpn_fft_initl (int **l, int k)
173 {
174 int i, j, K;
175 int *li;
176
177 l[0][0] = 0;
178 for (i = 1, K = 1; i <= k; i++, K *= 2)
179 {
180 li = l[i];
181 for (j = 0; j < K; j++)
182 {
183 li[j] = 2 * l[i - 1][j];
184 li[K + j] = 1 + li[j];
185 }
186 }
187 }
188
189
190 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
191 Assumes a is semi-normalized, i.e. a[n] <= 1.
192 r and a must have n+1 limbs, and not overlap.
193 */
194 static void
mpn_fft_mul_2exp_modF(mp_ptr r,mp_srcptr a,unsigned int d,mp_size_t n)195 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
196 {
197 int sh;
198 mp_limb_t cc, rd;
199
200 sh = d % GMP_NUMB_BITS;
201 d /= GMP_NUMB_BITS;
202
203 if (d >= n) /* negate */
204 {
205 /* r[0..d-1] <-- lshift(a[n-d]..a[n-1], sh)
206 r[d..n-1] <-- -lshift(a[0]..a[n-d-1], sh) */
207
208 d -= n;
209 if (sh != 0)
210 {
211 /* no out shift below since a[n] <= 1 */
212 mpn_lshift (r, a + n - d, d + 1, sh);
213 rd = r[d];
214 cc = mpn_lshiftc (r + d, a, n - d, sh);
215 }
216 else
217 {
218 MPN_COPY (r, a + n - d, d);
219 rd = a[n];
220 mpn_com (r + d, a, n - d);
221 cc = 0;
222 }
223
224 /* add cc to r[0], and add rd to r[d] */
225
226 /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
227
228 r[n] = 0;
229 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
230 cc++;
231 mpn_incr_u (r, cc);
232
233 rd++;
234 /* rd might overflow when sh=GMP_NUMB_BITS-1 */
235 cc = (rd == 0) ? 1 : rd;
236 r = r + d + (rd == 0);
237 mpn_incr_u (r, cc);
238 }
239 else
240 {
241 /* r[0..d-1] <-- -lshift(a[n-d]..a[n-1], sh)
242 r[d..n-1] <-- lshift(a[0]..a[n-d-1], sh) */
243 if (sh != 0)
244 {
245 /* no out bits below since a[n] <= 1 */
246 mpn_lshiftc (r, a + n - d, d + 1, sh);
247 rd = ~r[d];
248 /* {r, d+1} = {a+n-d, d+1} << sh */
249 cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
250 }
251 else
252 {
253 /* r[d] is not used below, but we save a test for d=0 */
254 mpn_com (r, a + n - d, d + 1);
255 rd = a[n];
256 MPN_COPY (r + d, a, n - d);
257 cc = 0;
258 }
259
260 /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
261
262 /* if d=0 we just have r[0]=a[n] << sh */
263 if (d != 0)
264 {
265 /* now add 1 in r[0], subtract 1 in r[d] */
266 if (cc-- == 0) /* then add 1 to r[0] */
267 cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
268 cc = mpn_sub_1 (r, r, d, cc) + 1;
269 /* add 1 to cc instead of rd since rd might overflow */
270 }
271
272 /* now subtract cc and rd from r[d..n] */
273
274 r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
275 r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
276 if (r[n] & GMP_LIMB_HIGHBIT)
277 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
278 }
279 }
280
281
282 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
283 Assumes a and b are semi-normalized.
284 */
285 static inline void
mpn_fft_add_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,int n)286 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
287 {
288 mp_limb_t c, x;
289
290 c = a[n] + b[n] + mpn_add_n (r, a, b, n);
291 /* 0 <= c <= 3 */
292
293 #if 1
294 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
295 result is slower code, of course. But the following outsmarts GCC. */
296 x = (c - 1) & -(c != 0);
297 r[n] = c - x;
298 MPN_DECR_U (r, n + 1, x);
299 #endif
300 #if 0
301 if (c > 1)
302 {
303 r[n] = 1; /* r[n] - c = 1 */
304 MPN_DECR_U (r, n + 1, c - 1);
305 }
306 else
307 {
308 r[n] = c;
309 }
310 #endif
311 }
312
313 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
314 Assumes a and b are semi-normalized.
315 */
316 static inline void
mpn_fft_sub_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,int n)317 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
318 {
319 mp_limb_t c, x;
320
321 c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
322 /* -2 <= c <= 1 */
323
324 #if 1
325 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
326 result is slower code, of course. But the following outsmarts GCC. */
327 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
328 r[n] = x + c;
329 MPN_INCR_U (r, n + 1, x);
330 #endif
331 #if 0
332 if ((c & GMP_LIMB_HIGHBIT) != 0)
333 {
334 r[n] = 0;
335 MPN_INCR_U (r, n + 1, -c);
336 }
337 else
338 {
339 r[n] = c;
340 }
341 #endif
342 }
343
344 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
345 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
346 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
347
348 static void
mpn_fft_fft(mp_ptr * Ap,mp_size_t K,int ** ll,mp_size_t omega,mp_size_t n,mp_size_t inc,mp_ptr tp)349 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
350 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
351 {
352 if (K == 2)
353 {
354 mp_limb_t cy;
355 #if HAVE_NATIVE_mpn_add_n_sub_n
356 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
357 #else
358 MPN_COPY (tp, Ap[0], n + 1);
359 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
360 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
361 #endif
362 if (Ap[0][n] > 1) /* can be 2 or 3 */
363 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
364 if (cy) /* Ap[inc][n] can be -1 or -2 */
365 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
366 }
367 else
368 {
369 int j;
370 int *lk = *ll;
371
372 mpn_fft_fft (Ap, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
373 mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
374 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
375 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
376 for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
377 {
378 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
379 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
380 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
381 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
382 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
383 }
384 }
385 }
386
387 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
388 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
389 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
390 tp must have space for 2*(n+1) limbs.
391 */
392
393
394 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
395 by subtracting that modulus if necessary.
396
397 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
398 borrow and the limbs must be zeroed out again. This will occur very
399 infrequently. */
400
401 static inline void
mpn_fft_normalize(mp_ptr ap,mp_size_t n)402 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
403 {
404 if (ap[n] != 0)
405 {
406 MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
407 if (ap[n] == 0)
408 {
409 /* This happens with very low probability; we have yet to trigger it,
410 and thereby make sure this code is correct. */
411 MPN_ZERO (ap, n);
412 ap[n] = 1;
413 }
414 else
415 ap[n] = 0;
416 }
417 }
418
419 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
420 static void
mpn_fft_mul_modF_K(mp_ptr * ap,mp_ptr * bp,mp_size_t n,int K)421 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
422 {
423 int i;
424 int sqr = (ap == bp);
425 TMP_DECL;
426
427 TMP_MARK;
428
429 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
430 {
431 int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
432 int **fft_l;
433 mp_ptr *Ap, *Bp, A, B, T;
434
435 k = mpn_fft_best_k (n, sqr);
436 K2 = 1 << k;
437 ASSERT_ALWAYS((n & (K2 - 1)) == 0);
438 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
439 M2 = n * GMP_NUMB_BITS >> k;
440 l = n >> k;
441 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
442 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
443 nprime2 = Nprime2 / GMP_NUMB_BITS;
444
445 /* we should ensure that nprime2 is a multiple of the next K */
446 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
447 {
448 unsigned long K3;
449 for (;;)
450 {
451 K3 = 1L << mpn_fft_best_k (nprime2, sqr);
452 if ((nprime2 & (K3 - 1)) == 0)
453 break;
454 nprime2 = (nprime2 + K3 - 1) & -K3;
455 Nprime2 = nprime2 * GMP_LIMB_BITS;
456 /* warning: since nprime2 changed, K3 may change too! */
457 }
458 }
459 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
460
461 Mp2 = Nprime2 >> k;
462
463 Ap = TMP_ALLOC_MP_PTRS (K2);
464 Bp = TMP_ALLOC_MP_PTRS (K2);
465 A = TMP_ALLOC_LIMBS (2 * (nprime2 + 1) << k);
466 T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
467 B = A + ((nprime2 + 1) << k);
468 fft_l = TMP_ALLOC_TYPE (k + 1, int *);
469 for (i = 0; i <= k; i++)
470 fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
471 mpn_fft_initl (fft_l, k);
472
473 TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
474 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
475 for (i = 0; i < K; i++, ap++, bp++)
476 {
477 mp_limb_t cy;
478 mpn_fft_normalize (*ap, n);
479 if (!sqr)
480 mpn_fft_normalize (*bp, n);
481
482 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
483 if (!sqr)
484 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
485
486 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
487 l, Mp2, fft_l, T, sqr);
488 (*ap)[n] = cy;
489 }
490 }
491 else
492 {
493 mp_ptr a, b, tp, tpn;
494 mp_limb_t cc;
495 int n2 = 2 * n;
496 tp = TMP_ALLOC_LIMBS (n2);
497 tpn = tp + n;
498 TRACE (printf (" mpn_mul_n %d of %ld limbs\n", K, n));
499 for (i = 0; i < K; i++)
500 {
501 a = *ap++;
502 b = *bp++;
503 if (sqr)
504 mpn_sqr (tp, a, n);
505 else
506 mpn_mul_n (tp, b, a, n);
507 if (a[n] != 0)
508 cc = mpn_add_n (tpn, tpn, b, n);
509 else
510 cc = 0;
511 if (b[n] != 0)
512 cc += mpn_add_n (tpn, tpn, a, n) + a[n];
513 if (cc != 0)
514 {
515 /* FIXME: use MPN_INCR_U here, since carry is not expected. */
516 cc = mpn_add_1 (tp, tp, n2, cc);
517 ASSERT (cc == 0);
518 }
519 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
520 }
521 }
522 TMP_FREE;
523 }
524
525
526 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
527 output: K*A[0] K*A[K-1] ... K*A[1].
528 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
529 This condition is also fulfilled at exit.
530 */
531 static void
mpn_fft_fftinv(mp_ptr * Ap,int K,mp_size_t omega,mp_size_t n,mp_ptr tp)532 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
533 {
534 if (K == 2)
535 {
536 mp_limb_t cy;
537 #if HAVE_NATIVE_mpn_add_n_sub_n
538 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
539 #else
540 MPN_COPY (tp, Ap[0], n + 1);
541 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
542 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
543 #endif
544 if (Ap[0][n] > 1) /* can be 2 or 3 */
545 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
546 if (cy) /* Ap[1][n] can be -1 or -2 */
547 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
548 }
549 else
550 {
551 int j, K2 = K >> 1;
552
553 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
554 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
555 /* A[j] <- A[j] + omega^j A[j+K/2]
556 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
557 for (j = 0; j < K2; j++, Ap++)
558 {
559 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
560 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */
561 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
562 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
563 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
564 }
565 }
566 }
567
568
569 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
570 static void
mpn_fft_div_2exp_modF(mp_ptr r,mp_srcptr a,int k,mp_size_t n)571 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
572 {
573 int i;
574
575 ASSERT (r != a);
576 i = 2 * n * GMP_NUMB_BITS - k;
577 mpn_fft_mul_2exp_modF (r, a, i, n);
578 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
579 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
580 mpn_fft_normalize (r, n);
581 }
582
583
584 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
585 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
586 then {rp,n}=0.
587 */
588 static int
mpn_fft_norm_modF(mp_ptr rp,mp_size_t n,mp_ptr ap,mp_size_t an)589 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
590 {
591 mp_size_t l;
592 long int m;
593 mp_limb_t cc;
594 int rpn;
595
596 ASSERT ((n <= an) && (an <= 3 * n));
597 m = an - 2 * n;
598 if (m > 0)
599 {
600 l = n;
601 /* add {ap, m} and {ap+2n, m} in {rp, m} */
602 cc = mpn_add_n (rp, ap, ap + 2 * n, m);
603 /* copy {ap+m, n-m} to {rp+m, n-m} */
604 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
605 }
606 else
607 {
608 l = an - n; /* l <= n */
609 MPN_COPY (rp, ap, n);
610 rpn = 0;
611 }
612
613 /* remains to subtract {ap+n, l} from {rp, n+1} */
614 cc = mpn_sub_n (rp, rp, ap + n, l);
615 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
616 if (rpn < 0) /* necessarily rpn = -1 */
617 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
618 return rpn;
619 }
620
621 /* store in A[0..nprime] the first M bits from {n, nl},
622 in A[nprime+1..] the following M bits, ...
623 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
624 T must have space for at least (nprime + 1) limbs.
625 We must have nl <= 2*K*l.
626 */
627 static void
mpn_mul_fft_decompose(mp_ptr A,mp_ptr * Ap,int K,int nprime,mp_srcptr n,mp_size_t nl,int l,int Mp,mp_ptr T)628 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
629 mp_size_t nl, int l, int Mp, mp_ptr T)
630 {
631 int i, j;
632 mp_ptr tmp;
633 mp_size_t Kl = K * l;
634 TMP_DECL;
635 TMP_MARK;
636
637 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
638 {
639 mp_size_t dif = nl - Kl;
640 mp_limb_signed_t cy;
641
642 tmp = TMP_ALLOC_LIMBS(Kl + 1);
643
644 if (dif > Kl)
645 {
646 int subp = 0;
647
648 cy = mpn_sub_n (tmp, n, n + Kl, Kl);
649 n += 2 * Kl;
650 dif -= Kl;
651
652 /* now dif > 0 */
653 while (dif > Kl)
654 {
655 if (subp)
656 cy += mpn_sub_n (tmp, tmp, n, Kl);
657 else
658 cy -= mpn_add_n (tmp, tmp, n, Kl);
659 subp ^= 1;
660 n += Kl;
661 dif -= Kl;
662 }
663 /* now dif <= Kl */
664 if (subp)
665 cy += mpn_sub (tmp, tmp, Kl, n, dif);
666 else
667 cy -= mpn_add (tmp, tmp, Kl, n, dif);
668 if (cy >= 0)
669 cy = mpn_add_1 (tmp, tmp, Kl, cy);
670 else
671 cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
672 }
673 else /* dif <= Kl, i.e. nl <= 2 * Kl */
674 {
675 cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
676 cy = mpn_add_1 (tmp, tmp, Kl, cy);
677 }
678 tmp[Kl] = cy;
679 nl = Kl + 1;
680 n = tmp;
681 }
682 for (i = 0; i < K; i++)
683 {
684 Ap[i] = A;
685 /* store the next M bits of n into A[0..nprime] */
686 if (nl > 0) /* nl is the number of remaining limbs */
687 {
688 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
689 nl -= j;
690 MPN_COPY (T, n, j);
691 MPN_ZERO (T + j, nprime + 1 - j);
692 n += l;
693 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
694 }
695 else
696 MPN_ZERO (A, nprime + 1);
697 A += nprime + 1;
698 }
699 ASSERT_ALWAYS (nl == 0);
700 TMP_FREE;
701 }
702
703 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
704 op is pl limbs, its high bit is returned.
705 One must have pl = mpn_fft_next_size (pl, k).
706 T must have space for 2 * (nprime + 1) limbs.
707 */
708
709 static mp_limb_t
mpn_mul_fft_internal(mp_ptr op,mp_size_t pl,int k,mp_ptr * Ap,mp_ptr * Bp,mp_ptr A,mp_ptr B,mp_size_t nprime,mp_size_t l,mp_size_t Mp,int ** fft_l,mp_ptr T,int sqr)710 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
711 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
712 mp_size_t nprime, mp_size_t l, mp_size_t Mp,
713 int **fft_l, mp_ptr T, int sqr)
714 {
715 int K, i, pla, lo, sh, j;
716 mp_ptr p;
717 mp_limb_t cc;
718
719 K = 1 << k;
720
721 /* direct fft's */
722 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
723 if (!sqr)
724 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
725
726 /* term to term multiplications */
727 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
728
729 /* inverse fft's */
730 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
731
732 /* division of terms after inverse fft */
733 Bp[0] = T + nprime + 1;
734 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
735 for (i = 1; i < K; i++)
736 {
737 Bp[i] = Ap[i - 1];
738 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
739 }
740
741 /* addition of terms in result p */
742 MPN_ZERO (T, nprime + 1);
743 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
744 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
745 MPN_ZERO (p, pla);
746 cc = 0; /* will accumulate the (signed) carry at p[pla] */
747 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
748 {
749 mp_ptr n = p + sh;
750
751 j = (K - i) & (K - 1);
752
753 if (mpn_add_n (n, n, Bp[j], nprime + 1))
754 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
755 pla - sh - nprime - 1, CNST_LIMB(1));
756 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
757 if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
758 { /* subtract 2^N'+1 */
759 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
760 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
761 }
762 }
763 if (cc == -CNST_LIMB(1))
764 {
765 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
766 {
767 /* p[pla-pl]...p[pla-1] are all zero */
768 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
769 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
770 }
771 }
772 else if (cc == 1)
773 {
774 if (pla >= 2 * pl)
775 {
776 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
777 ;
778 }
779 else
780 {
781 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
782 ASSERT (cc == 0);
783 }
784 }
785 else
786 ASSERT (cc == 0);
787
788 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
789 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
790 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
791 return mpn_fft_norm_modF (op, pl, p, pla);
792 }
793
794 /* return the lcm of a and 2^k */
795 static unsigned long int
mpn_mul_fft_lcm(unsigned long int a,unsigned int k)796 mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
797 {
798 unsigned long int l = k;
799
800 while (a % 2 == 0 && k > 0)
801 {
802 a >>= 1;
803 k --;
804 }
805 return a << l;
806 }
807
808
809 mp_limb_t
mpn_mul_fft(mp_ptr op,mp_size_t pl,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml,int k)810 mpn_mul_fft (mp_ptr op, mp_size_t pl,
811 mp_srcptr n, mp_size_t nl,
812 mp_srcptr m, mp_size_t ml,
813 int k)
814 {
815 int K, maxLK, i;
816 mp_size_t N, Nprime, nprime, M, Mp, l;
817 mp_ptr *Ap, *Bp, A, T, B;
818 int **fft_l;
819 int sqr = (n == m && nl == ml);
820 mp_limb_t h;
821 TMP_DECL;
822
823 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
824 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
825
826 TMP_MARK;
827 N = pl * GMP_NUMB_BITS;
828 fft_l = TMP_ALLOC_TYPE (k + 1, int *);
829 for (i = 0; i <= k; i++)
830 fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
831 mpn_fft_initl (fft_l, k);
832 K = 1 << k;
833 M = N >> k; /* N = 2^k M */
834 l = 1 + (M - 1) / GMP_NUMB_BITS;
835 maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
836
837 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
838 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
839 nprime = Nprime / GMP_NUMB_BITS;
840 TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
841 N, K, M, l, maxLK, Nprime, nprime));
842 /* we should ensure that recursively, nprime is a multiple of the next K */
843 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
844 {
845 unsigned long K2;
846 for (;;)
847 {
848 K2 = 1L << mpn_fft_best_k (nprime, sqr);
849 if ((nprime & (K2 - 1)) == 0)
850 break;
851 nprime = (nprime + K2 - 1) & -K2;
852 Nprime = nprime * GMP_LIMB_BITS;
853 /* warning: since nprime changed, K2 may change too! */
854 }
855 TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
856 }
857 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
858
859 T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
860 Mp = Nprime >> k;
861
862 TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
863 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
864 printf (" temp space %ld\n", 2 * K * (nprime + 1)));
865
866 A = TMP_ALLOC_LIMBS (K * (nprime + 1));
867 Ap = TMP_ALLOC_MP_PTRS (K);
868 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
869 if (sqr)
870 {
871 mp_size_t pla;
872 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
873 B = TMP_ALLOC_LIMBS (pla);
874 Bp = TMP_ALLOC_MP_PTRS (K);
875 }
876 else
877 {
878 B = TMP_ALLOC_LIMBS (K * (nprime + 1));
879 Bp = TMP_ALLOC_MP_PTRS (K);
880 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
881 }
882 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
883
884 TMP_FREE;
885 return h;
886 }
887
888 #if WANT_OLD_FFT_FULL
889 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
890 void
mpn_mul_fft_full(mp_ptr op,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml)891 mpn_mul_fft_full (mp_ptr op,
892 mp_srcptr n, mp_size_t nl,
893 mp_srcptr m, mp_size_t ml)
894 {
895 mp_ptr pad_op;
896 mp_size_t pl, pl2, pl3, l;
897 int k2, k3;
898 int sqr = (n == m && nl == ml);
899 int cc, c2, oldcc;
900
901 pl = nl + ml; /* total number of limbs of the result */
902
903 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
904 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
905 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
906 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
907 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
908 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
909 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
910
911 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
912
913 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
914 do
915 {
916 pl2++;
917 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
918 pl2 = mpn_fft_next_size (pl2, k2);
919 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
920 thus pl2 / 2 is exact */
921 k3 = mpn_fft_best_k (pl3, sqr);
922 }
923 while (mpn_fft_next_size (pl3, k3) != pl3);
924
925 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
926 nl, ml, pl2, pl3, k2));
927
928 ASSERT_ALWAYS(pl3 <= pl);
929 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */
930 ASSERT(cc == 0);
931 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
932 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
933 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */
934 /* 0 <= cc <= 1 */
935 ASSERT(0 <= cc && cc <= 1);
936 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
937 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
938 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
939 ASSERT(-1 <= cc && cc <= 1);
940 if (cc < 0)
941 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
942 ASSERT(0 <= cc && cc <= 1);
943 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
944 oldcc = cc;
945 #if HAVE_NATIVE_mpn_add_n_sub_n
946 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
947 /* c2 & 1 is the borrow, c2 & 2 is the carry */
948 cc += c2 >> 1; /* carry out from high <- low + high */
949 c2 = c2 & 1; /* borrow out from low <- low - high */
950 #else
951 {
952 mp_ptr tmp;
953 TMP_DECL;
954
955 TMP_MARK;
956 tmp = TMP_ALLOC_LIMBS (l);
957 MPN_COPY (tmp, pad_op, l);
958 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l);
959 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l);
960 TMP_FREE;
961 }
962 #endif
963 c2 += oldcc;
964 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
965 at pad_op + l, cc is the carry at pad_op + pl2 */
966 /* 0 <= cc <= 2 */
967 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
968 /* -1 <= cc <= 2 */
969 if (cc > 0)
970 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
971 /* now -1 <= cc <= 0 */
972 if (cc < 0)
973 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
974 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
975 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
976 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
977 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
978 out below */
979 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
980 if (cc) /* then cc=1 */
981 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
982 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
983 mod 2^(pl2*GMP_NUMB_BITS) + 1 */
984 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
985 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
986 MPN_COPY (op + pl3, pad_op, pl - pl3);
987 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
988 __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
989 /* since the final result has at most pl limbs, no carry out below */
990 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
991 }
992 #endif
993