1 /* mpn_toom32_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 1.5
2 times as large as bn. Or more accurately, bn < an < 3bn.
3
4 Contributed to the GNU project by Torbjorn Granlund.
5 Improvements by Marco Bodrato and Niels M�ller.
6
7 The idea of applying toom to unbalanced multiplication is due to Marco
8 Bodrato and Alberto Zanoni.
9
10 THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE. IT IS ONLY
11 SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
12 GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
13
14 Copyright 2006, 2007, 2008, 2009, 2010 Free Software Foundation, Inc.
15
16 This file is part of the GNU MP Library.
17
18 The GNU MP Library is free software; you can redistribute it and/or modify
19 it under the terms of the GNU Lesser General Public License as published by
20 the Free Software Foundation; either version 3 of the License, or (at your
21 option) any later version.
22
23 The GNU MP Library is distributed in the hope that it will be useful, but
24 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
25 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
26 License for more details.
27
28 You should have received a copy of the GNU Lesser General Public License
29 along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */
30
31
32 #include "gmp.h"
33 #include "gmp-impl.h"
34
35 /* Evaluate in: -1, 0, +1, +inf
36
37 <-s-><--n--><--n-->
38 ___ ______ ______
39 |a2_|___a1_|___a0_|
40 |_b1_|___b0_|
41 <-t--><--n-->
42
43 v0 = a0 * b0 # A(0)*B(0)
44 v1 = (a0+ a1+ a2)*(b0+ b1) # A(1)*B(1) ah <= 2 bh <= 1
45 vm1 = (a0- a1+ a2)*(b0- b1) # A(-1)*B(-1) |ah| <= 1 bh = 0
46 vinf= a2 * b1 # A(inf)*B(inf)
47 */
48
49 #define TOOM32_MUL_N_REC(p, a, b, n, ws) \
50 do { \
51 mpn_mul_n (p, a, b, n); \
52 } while (0)
53
54 void
mpn_toom32_mul(mp_ptr pp,mp_srcptr ap,mp_size_t an,mp_srcptr bp,mp_size_t bn,mp_ptr scratch)55 mpn_toom32_mul (mp_ptr pp,
56 mp_srcptr ap, mp_size_t an,
57 mp_srcptr bp, mp_size_t bn,
58 mp_ptr scratch)
59 {
60 mp_size_t n, s, t;
61 int vm1_neg;
62 mp_limb_t cy;
63 int hi;
64 mp_limb_t ap1_hi, bp1_hi;
65
66 #define a0 ap
67 #define a1 (ap + n)
68 #define a2 (ap + 2 * n)
69 #define b0 bp
70 #define b1 (bp + n)
71
72 /* Required, to ensure that s + t >= n. */
73 ASSERT (bn + 2 <= an && an + 6 <= 3*bn);
74
75 n = 1 + (2 * an >= 3 * bn ? (an - 1) / (size_t) 3 : (bn - 1) >> 1);
76
77 s = an - 2 * n;
78 t = bn - n;
79
80 ASSERT (0 < s && s <= n);
81 ASSERT (0 < t && t <= n);
82 ASSERT (s + t >= n);
83
84 /* Product area of size an + bn = 3*n + s + t >= 4*n + 2. */
85 #define ap1 (pp) /* n, most significant limb in ap1_hi */
86 #define bp1 (pp + n) /* n, most significant bit in bp1_hi */
87 #define am1 (pp + 2*n) /* n, most significant bit in hi */
88 #define bm1 (pp + 3*n) /* n */
89 #define v1 (scratch) /* 2n + 1 */
90 #define vm1 (pp) /* 2n + 1 */
91 #define scratch_out (scratch + 2*n + 1) /* Currently unused. */
92
93 /* Scratch need: 2*n + 1 + scratch for the recursive multiplications. */
94
95 /* FIXME: Keep v1[2*n] and vm1[2*n] in scalar variables? */
96
97 /* Compute ap1 = a0 + a1 + a3, am1 = a0 - a1 + a3 */
98 ap1_hi = mpn_add (ap1, a0, n, a2, s);
99 #if HAVE_NATIVE_mpn_add_n_sub_n
100 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
101 {
102 ap1_hi = mpn_add_n_sub_n (ap1, am1, a1, ap1, n) >> 1;
103 hi = 0;
104 vm1_neg = 1;
105 }
106 else
107 {
108 cy = mpn_add_n_sub_n (ap1, am1, ap1, a1, n);
109 hi = ap1_hi - (cy & 1);
110 ap1_hi += (cy >> 1);
111 vm1_neg = 0;
112 }
113 #else
114 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
115 {
116 ASSERT_NOCARRY (mpn_sub_n (am1, a1, ap1, n));
117 hi = 0;
118 vm1_neg = 1;
119 }
120 else
121 {
122 hi = ap1_hi - mpn_sub_n (am1, ap1, a1, n);
123 vm1_neg = 0;
124 }
125 ap1_hi += mpn_add_n (ap1, ap1, a1, n);
126 #endif
127
128 /* Compute bp1 = b0 + b1 and bm1 = b0 - b1. */
129 if (t == n)
130 {
131 #if HAVE_NATIVE_mpn_add_n_sub_n
132 if (mpn_cmp (b0, b1, n) < 0)
133 {
134 cy = mpn_add_n_sub_n (bp1, bm1, b1, b0, n);
135 vm1_neg ^= 1;
136 }
137 else
138 {
139 cy = mpn_add_n_sub_n (bp1, bm1, b0, b1, n);
140 }
141 bp1_hi = cy >> 1;
142 #else
143 bp1_hi = mpn_add_n (bp1, b0, b1, n);
144
145 if (mpn_cmp (b0, b1, n) < 0)
146 {
147 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, n));
148 vm1_neg ^= 1;
149 }
150 else
151 {
152 ASSERT_NOCARRY (mpn_sub_n (bm1, b0, b1, n));
153 }
154 #endif
155 }
156 else
157 {
158 /* FIXME: Should still use mpn_add_n_sub_n for the main part. */
159 bp1_hi = mpn_add (bp1, b0, n, b1, t);
160
161 if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0)
162 {
163 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, t));
164 MPN_ZERO (bm1 + t, n - t);
165 vm1_neg ^= 1;
166 }
167 else
168 {
169 ASSERT_NOCARRY (mpn_sub (bm1, b0, n, b1, t));
170 }
171 }
172
173 TOOM32_MUL_N_REC (v1, ap1, bp1, n, scratch_out);
174 if (ap1_hi == 1)
175 {
176 cy = bp1_hi + mpn_add_n (v1 + n, v1 + n, bp1, n);
177 }
178 else if (ap1_hi == 2)
179 {
180 #if HAVE_NATIVE_mpn_addlsh1_n
181 cy = 2 * bp1_hi + mpn_addlsh1_n (v1 + n, v1 + n, bp1, n);
182 #else
183 cy = 2 * bp1_hi + mpn_addmul_1 (v1 + n, bp1, n, CNST_LIMB(2));
184 #endif
185 }
186 else
187 cy = 0;
188 if (bp1_hi != 0)
189 cy += mpn_add_n (v1 + n, v1 + n, ap1, n);
190 v1[2 * n] = cy;
191
192 TOOM32_MUL_N_REC (vm1, am1, bm1, n, scratch_out);
193 if (hi)
194 hi = mpn_add_n (vm1+n, vm1+n, bm1, n);
195
196 vm1[2*n] = hi;
197
198 /* v1 <-- (v1 + vm1) / 2 = x0 + x2 */
199 if (vm1_neg)
200 {
201 #if HAVE_NATIVE_mpn_rsh1sub_n
202 mpn_rsh1sub_n (v1, v1, vm1, 2*n+1);
203 #else
204 mpn_sub_n (v1, v1, vm1, 2*n+1);
205 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
206 #endif
207 }
208 else
209 {
210 #if HAVE_NATIVE_mpn_rsh1add_n
211 mpn_rsh1add_n (v1, v1, vm1, 2*n+1);
212 #else
213 mpn_add_n (v1, v1, vm1, 2*n+1);
214 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
215 #endif
216 }
217
218 /* We get x1 + x3 = (x0 + x2) - (x0 - x1 + x2 - x3), and hence
219
220 y = x1 + x3 + (x0 + x2) * B
221 = (x0 + x2) * B + (x0 + x2) - vm1.
222
223 y is 3*n + 1 limbs, y = y0 + y1 B + y2 B^2. We store them as
224 follows: y0 at scratch, y1 at pp + 2*n, and y2 at scratch + n
225 (already in place, except for carry propagation).
226
227 We thus add
228
229 B^3 B^2 B 1
230 | | | |
231 +-----+----+
232 + | x0 + x2 |
233 +----+-----+----+
234 + | x0 + x2 |
235 +----------+
236 - | vm1 |
237 --+----++----+----+-
238 | y2 | y1 | y0 |
239 +-----+----+----+
240
241 Since we store y0 at the same location as the low half of x0 + x2, we
242 need to do the middle sum first. */
243
244 hi = vm1[2*n];
245 cy = mpn_add_n (pp + 2*n, v1, v1 + n, n);
246 MPN_INCR_U (v1 + n, n + 1, cy + v1[2*n]);
247
248 /* FIXME: Can we get rid of this second vm1_neg conditional by
249 swapping the location of +1 and -1 values? */
250 if (vm1_neg)
251 {
252 cy = mpn_add_n (v1, v1, vm1, n);
253 hi += mpn_add_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
254 MPN_INCR_U (v1 + n, n+1, hi);
255 }
256 else
257 {
258 cy = mpn_sub_n (v1, v1, vm1, n);
259 hi += mpn_sub_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
260 MPN_DECR_U (v1 + n, n+1, hi);
261 }
262
263 TOOM32_MUL_N_REC (pp, a0, b0, n, scratch_out);
264 /* vinf, s+t limbs. Use mpn_mul for now, to handle unbalanced operands */
265 if (s > t) mpn_mul (pp+3*n, a2, s, b1, t);
266 else mpn_mul (pp+3*n, b1, t, a2, s);
267
268 /* Remaining interpolation.
269
270 y * B + x0 + x3 B^3 - x0 B^2 - x3 B
271 = (x1 + x3) B + (x0 + x2) B^2 + x0 + x3 B^3 - x0 B^2 - x3 B
272 = y0 B + y1 B^2 + y3 B^3 + Lx0 + H x0 B
273 + L x3 B^3 + H x3 B^4 - Lx0 B^2 - H x0 B^3 - L x3 B - H x3 B^2
274 = L x0 + (y0 + H x0 - L x3) B + (y1 - L x0 - H x3) B^2
275 + (y2 - (H x0 - L x3)) B^3 + H x3 B^4
276
277 B^4 B^3 B^2 B 1
278 | | | | | |
279 +-------+ +---------+---------+
280 | Hx3 | | Hx0-Lx3 | Lx0 |
281 +------+----------+---------+---------+---------+
282 | y2 | y1 | y0 |
283 ++---------+---------+---------+
284 -| Hx0-Lx3 | - Lx0 |
285 +---------+---------+
286 | - Hx3 |
287 +--------+
288
289 We must take into account the carry from Hx0 - Lx3.
290 */
291
292 cy = mpn_sub_n (pp + n, pp + n, pp+3*n, n);
293 hi = scratch[2*n] + cy;
294
295 cy = mpn_sub_nc (pp + 2*n, pp + 2*n, pp, n, cy);
296 hi -= mpn_sub_nc (pp + 3*n, scratch + n, pp + n, n, cy);
297
298 hi += mpn_add (pp + n, pp + n, 3*n, scratch, n);
299
300 /* FIXME: Is support for s + t == n needed? */
301 if (LIKELY (s + t > n))
302 {
303 hi -= mpn_sub (pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n);
304
305 if (hi < 0)
306 MPN_DECR_U (pp + 4*n, s+t-n, -hi);
307 else
308 MPN_INCR_U (pp + 4*n, s+t-n, hi);
309 }
310 else
311 ASSERT (hi == 0);
312 }
313