1 #include "cado.h" // IWYU pragma: keep
2 #include <stdio.h>
3 /* This file defines some functions that work more or less the same
4    with mod_ul.h and modredc_ul.h. I.e. mod_div3() and mod_gcd() work
5    unchanged with plain and Montgomery representation (so we can work on
6    the stored residue directly, whatever its representation is);
7    mod_jacobi() converts to plain "unsigned long" first, the others use
8    only mod_*() inline functions.
9    Speed-critical functions need to be rewritten in assembly for REDC,
10    but this is a start.
11 */
12 
13 #include "mod_common.c"
14 
15 int
mod_div3(residue_t r,const residue_t a,const modulus_t m)16 mod_div3 (residue_t r, const residue_t a, const modulus_t m)
17 {
18   const unsigned long a3 = a[0] % 3UL;
19   unsigned long ml, m3;
20   residue_t t;
21 
22   ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
23 
24   ml = mod_getmod_ul (m);
25   m3 = ml % 3UL;
26   if (m3 == 0)
27     return 0;
28 
29   mod_init_noset0 (t, m);
30 
31   mod_set (t, a, m);
32   if (a3 != 0UL)
33     {
34       if (a3 + m3 == 3UL) /* Hence a3 == 1, m3 == 2 or a3 == 2, m3 == 1 */
35 	t[0] = t[0] + ml;
36       else /* a3 == 1, m3 == 1 or a3 == 2, m3 == 2 */
37 	t[0] = t[0] + 2UL * ml;
38     }
39 
40   /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 3.
41      (a+k*m)/3 < 2^w, so doing a division (mod 2^w) produces the
42      correct result. */
43 
44 #if LONG_BIT == 32
45     t[0] *= 0xaaaaaaabUL; /* 1/3 (mod 2^32) */
46 #elif LONG_BIT == 64
47     t[0] *= 0xaaaaaaaaaaaaaaabUL; /* 1/3 (mod 2^64) */
48 #else
49 #error LONG_BIT is neither 32 nor 64
50 #endif
51 
52 #ifdef WANT_ASSERT_EXPENSIVE
53   mod_sub (r, a, t, m);
54   mod_sub (r, r, t, m);
55   mod_sub (r, r, t, m);
56   ASSERT_EXPENSIVE (mod_is0 (r, m));
57 #endif
58 
59   mod_set (r, t, m);
60   mod_clear (t, m);
61 
62   return 1;
63 }
64 
65 
66 int
mod_div5(residue_t r,const residue_t a,const modulus_t m)67 mod_div5 (residue_t r, const residue_t a, const modulus_t m)
68 {
69   unsigned long ml, m5, k;
70   residue_t t;
71   const unsigned long a5 = a[0] % 5UL;
72   const unsigned long inv5[5] = {0,4,2,3,1}; /* inv5[i] = -1/i (mod 5) */
73 
74   ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
75 
76   ml = mod_getmod_ul (m);
77   m5 = ml % 5UL;
78   if (m5 == 0)
79     return 0;
80 
81   mod_init_noset0 (t, m);
82   mod_set (t, a, m);
83   if (a5 != 0UL)
84     {
85       /* We want a+km == 0 (mod 5), so k = -a*m^{-1} (mod 5) */
86       k = (a5 * inv5[m5]) % 5UL;
87       ASSERT_EXPENSIVE ((k*m5 + a5) % 5UL == 0UL);
88       t[0] = a[0] + k * ml;
89     }
90 
91   /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 5.
92      (a+k*m)/5 < 2^w, so doing a division (mod 2^w) produces the
93      correct result. */
94 
95 #if LONG_BIT == 32
96     t[0] *= 0xcccccccdUL; /* 1/5 (mod 2^32) */
97 #elif LONG_BIT == 64
98     t[0] *= 0xcccccccccccccccdUL; /* 1/5 (mod 2^64) */
99 #else
100 #error LONG_BIT is neither 32 nor 64
101 #endif
102 
103 #ifdef WANT_ASSERT_EXPENSIVE
104   ASSERT_EXPENSIVE (t[0] < mod_getmod_ul (m));
105   mod_sub (r, a, t, m);
106   mod_sub (r, r, t, m);
107   mod_sub (r, r, t, m);
108   mod_sub (r, r, t, m);
109   mod_sub (r, r, t, m);
110   ASSERT_EXPENSIVE (mod_is0 (r, m));
111 #endif
112 
113   mod_set (r, t, m);
114   mod_clear (t, m);
115 
116   return 1;
117 }
118 
119 
120 int
mod_div7(residue_t r,const residue_t a,const modulus_t m)121 mod_div7 (residue_t r, const residue_t a, const modulus_t m)
122 {
123   unsigned long ml, m7, k;
124   residue_t t;
125   const unsigned long a7 = a[0] % 7UL;
126   const unsigned long inv7[7] = {0,6,3,2,5,4,1}; /* inv7[i] = -1/i (mod 7) */
127 
128   ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
129 
130   ml = mod_getmod_ul (m);
131   m7 = ml % 7UL;
132   if (m7 == 0)
133     return 0;
134 
135   mod_init_noset0 (t, m);
136   mod_set (t, a, m);
137   if (a7 != 0UL)
138     {
139       /* We want a+km == 0 (mod 7), so k = -a*m^{-1} (mod 7) */
140       k = (a7 * inv7[m7]) % 7UL;
141       ASSERT_EXPENSIVE ((k*m7 + a7) % 7UL == 0UL);
142       t[0] = a[0] + k * ml;
143     }
144 
145   /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 7.
146      (a+k*m)/7 < 2^w, so doing a division (mod 2^w) produces the
147      correct result. */
148 
149 #if LONG_BIT == 32
150     t[0] *= 0xb6db6db7UL; /* 1/7 (mod 2^32) */
151 #elif LONG_BIT == 64
152     t[0] *= 0x6db6db6db6db6db7UL; /* 1/7 (mod 2^64) */
153 #else
154 #error LONG_BIT is neither 32 nor 64
155 #endif
156 
157 #ifdef WANT_ASSERT_EXPENSIVE
158   ASSERT_EXPENSIVE (t[0] < mod_getmod_ul (m));
159   mod_sub (r, a, t, m);
160   mod_sub (r, r, t, m);
161   mod_sub (r, r, t, m);
162   mod_sub (r, r, t, m);
163   mod_sub (r, r, t, m);
164   mod_sub (r, r, t, m);
165   mod_sub (r, r, t, m);
166   ASSERT_EXPENSIVE (mod_is0 (r, m));
167 #endif
168 
169   mod_set (r, t, m);
170   mod_clear (t, m);
171 
172   return 1;
173 }
174 
175 
176 int
mod_div11(residue_t r,const residue_t a,const modulus_t m)177 mod_div11 (residue_t r, const residue_t a, const modulus_t m)
178 {
179   unsigned long ml, m11, k;
180   residue_t t;
181   const unsigned long a11 = a[0] % 11UL;
182   /* inv11[i] = -1/i (mod 11) */
183   const unsigned long inv11[11] = {0, 10, 5, 7, 8, 2, 9, 3, 4, 6, 1};
184 
185   ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
186 
187   ml = mod_getmod_ul (m);
188   m11 = ml % 11UL;
189   if (m11 == 0)
190     return 0;
191 
192   mod_init_noset0 (t, m);
193   mod_set (t, a, m);
194   if (a11 != 0UL)
195     {
196       /* We want a+km == 0 (mod 11), so k = -a*m^{-1} (mod 11) */
197       k = (a11 * inv11[m11]) % 11UL;
198       ASSERT_EXPENSIVE ((k*m11 + a11) % 11UL == 0UL);
199       t[0] = a[0] + k * ml;
200     }
201 
202   /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 11.
203      (a+k*m)/11 < 2^w, so doing a division (mod 2^w) produces the
204      correct result. */
205 
206 #if LONG_BIT == 32
207     t[0] *= 0xba2e8ba3UL; /* 1/11 (mod 2^32) */
208 #elif LONG_BIT == 64
209     t[0] *= 0x2e8ba2e8ba2e8ba3UL; /* 1/11 (mod 2^64) */
210 #else
211 #error LONG_BIT is neither 32 nor 64
212 #endif
213 
214   mod_set (r, t, m);
215   mod_clear (t, m);
216 
217   return 1;
218 }
219 
220 
221 int
mod_div13(residue_t r,const residue_t a,const modulus_t m)222 mod_div13 (residue_t r, const residue_t a, const modulus_t m)
223 {
224   unsigned long ml, m13, k;
225   residue_t t;
226   const unsigned long a13 = a[0] % 13UL;
227   /* inv13[i] = -1/i (mod 13) */
228   const unsigned long inv13[13] = {0, 12, 6, 4, 3, 5, 2, 11, 8, 10, 9, 7, 1};
229 
230   ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
231 
232   ml = mod_getmod_ul (m);
233   m13 = ml % 13UL;
234   if (m13 == 0)
235     return 0;
236 
237   mod_init_noset0 (t, m);
238   mod_set (t, a, m);
239   if (a13 != 0UL)
240     {
241       /* We want a+km == 0 (mod 13), so k = -a*m^{-1} (mod 13) */
242       k = (a13 * inv13[m13]) % 13UL;
243       ASSERT_EXPENSIVE ((k*m13 + a13) % 13UL == 0UL);
244       t[0] = a[0] + k * ml;
245     }
246 
247   /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 13.
248      (a+k*m)/13 < 2^w, so doing a division (mod 2^w) produces the
249      correct result. */
250 
251 #if LONG_BIT == 32
252     t[0] *= 0xc4ec4ec5UL; /* 1/13 (mod 2^32) */
253 #elif LONG_BIT == 64
254     t[0] *= 0x4ec4ec4ec4ec4ec5UL; /* 1/13 (mod 2^64) */
255 #else
256 #error LONG_BIT is neither 32 nor 64
257 #endif
258 
259   mod_set (r, t, m);
260   mod_clear (t, m);
261 
262   return 1;
263 }
264 
265 
266 void
mod_gcd(modint_t g,const residue_t r,const modulus_t m)267 mod_gcd (modint_t g, const residue_t r, const modulus_t m)
268 {
269   unsigned long a, b, t;
270 
271   a = r[0]; /* This works the same for "a" in plain or Montgomery
272                representation */
273   b = mod_getmod_ul (m);
274   /* ASSERT (a < b); Should we require this? */
275   ASSERT (b > 0UL);
276 
277   if (a >= b)
278     a %= b;
279 
280   while (a > 0UL)
281     {
282       /* Here 0 < a < b */
283       t = b % a;
284       b = a;
285       a = t;
286     }
287 
288   g[0] = b;
289 }
290 
291 
292 /* Compute r = 2^e. Here, e is an unsigned long */
293 void
mod_2pow_ul(residue_t r,const unsigned long e,const modulus_t m)294 mod_2pow_ul (residue_t r, const unsigned long e, const modulus_t m)
295 {
296   unsigned long mask;
297   residue_t t, u;
298 
299   if (e == 0UL)
300     {
301       mod_set1 (r, m);
302       return;
303     }
304 
305   mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e);
306 
307   mod_init_noset0 (t, m);
308   mod_init_noset0 (u, m);
309   mod_set1 (t, m);
310   mod_add (t, t, t, m);
311   mask >>= 1;
312 
313   while (mask > 0UL)
314     {
315       mod_sqr (t, t, m);
316       mod_add (u, t, t, m);
317       if (e & mask)
318         mod_set (t, u, m);
319       mask >>= 1;
320     }
321   mod_set (r, t, m);
322   mod_clear (t, m);
323   mod_clear (u, m);
324 }
325 
326 
327 /* Compute r = 3^e. Here, e is an unsigned long */
328 static void
mod_3pow_ul(residue_t r,const unsigned long e,const modulus_t m)329 mod_3pow_ul (residue_t r, const unsigned long e, const modulus_t m)
330 {
331   unsigned long mask;
332   residue_t t, u;
333 
334   if (e == 0UL)
335     {
336       mod_set1 (r, m);
337       return;
338     }
339 
340   mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e);
341 
342   mod_init_noset0 (t, m);
343   mod_init_noset0 (u, m);
344   mod_set1 (u, m);
345   mod_add (t, u, u, m);
346   mod_add (t, t, u, m);
347   mask >>= 1;
348 
349   while (mask > 0UL)
350     {
351       mod_sqr (t, t, m);
352       mod_add (u, t, t, m);
353       mod_add (u, u, t, m);
354       if (e & mask)
355         mod_set (t, u, m);
356       mask >>= 1;
357     }
358   mod_set (r, t, m);
359   mod_clear (t, m);
360   mod_clear (u, m);
361 }
362 
363 
364 /* Computes 2^e (mod m), where e is a multiple precision integer.
365    Requires e != 0. The value of 2 in Montgomery representation
366    (i.e. 2*2^w (mod m) must be passed. */
367 
368 void
mod_2pow_mp(residue_t r,const unsigned long * e,const int e_nrwords,const modulus_t m)369 mod_2pow_mp (residue_t r, const unsigned long *e, const int e_nrwords,
370              const modulus_t m)
371 {
372   residue_t t, u;
373   unsigned long mask;
374   int i = e_nrwords - 1;
375 
376   ASSERT (e_nrwords != 0 && e[i] != 0);
377 
378   mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e[i]);
379 
380   mod_init_noset0 (t, m);
381   mod_init_noset0 (u, m);
382   mod_set1 (t, m);
383   mod_add (t, t, t , m);
384   mask >>= 1;
385 
386   for ( ; i >= 0; i--)
387     {
388       while (mask > 0UL)
389         {
390           mod_sqr (t, t, m);
391           mod_add (u, t, t, m);
392           if (e[i] & mask)
393             mod_set (t, u, m);
394           mask >>= 1;
395         }
396       mask = ~0UL - (~0UL >> 1);
397     }
398 
399   mod_set (r, t, m);
400   mod_clear (t, m);
401   mod_clear (u, m);
402 }
403 
404 
405 /* Returns 1 if m is a strong probable prime wrt base b, 0 otherwise.
406    We assume m is odd. */
407 int
mod_sprp(const residue_t b,const modulus_t m)408 mod_sprp (const residue_t b, const modulus_t m)
409 {
410   residue_t r1, minusone;
411   int i = 0, po2 = 1;
412   unsigned long mm1;
413 
414   mm1 = mod_getmod_ul (m);
415 
416   /* Set mm1 to the odd part of m-1 */
417   mm1 = (mm1 - 1) >> 1;
418   while (mm1 % 2UL == 0UL)
419     {
420       po2++;
421       mm1 >>= 1;
422     }
423   /* Hence, m-1 = mm1 * 2^po2 */
424 
425   mod_init_noset0 (r1, m);
426   mod_init_noset0 (minusone, m);
427   mod_set1 (minusone, m);
428   mod_neg (minusone, minusone, m);
429 
430   /* Exponentiate */
431   mod_pow_ul (r1, b, mm1, m);
432 
433   /* Now r1 == b^mm1 (mod m) */
434 #if defined(PARI)
435   printf ("(Mod(%lu,%lu) ^ %lu) == %lu /* PARI */\n",
436 	  mod_get_ul (b, m), mod_getmod_ul (m), mm1, mod_get_ul (r1, m));
437 #endif
438 
439   i = find_minus1 (r1, minusone, po2, m);
440 
441   mod_clear (r1, m);
442   mod_clear (minusone, m);
443   return i;
444 }
445 
446 
447 /* Returns 1 if m is a strong probable prime wrt base 2, 0 otherwise.
448    Assumes m > 1 and is odd.
449  */
450 int
mod_sprp2(const modulus_t m)451 mod_sprp2 (const modulus_t m)
452 {
453   residue_t r, minusone;
454   int i = 0, po2 = 1;
455   unsigned long mm1;
456 
457   mm1 = mod_getmod_ul (m);
458 
459   /* If m == 1,7 (mod 8), then 2 is a quadratic residue, and we must find
460      -1 with one less squaring. This does not reduce the number of
461      pseudo-primes because strong pseudo-primes are also Euler pseudo-primes,
462      but makes identifying composites a little faster on average. */
463   if (mm1 % 8 == 1 || mm1 % 8 == 7)
464     po2--;
465 
466   /* Set mm1 to the odd part of m-1 */
467   mm1 = (mm1 - 1) >> 1;
468   while (mm1 % 2UL == 0UL)
469     {
470       po2++;
471       mm1 >>= 1;
472     }
473   /* Hence, m-1 = mm1 * 2^po2 */
474 
475   mod_init_noset0 (r, m);
476   mod_init_noset0 (minusone, m);
477   mod_set1 (minusone, m);
478   mod_neg (minusone, minusone, m);
479 
480   /* Exponentiate */
481   mod_2pow_ul (r, mm1, m);
482 
483   /* Now r == b^mm1 (mod m) */
484 #if defined(PARI)
485   printf ("(Mod(2,%lu) ^ %lu) == %lu /* PARI */\n",
486 	  mod_getmod_ul (m), mm1, mod_get_ul (r, m));
487 #endif
488 
489   i = find_minus1 (r, minusone, po2, m);
490 
491   mod_clear (r, m);
492   mod_clear (minusone, m);
493   return i;
494 }
495 
496 
497 int
mod_isprime(const modulus_t m)498 mod_isprime (const modulus_t m)
499 {
500   residue_t b, minusone, r1;
501   const unsigned long n = mod_getmod_ul (m);
502   unsigned long mm1;
503   int r = 0, po2;
504 
505   if (n == 1UL)
506     return 0;
507 
508   if (n % 2UL == 0UL)
509     {
510       r = (n == 2UL);
511 #if defined(PARI)
512       printf ("isprime(%lu) == %d /* PARI */\n", n, r);
513 #endif
514       return r;
515     }
516 
517   /* Set mm1 to the odd part of m-1 */
518   mm1 = n - 1UL;
519   po2 = ularith_ctz (mm1);
520   mm1 >>= po2;
521 
522   mod_init_noset0 (b, m);
523   mod_init_noset0 (minusone, m);
524   mod_init_noset0 (r1, m);
525   mod_set1 (minusone, m);
526   mod_neg (minusone, minusone, m);
527 
528   /* Do base 2 SPRP test */
529   mod_2pow_ul (r1, mm1, m);   /* r = 2^mm1 mod m */
530   /* If n is prime and 1 or 7 (mod 8), then 2 is a square (mod n)
531      and one less squaring must suffice. This does not strengthen the
532      test but saves one squaring for composite input */
533   if (n % 8 == 7)
534     {
535       if (!mod_is1 (r1, m))
536         goto end;
537     }
538   else if (!find_minus1 (r1, minusone, po2 - ((n % 8 == 1) ? 1 : 0), m))
539     goto end; /* Not prime */
540 
541   if (n < 2047UL)
542     {
543       r = 1;
544       goto end;
545     }
546 
547   /* Base 3 is poor at identifying composites == 1 (mod 3), but good at
548      identifying composites == 2 (mod 3). Thus we use it only for 2 (mod 3) */
549   if (n % 3UL == 1UL)
550     {
551       mod_set1 (b, m);
552       mod_add (b, b, b, m);
553       mod_add (b, b, b, m);
554       mod_add (b, b, b, m);
555       mod_add (b, b, minusone, m);  /* b = 7 */
556       mod_pow_ul (r1, b, mm1, m);   /* r = 7^mm1 mod m */
557       if (!find_minus1 (r1, minusone, po2, m))
558 	goto end; /* Not prime */
559 
560       if (n < 2269093UL)
561         {
562 	  r = (n != 314821UL);
563 	  goto end;
564         }
565 
566       /* b is still 7 here */
567       mod_add (b, b, b, m); /* 14 */
568       mod_sub (b, b, minusone, m); /* 15 */
569       mod_add (b, b, b, m); /* 30 */
570       mod_add (b, b, b, m); /* 60 */
571       mod_sub (b, b, minusone, m); /* 61 */
572       mod_pow_ul (r1, b, mm1, m);   /* r = 61^mm1 mod m */
573       if (!find_minus1 (r1, minusone, po2, m))
574 	goto end; /* Not prime */
575 
576 #if (ULONG_MAX > 4294967295UL)
577       if (n != 4759123141UL && n != 8411807377UL && n < 11207066041UL)
578         {
579 	  r = 1;
580 	  goto end;
581         }
582 
583       mod_set1 (b, m);
584       mod_add (b, b, b, m);
585       mod_add (b, b, b, m);
586       mod_sub (b, b, minusone, m);    /* b = 5 */
587       mod_pow_ul (r1, b, mm1, m);   /* r = 5^mm1 mod m */
588       if (!find_minus1 (r1, minusone, po2, m))
589 	goto end; /* Not prime */
590 
591           /* These are the base 5,7,61 SPSP < 10^13 and n == 1 (mod 3) */
592 	  r = (n != 30926647201UL && n != 45821738881UL &&
593 	       n != 74359744201UL && n != 90528271681UL &&
594 	       n != 110330267041UL && n != 373303331521UL &&
595 	       n != 440478111067UL && n != 1436309367751UL &&
596 	       n != 1437328758421UL && n != 1858903385041UL &&
597 	       n != 4897239482521UL && n != 5026103290981UL &&
598 	       n != 5219055617887UL && n != 5660137043641UL &&
599 	       n != 6385803726241UL);
600 #else
601 	      r = 1;
602 #endif
603     }
604   else
605     {
606       /* Case n % 3 == 0, 2 */
607 
608       mod_3pow_ul (r1, mm1, m);   /* r = 3^mm1 mod m */
609       if (!find_minus1 (r1, minusone, po2, m))
610 	goto end; /* Not prime */
611 
612       if (n < 102690677UL && n != 5173601UL && n != 16070429UL &&
613           n != 54029741UL)
614         {
615 	  r = 1;
616 	  goto end;
617 	}
618 
619       mod_set1 (b, m);
620       mod_add (b, b, b, m);
621       mod_add (b, b, b, m);
622       mod_sub (b, b, minusone, m);    /* b = 5 */
623       mod_pow_ul (r1, b, mm1, m);   /* r = 5^mm1 mod m */
624       if (!find_minus1 (r1, minusone, po2, m))
625 	goto end; /* Not prime */
626 
627 #if (ULONG_MAX > 4294967295UL)
628       /* These are the base 3,5 SPSP < 10^13 with n == 2 (mod 3) */
629       r = (n != 244970876021UL && n != 405439595861UL &&
630 	   n != 1566655993781UL && n != 3857382025841UL &&
631 	   n != 4074652846961UL && n != 5783688565841UL);
632 #else
633       r = 1;
634 #endif
635     }
636 
637  end:
638 #if defined(PARI)
639   printf ("isprime(%lu) == %d /* PARI */\n", n, r);
640 #endif
641   mod_clear (b, m);
642   mod_clear (minusone, m);
643   mod_clear (r1, m);
644   return r;
645 }
646 
647 #if 1
648 
649 int
mod_jacobi(const residue_t a_par,const modulus_t m_par)650 mod_jacobi (const residue_t a_par, const modulus_t m_par)
651 {
652   unsigned long x, m;
653   unsigned int s, j;
654 
655   /* Get residue in Montgomery form directly without converting */
656   x = a_par[0];
657   m = mod_getmod_ul (m_par);
658   ASSERT (x < m);
659   ASSERT(m % 2 == 1);
660 
661   j = ularith_ctz(x);
662   x = x >> j;
663   /* If we divide by an odd power of 2, and 2 is a QNR, flip sign */
664   /* 2 is a QNR (mod m) iff m = 3,5 (mod 8)
665      m = 1 = 001b:   1
666      m = 3 = 011b:  -1
667      m = 5 = 101b:  -1
668      m = 7 = 111b:   1
669      Hence we can store in s the exponent of -1, i.e., s=0 for jacobi()=1
670      and s=1 for jacobi()=-1, and update s ^= (m>>1) & (m>>2) & 1.
671      We can do the &1 at the very end.
672      In fact, we store the exponent of -1 in the second bit of s.
673      The s ^= ((j<<1) & (m ^ (m>>1))) still needs 2 shift but one of them can
674      be done with LEA, and f = s ^ (x&m) needs no shift */
675 
676   s = ((j<<1) & (m ^ (m>>1)));
677 
678   while (x > 1) {
679     /* Here, x < m, x and m are odd */
680 
681     /* Implicitly swap by reversing roles of x and m in next loop */
682     /* Flip sign if both are 3 (mod 4) */
683     s = s ^ (x&m);
684 
685     /* Make m smaller by subtracting and shifting */
686     do {
687       m -= x; /* Difference is even */
688       if (m == 0)
689         break;
690       /* Make odd again */
691       j = ularith_ctz(m);
692       s ^= ((j<<1) & (x ^ (x>>1)));
693       m >>= j;
694     } while (m >= x);
695 
696     if (m <= 1) {
697       x = m;
698       break;
699     }
700 
701     /* Flip sign if both are 3 (mod 4) */
702     /* Implicitly swap again */
703     s = s ^ (x&m);
704 
705     /* Make x<m  by subtracting and shifting */
706     do {
707       x -= m; /* Difference is even */
708       if (x == 0)
709         break;
710       /* Make odd again */
711       j = ularith_ctz(x);
712       s ^= ((j<<1) & (m ^ (m>>1)));
713       x >>= j;
714     } while (x >= m);
715   }
716 
717   if (x == 0)
718     return 0;
719   return ((s & 2) == 0) ? 1 : -1;
720 }
721 
722 #else
723 
724 /*
725 #!/usr/bin/env python3
726 # Python program to create mult.h
727 
728 def rate(k,b):
729   # The 0.25 magic constant here tries to estimate the ratio m/x,
730   # to minimize (x+c*m)/2^b
731   r = (abs(k)*0.25 + 1.)/2**b
732   # print ("rate(" + str(k) + ", " + str(b) + ") = " + str(r))
733   return(r)
734 
735 def bestb(k):
736   best_b = 0
737   best_r = 1
738   best_c = 0
739   for b in range(1, 8):
740     c = k % (2**b)
741     r = rate(c, b)
742     if r < best_r:
743       best_r=r
744       best_b = b
745       best_c = c
746     c = - (2**b - c)
747     r = rate(c, b)
748     if r < best_r:
749       best_r=r
750       best_b = b
751       best_c = c
752   # print ("bestb(" + str(k) + ") = " + str(best_b))
753   return([k, best_b, best_c, 1/best_r])
754 
755 
756 r = [str(-bestb(2*i+1)[2]) for i in range(0, 128) ]
757 print("static char mult[128] = {" + ", ".join(r) + "};")
758 */
759 
760 #include "mult.h"
761 #include "macros.h"
762 static unsigned char invmod8[256] = {
763 0, 1, 0, 171, 0, 205, 0, 183, 0, 57, 0, 163, 0, 197, 0, 239, 0, 241, 0, 27, 0, 61, 0, 167, 0, 41, 0, 19, 0, 53, 0, 223, 0, 225, 0, 139, 0, 173, 0, 151, 0, 25, 0, 131, 0, 165, 0, 207, 0, 209, 0, 251, 0, 29, 0, 135, 0, 9, 0, 243, 0, 21, 0, 191, 0, 193, 0, 107, 0, 141, 0, 119, 0, 249, 0, 99, 0, 133, 0, 175, 0, 177, 0, 219, 0, 253, 0, 103, 0, 233, 0, 211, 0, 245, 0, 159, 0, 161, 0, 75, 0, 109, 0, 87, 0, 217, 0, 67, 0, 101, 0, 143, 0, 145, 0, 187, 0, 221, 0, 71, 0, 201, 0, 179, 0, 213, 0, 127, 0, 129, 0, 43, 0, 77, 0, 55, 0, 185, 0, 35, 0, 69, 0, 111, 0, 113, 0, 155, 0, 189, 0, 39, 0, 169, 0, 147, 0, 181, 0, 95, 0, 97, 0, 11, 0, 45, 0, 23, 0, 153, 0, 3, 0, 37, 0, 79, 0, 81, 0, 123, 0, 157, 0, 7, 0, 137, 0, 115, 0, 149, 0, 63, 0, 65, 0, 235, 0, 13, 0, 247, 0, 121, 0, 227, 0, 5, 0, 47, 0, 49, 0, 91, 0, 125, 0, 231, 0, 105, 0, 83, 0, 117, 0, 31, 0, 33, 0, 203, 0, 237, 0, 215, 0, 89, 0, 195, 0, 229, 0, 15, 0, 17, 0, 59, 0, 93, 0, 199, 0, 73, 0, 51, 0, 85, 0, 255
764 };
765 
766 static inline int
s_val(unsigned int s)767 s_val(unsigned int s) {
768   return ((s & 2) == 0) ? 1 : -1;
769 }
770 
771 static int
mod_jacobi1(const residue_t a_par,const modulus_t m_par)772 mod_jacobi1 (const residue_t a_par, const modulus_t m_par)
773 {
774   unsigned long x, m;
775   unsigned int s, j;
776 
777   x = mod_get_ul (a_par, m_par);
778   m = mod_getmod_ul (m_par);
779   ASSERT (x < m);
780   ASSERT(m % 2 == 1);
781 
782   j = ularith_ctz(x);
783   x = x >> j;
784 
785   s = ((j<<1) & (m ^ (m>>1)));
786 
787   while (x > 1) {
788     unsigned long t;
789     unsigned char inv;
790 
791     // printf ("kronecker(%lu, %lu) == %d * kronecker(%lu, %lu)\n",
792     //        mod_get_ul(a_par, m_par), mod_getmod_ul (m_par), s_val(s), x, m);
793     /* Here x < m. Swap to make x > m */
794     t = m;
795     m = x;
796     x = t;
797     s = s ^ (x&m);
798 
799     /* Now x > m */
800     inv = invmod8[(unsigned char)m];
801     do {
802       /* Do a REDC-like step. We want a multiplier k such that the low
803          8 bits of x+k*m are all zero.
804          That is, we want k = -x/m (mod 2^8). */
805       unsigned char k;
806       unsigned long t1;
807       long int c, t2;
808       // const unsigned long old_x = x;
809 
810       k = inv * (unsigned char)x;
811       // ASSERT_ALWAYS((k & 1) == 1);
812       c = mult[k / 2];
813 
814       /* Compute x+cm */
815       long tmp = c >> 63;
816       ularith_mul_ul_ul_2ul (&t1, (unsigned long *)&t2, (c ^ tmp) - tmp, m);
817       t2 ^= tmp;
818       t1 ^= tmp;
819       ularith_add_ul_2ul (&t1, (unsigned long *)&t2, x-tmp);
820       tmp = ((long) t2) >> 63;
821 
822       t2 ^= tmp;
823       t1 ^= tmp;
824       s ^= m & tmp;
825       ularith_add_ul_2ul (&t1, (unsigned long *)&t2, -tmp);
826       // ASSERT_ALWAYS(t2 >= 0);
827 
828       if (t1 == 0) {
829         if (t2 == 0) {
830           x = 0;
831           break;
832         }
833         t1 = t2;
834         /* Divided by 2^64 which is square, so no adjustment to s */
835         t2 = 0;
836       }
837 
838       j = ularith_ctz(t1);
839       ularith_shrd (&t1, t2, j);
840       // ASSERT_ALWAYS((t2 >> j) == 0);
841       x = t1;
842       s ^= ((j<<1) & (m ^ (m>>1)));
843       // ASSERT_ALWAYS(x < old_x);
844       // printf ("%f\n", (double)old_x / (double)x);
845     } while (x >= m);
846   }
847 
848   if (x == 0)
849     return 0;
850   return s_val(s);
851 }
852 
853 int
mod_jacobi(const residue_t a_par,const modulus_t m_par)854 mod_jacobi (const residue_t a_par, const modulus_t m_par)
855 {
856   unsigned long x, m;
857   unsigned int s, j;
858 
859   x = mod_get_ul (a_par, m_par);
860   m = mod_getmod_ul (m_par);
861   ASSERT (x < m);
862   ASSERT(m % 2 == 1);
863 
864   if ((LONG_MAX - x) / 50 < m)
865     return mod_jacobi1 (a_par, m_par);
866 
867   j = ularith_ctz(x);
868   x = x >> j;
869 
870   s = ((j<<1) & (m ^ (m>>1)));
871 
872   while (x > 1) {
873     unsigned long t;
874     unsigned char inv;
875 
876     // printf ("kronecker(%lu, %lu) == %d * kronecker(%lu, %lu)\n",
877     //        mod_get_ul(a_par, m_par), mod_getmod_ul (m_par), s_val(s), x, m);
878     /* Here x < m. Swap to make x > m */
879     t = m;
880     m = x;
881     x = t;
882     s = s ^ (x&m);
883 
884     /* Now x > m */
885     inv = invmod8[(unsigned char)m];
886     do {
887       /* Do a REDC-like step. We want a multiplier k such that the low
888          8 bits of x+k*m are all zero.
889          That is, we want k = -x/m (mod 2^8). */
890       unsigned char k;
891       long int c;
892       // const unsigned long old_x = x;
893 
894       k = inv * x;
895       // ASSERT_ALWAYS((k & 1) == 1);
896       c = mult[k / 2];
897 
898       c = x + c*m;
899       x = c;
900       c >>= 63;
901       x = (x ^ c) - c;
902 
903       if (x == 0) {
904         break;
905       }
906       s ^= m & c;
907 
908       j = ularith_ctz(x);
909       x >>= j;
910       s ^= ((j<<1) & (m ^ (m>>1)));
911       // printf ("%f\n", (double)old_x / (double)x);
912     } while (x >= m);
913   }
914 
915   if (x == 0)
916     return 0;
917   return s_val(s);
918 }
919 
920 #endif
921