xref: /dragonfly/contrib/gmp/mpn/generic/perfpow.c (revision 37de577a)
1 /* mpn_perfect_power_p -- mpn perfect power detection.
2 
3    Contributed to the GNU project by Martin Boij.
4 
5 Copyright 2009, 2010 Free Software Foundation, Inc.
6 
7 This file is part of the GNU MP Library.
8 
9 The GNU MP Library is free software; you can redistribute it and/or modify
10 it under the terms of the GNU Lesser General Public License as published by
11 the Free Software Foundation; either version 3 of the License, or (at your
12 option) any later version.
13 
14 The GNU MP Library is distributed in the hope that it will be useful, but
15 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
16 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
17 License for more details.
18 
19 You should have received a copy of the GNU Lesser General Public License
20 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
21 
22 #include "gmp.h"
23 #include "gmp-impl.h"
24 #include "longlong.h"
25 
26 #define SMALL 20
27 #define MEDIUM 100
28 
29 /*
30    Returns non-zero if {np,nn} == {xp,xn} ^ k.
31    Algorithm:
32        For s = 1, 2, 4, ..., s_max, compute the s least significant
33        limbs of {xp,xn}^k. Stop if they don't match the s least
34        significant limbs of {np,nn}.
35 */
36 static int
37 pow_equals (mp_srcptr np, mp_size_t nn,
38 	    mp_srcptr xp,mp_size_t xn,
39 	    mp_limb_t k, mp_bitcnt_t f,
40 	    mp_ptr tp)
41 {
42   mp_limb_t *tp2;
43   mp_bitcnt_t y, z, count;
44   mp_size_t i, bn;
45   int ans;
46   mp_limb_t h, l;
47   TMP_DECL;
48 
49   ASSERT (nn > 1 || (nn == 1 && np[0] > 1));
50   ASSERT (np[nn - 1] > 0);
51   ASSERT (xn > 0);
52 
53   if (xn == 1 && xp[0] == 1)
54     return 0;
55 
56   z = 1 + (nn >> 1);
57   for (bn = 1; bn < z; bn <<= 1)
58     {
59       mpn_powlo (tp, xp, &k, 1, bn, tp + bn);
60       if (mpn_cmp (tp, np, bn) != 0)
61 	return 0;
62     }
63 
64   TMP_MARK;
65 
66   /* Final check. Estimate the size of {xp,xn}^k before computing
67      the power with full precision.
68      Optimization: It might pay off to make a more accurate estimation of
69      the logarithm of {xp,xn}, rather than using the index of the MSB.
70   */
71 
72   count_leading_zeros (count, xp[xn - 1]);
73   y = xn * GMP_LIMB_BITS - count - 1;  /* msb_index (xp, xn) */
74 
75   umul_ppmm (h, l, k, y);
76   h -= l == 0;  l--;	/* two-limb decrement */
77 
78   z = f - 1; /* msb_index (np, nn) */
79   if (h == 0 && l <= z)
80     {
81       mp_limb_t size;
82       size = l + k;
83       ASSERT_ALWAYS (size >= k);
84 
85       y = 2 + size / GMP_LIMB_BITS;
86       tp2 = TMP_ALLOC_LIMBS (y);
87 
88       i = mpn_pow_1 (tp, xp, xn, k, tp2);
89       if (i == nn && mpn_cmp (tp, np, nn) == 0)
90 	ans = 1;
91       else
92 	ans = 0;
93     }
94   else
95     {
96       ans = 0;
97     }
98 
99   TMP_FREE;
100   return ans;
101 }
102 
103 /*
104    Computes rp such that rp^k * yp = 1 (mod 2^b).
105    Algorithm:
106        Apply Hensel lifting repeatedly, each time
107        doubling (approx.) the number of known bits in rp.
108 */
109 static void
110 binv_root (mp_ptr rp, mp_srcptr yp,
111 	   mp_limb_t k, mp_size_t bn,
112 	   mp_bitcnt_t b, mp_ptr tp)
113 {
114   mp_limb_t *tp2 = tp + bn, *tp3 = tp + 2 * bn, di, k2 = k + 1;
115   mp_bitcnt_t order[GMP_LIMB_BITS * 2];
116   int i, d = 0;
117 
118   ASSERT (bn > 0);
119   ASSERT (b > 0);
120   ASSERT ((k & 1) != 0);
121 
122   binvert_limb (di, k);
123 
124   rp[0] = 1;
125   for (; b != 1; b = (b + 1) >> 1)
126     order[d++] = b;
127 
128   for (i = d - 1; i >= 0; i--)
129     {
130       b = order[i];
131       bn = 1 + (b - 1) / GMP_LIMB_BITS;
132 
133       mpn_mul_1 (tp, rp, bn, k2);
134 
135       mpn_powlo (tp2, rp, &k2, 1, bn, tp3);
136       mpn_mullo_n (rp, yp, tp2, bn);
137 
138       mpn_sub_n (tp2, tp, rp, bn);
139       mpn_pi1_bdiv_q_1 (rp, tp2, bn, k, di, 0);
140       if ((b % GMP_LIMB_BITS) != 0)
141 	rp[(b - 1) / GMP_LIMB_BITS] &= (((mp_limb_t) 1) << (b % GMP_LIMB_BITS)) - 1;
142     }
143   return;
144 }
145 
146 /*
147    Computes rp such that rp^2 * yp = 1 (mod 2^{b+1}).
148    Returns non-zero if such an integer rp exists.
149 */
150 static int
151 binv_sqroot (mp_ptr rp, mp_srcptr yp,
152 	     mp_size_t bn, mp_bitcnt_t b,
153 	     mp_ptr tp)
154 {
155   mp_limb_t k = 3, *tp2 = tp + bn, *tp3 = tp + (bn << 1);
156   mp_bitcnt_t order[GMP_LIMB_BITS * 2];
157   int i, d = 0;
158 
159   ASSERT (bn > 0);
160   ASSERT (b > 0);
161 
162   rp[0] = 1;
163   if (b == 1)
164     {
165       if ((yp[0] & 3) != 1)
166 	return 0;
167     }
168   else
169     {
170       if ((yp[0] & 7) != 1)
171 	return 0;
172 
173       for (; b != 2; b = (b + 2) >> 1)
174 	order[d++] = b;
175 
176       for (i = d - 1; i >= 0; i--)
177 	{
178 	  b = order[i];
179 	  bn = 1 + b / GMP_LIMB_BITS;
180 
181 	  mpn_mul_1 (tp, rp, bn, k);
182 
183 	  mpn_powlo (tp2, rp, &k, 1, bn, tp3);
184 	  mpn_mullo_n (rp, yp, tp2, bn);
185 
186 #if HAVE_NATIVE_mpn_rsh1sub_n
187 	  mpn_rsh1sub_n (rp, tp, rp, bn);
188 #else
189 	  mpn_sub_n (tp2, tp, rp, bn);
190 	  mpn_rshift (rp, tp2, bn, 1);
191 #endif
192 	  rp[b / GMP_LIMB_BITS] &= (((mp_limb_t) 1) << (b % GMP_LIMB_BITS)) - 1;
193 	}
194     }
195   return 1;
196 }
197 
198 /*
199    Returns non-zero if {np,nn} is a kth power.
200 */
201 static int
202 is_kth_power (mp_ptr rp, mp_srcptr np,
203 	      mp_limb_t k, mp_srcptr yp,
204 	      mp_size_t nn, mp_bitcnt_t f,
205 	      mp_ptr tp)
206 {
207   mp_limb_t x, c;
208   mp_bitcnt_t b;
209   mp_size_t i, rn, xn;
210 
211   ASSERT (nn > 0);
212   ASSERT (((k & 1) != 0) || (k == 2));
213   ASSERT ((np[0] & 1) != 0);
214 
215   if (k == 2)
216     {
217       b = (f + 1) >> 1;
218       rn = 1 + b / GMP_LIMB_BITS;
219       if (binv_sqroot (rp, yp, rn, b, tp) != 0)
220 	{
221 	  xn = rn;
222 	  MPN_NORMALIZE (rp, xn);
223 	  if (pow_equals (np, nn, rp, xn, k, f, tp) != 0)
224 	    return 1;
225 
226 	  /* Check if (2^b - rp)^2 == np */
227 	  c = 0;
228 	  for (i = 0; i < rn; i++)
229 	    {
230 	      x = rp[i];
231 	      rp[i] = -x - c;
232 	      c |= (x != 0);
233 	    }
234 	  rp[rn - 1] &= (((mp_limb_t) 1) << (b % GMP_LIMB_BITS)) - 1;
235 	  MPN_NORMALIZE (rp, rn);
236 	  if (pow_equals (np, nn, rp, rn, k, f, tp) != 0)
237 	    return 1;
238 	}
239     }
240   else
241     {
242       b = 1 + (f - 1) / k;
243       rn = 1 + (b - 1) / GMP_LIMB_BITS;
244       binv_root (rp, yp, k, rn, b, tp);
245       MPN_NORMALIZE (rp, rn);
246       if (pow_equals (np, nn, rp, rn, k, f, tp) != 0)
247 	return 1;
248     }
249   MPN_ZERO (rp, rn); /* Untrash rp */
250   return 0;
251 }
252 
253 static int
254 perfpow (mp_srcptr np, mp_size_t nn,
255 	 mp_limb_t ub, mp_limb_t g,
256 	 mp_bitcnt_t f, int neg)
257 {
258   mp_limb_t *yp, *tp, k = 0, *rp1;
259   int ans = 0;
260   mp_bitcnt_t b;
261   gmp_primesieve_t ps;
262   TMP_DECL;
263 
264   ASSERT (nn > 0);
265   ASSERT ((np[0] & 1) != 0);
266   ASSERT (ub > 0);
267 
268   TMP_MARK;
269   gmp_init_primesieve (&ps);
270   b = (f + 3) >> 1;
271 
272   yp = TMP_ALLOC_LIMBS (nn);
273   rp1 = TMP_ALLOC_LIMBS (nn);
274   tp = TMP_ALLOC_LIMBS (5 * nn);	/* FIXME */
275   MPN_ZERO (rp1, nn);
276 
277   mpn_binvert (yp, np, 1 + (b - 1) / GMP_LIMB_BITS, tp);
278   if (b % GMP_LIMB_BITS)
279     yp[(b - 1) / GMP_LIMB_BITS] &= (((mp_limb_t) 1) << (b % GMP_LIMB_BITS)) - 1;
280 
281   if (neg)
282     gmp_nextprime (&ps);
283 
284   if (g > 0)
285     {
286       ub = MIN (ub, g + 1);
287       while ((k = gmp_nextprime (&ps)) < ub)
288 	{
289 	  if ((g % k) == 0)
290 	    {
291 	      if (is_kth_power (rp1, np, k, yp, nn, f, tp) != 0)
292 		{
293 		  ans = 1;
294 		  goto ret;
295 		}
296 	    }
297 	}
298     }
299   else
300     {
301       while ((k = gmp_nextprime (&ps)) < ub)
302 	{
303 	  if (is_kth_power (rp1, np, k, yp, nn, f, tp) != 0)
304 	    {
305 	      ans = 1;
306 	      goto ret;
307 	    }
308 	}
309     }
310  ret:
311   TMP_FREE;
312   return ans;
313 }
314 
315 static const unsigned short nrtrial[] = { 100, 500, 1000 };
316 
317 /* Table of (log_{p_i} 2) values, where p_i is
318    the (nrtrial[i] + 1)'th prime number.
319 */
320 static const double logs[] = { 0.1099457228193620, 0.0847016403115322, 0.0772048195144415 };
321 
322 int
323 mpn_perfect_power_p (mp_srcptr np, mp_size_t nn)
324 {
325   mp_size_t ncn, s, pn, xn;
326   mp_limb_t *nc, factor, g = 0;
327   mp_limb_t exp, *prev, *next, d, l, r, c, *tp, cry;
328   mp_bitcnt_t twos = 0, count;
329   int ans, where = 0, neg = 0, trial;
330   TMP_DECL;
331 
332   nc = (mp_ptr) np;
333 
334   if (nn < 0)
335     {
336       neg = 1;
337       nn = -nn;
338     }
339 
340   if (nn == 0 || (nn == 1 && np[0] == 1))
341     return 1;
342 
343   TMP_MARK;
344 
345   ncn = nn;
346   twos = mpn_scan1 (np, 0);
347   if (twos > 0)
348     {
349       if (twos == 1)
350 	{
351 	  ans = 0;
352 	  goto ret;
353 	}
354       s = twos / GMP_LIMB_BITS;
355       if (s + 1 == nn && POW2_P (np[s]))
356 	{
357 	  ans = ! (neg && POW2_P (twos));
358 	  goto ret;
359 	}
360       count = twos % GMP_LIMB_BITS;
361       ncn = nn - s;
362       nc = TMP_ALLOC_LIMBS (ncn);
363       if (count > 0)
364 	{
365 	  mpn_rshift (nc, np + s, ncn, count);
366 	  ncn -= (nc[ncn - 1] == 0);
367 	}
368       else
369 	{
370 	  MPN_COPY (nc, np + s, ncn);
371 	}
372       g = twos;
373     }
374 
375   if (ncn <= SMALL)
376     trial = 0;
377   else if (ncn <= MEDIUM)
378     trial = 1;
379   else
380     trial = 2;
381 
382   factor = mpn_trialdiv (nc, ncn, nrtrial[trial], &where);
383 
384   if (factor != 0)
385     {
386       if (twos == 0)
387 	{
388 	  nc = TMP_ALLOC_LIMBS (ncn);
389 	  MPN_COPY (nc, np, ncn);
390 	}
391 
392       /* Remove factors found by trialdiv.
393 	 Optimization: Perhaps better to use
394 	 the strategy in mpz_remove ().
395       */
396       prev = TMP_ALLOC_LIMBS (ncn + 2);
397       next = TMP_ALLOC_LIMBS (ncn + 2);
398       tp = TMP_ALLOC_LIMBS (4 * ncn);
399 
400       do
401 	{
402 	  binvert_limb (d, factor);
403 	  prev[0] = d;
404 	  pn = 1;
405 	  exp = 1;
406 	  while (2 * pn - 1 <= ncn)
407 	    {
408 	      mpn_sqr (next, prev, pn);
409 	      xn = 2 * pn;
410 	      xn -= (next[xn - 1] == 0);
411 
412 	      if (mpn_divisible_p (nc, ncn, next, xn) == 0)
413 		break;
414 
415 	      exp <<= 1;
416 	      pn = xn;
417 	      MP_PTR_SWAP (next, prev);
418 	    }
419 
420 	  /* Binary search for the exponent */
421 	  l = exp + 1;
422 	  r = 2 * exp - 1;
423 	  while (l <= r)
424 	    {
425 	      c = (l + r) >> 1;
426 	      if (c - exp > 1)
427 		{
428 		  xn = mpn_pow_1 (tp, &d, 1, c - exp, next);
429 		  if (pn + xn - 1 > ncn)
430 		    {
431 		      r = c - 1;
432 		      continue;
433 		    }
434 		  mpn_mul (next, prev, pn, tp, xn);
435 		  xn += pn;
436 		  xn -= (next[xn - 1] == 0);
437 		}
438 	      else
439 		{
440 		  cry = mpn_mul_1 (next, prev, pn, d);
441 		  next[pn] = cry;
442 		  xn = pn + (cry != 0);
443 		}
444 
445 	      if (mpn_divisible_p (nc, ncn, next, xn) == 0)
446 		{
447 		  r = c - 1;
448 		}
449 	      else
450 		{
451 		  exp = c;
452 		  l = c + 1;
453 		  MP_PTR_SWAP (next, prev);
454 		  pn = xn;
455 		}
456 	    }
457 
458 	  if (g == 0)
459 	    g = exp;
460 	  else
461 	    g = mpn_gcd_1 (&g, 1, exp);
462 
463 	  if (g == 1)
464 	    {
465 	      ans = 0;
466 	      goto ret;
467 	    }
468 
469 	  mpn_divexact (next, nc, ncn, prev, pn);
470 	  ncn = ncn - pn;
471 	  ncn += next[ncn] != 0;
472 	  MPN_COPY (nc, next, ncn);
473 
474 	  if (ncn == 1 && nc[0] == 1)
475 	    {
476 	      ans = ! (neg && POW2_P (g));
477 	      goto ret;
478 	    }
479 
480 	  factor = mpn_trialdiv (nc, ncn, nrtrial[trial], &where);
481 	}
482       while (factor != 0);
483     }
484 
485   count_leading_zeros (count, nc[ncn-1]);
486   count = GMP_LIMB_BITS * ncn - count;   /* log (nc) + 1 */
487   d = (mp_limb_t) (count * logs[trial] + 1e-9) + 1;
488   ans = perfpow (nc, ncn, d, g, count, neg);
489 
490  ret:
491   TMP_FREE;
492   return ans;
493 }
494