1 #include "cado.h" // IWYU pragma: keep
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <limits.h>     // INT_MAX
5 #include <math.h> /* for sqrt and floor and log and ceil */
6 #include <pthread.h>
7 #include <gmp.h>
8 #include "cado_poly.h"
9 #include "getprime.h"   // getprime
10 #include "gzip.h"       // fopen_maybe_compressed
11 #include "mpz_poly.h"   // mpz_poly
12 #include "rootfinder.h"
13 #include "verbose.h"
14 #include "macros.h"
15 #include "params.h"
16 
17 /*
18  * Compute g(x) = f(a*x+b), with deg f = d, and a and b are longs.
19  * Must have a != 0 and d >= 1
20  */
mp_poly_linear_comp(mpz_t * g,mpz_t * f,int d,long a,long b)21 void mp_poly_linear_comp(mpz_t *g, mpz_t *f, int d, long a, long b) {
22     ASSERT (a != 0  &&  d >= 1);
23     // lazy: use the mpz_poly interface of utils/mpz_poly.h
24     mpz_poly aXpb, aXpbi, G, Aux;
25     mpz_poly_init(aXpb, 1);  // alloc sets to zero
26     mpz_poly_init(aXpbi, d);
27     mpz_poly_init(G, d);
28     mpz_poly_init(Aux, d);
29     {
30         mpz_t aux;
31         mpz_init(aux);
32         mpz_set_si(aux, a);
33         mpz_poly_setcoeff(aXpb, 1, aux);
34         mpz_set_si(aux, b);
35         mpz_poly_setcoeff(aXpb, 0, aux);
36         mpz_clear(aux);
37     }
38     mpz_poly_set(aXpbi, aXpb);
39     mpz_poly_setcoeff(G, 0, f[0]);
40     for (int i = 1; i <= d; ++i) {
41         mpz_poly_mul_mpz(Aux, aXpbi, f[i]);
42         mpz_poly_add(G, G, Aux);
43         if (i < d)
44             mpz_poly_mul(aXpbi, aXpbi, aXpb);
45     }
46     for (int i = 0; i <= d; ++i)
47         mpz_poly_getcoeff(g[i], i, G);
48     mpz_poly_clear(aXpb);
49     mpz_poly_clear(aXpbi);
50     mpz_poly_clear(G);
51     mpz_poly_clear(Aux);
52 }
53 
mpz_p_val(mpz_t z,unsigned long p)54 int mpz_p_val(mpz_t z, unsigned long p) {
55     int v = 0;
56     if (mpz_cmp_ui(z, 0) == 0)
57         return INT_MAX;
58     mpz_t zz;
59     mpz_init(zz);
60     mpz_abs(zz, z);
61     unsigned long r;
62     do {
63         r = mpz_tdiv_q_ui(zz, zz, p);
64         if (r == 0)
65             v++;
66     } while (r == 0 && mpz_cmp_ui(zz, 1)!=0);
67     mpz_clear(zz);
68     return v;
69 }
70 
mp_poly_p_val_of_content(mpz_t * f,int d,unsigned long p)71 int mp_poly_p_val_of_content(mpz_t *f, int d, unsigned long p) {
72     int val = INT_MAX;
73     for (int i = 0; i <= d; ++i) {
74         int v = mpz_p_val(f[i], p);
75         if (v == 0)
76             return 0;
77         val = MIN(val, v);
78     }
79     return val;
80 }
81 
82 
83 void
mp_poly_eval(mpz_t r,mpz_t * poly,int deg,mpz_t a)84 mp_poly_eval (mpz_t r, mpz_t *poly, int deg, mpz_t a)
85 {
86   int i;
87 
88   mpz_set (r, poly[deg]);
89   for (i = deg - 1; i >= 0; i--)
90     {
91       mpz_mul (r, r, a);
92       mpz_add (r, r, poly[i]);
93     }
94 }
95 
96 // Evaluate the derivative of poly at a.
97 void
mp_poly_eval_diff(mpz_t r,mpz_t * poly,int deg,mpz_t a)98 mp_poly_eval_diff (mpz_t r, mpz_t *poly, int deg, mpz_t a)
99 {
100   int i;
101 
102   mpz_mul_ui (r, poly[deg], (unsigned long)deg);
103   for (i = deg - 1; i >= 1; i--)
104     {
105       mpz_mul (r, r, a);
106       mpz_addmul_ui (r, poly[i], (unsigned long)i);
107     }
108 }
109 
110 // Same function as above, with a slightly different interface.
111 unsigned long
lift_root_unramified(mpz_t * f,int d,unsigned long r,unsigned long p,int kmax)112 lift_root_unramified(mpz_t *f, int d, unsigned long r,
113         unsigned long p,int kmax) {
114     mpz_t aux, aux2, mp_p, mp_r;
115     int k = 1;
116     mpz_init(aux);
117     mpz_init(aux2);
118     mpz_init_set_ui(mp_r, r);
119     mpz_init_set_ui(mp_p, p);
120     while (k < kmax) {
121         if (2*k <= kmax)
122             mpz_mul(mp_p, mp_p, mp_p); // p^2k
123         else {
124             for (int i = k+1; i <= kmax; ++i)
125                 mpz_mul_ui(mp_p, mp_p, p);
126         }
127         mp_poly_eval(aux, f, d, mp_r);
128         mp_poly_eval_diff(aux2, f, d, mp_r);
129         if (!mpz_invert(aux2, aux2, mp_p)) {
130             fprintf(stderr, "Error in lift_root_unramified: multiple root mod %lu\n", p);
131             exit(EXIT_FAILURE);
132         }
133         mpz_mul(aux, aux, aux2);
134         mpz_sub(aux, mp_r, aux);
135         mpz_mod(mp_r, aux, mp_p);
136         k *= 2;
137     }
138     r = mpz_get_ui(mp_r);
139     mpz_clear(aux);
140     mpz_clear(aux2);
141     mpz_clear(mp_p);
142     mpz_clear(mp_r);
143     return r;
144 }
145 
146 typedef struct {
147     unsigned long q;
148     unsigned long r;
149     int n1;
150     int n0;
151 } entry;
152 
153 typedef struct {
154     entry *list;
155     int len;
156     int alloc;
157 } entry_list;
158 
entry_list_init(entry_list * L)159 void entry_list_init(entry_list *L) {
160     L->list = (entry*) malloc(10*sizeof(entry));
161     L->alloc = 10;
162     L->len = 0;
163 }
164 
entry_list_clear(entry_list * L)165 void entry_list_clear(entry_list *L) {
166     free(L->list);
167 }
168 
push_entry(entry_list * L,entry E)169 void push_entry(entry_list *L, entry E) {
170     if (L->len == L->alloc) {
171         L->alloc += 10;
172         L->list = (entry *)realloc(L->list, (L->alloc)*sizeof(entry));
173     }
174     L->list[L->len++] = E;
175 }
176 
cmp_entry(const void * A,const void * B)177 int cmp_entry(const void *A, const void *B) {
178     entry a, b;
179     a = ((entry *)A)[0];
180     b = ((entry *)B)[0];
181     if (a.q < b.q)
182         return -1;
183     if (a.q > b.q)
184         return 1;
185     if (a.n1 < b.n1)
186         return -1;
187     if (a.n1 > b.n1)
188         return 1;
189     if (a.n0 < b.n0)
190         return -1;
191     if (a.n0 > b.n0)
192         return 1;
193     if (a.r < b.r)
194         return -1;
195     if (a.r > b.r)
196         return 1;
197     return 0;
198 }
199 
200 /* see makefb.sage for the meaning of this function and its parameters */
all_roots_affine(entry_list * L,mpz_t * f,int d,unsigned long p,int kmax,int k0,int m,unsigned long phi1,unsigned long phi0,gmp_randstate_ptr rstate)201 void all_roots_affine(entry_list *L, mpz_t *f, int d, unsigned long p,
202         int kmax, int k0, int m, unsigned long phi1, unsigned long phi0, gmp_randstate_ptr rstate) {
203     int nroots;
204     unsigned long *roots;
205     mpz_t aux;
206 
207     mpz_poly F;
208     F->coeff = f;
209     F->deg = d;
210 
211     if (k0 >= kmax) {
212         return;
213     }
214     roots = (unsigned long*) malloc(d * sizeof(unsigned long));
215     mpz_init(aux);
216 
217     nroots = mpz_poly_roots_ulong (roots, F, p, rstate);
218     for (int i = 0; i < nroots; ++i) {
219         unsigned long r = roots[i];
220         {
221             mpz_t mp_r;
222             mpz_init_set_ui(mp_r, r);
223             mp_poly_eval_diff(aux, f, d, mp_r);
224             mpz_clear(mp_r);
225         }
226         unsigned long dfr = mpz_mod_ui(aux, aux, p);
227         if (dfr != 0) {
228             unsigned long rr = lift_root_unramified(f, d, r, p, kmax-k0);
229             unsigned long phir = phi1 * rr + phi0;
230             unsigned long pml = 1;
231             for (int j = 0; j < m; ++j)
232                 pml *= p;
233             for (int l = 1; l <= kmax-k0; ++l) {
234                 pml *= p;
235                 unsigned long phirr = phir % pml;
236                 entry E;
237                 E.q = pml;
238                 E.r = phirr;
239                 E.n1 = k0+l;
240                 E.n0 = k0+l-1;
241                 push_entry(L, E);
242             }
243         } else {
244             mpz_t *ff;
245             ff = (mpz_t *) malloc((d+1)*sizeof(mpz_t));
246             for (int j = 0; j <= d; ++j)
247                 mpz_init(ff[j]);
248             mp_poly_linear_comp(ff, f, d, p, r);
249             int val = mp_poly_p_val_of_content(ff, d, p);
250             unsigned long pmp1 = 1;
251             for (int j = 0; j < m+1; ++j)
252                 pmp1 *= p;
253             unsigned long phir = (phi1 * r + phi0) % pmp1;
254             entry E;
255             E.q = pmp1;
256             E.r = phir;
257             E.n1 = k0+val;
258             E.n0 = k0;
259             push_entry(L, E);
260             unsigned long nphi1 = phi1*p;
261             unsigned long nphi0 = phi0 + phi1*r;
262             {
263                 mpz_t pv;
264                 mpz_init(pv);
265                 mpz_set_ui(pv, 1);
266                 for (int j = 0; j < val; ++j)
267                     mpz_mul_ui(pv, pv, p);
268                 for (int j = 0; j <= d; ++j)
269                     mpz_tdiv_q(ff[j], ff[j], pv);
270                 mpz_clear(pv);
271             }
272             all_roots_affine(L, ff, d, p, kmax, k0+val, m+1, nphi1, nphi0, rstate);
273             for (int j = 0; j <= d; ++j)
274                 mpz_clear(ff[j]);
275             free(ff);
276         }
277     }
278     free(roots);
279     mpz_clear(aux);
280 }
281 
282 /* Compute roots mod powers of p, up to maxbits.
283  * TODO:
284  * Maybe the number of bits of the power is not really the right measure.
285  * We could take into account not only the cost of the updates (which is
286  * indeed in k*log(p)), but also the amount of information we gain for
287  * each update, which is in log(p). So what would be the right measure
288  * ???
289  */
290 
291 /*
292  * See makefb.sage for details on this function
293  */
294 
all_roots(mpz_t * f,int d,unsigned long p,int maxbits,gmp_randstate_ptr rstate)295 entry_list all_roots(mpz_t *f, int d, unsigned long p, int maxbits, gmp_randstate_ptr rstate) {
296     int kmax;
297     entry_list L;
298     entry_list_init(&L);
299     kmax = (int)floor(maxbits*log(2)/log(p));
300     if (kmax == 0)
301         kmax = 1;
302     { // handle projective roots first.
303         mpz_t *fh;
304         mpz_t pk;
305         fh = (mpz_t *)malloc ((d+1)*sizeof(mpz_t));
306         mpz_init(pk);
307         mpz_set_ui(pk, 1);
308         for (int i = 0; i <= d; ++i) {
309             mpz_init(fh[i]);
310             mpz_mul(fh[i], f[d-i], pk);
311             if (i < d)
312                 mpz_mul_ui(pk, pk, p);
313         }
314         int v = mp_poly_p_val_of_content(fh, d, p);
315         if (v > 0) { // We have projective roots only in that case
316             {
317                 entry E;
318                 E.q = p;
319                 E.r = p;
320                 E.n1 = v;
321                 E.n0 = 0;
322                 push_entry(&L, E);
323             }
324             mpz_set_ui(pk, p);
325             for (int i = 1; i < v; ++i)
326                 mpz_mul_ui(pk, pk, p);
327             for (int i = 0; i <= d; ++i)
328                 mpz_tdiv_q(fh[i], fh[i], pk);
329 
330             all_roots_affine(&L, fh, d, p, kmax-1, 0, 0, 1, 0, rstate);
331             // convert back the roots
332             for (int i = 1; i < L.len; ++i) {
333                 entry E = L.list[i];
334                 E.q *= p;
335                 E.n1 += v;
336                 E.n0 += v;
337                 E.r = E.r*p + E.q;
338                 L.list[i] = E;
339             }
340         }
341         for (int i = 0; i <= d; ++i)
342             mpz_clear(fh[i]);
343         mpz_clear(pk);
344         free(fh);
345     }
346     // affine roots are easier.
347     all_roots_affine(&L, f, d, p, kmax, 0, 0, 1, 0, rstate);
348 
349     qsort((void *)(&L.list[0]), L.len, sizeof(entry), cmp_entry);
350     return L;
351 }
352 
353 /* process 'GROUP' primes per thread, since the timings might differ a lot
354    between successive primes */
355 #define GROUP 1024
356 
357 /* thread structure */
358 typedef struct
359 {
360   unsigned long p[GROUP];
361   entry_list L[GROUP];
362   int n; /* number of primes to be processed, n <= GROUP */
363   mpz_t *f;
364   int d;
365   int thread;
366   int maxbits;
367 } __tab_struct;
368 typedef __tab_struct tab_t[1];
369 
370 void*
one_thread(void * args)371 one_thread (void* args)
372 {
373   int k;
374   tab_t *tab = (tab_t*) args;
375   gmp_randstate_t rstate;
376   gmp_randinit_default(rstate);
377   for (k = 0; k < tab[0]->n; k++)
378     tab[0]->L[k] = all_roots (tab[0]->f, tab[0]->d, tab[0]->p[k],
379                               tab[0]->maxbits, rstate);
380   gmp_randclear(rstate);
381   return NULL;
382 }
383 
makefb_with_powers(FILE * outfile,mpz_poly F,unsigned long lim,int maxbits,int nb_threads)384 void makefb_with_powers(FILE* outfile, mpz_poly F, unsigned long lim,
385                         int maxbits, int nb_threads)
386 {
387     mpz_t *f = F->coeff;
388     int d = F->deg, j, k, maxj;
389 
390     fprintf(outfile, "# Roots for polynomial ");
391     mpz_poly_fprintf(outfile, F);
392     fprintf(outfile, "# DEGREE: %d\n", d);
393     fprintf(outfile, "# lim = %lu\n", lim);
394     fprintf(outfile, "# maxbits = %d\n", maxbits);
395 
396     pthread_t *tid;
397     unsigned long p;
398     tab_t *T;
399     tid = (pthread_t*) malloc (nb_threads * sizeof (pthread_t));
400     T = (tab_t*) malloc (nb_threads * sizeof (tab_t));
401     for (j = 0; j < nb_threads; j++)
402       {
403         T[j]->f = f;
404         T[j]->d = d;
405         T[j]->maxbits = maxbits;
406       }
407     prime_info pi;
408     prime_info_init (pi);
409     for (p = 2; p <= lim;) {
410       for (j = 0; j < nb_threads && p <= lim; j++)
411         {
412           for (k = 0; k < GROUP && p <= lim; p = getprime_mt (pi), k++)
413             T[j]->p[k] = p;
414           T[j]->n = k;
415         }
416       maxj = j;
417       for (j = 0; j < maxj; j++)
418         pthread_create (&tid[j], NULL, one_thread, (void *) (T+j));
419       while (j > 0)
420         pthread_join (tid[--j], NULL);
421       for (j = 0; j < maxj; j++) {
422         for (k = 0; k < T[j]->n; k++) {
423           // print in a compactified way
424           int oldn0=-1, oldn1=-1;
425           unsigned long oldq = 0;
426           for (int i = 0; i < T[j]->L[k].len; ++i) {
427             unsigned long q = T[j]->L[k].list[i].q;
428             int n1 = T[j]->L[k].list[i].n1;
429             int n0 = T[j]->L[k].list[i].n0;
430             unsigned long r =  T[j]->L[k].list[i].r;
431             if (q == oldq && n1 == oldn1 && n0 == oldn0)
432               fprintf(outfile, ",%lu", r);
433             else {
434               if (i > 0)
435                 fprintf(outfile, "\n");
436               oldq = q; oldn1 = n1; oldn0 = n0;
437               if (n1 == 1 && n0 == 0)
438                 fprintf(outfile, "%lu: %lu", q, r);
439               else
440                 fprintf(outfile, "%lu:%d,%d: %lu", q, n1, n0, r);
441             }
442           }
443           if (T[j]->L[k].len > 0)
444             fprintf(outfile, "\n");
445           entry_list_clear(&(T[j]->L[k]));
446         }
447       }
448     }
449     free (tid);
450     free (T);
451     prime_info_clear (pi);
452 }
453 
declare_usage(param_list pl)454 static void declare_usage(param_list pl)
455 {
456     param_list_decl_usage(pl, "poly", "polynomial file");
457     param_list_decl_usage(pl, "lim", "factor base bound");
458     param_list_decl_usage(pl, "maxbits", "(optional) maximal number of "
459             "bits of powers");
460     param_list_decl_usage(pl, "out", "(optional) name of the output file");
461     param_list_decl_usage(pl, "side", "(optional) create factor base for given side. "
462                         "By default, use the unique algebraic side.");
463     param_list_decl_usage(pl, "t", "number of threads");
464     verbose_decl_usage(pl);
465 }
466 
467 int
main(int argc,char * argv[])468 main (int argc, char *argv[])
469 {
470   param_list pl;
471   cado_poly cpoly;
472   FILE * f, *outputfile;
473   const char *outfilename = NULL;
474   int maxbits = 1;  // disable powers by default
475   int side = -1;
476   unsigned long lim = ULONG_MAX;
477   char *argv0 = argv[0];
478   unsigned long nb_threads = 1;
479 
480   param_list_init(pl);
481   declare_usage(pl);
482   cado_poly_init(cpoly);
483 
484   argv++, argc--;
485   for( ; argc ; ) {
486       if (param_list_update_cmdline(pl, &argc, &argv)) { continue; }
487 
488       /* Could also be a file */
489       if ((f = fopen(argv[0], "r")) != NULL) {
490           param_list_read_stream(pl, f, 0);
491           fclose(f);
492           argv++,argc--;
493           continue;
494       }
495 
496       fprintf(stderr, "Unhandled parameter %s\n", argv[0]);
497       param_list_print_usage(pl, argv0, stderr);
498       exit (EXIT_FAILURE);
499   }
500   verbose_interpret_parameters(pl);
501   param_list_print_command_line(stdout, pl);
502 
503   const char * filename;
504   if ((filename = param_list_lookup_string(pl, "poly")) == NULL) {
505       fprintf(stderr, "Error: parameter -poly is mandatory\n");
506       param_list_print_usage(pl, argv0, stderr);
507       exit(EXIT_FAILURE);
508   }
509 
510   param_list_parse_ulong(pl, "t"   , &nb_threads);
511   ASSERT_ALWAYS(1 <= nb_threads);
512 
513   param_list_parse_ulong(pl, "lim", &lim);
514   if (lim == ULONG_MAX) {
515       fprintf(stderr, "Error: parameter -lim is mandatory\n");
516       param_list_print_usage(pl, argv0, stderr);
517       exit(EXIT_FAILURE);
518   }
519 
520   param_list_parse_int(pl, "maxbits", &maxbits);
521 
522   outfilename = param_list_lookup_string(pl, "out");
523   if (outfilename != NULL) {
524     outputfile = fopen_maybe_compressed(outfilename, "w");
525     if (!outputfile) {
526         fprintf(stderr, "Error: could not open output file: %s\n", outfilename);
527         exit(EXIT_FAILURE);
528     }
529   } else {
530     outputfile = stdout;
531   }
532   if (lim == 0) {
533     fclose_maybe_compressed(outputfile, outfilename);
534     param_list_clear(pl);
535     return 0;
536   }
537 
538 
539   if (!cado_poly_read(cpoly, filename))
540     {
541       fprintf (stderr, "Error reading polynomial file %s\n", filename);
542       exit (EXIT_FAILURE);
543     }
544 
545   param_list_parse_int(pl, "side", &side);
546   if (side >= (int)cpoly->nb_polys){
547       fprintf(stderr, "Error: side must be in [0..%d[\n", cpoly->nb_polys);
548       param_list_print_usage(pl, argv0, stderr);
549       exit(EXIT_FAILURE);
550   }
551 
552   // No side is given: choose the unique algebraic side.
553   if (side == -1) {
554       for (int i = 0; i < cpoly->nb_polys; ++i) {
555           if (cpoly->pols[i]->deg > 1) {
556               if (side == -1) {
557                   side = i;
558               } else {
559                   fprintf(stderr, "Error: there are more than one algebraic side;"
560                           " parameter -side is therefore mandatory\n");
561                   param_list_print_usage(pl, argv0, stderr);
562                   exit(EXIT_FAILURE);
563               }
564           }
565       }
566       if (side == -1) {
567           fprintf(stderr, "Error: there are no algebraic side;"
568                   " parameter -side is therefore mandatory\n");
569           param_list_print_usage(pl, argv0, stderr);
570           exit(EXIT_FAILURE);
571       }
572   }
573 
574   param_list_warn_unused(pl);
575 
576   makefb_with_powers (outputfile, cpoly->pols[side], lim, maxbits, nb_threads);
577 
578   cado_poly_clear (cpoly);
579   if (outfilename != NULL) {
580     fclose_maybe_compressed(outputfile, outfilename);
581   }
582 
583   param_list_clear(pl);
584 
585   return 0;
586 }
587