1 /* Arithmetic modulo Fermat numbers.
2 
3 Copyright 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2012 Alexander Kruppa,
4 Paul Zimmermann
5 
6 This file is part of the ECM Library.
7 
8 The ECM Library is free software; you can redistribute it and/or modify
9 it under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 3 of the License, or (at your
11 option) any later version.
12 
13 The ECM Library is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
16 License for more details.
17 
18 You should have received a copy of the GNU Lesser General Public License
19 along with the ECM Library; see the file COPYING.LIB.  If not, see
20 http://www.gnu.org/licenses/ or write to the Free Software Foundation, Inc.,
21 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA. */
22 
23 #include <stdio.h>
24 #include <stdlib.h> /* for abs if assertions enabled */
25 #include "ecm-impl.h"
26 #include "ecm-gmp.h"
27 
28 #ifdef HAVE_LIMITS_H
29 # include <limits.h>
30 #else
31 # ifndef UINT_MAX
32 #  define UINT_MAX (~(unsigned int) 0)
33 # endif
34 #endif
35 
36 /*
37 #define DEBUG 1
38 #define CHECKSUM 1
39 */
40 
41 static mpz_t gt;
42 static int gt_inited = 0;
43 unsigned int Fermat;
44 
45 #define CACHESIZE 512U
46 
47 /* a' <- a+b, b' <- a-b. */
48 
49 #define ADDSUB_MOD(a, b) \
50   mpz_sub (gt, a, b); \
51   mpz_add (a, a, b);  \
52   F_mod_gt (b, n);    \
53   F_mod_1 (a, n);
54 
55 __GMP_DECLSPEC mp_limb_t __gmpn_mod_34lsub1 (mp_limb_t*, mp_size_t);
56 
57 /* compute remainder modulo 2^(GMP_LIMB_BITS*3/4)-1 */
58 #ifndef HAVE___GMPN_MOD_34LSUB1
59 mp_limb_t
__gmpn_mod_34lsub1(mp_limb_t * src,mp_size_t size)60 __gmpn_mod_34lsub1 (mp_limb_t *src, mp_size_t size)
61 {
62   mp_ptr tp;
63   mp_limb_t r, d;
64 
65   ASSERT(GMP_LIMB_BITS % 4 == 0);
66   tp = malloc (size * sizeof (mp_limb_t));
67   if (tp == NULL)
68     {
69       fprintf (stderr, "Cannot allocate memory in __gmpn_mod_34lsub1\n");
70       exit (1);
71     }
72   MPN_COPY (tp, src, size);
73   d = ((mp_limb_t) 1 << (3 * (GMP_LIMB_BITS / 4))) - (mp_limb_t) 1;
74   mpn_divmod_1 (&r, tp, size, d);
75   free (tp);
76   return r;
77 }
78 #endif
79 
80 /* RS -> RS (mod 2^n+1). If input |RS| < 2^(2*n), result |RS| < 2^(n+1) */
81 
82 static inline void
F_mod_1(mpz_t RS,unsigned int n)83 F_mod_1 (mpz_t RS, unsigned int n)
84 {
85   mp_size_t size;
86   mp_limb_t v;
87   int sgn;
88 
89   size = mpz_size (RS);
90 
91   ASSERT_ALWAYS(size <= (mp_size_t) n / GMP_NUMB_BITS + 1);
92   sgn = mpz_sgn (RS);          /* Remember original sign */
93   v = mpz_getlimbn (RS, n / GMP_NUMB_BITS);
94   mpz_tdiv_r_2exp (RS, RS, n); /* Just a truncate. RS < 2^n. Can make
95                                   RS zero and so change sgn(RS)! */
96   if (sgn == -1)
97     mpz_add_ui (RS, RS, v);
98   else
99     mpz_sub_ui (RS, RS, v);
100 }
101 
102 /* R = gt (mod 2^n+1) */
103 
104 static inline void
F_mod_gt(mpz_t R,unsigned int n)105 F_mod_gt (mpz_t R, unsigned int n)
106 {
107   mp_size_t size;
108   mp_limb_t v;
109 
110   size = mpz_size (gt);
111 
112   ASSERT(R != gt);
113 
114   if ((unsigned int) size == n / GMP_NUMB_BITS + 1)
115     {
116       int sgn;
117       sgn = mpz_sgn (gt);
118       v = mpz_getlimbn (gt, n / GMP_NUMB_BITS);
119       mpz_tdiv_r_2exp (gt, gt, n); /* Just a truncate */
120       if (sgn == -1)
121           mpz_add_ui (R, gt, v);
122       else
123           mpz_sub_ui (R, gt, v);
124     }
125   else if ((unsigned int) size > n / GMP_NUMB_BITS + 1)
126     {
127       mpz_tdiv_q_2exp (R, gt, n);
128       mpz_tdiv_r_2exp (gt, gt, n); /* Just a truncate */
129       mpz_sub (R, gt, R);
130     }
131   else
132     mpz_set (R, gt);
133 }
134 
135 
136 /* R = S1 * S2 (mod 2^n+1) where n is a power of 2
137    S1 == S2, S1 == R, S2 == R ok, but none may == gt.
138    Assume n >= GMP_NUMB_BITS, and GMP_NUMB_BITS is a power of two. */
139 static void
F_mulmod(mpz_t R,mpz_t S1,mpz_t S2,unsigned int n)140 F_mulmod (mpz_t R, mpz_t S1, mpz_t S2, unsigned int n)
141 {
142   int n2 = n / GMP_NUMB_BITS; /* type of _mp_size is int */
143 
144   F_mod_1 (S1, n);
145   F_mod_1 (S2, n);
146   ASSERT(mpz_size (S1) <= (unsigned) n2);
147   ASSERT(mpz_size (S2) <= (unsigned) n2);
148 
149   if (n >= 32768)
150     {
151       unsigned long k;
152 
153       _mpz_realloc (gt, n2 + 1);
154       /* in case the reallocation fails, _mpz_realloc sets the value to 0 */
155       ASSERT_ALWAYS (mpz_cmp_ui (gt, 0) != 0);
156       k = mpn_fft_best_k (n2, S1 == S2);
157       /* the following cannot be changed to use mpn_mulmod_bnm1 since we
158          are precisely multiplying modulo a Fermat number */
159       mpn_mul_fft (PTR(gt), n2, PTR(S1), ABSIZ(S1), PTR(S2), ABSIZ(S2), k);
160       MPN_NORMALIZE(PTR(gt), n2);
161       SIZ(gt) = ((SIZ(S1) ^ SIZ(S2)) >= 0) ? n2 : -n2;
162       F_mod_gt (R, n);
163       return;
164     }
165   mpz_mul (gt, S1, S2);
166   F_mod_gt (R, n);
167   return;
168 }
169 
170 /* R = S + sgn(S)*(2^e) */
171 
172 static void
mpz_absadd_2exp(mpz_t RS,unsigned int e)173 mpz_absadd_2exp (mpz_t RS, unsigned int e)
174 {
175   mp_size_t siz, limb_idx, bit_idx;
176   mp_limb_t cy;
177   int sgn;
178 
179   limb_idx = e / GMP_NUMB_BITS;
180   bit_idx = e % GMP_NUMB_BITS;
181   siz = mpz_size (RS);
182   sgn = (mpz_sgn (RS) >= 0) ? 1 : -1;
183 
184   if (limb_idx >= RS->_mp_alloc)
185     /* WARNING: mpz_realloc2 does not keep the value!!! */
186     mpz_realloc2 (RS, (limb_idx + 1) * GMP_NUMB_BITS);
187 
188   /* Now RS->_mp_alloc > limb_idx) */
189 
190   while (siz <= limb_idx)
191     {
192       RS->_mp_d[siz++] = 0;
193       RS->_mp_size += sgn;
194     }
195 
196   /* Now RS->_mp_alloc >= siz > limb_idx */
197 
198   cy = mpn_add_1 (RS->_mp_d + limb_idx, RS->_mp_d + limb_idx,
199                   siz - limb_idx, ((mp_limb_t)1) << bit_idx);
200   if (cy)
201     {
202       if (RS->_mp_alloc <= siz)
203         /* WARNING: mpz_realloc2 does not keep the value!!! */
204         mpz_realloc2 (RS, (siz + 1) * GMP_NUMB_BITS);
205 
206       RS->_mp_d[siz] = 1;
207       RS->_mp_size += sgn;
208     }
209 }
210 
211 /* R = S / 2 (mod 2^n + 1). S == gt is ok */
212 
213 static void
F_divby2(mpz_t R,mpz_t S,unsigned int n)214 F_divby2 (mpz_t R, mpz_t S, unsigned int n)
215 {
216   int odd, sgn;
217 
218   odd = mpz_odd_p (S);
219   sgn = mpz_sgn (S);
220   mpz_tdiv_q_2exp (R, S, 1);
221 
222   if (odd)
223     {
224       /* We shifted out a set bit at the bottom. With negative wrap-around,
225          that becomes -2^(n-1), so we add -2^(n-1) + 2^n+1 = 2^(n-1)+1.
226          If |S| < 2^(n+1), |R| < 2^n + 2^(n-1) + 1 < 2^(n+1) for n > 1. */
227 
228       mpz_absadd_2exp (R, n - 1);
229       if (sgn < 0)
230         mpz_sub_ui (R, R, 1);
231       else
232         mpz_add_ui (R, R, 1);
233     }
234 }
235 
236 
237 /* RS = RS / 3 (mod 2^n + 1). RS == gt is ok */
238 
239 static void
F_divby3_1(mpz_t RS,unsigned int n)240 F_divby3_1 (mpz_t RS, unsigned int n)
241 {
242   /* 2^2^m == 1 (mod 3) for m>0, thus F_m == 2 (mod 3) */
243   int mod, sgn;
244 
245   sgn = mpz_sgn (RS);
246   mod = __gmpn_mod_34lsub1 (RS->_mp_d, mpz_size (RS)) % 3;
247 
248   if (mod == 1)
249     {
250       /* Add F_m. If |RS| < 2^(n+1), |RS|+F_m < 3*2^n+1 */
251       mpz_absadd_2exp (RS, n);
252       if (sgn >= 0)
253         mpz_add_ui (RS, RS, 1);
254       else
255         mpz_sub_ui (RS, RS, 1);
256     }
257   else if (mod == 2)
258     {
259       /* Add 2 * F_m.  If |RS| < 2^(n+1), |RS|+2*F_m < 4*2^n+2 */
260       mpz_absadd_2exp (RS, n + 1);
261       if (sgn >= 0)
262         mpz_add_ui (RS, RS, 2);
263       else
264         mpz_sub_ui (RS, RS, 2);
265     }
266 
267   mpz_divby3_1op (RS); /* |RS| < (4*2^n+2)/3 < 2^(n+1) */
268 }
269 
270 static void
F_divby5_1(mpz_t RS,unsigned int n)271 F_divby5_1 (mpz_t RS, unsigned int n)
272 {
273   /* 2^2^m == 1 (mod 5) for m>1, thus F_m == 2 (mod 5) */
274   int mod, sgn;
275 
276   sgn = mpz_sgn (RS);
277   mod = __gmpn_mod_34lsub1 (RS->_mp_d, mpz_size (RS)) % 5;
278 
279   if (mod == 1)
280     {
281       /* Add 2 * F_m == 4 (mod 5) */
282       mpz_absadd_2exp (RS, n + 1);
283       if (sgn == 1)
284         mpz_add_ui (RS, RS, 2);
285       else
286         mpz_sub_ui (RS, RS, 2);
287     }
288   else if (mod == 2)
289     {
290       /* Add 4 * F_m == 3 (mod 5) */
291       mpz_absadd_2exp (RS, n + 2);
292       if (sgn == 1)
293         mpz_add_ui (RS, RS, 4);
294       else
295         mpz_sub_ui (RS, RS, 4);
296     }
297   else if (mod == 3)
298     {
299       /* Add F_m == 3 (mod 5) */
300       mpz_absadd_2exp (RS, n);
301       if (sgn == 1)
302         mpz_add_ui (RS, RS, 1);
303       else
304         mpz_sub_ui (RS, RS, 1);
305     }
306   else if (mod == 4)
307     {
308       /* Add 3 * F_m == 1 (mod 5) */
309       mpz_absadd_2exp (RS, n);
310       mpz_absadd_2exp (RS, n + 1);
311       if (sgn == 1)
312         mpz_add_ui (RS, RS, 3);
313       else
314         mpz_sub_ui (RS, RS, 3);
315     }
316 
317   ASSERT(mpz_divisible_ui_p (RS, 5));
318   mpz_divexact_ui (RS, RS, 5);
319 }
320 
321 
322 /* A 2^(m+2) length convolution is possible:
323    (2^(3n/4) - 2^(n/4))^2 == 2 (mod 2^n+1)
324    so we have an element of order 2^(m+2) of simple enough form
325    to use it as a root of unity the transform */
326 
327 /* Multiply by sqrt(2)^e (mod F_m).  n = 2^m */
328 /* R = (S * sqrt(2)^e) % (2^n+1) */
329 /* R == S is ok, but neither must be == gt */
330 /* Assumes 0 < e < 4*n, and e <> 2*n */
331 
332 static void
F_mul_sqrt2exp(mpz_t R,mpz_t S,int e,unsigned int n)333 F_mul_sqrt2exp (mpz_t R, mpz_t S, int e, unsigned int n)
334 {
335   int chgsgn = 0, odd;
336 
337   ASSERT(S != gt);
338   ASSERT(R != gt);
339   ASSERT(0 < e && (unsigned int) e < 4 * n && (unsigned int) e != 2 * n);
340 
341   /* 0 < e < 4*n */
342   if ((unsigned) e > 2 * n)     /* sqrt(2)^(2*n) == -1 (mod F_m), so */
343     {
344       e -= 2 * n;               /* sqrt(2)^e == -sqrt(2)^(e-2*n) (mod F_m) */
345       chgsgn = 1;
346     }				/* Now e < 2*n */
347 
348   odd = e & 1;
349   e >>= 1;
350 
351   if (odd)
352     {
353       /* Multiply by sqrt(2) == 2^(3n/4) - 2^(n/4) */
354       /* S * (2^(3n/4) - 2^(n/4)) == 2^(n/4) * (S*2^(n/2) - S) */
355       mpz_mul_2exp (gt, S, n / 2);
356       mpz_sub (gt, gt, S);
357       mpz_tdiv_q_2exp (R, gt, n / 4 * 3);
358       mpz_tdiv_r_2exp (gt, gt, n / 4 * 3);
359       mpz_mul_2exp (gt, gt, n / 4);
360       mpz_sub (R, gt, R);
361 
362       if (e != 0)
363         {
364           mpz_tdiv_q_2exp (gt, R, n-e);
365           mpz_tdiv_r_2exp (R, R, n-e);
366           mpz_mul_2exp (R, R, e);
367           mpz_sub (R, R, gt);
368         }
369     }
370   else /* necessarily e <> 0 */
371     {
372       ASSERT (e != 0);
373       /*  S     = a*2^(n-e) + b,   b < 2^(n-e)  */
374       /*  S*2^e = a*2^n + b*2^e = b*2^e - a */
375       /*  b*2^e < 2^(n-e)*2^e = 2^n */
376       mpz_tdiv_q_2exp (gt, S, n - e); /* upper e bits (=a) into gt */
377       mpz_tdiv_r_2exp (R, S, n - e);  /* lower n-e bits (=b) into R */
378                                       /* This is simply a truncate if S == R */
379       mpz_mul_2exp (R, R, e);         /* R < 2^n */
380       mpz_sub (R, R, gt);
381     }
382 
383   if (chgsgn)
384     mpz_neg (R, R);
385 }
386 
387 /* Same, but input may be gt. Input and output must not be identical.
388    Currently this routine is always called with e=n, with n a power of 2,
389    thus we assume e is even. Moreover we assume 0 < e < 2n. */
390 static void
F_mul_sqrt2exp_2(mpz_t R,mpz_t S,int e,unsigned int n)391 F_mul_sqrt2exp_2 (mpz_t R, mpz_t S, int e, unsigned int n)
392 {
393   ASSERT (S != R);
394   ASSERT (R != gt);
395   ASSERT (0 < e && (unsigned) e < 2 * n);
396   ASSERT ((e & 1) == 0);
397 
398   e >>= 1;
399 
400   mpz_tdiv_q_2exp (R, S, n - e); /* upper e bits into R */
401   mpz_tdiv_r_2exp (gt, S, n - e); /* lower n-e bits into gt */
402   mpz_mul_2exp (gt, gt, e);
403   mpz_sub (R, gt, R);
404 }
405 
406 #define A0s A[0]
407 #define A1s A[l << stride2]
408 #define A2s A[2 * l << stride2]
409 #define A3s A[3 * l << stride2]
410 #define A0is A[i << stride2]
411 #define A1is A[(i + l) << stride2]
412 #define A2is A[(i + 2 * l) << stride2]
413 #define A3is A[(i + 3 * l) << stride2]
414 
415 /* Decimation-in-frequency FFT. Unscrambled input, scrambled output.
416    Elements are (mod 2^n+1), l and n must be powers of 2, l must be <= 4*n.
417    Performs forward transform.
418    Assumes l > 1. */
419 static void
F_fft_dif(mpz_t * A,int l,int stride2,int n)420 F_fft_dif (mpz_t *A, int l, int stride2, int n)
421 {
422   int i, omega = (4 * n) / l, iomega;
423 
424   ASSERT (l > 1);
425 
426   ASSERT((4 * n) % l == 0);
427 
428   if (l == 2)
429     {
430       ADDSUB_MOD(A[0], A[1<<stride2]);
431       return;
432     }
433 
434   l /= 4;
435 
436   mpz_sub (gt, A1s, A3s);            /* gt = a1 - a3 */
437   mpz_add (A1s, A1s, A3s);           /* A1 = a1 + a3 */
438   F_mul_sqrt2exp_2 (A3s, gt, n, n);  /* A3 = i * (a1 - a3) */
439 
440   mpz_sub (gt, A[0], A2s);           /* gt = a0 - a2 */
441   mpz_add (A[0], A[0], A2s);         /* A0 = a0 + a2 */
442 
443   mpz_sub (A2s, A[0], A1s);          /* A2 = a0 - a1 + a2 - a3 */
444   mpz_add (A[0], A[0], A1s);         /* A0 = a0 + a1 + a2 + a3 */
445   mpz_add (A1s, gt, A3s);            /* A1 = a0 - a2 + i * (a1 - a3) */
446   mpz_sub (A3s, gt, A3s);            /* A3 = a0 - a2 - i * (a1 - a3) */
447 
448   for (i = 1, iomega = omega; i < l; i++, iomega += omega)
449     {
450       mpz_sub (gt, A1is, A3is);
451       mpz_add (A1is, A1is, A3is);
452       F_mul_sqrt2exp_2 (A3is, gt, n, n);
453 
454       mpz_sub (gt, A0is, A2is);
455       mpz_add (A0is, A0is, A2is);
456 
457       mpz_sub (A2is, A0is, A1is);
458       mpz_add (A0is, A0is, A1is);
459       mpz_add (A1is, gt, A3is);
460       mpz_sub (A3is, gt, A3is);
461       /* iomega goes from 4n/l to n-4n/l (with original l) thus cannot
462          equal 0 nor 2n */
463       F_mul_sqrt2exp (A1is, A1is, iomega, n);
464       /* 2*iomega goes from 8n/l to 2n-8n/l (with original l) thus cannot
465          equal 0 nor 2n */
466       F_mul_sqrt2exp (A2is, A2is, 2 * iomega, n);
467       /* 3*iomega goes from 12n/l to 3n-12n/l (with original l) thus cannot
468          equal 0 nor 2n (because n is a power of 2) */
469       F_mul_sqrt2exp (A3is, A3is, 3 * iomega, n);
470     }
471 
472   if (l > 1)
473     {
474       F_fft_dif (A, l, stride2, n);
475       F_fft_dif (A + (l << stride2), l, stride2, n);
476       F_fft_dif (A + (2 * l << stride2), l, stride2, n);
477       F_fft_dif (A + (3 * l << stride2), l, stride2, n);
478     }
479 }
480 
481 /* Decimation-in-time inverse FFT. Scrambled input, unscrambled output.
482    Does not perform divide-by-length. l, and n as in F_fft_dif().
483    Assume l > 1. */
484 static void
F_fft_dit(mpz_t * A,int l,int stride2,int n)485 F_fft_dit (mpz_t *A, int l, int stride2, int n)
486 {
487   int i, omega = (4 * n) / l, iomega;
488 
489   ASSERT (l > 1);
490 
491   ASSERT((4 * n) % l == 0);
492 
493   if (l == 2)
494     {
495       ADDSUB_MOD(A[0], A[1<<stride2]);
496       return;
497     }
498 
499   l /= 4;
500 
501   if (l > 1)
502     {
503       F_fft_dit (A, l, stride2, n);
504       F_fft_dit (A + (l << stride2), l, stride2, n);
505       F_fft_dit (A + (2 * l << stride2), l, stride2, n);
506       F_fft_dit (A + (3 * l << stride2), l, stride2, n);
507     }
508 
509   mpz_sub (gt, A3s, A1s);            /* gt = -(a1 - a3) */
510   mpz_add (A1s, A1s, A3s);           /* A1 = a1 + a3 */
511   F_mul_sqrt2exp_2 (A3s, gt, n, n);  /* A3 = i * -(a1 - a3) */
512 
513   mpz_sub (gt, A[0], A2s);           /* gt = a0 - a2 */
514   mpz_add (A[0], A[0], A2s);         /* A0 = a0 + a2 */
515 
516   mpz_sub (A2s, A[0], A1s);          /* A2 = a0 - a1 + a2 - a3 */
517   mpz_add (A[0], A[0], A1s);         /* A0 = a0 + a1 + a2 + a3 */
518   mpz_add (A1s, gt, A3s);            /* A1 = a0 - a2 + i * -(a1 - a3) */
519   mpz_sub (A3s, gt, A3s);            /* A3 = a0 - a2 - i * -(a1 - a3) */
520 
521   for (i = 1, iomega = omega; i < l; i++, iomega += omega)
522     {
523       /* Divide by omega^i. Since sqrt(2)^(4*n) == 1 (mod 2^n+1),
524          this is like multiplying by omega^(4*n-i) */
525       /* iomega goes from 4n/l to n-4n/l (with original l) thus
526          3n < 4*n-iomega < 4n */
527       F_mul_sqrt2exp (A1is, A1is, 4 * n - iomega, n);
528       /* 2n < 4*n-2*iomega < 4n */
529       F_mul_sqrt2exp (A2is, A2is, 4 * n - 2 * iomega, n);
530       /* n < 4*n-3*iomega < 4n, and 4*n-3*iomega cannot equal 2n since
531          n is a power of 2 and 3*iomega is divisible by 3 */
532       F_mul_sqrt2exp (A3is, A3is, 4 * n - 3 * iomega, n);
533 
534       mpz_sub (gt, A3is, A1is);
535       mpz_add (A1is, A1is, A3is);
536       F_mul_sqrt2exp_2 (A3is, gt, n, n);
537 
538       mpz_sub (gt, A0is, A2is);
539       mpz_add (A0is, A0is, A2is);
540 
541       mpz_sub (A2is, A0is, A1is);
542       mpz_add (A0is, A0is, A1is);
543       mpz_add (A1is, gt, A3is);
544       mpz_sub (A3is, gt, A3is);
545 
546       F_mod_1 (A0is, n);
547       F_mod_1 (A1is, n);
548       F_mod_1 (A2is, n);
549       F_mod_1 (A3is, n);
550     }
551 }
552 
553 #define A0 A[i]
554 #define A1 A[l+i]
555 #define A2 A[2*l+i]
556 #define A3 A[3*l+i]
557 #define B0 B[i]
558 #define B1 B[l+i]
559 #define B2 B[2*l+i]
560 #define B3 B[3*l+i]
561 #define C0 C[i]
562 #define C1 C[l+i]
563 #define C2 C[2*l+i]
564 #define C3 C[3*l+i]
565 #define C4 C[4*l+i]
566 #define C5 C[5*l+i]
567 #define C6 C[6*l+i]
568 #define C7 C[7*l+i]
569 #define t0 t[i]
570 #define t1 t[l+i]
571 #define t2 t[2*l+i]
572 #define t3 t[3*l+i]
573 #define t4 t[4*l+i]
574 #define t5 t[5*l+i]
575 
576 
577 /* Assume A <> B. There was some code for squaring (A=B) in revision <= 2788.
578  */
579 static unsigned int
F_toomcook4(mpz_t * C,mpz_t * A,mpz_t * B,unsigned int len,unsigned int n,mpz_t * t)580 F_toomcook4 (mpz_t *C, mpz_t *A, mpz_t *B, unsigned int len, unsigned int n,
581              mpz_t *t)
582 {
583   unsigned int l, i, r;
584 
585   ASSERT(A != B);
586   ASSERT(len % 4 == 0);
587 
588   l = len / 4;
589 
590   for (i = 0; i < l; i++)
591     {
592       /*** Evaluate A(2), A(-2), 8*A(1/2) ***/
593       mpz_mul_2exp (t0, A0, 1);
594       mpz_add (t0, t0, A1);
595       mpz_mul_2exp (t0, t0, 1);
596       mpz_add (t0, t0, A2);
597       mpz_mul_2exp (t0, t0, 1);
598       mpz_add (t0, t0, A3);         /* t[0 .. l-1] = 8*A(1/2) < 15*N */
599       F_mod_1 (t0, n);
600 
601       mpz_mul_2exp (t2, A3, 2);
602       mpz_add (t2, t2, A1);
603       mpz_mul_2exp (t2, t2, 1);     /* t[2l .. 3l-1] = 8*A_3 + 2*A_1 */
604 
605       mpz_mul_2exp (gt, A2, 2);
606       mpz_add (gt, gt, A0);         /* gt = 4*A_2 + A0 */
607       mpz_sub (t4, gt, t2);         /* t[4l .. 5l-1] = A(-2) */
608       mpz_add (t2, t2, gt);         /* t[2l .. 3l-1] = A(2) */
609       F_mod_1 (t4, n);
610       F_mod_1 (t2, n);
611 
612       /*** Evaluate B(2), B(-2), 8*B(1/2) ***/
613       mpz_mul_2exp (t1, B0, 1);
614       mpz_add (t1, t1, B1);
615       mpz_mul_2exp (t1, t1, 1);
616       mpz_add (t1, t1, B2);
617       mpz_mul_2exp (t1, t1, 1);
618       mpz_add (t1, t1, B3);         /* t[l .. 2l-1] = 8*B(1/2) */
619       F_mod_1 (t1, n);
620 
621       mpz_mul_2exp (t3, B3, 2);
622       mpz_add (t3, t3, B1);
623       mpz_mul_2exp (t3, t3, 1);     /* t[3l .. 4l-1] = 8*B_3 + 2*B_1 */
624 
625       mpz_mul_2exp (gt, B2, 2);
626       mpz_add (gt, gt, B0);         /* gt = 4*B_2 + B0 */
627       mpz_sub (t5, gt, t3);         /* t[5l .. 6l-1] = B(-2) */
628       mpz_add (t3, t3, gt);         /* t[3l .. 4l-1] = B(2) */
629       F_mod_1 (t5, n);
630       F_mod_1 (t3, n);
631 
632       /* Evaluate A(1), A(-1) */
633       mpz_add (C2, A0, A2);         /* May overwrite A2 */
634 #undef A2
635       mpz_add (gt, A1, A3);
636       mpz_set (C1, B0);             /* C1 = B(0) May overwrite A1 */
637 #undef A1
638       mpz_sub (C4, C2, gt);         /* C4 = A(-1). May overwrite B0 */
639 #undef B0
640       mpz_add (C2, C2, gt);         /* C2 = A(1) < 4*N */
641       F_mod_1 (C2, n);
642       F_mod_1 (C4, n);
643 
644       /* Evaluate B(1), B(-1) */
645       mpz_add (gt, C1, B2);         /* B0 is in C1 */
646       mpz_set (C6, A3);             /* C6 = A(inf) May overwrite B2 */
647 #undef B2
648       mpz_add (C3, B1, B3);         /* May overwrite A3 */
649 #undef A3
650       mpz_sub (C5, gt, C3);         /* C5 = B(-1). May overwrite B1 */
651 #undef B1
652       mpz_add (C3, gt, C3);         /* C3 = B(1) */
653       F_mod_1 (C3, n);
654       F_mod_1 (C5, n);
655     }
656 
657   /* A0 A1   A2   A3   B0    B1   B2 B3 */
658   /* A0 B0  A(1) B(1) A(-1) B(-1) A3 B3 */
659   /* C0 C1   C2   C3   C4    C5   C6 C7 */
660 
661   r = F_mul (t, t, t + l, l, DEFAULT, n, t + 6 * l);
662   /* t0 = 8*A(1/2) * 8*B(1/2) = 64*C(1/2) */
663   r += F_mul (t + 2 * l, t + 2 * l, t + 3 * l, l, DEFAULT, n, t + 6 * l);
664   /* t2 = A(2) * B(2) = C(2) */
665   r += F_mul (t + 4 * l, t + 4 * l, t + 5 * l, l, DEFAULT, n, t + 6 * l);
666   /* t4 = A(-2) * B(-2) = C(-2) */
667   r += F_mul (C, A, C + l, l, DEFAULT, n, t + 6 * l);
668   /* C0 = A(0)*B(0) = C(0) */
669   r += F_mul (C + 2 * l, C + 2 * l, C + 3 * l, l, DEFAULT, n, t + 6 * l);
670   /* C2 = A(1)*B(1) = C(1) */
671   r += F_mul (C + 4 * l, C + 4 * l, C + 5 * l, l, DEFAULT, n, t + 6 * l);
672   /* C4 = A(-1)*B(-1) = C(-1) */
673   r += F_mul (C + 6 * l, C + 6 * l, B + 3 * l, l, DEFAULT, n, t + 6 * l);
674   /* C6 = A(inf)*B(inf) = C(inf) */
675 
676 /* C(0)   C(1)   C(-1)  C(inf)  64*C(1/2)  C(2)   C(-2) */
677 /* C0,C1  C2,C3  C4,C5  C6,C7   t0,t1      t2,t3  t4,t5 */
678 
679   for (i = 0; i < 2 * l - 1; i++)
680     {
681       mpz_add (t0, t0, t2);             /* t0 = 65 34 20 16 20 34 65 */
682 
683       mpz_sub (gt, C2, C4);             /* gt = 2*C_odd(1) = 0 2 0 2 0 2 0 */
684       mpz_add (C2, C2, C4);             /* C2 = 2*C_even(1) = 2 0 2 0 2 0 2 */
685       F_divby2 (C2, C2, n);             /* C2 = C_even(1) */
686 
687       mpz_add (C4, t2, t4);             /* C4 = 2*C_even(2) */
688       F_divby2 (C4, C4, n);             /* C4 = C_even(2) */
689       mpz_sub (t4, t2, t4);             /* t4 = 2*C_odd(2) */
690       F_divby2 (t4, t4, n);
691       F_divby2 (t4, t4, n);             /* t4 = C_odd(2)/2 = C_1 + 4*C_3 + 16*C_5 */
692       F_divby2 (t2, gt, n);             /* t2 = C_odd(1) */
693 
694       mpz_sub (t0, t0, gt);             /* t0 = 65 32 20 14 20 32 65 */
695       mpz_mul_2exp (gt, gt, 4);
696       mpz_sub (t0, t0, gt);             /* t0 = 65 0 20 -18 20 0 65 */
697 
698       mpz_add (gt, C0, C6);             /* gt = C_0 + C_6 */
699       mpz_sub (C2, C2, gt);             /* C2 = C_2 + C_4 */
700       mpz_sub (t0, t0, gt);             /* t0 = 64 0 20 -18 20 0 64 */
701       mpz_mul_2exp (gt, gt, 5);         /* gt = 32*C_0 + 32*C_6 */
702       F_divby2 (t0, t0, n);             /* t0 = 32 0 10 -9 10 0 32 */
703       mpz_sub (t0, t0, gt);             /* t0 = 0 0 10 -9 10 0 0 */
704       mpz_sub (t0, t0, C2);             /* t0 = 0 0 9 -9 9 0 0 */
705       F_divby3_1 (t0, n);
706       F_divby3_1 (t0, n);               /* t0 = 0 0 1 -1 1 0 0 */
707       mpz_sub (t0, C2, t0);             /* t0 = C_3 */
708       mpz_sub (t2, t2, t0);             /* t2 = C_1 + C_5 */
709       mpz_mul_2exp (gt, t0, 2);         /* gt = 4*C_3 */
710       mpz_sub (t4, t4, gt);             /* t4 = C_1 + 16*C_5 */
711       mpz_sub (t4, t4, t2);             /* t4 = 15*C_5 */
712       F_divby3_1 (t4, n);
713       F_divby5_1 (t4, n);               /* t4 = C_5 */
714       mpz_sub (t2, t2, t4);             /* t2 = C_1 */
715 
716       mpz_sub (C4, C4, C0);             /* C4 = 4*C_2 + 16*C_4 + 64*C_6 */
717       F_divby2 (C4, C4, n);
718       F_divby2 (C4, C4, n);             /* C4 = C_2 + 4*C_4 + 16*C_6 */
719 
720       mpz_mul_2exp (gt, C6, 4);
721       mpz_sub (C4, C4, gt);             /* C4 = C_2 + 4*C_4 */
722 
723       mpz_sub (C4, C4, C2);             /* C4 = 3*C_4 */
724       F_divby3_1 (C4, n);               /* C4 = C_4 */
725       mpz_sub (C2, C2, C4);             /* C2 = C_2 */
726     }
727 
728   for (i = 0; i < l - 1; i++)
729     {
730       mpz_add (C1, C1, t2);
731       F_mod_1 (C1, n);
732     }
733   mpz_set (C1, t2);
734   F_mod_1 (C1, n);
735   for (i = l; i < 2 * l - 1; i++)
736     {
737       mpz_add (C1, C1, t2);
738       F_mod_1 (C1, n);
739     }
740 
741   for (i = 0; i < l - 1; i++)
742     {
743       mpz_add (C3, C3, t0);
744       F_mod_1 (C3, n);
745     }
746   mpz_set (C3, t0);
747   F_mod_1 (C3, n);
748   for (i = l; i < 2 * l - 1; i++)
749     {
750       mpz_add (C3, C3, t0);
751       F_mod_1 (C3, n);
752     }
753 
754   for (i = 0; i < l - 1; i++)
755     {
756       mpz_add (C5, C5, t4);
757       F_mod_1 (C5, n);
758     }
759   mpz_set (C5, t4);
760   F_mod_1 (C5, n);
761   for (i = l; i < 2 * l - 1; i++)
762     {
763       mpz_add (C5, C5, t4);
764       F_mod_1 (C5, n);
765     }
766 
767   return r;
768 }
769 
770 
771 /* Karatsuba split. Calls F_mul() to multiply the three pieces.
772    Assume A <> B (there was code for squaring in revision <= 2788. */
773 static unsigned int
F_karatsuba(mpz_t * R,mpz_t * A,mpz_t * B,unsigned int len,unsigned int n,mpz_t * t)774 F_karatsuba (mpz_t *R, mpz_t *A, mpz_t *B, unsigned int len, unsigned int n,
775              mpz_t *t)
776 {
777   unsigned int i, r;
778 
779   ASSERT(len % 2 == 0);
780 
781   len /= 2;
782 
783   for (i = 0; i < len; i++)
784     {
785       mpz_add (t[i],       A[i], A[i + len]); /* t0 = A0 + A1 */
786       mpz_add (t[i + len], B[i], B[i + len]); /* t1 = B0 + B1 */
787     }
788 
789   r = F_mul (t, t, t + len, len, DEFAULT, n, t + 2 * len);
790   /* t[0...2*len-1] = (A0+A1) * (B0+B1) = A0*B0 + A0*B1 + A1*B0 + A1*B1 */
791 
792   if (R != A)
793     {
794       r += F_mul (R, A, B, len, DEFAULT, n, t + 2 * len);
795       /* R[0...2*len-1] = A0 * B0 */
796       r += F_mul (R + 2 * len, A + len, B + len, len, DEFAULT, n, t + 2 * len);
797       /* R[2*len...4*len-1] = A1 * B1, may overwrite B */
798     }
799   else if (R + 2 * len != B)
800     {
801       r += F_mul (R + 2 * len, A + len, B + len, len, DEFAULT, n, t + 2 * len);
802       /* R[2*len...4*len-1] = A1 * B1 */
803       r += F_mul (R, A, B, len, DEFAULT, n, t + 2 * len);
804       /* R[0...2*len-1] = A0 * B0, overwrites A */
805     }
806   else /* R == A && R + 2*len == B */
807     {
808       for (i = 0; i < len; i++)
809         { /* mpz_swap instead? Perhaps undo later? Or interface for F_mul
810              to specify separate result arrays for high/low half? */
811           mpz_set (gt, A[len + i]); /* Swap A1 and B0 */
812           mpz_set (A[len + i], B[i]);
813           mpz_set (B[i], gt);
814         }
815       r += F_mul (R, R, R + len, len, DEFAULT, n, t + 2 * len);
816       /* R[0...2*len-1] = A0 * B0, overwrites A */
817       r += F_mul (R + 2 * len, R + 2 * len, R + 3 * len, len, DEFAULT, n, t + 2 * len);
818       /* R[2*len...4*len-1] = A1 * B1, overwrites B */
819     }
820 
821   /* R[0...2*len-2] == A0*B0, R[2*len-1] == 0 */
822   /* R[2*len...3*len-2] == A1*B1, R[4*len-1] == 0 */
823   /* t[0...2*len-2] == (A0+A1)*(B0+B1), t[2*len-1] == 0 */
824 
825   /* We're doing indices i and i+len in one loop on the assumption
826      that 6 residues will probably fit into cache. After all,
827      Karatsuba is only called for smallish F_m. This way, the final
828      add R[i+len] += t[i] can be done inside the same loop and we need
829      only one pass over main memory. */
830 
831   for (i = 0; i < len - 1; i++)
832     {
833       mpz_sub (t[i], t[i], R[i]); /* t = A0*B1 + A1*B0 + A1*B1 */
834       mpz_sub (t[i], t[i], R[i + 2 * len]); /* t = A0*B1 + A1*B0 */
835       mpz_sub (t[i + len], t[i + len], R[i + len]);
836       mpz_sub (t[i + len], t[i + len ], R[i + 3 * len]);
837 
838       mpz_add (R[i + len], R[i + len], t[i]);
839       mpz_add (R[i + 2 * len], R[i + 2 * len], t[i + len]);
840     }
841   mpz_sub (t[len - 1], t[len - 1], R[len - 1]);
842   mpz_sub (R[2 * len - 1], t[len - 1], R[3 * len - 1]);
843 
844   return r;
845 }
846 
847 /* Multiply two polynomials with coefficients modulo 2^(2^m)+1.
848    len is length (=degree+1) of polynomials and must be a power of 2.
849    n=2^m
850    Return value: number of multiplies performed, or UINT_MAX in case of error.
851 */
852 unsigned int
F_mul(mpz_t * R,mpz_t * A,mpz_t * B,unsigned int len,int parameter,unsigned int n,mpz_t * t)853 F_mul (mpz_t *R, mpz_t *A, mpz_t *B, unsigned int len, int parameter,
854        unsigned int n, mpz_t *t)
855 {
856   unsigned int i, r=0;
857   unsigned int transformlen = (parameter == NOPAD) ? len : 2 * len;
858 #ifdef CHECKSUM
859   mpz_t chksum1, chksum_1, chksum0, chksuminf;
860 #endif
861 
862   /* Handle trivial cases */
863   if (len == 0)
864     return 0;
865 
866   if (!gt_inited)
867     {
868       mpz_init2 (gt, 2 * n);
869       gt_inited = 1;
870     }
871 
872   if (len == 1)
873     {
874       if (parameter == MONIC)
875         {
876           /* (x + a0)(x + b0) = x^2 + (a0 + b0)x + a0*b0 */
877           mpz_add (gt, A[0], B[0]);
878           F_mod_gt (t[0], n);
879           F_mulmod (R[0], A[0], B[0], n); /* May overwrite A[0] */
880           mpz_set (R[1], t[0]); /* May overwrite B[0] */
881           /* We don't store the leading 1 monomial in the result poly */
882         }
883       else
884         {
885           F_mulmod (R[0], A[0], B[0], n); /* May overwrite A[0] */
886           mpz_set_ui (R[1], 0); /* May overwrite B[0] */
887         }
888 
889       return 1;
890     }
891 
892 #ifdef CHECKSUM
893   mpz_init2 (chksum1, n+64);
894   mpz_init2 (chksum_1, n+64);
895   mpz_init2 (chksum0, n+64);
896   mpz_init2 (chksuminf, n+64);
897 
898   mpz_set_ui (gt, 0);
899   for (i = 0; i < len; i++)
900     {
901       /* Compute A(1) and B(1) */
902       mpz_add (chksum1, chksum1, A[i]);
903       mpz_add (gt, gt, B[i]);
904 
905       /* Compute A(-1) and B(-1) */
906       if (i % 2 == 0)
907         {
908           mpz_add (chksum_1, chksum_1, A[i]);
909           mpz_add (chksum0, chksum0, B[i]); /* chksum0 used temporarily here */
910         }
911       else
912         {
913           mpz_sub (chksum_1, chksum_1, A[i]);
914           mpz_sub (chksum0, chksum0, B[i]);
915         }
916     }
917 
918   if (parameter == MONIC)
919     {
920       mpz_add_ui (chksum1, chksum1, 1);
921       mpz_add_ui (gt, gt, 1);
922       mpz_add_ui (chksum_1, chksum_1, 1);
923       mpz_add_ui (chksum0, chksum0, 1);
924     }
925 
926   mpz_mul (gt, gt, chksum1);
927   F_mod_gt (chksum1, n);
928 
929   mpz_mul (gt, chksum0, chksum_1);
930   F_mod_gt (chksum_1, n);
931 
932   /* Compute A(0) * B(0) */
933   mpz_mul (gt, A[0], B[0]);
934   F_mod_gt (chksum0, n);
935 
936   /* Compute A(inf) * B(inf) */
937   mpz_mul (gt, A[len - 1], B[len - 1]);
938   F_mod_gt (chksuminf, n);
939   if (parameter == MONIC)
940     {
941       mpz_add (chksuminf, chksuminf, A[len - 2]);
942       mpz_add (chksuminf, chksuminf, B[len - 2]);
943     }
944 
945   r += 4;
946 #endif /* CHECKSUM */
947 
948   /* Don't do FFT if len <= 4 (Karatsuba or Toom-Cook are faster) unless we
949      do a transform without zero padding, or if transformlen > 4*n
950      (no suitable primitive roots of 1) */
951   if ((len > 4 || parameter == NOPAD) && transformlen <= 4 * n)
952     {
953       unsigned int len2;
954 
955       /* len2 = log_2(transformlen). Assumes transformlen > 0 */
956       for (i = transformlen, len2 = 0; (i&1) == 0; i >>= 1, len2++);
957 
958       if (i != 1)
959         {
960           outputf (OUTPUT_ERROR, "F_mul: polynomial length must be power of 2, "
961                            "but is %d\n", len);
962           return UINT_MAX;
963         }
964 
965       /* Are we performing a squaring or multiplication? */
966       if (A != B)
967         {
968           /* So it's a multiplication */
969 
970           /* Put transform of B into t */
971           for (i = 0; i < len; i++)
972             mpz_set (t[i], B[i]);
973           if (parameter == MONIC)
974             mpz_set_ui (t[i++], 1);
975           for (; i < transformlen; i++)
976             mpz_set_ui (t[i], 0);
977 
978           F_fft_dif (t, transformlen, 0, n);
979         } else
980           t = R; /* Do squaring */
981 
982       /* Put A into R */
983       for (i = 0; i < len; i++)
984         mpz_set (R[i], A[i]);
985       if (parameter == MONIC)
986         mpz_set_ui (R[i++], 1); /* May overwrite B[0] */
987       for (; i < transformlen; i++)
988         mpz_set_ui (R[i], 0); /* May overwrite B[i - len] */
989 
990       F_fft_dif (R, transformlen, 0, n);
991 
992       for (i = 0; i < transformlen; i++)
993         {
994           F_mulmod (R[i], R[i], t[i], n);
995           /* Do the div-by-length. Transform length was transformlen,
996              len2 = log_2 (transformlen), so divide by
997              2^(len2) = sqrt(2)^(2*len2) */
998 
999           /* since transformlen = 2^len2 <= 4*n then for n >= 8 we have
1000              2*len2 <= 2*log2(4*n) < 2n */
1001           F_mul_sqrt2exp (R[i], R[i], 4 * n - 2 * len2, n);
1002         }
1003 
1004       r += transformlen;
1005 
1006       F_fft_dit (R, transformlen, 0, n);
1007 
1008       if (parameter == MONIC)
1009         mpz_sub_ui (R[0], R[0], 1);
1010 
1011     } else { /* Karatsuba or Toom-Cook split */
1012 
1013       if (parameter == NOPAD)
1014         {
1015           outputf (OUTPUT_ERROR, "F_mul: cyclic/short products not supported "
1016                    "by Karatsuba/Toom-Cook\n");
1017           return UINT_MAX;
1018         }
1019 
1020       if (len / n == 4 || len == 2)
1021         r += F_karatsuba (R, A, B, len, n, t);
1022       else
1023         r += F_toomcook4 (R, A, B, len, n, t);
1024 
1025       if (parameter == MONIC) /* Handle the leading monomial the hard way */
1026         {
1027           /* This only works if A, B and R do not overlap */
1028           if (A == R || B == R + len)
1029             {
1030               outputf (OUTPUT_ERROR, "F_mul: monic polynomials with Karatsuba/"
1031                        "Toom-Cook and overlapping input/output not supported\n");
1032               return UINT_MAX;
1033             }
1034           for (i = 0; i < len; i++)
1035             {
1036               mpz_add (R[i + len], R[i + len], A[i]);
1037               mpz_add (R[i + len], R[i + len], B[i]);
1038               F_mod_1 (R[i + len], n);
1039             }
1040         }
1041     }
1042 
1043 #ifdef DEBUG
1044   if (parameter != MONIC && parameter != NOPAD)
1045     {
1046       F_mod_1 (R[transformlen - 1], n);
1047       if (mpz_sgn (R[transformlen - 1]) != 0)
1048         outputf (OUTPUT_ALWAYS, "F_mul, len %d: R[%d] == %Zd != 0\n",
1049 		     len, transformlen - 1, R[transformlen - 1]);
1050     }
1051 #endif
1052 
1053 #ifdef CHECKSUM
1054   /* Compute R(1) = (A*B)(1) and subtract from chksum1 */
1055 
1056   for (i = 0; i < transformlen; i++)
1057     mpz_sub (chksum1, chksum1, R[i]);
1058 
1059   if (parameter == MONIC)
1060     mpz_sub_ui (chksum1, chksum1, 1);
1061 
1062   while (mpz_sizeinbase (chksum1, 2) > n)
1063     F_mod_1 (chksum1, n);
1064 
1065   if (mpz_sgn (chksum1) != 0)
1066     outputf (OUTPUT_ALWAYS, "F_mul, len %d: A(1)*B(1) != R(1), difference %Zd\n",
1067                  len, chksum1);
1068 
1069   /* Compute R(-1) = (A*B)(-1) and subtract from chksum_1 */
1070 
1071   for (i = 0; i < transformlen; i++)
1072     if (i % 2 == 0)
1073       mpz_sub (chksum_1, chksum_1, R[i]);
1074     else
1075       mpz_add (chksum_1, chksum_1, R[i]);
1076 
1077   if (parameter == MONIC)
1078     mpz_sub_ui (chksum_1, chksum_1, 1);
1079 
1080   while (mpz_sizeinbase (chksum_1, 2) > n)
1081     F_mod_1 (chksum_1, n);
1082 
1083   if (mpz_sgn (chksum_1) != 0)
1084     outputf (OUTPUT_ALWAYS, "F_mul, len %d: A(-1)*B(-1) != R(-1), difference %Zd\n",
1085 		 len, chksum_1);
1086 
1087   if (parameter != NOPAD)
1088     {
1089       mpz_sub (chksum0, chksum0, R[0]);
1090       while (mpz_sizeinbase (chksum0, 2) > n)
1091         F_mod_1 (chksum0, n);
1092 
1093       if (mpz_sgn (chksum0) != 0)
1094         outputf (OUTPUT_ALWAYS, "F_mul, len %d: A(0)*B(0) != R(0), difference %Zd\n",
1095                    len, chksum0);
1096 
1097       mpz_sub (chksuminf, chksuminf, R[transformlen - 2]);
1098       while (mpz_sizeinbase (chksuminf, 2) > n)
1099         F_mod_1 (chksuminf, n);
1100 
1101       if (mpz_sgn (chksuminf) != 0)
1102         outputf (OUTPUT_ALWAYS, "F_mul, len %d: A(inf)*B(inf) != R(inf), difference %Zd\n",
1103                     len, chksuminf);
1104     }
1105 
1106   mpz_clear (chksum1);
1107   mpz_clear (chksum_1);
1108   mpz_clear (chksum0);
1109   mpz_clear (chksuminf);
1110 #endif /* CHECKSUM */
1111 
1112   return r;
1113 }
1114 
1115 /* Transposed multiply of two polynomials with coefficients
1116    modulo 2^(2^m)+1.
1117    lenB is the length of polynomial B and must be a power of 2,
1118    lenA is the length of polynomial A and must be lenB / 2 or lenB / 2 + 1.
1119    n=2^m
1120    t must have space for 2*lenB coefficients
1121    Only the product coefficients [lenA - 1 ... lenA + lenB/2 - 2] will go into
1122    R[0 ... lenB / 2 - 1]
1123    Return value: number of multiplies performed, UINT_MAX in error case. */
1124 
1125 unsigned int
F_mul_trans(mpz_t * R,mpz_t * A,mpz_t * B,unsigned int lenA,unsigned int lenB,unsigned int n,mpz_t * t)1126 F_mul_trans (mpz_t *R, mpz_t *A, mpz_t *B, unsigned int lenA,
1127              unsigned int lenB, unsigned int n, mpz_t *t)
1128 {
1129   unsigned int i, r = 0, len2;
1130 
1131   /* Handle trivial cases */
1132   if (lenB < 2)
1133     return 0;
1134 
1135   ASSERT(lenA == lenB / 2 || lenA == lenB / 2 + 1);
1136 
1137   if (!gt_inited)
1138     {
1139       mpz_init2 (gt, 2 * n);
1140       gt_inited = 1;
1141     }
1142 
1143   if (lenB == 2)
1144     {
1145       F_mulmod (R[0], A[0], B[0], n);
1146       return 1;
1147     }
1148 
1149   if (lenB <= 4 * n)
1150     {
1151       /* len2 = log_2(lenB) */
1152       for (i = lenB, len2 = 0; i > 1 && (i&1) == 0; i >>= 1, len2++);
1153 
1154       if (i != 1)
1155         {
1156           outputf (OUTPUT_ERROR, "F_mul_trans: polynomial length must be power of 2, "
1157                            "but is %d\n", lenB);
1158           return UINT_MAX;
1159         }
1160 
1161       /* Put transform of B into t */
1162       for (i = 0; i < lenB; i++)
1163         mpz_set (t[i], B[i]);
1164 
1165       F_fft_dif (t, lenB, 0, n);
1166 
1167       /* Put transform of reversed A into t + lenB */
1168       for (i = 0; i < lenA; i++)
1169         mpz_set (t[i + lenB], A[lenA - 1 - i]);
1170       for (i = lenA; i < lenB; i++)
1171         mpz_set_ui (t[i + lenB], 0);
1172 
1173       F_fft_dif (t + lenB, lenB, 0, n);
1174 
1175       for (i = 0; i < lenB; i++)
1176         {
1177           F_mulmod (t[i], t[i], t[i + lenB], n);
1178           /* Do the div-by-length. Transform length was len, so divide by
1179              2^len2 = sqrt(2)^(2*len2) */
1180           /* since len2 = log2(lenB) and lenB <= 4*n, for n >= 8 we have
1181              2*len2 < 2*n */
1182           F_mul_sqrt2exp (t[i], t[i], 4 * n - 2 * len2, n);
1183         }
1184 
1185       r += lenB;
1186 
1187       F_fft_dit (t, lenB, 0, n);
1188 
1189       for (i = 0; i < lenB / 2; i++)
1190         mpz_set (R[i], t[i + lenA - 1]);
1191 
1192     } else { /* Only Karatsuba, no Toom-Cook here */
1193       unsigned int h = lenB / 4;
1194       const unsigned int lenA0 = h, lenA1 = lenA - h;
1195 
1196       outputf (OUTPUT_DEVVERBOSE, "schoen_strass.c: Transposed Karatsuba, "
1197 	       "lenA = %lu, lenB = %lu\n", lenA, lenB);
1198 
1199       /* A = a1 * x^h + a0
1200          B = b3 * x^3h + b2 * x^2h + b1 * x^h + b0
1201          mul^T(A, B) = mul^T(a0,b3) * x^4h +
1202                       (mul^T(a1,b3) + mul^T(a0,b2)) * x^3h +
1203                       (mul^T(a1,b2) + mul^T(a0,b1)) * x^2h +
1204                       (mul^T(a1,b1) + mul^T(a0,b0)) * x +
1205                        mul^T(a1,b0)
1206          We only want the x^h, x^2h and x^3h coefficients,
1207          mul^T(a1,b1) + mul^T(a0,b0)
1208          mul^T(a1,b2) + mul^T(a0,b1)
1209          mul^T(a1,b3) + mul^T(a0,b2)
1210 
1211          Specifically, we want
1212 	 R[i] = \sum_{j=0}^{lenA} A[j] * B[j+i], 0 <= i < 2h
1213       */
1214 
1215       /* T */
1216       for (i = 0; i < h; i++)
1217         mpz_add (t[i], A[i], A[i + h]);
1218       if (lenA1 == h + 1)
1219 	mpz_set (t[h], A[2*h]);
1220       r = F_mul_trans (t, t, B + h, lenA1, 2 * h, n, t + lenA1);
1221       /* Uses t[h ... 5h-1] as temp */
1222 
1223       /* U */
1224       for (i = 0; i < 2 * h; i++)
1225         mpz_sub (t[i + h], B[i], B[h + i]);
1226       r += F_mul_trans (t + h, A, t + h, lenA0, 2 * h, n, t + 3 * h);
1227       /* Uses t[3h ... 7h-1] as temp */
1228 
1229       for (i = 0; i < h; i++)
1230         mpz_add (R[i], t[i], t[i + h]); /* R[0 ... h-1] = t + r */
1231 
1232       /* V */
1233       for (i = 0; i < 2 * h; i++)
1234         mpz_sub (t[i + h], B[i + 2 * h], B[i + h]);
1235       r += F_mul_trans (t + h, A + h, t + h, lenA1, 2 * h, n, t + 3 * h);
1236       /* Uses t[3h ... 7h - 1] as temp */
1237 
1238       for (i = 0; i < h; i++)
1239         mpz_add (R[i + h], t[i], t[i + h]);
1240     }
1241 
1242   return r;
1243 }
1244 
F_clear()1245 void F_clear ()
1246 {
1247   if (gt_inited)
1248     mpz_clear (gt);
1249   gt_inited = 0;
1250 }
1251