1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <math.h>
5 #include <float.h>
6 
7 /* The AKS primality algorithm for native integers.
8  *
9  * There are three versions here:
10  *   V6         The v6 algorithm from the latest AKS paper.
11  *              https://www.cse.iitk.ac.in/users/manindra/algebra/primality_v6.pdf
12  *   BORNEMANN  Improvements from Bernstein, Voloch, and a clever r/s
13  *              selection from Folkmar Bornemann.  Similar to Bornemann's
14  *              2003 Pari/GP implementation:
15  *              https://homepage.univie.ac.at/Dietrich.Burde/pari/aks.gp
16  *   BERN41     My implementation of theorem 4.1 from Bernstein's 2003 paper.
17  *              https://cr.yp.to/papers/aks.pdf
18  *
19  * Each one is orders of magnitude faster than the previous, and by default
20  * we use Bernstein 4.1 as it is by far the fastest.
21  *
22  * Note that AKS is very, very slow compared to other methods.  It is, however,
23  * polynomial in log(N), and log-log performance graphs show nice straight
24  * lines for both implementations.  However APR-CL and ECPP both start out
25  * much faster and the slope will be less for any sizes of N that we're
26  * interested in.
27  *
28  * For native 64-bit integers this is purely a coding exercise, as BPSW is
29  * a million times faster and gives proven results.
30  *
31  *
32  * When n < 2^(wordbits/2)-1, we can do a straightforward intermediate:
33  *      r = (r + a * b) % n
34  * If n is larger, then these are replaced with:
35  *      r = addmod( r, mulmod(a, b, n), n)
36  * which is a lot more work, but keeps us correct.
37  *
38  * Software that does polynomial convolutions followed by a modulo can be
39  * very fast, but will fail when n >= (2^wordbits)/r.
40  *
41  * This is all much easier in GMP.
42  *
43  * Copyright 2012-2016, Dana Jacobsen.
44  */
45 
46 #define SQRTN_SHORTCUT 1
47 
48 #define IMPL_V6        0    /* From the primality_v6 paper */
49 #define IMPL_BORNEMANN 0    /* From Bornemann's 2002 implementation */
50 #define IMPL_BERN41    1    /* From Bernstein's early 2003 paper */
51 
52 #include "ptypes.h"
53 #include "aks.h"
54 #define FUNC_isqrt 1
55 #define FUNC_gcd_ui 1
56 #include "util.h"
57 #include "cache.h"
58 #include "mulmod.h"
59 #include "factor.h"
60 
61 #if IMPL_BORNEMANN || IMPL_BERN41
62 /* We could use lgamma, but it isn't in MSVC and not in pre-C99.  The only
63  * sure way to find if it is available is test compilation (ala autoconf).
64  * Instead, we'll just use our own implementation.
65  * See http://mrob.com/pub/ries/lanczos-gamma.html for alternates. */
log_gamma(double x)66 static double log_gamma(double x)
67 {
68   static const double log_sqrt_two_pi =  0.91893853320467274178;
69   static const double lanczos_coef[8+1] =
70     { 0.99999999999980993, 676.5203681218851, -1259.1392167224028,
71       771.32342877765313, -176.61502916214059, 12.507343278686905,
72       -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 };
73   double base = x + 7.5, sum = 0;
74   int i;
75   for (i = 8; i >= 1; i--)
76     sum += lanczos_coef[i] / (x + (double)i);
77   sum += lanczos_coef[0];
78   sum = log_sqrt_two_pi + log(sum/x) + ( (x+0.5)*log(base) - base );
79   return sum;
80 }
81 
82 /* Note: For lgammal we need logl in the above.
83  * Max error drops from 2.688466e-09 to 1.818989e-12. */
84 #undef lgamma
85 #define lgamma(x) log_gamma(x)
86 #endif
87 
88 #if IMPL_BERN41
log_binomial(UV n,UV k)89 static double log_binomial(UV n, UV k)
90 {
91   return log_gamma(n+1) - log_gamma(k+1) - log_gamma(n-k+1);
92 }
log_bern41_binomial(UV r,UV d,UV i,UV j,UV s)93 static double log_bern41_binomial(UV r, UV d, UV i, UV j, UV s)
94 {
95   return   log_binomial( 2*s,   i)
96          + log_binomial( d,     i)
97          + log_binomial( 2*s-i, j)
98          + log_binomial( r-2-d, j);
99 }
bern41_acceptable(UV n,UV r,UV s)100 static int bern41_acceptable(UV n, UV r, UV s)
101 {
102   double scmp = ceil(sqrt( (r-1)/3.0 )) * log(n);
103   UV d = (UV) (0.5 * (r-1));
104   UV i = (UV) (0.475 * (r-1));
105   UV j = i;
106   if (d > r-2)     d = r-2;
107   if (i > d)       i = d;
108   if (j > (r-2-d)) j = r-2-d;
109   return (log_bern41_binomial(r,d,i,j,s) >= scmp);
110 }
111 #endif
112 
113 #if 0
114 /* Naive znorder.  Works well if limit is small.  Note arguments.  */
115 static UV order(UV r, UV n, UV limit) {
116   UV j;
117   UV t = 1;
118   for (j = 1; j <= limit; j++) {
119     t = mulmod(t, n, r);
120     if (t == 1)
121       break;
122   }
123   return j;
124 }
125 static void poly_print(UV* poly, UV r)
126 {
127   int i;
128   for (i = r-1; i >= 1; i--) {
129     if (poly[i] != 0)
130       printf("%lux^%d + ", poly[i], i);
131   }
132   if (poly[0] != 0) printf("%lu", poly[0]);
133   printf("\n");
134 }
135 #endif
136 
poly_mod_mul(UV * px,UV * py,UV * res,UV r,UV mod)137 static void poly_mod_mul(UV* px, UV* py, UV* res, UV r, UV mod)
138 {
139   UV degpx, degpy;
140   UV i, j, pxi, pyj, rindex;
141 
142   /* Determine max degree of px and py */
143   for (degpx = r-1; degpx > 0 && !px[degpx]; degpx--) ; /* */
144   for (degpy = r-1; degpy > 0 && !py[degpy]; degpy--) ; /* */
145   /* We can sum at least j values at once */
146   j = (mod >= HALF_WORD) ? 0 : (UV_MAX / ((mod-1)*(mod-1)));
147 
148   if (j >= degpx || j >= degpy) {
149     /* res will be written completely, so no need to set */
150     for (rindex = 0; rindex < r; rindex++) {
151       UV sum = 0;
152       j = rindex;
153       for (i = 0; i <= degpx; i++) {
154         if (j <= degpy)
155           sum += px[i] * py[j];
156         j = (j == 0) ? r-1 : j-1;
157       }
158       res[rindex] = sum % mod;
159     }
160   } else {
161     memset(res, 0, r * sizeof(UV));  /* Zero result accumulator */
162     for (i = 0; i <= degpx; i++) {
163       pxi = px[i];
164       if (pxi == 0)  continue;
165       if (mod < HALF_WORD) {
166         for (j = 0; j <= degpy; j++) {
167           pyj = py[j];
168           rindex = i+j;   if (rindex >= r)  rindex -= r;
169           res[rindex] = (res[rindex] + (pxi*pyj) ) % mod;
170         }
171       } else {
172         for (j = 0; j <= degpy; j++) {
173           pyj = py[j];
174           rindex = i+j;   if (rindex >= r)  rindex -= r;
175           res[rindex] = muladdmod(pxi, pyj, res[rindex], mod);
176         }
177       }
178     }
179   }
180   memcpy(px, res, r * sizeof(UV)); /* put result in px */
181 }
poly_mod_sqr(UV * px,UV * res,UV r,UV mod)182 static void poly_mod_sqr(UV* px, UV* res, UV r, UV mod)
183 {
184   UV c, d, s, sum, rindex, maxpx;
185   UV degree = r-1;
186   int native_sqr = (mod > isqrt(UV_MAX/(2*r))) ? 0 : 1;
187 
188   memset(res, 0, r * sizeof(UV)); /* zero out sums */
189   /* Discover index of last non-zero value in px */
190   for (s = degree; s > 0; s--)
191     if (px[s] != 0)
192       break;
193   maxpx = s;
194   /* 1D convolution */
195   for (d = 0; d <= 2*degree; d++) {
196     UV *pp1, *pp2, *ppend;
197     UV s_beg = (d <= degree) ? 0 : d-degree;
198     UV s_end = ((d/2) <= maxpx) ? d/2 : maxpx;
199     if (s_end < s_beg) continue;
200     sum = 0;
201     pp1 = px + s_beg;
202     pp2 = px + d - s_beg;
203     ppend = px + s_end;
204     if (native_sqr) {
205       while (pp1 < ppend)
206         sum += 2 * *pp1++  *  *pp2--;
207       /* Special treatment for last point */
208       c = px[s_end];
209       sum += (s_end*2 == d)  ?  c*c  :  2*c*px[d-s_end];
210       rindex = (d < r) ? d : d-r;  /* d % r */
211       res[rindex] = (res[rindex] + sum) % mod;
212 #if HAVE_UINT128
213     } else {
214       uint128_t max = ((uint128_t)1 << 127) - 1;
215       uint128_t c128, sum128 = 0;
216 
217       while (pp1 < ppend) {
218         c128 = ((uint128_t)*pp1++)  *  ((uint128_t)*pp2--);
219         if (c128 > max) c128 %= mod;
220         c128 <<= 1;
221         if (c128 > max) c128 %= mod;
222         sum128 += c128;
223         if (sum128 > max) sum128 %= mod;
224       }
225       c128 = px[s_end];
226       if (s_end*2 == d) {
227         c128 *= c128;
228       } else {
229         c128 *= px[d-s_end];
230         if (c128 > max) c128 %= mod;
231         c128 <<= 1;
232       }
233       if (c128 > max) c128 %= mod;
234       sum128 += c128;
235       if (sum128 > max) sum128 %= mod;
236       rindex = (d < r) ? d : d-r;  /* d % r */
237       res[rindex] = ((uint128_t)res[rindex] + sum128) % mod;
238 #else
239     } else {
240       while (pp1 < ppend) {
241         UV p1 = *pp1++;
242         UV p2 = *pp2--;
243         sum = addmod(sum, mulmod(2, mulmod(p1, p2, mod), mod), mod);
244       }
245       c = px[s_end];
246       if (s_end*2 == d)
247         sum = addmod(sum, sqrmod(c, mod), mod);
248       else
249         sum = addmod(sum, mulmod(2, mulmod(c, px[d-s_end], mod), mod), mod);
250       rindex = (d < r) ? d : d-r;  /* d % r */
251       res[rindex] = addmod(res[rindex], sum, mod);
252 #endif
253     }
254   }
255   memcpy(px, res, r * sizeof(UV)); /* put result in px */
256 }
257 
poly_mod_pow(UV * pn,UV power,UV r,UV mod)258 static UV* poly_mod_pow(UV* pn, UV power, UV r, UV mod)
259 {
260   UV *res, *temp;
261 
262   Newz(0, res, r, UV);
263   New(0, temp, r, UV);
264   res[0] = 1;
265 
266   while (power) {
267     if (power & 1)  poly_mod_mul(res, pn, temp, r, mod);
268     power >>= 1;
269     if (power)      poly_mod_sqr(pn, temp, r, mod);
270   }
271   Safefree(temp);
272   return res;
273 }
274 
test_anr(UV a,UV n,UV r)275 static int test_anr(UV a, UV n, UV r)
276 {
277   UV* pn;
278   UV* res;
279   UV i;
280   int retval = 1;
281 
282   Newz(0, pn, r, UV);
283   a %= r;
284   pn[0] = a;
285   pn[1] = 1;
286   res = poly_mod_pow(pn, n, r, n);
287   res[n % r] = addmod(res[n % r], n - 1, n);
288   res[0]     = addmod(res[0],     n - a, n);
289 
290   for (i = 0; i < r; i++)
291     if (res[i] != 0)
292       retval = 0;
293   Safefree(res);
294   Safefree(pn);
295   return retval;
296 }
297 
298 /*
299  * Avanzi and Mihǎilescu, 2007
300  * http://www.uni-math.gwdg.de/preda/mihailescu-papers/ouraks3.pdf
301  * "As a consequence, one cannot expect the present variants of AKS to
302  *  compete with the earlier primality proving methods like ECPP and
303  *  cyclotomy." - conclusion regarding memory consumption
304  */
is_aks_prime(UV n)305 int is_aks_prime(UV n)
306 {
307   UV r, s, a, starta = 1;
308 
309   if (n < 2)
310     return 0;
311   if (n == 2)
312     return 1;
313 
314   if (is_power(n, 0))
315     return 0;
316 
317   if (n > 11 && ( !(n%2) || !(n%3) || !(n%5) || !(n%7) || !(n%11) )) return 0;
318   /* if (!is_prob_prime(n)) return 0; */
319 
320 #if IMPL_V6
321   {
322     UV sqrtn = isqrt(n);
323     double log2n = log(n) / log(2);   /* C99 has a log2() function */
324     UV limit = (UV) floor(log2n * log2n);
325 
326     MPUverbose(1, "# aks limit is %lu\n", (unsigned long) limit);
327 
328     for (r = 2; r < n; r++) {
329       if ((n % r) == 0)
330         return 0;
331 #if SQRTN_SHORTCUT
332       if (r > sqrtn)
333         return 1;
334 #endif
335       if (znorder(n, r) > limit)
336         break;
337     }
338 
339     if (r >= n)
340       return 1;
341 
342     s = (UV) floor(sqrt(r-1) * log2n);
343   }
344 #endif
345 #if IMPL_BORNEMANN
346   {
347     UV fac[MPU_MAX_FACTORS+1];
348     UV slim;
349     double c1, c2, x;
350     double const t = 48;
351     double const t1 = (1.0/((t+1)*log(t+1)-t*log(t)));
352     double const dlogn = log(n);
353     r = next_prime( (UV) (t1*t1 * dlogn*dlogn) );
354     while (!is_primitive_root(n,r,1))
355       r = next_prime(r);
356 
357     slim = (UV) (2*t*(r-1));
358     c1 = lgamma(r-1);
359     c2 = dlogn * floor(sqrt(r));
360     { /* Binary search for first s in [1,slim] where x >= 0 */
361       UV i = 1;
362       UV j = slim;
363       while (i < j) {
364         s = i + (j-i)/2;
365         x = (lgamma(r-1+s) - c1 - lgamma(s+1)) / c2 - 1.0;
366         if (x < 0)  i = s+1;
367         else        j = s;
368       }
369       s = i-1;
370     }
371     s = (s+3) >> 1;
372     /* Bornemann checks factors up to (s-1)^2, we check to max(r,s) */
373     /* slim = (s-1)*(s-1); */
374     slim = (r > s) ? r : s;
375     MPUverbose(2, "# aks trial to %lu\n", slim);
376     if (trial_factor(n, fac, 2, slim) > 1)
377       return 0;
378     if (slim >= HALF_WORD || (slim*slim) >= n)
379       return 1;
380   }
381 #endif
382 #if IMPL_BERN41
383   {
384     UV slim, fac[MPU_MAX_FACTORS+1];
385     double const log2n = log(n) / log(2);
386     /* Tuning: Initial 'r' selection.  Search limit for 's'. */
387     double const r0 = ((log2n > 32) ? 0.010 : 0.003) * log2n * log2n;
388     UV const rmult  =  (log2n > 32) ? 6    : 30;
389 
390     r = next_prime(r0 < 2 ? 2 : (UV)r0);  /* r must be at least 3 */
391     while ( !is_primitive_root(n,r,1) || !bern41_acceptable(n,r,rmult*(r-1)) )
392       r = next_prime(r);
393 
394     { /* Binary search for first s in [1,slim] where conditions met */
395       UV bi = 1;
396       UV bj = rmult * (r-1);
397       while (bi < bj) {
398         s = bi + (bj-bi)/2;
399         if (!bern41_acceptable(n, r, s))  bi = s+1;
400         else                              bj = s;
401       }
402       s = bj;
403       if (!bern41_acceptable(n, r, s)) croak("AKS: bad s selected");
404       /* S goes from 2 to s+1 */
405       starta = 2;
406       s = s+1;
407     }
408     /* Check divisibility to s * (s-1) to cover both gcd conditions */
409     slim = s * (s-1);
410     MPUverbose(2, "# aks trial to %lu\n", (unsigned long)slim);
411     if (trial_factor(n, fac, 2, slim) > 1)
412       return 0;
413     if (slim >= HALF_WORD || (slim*slim) >= n)
414       return 1;
415     /* Check b^(n-1) = 1 mod n for b in [2..s] */
416     for (a = 2; a <= s; a++) {
417       if (powmod(a, n-1, n) != 1)
418         return 0;
419     }
420   }
421 #endif
422 
423   MPUverbose(1, "# aks r = %lu  s = %lu\n", (unsigned long) r, (unsigned long) s);
424 
425   /* Almost every composite will get recognized by the first test.
426    * However, we need to run 's' tests to have the result proven for all n
427    * based on the theorems we have available at this time. */
428   for (a = starta; a <= s; a++) {
429     if (! test_anr(a, n, r) )
430       return 0;
431     MPUverbose(2, ".");
432   }
433   MPUverbose(2, "\n");
434   return 1;
435 }
436