1 /* Schoenhage's fast multiplication modulo 2^N+1.
2
3 Contributed by Paul Zimmermann.
4
5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY
6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9 Copyright 1998-2010, 2012, 2013, 2018, 2020 Free Software Foundation, Inc.
10
11 This file is part of the GNU MP Library.
12
13 The GNU MP Library is free software; you can redistribute it and/or modify
14 it under the terms of either:
15
16 * the GNU Lesser General Public License as published by the Free
17 Software Foundation; either version 3 of the License, or (at your
18 option) any later version.
19
20 or
21
22 * the GNU General Public License as published by the Free Software
23 Foundation; either version 2 of the License, or (at your option) any
24 later version.
25
26 or both in parallel, as here.
27
28 The GNU MP Library is distributed in the hope that it will be useful, but
29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
30 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
31 for more details.
32
33 You should have received copies of the GNU General Public License and the
34 GNU Lesser General Public License along with the GNU MP Library. If not,
35 see https://www.gnu.org/licenses/. */
36
37
38 /* References:
39
40 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
41 Strassen, Computing 7, p. 281-292, 1971.
42
43 Asymptotically fast algorithms for the numerical multiplication and division
44 of polynomials with complex coefficients, by Arnold Schoenhage, Computer
45 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
46
47 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
48 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
49
50 TODO:
51
52 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
53 Zimmermann.
54
55 It might be possible to avoid a small number of MPN_COPYs by using a
56 rotating temporary or two.
57
58 Cleanup and simplify the code!
59 */
60
61 #ifdef TRACE
62 #undef TRACE
63 #define TRACE(x) x
64 #include <stdio.h>
65 #else
66 #define TRACE(x)
67 #endif
68
69 #include "gmp-impl.h"
70
71 #ifdef WANT_ADDSUB
72 #include "generic/add_n_sub_n.c"
73 #define HAVE_NATIVE_mpn_add_n_sub_n 1
74 #endif
75
76 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
77 mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
78 mp_size_t, mp_size_t, int **, mp_ptr, int);
79 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
80 mp_size_t, mp_size_t, mp_size_t, mp_ptr);
81
82
83 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
84 We have sqr=0 if for a multiply, sqr=1 for a square.
85 There are three generations of this code; we keep the old ones as long as
86 some gmp-mparam.h is not updated. */
87
88
89 /*****************************************************************************/
90
91 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
92
93 #ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */
94 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
95 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
96 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
97 #else
98 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
99 #endif
100 #endif
101 #endif
102
103 #ifndef FFT_TABLE3_SIZE
104 #define FFT_TABLE3_SIZE 200
105 #endif
106
107 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
108 {
109 MUL_FFT_TABLE3,
110 SQR_FFT_TABLE3
111 };
112
113 int
mpn_fft_best_k(mp_size_t n,int sqr)114 mpn_fft_best_k (mp_size_t n, int sqr)
115 {
116 const struct fft_table_nk *fft_tab, *tab;
117 mp_size_t tab_n, thres;
118 int last_k;
119
120 fft_tab = mpn_fft_table3[sqr];
121 last_k = fft_tab->k;
122 for (tab = fft_tab + 1; ; tab++)
123 {
124 tab_n = tab->n;
125 thres = tab_n << last_k;
126 if (n <= thres)
127 break;
128 last_k = tab->k;
129 }
130 return last_k;
131 }
132
133 #define MPN_FFT_BEST_READY 1
134 #endif
135
136 /*****************************************************************************/
137
138 #if ! defined (MPN_FFT_BEST_READY)
139 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
140 {
141 MUL_FFT_TABLE,
142 SQR_FFT_TABLE
143 };
144
145 int
mpn_fft_best_k(mp_size_t n,int sqr)146 mpn_fft_best_k (mp_size_t n, int sqr)
147 {
148 int i;
149
150 for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
151 if (n < mpn_fft_table[sqr][i])
152 return i + FFT_FIRST_K;
153
154 /* treat 4*last as one further entry */
155 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
156 return i + FFT_FIRST_K;
157 else
158 return i + FFT_FIRST_K + 1;
159 }
160 #endif
161
162 /*****************************************************************************/
163
164
165 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
166 i.e. smallest multiple of 2^k >= pl.
167
168 Don't declare static: needed by tuneup.
169 */
170
171 mp_size_t
mpn_fft_next_size(mp_size_t pl,int k)172 mpn_fft_next_size (mp_size_t pl, int k)
173 {
174 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
175 return pl << k;
176 }
177
178
179 /* Initialize l[i][j] with bitrev(j) */
180 static void
mpn_fft_initl(int ** l,int k)181 mpn_fft_initl (int **l, int k)
182 {
183 int i, j, K;
184 int *li;
185
186 l[0][0] = 0;
187 for (i = 1, K = 1; i <= k; i++, K *= 2)
188 {
189 li = l[i];
190 for (j = 0; j < K; j++)
191 {
192 li[j] = 2 * l[i - 1][j];
193 li[K + j] = 1 + li[j];
194 }
195 }
196 }
197
198
199 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
200 Assumes a is semi-normalized, i.e. a[n] <= 1.
201 r and a must have n+1 limbs, and not overlap.
202 */
203 static void
mpn_fft_mul_2exp_modF(mp_ptr r,mp_srcptr a,mp_bitcnt_t d,mp_size_t n)204 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
205 {
206 unsigned int sh;
207 mp_size_t m;
208 mp_limb_t cc, rd;
209
210 sh = d % GMP_NUMB_BITS;
211 m = d / GMP_NUMB_BITS;
212
213 if (m >= n) /* negate */
214 {
215 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh)
216 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */
217
218 m -= n;
219 if (sh != 0)
220 {
221 /* no out shift below since a[n] <= 1 */
222 mpn_lshift (r, a + n - m, m + 1, sh);
223 rd = r[m];
224 cc = mpn_lshiftc (r + m, a, n - m, sh);
225 }
226 else
227 {
228 MPN_COPY (r, a + n - m, m);
229 rd = a[n];
230 mpn_com (r + m, a, n - m);
231 cc = 0;
232 }
233
234 /* add cc to r[0], and add rd to r[m] */
235
236 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
237
238 r[n] = 0;
239 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
240 cc++;
241 mpn_incr_u (r, cc);
242
243 rd++;
244 /* rd might overflow when sh=GMP_NUMB_BITS-1 */
245 cc = (rd == 0) ? 1 : rd;
246 r = r + m + (rd == 0);
247 mpn_incr_u (r, cc);
248 }
249 else
250 {
251 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh)
252 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */
253 if (sh != 0)
254 {
255 /* no out bits below since a[n] <= 1 */
256 mpn_lshiftc (r, a + n - m, m + 1, sh);
257 rd = ~r[m];
258 /* {r, m+1} = {a+n-m, m+1} << sh */
259 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
260 }
261 else
262 {
263 /* r[m] is not used below, but we save a test for m=0 */
264 mpn_com (r, a + n - m, m + 1);
265 rd = a[n];
266 MPN_COPY (r + m, a, n - m);
267 cc = 0;
268 }
269
270 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
271
272 /* if m=0 we just have r[0]=a[n] << sh */
273 if (m != 0)
274 {
275 /* now add 1 in r[0], subtract 1 in r[m] */
276 if (cc-- == 0) /* then add 1 to r[0] */
277 cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
278 cc = mpn_sub_1 (r, r, m, cc) + 1;
279 /* add 1 to cc instead of rd since rd might overflow */
280 }
281
282 /* now subtract cc and rd from r[m..n] */
283
284 r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
285 r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
286 if (r[n] & GMP_LIMB_HIGHBIT)
287 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
288 }
289 }
290
291 #if HAVE_NATIVE_mpn_add_n_sub_n
292 static inline void
mpn_fft_add_sub_modF(mp_ptr A0,mp_ptr Ai,mp_srcptr tp,mp_size_t n)293 mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n)
294 {
295 mp_limb_t cyas, c, x;
296
297 cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n);
298
299 c = A0[n] - tp[n] - (cyas & 1);
300 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
301 Ai[n] = x + c;
302 MPN_INCR_U (Ai, n + 1, x);
303
304 c = A0[n] + tp[n] + (cyas >> 1);
305 x = (c - 1) & -(c != 0);
306 A0[n] = c - x;
307 MPN_DECR_U (A0, n + 1, x);
308 }
309
310 #else /* ! HAVE_NATIVE_mpn_add_n_sub_n */
311
312 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
313 Assumes a and b are semi-normalized.
314 */
315 static inline void
mpn_fft_add_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,mp_size_t n)316 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
317 {
318 mp_limb_t c, x;
319
320 c = a[n] + b[n] + mpn_add_n (r, a, b, n);
321 /* 0 <= c <= 3 */
322
323 #if 1
324 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
325 result is slower code, of course. But the following outsmarts GCC. */
326 x = (c - 1) & -(c != 0);
327 r[n] = c - x;
328 MPN_DECR_U (r, n + 1, x);
329 #endif
330 #if 0
331 if (c > 1)
332 {
333 r[n] = 1; /* r[n] - c = 1 */
334 MPN_DECR_U (r, n + 1, c - 1);
335 }
336 else
337 {
338 r[n] = c;
339 }
340 #endif
341 }
342
343 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
344 Assumes a and b are semi-normalized.
345 */
346 static inline void
mpn_fft_sub_modF(mp_ptr r,mp_srcptr a,mp_srcptr b,mp_size_t n)347 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
348 {
349 mp_limb_t c, x;
350
351 c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
352 /* -2 <= c <= 1 */
353
354 #if 1
355 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
356 result is slower code, of course. But the following outsmarts GCC. */
357 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
358 r[n] = x + c;
359 MPN_INCR_U (r, n + 1, x);
360 #endif
361 #if 0
362 if ((c & GMP_LIMB_HIGHBIT) != 0)
363 {
364 r[n] = 0;
365 MPN_INCR_U (r, n + 1, -c);
366 }
367 else
368 {
369 r[n] = c;
370 }
371 #endif
372 }
373 #endif /* HAVE_NATIVE_mpn_add_n_sub_n */
374
375 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
376 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
377 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
378
379 static void
mpn_fft_fft(mp_ptr * Ap,mp_size_t K,int ** ll,mp_size_t omega,mp_size_t n,mp_size_t inc,mp_ptr tp)380 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
381 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
382 {
383 if (K == 2)
384 {
385 mp_limb_t cy;
386 #if HAVE_NATIVE_mpn_add_n_sub_n
387 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
388 #else
389 MPN_COPY (tp, Ap[0], n + 1);
390 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
391 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
392 #endif
393 if (Ap[0][n] > 1) /* can be 2 or 3 */
394 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
395 if (cy) /* Ap[inc][n] can be -1 or -2 */
396 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
397 }
398 else
399 {
400 mp_size_t j, K2 = K >> 1;
401 int *lk = *ll;
402
403 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp);
404 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
405 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
406 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
407 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
408 {
409 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
410 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
411 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
412 #if HAVE_NATIVE_mpn_add_n_sub_n
413 mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n);
414 #else
415 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
416 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
417 #endif
418 }
419 }
420 }
421
422 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
423 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
424 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
425 tp must have space for 2*(n+1) limbs.
426 */
427
428
429 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
430 by subtracting that modulus if necessary.
431
432 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
433 borrow and the limbs must be zeroed out again. This will occur very
434 infrequently. */
435
436 static inline void
mpn_fft_normalize(mp_ptr ap,mp_size_t n)437 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
438 {
439 if (ap[n] != 0)
440 {
441 MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
442 if (ap[n] == 0)
443 {
444 /* This happens with very low probability; we have yet to trigger it,
445 and thereby make sure this code is correct. */
446 MPN_ZERO (ap, n);
447 ap[n] = 1;
448 }
449 else
450 ap[n] = 0;
451 }
452 }
453
454 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
455 static void
mpn_fft_mul_modF_K(mp_ptr * ap,mp_ptr * bp,mp_size_t n,mp_size_t K)456 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
457 {
458 int i;
459 int sqr = (ap == bp);
460 TMP_DECL;
461
462 TMP_MARK;
463
464 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
465 {
466 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
467 int k;
468 int **fft_l, *tmp;
469 mp_ptr *Ap, *Bp, A, B, T;
470
471 k = mpn_fft_best_k (n, sqr);
472 K2 = (mp_size_t) 1 << k;
473 ASSERT_ALWAYS((n & (K2 - 1)) == 0);
474 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
475 M2 = n * GMP_NUMB_BITS >> k;
476 l = n >> k;
477 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
478 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
479 nprime2 = Nprime2 / GMP_NUMB_BITS;
480
481 /* we should ensure that nprime2 is a multiple of the next K */
482 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
483 {
484 mp_size_t K3;
485 for (;;)
486 {
487 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
488 if ((nprime2 & (K3 - 1)) == 0)
489 break;
490 nprime2 = (nprime2 + K3 - 1) & -K3;
491 Nprime2 = nprime2 * GMP_LIMB_BITS;
492 /* warning: since nprime2 changed, K3 may change too! */
493 }
494 }
495 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
496
497 Mp2 = Nprime2 >> k;
498
499 Ap = TMP_BALLOC_MP_PTRS (K2);
500 Bp = TMP_BALLOC_MP_PTRS (K2);
501 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
502 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
503 B = A + ((nprime2 + 1) << k);
504 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
505 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
506 for (i = 0; i <= k; i++)
507 {
508 fft_l[i] = tmp;
509 tmp += (mp_size_t) 1 << i;
510 }
511
512 mpn_fft_initl (fft_l, k);
513
514 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
515 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
516 for (i = 0; i < K; i++, ap++, bp++)
517 {
518 mp_limb_t cy;
519 mpn_fft_normalize (*ap, n);
520 if (!sqr)
521 mpn_fft_normalize (*bp, n);
522
523 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
524 if (!sqr)
525 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
526
527 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
528 l, Mp2, fft_l, T, sqr);
529 (*ap)[n] = cy;
530 }
531 }
532 else
533 {
534 mp_ptr a, b, tp, tpn;
535 mp_limb_t cc;
536 mp_size_t n2 = 2 * n;
537 tp = TMP_BALLOC_LIMBS (n2);
538 tpn = tp + n;
539 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n));
540 for (i = 0; i < K; i++)
541 {
542 a = *ap++;
543 b = *bp++;
544 if (sqr)
545 mpn_sqr (tp, a, n);
546 else
547 mpn_mul_n (tp, b, a, n);
548 if (a[n] != 0)
549 cc = mpn_add_n (tpn, tpn, b, n);
550 else
551 cc = 0;
552 if (b[n] != 0)
553 cc += mpn_add_n (tpn, tpn, a, n) + a[n];
554 if (cc != 0)
555 {
556 cc = mpn_add_1 (tp, tp, n2, cc);
557 /* If mpn_add_1 give a carry (cc != 0),
558 the result (tp) is at most GMP_NUMB_MAX - 1,
559 so the following addition can't overflow.
560 */
561 tp[0] += cc;
562 }
563 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
564 }
565 }
566 TMP_FREE;
567 }
568
569
570 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
571 output: K*A[0] K*A[K-1] ... K*A[1].
572 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
573 This condition is also fulfilled at exit.
574 */
575 static void
mpn_fft_fftinv(mp_ptr * Ap,mp_size_t K,mp_size_t omega,mp_size_t n,mp_ptr tp)576 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
577 {
578 if (K == 2)
579 {
580 mp_limb_t cy;
581 #if HAVE_NATIVE_mpn_add_n_sub_n
582 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
583 #else
584 MPN_COPY (tp, Ap[0], n + 1);
585 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
586 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
587 #endif
588 if (Ap[0][n] > 1) /* can be 2 or 3 */
589 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
590 if (cy) /* Ap[1][n] can be -1 or -2 */
591 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
592 }
593 else
594 {
595 mp_size_t j, K2 = K >> 1;
596
597 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
598 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
599 /* A[j] <- A[j] + omega^j A[j+K/2]
600 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
601 for (j = 0; j < K2; j++, Ap++)
602 {
603 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
604 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */
605 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
606 #if HAVE_NATIVE_mpn_add_n_sub_n
607 mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n);
608 #else
609 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
610 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
611 #endif
612 }
613 }
614 }
615
616
617 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
618 static void
mpn_fft_div_2exp_modF(mp_ptr r,mp_srcptr a,mp_bitcnt_t k,mp_size_t n)619 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
620 {
621 mp_bitcnt_t i;
622
623 ASSERT (r != a);
624 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
625 mpn_fft_mul_2exp_modF (r, a, i, n);
626 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
627 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
628 mpn_fft_normalize (r, n);
629 }
630
631
632 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
633 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
634 then {rp,n}=0.
635 */
636 static mp_size_t
mpn_fft_norm_modF(mp_ptr rp,mp_size_t n,mp_ptr ap,mp_size_t an)637 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
638 {
639 mp_size_t l, m, rpn;
640 mp_limb_t cc;
641
642 ASSERT ((n <= an) && (an <= 3 * n));
643 m = an - 2 * n;
644 if (m > 0)
645 {
646 l = n;
647 /* add {ap, m} and {ap+2n, m} in {rp, m} */
648 cc = mpn_add_n (rp, ap, ap + 2 * n, m);
649 /* copy {ap+m, n-m} to {rp+m, n-m} */
650 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
651 }
652 else
653 {
654 l = an - n; /* l <= n */
655 MPN_COPY (rp, ap, n);
656 rpn = 0;
657 }
658
659 /* remains to subtract {ap+n, l} from {rp, n+1} */
660 cc = mpn_sub_n (rp, rp, ap + n, l);
661 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
662 if (rpn < 0) /* necessarily rpn = -1 */
663 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
664 return rpn;
665 }
666
667 /* store in A[0..nprime] the first M bits from {n, nl},
668 in A[nprime+1..] the following M bits, ...
669 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
670 T must have space for at least (nprime + 1) limbs.
671 We must have nl <= 2*K*l.
672 */
673 static void
mpn_mul_fft_decompose(mp_ptr A,mp_ptr * Ap,mp_size_t K,mp_size_t nprime,mp_srcptr n,mp_size_t nl,mp_size_t l,mp_size_t Mp,mp_ptr T)674 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
675 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
676 mp_ptr T)
677 {
678 mp_size_t i, j;
679 mp_ptr tmp;
680 mp_size_t Kl = K * l;
681 TMP_DECL;
682 TMP_MARK;
683
684 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
685 {
686 mp_size_t dif = nl - Kl;
687 mp_limb_signed_t cy;
688
689 tmp = TMP_BALLOC_LIMBS(Kl + 1);
690
691 if (dif > Kl)
692 {
693 int subp = 0;
694
695 cy = mpn_sub_n (tmp, n, n + Kl, Kl);
696 n += 2 * Kl;
697 dif -= Kl;
698
699 /* now dif > 0 */
700 while (dif > Kl)
701 {
702 if (subp)
703 cy += mpn_sub_n (tmp, tmp, n, Kl);
704 else
705 cy -= mpn_add_n (tmp, tmp, n, Kl);
706 subp ^= 1;
707 n += Kl;
708 dif -= Kl;
709 }
710 /* now dif <= Kl */
711 if (subp)
712 cy += mpn_sub (tmp, tmp, Kl, n, dif);
713 else
714 cy -= mpn_add (tmp, tmp, Kl, n, dif);
715 if (cy >= 0)
716 cy = mpn_add_1 (tmp, tmp, Kl, cy);
717 else
718 cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
719 }
720 else /* dif <= Kl, i.e. nl <= 2 * Kl */
721 {
722 cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
723 cy = mpn_add_1 (tmp, tmp, Kl, cy);
724 }
725 tmp[Kl] = cy;
726 nl = Kl + 1;
727 n = tmp;
728 }
729 for (i = 0; i < K; i++)
730 {
731 Ap[i] = A;
732 /* store the next M bits of n into A[0..nprime] */
733 if (nl > 0) /* nl is the number of remaining limbs */
734 {
735 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
736 nl -= j;
737 MPN_COPY (T, n, j);
738 MPN_ZERO (T + j, nprime + 1 - j);
739 n += l;
740 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
741 }
742 else
743 MPN_ZERO (A, nprime + 1);
744 A += nprime + 1;
745 }
746 ASSERT_ALWAYS (nl == 0);
747 TMP_FREE;
748 }
749
750 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
751 op is pl limbs, its high bit is returned.
752 One must have pl = mpn_fft_next_size (pl, k).
753 T must have space for 2 * (nprime + 1) limbs.
754 */
755
756 static mp_limb_t
mpn_mul_fft_internal(mp_ptr op,mp_size_t pl,int k,mp_ptr * Ap,mp_ptr * Bp,mp_ptr A,mp_ptr B,mp_size_t nprime,mp_size_t l,mp_size_t Mp,int ** fft_l,mp_ptr T,int sqr)757 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
758 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
759 mp_size_t nprime, mp_size_t l, mp_size_t Mp,
760 int **fft_l, mp_ptr T, int sqr)
761 {
762 mp_size_t K, i, pla, lo, sh, j;
763 mp_ptr p;
764 mp_limb_t cc;
765
766 K = (mp_size_t) 1 << k;
767
768 /* direct fft's */
769 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
770 if (!sqr)
771 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
772
773 /* term to term multiplications */
774 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
775
776 /* inverse fft's */
777 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
778
779 /* division of terms after inverse fft */
780 Bp[0] = T + nprime + 1;
781 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
782 for (i = 1; i < K; i++)
783 {
784 Bp[i] = Ap[i - 1];
785 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
786 }
787
788 /* addition of terms in result p */
789 MPN_ZERO (T, nprime + 1);
790 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
791 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
792 MPN_ZERO (p, pla);
793 cc = 0; /* will accumulate the (signed) carry at p[pla] */
794 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
795 {
796 mp_ptr n = p + sh;
797
798 j = (K - i) & (K - 1);
799
800 if (mpn_add_n (n, n, Bp[j], nprime + 1))
801 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
802 pla - sh - nprime - 1, CNST_LIMB(1));
803 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
804 if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
805 { /* subtract 2^N'+1 */
806 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
807 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
808 }
809 }
810 if (cc == -CNST_LIMB(1))
811 {
812 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
813 {
814 /* p[pla-pl]...p[pla-1] are all zero */
815 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
816 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
817 }
818 }
819 else if (cc == 1)
820 {
821 if (pla >= 2 * pl)
822 {
823 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
824 ;
825 }
826 else
827 {
828 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
829 ASSERT (cc == 0);
830 }
831 }
832 else
833 ASSERT (cc == 0);
834
835 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
836 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
837 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
838 return mpn_fft_norm_modF (op, pl, p, pla);
839 }
840
841 /* return the lcm of a and 2^k */
842 static mp_bitcnt_t
mpn_mul_fft_lcm(mp_bitcnt_t a,int k)843 mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
844 {
845 mp_bitcnt_t l = k;
846
847 while (a % 2 == 0 && k > 0)
848 {
849 a >>= 1;
850 k --;
851 }
852 return a << l;
853 }
854
855
856 mp_limb_t
mpn_mul_fft(mp_ptr op,mp_size_t pl,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml,int k)857 mpn_mul_fft (mp_ptr op, mp_size_t pl,
858 mp_srcptr n, mp_size_t nl,
859 mp_srcptr m, mp_size_t ml,
860 int k)
861 {
862 int i;
863 mp_size_t K, maxLK;
864 mp_size_t N, Nprime, nprime, M, Mp, l;
865 mp_ptr *Ap, *Bp, A, T, B;
866 int **fft_l, *tmp;
867 int sqr = (n == m && nl == ml);
868 mp_limb_t h;
869 TMP_DECL;
870
871 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
872 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
873
874 TMP_MARK;
875 N = pl * GMP_NUMB_BITS;
876 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
877 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
878 for (i = 0; i <= k; i++)
879 {
880 fft_l[i] = tmp;
881 tmp += (mp_size_t) 1 << i;
882 }
883
884 mpn_fft_initl (fft_l, k);
885 K = (mp_size_t) 1 << k;
886 M = N >> k; /* N = 2^k M */
887 l = 1 + (M - 1) / GMP_NUMB_BITS;
888 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
889
890 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
891 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
892 nprime = Nprime / GMP_NUMB_BITS;
893 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
894 N, K, M, l, maxLK, Nprime, nprime));
895 /* we should ensure that recursively, nprime is a multiple of the next K */
896 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
897 {
898 mp_size_t K2;
899 for (;;)
900 {
901 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
902 if ((nprime & (K2 - 1)) == 0)
903 break;
904 nprime = (nprime + K2 - 1) & -K2;
905 Nprime = nprime * GMP_LIMB_BITS;
906 /* warning: since nprime changed, K2 may change too! */
907 }
908 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
909 }
910 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
911
912 T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
913 Mp = Nprime >> k;
914
915 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
916 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
917 printf (" temp space %ld\n", 2 * K * (nprime + 1)));
918
919 A = TMP_BALLOC_LIMBS (K * (nprime + 1));
920 Ap = TMP_BALLOC_MP_PTRS (K);
921 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
922 if (sqr)
923 {
924 mp_size_t pla;
925 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
926 B = TMP_BALLOC_LIMBS (pla);
927 Bp = TMP_BALLOC_MP_PTRS (K);
928 }
929 else
930 {
931 B = TMP_BALLOC_LIMBS (K * (nprime + 1));
932 Bp = TMP_BALLOC_MP_PTRS (K);
933 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
934 }
935 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
936
937 TMP_FREE;
938 return h;
939 }
940
941 #if WANT_OLD_FFT_FULL
942 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
943 void
mpn_mul_fft_full(mp_ptr op,mp_srcptr n,mp_size_t nl,mp_srcptr m,mp_size_t ml)944 mpn_mul_fft_full (mp_ptr op,
945 mp_srcptr n, mp_size_t nl,
946 mp_srcptr m, mp_size_t ml)
947 {
948 mp_ptr pad_op;
949 mp_size_t pl, pl2, pl3, l;
950 mp_size_t cc, c2, oldcc;
951 int k2, k3;
952 int sqr = (n == m && nl == ml);
953
954 pl = nl + ml; /* total number of limbs of the result */
955
956 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
957 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
958 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
959 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
960 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
961 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
962 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
963
964 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
965
966 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
967 do
968 {
969 pl2++;
970 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
971 pl2 = mpn_fft_next_size (pl2, k2);
972 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
973 thus pl2 / 2 is exact */
974 k3 = mpn_fft_best_k (pl3, sqr);
975 }
976 while (mpn_fft_next_size (pl3, k3) != pl3);
977
978 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
979 nl, ml, pl2, pl3, k2));
980
981 ASSERT_ALWAYS(pl3 <= pl);
982 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */
983 ASSERT(cc == 0);
984 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
985 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
986 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */
987 /* 0 <= cc <= 1 */
988 ASSERT(0 <= cc && cc <= 1);
989 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
990 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
991 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
992 ASSERT(-1 <= cc && cc <= 1);
993 if (cc < 0)
994 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
995 ASSERT(0 <= cc && cc <= 1);
996 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
997 oldcc = cc;
998 #if HAVE_NATIVE_mpn_add_n_sub_n
999 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
1000 cc += c2 >> 1; /* carry out from high <- low + high */
1001 c2 = c2 & 1; /* borrow out from low <- low - high */
1002 #else
1003 {
1004 mp_ptr tmp;
1005 TMP_DECL;
1006
1007 TMP_MARK;
1008 tmp = TMP_BALLOC_LIMBS (l);
1009 MPN_COPY (tmp, pad_op, l);
1010 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l);
1011 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l);
1012 TMP_FREE;
1013 }
1014 #endif
1015 c2 += oldcc;
1016 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
1017 at pad_op + l, cc is the carry at pad_op + pl2 */
1018 /* 0 <= cc <= 2 */
1019 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
1020 /* -1 <= cc <= 2 */
1021 if (cc > 0)
1022 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
1023 /* now -1 <= cc <= 0 */
1024 if (cc < 0)
1025 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1026 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1027 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1028 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1029 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1030 out below */
1031 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1032 if (cc) /* then cc=1 */
1033 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1034 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1035 mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1036 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1037 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1038 MPN_COPY (op + pl3, pad_op, pl - pl3);
1039 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1040 __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1041 /* since the final result has at most pl limbs, no carry out below */
1042 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1043 }
1044 #endif
1045