1 /* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008 10 Free Software Foundation, Inc. 11 12 This file is part of the GNU MP Library. 13 14 The GNU MP Library is free software; you can redistribute it and/or modify 15 it under the terms of the GNU Lesser General Public License as published by 16 the Free Software Foundation; either version 3 of the License, or (at your 17 option) any later version. 18 19 The GNU MP Library is distributed in the hope that it will be useful, but 20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 21 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 22 License for more details. 23 24 You should have received a copy of the GNU Lesser General Public License 25 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 26 27 28 /* References: 29 30 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 31 Strassen, Computing 7, p. 281-292, 1971. 32 33 Asymptotically fast algorithms for the numerical multiplication and division 34 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 35 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 36 37 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 38 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 39 40 TODO: 41 42 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 43 Zimmermann. 44 45 It might be possible to avoid a small number of MPN_COPYs by using a 46 rotating temporary or two. 47 48 Cleanup and simplify the code! 49 */ 50 51 #ifdef TRACE 52 #undef TRACE 53 #define TRACE(x) x 54 #include <stdio.h> 55 #else 56 #define TRACE(x) 57 #endif 58 59 #include "gmp.h" 60 #include "gmp-impl.h" 61 62 #ifdef WANT_ADDSUB 63 #include "generic/addsub_n.c" 64 #define HAVE_NATIVE_mpn_addsub_n 1 65 #endif 66 67 static void mpn_mul_fft_internal 68 __GMP_PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *, 69 mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, 70 int)); 71 72 73 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 74 sqr==0 if for a multiply, sqr==1 for a square. 75 Don't declare it static since it is needed by tuneup. 76 */ 77 #ifdef MUL_FFT_TABLE2 78 #define MPN_FFT_TABLE2_SIZE 512 79 80 /* FIXME: The format of this should change to need less space. 81 Perhaps put n and k in the same 32-bit word, with n shifted-down 82 (k-2) steps, and k using the 4-5 lowest bits. That's possible since 83 n-1 is highly divisible. 84 Alternatively, separate n and k out into separate arrays. */ 85 struct nk { 86 mp_size_t n; 87 unsigned char k; 88 }; 89 90 static struct nk mpn_fft_table2[2][MPN_FFT_TABLE2_SIZE] = 91 { 92 MUL_FFT_TABLE2, 93 SQR_FFT_TABLE2 94 }; 95 96 int 97 mpn_fft_best_k (mp_size_t n, int sqr) 98 { 99 struct nk *tab; 100 int last_k; 101 102 last_k = 4; 103 for (tab = mpn_fft_table2[sqr] + 1; ; tab++) 104 { 105 if (n < tab->n) 106 break; 107 last_k = tab->k; 108 } 109 return last_k; 110 } 111 #endif 112 113 #if !defined (MUL_FFT_TABLE2) || TUNE_PROGRAM_BUILD 114 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 115 { 116 MUL_FFT_TABLE, 117 SQR_FFT_TABLE 118 }; 119 #endif 120 121 #if !defined (MUL_FFT_TABLE2) 122 int 123 mpn_fft_best_k (mp_size_t n, int sqr) 124 { 125 int i; 126 127 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 128 if (n < mpn_fft_table[sqr][i]) 129 return i + FFT_FIRST_K; 130 131 /* treat 4*last as one further entry */ 132 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 133 return i + FFT_FIRST_K; 134 else 135 return i + FFT_FIRST_K + 1; 136 } 137 #endif 138 139 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 140 i.e. smallest multiple of 2^k >= pl. 141 142 Don't declare static: needed by tuneup. 143 */ 144 145 mp_size_t 146 mpn_fft_next_size (mp_size_t pl, int k) 147 { 148 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 149 return pl << k; 150 } 151 152 153 /* Initialize l[i][j] with bitrev(j) */ 154 static void 155 mpn_fft_initl (int **l, int k) 156 { 157 int i, j, K; 158 int *li; 159 160 l[0][0] = 0; 161 for (i = 1, K = 1; i <= k; i++, K *= 2) 162 { 163 li = l[i]; 164 for (j = 0; j < K; j++) 165 { 166 li[j] = 2 * l[i - 1][j]; 167 li[K + j] = 1 + li[j]; 168 } 169 } 170 } 171 172 /* Shift {up, n} of cnt bits to the left, store the complemented result 173 in {rp, n}, and output the shifted bits (not complemented). 174 Same as: 175 cc = mpn_lshift (rp, up, n, cnt); 176 mpn_com_n (rp, rp, n); 177 return cc; 178 179 Assumes n >= 1, 1 < cnt < GMP_NUMB_BITS, rp >= up. 180 */ 181 #ifndef HAVE_NATIVE_mpn_lshiftc 182 #undef mpn_lshiftc 183 static mp_limb_t 184 mpn_lshiftc (mp_ptr rp, mp_srcptr up, mp_size_t n, unsigned int cnt) 185 { 186 mp_limb_t high_limb, low_limb; 187 unsigned int tnc; 188 mp_size_t i; 189 mp_limb_t retval; 190 191 up += n; 192 rp += n; 193 194 tnc = GMP_NUMB_BITS - cnt; 195 low_limb = *--up; 196 retval = low_limb >> tnc; 197 high_limb = (low_limb << cnt); 198 199 for (i = n - 1; i != 0; i--) 200 { 201 low_limb = *--up; 202 *--rp = (~(high_limb | (low_limb >> tnc))) & GMP_NUMB_MASK; 203 high_limb = low_limb << cnt; 204 } 205 *--rp = (~high_limb) & GMP_NUMB_MASK; 206 207 return retval; 208 } 209 #endif 210 211 /* r <- a*2^e mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 212 Assumes a is semi-normalized, i.e. a[n] <= 1. 213 r and a must have n+1 limbs, and not overlap. 214 */ 215 static void 216 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n) 217 { 218 int sh, negate; 219 mp_limb_t cc, rd; 220 221 sh = d % GMP_NUMB_BITS; 222 d /= GMP_NUMB_BITS; 223 negate = d >= n; 224 if (negate) 225 d -= n; 226 227 if (negate) 228 { 229 /* r[0..d-1] <-- lshift(a[n-d]..a[n-1], sh) 230 r[d..n-1] <-- -lshift(a[0]..a[n-d-1], sh) */ 231 if (sh != 0) 232 { 233 /* no out shift below since a[n] <= 1 */ 234 mpn_lshift (r, a + n - d, d + 1, sh); 235 rd = r[d]; 236 cc = mpn_lshiftc (r + d, a, n - d, sh); 237 } 238 else 239 { 240 MPN_COPY (r, a + n - d, d); 241 rd = a[n]; 242 mpn_com_n (r + d, a, n - d); 243 cc = 0; 244 } 245 246 /* add cc to r[0], and add rd to r[d] */ 247 248 /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */ 249 250 r[n] = 0; 251 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 252 cc++; 253 mpn_incr_u (r, cc); 254 255 rd ++; 256 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 257 cc = (rd == 0) ? 1 : rd; 258 r = r + d + (rd == 0); 259 mpn_incr_u (r, cc); 260 261 return; 262 } 263 264 /* if negate=0, 265 r[0..d-1] <-- -lshift(a[n-d]..a[n-1], sh) 266 r[d..n-1] <-- lshift(a[0]..a[n-d-1], sh) 267 */ 268 if (sh != 0) 269 { 270 /* no out bits below since a[n] <= 1 */ 271 mpn_lshiftc (r, a + n - d, d + 1, sh); 272 rd = ~r[d]; 273 /* {r, d+1} = {a+n-d, d+1} << sh */ 274 cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */ 275 } 276 else 277 { 278 /* r[d] is not used below, but we save a test for d=0 */ 279 mpn_com_n (r, a + n - d, d + 1); 280 rd = a[n]; 281 MPN_COPY (r + d, a, n - d); 282 cc = 0; 283 } 284 285 /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */ 286 287 /* if d=0 we just have r[0]=a[n] << sh */ 288 if (d != 0) 289 { 290 /* now add 1 in r[0], subtract 1 in r[d] */ 291 if (cc-- == 0) /* then add 1 to r[0] */ 292 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 293 cc = mpn_sub_1 (r, r, d, cc) + 1; 294 /* add 1 to cc instead of rd since rd might overflow */ 295 } 296 297 /* now subtract cc and rd from r[d..n] */ 298 299 r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc); 300 r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd); 301 if (r[n] & GMP_LIMB_HIGHBIT) 302 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 303 } 304 305 306 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 307 Assumes a and b are semi-normalized. 308 */ 309 static inline void 310 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 311 { 312 mp_limb_t c, x; 313 314 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 315 /* 0 <= c <= 3 */ 316 317 #if 1 318 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 319 result is slower code, of course. But the following outsmarts GCC. */ 320 x = (c - 1) & -(c != 0); 321 r[n] = c - x; 322 MPN_DECR_U (r, n + 1, x); 323 #endif 324 #if 0 325 if (c > 1) 326 { 327 r[n] = 1; /* r[n] - c = 1 */ 328 MPN_DECR_U (r, n + 1, c - 1); 329 } 330 else 331 { 332 r[n] = c; 333 } 334 #endif 335 } 336 337 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 338 Assumes a and b are semi-normalized. 339 */ 340 static inline void 341 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 342 { 343 mp_limb_t c, x; 344 345 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 346 /* -2 <= c <= 1 */ 347 348 #if 1 349 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 350 result is slower code, of course. But the following outsmarts GCC. */ 351 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 352 r[n] = x + c; 353 MPN_INCR_U (r, n + 1, x); 354 #endif 355 #if 0 356 if ((c & GMP_LIMB_HIGHBIT) != 0) 357 { 358 r[n] = 0; 359 MPN_INCR_U (r, n + 1, -c); 360 } 361 else 362 { 363 r[n] = c; 364 } 365 #endif 366 } 367 368 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 369 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 370 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 371 372 static void 373 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 374 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 375 { 376 if (K == 2) 377 { 378 mp_limb_t cy; 379 #if HAVE_NATIVE_mpn_addsub_n 380 cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 381 #else 382 MPN_COPY (tp, Ap[0], n + 1); 383 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 384 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 385 #endif 386 if (Ap[0][n] > 1) /* can be 2 or 3 */ 387 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 388 if (cy) /* Ap[inc][n] can be -1 or -2 */ 389 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 390 } 391 else 392 { 393 int j; 394 int *lk = *ll; 395 396 mpn_fft_fft (Ap, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 397 mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 398 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 399 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 400 for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc) 401 { 402 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 403 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 404 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 405 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 406 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 407 } 408 } 409 } 410 411 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 412 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 413 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 414 tp must have space for 2*(n+1) limbs. 415 */ 416 417 418 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 419 by subtracting that modulus if necessary. 420 421 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 422 borrow and the limbs must be zeroed out again. This will occur very 423 infrequently. */ 424 425 static inline void 426 mpn_fft_normalize (mp_ptr ap, mp_size_t n) 427 { 428 if (ap[n] != 0) 429 { 430 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 431 if (ap[n] == 0) 432 { 433 /* This happens with very low probability; we have yet to trigger it, 434 and thereby make sure this code is correct. */ 435 MPN_ZERO (ap, n); 436 ap[n] = 1; 437 } 438 else 439 ap[n] = 0; 440 } 441 } 442 443 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 444 static void 445 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K) 446 { 447 int i; 448 int sqr = (ap == bp); 449 TMP_DECL; 450 451 TMP_MARK; 452 453 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 454 { 455 int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 456 int **_fft_l; 457 mp_ptr *Ap, *Bp, A, B, T; 458 459 k = mpn_fft_best_k (n, sqr); 460 K2 = 1 << k; 461 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 462 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 463 M2 = n * GMP_NUMB_BITS >> k; 464 l = n >> k; 465 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 466 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 467 nprime2 = Nprime2 / GMP_NUMB_BITS; 468 469 /* we should ensure that nprime2 is a multiple of the next K */ 470 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 471 { 472 unsigned long K3; 473 for (;;) 474 { 475 K3 = 1L << mpn_fft_best_k (nprime2, sqr); 476 if ((nprime2 & (K3 - 1)) == 0) 477 break; 478 nprime2 = (nprime2 + K3 - 1) & -K3; 479 Nprime2 = nprime2 * GMP_LIMB_BITS; 480 /* warning: since nprime2 changed, K3 may change too! */ 481 } 482 } 483 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 484 485 Mp2 = Nprime2 >> k; 486 487 Ap = TMP_ALLOC_MP_PTRS (K2); 488 Bp = TMP_ALLOC_MP_PTRS (K2); 489 A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1)); 490 T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1)); 491 B = A + K2 * (nprime2 + 1); 492 _fft_l = TMP_ALLOC_TYPE (k + 1, int *); 493 for (i = 0; i <= k; i++) 494 _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int); 495 mpn_fft_initl (_fft_l, k); 496 497 TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n, 498 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 499 for (i = 0; i < K; i++, ap++, bp++) 500 { 501 mpn_fft_normalize (*ap, n); 502 if (!sqr) 503 mpn_fft_normalize (*bp, n); 504 mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2, 505 l, Mp2, _fft_l, T, 1); 506 } 507 } 508 else 509 { 510 mp_ptr a, b, tp, tpn; 511 mp_limb_t cc; 512 int n2 = 2 * n; 513 tp = TMP_ALLOC_LIMBS (n2); 514 tpn = tp + n; 515 TRACE (printf (" mpn_mul_n %d of %ld limbs\n", K, n)); 516 for (i = 0; i < K; i++) 517 { 518 a = *ap++; 519 b = *bp++; 520 if (sqr) 521 mpn_sqr_n (tp, a, n); 522 else 523 mpn_mul_n (tp, b, a, n); 524 if (a[n] != 0) 525 cc = mpn_add_n (tpn, tpn, b, n); 526 else 527 cc = 0; 528 if (b[n] != 0) 529 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 530 if (cc != 0) 531 { 532 /* FIXME: use MPN_INCR_U here, since carry is not expected. */ 533 cc = mpn_add_1 (tp, tp, n2, cc); 534 ASSERT (cc == 0); 535 } 536 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 537 } 538 } 539 TMP_FREE; 540 } 541 542 543 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 544 output: K*A[0] K*A[K-1] ... K*A[1]. 545 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 546 This condition is also fulfilled at exit. 547 */ 548 static void 549 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp) 550 { 551 if (K == 2) 552 { 553 mp_limb_t cy; 554 #if HAVE_NATIVE_mpn_addsub_n 555 cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 556 #else 557 MPN_COPY (tp, Ap[0], n + 1); 558 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 559 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 560 #endif 561 if (Ap[0][n] > 1) /* can be 2 or 3 */ 562 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 563 if (cy) /* Ap[1][n] can be -1 or -2 */ 564 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 565 } 566 else 567 { 568 int j, K2 = K >> 1; 569 570 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 571 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 572 /* A[j] <- A[j] + omega^j A[j+K/2] 573 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 574 for (j = 0; j < K2; j++, Ap++) 575 { 576 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 577 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 578 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 579 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 580 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 581 } 582 } 583 } 584 585 586 /* A <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 587 static void 588 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n) 589 { 590 int i; 591 592 ASSERT (r != a); 593 i = 2 * n * GMP_NUMB_BITS; 594 i = (i - k) % i; /* FIXME: This % looks superfluous */ 595 mpn_fft_mul_2exp_modF (r, a, i, n); 596 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 597 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 598 mpn_fft_normalize (r, n); 599 } 600 601 602 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 603 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 604 then {rp,n}=0. 605 */ 606 static int 607 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 608 { 609 mp_size_t l; 610 long int m; 611 mp_limb_t cc; 612 int rpn; 613 614 ASSERT ((n <= an) && (an <= 3 * n)); 615 m = an - 2 * n; 616 if (m > 0) 617 { 618 l = n; 619 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 620 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 621 /* copy {ap+m, n-m} to {rp+m, n-m} */ 622 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 623 } 624 else 625 { 626 l = an - n; /* l <= n */ 627 MPN_COPY (rp, ap, n); 628 rpn = 0; 629 } 630 631 /* remains to subtract {ap+n, l} from {rp, n+1} */ 632 cc = mpn_sub_n (rp, rp, ap + n, l); 633 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 634 if (rpn < 0) /* necessarily rpn = -1 */ 635 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 636 return rpn; 637 } 638 639 /* store in A[0..nprime] the first M bits from {n, nl}, 640 in A[nprime+1..] the following M bits, ... 641 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 642 T must have space for at least (nprime + 1) limbs. 643 We must have nl <= 2*K*l. 644 */ 645 static void 646 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n, 647 mp_size_t nl, int l, int Mp, mp_ptr T) 648 { 649 int i, j; 650 mp_ptr tmp; 651 mp_size_t Kl = K * l; 652 TMP_DECL; 653 TMP_MARK; 654 655 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 656 { 657 mp_size_t dif = nl - Kl; 658 mp_limb_signed_t cy; 659 660 tmp = TMP_ALLOC_LIMBS(Kl + 1); 661 662 if (dif > Kl) 663 { 664 int subp = 0; 665 666 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 667 n += 2 * Kl; 668 dif -= Kl; 669 670 /* now dif > 0 */ 671 while (dif > Kl) 672 { 673 if (subp) 674 cy += mpn_sub_n (tmp, tmp, n, Kl); 675 else 676 cy -= mpn_add_n (tmp, tmp, n, Kl); 677 subp ^= 1; 678 n += Kl; 679 dif -= Kl; 680 } 681 /* now dif <= Kl */ 682 if (subp) 683 cy += mpn_sub (tmp, tmp, Kl, n, dif); 684 else 685 cy -= mpn_add (tmp, tmp, Kl, n, dif); 686 if (cy >= 0) 687 cy = mpn_add_1 (tmp, tmp, Kl, cy); 688 else 689 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 690 } 691 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 692 { 693 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 694 cy = mpn_add_1 (tmp, tmp, Kl, cy); 695 } 696 tmp[Kl] = cy; 697 nl = Kl + 1; 698 n = tmp; 699 } 700 for (i = 0; i < K; i++) 701 { 702 Ap[i] = A; 703 /* store the next M bits of n into A[0..nprime] */ 704 if (nl > 0) /* nl is the number of remaining limbs */ 705 { 706 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 707 nl -= j; 708 MPN_COPY (T, n, j); 709 MPN_ZERO (T + j, nprime + 1 - j); 710 n += l; 711 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 712 } 713 else 714 MPN_ZERO (A, nprime + 1); 715 A += nprime + 1; 716 } 717 ASSERT_ALWAYS (nl == 0); 718 TMP_FREE; 719 } 720 721 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 722 n and m have respectively nl and ml limbs 723 op must have space for pl+1 limbs if rec=1 (and pl limbs if rec=0). 724 One must have pl = mpn_fft_next_size (pl, k). 725 T must have space for 2 * (nprime + 1) limbs. 726 727 If rec=0, then store only the pl low bits of the result, and return 728 the out carry. 729 */ 730 731 static void 732 mpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl, 733 int k, int K, 734 mp_ptr *Ap, mp_ptr *Bp, 735 mp_ptr A, mp_ptr B, 736 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 737 int **_fft_l, 738 mp_ptr T, int rec) 739 { 740 int i, sqr, pla, lo, sh, j; 741 mp_ptr p; 742 mp_limb_t cc; 743 744 sqr = n == m; 745 746 TRACE (printf ("pl=%ld k=%d K=%d np=%ld l=%ld Mp=%ld rec=%d sqr=%d\n", 747 pl,k,K,nprime,l,Mp,rec,sqr)); 748 749 /* decomposition of inputs into arrays Ap[i] and Bp[i] */ 750 if (rec) 751 { 752 mpn_mul_fft_decompose (A, Ap, K, nprime, n, K * l + 1, l, Mp, T); 753 if (!sqr) 754 mpn_mul_fft_decompose (B, Bp, K, nprime, m, K * l + 1, l, Mp, T); 755 } 756 757 /* direct fft's */ 758 mpn_fft_fft (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T); 759 if (!sqr) 760 mpn_fft_fft (Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T); 761 762 /* term to term multiplications */ 763 mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K); 764 765 /* inverse fft's */ 766 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 767 768 /* division of terms after inverse fft */ 769 Bp[0] = T + nprime + 1; 770 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 771 for (i = 1; i < K; i++) 772 { 773 Bp[i] = Ap[i - 1]; 774 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 775 } 776 777 /* addition of terms in result p */ 778 MPN_ZERO (T, nprime + 1); 779 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 780 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 781 MPN_ZERO (p, pla); 782 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 783 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 784 { 785 mp_ptr n = p + sh; 786 787 j = (K - i) & (K - 1); 788 789 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 790 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 791 pla - sh - nprime - 1, CNST_LIMB(1)); 792 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 793 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 794 { /* subtract 2^N'+1 */ 795 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 796 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 797 } 798 } 799 if (cc == -CNST_LIMB(1)) 800 { 801 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 802 { 803 /* p[pla-pl]...p[pla-1] are all zero */ 804 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 805 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 806 } 807 } 808 else if (cc == 1) 809 { 810 if (pla >= 2 * pl) 811 { 812 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 813 ; 814 } 815 else 816 { 817 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 818 ASSERT (cc == 0); 819 } 820 } 821 else 822 ASSERT (cc == 0); 823 824 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 825 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 826 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 827 i = mpn_fft_norm_modF (op, pl, p, pla); 828 if (rec) /* store the carry out */ 829 op[pl] = i; 830 } 831 832 /* return the lcm of a and 2^k */ 833 static unsigned long int 834 mpn_mul_fft_lcm (unsigned long int a, unsigned int k) 835 { 836 unsigned long int l = k; 837 838 while (a % 2 == 0 && k > 0) 839 { 840 a >>= 1; 841 k --; 842 } 843 return a << l; 844 } 845 846 void 847 mpn_mul_fft (mp_ptr op, mp_size_t pl, 848 mp_srcptr n, mp_size_t nl, 849 mp_srcptr m, mp_size_t ml, 850 int k) 851 { 852 int K, maxLK, i; 853 mp_size_t N, Nprime, nprime, M, Mp, l; 854 mp_ptr *Ap, *Bp, A, T, B; 855 int **_fft_l; 856 int sqr = (n == m && nl == ml); 857 TMP_DECL; 858 859 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 860 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 861 862 TMP_MARK; 863 N = pl * GMP_NUMB_BITS; 864 _fft_l = TMP_ALLOC_TYPE (k + 1, int *); 865 for (i = 0; i <= k; i++) 866 _fft_l[i] = TMP_ALLOC_TYPE (1 << i, int); 867 mpn_fft_initl (_fft_l, k); 868 K = 1 << k; 869 M = N >> k; /* N = 2^k M */ 870 l = 1 + (M - 1) / GMP_NUMB_BITS; 871 maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 872 873 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 874 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 875 nprime = Nprime / GMP_NUMB_BITS; 876 TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n", 877 N, K, M, l, maxLK, Nprime, nprime)); 878 /* we should ensure that recursively, nprime is a multiple of the next K */ 879 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 880 { 881 unsigned long K2; 882 for (;;) 883 { 884 K2 = 1L << mpn_fft_best_k (nprime, sqr); 885 if ((nprime & (K2 - 1)) == 0) 886 break; 887 nprime = (nprime + K2 - 1) & -K2; 888 Nprime = nprime * GMP_LIMB_BITS; 889 /* warning: since nprime changed, K2 may change too! */ 890 } 891 TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 892 } 893 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 894 895 T = TMP_ALLOC_LIMBS (2 * (nprime + 1)); 896 Mp = Nprime >> k; 897 898 TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n", 899 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 900 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 901 902 A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1)); 903 B = A + K * (nprime + 1); 904 Ap = TMP_ALLOC_MP_PTRS (K); 905 Bp = TMP_ALLOC_MP_PTRS (K); 906 907 /* special decomposition for main call */ 908 /* nl is the number of significant limbs in n */ 909 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 910 if (n != m) 911 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 912 913 mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0); 914 915 TMP_FREE; 916 __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1)); 917 } 918 919 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 920 void 921 mpn_mul_fft_full (mp_ptr op, 922 mp_srcptr n, mp_size_t nl, 923 mp_srcptr m, mp_size_t ml) 924 { 925 mp_ptr pad_op; 926 mp_size_t pl, pl2, pl3, l; 927 int k2, k3; 928 int sqr = (n == m && nl == ml); 929 int cc, c2, oldcc; 930 931 pl = nl + ml; /* total number of limbs of the result */ 932 933 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 934 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 935 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 936 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 937 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 938 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 939 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 940 941 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 942 943 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 944 do 945 { 946 pl2 ++; 947 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 948 pl2 = mpn_fft_next_size (pl2, k2); 949 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 950 thus pl2 / 2 is exact */ 951 k3 = mpn_fft_best_k (pl3, sqr); 952 } 953 while (mpn_fft_next_size (pl3, k3) != pl3); 954 955 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 956 nl, ml, pl2, pl3, k2)); 957 958 ASSERT_ALWAYS(pl3 <= pl); 959 mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 960 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 961 mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 962 cc = mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 963 /* 0 <= cc <= 1 */ 964 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 965 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 966 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; /* -1 <= cc <= 1 */ 967 if (cc < 0) 968 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 969 /* 0 <= cc <= 1 */ 970 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 971 oldcc = cc; 972 #if HAVE_NATIVE_mpn_addsub_n 973 c2 = mpn_addsub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 974 /* c2 & 1 is the borrow, c2 & 2 is the carry */ 975 cc += c2 >> 1; /* carry out from high <- low + high */ 976 c2 = c2 & 1; /* borrow out from low <- low - high */ 977 #else 978 { 979 mp_ptr tmp; 980 TMP_DECL; 981 982 TMP_MARK; 983 tmp = TMP_ALLOC_LIMBS (l); 984 MPN_COPY (tmp, pad_op, l); 985 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 986 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 987 TMP_FREE; 988 } 989 #endif 990 c2 += oldcc; 991 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 992 at pad_op + l, cc is the carry at pad_op + pl2 */ 993 /* 0 <= cc <= 2 */ 994 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 995 /* -1 <= cc <= 2 */ 996 if (cc > 0) 997 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 998 /* now -1 <= cc <= 0 */ 999 if (cc < 0) 1000 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 1001 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 1002 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 1003 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 1004 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 1005 out below */ 1006 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 1007 if (cc) /* then cc=1 */ 1008 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 1009 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 1010 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 1011 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 1012 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 1013 MPN_COPY (op + pl3, pad_op, pl - pl3); 1014 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 1015 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 1016 /* since the final result has at most pl limbs, no carry out below */ 1017 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 1018 } 1019