1 /* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9 Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 10 2009, 2010 Free Software Foundation, Inc. 11 12 This file is part of the GNU MP Library. 13 14 The GNU MP Library is free software; you can redistribute it and/or modify 15 it under the terms of the GNU Lesser General Public License as published by 16 the Free Software Foundation; either version 3 of the License, or (at your 17 option) any later version. 18 19 The GNU MP Library is distributed in the hope that it will be useful, but 20 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 21 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 22 License for more details. 23 24 You should have received a copy of the GNU Lesser General Public License 25 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 26 27 28 /* References: 29 30 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 31 Strassen, Computing 7, p. 281-292, 1971. 32 33 Asymptotically fast algorithms for the numerical multiplication and division 34 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 35 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 36 37 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 38 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 39 40 TODO: 41 42 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 43 Zimmermann. 44 45 It might be possible to avoid a small number of MPN_COPYs by using a 46 rotating temporary or two. 47 48 Cleanup and simplify the code! 49 */ 50 51 #ifdef TRACE 52 #undef TRACE 53 #define TRACE(x) x 54 #include <stdio.h> 55 #else 56 #define TRACE(x) 57 #endif 58 59 #include "gmp.h" 60 #include "gmp-impl.h" 61 62 #ifdef WANT_ADDSUB 63 #include "generic/add_n_sub_n.c" 64 #define HAVE_NATIVE_mpn_add_n_sub_n 1 65 #endif 66 67 static mp_limb_t mpn_mul_fft_internal 68 __GMP_PROTO ((mp_ptr, mp_size_t, int, mp_ptr *, mp_ptr *, 69 mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, int)); 70 static void mpn_mul_fft_decompose 71 __GMP_PROTO ((mp_ptr, mp_ptr *, int, int, mp_srcptr, mp_size_t, int, int, mp_ptr)); 72 73 74 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 75 We have sqr=0 if for a multiply, sqr=1 for a square. 76 There are three generations of this code; we keep the old ones as long as 77 some gmp-mparam.h is not updated. */ 78 79 80 /*****************************************************************************/ 81 82 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3)) 83 84 #ifndef FFT_TABLE3_SIZE /* When tuning, this is define in gmp-impl.h */ 85 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE) 86 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE 87 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE 88 #else 89 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE 90 #endif 91 #endif 92 #endif 93 94 #ifndef FFT_TABLE3_SIZE 95 #define FFT_TABLE3_SIZE 200 96 #endif 97 98 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] = 99 { 100 MUL_FFT_TABLE3, 101 SQR_FFT_TABLE3 102 }; 103 104 int 105 mpn_fft_best_k (mp_size_t n, int sqr) 106 { 107 FFT_TABLE_ATTRS struct fft_table_nk *fft_tab, *tab; 108 mp_size_t tab_n, thres; 109 int last_k; 110 111 fft_tab = mpn_fft_table3[sqr]; 112 last_k = fft_tab->k; 113 for (tab = fft_tab + 1; ; tab++) 114 { 115 tab_n = tab->n; 116 thres = tab_n << last_k; 117 if (n <= thres) 118 break; 119 last_k = tab->k; 120 } 121 return last_k; 122 } 123 124 #define MPN_FFT_BEST_READY 1 125 #endif 126 127 /*****************************************************************************/ 128 129 #if ! defined (MPN_FFT_BEST_READY) 130 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 131 { 132 MUL_FFT_TABLE, 133 SQR_FFT_TABLE 134 }; 135 136 int 137 mpn_fft_best_k (mp_size_t n, int sqr) 138 { 139 int i; 140 141 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 142 if (n < mpn_fft_table[sqr][i]) 143 return i + FFT_FIRST_K; 144 145 /* treat 4*last as one further entry */ 146 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 147 return i + FFT_FIRST_K; 148 else 149 return i + FFT_FIRST_K + 1; 150 } 151 #endif 152 153 /*****************************************************************************/ 154 155 156 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 157 i.e. smallest multiple of 2^k >= pl. 158 159 Don't declare static: needed by tuneup. 160 */ 161 162 mp_size_t 163 mpn_fft_next_size (mp_size_t pl, int k) 164 { 165 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 166 return pl << k; 167 } 168 169 170 /* Initialize l[i][j] with bitrev(j) */ 171 static void 172 mpn_fft_initl (int **l, int k) 173 { 174 int i, j, K; 175 int *li; 176 177 l[0][0] = 0; 178 for (i = 1, K = 1; i <= k; i++, K *= 2) 179 { 180 li = l[i]; 181 for (j = 0; j < K; j++) 182 { 183 li[j] = 2 * l[i - 1][j]; 184 li[K + j] = 1 + li[j]; 185 } 186 } 187 } 188 189 190 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 191 Assumes a is semi-normalized, i.e. a[n] <= 1. 192 r and a must have n+1 limbs, and not overlap. 193 */ 194 static void 195 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n) 196 { 197 int sh; 198 mp_limb_t cc, rd; 199 200 sh = d % GMP_NUMB_BITS; 201 d /= GMP_NUMB_BITS; 202 203 if (d >= n) /* negate */ 204 { 205 /* r[0..d-1] <-- lshift(a[n-d]..a[n-1], sh) 206 r[d..n-1] <-- -lshift(a[0]..a[n-d-1], sh) */ 207 208 d -= n; 209 if (sh != 0) 210 { 211 /* no out shift below since a[n] <= 1 */ 212 mpn_lshift (r, a + n - d, d + 1, sh); 213 rd = r[d]; 214 cc = mpn_lshiftc (r + d, a, n - d, sh); 215 } 216 else 217 { 218 MPN_COPY (r, a + n - d, d); 219 rd = a[n]; 220 mpn_com (r + d, a, n - d); 221 cc = 0; 222 } 223 224 /* add cc to r[0], and add rd to r[d] */ 225 226 /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */ 227 228 r[n] = 0; 229 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 230 cc++; 231 mpn_incr_u (r, cc); 232 233 rd++; 234 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 235 cc = (rd == 0) ? 1 : rd; 236 r = r + d + (rd == 0); 237 mpn_incr_u (r, cc); 238 } 239 else 240 { 241 /* r[0..d-1] <-- -lshift(a[n-d]..a[n-1], sh) 242 r[d..n-1] <-- lshift(a[0]..a[n-d-1], sh) */ 243 if (sh != 0) 244 { 245 /* no out bits below since a[n] <= 1 */ 246 mpn_lshiftc (r, a + n - d, d + 1, sh); 247 rd = ~r[d]; 248 /* {r, d+1} = {a+n-d, d+1} << sh */ 249 cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */ 250 } 251 else 252 { 253 /* r[d] is not used below, but we save a test for d=0 */ 254 mpn_com (r, a + n - d, d + 1); 255 rd = a[n]; 256 MPN_COPY (r + d, a, n - d); 257 cc = 0; 258 } 259 260 /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */ 261 262 /* if d=0 we just have r[0]=a[n] << sh */ 263 if (d != 0) 264 { 265 /* now add 1 in r[0], subtract 1 in r[d] */ 266 if (cc-- == 0) /* then add 1 to r[0] */ 267 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 268 cc = mpn_sub_1 (r, r, d, cc) + 1; 269 /* add 1 to cc instead of rd since rd might overflow */ 270 } 271 272 /* now subtract cc and rd from r[d..n] */ 273 274 r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc); 275 r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd); 276 if (r[n] & GMP_LIMB_HIGHBIT) 277 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 278 } 279 } 280 281 282 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 283 Assumes a and b are semi-normalized. 284 */ 285 static inline void 286 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 287 { 288 mp_limb_t c, x; 289 290 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 291 /* 0 <= c <= 3 */ 292 293 #if 1 294 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 295 result is slower code, of course. But the following outsmarts GCC. */ 296 x = (c - 1) & -(c != 0); 297 r[n] = c - x; 298 MPN_DECR_U (r, n + 1, x); 299 #endif 300 #if 0 301 if (c > 1) 302 { 303 r[n] = 1; /* r[n] - c = 1 */ 304 MPN_DECR_U (r, n + 1, c - 1); 305 } 306 else 307 { 308 r[n] = c; 309 } 310 #endif 311 } 312 313 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 314 Assumes a and b are semi-normalized. 315 */ 316 static inline void 317 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n) 318 { 319 mp_limb_t c, x; 320 321 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 322 /* -2 <= c <= 1 */ 323 324 #if 1 325 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 326 result is slower code, of course. But the following outsmarts GCC. */ 327 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 328 r[n] = x + c; 329 MPN_INCR_U (r, n + 1, x); 330 #endif 331 #if 0 332 if ((c & GMP_LIMB_HIGHBIT) != 0) 333 { 334 r[n] = 0; 335 MPN_INCR_U (r, n + 1, -c); 336 } 337 else 338 { 339 r[n] = c; 340 } 341 #endif 342 } 343 344 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 345 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 346 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 347 348 static void 349 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 350 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 351 { 352 if (K == 2) 353 { 354 mp_limb_t cy; 355 #if HAVE_NATIVE_mpn_add_n_sub_n 356 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 357 #else 358 MPN_COPY (tp, Ap[0], n + 1); 359 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 360 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 361 #endif 362 if (Ap[0][n] > 1) /* can be 2 or 3 */ 363 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 364 if (cy) /* Ap[inc][n] can be -1 or -2 */ 365 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 366 } 367 else 368 { 369 int j; 370 int *lk = *ll; 371 372 mpn_fft_fft (Ap, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 373 mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp); 374 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 375 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 376 for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc) 377 { 378 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 379 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 380 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 381 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 382 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 383 } 384 } 385 } 386 387 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 388 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 389 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 390 tp must have space for 2*(n+1) limbs. 391 */ 392 393 394 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 395 by subtracting that modulus if necessary. 396 397 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 398 borrow and the limbs must be zeroed out again. This will occur very 399 infrequently. */ 400 401 static inline void 402 mpn_fft_normalize (mp_ptr ap, mp_size_t n) 403 { 404 if (ap[n] != 0) 405 { 406 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 407 if (ap[n] == 0) 408 { 409 /* This happens with very low probability; we have yet to trigger it, 410 and thereby make sure this code is correct. */ 411 MPN_ZERO (ap, n); 412 ap[n] = 1; 413 } 414 else 415 ap[n] = 0; 416 } 417 } 418 419 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 420 static void 421 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K) 422 { 423 int i; 424 int sqr = (ap == bp); 425 TMP_DECL; 426 427 TMP_MARK; 428 429 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 430 { 431 int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 432 int **fft_l; 433 mp_ptr *Ap, *Bp, A, B, T; 434 435 k = mpn_fft_best_k (n, sqr); 436 K2 = 1 << k; 437 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 438 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 439 M2 = n * GMP_NUMB_BITS >> k; 440 l = n >> k; 441 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 442 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 443 nprime2 = Nprime2 / GMP_NUMB_BITS; 444 445 /* we should ensure that nprime2 is a multiple of the next K */ 446 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 447 { 448 unsigned long K3; 449 for (;;) 450 { 451 K3 = 1L << mpn_fft_best_k (nprime2, sqr); 452 if ((nprime2 & (K3 - 1)) == 0) 453 break; 454 nprime2 = (nprime2 + K3 - 1) & -K3; 455 Nprime2 = nprime2 * GMP_LIMB_BITS; 456 /* warning: since nprime2 changed, K3 may change too! */ 457 } 458 } 459 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 460 461 Mp2 = Nprime2 >> k; 462 463 Ap = TMP_ALLOC_MP_PTRS (K2); 464 Bp = TMP_ALLOC_MP_PTRS (K2); 465 A = TMP_ALLOC_LIMBS (2 * (nprime2 + 1) << k); 466 T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1)); 467 B = A + ((nprime2 + 1) << k); 468 fft_l = TMP_ALLOC_TYPE (k + 1, int *); 469 for (i = 0; i <= k; i++) 470 fft_l[i] = TMP_ALLOC_TYPE (1<<i, int); 471 mpn_fft_initl (fft_l, k); 472 473 TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n, 474 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 475 for (i = 0; i < K; i++, ap++, bp++) 476 { 477 mp_limb_t cy; 478 mpn_fft_normalize (*ap, n); 479 if (!sqr) 480 mpn_fft_normalize (*bp, n); 481 482 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T); 483 if (!sqr) 484 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T); 485 486 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2, 487 l, Mp2, fft_l, T, sqr); 488 (*ap)[n] = cy; 489 } 490 } 491 else 492 { 493 mp_ptr a, b, tp, tpn; 494 mp_limb_t cc; 495 int n2 = 2 * n; 496 tp = TMP_ALLOC_LIMBS (n2); 497 tpn = tp + n; 498 TRACE (printf (" mpn_mul_n %d of %ld limbs\n", K, n)); 499 for (i = 0; i < K; i++) 500 { 501 a = *ap++; 502 b = *bp++; 503 if (sqr) 504 mpn_sqr (tp, a, n); 505 else 506 mpn_mul_n (tp, b, a, n); 507 if (a[n] != 0) 508 cc = mpn_add_n (tpn, tpn, b, n); 509 else 510 cc = 0; 511 if (b[n] != 0) 512 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 513 if (cc != 0) 514 { 515 /* FIXME: use MPN_INCR_U here, since carry is not expected. */ 516 cc = mpn_add_1 (tp, tp, n2, cc); 517 ASSERT (cc == 0); 518 } 519 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 520 } 521 } 522 TMP_FREE; 523 } 524 525 526 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 527 output: K*A[0] K*A[K-1] ... K*A[1]. 528 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 529 This condition is also fulfilled at exit. 530 */ 531 static void 532 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp) 533 { 534 if (K == 2) 535 { 536 mp_limb_t cy; 537 #if HAVE_NATIVE_mpn_add_n_sub_n 538 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 539 #else 540 MPN_COPY (tp, Ap[0], n + 1); 541 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 542 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 543 #endif 544 if (Ap[0][n] > 1) /* can be 2 or 3 */ 545 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 546 if (cy) /* Ap[1][n] can be -1 or -2 */ 547 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 548 } 549 else 550 { 551 int j, K2 = K >> 1; 552 553 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 554 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 555 /* A[j] <- A[j] + omega^j A[j+K/2] 556 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 557 for (j = 0; j < K2; j++, Ap++) 558 { 559 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 560 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 561 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 562 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 563 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 564 } 565 } 566 } 567 568 569 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 570 static void 571 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n) 572 { 573 int i; 574 575 ASSERT (r != a); 576 i = 2 * n * GMP_NUMB_BITS - k; 577 mpn_fft_mul_2exp_modF (r, a, i, n); 578 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 579 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 580 mpn_fft_normalize (r, n); 581 } 582 583 584 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 585 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 586 then {rp,n}=0. 587 */ 588 static int 589 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 590 { 591 mp_size_t l; 592 long int m; 593 mp_limb_t cc; 594 int rpn; 595 596 ASSERT ((n <= an) && (an <= 3 * n)); 597 m = an - 2 * n; 598 if (m > 0) 599 { 600 l = n; 601 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 602 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 603 /* copy {ap+m, n-m} to {rp+m, n-m} */ 604 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 605 } 606 else 607 { 608 l = an - n; /* l <= n */ 609 MPN_COPY (rp, ap, n); 610 rpn = 0; 611 } 612 613 /* remains to subtract {ap+n, l} from {rp, n+1} */ 614 cc = mpn_sub_n (rp, rp, ap + n, l); 615 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 616 if (rpn < 0) /* necessarily rpn = -1 */ 617 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 618 return rpn; 619 } 620 621 /* store in A[0..nprime] the first M bits from {n, nl}, 622 in A[nprime+1..] the following M bits, ... 623 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 624 T must have space for at least (nprime + 1) limbs. 625 We must have nl <= 2*K*l. 626 */ 627 static void 628 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n, 629 mp_size_t nl, int l, int Mp, mp_ptr T) 630 { 631 int i, j; 632 mp_ptr tmp; 633 mp_size_t Kl = K * l; 634 TMP_DECL; 635 TMP_MARK; 636 637 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 638 { 639 mp_size_t dif = nl - Kl; 640 mp_limb_signed_t cy; 641 642 tmp = TMP_ALLOC_LIMBS(Kl + 1); 643 644 if (dif > Kl) 645 { 646 int subp = 0; 647 648 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 649 n += 2 * Kl; 650 dif -= Kl; 651 652 /* now dif > 0 */ 653 while (dif > Kl) 654 { 655 if (subp) 656 cy += mpn_sub_n (tmp, tmp, n, Kl); 657 else 658 cy -= mpn_add_n (tmp, tmp, n, Kl); 659 subp ^= 1; 660 n += Kl; 661 dif -= Kl; 662 } 663 /* now dif <= Kl */ 664 if (subp) 665 cy += mpn_sub (tmp, tmp, Kl, n, dif); 666 else 667 cy -= mpn_add (tmp, tmp, Kl, n, dif); 668 if (cy >= 0) 669 cy = mpn_add_1 (tmp, tmp, Kl, cy); 670 else 671 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 672 } 673 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 674 { 675 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 676 cy = mpn_add_1 (tmp, tmp, Kl, cy); 677 } 678 tmp[Kl] = cy; 679 nl = Kl + 1; 680 n = tmp; 681 } 682 for (i = 0; i < K; i++) 683 { 684 Ap[i] = A; 685 /* store the next M bits of n into A[0..nprime] */ 686 if (nl > 0) /* nl is the number of remaining limbs */ 687 { 688 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 689 nl -= j; 690 MPN_COPY (T, n, j); 691 MPN_ZERO (T + j, nprime + 1 - j); 692 n += l; 693 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 694 } 695 else 696 MPN_ZERO (A, nprime + 1); 697 A += nprime + 1; 698 } 699 ASSERT_ALWAYS (nl == 0); 700 TMP_FREE; 701 } 702 703 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 704 op is pl limbs, its high bit is returned. 705 One must have pl = mpn_fft_next_size (pl, k). 706 T must have space for 2 * (nprime + 1) limbs. 707 */ 708 709 static mp_limb_t 710 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k, 711 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B, 712 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 713 int **fft_l, mp_ptr T, int sqr) 714 { 715 int K, i, pla, lo, sh, j; 716 mp_ptr p; 717 mp_limb_t cc; 718 719 K = 1 << k; 720 721 /* direct fft's */ 722 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T); 723 if (!sqr) 724 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T); 725 726 /* term to term multiplications */ 727 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K); 728 729 /* inverse fft's */ 730 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 731 732 /* division of terms after inverse fft */ 733 Bp[0] = T + nprime + 1; 734 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 735 for (i = 1; i < K; i++) 736 { 737 Bp[i] = Ap[i - 1]; 738 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 739 } 740 741 /* addition of terms in result p */ 742 MPN_ZERO (T, nprime + 1); 743 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 744 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 745 MPN_ZERO (p, pla); 746 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 747 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 748 { 749 mp_ptr n = p + sh; 750 751 j = (K - i) & (K - 1); 752 753 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 754 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 755 pla - sh - nprime - 1, CNST_LIMB(1)); 756 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 757 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 758 { /* subtract 2^N'+1 */ 759 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 760 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 761 } 762 } 763 if (cc == -CNST_LIMB(1)) 764 { 765 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 766 { 767 /* p[pla-pl]...p[pla-1] are all zero */ 768 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 769 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 770 } 771 } 772 else if (cc == 1) 773 { 774 if (pla >= 2 * pl) 775 { 776 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 777 ; 778 } 779 else 780 { 781 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 782 ASSERT (cc == 0); 783 } 784 } 785 else 786 ASSERT (cc == 0); 787 788 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 789 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 790 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 791 return mpn_fft_norm_modF (op, pl, p, pla); 792 } 793 794 /* return the lcm of a and 2^k */ 795 static unsigned long int 796 mpn_mul_fft_lcm (unsigned long int a, unsigned int k) 797 { 798 unsigned long int l = k; 799 800 while (a % 2 == 0 && k > 0) 801 { 802 a >>= 1; 803 k --; 804 } 805 return a << l; 806 } 807 808 809 mp_limb_t 810 mpn_mul_fft (mp_ptr op, mp_size_t pl, 811 mp_srcptr n, mp_size_t nl, 812 mp_srcptr m, mp_size_t ml, 813 int k) 814 { 815 int K, maxLK, i; 816 mp_size_t N, Nprime, nprime, M, Mp, l; 817 mp_ptr *Ap, *Bp, A, T, B; 818 int **fft_l; 819 int sqr = (n == m && nl == ml); 820 mp_limb_t h; 821 TMP_DECL; 822 823 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 824 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 825 826 TMP_MARK; 827 N = pl * GMP_NUMB_BITS; 828 fft_l = TMP_ALLOC_TYPE (k + 1, int *); 829 for (i = 0; i <= k; i++) 830 fft_l[i] = TMP_ALLOC_TYPE (1 << i, int); 831 mpn_fft_initl (fft_l, k); 832 K = 1 << k; 833 M = N >> k; /* N = 2^k M */ 834 l = 1 + (M - 1) / GMP_NUMB_BITS; 835 maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 836 837 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 838 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 839 nprime = Nprime / GMP_NUMB_BITS; 840 TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n", 841 N, K, M, l, maxLK, Nprime, nprime)); 842 /* we should ensure that recursively, nprime is a multiple of the next K */ 843 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 844 { 845 unsigned long K2; 846 for (;;) 847 { 848 K2 = 1L << mpn_fft_best_k (nprime, sqr); 849 if ((nprime & (K2 - 1)) == 0) 850 break; 851 nprime = (nprime + K2 - 1) & -K2; 852 Nprime = nprime * GMP_LIMB_BITS; 853 /* warning: since nprime changed, K2 may change too! */ 854 } 855 TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 856 } 857 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 858 859 T = TMP_ALLOC_LIMBS (2 * (nprime + 1)); 860 Mp = Nprime >> k; 861 862 TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n", 863 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 864 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 865 866 A = TMP_ALLOC_LIMBS (K * (nprime + 1)); 867 Ap = TMP_ALLOC_MP_PTRS (K); 868 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 869 if (sqr) 870 { 871 mp_size_t pla; 872 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 873 B = TMP_ALLOC_LIMBS (pla); 874 Bp = TMP_ALLOC_MP_PTRS (K); 875 } 876 else 877 { 878 B = TMP_ALLOC_LIMBS (K * (nprime + 1)); 879 Bp = TMP_ALLOC_MP_PTRS (K); 880 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 881 } 882 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr); 883 884 TMP_FREE; 885 return h; 886 } 887 888 #if WANT_OLD_FFT_FULL 889 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 890 void 891 mpn_mul_fft_full (mp_ptr op, 892 mp_srcptr n, mp_size_t nl, 893 mp_srcptr m, mp_size_t ml) 894 { 895 mp_ptr pad_op; 896 mp_size_t pl, pl2, pl3, l; 897 int k2, k3; 898 int sqr = (n == m && nl == ml); 899 int cc, c2, oldcc; 900 901 pl = nl + ml; /* total number of limbs of the result */ 902 903 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 904 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 905 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 906 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 907 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 908 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 909 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 910 911 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 912 913 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 914 do 915 { 916 pl2++; 917 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 918 pl2 = mpn_fft_next_size (pl2, k2); 919 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 920 thus pl2 / 2 is exact */ 921 k3 = mpn_fft_best_k (pl3, sqr); 922 } 923 while (mpn_fft_next_size (pl3, k3) != pl3); 924 925 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 926 nl, ml, pl2, pl3, k2)); 927 928 ASSERT_ALWAYS(pl3 <= pl); 929 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 930 ASSERT(cc == 0); 931 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 932 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 933 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 934 /* 0 <= cc <= 1 */ 935 ASSERT(0 <= cc && cc <= 1); 936 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 937 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 938 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; 939 ASSERT(-1 <= cc && cc <= 1); 940 if (cc < 0) 941 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 942 ASSERT(0 <= cc && cc <= 1); 943 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 944 oldcc = cc; 945 #if HAVE_NATIVE_mpn_add_n_sub_n 946 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 947 /* c2 & 1 is the borrow, c2 & 2 is the carry */ 948 cc += c2 >> 1; /* carry out from high <- low + high */ 949 c2 = c2 & 1; /* borrow out from low <- low - high */ 950 #else 951 { 952 mp_ptr tmp; 953 TMP_DECL; 954 955 TMP_MARK; 956 tmp = TMP_ALLOC_LIMBS (l); 957 MPN_COPY (tmp, pad_op, l); 958 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 959 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 960 TMP_FREE; 961 } 962 #endif 963 c2 += oldcc; 964 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 965 at pad_op + l, cc is the carry at pad_op + pl2 */ 966 /* 0 <= cc <= 2 */ 967 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 968 /* -1 <= cc <= 2 */ 969 if (cc > 0) 970 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 971 /* now -1 <= cc <= 0 */ 972 if (cc < 0) 973 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 974 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 975 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 976 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 977 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 978 out below */ 979 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 980 if (cc) /* then cc=1 */ 981 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 982 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 983 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 984 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 985 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 986 MPN_COPY (op + pl3, pad_op, pl - pl3); 987 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 988 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 989 /* since the final result has at most pl limbs, no carry out below */ 990 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 991 } 992 #endif 993