1 /* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008 10 Free Software Foundation, Inc. 11 12 This file is part of the GNU MP Library. 13 14 The GNU MP Library is free software; you can redistribute it and/or modify 15 it under the terms of the GNU Lesser General Public License as published by 16 the Free Software Foundation; either version 3 of the License, or (at your 17 option) any later version. 18 19 The GNU MP Library is distributed in the hope that it will be useful, but 20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 21 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 22 License for more details. 23 24 You should have received a copy of the GNU Lesser General Public License 25 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 26 27 28 /* References: 29 30 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 31 Strassen, Computing 7, p. 281-292, 1971. 32 33 Asymptotically fast algorithms for the numerical multiplication and division 34 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 35 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 36 37 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 38 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 39 40 TODO: 41 42 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 43 Zimmermann. 44 45 It might be possible to avoid a small number of MPN_COPYs by using a 46 rotating temporary or two. 47 48 Cleanup and simplify the code! 49 */ 50 51 #ifdef TRACE 52 #undef TRACE 53 #define TRACE(x) x 54 #include <stdio.h> 55 #else 56 #define TRACE(x) 57 #endif 58 59 #include "gmp.h" 60 #include "gmp-impl.h" 61 62 #ifdef WANT_ADDSUB 63 #include "generic/addsub_n.c" 64 #define HAVE_NATIVE_mpn_addsub_n 1 65 #endif 66 67 static mp_limb_t mpn_mul_fft_internal 68 __GMP_PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *, 69 mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, 70 int)); 71 72 73 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 74 sqr==0 if for a multiply, sqr==1 for a square. 75 Don't declare it static since it is needed by tuneup. 76 */ 77 #ifdef MUL_FFT_TABLE2 78 79 #if defined (MUL_FFT_TABLE2_SIZE) && defined (SQR_FFT_TABLE2_SIZE) 80 #if MUL_FFT_TABLE2_SIZE > SQR_FFT_TABLE2_SIZE 81 #define FFT_TABLE2_SIZE MUL_FFT_TABLE2_SIZE 82 #else 83 #define FFT_TABLE2_SIZE SQR_FFT_TABLE2_SIZE 84 #endif 85 #endif 86 87 #ifndef FFT_TABLE2_SIZE 88 #define FFT_TABLE2_SIZE 200 89 #endif 90 91 /* FIXME: The format of this should change to need less space. 92 Perhaps put n and k in the same 32-bit word, with n shifted-down 93 (k-2) steps, and k using the 4-5 lowest bits. That's possible since 94 n-1 is highly divisible. 95 Alternatively, separate n and k out into separate arrays. */ 96 struct nk { 97 unsigned int n:27; 98 unsigned int k:5; 99 }; 100 101 static struct nk mpn_fft_table2[2][FFT_TABLE2_SIZE] = 102 { 103 MUL_FFT_TABLE2, 104 SQR_FFT_TABLE2 105 }; 106 107 int 108 mpn_fft_best_k (mp_size_t n, int sqr) 109 { 110 struct nk *tab; 111 int last_k; 112 113 last_k = 4; 114 for (tab = mpn_fft_table2[sqr] + 1; ; tab++) 115 { 116 if (n < tab->n) 117 break; 118 last_k = tab->k; 119 } 120 return last_k; 121 } 122 #endif 123 124 #if !defined (MUL_FFT_TABLE2) || TUNE_PROGRAM_BUILD 125 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 126 { 127 MUL_FFT_TABLE, 128 SQR_FFT_TABLE 129 }; 130 #endif 131 132 #if !defined (MUL_FFT_TABLE2) 133 int 134 mpn_fft_best_k (mp_size_t n, int sqr) 135 { 136 int i; 137 138 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 139 if (n < mpn_fft_table[sqr][i]) 140 return i + FFT_FIRST_K; 141 142 /* treat 4*last as one further entry */ 143 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 144 return i + FFT_FIRST_K; 145 else 146 return i + FFT_FIRST_K + 1; 147 } 148 #endif 149 150 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 151 i.e. smallest multiple of 2^k >= pl. 152 153 Don't declare static: needed by tuneup. 154 */ 155 156 mp_size_t 157 mpn_fft_next_size (mp_size_t pl, int k) 158 { 159 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 160 return pl << k; 161 } 162 163 164 /* Initialize l[i][j] with bitrev(j) */ 165 static void 166 mpn_fft_initl (int **l, int k) 167 { 168 int i, j, K; 169 int *li; 170 171 l[0][0] = 0; 172 for (i = 1, K = 1; i <= k; i++, K *= 2) 173 { 174 li = l[i]; 175 for (j = 0; j < K; j++) 176 { 177 li[j] = 2 * l[i - 1][j]; 178 li[K + j] = 1 + li[j]; 179 } 180 } 181 } 182 183 /* Shift {up, n} of cnt bits to the left, store the complemented result 184 in {rp, n}, and output the shifted bits (not complemented). 185 Same as: 186 cc = mpn_lshift (rp, up, n, cnt); 187 mpn_com_n (rp, rp, n); 188 return cc; 189 190 Assumes n >= 1, 1 < cnt < GMP_NUMB_BITS, rp >= up. 191 */ 192 #ifndef HAVE_NATIVE_mpn_lshiftc 193 #undef mpn_lshiftc 194 static mp_limb_t 195 mpn_lshiftc (mp_ptr rp, mp_srcptr up, mp_size_t n, unsigned int cnt) 196 { 197 mp_limb_t high_limb, low_limb; 198 unsigned int tnc; 199 mp_size_t i; 200 mp_limb_t retval; 201 202 up += n; 203 rp += n; 204 205 tnc = GMP_NUMB_BITS - cnt; 206 low_limb = *--up; 207 retval = low_limb >> tnc; 208 high_limb = (low_limb << cnt); 209 210 for (i = n - 1; i != 0; i--) 211 { 212 low_limb = *--up; 213 *--rp = (~(high_limb | (low_limb >> tnc))) & GMP_NUMB_MASK; 214 high_limb = low_limb << cnt; 215 } 216 *--rp = (~high_limb) & GMP_NUMB_MASK; 217 218 return retval; 219 } 220 #endif 221 222 /* r <- a*2^e mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 223 Assumes a is semi-normalized, i.e. a[n] <= 1. 224 r and a must have n+1 limbs, and not overlap. 225 */ 226 static void 227 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n) 228 { 229 int sh, negate; 230 mp_limb_t cc, rd; 231 232 sh = d % GMP_NUMB_BITS; 233 d /= GMP_NUMB_BITS; 234 negate = d >= n; 235 if (negate) 236 d -= n; 237 238 if (negate) 239 { 240 /* r[0..d-1] <-- lshift(a[n-d]..a[n-1], sh) 241 r[d..n-1] <-- -lshift(a[0]..a[n-d-1], sh) */ 242 if (sh != 0) 243 { 244 /* no out shift below since a[n] <= 1 */ 245 mpn_lshift (r, a + n - d, d + 1, sh); 246 rd = r[d]; 247 cc = mpn_lshiftc (r + d, a, n - d, sh); 248 } 249 else 250 { 251 MPN_COPY (r, a + n - d, d); 252 rd = a[n]; 253 mpn_com_n (r + d, a, n - d); 254 cc = 0; 255 } 256 257 /* add cc to r[0], and add rd to r[d] */ 258 259 /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */ 260 261 r[n] = 0; 262 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 263 cc++; 264 mpn_incr_u (r, cc); 265 266 rd ++; 267 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 268 cc = (rd == 0) ? 1 : rd; 269 r = r + d + (rd == 0); 270 mpn_incr_u (r, cc); 271 272 return; 273 } 274 275 /* if negate=0, 276 r[0..d-1] <-- -lshift(a[n-d]..a[n-1], sh) 277 r[d..n-1] <-- lshift(a[0]..a[n-d-1], sh) 278 */ 279 if (sh != 0) 280 { 281 /* no out bits below since a[n] <= 1 */ 282 mpn_lshiftc (r, a + n - d, d + 1, sh); 283 rd = ~r[d]; 284 /* {r, d+1} = {a+n-d, d+1} << sh */ 285 cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */ 286 } 287 else 288 { 289 /* r[d] is not used below, but we save a test for d=0 */ 290 mpn_com_n (r, a + n - d, d + 1); 291 rd = a[n]; 292 MPN_COPY (r + d, a, n - d); 293 cc = 0; 294 } 295 296 /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */ 297 298 /* if d=0 we just have r[0]=a[n] << sh */ 299 if (d != 0) 300 { 301 /* now add 1 in r[0], subtract 1 in r[d] */ 302 if (cc-- == 0) /* then add 1 to r[0] */ 303 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 304 cc = mpn_sub_1 (r, r, d, cc) + 1; 305 /* add 1 to cc instead of rd since rd might overflow */ 306 } 307 308 /* now subtract cc and rd from r[d..n] */ 309 310 r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc); 311 r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd); 312 if (r[n] & GMP_LIMB_HIGHBIT) 313 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 314 } 315 316 317 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 318 Assumes a and b are semi-normalized. 319 */ 320 static inline void 321 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 322 { 323 mp_limb_t c, x; 324 325 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 326 /* 0 <= c <= 3 */ 327 328 #if 1 329 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 330 result is slower code, of course. But the following outsmarts GCC. */ 331 x = (c - 1) & -(c != 0); 332 r[n] = c - x; 333 MPN_DECR_U (r, n + 1, x); 334 #endif 335 #if 0 336 if (c > 1) 337 { 338 r[n] = 1; /* r[n] - c = 1 */ 339 MPN_DECR_U (r, n + 1, c - 1); 340 } 341 else 342 { 343 r[n] = c; 344 } 345 #endif 346 } 347 348 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 349 Assumes a and b are semi-normalized. 350 */ 351 static inline void 352 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 353 { 354 mp_limb_t c, x; 355 356 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 357 /* -2 <= c <= 1 */ 358 359 #if 1 360 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 361 result is slower code, of course. But the following outsmarts GCC. */ 362 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 363 r[n] = x + c; 364 MPN_INCR_U (r, n + 1, x); 365 #endif 366 #if 0 367 if ((c & GMP_LIMB_HIGHBIT) != 0) 368 { 369 r[n] = 0; 370 MPN_INCR_U (r, n + 1, -c); 371 } 372 else 373 { 374 r[n] = c; 375 } 376 #endif 377 } 378 379 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 380 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 381 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 382 383 static void 384 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 385 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 386 { 387 if (K == 2) 388 { 389 mp_limb_t cy; 390 #if HAVE_NATIVE_mpn_addsub_n 391 cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 392 #else 393 MPN_COPY (tp, Ap[0], n + 1); 394 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 395 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 396 #endif 397 if (Ap[0][n] > 1) /* can be 2 or 3 */ 398 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 399 if (cy) /* Ap[inc][n] can be -1 or -2 */ 400 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 401 } 402 else 403 { 404 int j; 405 int *lk = *ll; 406 407 mpn_fft_fft (Ap, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 408 mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 409 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 410 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 411 for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc) 412 { 413 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 414 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 415 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 416 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 417 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 418 } 419 } 420 } 421 422 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 423 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 424 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 425 tp must have space for 2*(n+1) limbs. 426 */ 427 428 429 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 430 by subtracting that modulus if necessary. 431 432 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 433 borrow and the limbs must be zeroed out again. This will occur very 434 infrequently. */ 435 436 static inline void 437 mpn_fft_normalize (mp_ptr ap, mp_size_t n) 438 { 439 if (ap[n] != 0) 440 { 441 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 442 if (ap[n] == 0) 443 { 444 /* This happens with very low probability; we have yet to trigger it, 445 and thereby make sure this code is correct. */ 446 MPN_ZERO (ap, n); 447 ap[n] = 1; 448 } 449 else 450 ap[n] = 0; 451 } 452 } 453 454 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 455 static void 456 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K) 457 { 458 int i; 459 int sqr = (ap == bp); 460 TMP_DECL; 461 462 TMP_MARK; 463 464 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 465 { 466 int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 467 int **_fft_l; 468 mp_ptr *Ap, *Bp, A, B, T; 469 470 k = mpn_fft_best_k (n, sqr); 471 K2 = 1 << k; 472 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 473 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 474 M2 = n * GMP_NUMB_BITS >> k; 475 l = n >> k; 476 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 477 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 478 nprime2 = Nprime2 / GMP_NUMB_BITS; 479 480 /* we should ensure that nprime2 is a multiple of the next K */ 481 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 482 { 483 unsigned long K3; 484 for (;;) 485 { 486 K3 = 1L << mpn_fft_best_k (nprime2, sqr); 487 if ((nprime2 & (K3 - 1)) == 0) 488 break; 489 nprime2 = (nprime2 + K3 - 1) & -K3; 490 Nprime2 = nprime2 * GMP_LIMB_BITS; 491 /* warning: since nprime2 changed, K3 may change too! */ 492 } 493 } 494 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 495 496 Mp2 = Nprime2 >> k; 497 498 Ap = TMP_ALLOC_MP_PTRS (K2); 499 Bp = TMP_ALLOC_MP_PTRS (K2); 500 A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1)); 501 T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1)); 502 B = A + K2 * (nprime2 + 1); 503 _fft_l = TMP_ALLOC_TYPE (k + 1, int *); 504 for (i = 0; i <= k; i++) 505 _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int); 506 mpn_fft_initl (_fft_l, k); 507 508 TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n, 509 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 510 for (i = 0; i < K; i++, ap++, bp++) 511 { 512 mpn_fft_normalize (*ap, n); 513 if (!sqr) 514 mpn_fft_normalize (*bp, n); 515 mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2, 516 l, Mp2, _fft_l, T, 1); 517 } 518 } 519 else 520 { 521 mp_ptr a, b, tp, tpn; 522 mp_limb_t cc; 523 int n2 = 2 * n; 524 tp = TMP_ALLOC_LIMBS (n2); 525 tpn = tp + n; 526 TRACE (printf (" mpn_mul_n %d of %ld limbs\n", K, n)); 527 for (i = 0; i < K; i++) 528 { 529 a = *ap++; 530 b = *bp++; 531 if (sqr) 532 mpn_sqr_n (tp, a, n); 533 else 534 mpn_mul_n (tp, b, a, n); 535 if (a[n] != 0) 536 cc = mpn_add_n (tpn, tpn, b, n); 537 else 538 cc = 0; 539 if (b[n] != 0) 540 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 541 if (cc != 0) 542 { 543 /* FIXME: use MPN_INCR_U here, since carry is not expected. */ 544 cc = mpn_add_1 (tp, tp, n2, cc); 545 ASSERT (cc == 0); 546 } 547 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 548 } 549 } 550 TMP_FREE; 551 } 552 553 554 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 555 output: K*A[0] K*A[K-1] ... K*A[1]. 556 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 557 This condition is also fulfilled at exit. 558 */ 559 static void 560 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp) 561 { 562 if (K == 2) 563 { 564 mp_limb_t cy; 565 #if HAVE_NATIVE_mpn_addsub_n 566 cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 567 #else 568 MPN_COPY (tp, Ap[0], n + 1); 569 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 570 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 571 #endif 572 if (Ap[0][n] > 1) /* can be 2 or 3 */ 573 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 574 if (cy) /* Ap[1][n] can be -1 or -2 */ 575 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 576 } 577 else 578 { 579 int j, K2 = K >> 1; 580 581 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 582 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 583 /* A[j] <- A[j] + omega^j A[j+K/2] 584 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 585 for (j = 0; j < K2; j++, Ap++) 586 { 587 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 588 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 589 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 590 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 591 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 592 } 593 } 594 } 595 596 597 /* A <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 598 static void 599 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n) 600 { 601 int i; 602 603 ASSERT (r != a); 604 i = 2 * n * GMP_NUMB_BITS; 605 i = (i - k) % i; /* FIXME: This % looks superfluous */ 606 mpn_fft_mul_2exp_modF (r, a, i, n); 607 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 608 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 609 mpn_fft_normalize (r, n); 610 } 611 612 613 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 614 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 615 then {rp,n}=0. 616 */ 617 static int 618 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 619 { 620 mp_size_t l; 621 long int m; 622 mp_limb_t cc; 623 int rpn; 624 625 ASSERT ((n <= an) && (an <= 3 * n)); 626 m = an - 2 * n; 627 if (m > 0) 628 { 629 l = n; 630 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 631 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 632 /* copy {ap+m, n-m} to {rp+m, n-m} */ 633 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 634 } 635 else 636 { 637 l = an - n; /* l <= n */ 638 MPN_COPY (rp, ap, n); 639 rpn = 0; 640 } 641 642 /* remains to subtract {ap+n, l} from {rp, n+1} */ 643 cc = mpn_sub_n (rp, rp, ap + n, l); 644 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 645 if (rpn < 0) /* necessarily rpn = -1 */ 646 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 647 return rpn; 648 } 649 650 /* store in A[0..nprime] the first M bits from {n, nl}, 651 in A[nprime+1..] the following M bits, ... 652 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 653 T must have space for at least (nprime + 1) limbs. 654 We must have nl <= 2*K*l. 655 */ 656 static void 657 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n, 658 mp_size_t nl, int l, int Mp, mp_ptr T) 659 { 660 int i, j; 661 mp_ptr tmp; 662 mp_size_t Kl = K * l; 663 TMP_DECL; 664 TMP_MARK; 665 666 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 667 { 668 mp_size_t dif = nl - Kl; 669 mp_limb_signed_t cy; 670 671 tmp = TMP_ALLOC_LIMBS(Kl + 1); 672 673 if (dif > Kl) 674 { 675 int subp = 0; 676 677 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 678 n += 2 * Kl; 679 dif -= Kl; 680 681 /* now dif > 0 */ 682 while (dif > Kl) 683 { 684 if (subp) 685 cy += mpn_sub_n (tmp, tmp, n, Kl); 686 else 687 cy -= mpn_add_n (tmp, tmp, n, Kl); 688 subp ^= 1; 689 n += Kl; 690 dif -= Kl; 691 } 692 /* now dif <= Kl */ 693 if (subp) 694 cy += mpn_sub (tmp, tmp, Kl, n, dif); 695 else 696 cy -= mpn_add (tmp, tmp, Kl, n, dif); 697 if (cy >= 0) 698 cy = mpn_add_1 (tmp, tmp, Kl, cy); 699 else 700 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 701 } 702 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 703 { 704 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 705 cy = mpn_add_1 (tmp, tmp, Kl, cy); 706 } 707 tmp[Kl] = cy; 708 nl = Kl + 1; 709 n = tmp; 710 } 711 for (i = 0; i < K; i++) 712 { 713 Ap[i] = A; 714 /* store the next M bits of n into A[0..nprime] */ 715 if (nl > 0) /* nl is the number of remaining limbs */ 716 { 717 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 718 nl -= j; 719 MPN_COPY (T, n, j); 720 MPN_ZERO (T + j, nprime + 1 - j); 721 n += l; 722 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 723 } 724 else 725 MPN_ZERO (A, nprime + 1); 726 A += nprime + 1; 727 } 728 ASSERT_ALWAYS (nl == 0); 729 TMP_FREE; 730 } 731 732 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 733 n and m have respectively nl and ml limbs 734 op must have space for pl+1 limbs if rec=1 (and pl limbs if rec=0). 735 One must have pl = mpn_fft_next_size (pl, k). 736 T must have space for 2 * (nprime + 1) limbs. 737 738 If rec=0, then store only the pl low bits of the result, and return 739 the out carry. 740 */ 741 742 static mp_limb_t 743 mpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl, 744 int k, int K, 745 mp_ptr *Ap, mp_ptr *Bp, 746 mp_ptr A, mp_ptr B, 747 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 748 int **_fft_l, 749 mp_ptr T, int rec) 750 { 751 int i, sqr, pla, lo, sh, j; 752 mp_ptr p; 753 mp_limb_t cc; 754 755 sqr = n == m; 756 757 TRACE (printf ("pl=%ld k=%d K=%d np=%ld l=%ld Mp=%ld rec=%d sqr=%d\n", 758 pl,k,K,nprime,l,Mp,rec,sqr)); 759 760 /* decomposition of inputs into arrays Ap[i] and Bp[i] */ 761 if (rec) 762 { 763 mpn_mul_fft_decompose (A, Ap, K, nprime, n, K * l + 1, l, Mp, T); 764 if (!sqr) 765 mpn_mul_fft_decompose (B, Bp, K, nprime, m, K * l + 1, l, Mp, T); 766 } 767 768 /* direct fft's */ 769 mpn_fft_fft (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T); 770 if (!sqr) 771 mpn_fft_fft (Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T); 772 773 /* term to term multiplications */ 774 mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K); 775 776 /* inverse fft's */ 777 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 778 779 /* division of terms after inverse fft */ 780 Bp[0] = T + nprime + 1; 781 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 782 for (i = 1; i < K; i++) 783 { 784 Bp[i] = Ap[i - 1]; 785 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 786 } 787 788 /* addition of terms in result p */ 789 MPN_ZERO (T, nprime + 1); 790 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 791 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 792 MPN_ZERO (p, pla); 793 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 794 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 795 { 796 mp_ptr n = p + sh; 797 798 j = (K - i) & (K - 1); 799 800 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 801 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 802 pla - sh - nprime - 1, CNST_LIMB(1)); 803 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 804 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 805 { /* subtract 2^N'+1 */ 806 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 807 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 808 } 809 } 810 if (cc == -CNST_LIMB(1)) 811 { 812 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 813 { 814 /* p[pla-pl]...p[pla-1] are all zero */ 815 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 816 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 817 } 818 } 819 else if (cc == 1) 820 { 821 if (pla >= 2 * pl) 822 { 823 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 824 ; 825 } 826 else 827 { 828 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 829 ASSERT (cc == 0); 830 } 831 } 832 else 833 ASSERT (cc == 0); 834 835 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 836 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 837 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 838 i = mpn_fft_norm_modF (op, pl, p, pla); 839 if (rec) /* store the carry out */ 840 op[pl] = i; 841 842 return i; 843 } 844 845 /* return the lcm of a and 2^k */ 846 static unsigned long int 847 mpn_mul_fft_lcm (unsigned long int a, unsigned int k) 848 { 849 unsigned long int l = k; 850 851 while (a % 2 == 0 && k > 0) 852 { 853 a >>= 1; 854 k --; 855 } 856 return a << l; 857 } 858 859 860 mp_limb_t 861 mpn_mul_fft (mp_ptr op, mp_size_t pl, 862 mp_srcptr n, mp_size_t nl, 863 mp_srcptr m, mp_size_t ml, 864 int k) 865 { 866 int K, maxLK, i; 867 mp_size_t N, Nprime, nprime, M, Mp, l; 868 mp_ptr *Ap, *Bp, A, T, B; 869 int **_fft_l; 870 int sqr = (n == m && nl == ml); 871 mp_limb_t h; 872 TMP_DECL; 873 874 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 875 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 876 877 TMP_MARK; 878 N = pl * GMP_NUMB_BITS; 879 _fft_l = TMP_ALLOC_TYPE (k + 1, int *); 880 for (i = 0; i <= k; i++) 881 _fft_l[i] = TMP_ALLOC_TYPE (1 << i, int); 882 mpn_fft_initl (_fft_l, k); 883 K = 1 << k; 884 M = N >> k; /* N = 2^k M */ 885 l = 1 + (M - 1) / GMP_NUMB_BITS; 886 maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 887 888 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 889 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 890 nprime = Nprime / GMP_NUMB_BITS; 891 TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n", 892 N, K, M, l, maxLK, Nprime, nprime)); 893 /* we should ensure that recursively, nprime is a multiple of the next K */ 894 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 895 { 896 unsigned long K2; 897 for (;;) 898 { 899 K2 = 1L << mpn_fft_best_k (nprime, sqr); 900 if ((nprime & (K2 - 1)) == 0) 901 break; 902 nprime = (nprime + K2 - 1) & -K2; 903 Nprime = nprime * GMP_LIMB_BITS; 904 /* warning: since nprime changed, K2 may change too! */ 905 } 906 TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 907 } 908 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 909 910 T = TMP_ALLOC_LIMBS (2 * (nprime + 1)); 911 Mp = Nprime >> k; 912 913 TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n", 914 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 915 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 916 917 A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1)); 918 B = A + K * (nprime + 1); 919 Ap = TMP_ALLOC_MP_PTRS (K); 920 Bp = TMP_ALLOC_MP_PTRS (K); 921 922 /* special decomposition for main call */ 923 /* nl is the number of significant limbs in n */ 924 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 925 if (n != m) 926 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 927 928 h = mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0); 929 930 TMP_FREE; 931 __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1)); 932 933 return h; 934 } 935 936 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 937 void 938 mpn_mul_fft_full (mp_ptr op, 939 mp_srcptr n, mp_size_t nl, 940 mp_srcptr m, mp_size_t ml) 941 { 942 mp_ptr pad_op; 943 mp_size_t pl, pl2, pl3, l; 944 int k2, k3; 945 int sqr = (n == m && nl == ml); 946 int cc, c2, oldcc; 947 948 pl = nl + ml; /* total number of limbs of the result */ 949 950 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 951 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 952 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 953 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 954 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 955 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 956 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 957 958 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 959 960 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 961 do 962 { 963 pl2 ++; 964 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 965 pl2 = mpn_fft_next_size (pl2, k2); 966 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 967 thus pl2 / 2 is exact */ 968 k3 = mpn_fft_best_k (pl3, sqr); 969 } 970 while (mpn_fft_next_size (pl3, k3) != pl3); 971 972 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 973 nl, ml, pl2, pl3, k2)); 974 975 ASSERT_ALWAYS(pl3 <= pl); 976 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 977 ASSERT_ALWAYS(cc == 0); 978 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 979 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 980 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 981 /* 0 <= cc <= 1 */ 982 ASSERT_ALWAYS(0 <= cc && cc <= 1); 983 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 984 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 985 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; 986 ASSERT_ALWAYS(-1 <= cc && cc <= 1); 987 if (cc < 0) 988 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 989 ASSERT_ALWAYS(0 <= cc && cc <= 1); 990 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 991 oldcc = cc; 992 #if HAVE_NATIVE_mpn_addsub_n 993 c2 = mpn_addsub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 994 /* c2 & 1 is the borrow, c2 & 2 is the carry */ 995 cc += c2 >> 1; /* carry out from high <- low + high */ 996 c2 = c2 & 1; /* borrow out from low <- low - high */ 997 #else 998 { 999 mp_ptr tmp; 1000 TMP_DECL; 1001 1002 TMP_MARK; 1003 tmp = TMP_ALLOC_LIMBS (l); 1004 MPN_COPY (tmp, pad_op, l); 1005 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 1006 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 1007 TMP_FREE; 1008 } 1009 #endif 1010 c2 += oldcc; 1011 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 1012 at pad_op + l, cc is the carry at pad_op + pl2 */ 1013 /* 0 <= cc <= 2 */ 1014 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 1015 /* -1 <= cc <= 2 */ 1016 if (cc > 0) 1017 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 1018 /* now -1 <= cc <= 0 */ 1019 if (cc < 0) 1020 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 1021 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 1022 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 1023 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 1024 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 1025 out below */ 1026 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 1027 if (cc) /* then cc=1 */ 1028 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 1029 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 1030 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 1031 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 1032 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 1033 MPN_COPY (op + pl3, pad_op, pl - pl3); 1034 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 1035 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 1036 /* since the final result has at most pl limbs, no carry out below */ 1037 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 1038 } 1039