1 /* mpn_toom53_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 5/3 2 times as large as bn. Or more accurately, (4/3)bn < an < (5/2)bn. 3 4 Contributed to the GNU project by Torbjorn Granlund and Marco Bodrato. 5 6 The idea of applying toom to unbalanced multiplication is due to by Marco 7 Bodrato and Alberto Zanoni. 8 9 THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE. IT IS ONLY 10 SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 11 GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 12 13 Copyright 2006, 2007, 2008 Free Software Foundation, Inc. 14 15 This file is part of the GNU MP Library. 16 17 The GNU MP Library is free software; you can redistribute it and/or modify 18 it under the terms of the GNU Lesser General Public License as published by 19 the Free Software Foundation; either version 3 of the License, or (at your 20 option) any later version. 21 22 The GNU MP Library is distributed in the hope that it will be useful, but 23 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 24 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 25 License for more details. 26 27 You should have received a copy of the GNU Lesser General Public License 28 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 29 30 31 /* 32 Things to work on: 33 34 1. Trim allocation. The allocations for as1, asm1, bs1, and bsm1 could be 35 avoided by instead reusing the pp area and the scratch allocation. 36 */ 37 38 #include "gmp.h" 39 #include "gmp-impl.h" 40 41 /* Evaluate in: -1, -1/2, 0, +1/2, +1, +2, +inf 42 43 <-s-><--n--><--n--><--n--><--n--> 44 ___ ______ ______ ______ ______ 45 |a4_|___a3_|___a2_|___a1_|___a0_| 46 |__b2|___b1_|___b0_| 47 <-t--><--n--><--n--> 48 49 v0 = a0 * b0 # A(0)*B(0) 50 v1 = ( a0+ a1+ a2+ a3+ a4)*( b0+ b1+ b2) # A(1)*B(1) ah <= 4 bh <= 2 51 vm1 = ( a0- a1+ a2- a3+ a4)*( b0- b1+ b2) # A(-1)*B(-1) |ah| <= 2 bh <= 1 52 v2 = ( a0+2a1+4a2+8a3+16a4)*( b0+2b1+4b2) # A(2)*B(2) ah <= 30 bh <= 6 53 vh = (16a0+8a1+4a2+2a3+ a4)*(4b0+2b1+ b2) # A(1/2)*B(1/2) ah <= 30 bh <= 6 54 vmh = (16a0-8a1+4a2-2a3+ a4)*(4b0-2b1+ b2) # A(-1/2)*B(-1/2) -9<=ah<=20 -1<=bh<=4 55 vinf= a4 * b2 # A(inf)*B(inf) 56 */ 57 58 void 59 mpn_toom53_mul (mp_ptr pp, 60 mp_srcptr ap, mp_size_t an, 61 mp_srcptr bp, mp_size_t bn, 62 mp_ptr scratch) 63 { 64 mp_size_t n, s, t; 65 int vm1_neg, vmh_neg; 66 mp_limb_t cy; 67 mp_ptr gp, hp; 68 mp_ptr as1, asm1, as2, ash, asmh; 69 mp_ptr bs1, bsm1, bs2, bsh, bsmh; 70 enum toom4_flags flags; 71 TMP_DECL; 72 73 #define a0 ap 74 #define a1 (ap + n) 75 #define a2 (ap + 2*n) 76 #define a3 (ap + 3*n) 77 #define a4 (ap + 4*n) 78 #define b0 bp 79 #define b1 (bp + n) 80 #define b2 (bp + 2*n) 81 82 n = 1 + (3 * an >= 5 * bn ? (an - 1) / (size_t) 5 : (bn - 1) / (size_t) 3); 83 84 s = an - 4 * n; 85 t = bn - 2 * n; 86 87 ASSERT (0 < s && s <= n); 88 ASSERT (0 < t && t <= n); 89 90 TMP_MARK; 91 92 as1 = TMP_SALLOC_LIMBS (n + 1); 93 asm1 = TMP_SALLOC_LIMBS (n + 1); 94 as2 = TMP_SALLOC_LIMBS (n + 1); 95 ash = TMP_SALLOC_LIMBS (n + 1); 96 asmh = TMP_SALLOC_LIMBS (n + 1); 97 98 bs1 = TMP_SALLOC_LIMBS (n + 1); 99 bsm1 = TMP_SALLOC_LIMBS (n + 1); 100 bs2 = TMP_SALLOC_LIMBS (n + 1); 101 bsh = TMP_SALLOC_LIMBS (n + 1); 102 bsmh = TMP_SALLOC_LIMBS (n + 1); 103 104 gp = pp; 105 hp = pp + n + 1; 106 107 /* Compute as1 and asm1. */ 108 gp[n] = mpn_add_n (gp, a0, a2, n); 109 gp[n] += mpn_add (gp, gp, n, a4, s); 110 hp[n] = mpn_add_n (hp, a1, a3, n); 111 #if HAVE_NATIVE_mpn_addsub_n 112 if (mpn_cmp (gp, hp, n + 1) < 0) 113 { 114 mpn_addsub_n (as1, asm1, hp, gp, n + 1); 115 vm1_neg = 1; 116 } 117 else 118 { 119 mpn_addsub_n (as1, asm1, gp, hp, n + 1); 120 vm1_neg = 0; 121 } 122 #else 123 mpn_add_n (as1, gp, hp, n + 1); 124 if (mpn_cmp (gp, hp, n + 1) < 0) 125 { 126 mpn_sub_n (asm1, hp, gp, n + 1); 127 vm1_neg = 1; 128 } 129 else 130 { 131 mpn_sub_n (asm1, gp, hp, n + 1); 132 vm1_neg = 0; 133 } 134 #endif 135 136 /* Compute as2. */ 137 #if !HAVE_NATIVE_mpn_addlsh_n 138 ash[n] = mpn_lshift (ash, a2, n, 2); /* 4a2 */ 139 #endif 140 #if HAVE_NATIVE_mpn_addlsh1_n 141 cy = mpn_addlsh1_n (as2, a3, a4, s); 142 if (s != n) 143 cy = mpn_add_1 (as2 + s, a3 + s, n - s, cy); 144 cy = 2 * cy + mpn_addlsh1_n (as2, a2, as2, n); 145 cy = 2 * cy + mpn_addlsh1_n (as2, a1, as2, n); 146 as2[n] = 2 * cy + mpn_addlsh1_n (as2, a0, as2, n); 147 #else 148 cy = mpn_lshift (as2, a4, s, 1); 149 cy += mpn_add_n (as2, a3, as2, s); 150 if (s != n) 151 cy = mpn_add_1 (as2 + s, a3 + s, n - s, cy); 152 cy = 4 * cy + mpn_lshift (as2, as2, n, 2); 153 cy += mpn_add_n (as2, a1, as2, n); 154 cy = 2 * cy + mpn_lshift (as2, as2, n, 1); 155 as2[n] = cy + mpn_add_n (as2, a0, as2, n); 156 mpn_add_n (as2, ash, as2, n + 1); 157 #endif 158 159 /* Compute ash and asmh. */ 160 #if HAVE_NATIVE_mpn_addlsh_n 161 cy = mpn_addlsh_n (gp, a2, a0, n, 2); /* 4a0 + a2 */ 162 cy = 4 * cy + mpn_addlsh_n (gp, a4, gp, n, 2); /* 16a0 + 4a2 + a4 */ /* FIXME s */ 163 gp[n] = cy; 164 cy = mpn_addlsh_n (hp, a3, a1, n, 2); /* 4a1 + a3 */ 165 cy = 2 * cy + mpn_lshift (hp, hp, n, 1); /* 8a1 + 2a3 */ 166 hp[n] = cy; 167 #else 168 gp[n] = mpn_lshift (gp, a0, n, 4); /* 16a0 */ 169 mpn_add (gp, gp, n + 1, a4, s); /* 16a0 + a4 */ 170 mpn_add_n (gp, ash, gp, n+1); /* 16a0 + 4a2 + a4 */ 171 cy = mpn_lshift (hp, a1, n, 3); /* 8a1 */ 172 cy += mpn_lshift (ash, a3, n, 1); /* 2a3 */ 173 cy += mpn_add_n (hp, ash, hp, n); /* 8a1 + 2a3 */ 174 hp[n] = cy; 175 #endif 176 #if HAVE_NATIVE_mpn_addsub_n 177 if (mpn_cmp (gp, hp, n + 1) < 0) 178 { 179 mpn_addsub_n (ash, asmh, hp, gp, n + 1); 180 vmh_neg = 1; 181 } 182 else 183 { 184 mpn_addsub_n (ash, asmh, gp, hp, n + 1); 185 vmh_neg = 0; 186 } 187 #else 188 mpn_add_n (ash, gp, hp, n + 1); 189 if (mpn_cmp (gp, hp, n + 1) < 0) 190 { 191 mpn_sub_n (asmh, hp, gp, n + 1); 192 vmh_neg = 1; 193 } 194 else 195 { 196 mpn_sub_n (asmh, gp, hp, n + 1); 197 vmh_neg = 0; 198 } 199 #endif 200 201 /* Compute bs1 and bsm1. */ 202 bs1[n] = mpn_add (bs1, b0, n, b2, t); /* b0 + b2 */ 203 #if HAVE_NATIVE_mpn_addsub_n 204 if (bs1[n] == 0 && mpn_cmp (bs1, b1, n) < 0) 205 { 206 bs1[n] = mpn_addsub_n (bs1, bsm1, b1, bs1, n) >> 1; 207 bsm1[n] = 0; 208 vm1_neg ^= 1; 209 } 210 else 211 { 212 cy = mpn_addsub_n (bs1, bsm1, bs1, b1, n); 213 bsm1[n] = bs1[n] - (cy & 1); 214 bs1[n] += (cy >> 1); 215 } 216 #else 217 if (bs1[n] == 0 && mpn_cmp (bs1, b1, n) < 0) 218 { 219 mpn_sub_n (bsm1, b1, bs1, n); 220 bsm1[n] = 0; 221 vm1_neg ^= 1; 222 } 223 else 224 { 225 bsm1[n] = bs1[n] - mpn_sub_n (bsm1, bs1, b1, n); 226 } 227 bs1[n] += mpn_add_n (bs1, bs1, b1, n); /* b0+b1+b2 */ 228 #endif 229 230 /* Compute bs2 */ 231 hp[n] = mpn_lshift (hp, b1, n, 1); /* 2b1 */ 232 233 #ifdef HAVE_NATIVE_mpn_addlsh1_n 234 cy = mpn_addlsh1_n (bs2, b1, b2, t); 235 if (t != n) 236 cy = mpn_add_1 (bs2 + t, b1 + t, n - t, cy); 237 bs2[n] = 2 * cy + mpn_addlsh1_n (bs2, b0, bs2, n); 238 #else 239 bs2[t] = mpn_lshift (bs2, b2, t, 2); 240 mpn_add (bs2, hp, n + 1, bs2, t + 1); 241 bs2[n] += mpn_add_n (bs2, bs2, b0, n); 242 #endif 243 244 /* Compute bsh and bsmh. */ 245 #if HAVE_NATIVE_mpn_addlsh_n 246 gp[n] = mpn_addlsh_n (gp, b2, b0, n, 2); /* 4a0 + a2 */ 247 #else 248 cy = mpn_lshift (gp, b0, n, 2); /* 4b0 */ 249 gp[n] = cy + mpn_add (gp, gp, n, b2, t); /* 4b0 + b2 */ 250 #endif 251 #if HAVE_NATIVE_mpn_addsub_n 252 if (mpn_cmp (gp, hp, n + 1) < 0) 253 { 254 mpn_addsub_n (bsh, bsmh, hp, gp, n + 1); 255 vmh_neg^= 1; 256 } 257 else 258 mpn_addsub_n (bsh, bsmh, gp, hp, n + 1); 259 #else 260 mpn_add_n (bsh, gp, hp, n + 1); /* 4b0 + 2b1 + b2 */ 261 if (mpn_cmp (gp, hp, n + 1) < 0) 262 { 263 mpn_sub_n (bsmh, hp, gp, n + 1); 264 vmh_neg ^= 1; 265 } 266 else 267 { 268 mpn_sub_n (bsmh, gp, hp, n + 1); 269 } 270 #endif 271 272 ASSERT (as1[n] <= 4); 273 ASSERT (bs1[n] <= 2); 274 ASSERT (asm1[n] <= 2); 275 ASSERT (bsm1[n] <= 1); 276 ASSERT (as2[n] <= 30); 277 ASSERT (bs2[n] <= 6); 278 ASSERT (ash[n] <= 30); 279 ASSERT (bsh[n] <= 6); 280 ASSERT (asmh[n] <= 20); 281 ASSERT (bsmh[n] <= 4); 282 283 #define v0 pp /* 2n */ 284 #define v1 (scratch + 6 * n + 6) /* 2n+1 */ 285 #define vm1 scratch /* 2n+1 */ 286 #define v2 (scratch + 2 * n + 2) /* 2n+1 */ 287 #define vinf (pp + 6 * n) /* s+t */ 288 #define vh (pp + 2 * n) /* 2n+1 */ 289 #define vmh (scratch + 4 * n + 4) 290 291 /* vm1, 2n+1 limbs */ 292 #ifdef SMALLER_RECURSION 293 mpn_mul_n (vm1, asm1, bsm1, n); 294 if (asm1[n] == 1) 295 { 296 cy = bsm1[n] + mpn_add_n (vm1 + n, vm1 + n, bsm1, n); 297 } 298 else if (asm1[n] == 2) 299 { 300 #if HAVE_NATIVE_mpn_addlsh1_n 301 cy = 2 * bsm1[n] + mpn_addlsh1_n (vm1 + n, vm1 + n, bsm1, n); 302 #else 303 cy = 2 * bsm1[n] + mpn_addmul_1 (vm1 + n, bsm1, n, CNST_LIMB(2)); 304 #endif 305 } 306 else 307 cy = 0; 308 if (bsm1[n] != 0) 309 cy += mpn_add_n (vm1 + n, vm1 + n, asm1, n); 310 vm1[2 * n] = cy; 311 #else /* SMALLER_RECURSION */ 312 vm1[2 * n] = 0; 313 mpn_mul_n (vm1, asm1, bsm1, n + ((asm1[n] | bsm1[n]) != 0)); 314 #endif /* SMALLER_RECURSION */ 315 316 mpn_mul_n (v2, as2, bs2, n + 1); /* v2, 2n+1 limbs */ 317 318 /* vinf, s+t limbs */ 319 if (s > t) mpn_mul (vinf, a4, s, b2, t); 320 else mpn_mul (vinf, b2, t, a4, s); 321 322 /* v1, 2n+1 limbs */ 323 #ifdef SMALLER_RECURSION 324 mpn_mul_n (v1, as1, bs1, n); 325 if (as1[n] == 1) 326 { 327 cy = bs1[n] + mpn_add_n (v1 + n, v1 + n, bs1, n); 328 } 329 else if (as1[n] == 2) 330 { 331 #if HAVE_NATIVE_mpn_addlsh1_n 332 cy = 2 * bs1[n] + mpn_addlsh1_n (v1 + n, v1 + n, bs1, n); 333 #else 334 cy = 2 * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, CNST_LIMB(2)); 335 #endif 336 } 337 else if (as1[n] != 0) 338 { 339 cy = as1[n] * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, as1[n]); 340 } 341 else 342 cy = 0; 343 if (bs1[n] == 1) 344 { 345 cy += mpn_add_n (v1 + n, v1 + n, as1, n); 346 } 347 else if (bs1[n] == 2) 348 { 349 #if HAVE_NATIVE_mpn_addlsh1_n 350 cy += mpn_addlsh1_n (v1 + n, v1 + n, as1, n); 351 #else 352 cy += mpn_addmul_1 (v1 + n, as1, n, CNST_LIMB(2)); 353 #endif 354 } 355 v1[2 * n] = cy; 356 #else /* SMALLER_RECURSION */ 357 v1[2 * n] = 0; 358 mpn_mul_n (v1, as1, bs1, n + ((as1[n] | bs1[n]) != 0)); 359 #endif /* SMALLER_RECURSION */ 360 361 mpn_mul_n (vh, ash, bsh, n + 1); 362 363 mpn_mul_n (vmh, asmh, bsmh, n + 1); 364 365 mpn_mul_n (v0, ap, bp, n); /* v0, 2n limbs */ 366 367 flags = vm1_neg ? toom4_w3_neg : 0; 368 flags |= vmh_neg ? toom4_w1_neg : 0; 369 370 mpn_toom_interpolate_7pts (pp, n, flags, vmh, vm1, v1, v2, s + t, scratch + 8 * n + 8); 371 372 TMP_FREE; 373 } 374