1 #include "cado.h" // IWYU pragma: keep
2 #include <stdio.h>
3 /* This file defines some functions that work more or less the same
4 with mod_ul.h and modredc_ul.h. I.e. mod_div3() and mod_gcd() work
5 unchanged with plain and Montgomery representation (so we can work on
6 the stored residue directly, whatever its representation is);
7 mod_jacobi() converts to plain "unsigned long" first, the others use
8 only mod_*() inline functions.
9 Speed-critical functions need to be rewritten in assembly for REDC,
10 but this is a start.
11 */
12
13 #include "mod_common.c"
14
15 int
mod_div3(residue_t r,const residue_t a,const modulus_t m)16 mod_div3 (residue_t r, const residue_t a, const modulus_t m)
17 {
18 const unsigned long a3 = a[0] % 3UL;
19 unsigned long ml, m3;
20 residue_t t;
21
22 ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
23
24 ml = mod_getmod_ul (m);
25 m3 = ml % 3UL;
26 if (m3 == 0)
27 return 0;
28
29 mod_init_noset0 (t, m);
30
31 mod_set (t, a, m);
32 if (a3 != 0UL)
33 {
34 if (a3 + m3 == 3UL) /* Hence a3 == 1, m3 == 2 or a3 == 2, m3 == 1 */
35 t[0] = t[0] + ml;
36 else /* a3 == 1, m3 == 1 or a3 == 2, m3 == 2 */
37 t[0] = t[0] + 2UL * ml;
38 }
39
40 /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 3.
41 (a+k*m)/3 < 2^w, so doing a division (mod 2^w) produces the
42 correct result. */
43
44 #if LONG_BIT == 32
45 t[0] *= 0xaaaaaaabUL; /* 1/3 (mod 2^32) */
46 #elif LONG_BIT == 64
47 t[0] *= 0xaaaaaaaaaaaaaaabUL; /* 1/3 (mod 2^64) */
48 #else
49 #error LONG_BIT is neither 32 nor 64
50 #endif
51
52 #ifdef WANT_ASSERT_EXPENSIVE
53 mod_sub (r, a, t, m);
54 mod_sub (r, r, t, m);
55 mod_sub (r, r, t, m);
56 ASSERT_EXPENSIVE (mod_is0 (r, m));
57 #endif
58
59 mod_set (r, t, m);
60 mod_clear (t, m);
61
62 return 1;
63 }
64
65
66 int
mod_div5(residue_t r,const residue_t a,const modulus_t m)67 mod_div5 (residue_t r, const residue_t a, const modulus_t m)
68 {
69 unsigned long ml, m5, k;
70 residue_t t;
71 const unsigned long a5 = a[0] % 5UL;
72 const unsigned long inv5[5] = {0,4,2,3,1}; /* inv5[i] = -1/i (mod 5) */
73
74 ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
75
76 ml = mod_getmod_ul (m);
77 m5 = ml % 5UL;
78 if (m5 == 0)
79 return 0;
80
81 mod_init_noset0 (t, m);
82 mod_set (t, a, m);
83 if (a5 != 0UL)
84 {
85 /* We want a+km == 0 (mod 5), so k = -a*m^{-1} (mod 5) */
86 k = (a5 * inv5[m5]) % 5UL;
87 ASSERT_EXPENSIVE ((k*m5 + a5) % 5UL == 0UL);
88 t[0] = a[0] + k * ml;
89 }
90
91 /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 5.
92 (a+k*m)/5 < 2^w, so doing a division (mod 2^w) produces the
93 correct result. */
94
95 #if LONG_BIT == 32
96 t[0] *= 0xcccccccdUL; /* 1/5 (mod 2^32) */
97 #elif LONG_BIT == 64
98 t[0] *= 0xcccccccccccccccdUL; /* 1/5 (mod 2^64) */
99 #else
100 #error LONG_BIT is neither 32 nor 64
101 #endif
102
103 #ifdef WANT_ASSERT_EXPENSIVE
104 ASSERT_EXPENSIVE (t[0] < mod_getmod_ul (m));
105 mod_sub (r, a, t, m);
106 mod_sub (r, r, t, m);
107 mod_sub (r, r, t, m);
108 mod_sub (r, r, t, m);
109 mod_sub (r, r, t, m);
110 ASSERT_EXPENSIVE (mod_is0 (r, m));
111 #endif
112
113 mod_set (r, t, m);
114 mod_clear (t, m);
115
116 return 1;
117 }
118
119
120 int
mod_div7(residue_t r,const residue_t a,const modulus_t m)121 mod_div7 (residue_t r, const residue_t a, const modulus_t m)
122 {
123 unsigned long ml, m7, k;
124 residue_t t;
125 const unsigned long a7 = a[0] % 7UL;
126 const unsigned long inv7[7] = {0,6,3,2,5,4,1}; /* inv7[i] = -1/i (mod 7) */
127
128 ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
129
130 ml = mod_getmod_ul (m);
131 m7 = ml % 7UL;
132 if (m7 == 0)
133 return 0;
134
135 mod_init_noset0 (t, m);
136 mod_set (t, a, m);
137 if (a7 != 0UL)
138 {
139 /* We want a+km == 0 (mod 7), so k = -a*m^{-1} (mod 7) */
140 k = (a7 * inv7[m7]) % 7UL;
141 ASSERT_EXPENSIVE ((k*m7 + a7) % 7UL == 0UL);
142 t[0] = a[0] + k * ml;
143 }
144
145 /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 7.
146 (a+k*m)/7 < 2^w, so doing a division (mod 2^w) produces the
147 correct result. */
148
149 #if LONG_BIT == 32
150 t[0] *= 0xb6db6db7UL; /* 1/7 (mod 2^32) */
151 #elif LONG_BIT == 64
152 t[0] *= 0x6db6db6db6db6db7UL; /* 1/7 (mod 2^64) */
153 #else
154 #error LONG_BIT is neither 32 nor 64
155 #endif
156
157 #ifdef WANT_ASSERT_EXPENSIVE
158 ASSERT_EXPENSIVE (t[0] < mod_getmod_ul (m));
159 mod_sub (r, a, t, m);
160 mod_sub (r, r, t, m);
161 mod_sub (r, r, t, m);
162 mod_sub (r, r, t, m);
163 mod_sub (r, r, t, m);
164 mod_sub (r, r, t, m);
165 mod_sub (r, r, t, m);
166 ASSERT_EXPENSIVE (mod_is0 (r, m));
167 #endif
168
169 mod_set (r, t, m);
170 mod_clear (t, m);
171
172 return 1;
173 }
174
175
176 int
mod_div11(residue_t r,const residue_t a,const modulus_t m)177 mod_div11 (residue_t r, const residue_t a, const modulus_t m)
178 {
179 unsigned long ml, m11, k;
180 residue_t t;
181 const unsigned long a11 = a[0] % 11UL;
182 /* inv11[i] = -1/i (mod 11) */
183 const unsigned long inv11[11] = {0, 10, 5, 7, 8, 2, 9, 3, 4, 6, 1};
184
185 ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
186
187 ml = mod_getmod_ul (m);
188 m11 = ml % 11UL;
189 if (m11 == 0)
190 return 0;
191
192 mod_init_noset0 (t, m);
193 mod_set (t, a, m);
194 if (a11 != 0UL)
195 {
196 /* We want a+km == 0 (mod 11), so k = -a*m^{-1} (mod 11) */
197 k = (a11 * inv11[m11]) % 11UL;
198 ASSERT_EXPENSIVE ((k*m11 + a11) % 11UL == 0UL);
199 t[0] = a[0] + k * ml;
200 }
201
202 /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 11.
203 (a+k*m)/11 < 2^w, so doing a division (mod 2^w) produces the
204 correct result. */
205
206 #if LONG_BIT == 32
207 t[0] *= 0xba2e8ba3UL; /* 1/11 (mod 2^32) */
208 #elif LONG_BIT == 64
209 t[0] *= 0x2e8ba2e8ba2e8ba3UL; /* 1/11 (mod 2^64) */
210 #else
211 #error LONG_BIT is neither 32 nor 64
212 #endif
213
214 mod_set (r, t, m);
215 mod_clear (t, m);
216
217 return 1;
218 }
219
220
221 int
mod_div13(residue_t r,const residue_t a,const modulus_t m)222 mod_div13 (residue_t r, const residue_t a, const modulus_t m)
223 {
224 unsigned long ml, m13, k;
225 residue_t t;
226 const unsigned long a13 = a[0] % 13UL;
227 /* inv13[i] = -1/i (mod 13) */
228 const unsigned long inv13[13] = {0, 12, 6, 4, 3, 5, 2, 11, 8, 10, 9, 7, 1};
229
230 ASSERT_EXPENSIVE (a[0] < mod_getmod_ul (m));
231
232 ml = mod_getmod_ul (m);
233 m13 = ml % 13UL;
234 if (m13 == 0)
235 return 0;
236
237 mod_init_noset0 (t, m);
238 mod_set (t, a, m);
239 if (a13 != 0UL)
240 {
241 /* We want a+km == 0 (mod 13), so k = -a*m^{-1} (mod 13) */
242 k = (a13 * inv13[m13]) % 13UL;
243 ASSERT_EXPENSIVE ((k*m13 + a13) % 13UL == 0UL);
244 t[0] = a[0] + k * ml;
245 }
246
247 /* Now t[0] == a+k*m (mod 2^w) so that a+k*m is divisible by 13.
248 (a+k*m)/13 < 2^w, so doing a division (mod 2^w) produces the
249 correct result. */
250
251 #if LONG_BIT == 32
252 t[0] *= 0xc4ec4ec5UL; /* 1/13 (mod 2^32) */
253 #elif LONG_BIT == 64
254 t[0] *= 0x4ec4ec4ec4ec4ec5UL; /* 1/13 (mod 2^64) */
255 #else
256 #error LONG_BIT is neither 32 nor 64
257 #endif
258
259 mod_set (r, t, m);
260 mod_clear (t, m);
261
262 return 1;
263 }
264
265
266 void
mod_gcd(modint_t g,const residue_t r,const modulus_t m)267 mod_gcd (modint_t g, const residue_t r, const modulus_t m)
268 {
269 unsigned long a, b, t;
270
271 a = r[0]; /* This works the same for "a" in plain or Montgomery
272 representation */
273 b = mod_getmod_ul (m);
274 /* ASSERT (a < b); Should we require this? */
275 ASSERT (b > 0UL);
276
277 if (a >= b)
278 a %= b;
279
280 while (a > 0UL)
281 {
282 /* Here 0 < a < b */
283 t = b % a;
284 b = a;
285 a = t;
286 }
287
288 g[0] = b;
289 }
290
291
292 /* Compute r = 2^e. Here, e is an unsigned long */
293 void
mod_2pow_ul(residue_t r,const unsigned long e,const modulus_t m)294 mod_2pow_ul (residue_t r, const unsigned long e, const modulus_t m)
295 {
296 unsigned long mask;
297 residue_t t, u;
298
299 if (e == 0UL)
300 {
301 mod_set1 (r, m);
302 return;
303 }
304
305 mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e);
306
307 mod_init_noset0 (t, m);
308 mod_init_noset0 (u, m);
309 mod_set1 (t, m);
310 mod_add (t, t, t, m);
311 mask >>= 1;
312
313 while (mask > 0UL)
314 {
315 mod_sqr (t, t, m);
316 mod_add (u, t, t, m);
317 if (e & mask)
318 mod_set (t, u, m);
319 mask >>= 1;
320 }
321 mod_set (r, t, m);
322 mod_clear (t, m);
323 mod_clear (u, m);
324 }
325
326
327 /* Compute r = 3^e. Here, e is an unsigned long */
328 static void
mod_3pow_ul(residue_t r,const unsigned long e,const modulus_t m)329 mod_3pow_ul (residue_t r, const unsigned long e, const modulus_t m)
330 {
331 unsigned long mask;
332 residue_t t, u;
333
334 if (e == 0UL)
335 {
336 mod_set1 (r, m);
337 return;
338 }
339
340 mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e);
341
342 mod_init_noset0 (t, m);
343 mod_init_noset0 (u, m);
344 mod_set1 (u, m);
345 mod_add (t, u, u, m);
346 mod_add (t, t, u, m);
347 mask >>= 1;
348
349 while (mask > 0UL)
350 {
351 mod_sqr (t, t, m);
352 mod_add (u, t, t, m);
353 mod_add (u, u, t, m);
354 if (e & mask)
355 mod_set (t, u, m);
356 mask >>= 1;
357 }
358 mod_set (r, t, m);
359 mod_clear (t, m);
360 mod_clear (u, m);
361 }
362
363
364 /* Computes 2^e (mod m), where e is a multiple precision integer.
365 Requires e != 0. The value of 2 in Montgomery representation
366 (i.e. 2*2^w (mod m) must be passed. */
367
368 void
mod_2pow_mp(residue_t r,const unsigned long * e,const int e_nrwords,const modulus_t m)369 mod_2pow_mp (residue_t r, const unsigned long *e, const int e_nrwords,
370 const modulus_t m)
371 {
372 residue_t t, u;
373 unsigned long mask;
374 int i = e_nrwords - 1;
375
376 ASSERT (e_nrwords != 0 && e[i] != 0);
377
378 mask = (1UL << (LONG_BIT - 1)) >> ularith_clz (e[i]);
379
380 mod_init_noset0 (t, m);
381 mod_init_noset0 (u, m);
382 mod_set1 (t, m);
383 mod_add (t, t, t , m);
384 mask >>= 1;
385
386 for ( ; i >= 0; i--)
387 {
388 while (mask > 0UL)
389 {
390 mod_sqr (t, t, m);
391 mod_add (u, t, t, m);
392 if (e[i] & mask)
393 mod_set (t, u, m);
394 mask >>= 1;
395 }
396 mask = ~0UL - (~0UL >> 1);
397 }
398
399 mod_set (r, t, m);
400 mod_clear (t, m);
401 mod_clear (u, m);
402 }
403
404
405 /* Returns 1 if m is a strong probable prime wrt base b, 0 otherwise.
406 We assume m is odd. */
407 int
mod_sprp(const residue_t b,const modulus_t m)408 mod_sprp (const residue_t b, const modulus_t m)
409 {
410 residue_t r1, minusone;
411 int i = 0, po2 = 1;
412 unsigned long mm1;
413
414 mm1 = mod_getmod_ul (m);
415
416 /* Set mm1 to the odd part of m-1 */
417 mm1 = (mm1 - 1) >> 1;
418 while (mm1 % 2UL == 0UL)
419 {
420 po2++;
421 mm1 >>= 1;
422 }
423 /* Hence, m-1 = mm1 * 2^po2 */
424
425 mod_init_noset0 (r1, m);
426 mod_init_noset0 (minusone, m);
427 mod_set1 (minusone, m);
428 mod_neg (minusone, minusone, m);
429
430 /* Exponentiate */
431 mod_pow_ul (r1, b, mm1, m);
432
433 /* Now r1 == b^mm1 (mod m) */
434 #if defined(PARI)
435 printf ("(Mod(%lu,%lu) ^ %lu) == %lu /* PARI */\n",
436 mod_get_ul (b, m), mod_getmod_ul (m), mm1, mod_get_ul (r1, m));
437 #endif
438
439 i = find_minus1 (r1, minusone, po2, m);
440
441 mod_clear (r1, m);
442 mod_clear (minusone, m);
443 return i;
444 }
445
446
447 /* Returns 1 if m is a strong probable prime wrt base 2, 0 otherwise.
448 Assumes m > 1 and is odd.
449 */
450 int
mod_sprp2(const modulus_t m)451 mod_sprp2 (const modulus_t m)
452 {
453 residue_t r, minusone;
454 int i = 0, po2 = 1;
455 unsigned long mm1;
456
457 mm1 = mod_getmod_ul (m);
458
459 /* If m == 1,7 (mod 8), then 2 is a quadratic residue, and we must find
460 -1 with one less squaring. This does not reduce the number of
461 pseudo-primes because strong pseudo-primes are also Euler pseudo-primes,
462 but makes identifying composites a little faster on average. */
463 if (mm1 % 8 == 1 || mm1 % 8 == 7)
464 po2--;
465
466 /* Set mm1 to the odd part of m-1 */
467 mm1 = (mm1 - 1) >> 1;
468 while (mm1 % 2UL == 0UL)
469 {
470 po2++;
471 mm1 >>= 1;
472 }
473 /* Hence, m-1 = mm1 * 2^po2 */
474
475 mod_init_noset0 (r, m);
476 mod_init_noset0 (minusone, m);
477 mod_set1 (minusone, m);
478 mod_neg (minusone, minusone, m);
479
480 /* Exponentiate */
481 mod_2pow_ul (r, mm1, m);
482
483 /* Now r == b^mm1 (mod m) */
484 #if defined(PARI)
485 printf ("(Mod(2,%lu) ^ %lu) == %lu /* PARI */\n",
486 mod_getmod_ul (m), mm1, mod_get_ul (r, m));
487 #endif
488
489 i = find_minus1 (r, minusone, po2, m);
490
491 mod_clear (r, m);
492 mod_clear (minusone, m);
493 return i;
494 }
495
496
497 int
mod_isprime(const modulus_t m)498 mod_isprime (const modulus_t m)
499 {
500 residue_t b, minusone, r1;
501 const unsigned long n = mod_getmod_ul (m);
502 unsigned long mm1;
503 int r = 0, po2;
504
505 if (n == 1UL)
506 return 0;
507
508 if (n % 2UL == 0UL)
509 {
510 r = (n == 2UL);
511 #if defined(PARI)
512 printf ("isprime(%lu) == %d /* PARI */\n", n, r);
513 #endif
514 return r;
515 }
516
517 /* Set mm1 to the odd part of m-1 */
518 mm1 = n - 1UL;
519 po2 = ularith_ctz (mm1);
520 mm1 >>= po2;
521
522 mod_init_noset0 (b, m);
523 mod_init_noset0 (minusone, m);
524 mod_init_noset0 (r1, m);
525 mod_set1 (minusone, m);
526 mod_neg (minusone, minusone, m);
527
528 /* Do base 2 SPRP test */
529 mod_2pow_ul (r1, mm1, m); /* r = 2^mm1 mod m */
530 /* If n is prime and 1 or 7 (mod 8), then 2 is a square (mod n)
531 and one less squaring must suffice. This does not strengthen the
532 test but saves one squaring for composite input */
533 if (n % 8 == 7)
534 {
535 if (!mod_is1 (r1, m))
536 goto end;
537 }
538 else if (!find_minus1 (r1, minusone, po2 - ((n % 8 == 1) ? 1 : 0), m))
539 goto end; /* Not prime */
540
541 if (n < 2047UL)
542 {
543 r = 1;
544 goto end;
545 }
546
547 /* Base 3 is poor at identifying composites == 1 (mod 3), but good at
548 identifying composites == 2 (mod 3). Thus we use it only for 2 (mod 3) */
549 if (n % 3UL == 1UL)
550 {
551 mod_set1 (b, m);
552 mod_add (b, b, b, m);
553 mod_add (b, b, b, m);
554 mod_add (b, b, b, m);
555 mod_add (b, b, minusone, m); /* b = 7 */
556 mod_pow_ul (r1, b, mm1, m); /* r = 7^mm1 mod m */
557 if (!find_minus1 (r1, minusone, po2, m))
558 goto end; /* Not prime */
559
560 if (n < 2269093UL)
561 {
562 r = (n != 314821UL);
563 goto end;
564 }
565
566 /* b is still 7 here */
567 mod_add (b, b, b, m); /* 14 */
568 mod_sub (b, b, minusone, m); /* 15 */
569 mod_add (b, b, b, m); /* 30 */
570 mod_add (b, b, b, m); /* 60 */
571 mod_sub (b, b, minusone, m); /* 61 */
572 mod_pow_ul (r1, b, mm1, m); /* r = 61^mm1 mod m */
573 if (!find_minus1 (r1, minusone, po2, m))
574 goto end; /* Not prime */
575
576 #if (ULONG_MAX > 4294967295UL)
577 if (n != 4759123141UL && n != 8411807377UL && n < 11207066041UL)
578 {
579 r = 1;
580 goto end;
581 }
582
583 mod_set1 (b, m);
584 mod_add (b, b, b, m);
585 mod_add (b, b, b, m);
586 mod_sub (b, b, minusone, m); /* b = 5 */
587 mod_pow_ul (r1, b, mm1, m); /* r = 5^mm1 mod m */
588 if (!find_minus1 (r1, minusone, po2, m))
589 goto end; /* Not prime */
590
591 /* These are the base 5,7,61 SPSP < 10^13 and n == 1 (mod 3) */
592 r = (n != 30926647201UL && n != 45821738881UL &&
593 n != 74359744201UL && n != 90528271681UL &&
594 n != 110330267041UL && n != 373303331521UL &&
595 n != 440478111067UL && n != 1436309367751UL &&
596 n != 1437328758421UL && n != 1858903385041UL &&
597 n != 4897239482521UL && n != 5026103290981UL &&
598 n != 5219055617887UL && n != 5660137043641UL &&
599 n != 6385803726241UL);
600 #else
601 r = 1;
602 #endif
603 }
604 else
605 {
606 /* Case n % 3 == 0, 2 */
607
608 mod_3pow_ul (r1, mm1, m); /* r = 3^mm1 mod m */
609 if (!find_minus1 (r1, minusone, po2, m))
610 goto end; /* Not prime */
611
612 if (n < 102690677UL && n != 5173601UL && n != 16070429UL &&
613 n != 54029741UL)
614 {
615 r = 1;
616 goto end;
617 }
618
619 mod_set1 (b, m);
620 mod_add (b, b, b, m);
621 mod_add (b, b, b, m);
622 mod_sub (b, b, minusone, m); /* b = 5 */
623 mod_pow_ul (r1, b, mm1, m); /* r = 5^mm1 mod m */
624 if (!find_minus1 (r1, minusone, po2, m))
625 goto end; /* Not prime */
626
627 #if (ULONG_MAX > 4294967295UL)
628 /* These are the base 3,5 SPSP < 10^13 with n == 2 (mod 3) */
629 r = (n != 244970876021UL && n != 405439595861UL &&
630 n != 1566655993781UL && n != 3857382025841UL &&
631 n != 4074652846961UL && n != 5783688565841UL);
632 #else
633 r = 1;
634 #endif
635 }
636
637 end:
638 #if defined(PARI)
639 printf ("isprime(%lu) == %d /* PARI */\n", n, r);
640 #endif
641 mod_clear (b, m);
642 mod_clear (minusone, m);
643 mod_clear (r1, m);
644 return r;
645 }
646
647 #if 1
648
649 int
mod_jacobi(const residue_t a_par,const modulus_t m_par)650 mod_jacobi (const residue_t a_par, const modulus_t m_par)
651 {
652 unsigned long x, m;
653 unsigned int s, j;
654
655 /* Get residue in Montgomery form directly without converting */
656 x = a_par[0];
657 m = mod_getmod_ul (m_par);
658 ASSERT (x < m);
659 ASSERT(m % 2 == 1);
660
661 j = ularith_ctz(x);
662 x = x >> j;
663 /* If we divide by an odd power of 2, and 2 is a QNR, flip sign */
664 /* 2 is a QNR (mod m) iff m = 3,5 (mod 8)
665 m = 1 = 001b: 1
666 m = 3 = 011b: -1
667 m = 5 = 101b: -1
668 m = 7 = 111b: 1
669 Hence we can store in s the exponent of -1, i.e., s=0 for jacobi()=1
670 and s=1 for jacobi()=-1, and update s ^= (m>>1) & (m>>2) & 1.
671 We can do the &1 at the very end.
672 In fact, we store the exponent of -1 in the second bit of s.
673 The s ^= ((j<<1) & (m ^ (m>>1))) still needs 2 shift but one of them can
674 be done with LEA, and f = s ^ (x&m) needs no shift */
675
676 s = ((j<<1) & (m ^ (m>>1)));
677
678 while (x > 1) {
679 /* Here, x < m, x and m are odd */
680
681 /* Implicitly swap by reversing roles of x and m in next loop */
682 /* Flip sign if both are 3 (mod 4) */
683 s = s ^ (x&m);
684
685 /* Make m smaller by subtracting and shifting */
686 do {
687 m -= x; /* Difference is even */
688 if (m == 0)
689 break;
690 /* Make odd again */
691 j = ularith_ctz(m);
692 s ^= ((j<<1) & (x ^ (x>>1)));
693 m >>= j;
694 } while (m >= x);
695
696 if (m <= 1) {
697 x = m;
698 break;
699 }
700
701 /* Flip sign if both are 3 (mod 4) */
702 /* Implicitly swap again */
703 s = s ^ (x&m);
704
705 /* Make x<m by subtracting and shifting */
706 do {
707 x -= m; /* Difference is even */
708 if (x == 0)
709 break;
710 /* Make odd again */
711 j = ularith_ctz(x);
712 s ^= ((j<<1) & (m ^ (m>>1)));
713 x >>= j;
714 } while (x >= m);
715 }
716
717 if (x == 0)
718 return 0;
719 return ((s & 2) == 0) ? 1 : -1;
720 }
721
722 #else
723
724 /*
725 #!/usr/bin/env python3
726 # Python program to create mult.h
727
728 def rate(k,b):
729 # The 0.25 magic constant here tries to estimate the ratio m/x,
730 # to minimize (x+c*m)/2^b
731 r = (abs(k)*0.25 + 1.)/2**b
732 # print ("rate(" + str(k) + ", " + str(b) + ") = " + str(r))
733 return(r)
734
735 def bestb(k):
736 best_b = 0
737 best_r = 1
738 best_c = 0
739 for b in range(1, 8):
740 c = k % (2**b)
741 r = rate(c, b)
742 if r < best_r:
743 best_r=r
744 best_b = b
745 best_c = c
746 c = - (2**b - c)
747 r = rate(c, b)
748 if r < best_r:
749 best_r=r
750 best_b = b
751 best_c = c
752 # print ("bestb(" + str(k) + ") = " + str(best_b))
753 return([k, best_b, best_c, 1/best_r])
754
755
756 r = [str(-bestb(2*i+1)[2]) for i in range(0, 128) ]
757 print("static char mult[128] = {" + ", ".join(r) + "};")
758 */
759
760 #include "mult.h"
761 #include "macros.h"
762 static unsigned char invmod8[256] = {
763 0, 1, 0, 171, 0, 205, 0, 183, 0, 57, 0, 163, 0, 197, 0, 239, 0, 241, 0, 27, 0, 61, 0, 167, 0, 41, 0, 19, 0, 53, 0, 223, 0, 225, 0, 139, 0, 173, 0, 151, 0, 25, 0, 131, 0, 165, 0, 207, 0, 209, 0, 251, 0, 29, 0, 135, 0, 9, 0, 243, 0, 21, 0, 191, 0, 193, 0, 107, 0, 141, 0, 119, 0, 249, 0, 99, 0, 133, 0, 175, 0, 177, 0, 219, 0, 253, 0, 103, 0, 233, 0, 211, 0, 245, 0, 159, 0, 161, 0, 75, 0, 109, 0, 87, 0, 217, 0, 67, 0, 101, 0, 143, 0, 145, 0, 187, 0, 221, 0, 71, 0, 201, 0, 179, 0, 213, 0, 127, 0, 129, 0, 43, 0, 77, 0, 55, 0, 185, 0, 35, 0, 69, 0, 111, 0, 113, 0, 155, 0, 189, 0, 39, 0, 169, 0, 147, 0, 181, 0, 95, 0, 97, 0, 11, 0, 45, 0, 23, 0, 153, 0, 3, 0, 37, 0, 79, 0, 81, 0, 123, 0, 157, 0, 7, 0, 137, 0, 115, 0, 149, 0, 63, 0, 65, 0, 235, 0, 13, 0, 247, 0, 121, 0, 227, 0, 5, 0, 47, 0, 49, 0, 91, 0, 125, 0, 231, 0, 105, 0, 83, 0, 117, 0, 31, 0, 33, 0, 203, 0, 237, 0, 215, 0, 89, 0, 195, 0, 229, 0, 15, 0, 17, 0, 59, 0, 93, 0, 199, 0, 73, 0, 51, 0, 85, 0, 255
764 };
765
766 static inline int
s_val(unsigned int s)767 s_val(unsigned int s) {
768 return ((s & 2) == 0) ? 1 : -1;
769 }
770
771 static int
mod_jacobi1(const residue_t a_par,const modulus_t m_par)772 mod_jacobi1 (const residue_t a_par, const modulus_t m_par)
773 {
774 unsigned long x, m;
775 unsigned int s, j;
776
777 x = mod_get_ul (a_par, m_par);
778 m = mod_getmod_ul (m_par);
779 ASSERT (x < m);
780 ASSERT(m % 2 == 1);
781
782 j = ularith_ctz(x);
783 x = x >> j;
784
785 s = ((j<<1) & (m ^ (m>>1)));
786
787 while (x > 1) {
788 unsigned long t;
789 unsigned char inv;
790
791 // printf ("kronecker(%lu, %lu) == %d * kronecker(%lu, %lu)\n",
792 // mod_get_ul(a_par, m_par), mod_getmod_ul (m_par), s_val(s), x, m);
793 /* Here x < m. Swap to make x > m */
794 t = m;
795 m = x;
796 x = t;
797 s = s ^ (x&m);
798
799 /* Now x > m */
800 inv = invmod8[(unsigned char)m];
801 do {
802 /* Do a REDC-like step. We want a multiplier k such that the low
803 8 bits of x+k*m are all zero.
804 That is, we want k = -x/m (mod 2^8). */
805 unsigned char k;
806 unsigned long t1;
807 long int c, t2;
808 // const unsigned long old_x = x;
809
810 k = inv * (unsigned char)x;
811 // ASSERT_ALWAYS((k & 1) == 1);
812 c = mult[k / 2];
813
814 /* Compute x+cm */
815 long tmp = c >> 63;
816 ularith_mul_ul_ul_2ul (&t1, (unsigned long *)&t2, (c ^ tmp) - tmp, m);
817 t2 ^= tmp;
818 t1 ^= tmp;
819 ularith_add_ul_2ul (&t1, (unsigned long *)&t2, x-tmp);
820 tmp = ((long) t2) >> 63;
821
822 t2 ^= tmp;
823 t1 ^= tmp;
824 s ^= m & tmp;
825 ularith_add_ul_2ul (&t1, (unsigned long *)&t2, -tmp);
826 // ASSERT_ALWAYS(t2 >= 0);
827
828 if (t1 == 0) {
829 if (t2 == 0) {
830 x = 0;
831 break;
832 }
833 t1 = t2;
834 /* Divided by 2^64 which is square, so no adjustment to s */
835 t2 = 0;
836 }
837
838 j = ularith_ctz(t1);
839 ularith_shrd (&t1, t2, j);
840 // ASSERT_ALWAYS((t2 >> j) == 0);
841 x = t1;
842 s ^= ((j<<1) & (m ^ (m>>1)));
843 // ASSERT_ALWAYS(x < old_x);
844 // printf ("%f\n", (double)old_x / (double)x);
845 } while (x >= m);
846 }
847
848 if (x == 0)
849 return 0;
850 return s_val(s);
851 }
852
853 int
mod_jacobi(const residue_t a_par,const modulus_t m_par)854 mod_jacobi (const residue_t a_par, const modulus_t m_par)
855 {
856 unsigned long x, m;
857 unsigned int s, j;
858
859 x = mod_get_ul (a_par, m_par);
860 m = mod_getmod_ul (m_par);
861 ASSERT (x < m);
862 ASSERT(m % 2 == 1);
863
864 if ((LONG_MAX - x) / 50 < m)
865 return mod_jacobi1 (a_par, m_par);
866
867 j = ularith_ctz(x);
868 x = x >> j;
869
870 s = ((j<<1) & (m ^ (m>>1)));
871
872 while (x > 1) {
873 unsigned long t;
874 unsigned char inv;
875
876 // printf ("kronecker(%lu, %lu) == %d * kronecker(%lu, %lu)\n",
877 // mod_get_ul(a_par, m_par), mod_getmod_ul (m_par), s_val(s), x, m);
878 /* Here x < m. Swap to make x > m */
879 t = m;
880 m = x;
881 x = t;
882 s = s ^ (x&m);
883
884 /* Now x > m */
885 inv = invmod8[(unsigned char)m];
886 do {
887 /* Do a REDC-like step. We want a multiplier k such that the low
888 8 bits of x+k*m are all zero.
889 That is, we want k = -x/m (mod 2^8). */
890 unsigned char k;
891 long int c;
892 // const unsigned long old_x = x;
893
894 k = inv * x;
895 // ASSERT_ALWAYS((k & 1) == 1);
896 c = mult[k / 2];
897
898 c = x + c*m;
899 x = c;
900 c >>= 63;
901 x = (x ^ c) - c;
902
903 if (x == 0) {
904 break;
905 }
906 s ^= m & c;
907
908 j = ularith_ctz(x);
909 x >>= j;
910 s ^= ((j<<1) & (m ^ (m>>1)));
911 // printf ("%f\n", (double)old_x / (double)x);
912 } while (x >= m);
913 }
914
915 if (x == 0)
916 return 0;
917 return s_val(s);
918 }
919
920 #endif
921