1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <math.h>
5 #include <float.h>
6
7 /* The AKS primality algorithm for native integers.
8 *
9 * There are three versions here:
10 * V6 The v6 algorithm from the latest AKS paper.
11 * https://www.cse.iitk.ac.in/users/manindra/algebra/primality_v6.pdf
12 * BORNEMANN Improvements from Bernstein, Voloch, and a clever r/s
13 * selection from Folkmar Bornemann. Similar to Bornemann's
14 * 2003 Pari/GP implementation:
15 * https://homepage.univie.ac.at/Dietrich.Burde/pari/aks.gp
16 * BERN41 My implementation of theorem 4.1 from Bernstein's 2003 paper.
17 * https://cr.yp.to/papers/aks.pdf
18 *
19 * Each one is orders of magnitude faster than the previous, and by default
20 * we use Bernstein 4.1 as it is by far the fastest.
21 *
22 * Note that AKS is very, very slow compared to other methods. It is, however,
23 * polynomial in log(N), and log-log performance graphs show nice straight
24 * lines for both implementations. However APR-CL and ECPP both start out
25 * much faster and the slope will be less for any sizes of N that we're
26 * interested in.
27 *
28 * For native 64-bit integers this is purely a coding exercise, as BPSW is
29 * a million times faster and gives proven results.
30 *
31 *
32 * When n < 2^(wordbits/2)-1, we can do a straightforward intermediate:
33 * r = (r + a * b) % n
34 * If n is larger, then these are replaced with:
35 * r = addmod( r, mulmod(a, b, n), n)
36 * which is a lot more work, but keeps us correct.
37 *
38 * Software that does polynomial convolutions followed by a modulo can be
39 * very fast, but will fail when n >= (2^wordbits)/r.
40 *
41 * This is all much easier in GMP.
42 *
43 * Copyright 2012-2016, Dana Jacobsen.
44 */
45
46 #define SQRTN_SHORTCUT 1
47
48 #define IMPL_V6 0 /* From the primality_v6 paper */
49 #define IMPL_BORNEMANN 0 /* From Bornemann's 2002 implementation */
50 #define IMPL_BERN41 1 /* From Bernstein's early 2003 paper */
51
52 #include "ptypes.h"
53 #include "aks.h"
54 #define FUNC_isqrt 1
55 #define FUNC_gcd_ui 1
56 #include "util.h"
57 #include "cache.h"
58 #include "mulmod.h"
59 #include "factor.h"
60
61 #if IMPL_BORNEMANN || IMPL_BERN41
62 /* We could use lgamma, but it isn't in MSVC and not in pre-C99. The only
63 * sure way to find if it is available is test compilation (ala autoconf).
64 * Instead, we'll just use our own implementation.
65 * See http://mrob.com/pub/ries/lanczos-gamma.html for alternates. */
log_gamma(double x)66 static double log_gamma(double x)
67 {
68 static const double log_sqrt_two_pi = 0.91893853320467274178;
69 static const double lanczos_coef[8+1] =
70 { 0.99999999999980993, 676.5203681218851, -1259.1392167224028,
71 771.32342877765313, -176.61502916214059, 12.507343278686905,
72 -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 };
73 double base = x + 7.5, sum = 0;
74 int i;
75 for (i = 8; i >= 1; i--)
76 sum += lanczos_coef[i] / (x + (double)i);
77 sum += lanczos_coef[0];
78 sum = log_sqrt_two_pi + log(sum/x) + ( (x+0.5)*log(base) - base );
79 return sum;
80 }
81
82 /* Note: For lgammal we need logl in the above.
83 * Max error drops from 2.688466e-09 to 1.818989e-12. */
84 #undef lgamma
85 #define lgamma(x) log_gamma(x)
86 #endif
87
88 #if IMPL_BERN41
log_binomial(UV n,UV k)89 static double log_binomial(UV n, UV k)
90 {
91 return log_gamma(n+1) - log_gamma(k+1) - log_gamma(n-k+1);
92 }
log_bern41_binomial(UV r,UV d,UV i,UV j,UV s)93 static double log_bern41_binomial(UV r, UV d, UV i, UV j, UV s)
94 {
95 return log_binomial( 2*s, i)
96 + log_binomial( d, i)
97 + log_binomial( 2*s-i, j)
98 + log_binomial( r-2-d, j);
99 }
bern41_acceptable(UV n,UV r,UV s)100 static int bern41_acceptable(UV n, UV r, UV s)
101 {
102 double scmp = ceil(sqrt( (r-1)/3.0 )) * log(n);
103 UV d = (UV) (0.5 * (r-1));
104 UV i = (UV) (0.475 * (r-1));
105 UV j = i;
106 if (d > r-2) d = r-2;
107 if (i > d) i = d;
108 if (j > (r-2-d)) j = r-2-d;
109 return (log_bern41_binomial(r,d,i,j,s) >= scmp);
110 }
111 #endif
112
113 #if 0
114 /* Naive znorder. Works well if limit is small. Note arguments. */
115 static UV order(UV r, UV n, UV limit) {
116 UV j;
117 UV t = 1;
118 for (j = 1; j <= limit; j++) {
119 t = mulmod(t, n, r);
120 if (t == 1)
121 break;
122 }
123 return j;
124 }
125 static void poly_print(UV* poly, UV r)
126 {
127 int i;
128 for (i = r-1; i >= 1; i--) {
129 if (poly[i] != 0)
130 printf("%lux^%d + ", poly[i], i);
131 }
132 if (poly[0] != 0) printf("%lu", poly[0]);
133 printf("\n");
134 }
135 #endif
136
poly_mod_mul(UV * px,UV * py,UV * res,UV r,UV mod)137 static void poly_mod_mul(UV* px, UV* py, UV* res, UV r, UV mod)
138 {
139 UV degpx, degpy;
140 UV i, j, pxi, pyj, rindex;
141
142 /* Determine max degree of px and py */
143 for (degpx = r-1; degpx > 0 && !px[degpx]; degpx--) ; /* */
144 for (degpy = r-1; degpy > 0 && !py[degpy]; degpy--) ; /* */
145 /* We can sum at least j values at once */
146 j = (mod >= HALF_WORD) ? 0 : (UV_MAX / ((mod-1)*(mod-1)));
147
148 if (j >= degpx || j >= degpy) {
149 /* res will be written completely, so no need to set */
150 for (rindex = 0; rindex < r; rindex++) {
151 UV sum = 0;
152 j = rindex;
153 for (i = 0; i <= degpx; i++) {
154 if (j <= degpy)
155 sum += px[i] * py[j];
156 j = (j == 0) ? r-1 : j-1;
157 }
158 res[rindex] = sum % mod;
159 }
160 } else {
161 memset(res, 0, r * sizeof(UV)); /* Zero result accumulator */
162 for (i = 0; i <= degpx; i++) {
163 pxi = px[i];
164 if (pxi == 0) continue;
165 if (mod < HALF_WORD) {
166 for (j = 0; j <= degpy; j++) {
167 pyj = py[j];
168 rindex = i+j; if (rindex >= r) rindex -= r;
169 res[rindex] = (res[rindex] + (pxi*pyj) ) % mod;
170 }
171 } else {
172 for (j = 0; j <= degpy; j++) {
173 pyj = py[j];
174 rindex = i+j; if (rindex >= r) rindex -= r;
175 res[rindex] = muladdmod(pxi, pyj, res[rindex], mod);
176 }
177 }
178 }
179 }
180 memcpy(px, res, r * sizeof(UV)); /* put result in px */
181 }
poly_mod_sqr(UV * px,UV * res,UV r,UV mod)182 static void poly_mod_sqr(UV* px, UV* res, UV r, UV mod)
183 {
184 UV c, d, s, sum, rindex, maxpx;
185 UV degree = r-1;
186 int native_sqr = (mod > isqrt(UV_MAX/(2*r))) ? 0 : 1;
187
188 memset(res, 0, r * sizeof(UV)); /* zero out sums */
189 /* Discover index of last non-zero value in px */
190 for (s = degree; s > 0; s--)
191 if (px[s] != 0)
192 break;
193 maxpx = s;
194 /* 1D convolution */
195 for (d = 0; d <= 2*degree; d++) {
196 UV *pp1, *pp2, *ppend;
197 UV s_beg = (d <= degree) ? 0 : d-degree;
198 UV s_end = ((d/2) <= maxpx) ? d/2 : maxpx;
199 if (s_end < s_beg) continue;
200 sum = 0;
201 pp1 = px + s_beg;
202 pp2 = px + d - s_beg;
203 ppend = px + s_end;
204 if (native_sqr) {
205 while (pp1 < ppend)
206 sum += 2 * *pp1++ * *pp2--;
207 /* Special treatment for last point */
208 c = px[s_end];
209 sum += (s_end*2 == d) ? c*c : 2*c*px[d-s_end];
210 rindex = (d < r) ? d : d-r; /* d % r */
211 res[rindex] = (res[rindex] + sum) % mod;
212 #if HAVE_UINT128
213 } else {
214 uint128_t max = ((uint128_t)1 << 127) - 1;
215 uint128_t c128, sum128 = 0;
216
217 while (pp1 < ppend) {
218 c128 = ((uint128_t)*pp1++) * ((uint128_t)*pp2--);
219 if (c128 > max) c128 %= mod;
220 c128 <<= 1;
221 if (c128 > max) c128 %= mod;
222 sum128 += c128;
223 if (sum128 > max) sum128 %= mod;
224 }
225 c128 = px[s_end];
226 if (s_end*2 == d) {
227 c128 *= c128;
228 } else {
229 c128 *= px[d-s_end];
230 if (c128 > max) c128 %= mod;
231 c128 <<= 1;
232 }
233 if (c128 > max) c128 %= mod;
234 sum128 += c128;
235 if (sum128 > max) sum128 %= mod;
236 rindex = (d < r) ? d : d-r; /* d % r */
237 res[rindex] = ((uint128_t)res[rindex] + sum128) % mod;
238 #else
239 } else {
240 while (pp1 < ppend) {
241 UV p1 = *pp1++;
242 UV p2 = *pp2--;
243 sum = addmod(sum, mulmod(2, mulmod(p1, p2, mod), mod), mod);
244 }
245 c = px[s_end];
246 if (s_end*2 == d)
247 sum = addmod(sum, sqrmod(c, mod), mod);
248 else
249 sum = addmod(sum, mulmod(2, mulmod(c, px[d-s_end], mod), mod), mod);
250 rindex = (d < r) ? d : d-r; /* d % r */
251 res[rindex] = addmod(res[rindex], sum, mod);
252 #endif
253 }
254 }
255 memcpy(px, res, r * sizeof(UV)); /* put result in px */
256 }
257
poly_mod_pow(UV * pn,UV power,UV r,UV mod)258 static UV* poly_mod_pow(UV* pn, UV power, UV r, UV mod)
259 {
260 UV *res, *temp;
261
262 Newz(0, res, r, UV);
263 New(0, temp, r, UV);
264 res[0] = 1;
265
266 while (power) {
267 if (power & 1) poly_mod_mul(res, pn, temp, r, mod);
268 power >>= 1;
269 if (power) poly_mod_sqr(pn, temp, r, mod);
270 }
271 Safefree(temp);
272 return res;
273 }
274
test_anr(UV a,UV n,UV r)275 static int test_anr(UV a, UV n, UV r)
276 {
277 UV* pn;
278 UV* res;
279 UV i;
280 int retval = 1;
281
282 Newz(0, pn, r, UV);
283 a %= r;
284 pn[0] = a;
285 pn[1] = 1;
286 res = poly_mod_pow(pn, n, r, n);
287 res[n % r] = addmod(res[n % r], n - 1, n);
288 res[0] = addmod(res[0], n - a, n);
289
290 for (i = 0; i < r; i++)
291 if (res[i] != 0)
292 retval = 0;
293 Safefree(res);
294 Safefree(pn);
295 return retval;
296 }
297
298 /*
299 * Avanzi and Mihǎilescu, 2007
300 * http://www.uni-math.gwdg.de/preda/mihailescu-papers/ouraks3.pdf
301 * "As a consequence, one cannot expect the present variants of AKS to
302 * compete with the earlier primality proving methods like ECPP and
303 * cyclotomy." - conclusion regarding memory consumption
304 */
is_aks_prime(UV n)305 int is_aks_prime(UV n)
306 {
307 UV r, s, a, starta = 1;
308
309 if (n < 2)
310 return 0;
311 if (n == 2)
312 return 1;
313
314 if (is_power(n, 0))
315 return 0;
316
317 if (n > 11 && ( !(n%2) || !(n%3) || !(n%5) || !(n%7) || !(n%11) )) return 0;
318 /* if (!is_prob_prime(n)) return 0; */
319
320 #if IMPL_V6
321 {
322 UV sqrtn = isqrt(n);
323 double log2n = log(n) / log(2); /* C99 has a log2() function */
324 UV limit = (UV) floor(log2n * log2n);
325
326 MPUverbose(1, "# aks limit is %lu\n", (unsigned long) limit);
327
328 for (r = 2; r < n; r++) {
329 if ((n % r) == 0)
330 return 0;
331 #if SQRTN_SHORTCUT
332 if (r > sqrtn)
333 return 1;
334 #endif
335 if (znorder(n, r) > limit)
336 break;
337 }
338
339 if (r >= n)
340 return 1;
341
342 s = (UV) floor(sqrt(r-1) * log2n);
343 }
344 #endif
345 #if IMPL_BORNEMANN
346 {
347 UV fac[MPU_MAX_FACTORS+1];
348 UV slim;
349 double c1, c2, x;
350 double const t = 48;
351 double const t1 = (1.0/((t+1)*log(t+1)-t*log(t)));
352 double const dlogn = log(n);
353 r = next_prime( (UV) (t1*t1 * dlogn*dlogn) );
354 while (!is_primitive_root(n,r,1))
355 r = next_prime(r);
356
357 slim = (UV) (2*t*(r-1));
358 c1 = lgamma(r-1);
359 c2 = dlogn * floor(sqrt(r));
360 { /* Binary search for first s in [1,slim] where x >= 0 */
361 UV i = 1;
362 UV j = slim;
363 while (i < j) {
364 s = i + (j-i)/2;
365 x = (lgamma(r-1+s) - c1 - lgamma(s+1)) / c2 - 1.0;
366 if (x < 0) i = s+1;
367 else j = s;
368 }
369 s = i-1;
370 }
371 s = (s+3) >> 1;
372 /* Bornemann checks factors up to (s-1)^2, we check to max(r,s) */
373 /* slim = (s-1)*(s-1); */
374 slim = (r > s) ? r : s;
375 MPUverbose(2, "# aks trial to %lu\n", slim);
376 if (trial_factor(n, fac, 2, slim) > 1)
377 return 0;
378 if (slim >= HALF_WORD || (slim*slim) >= n)
379 return 1;
380 }
381 #endif
382 #if IMPL_BERN41
383 {
384 UV slim, fac[MPU_MAX_FACTORS+1];
385 double const log2n = log(n) / log(2);
386 /* Tuning: Initial 'r' selection. Search limit for 's'. */
387 double const r0 = ((log2n > 32) ? 0.010 : 0.003) * log2n * log2n;
388 UV const rmult = (log2n > 32) ? 6 : 30;
389
390 r = next_prime(r0 < 2 ? 2 : (UV)r0); /* r must be at least 3 */
391 while ( !is_primitive_root(n,r,1) || !bern41_acceptable(n,r,rmult*(r-1)) )
392 r = next_prime(r);
393
394 { /* Binary search for first s in [1,slim] where conditions met */
395 UV bi = 1;
396 UV bj = rmult * (r-1);
397 while (bi < bj) {
398 s = bi + (bj-bi)/2;
399 if (!bern41_acceptable(n, r, s)) bi = s+1;
400 else bj = s;
401 }
402 s = bj;
403 if (!bern41_acceptable(n, r, s)) croak("AKS: bad s selected");
404 /* S goes from 2 to s+1 */
405 starta = 2;
406 s = s+1;
407 }
408 /* Check divisibility to s * (s-1) to cover both gcd conditions */
409 slim = s * (s-1);
410 MPUverbose(2, "# aks trial to %lu\n", (unsigned long)slim);
411 if (trial_factor(n, fac, 2, slim) > 1)
412 return 0;
413 if (slim >= HALF_WORD || (slim*slim) >= n)
414 return 1;
415 /* Check b^(n-1) = 1 mod n for b in [2..s] */
416 for (a = 2; a <= s; a++) {
417 if (powmod(a, n-1, n) != 1)
418 return 0;
419 }
420 }
421 #endif
422
423 MPUverbose(1, "# aks r = %lu s = %lu\n", (unsigned long) r, (unsigned long) s);
424
425 /* Almost every composite will get recognized by the first test.
426 * However, we need to run 's' tests to have the result proven for all n
427 * based on the theorems we have available at this time. */
428 for (a = starta; a <= s; a++) {
429 if (! test_anr(a, n, r) )
430 return 0;
431 MPUverbose(2, ".");
432 }
433 MPUverbose(2, "\n");
434 return 1;
435 }
436