1 /* mpn_toom62_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 3 times
2    as large as bn.  Or more accurately, (5/2)bn < an < 6bn.
3 
4    Contributed to the GNU project by Torbjorn Granlund and Marco Bodrato.
5 
6    The idea of applying toom to unbalanced multiplication is due to Marco
7    Bodrato and Alberto Zanoni.
8 
9    THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE.  IT IS ONLY
10    SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
11    GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
12 
13 Copyright 2006-2008, 2012 Free Software Foundation, Inc.
14 
15 This file is part of the GNU MP Library.
16 
17 The GNU MP Library is free software; you can redistribute it and/or modify
18 it under the terms of either:
19 
20   * the GNU Lesser General Public License as published by the Free
21     Software Foundation; either version 3 of the License, or (at your
22     option) any later version.
23 
24 or
25 
26   * the GNU General Public License as published by the Free Software
27     Foundation; either version 2 of the License, or (at your option) any
28     later version.
29 
30 or both in parallel, as here.
31 
32 The GNU MP Library is distributed in the hope that it will be useful, but
33 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
34 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
35 for more details.
36 
37 You should have received copies of the GNU General Public License and the
38 GNU Lesser General Public License along with the GNU MP Library.  If not,
39 see https://www.gnu.org/licenses/.  */
40 
41 
42 #include "gmp.h"
43 #include "gmp-impl.h"
44 
45 /* Evaluate in:
46    0, +1, -1, +2, -2, 1/2, +inf
47 
48   <-s-><--n--><--n--><--n--><--n--><--n-->
49    ___ ______ ______ ______ ______ ______
50   |a5_|___a4_|___a3_|___a2_|___a1_|___a0_|
51 			     |_b1_|___b0_|
52 			     <-t--><--n-->
53 
54   v0  =    a0                       *   b0      #    A(0)*B(0)
55   v1  = (  a0+  a1+ a2+ a3+  a4+  a5)*( b0+ b1) #    A(1)*B(1)      ah  <= 5   bh <= 1
56   vm1 = (  a0-  a1+ a2- a3+  a4-  a5)*( b0- b1) #   A(-1)*B(-1)    |ah| <= 2   bh  = 0
57   v2  = (  a0+ 2a1+4a2+8a3+16a4+32a5)*( b0+2b1) #    A(2)*B(2)      ah  <= 62  bh <= 2
58   vm2 = (  a0- 2a1+4a2-8a3+16a4-32a5)*( b0-2b1) #   A(-2)*B(-2)    -41<=ah<=20 -1<=bh<=0
59   vh  = (32a0+16a1+8a2+4a3+ 2a4+  a5)*(2b0+ b1) #  A(1/2)*B(1/2)    ah  <= 62  bh <= 2
60   vinf=                           a5 *      b1  #  A(inf)*B(inf)
61 */
62 
63 void
mpn_toom62_mul(mp_ptr pp,mp_srcptr ap,mp_size_t an,mp_srcptr bp,mp_size_t bn,mp_ptr scratch)64 mpn_toom62_mul (mp_ptr pp,
65 		mp_srcptr ap, mp_size_t an,
66 		mp_srcptr bp, mp_size_t bn,
67 		mp_ptr scratch)
68 {
69   mp_size_t n, s, t;
70   mp_limb_t cy;
71   mp_ptr as1, asm1, as2, asm2, ash;
72   mp_ptr bs1, bsm1, bs2, bsm2, bsh;
73   mp_ptr gp;
74   enum toom7_flags aflags, bflags;
75   TMP_DECL;
76 
77 #define a0  ap
78 #define a1  (ap + n)
79 #define a2  (ap + 2*n)
80 #define a3  (ap + 3*n)
81 #define a4  (ap + 4*n)
82 #define a5  (ap + 5*n)
83 #define b0  bp
84 #define b1  (bp + n)
85 
86   n = 1 + (an >= 3 * bn ? (an - 1) / (size_t) 6 : (bn - 1) >> 1);
87 
88   s = an - 5 * n;
89   t = bn - n;
90 
91   ASSERT (0 < s && s <= n);
92   ASSERT (0 < t && t <= n);
93 
94   TMP_MARK;
95 
96   as1 = TMP_SALLOC_LIMBS (n + 1);
97   asm1 = TMP_SALLOC_LIMBS (n + 1);
98   as2 = TMP_SALLOC_LIMBS (n + 1);
99   asm2 = TMP_SALLOC_LIMBS (n + 1);
100   ash = TMP_SALLOC_LIMBS (n + 1);
101 
102   bs1 = TMP_SALLOC_LIMBS (n + 1);
103   bsm1 = TMP_SALLOC_LIMBS (n);
104   bs2 = TMP_SALLOC_LIMBS (n + 1);
105   bsm2 = TMP_SALLOC_LIMBS (n + 1);
106   bsh = TMP_SALLOC_LIMBS (n + 1);
107 
108   gp = pp;
109 
110   /* Compute as1 and asm1.  */
111   aflags = (enum toom7_flags) (toom7_w3_neg & mpn_toom_eval_pm1 (as1, asm1, 5, ap, n, s, gp));
112 
113   /* Compute as2 and asm2. */
114   aflags = (enum toom7_flags) (aflags | toom7_w1_neg & mpn_toom_eval_pm2 (as2, asm2, 5, ap, n, s, gp));
115 
116   /* Compute ash = 32 a0 + 16 a1 + 8 a2 + 4 a3 + 2 a4 + a5
117      = 2*(2*(2*(2*(2*a0 + a1) + a2) + a3) + a4) + a5  */
118 
119 #if HAVE_NATIVE_mpn_addlsh1_n
120   cy = mpn_addlsh1_n (ash, a1, a0, n);
121   cy = 2*cy + mpn_addlsh1_n (ash, a2, ash, n);
122   cy = 2*cy + mpn_addlsh1_n (ash, a3, ash, n);
123   cy = 2*cy + mpn_addlsh1_n (ash, a4, ash, n);
124   if (s < n)
125     {
126       mp_limb_t cy2;
127       cy2 = mpn_addlsh1_n (ash, a5, ash, s);
128       ash[n] = 2*cy + mpn_lshift (ash + s, ash + s, n - s, 1);
129       MPN_INCR_U (ash + s, n+1-s, cy2);
130     }
131   else
132     ash[n] = 2*cy + mpn_addlsh1_n (ash, a5, ash, n);
133 #else
134   cy = mpn_lshift (ash, a0, n, 1);
135   cy += mpn_add_n (ash, ash, a1, n);
136   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
137   cy += mpn_add_n (ash, ash, a2, n);
138   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
139   cy += mpn_add_n (ash, ash, a3, n);
140   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
141   cy += mpn_add_n (ash, ash, a4, n);
142   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
143   ash[n] = cy + mpn_add (ash, ash, n, a5, s);
144 #endif
145 
146   /* Compute bs1 and bsm1.  */
147   if (t == n)
148     {
149 #if HAVE_NATIVE_mpn_add_n_sub_n
150       if (mpn_cmp (b0, b1, n) < 0)
151 	{
152 	  cy = mpn_add_n_sub_n (bs1, bsm1, b1, b0, n);
153 	  bflags = toom7_w3_neg;
154 	}
155       else
156 	{
157 	  cy = mpn_add_n_sub_n (bs1, bsm1, b0, b1, n);
158 	  bflags = (enum toom7_flags) 0;
159 	}
160       bs1[n] = cy >> 1;
161 #else
162       bs1[n] = mpn_add_n (bs1, b0, b1, n);
163       if (mpn_cmp (b0, b1, n) < 0)
164 	{
165 	  mpn_sub_n (bsm1, b1, b0, n);
166 	  bflags = toom7_w3_neg;
167 	}
168       else
169 	{
170 	  mpn_sub_n (bsm1, b0, b1, n);
171 	  bflags = (enum toom7_flags) 0;
172 	}
173 #endif
174     }
175   else
176     {
177       bs1[n] = mpn_add (bs1, b0, n, b1, t);
178       if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0)
179 	{
180 	  mpn_sub_n (bsm1, b1, b0, t);
181 	  MPN_ZERO (bsm1 + t, n - t);
182 	  bflags = toom7_w3_neg;
183 	}
184       else
185 	{
186 	  mpn_sub (bsm1, b0, n, b1, t);
187 	  bflags = (enum toom7_flags) 0;
188 	}
189     }
190 
191   /* Compute bs2 and bsm2. Recycling bs1 and bsm1; bs2=bs1+b1, bsm2 =
192      bsm1 - b1 */
193   mpn_add (bs2, bs1, n + 1, b1, t);
194   if (bflags & toom7_w3_neg)
195     {
196       bsm2[n] = mpn_add (bsm2, bsm1, n, b1, t);
197       bflags = (enum toom7_flags) (bflags | toom7_w1_neg);
198     }
199   else
200     {
201       /* FIXME: Simplify this logic? */
202       if (t < n)
203 	{
204 	  if (mpn_zero_p (bsm1 + t, n - t) && mpn_cmp (bsm1, b1, t) < 0)
205 	    {
206 	      ASSERT_NOCARRY (mpn_sub_n (bsm2, b1, bsm1, t));
207 	      MPN_ZERO (bsm2 + t, n + 1 - t);
208 	      bflags = (enum toom7_flags) (bflags | toom7_w1_neg);
209 	    }
210 	  else
211 	    {
212 	      ASSERT_NOCARRY (mpn_sub (bsm2, bsm1, n, b1, t));
213 	      bsm2[n] = 0;
214 	    }
215 	}
216       else
217 	{
218 	  if (mpn_cmp (bsm1, b1, n) < 0)
219 	    {
220 	      ASSERT_NOCARRY (mpn_sub_n (bsm2, b1, bsm1, n));
221 	      bflags = (enum toom7_flags) (bflags | toom7_w1_neg);
222 	    }
223 	  else
224 	    {
225 	      ASSERT_NOCARRY (mpn_sub_n (bsm2, bsm1, b1, n));
226 	    }
227 	  bsm2[n] = 0;
228 	}
229     }
230 
231   /* Compute bsh, recycling bs1. bsh=bs1+b0;  */
232   bsh[n] = bs1[n] + mpn_add_n (bsh, bs1, b0, n);
233 
234   ASSERT (as1[n] <= 5);
235   ASSERT (bs1[n] <= 1);
236   ASSERT (asm1[n] <= 2);
237   ASSERT (as2[n] <= 62);
238   ASSERT (bs2[n] <= 2);
239   ASSERT (asm2[n] <= 41);
240   ASSERT (bsm2[n] <= 1);
241   ASSERT (ash[n] <= 62);
242   ASSERT (bsh[n] <= 2);
243 
244 #define v0    pp				/* 2n */
245 #define v1    (pp + 2 * n)			/* 2n+1 */
246 #define vinf  (pp + 6 * n)			/* s+t */
247 #define v2    scratch				/* 2n+1 */
248 #define vm2   (scratch + 2 * n + 1)		/* 2n+1 */
249 #define vh    (scratch + 4 * n + 2)		/* 2n+1 */
250 #define vm1   (scratch + 6 * n + 3)		/* 2n+1 */
251 #define scratch_out (scratch + 8 * n + 4)		/* 2n+1 */
252   /* Total scratch need: 10*n+5 */
253 
254   /* Must be in allocation order, as they overwrite one limb beyond
255    * 2n+1. */
256   mpn_mul_n (v2, as2, bs2, n + 1);		/* v2, 2n+1 limbs */
257   mpn_mul_n (vm2, asm2, bsm2, n + 1);		/* vm2, 2n+1 limbs */
258   mpn_mul_n (vh, ash, bsh, n + 1);		/* vh, 2n+1 limbs */
259 
260   /* vm1, 2n+1 limbs */
261   mpn_mul_n (vm1, asm1, bsm1, n);
262   cy = 0;
263   if (asm1[n] == 1)
264     {
265       cy = mpn_add_n (vm1 + n, vm1 + n, bsm1, n);
266     }
267   else if (asm1[n] == 2)
268     {
269 #if HAVE_NATIVE_mpn_addlsh1_n
270       cy = mpn_addlsh1_n (vm1 + n, vm1 + n, bsm1, n);
271 #else
272       cy = mpn_addmul_1 (vm1 + n, bsm1, n, CNST_LIMB(2));
273 #endif
274     }
275   vm1[2 * n] = cy;
276 
277   /* v1, 2n+1 limbs */
278   mpn_mul_n (v1, as1, bs1, n);
279   if (as1[n] == 1)
280     {
281       cy = bs1[n] + mpn_add_n (v1 + n, v1 + n, bs1, n);
282     }
283   else if (as1[n] == 2)
284     {
285 #if HAVE_NATIVE_mpn_addlsh1_n
286       cy = 2 * bs1[n] + mpn_addlsh1_n (v1 + n, v1 + n, bs1, n);
287 #else
288       cy = 2 * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, CNST_LIMB(2));
289 #endif
290     }
291   else if (as1[n] != 0)
292     {
293       cy = as1[n] * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, as1[n]);
294     }
295   else
296     cy = 0;
297   if (bs1[n] != 0)
298     cy += mpn_add_n (v1 + n, v1 + n, as1, n);
299   v1[2 * n] = cy;
300 
301   mpn_mul_n (v0, a0, b0, n);			/* v0, 2n limbs */
302 
303   /* vinf, s+t limbs */
304   if (s > t)  mpn_mul (vinf, a5, s, b1, t);
305   else        mpn_mul (vinf, b1, t, a5, s);
306 
307   mpn_toom_interpolate_7pts (pp, n, (enum toom7_flags) (aflags ^ bflags),
308 			     vm2, vm1, v2, vh, s + t, scratch_out);
309 
310   TMP_FREE;
311 }
312