1 /* mpn_toom32_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 1.5 2 times as large as bn. Or more accurately, bn < an < 3bn. 3 4 Contributed to the GNU project by Torbjorn Granlund. 5 Improvements by Marco Bodrato and Niels M�ller. 6 7 The idea of applying toom to unbalanced multiplication is due to Marco 8 Bodrato and Alberto Zanoni. 9 10 THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE. IT IS ONLY 11 SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 12 GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 13 14 Copyright 2006, 2007, 2008, 2009, 2010 Free Software Foundation, Inc. 15 16 This file is part of the GNU MP Library. 17 18 The GNU MP Library is free software; you can redistribute it and/or modify 19 it under the terms of the GNU Lesser General Public License as published by 20 the Free Software Foundation; either version 3 of the License, or (at your 21 option) any later version. 22 23 The GNU MP Library is distributed in the hope that it will be useful, but 24 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 25 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 26 License for more details. 27 28 You should have received a copy of the GNU Lesser General Public License 29 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 30 31 32 #include "gmp.h" 33 #include "gmp-impl.h" 34 35 /* Evaluate in: -1, 0, +1, +inf 36 37 <-s-><--n--><--n--> 38 ___ ______ ______ 39 |a2_|___a1_|___a0_| 40 |_b1_|___b0_| 41 <-t--><--n--> 42 43 v0 = a0 * b0 # A(0)*B(0) 44 v1 = (a0+ a1+ a2)*(b0+ b1) # A(1)*B(1) ah <= 2 bh <= 1 45 vm1 = (a0- a1+ a2)*(b0- b1) # A(-1)*B(-1) |ah| <= 1 bh = 0 46 vinf= a2 * b1 # A(inf)*B(inf) 47 */ 48 49 #define TOOM32_MUL_N_REC(p, a, b, n, ws) \ 50 do { \ 51 mpn_mul_n (p, a, b, n); \ 52 } while (0) 53 54 void 55 mpn_toom32_mul (mp_ptr pp, 56 mp_srcptr ap, mp_size_t an, 57 mp_srcptr bp, mp_size_t bn, 58 mp_ptr scratch) 59 { 60 mp_size_t n, s, t; 61 int vm1_neg; 62 mp_limb_t cy; 63 int hi; 64 mp_limb_t ap1_hi, bp1_hi; 65 66 #define a0 ap 67 #define a1 (ap + n) 68 #define a2 (ap + 2 * n) 69 #define b0 bp 70 #define b1 (bp + n) 71 72 /* Required, to ensure that s + t >= n. */ 73 ASSERT (bn + 2 <= an && an + 6 <= 3*bn); 74 75 n = 1 + (2 * an >= 3 * bn ? (an - 1) / (size_t) 3 : (bn - 1) >> 1); 76 77 s = an - 2 * n; 78 t = bn - n; 79 80 ASSERT (0 < s && s <= n); 81 ASSERT (0 < t && t <= n); 82 ASSERT (s + t >= n); 83 84 /* Product area of size an + bn = 3*n + s + t >= 4*n + 2. */ 85 #define ap1 (pp) /* n, most significant limb in ap1_hi */ 86 #define bp1 (pp + n) /* n, most significant bit in bp1_hi */ 87 #define am1 (pp + 2*n) /* n, most significant bit in hi */ 88 #define bm1 (pp + 3*n) /* n */ 89 #define v1 (scratch) /* 2n + 1 */ 90 #define vm1 (pp) /* 2n + 1 */ 91 #define scratch_out (scratch + 2*n + 1) /* Currently unused. */ 92 93 /* Scratch need: 2*n + 1 + scratch for the recursive multiplications. */ 94 95 /* FIXME: Keep v1[2*n] and vm1[2*n] in scalar variables? */ 96 97 /* Compute ap1 = a0 + a1 + a3, am1 = a0 - a1 + a3 */ 98 ap1_hi = mpn_add (ap1, a0, n, a2, s); 99 #if HAVE_NATIVE_mpn_add_n_sub_n 100 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0) 101 { 102 ap1_hi = mpn_add_n_sub_n (ap1, am1, a1, ap1, n) >> 1; 103 hi = 0; 104 vm1_neg = 1; 105 } 106 else 107 { 108 cy = mpn_add_n_sub_n (ap1, am1, ap1, a1, n); 109 hi = ap1_hi - (cy & 1); 110 ap1_hi += (cy >> 1); 111 vm1_neg = 0; 112 } 113 #else 114 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0) 115 { 116 ASSERT_NOCARRY (mpn_sub_n (am1, a1, ap1, n)); 117 hi = 0; 118 vm1_neg = 1; 119 } 120 else 121 { 122 hi = ap1_hi - mpn_sub_n (am1, ap1, a1, n); 123 vm1_neg = 0; 124 } 125 ap1_hi += mpn_add_n (ap1, ap1, a1, n); 126 #endif 127 128 /* Compute bp1 = b0 + b1 and bm1 = b0 - b1. */ 129 if (t == n) 130 { 131 #if HAVE_NATIVE_mpn_add_n_sub_n 132 if (mpn_cmp (b0, b1, n) < 0) 133 { 134 cy = mpn_add_n_sub_n (bp1, bm1, b1, b0, n); 135 vm1_neg ^= 1; 136 } 137 else 138 { 139 cy = mpn_add_n_sub_n (bp1, bm1, b0, b1, n); 140 } 141 bp1_hi = cy >> 1; 142 #else 143 bp1_hi = mpn_add_n (bp1, b0, b1, n); 144 145 if (mpn_cmp (b0, b1, n) < 0) 146 { 147 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, n)); 148 vm1_neg ^= 1; 149 } 150 else 151 { 152 ASSERT_NOCARRY (mpn_sub_n (bm1, b0, b1, n)); 153 } 154 #endif 155 } 156 else 157 { 158 /* FIXME: Should still use mpn_add_n_sub_n for the main part. */ 159 bp1_hi = mpn_add (bp1, b0, n, b1, t); 160 161 if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0) 162 { 163 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, t)); 164 MPN_ZERO (bm1 + t, n - t); 165 vm1_neg ^= 1; 166 } 167 else 168 { 169 ASSERT_NOCARRY (mpn_sub (bm1, b0, n, b1, t)); 170 } 171 } 172 173 TOOM32_MUL_N_REC (v1, ap1, bp1, n, scratch_out); 174 if (ap1_hi == 1) 175 { 176 cy = bp1_hi + mpn_add_n (v1 + n, v1 + n, bp1, n); 177 } 178 else if (ap1_hi == 2) 179 { 180 #if HAVE_NATIVE_mpn_addlsh1_n 181 cy = 2 * bp1_hi + mpn_addlsh1_n (v1 + n, v1 + n, bp1, n); 182 #else 183 cy = 2 * bp1_hi + mpn_addmul_1 (v1 + n, bp1, n, CNST_LIMB(2)); 184 #endif 185 } 186 else 187 cy = 0; 188 if (bp1_hi != 0) 189 cy += mpn_add_n (v1 + n, v1 + n, ap1, n); 190 v1[2 * n] = cy; 191 192 TOOM32_MUL_N_REC (vm1, am1, bm1, n, scratch_out); 193 if (hi) 194 hi = mpn_add_n (vm1+n, vm1+n, bm1, n); 195 196 vm1[2*n] = hi; 197 198 /* v1 <-- (v1 + vm1) / 2 = x0 + x2 */ 199 if (vm1_neg) 200 { 201 #if HAVE_NATIVE_mpn_rsh1sub_n 202 mpn_rsh1sub_n (v1, v1, vm1, 2*n+1); 203 #else 204 mpn_sub_n (v1, v1, vm1, 2*n+1); 205 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1)); 206 #endif 207 } 208 else 209 { 210 #if HAVE_NATIVE_mpn_rsh1add_n 211 mpn_rsh1add_n (v1, v1, vm1, 2*n+1); 212 #else 213 mpn_add_n (v1, v1, vm1, 2*n+1); 214 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1)); 215 #endif 216 } 217 218 /* We get x1 + x3 = (x0 + x2) - (x0 - x1 + x2 - x3), and hence 219 220 y = x1 + x3 + (x0 + x2) * B 221 = (x0 + x2) * B + (x0 + x2) - vm1. 222 223 y is 3*n + 1 limbs, y = y0 + y1 B + y2 B^2. We store them as 224 follows: y0 at scratch, y1 at pp + 2*n, and y2 at scratch + n 225 (already in place, except for carry propagation). 226 227 We thus add 228 229 B^3 B^2 B 1 230 | | | | 231 +-----+----+ 232 + | x0 + x2 | 233 +----+-----+----+ 234 + | x0 + x2 | 235 +----------+ 236 - | vm1 | 237 --+----++----+----+- 238 | y2 | y1 | y0 | 239 +-----+----+----+ 240 241 Since we store y0 at the same location as the low half of x0 + x2, we 242 need to do the middle sum first. */ 243 244 hi = vm1[2*n]; 245 cy = mpn_add_n (pp + 2*n, v1, v1 + n, n); 246 MPN_INCR_U (v1 + n, n + 1, cy + v1[2*n]); 247 248 /* FIXME: Can we get rid of this second vm1_neg conditional by 249 swapping the location of +1 and -1 values? */ 250 if (vm1_neg) 251 { 252 cy = mpn_add_n (v1, v1, vm1, n); 253 hi += mpn_add_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy); 254 MPN_INCR_U (v1 + n, n+1, hi); 255 } 256 else 257 { 258 cy = mpn_sub_n (v1, v1, vm1, n); 259 hi += mpn_sub_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy); 260 MPN_DECR_U (v1 + n, n+1, hi); 261 } 262 263 TOOM32_MUL_N_REC (pp, a0, b0, n, scratch_out); 264 /* vinf, s+t limbs. Use mpn_mul for now, to handle unbalanced operands */ 265 if (s > t) mpn_mul (pp+3*n, a2, s, b1, t); 266 else mpn_mul (pp+3*n, b1, t, a2, s); 267 268 /* Remaining interpolation. 269 270 y * B + x0 + x3 B^3 - x0 B^2 - x3 B 271 = (x1 + x3) B + (x0 + x2) B^2 + x0 + x3 B^3 - x0 B^2 - x3 B 272 = y0 B + y1 B^2 + y3 B^3 + Lx0 + H x0 B 273 + L x3 B^3 + H x3 B^4 - Lx0 B^2 - H x0 B^3 - L x3 B - H x3 B^2 274 = L x0 + (y0 + H x0 - L x3) B + (y1 - L x0 - H x3) B^2 275 + (y2 - (H x0 - L x3)) B^3 + H x3 B^4 276 277 B^4 B^3 B^2 B 1 278 | | | | | | 279 +-------+ +---------+---------+ 280 | Hx3 | | Hx0-Lx3 | Lx0 | 281 +------+----------+---------+---------+---------+ 282 | y2 | y1 | y0 | 283 ++---------+---------+---------+ 284 -| Hx0-Lx3 | - Lx0 | 285 +---------+---------+ 286 | - Hx3 | 287 +--------+ 288 289 We must take into account the carry from Hx0 - Lx3. 290 */ 291 292 cy = mpn_sub_n (pp + n, pp + n, pp+3*n, n); 293 hi = scratch[2*n] + cy; 294 295 cy = mpn_sub_nc (pp + 2*n, pp + 2*n, pp, n, cy); 296 hi -= mpn_sub_nc (pp + 3*n, scratch + n, pp + n, n, cy); 297 298 hi += mpn_add (pp + n, pp + n, 3*n, scratch, n); 299 300 /* FIXME: Is support for s + t == n needed? */ 301 if (LIKELY (s + t > n)) 302 { 303 hi -= mpn_sub (pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n); 304 305 if (hi < 0) 306 MPN_DECR_U (pp + 4*n, s+t-n, -hi); 307 else 308 MPN_INCR_U (pp + 4*n, s+t-n, hi); 309 } 310 else 311 ASSERT (hi == 0); 312 } 313