1 /**
2  * @file ropt_arith.c
3  * Some arithmetics used in ropt.
4  */
5 
6 
7 #include "cado.h" // IWYU pragma: keep
8 #include <stdio.h>      // fprintf stderr
9 #include <math.h>       // ceil
10 #include <stdlib.h>       // exit
11 #include <gmp.h>
12 #include "mod_ul.h"
13 #include "macros.h"     // ASSERT
14 #include "ropt_arith.h"
15 
16 
17 /**
18  * Solve x in a + b*x = c (mod p)
19  */
20 unsigned long
solve_lineq(unsigned long a,unsigned long b,unsigned long c,unsigned long p)21 solve_lineq ( unsigned long a,
22               unsigned long b,
23               unsigned long c,
24               unsigned long p )
25 {
26   /* in general, we should know that gcd(b, p) = 1 */
27   if (b % p == 0) {
28     fprintf (stderr, "Error, impossible inverse in solve_lineq().\n");
29     exit (1);
30   }
31 
32   unsigned long tmp;
33   modulusul_t mod;
34   residueul_t tmpr, ar, cr;
35   modul_initmod_ul (mod, p);
36   modul_init (tmpr, mod);
37   modul_init (cr, mod);
38   modul_init (ar, mod);
39   modul_set_ul (cr, c, mod);
40   modul_set_ul (ar, a, mod);
41   modul_sub (cr, cr, ar, mod);
42   modul_set_ul (tmpr, b, mod);
43   modul_inv (tmpr, tmpr, mod);
44   modul_mul (tmpr, tmpr, cr, mod);
45   tmp = modul_get_ul(tmpr, mod);
46   modul_clear (tmpr, mod);
47   modul_clear (cr, mod);
48   modul_clear (ar, mod);
49   modul_clearmod (mod);
50   return tmp;
51 }
52 
53 
54 /**
55  * Change coordinate from (a, b) to (u, v),
56  * where A + MOD*a = u.
57  */
58 void
ab2uv(mpz_t A,mpz_t MOD,long a,mpz_t u)59 ab2uv ( mpz_t A,
60         mpz_t MOD,
61         long a,
62         mpz_t u )
63 {
64   mpz_mul_si (u, MOD, a);
65   mpz_add (u, u, A);
66 }
67 
68 
69 /**
70  * Change coordinate from (a, b) to the index of
71  * the sieving array, where index = a - Amin,
72  * where Amin is negative.
73  */
74 long
ab2ij(long Amin,long a)75 ab2ij ( long Amin,
76         long a )
77 {
78   return ( a - Amin );
79 }
80 
81 
82 /**
83  * Change coordinate from (i, j) to (a, b).
84  */
85 long
ij2ab(long Amin,long i)86 ij2ab ( long Amin,
87         long i )
88 {
89   return ( i + Amin );
90 }
91 
92 
93 /**
94  * Change coordinate from (i, j) to (u, v).
95  */
96 void
ij2uv(mpz_t A,mpz_t MOD,long Amin,long i,mpz_t u)97 ij2uv ( mpz_t A,
98         mpz_t MOD,
99         long Amin,
100         long i,
101         mpz_t u )
102 {
103   ab2uv(A, MOD, ij2ab(Amin, i), u);
104 }
105 
106 
107 /**
108  * Find coordinate a such that
109  * A + MOD*a = u (mod p).
110  */
111 unsigned int
uv2ab_mod(mpz_t A,mpz_t MOD,unsigned int U,unsigned int p)112 uv2ab_mod ( mpz_t A,
113             mpz_t MOD,
114             unsigned int U,
115             unsigned int p )
116 {
117   unsigned long a = mpz_fdiv_ui (A, p);
118   unsigned long mod = mpz_fdiv_ui (MOD, p);
119   unsigned long u = U % p;
120   /* compute the A + MOD * a = u (mod p) */
121   return (unsigned int) solve_lineq(a, mod, u, p);
122 }
123 
124 
125 /**
126  * Same as above, but return the
127  * position of a in the array.
128  */
129 long
uv2ij_mod(mpz_t A,long Amin,mpz_t MOD,unsigned int U,unsigned int p)130 uv2ij_mod ( mpz_t A,
131             long Amin,
132             mpz_t MOD,
133             unsigned int U,
134             unsigned int p )
135 {
136   long i = (long) uv2ab_mod (A, MOD, U, p);
137 
138   /* smallest k + p*i such that A0 < k + p*i, where A0 < 0,
139      hence i = ceil((A0-tmp)/p). Note, this should be negative. */
140   i = (long) ceil (((double) Amin - (double) i) / (double) p)
141     * (long) p + i;
142 
143   /* compute the position of this (u, v) in the array. */
144   i = ab2ij (Amin, i);
145 
146   return i;
147 }
148 
149 
150 /**
151  * Compute fuv = f+(u*x+v)*g,
152  * f(r) + u*r*g(r) + v*g(r) = 0
153  * The inputs for f and g are mpz.
154  */
155 void
compute_fuv_mp(mpz_t * fuv,mpz_t * f,mpz_t * g,int d,mpz_t u,mpz_t v)156 compute_fuv_mp ( mpz_t *fuv,
157                  mpz_t *f,
158                  mpz_t *g,
159                  int d,
160                  mpz_t u,
161                  mpz_t v )
162 {
163   mpz_t tmp, tmp1;
164   mpz_init (tmp);
165   mpz_init (tmp1);
166   int i = 0;
167 
168   for (i = 3; i <= d; i ++)
169     mpz_set (fuv[i], f[i]);
170 
171   /* f + u*g1*x^2
172      + (g0*u* + v*g1)*x
173      + v*g0 */
174 
175   /* Note, u, v are signed long! */
176   /* u*g1*x^2 */
177   mpz_mul (tmp, g[1], u);
178   mpz_add (fuv[2], f[2], tmp);
179 
180   /* (g0*u* + v*g1)*x */
181   mpz_mul (tmp, g[0], u);
182   mpz_mul (tmp1, g[1], v);
183   mpz_add (tmp, tmp, tmp1);
184   mpz_add (fuv[1], f[1], tmp);
185 
186   /* v*g0 */
187   mpz_mul (tmp, g[0], v);
188   mpz_add (fuv[0], f[0], tmp);
189 
190   mpz_clear (tmp);
191   mpz_clear (tmp1);
192 }
193 
194 
195 /**
196  * Compute fuv = f+(u*x+v)*g,
197  * The inputs for f and g are unsigne long.
198  * Note, u, v are unsigned int.
199  * So they should be reduce (mod p) if necessary.
200  */
201 void
compute_fuv_ui(unsigned int * fuv_ui,unsigned int * f_ui,unsigned int * g_ui,int d,unsigned int u,unsigned int v,unsigned int p)202 compute_fuv_ui ( unsigned int *fuv_ui,
203                  unsigned int *f_ui,
204                  unsigned int *g_ui,
205                  int d,
206                  unsigned int u,
207                  unsigned int v,
208                  unsigned int p )
209 {
210   int i;
211   modulusul_t mod;
212   residueul_t tmp, tmp1, tmp2;
213   modul_initmod_ul (mod, p);
214   modul_init (tmp, mod);
215   modul_init (tmp1, mod);
216   modul_init (tmp2, mod);
217 
218   for (i = 3; i <= d; i ++)
219     fuv_ui[i] = f_ui[i];
220 
221   /* f + u*g1*x^2
222      + (g0*u* + v*g1)*x
223      + v*g0 */
224 
225   /* u*g1*x^2 */
226   modul_set_ul (tmp, g_ui[1], mod);
227   modul_set_ul (tmp2, u, mod);
228   modul_mul (tmp, tmp, tmp2, mod);
229   modul_set_ul (tmp1, f_ui[2], mod);
230   modul_add (tmp, tmp, tmp1, mod);
231   fuv_ui[2] = (unsigned int) modul_get_ul(tmp, mod);
232 
233   /* (g0*u* + v*g1)*x */
234   modul_set_ul (tmp, g_ui[1], mod);
235   modul_set_ul (tmp1, v, mod);
236   modul_mul (tmp, tmp, tmp1, mod);
237   modul_set_ul (tmp1, g_ui[0], mod);
238   // tmp2 = u as set above.
239   modul_mul (tmp1, tmp1, tmp2, mod);
240   modul_add (tmp, tmp, tmp1, mod);
241   modul_set_ul (tmp1, f_ui[1], mod);
242   modul_add (tmp, tmp, tmp1, mod);
243   fuv_ui[1] = (unsigned int) modul_get_ul(tmp, mod);
244 
245   /* v*g0 */
246   modul_set_ul (tmp1, v, mod);
247   modul_set_ul (tmp2, g_ui[0], mod);
248   modul_mul (tmp1, tmp1, tmp2, mod);
249   modul_set_ul (tmp, f_ui[0], mod);
250   modul_add (tmp, tmp, tmp1, mod);
251   fuv_ui[0] = (unsigned int) modul_get_ul(tmp, mod);
252 }
253 
254 
255 /**
256  * Compute v (mod p) by
257  * f(r) + u*r*g(r) + v*g(r) = 0 (mod p).
258  * The inputs for f(r) and g(r) are unsigned int.
259  */
260 unsigned int
compute_v_ui(unsigned int fx,unsigned int gx,unsigned int r,unsigned int u,unsigned int p)261 compute_v_ui ( unsigned int fx,
262                unsigned int gx,
263                unsigned int r,
264                unsigned int u,
265                unsigned int p)
266 {
267   modulusul_t mod;
268   residueul_t tmp, tmp1;
269   unsigned long v;
270   modul_initmod_ul (mod, p);
271   modul_init (tmp, mod);
272   modul_init (tmp1, mod);
273 
274   /* g(r)*r*u + f(r) */
275   modul_set_ul (tmp, gx, mod);
276   modul_set_ul (tmp1, r, mod);
277   modul_mul (tmp, tmp, tmp1, mod);
278   modul_set_ul (tmp1, u, mod);
279   modul_mul (tmp, tmp, tmp1, mod);
280   modul_set_ul (tmp1, fx, mod);
281   modul_add (tmp, tmp, tmp1, mod);
282   v = modul_get_ul(tmp, mod);
283 
284   /* solve v in tmp2 + v*g(r) = 0 (mod p) */
285   v = solve_lineq (v, gx, 0, p);
286 
287   modul_clear (tmp, mod);
288   modul_clear (tmp1, mod);
289   modul_clearmod (mod);
290   return (unsigned int) v;
291 }
292 
293 
294 /**
295  * Compute v = f(r) (mod pe), where f is of degree d.
296  * The input f should be unsigned int.
297  */
298 unsigned int
eval_poly_ui_mod(unsigned int * f,int d,unsigned int r,unsigned int pe)299 eval_poly_ui_mod ( unsigned int *f,
300                    int d,
301                    unsigned int r,
302                    unsigned int pe )
303 {
304   int i;
305   modulusul_t mod;
306   residueul_t vtmp, rtmp, tmp;
307   unsigned int v;
308 
309   modul_initmod_ul (mod, pe);
310   modul_init (vtmp, mod);
311   modul_init (rtmp, mod);
312   modul_init (tmp, mod);
313 
314   /* set vtmp = f[d] (mod p) and rtmp = r (mod p) */
315   modul_set_ul (vtmp, f[d], mod);
316   modul_set_ul (rtmp, r, mod);
317 
318   for (i = d - 1; i >= 0; i--) {
319     modul_mul (vtmp, vtmp, rtmp, mod);
320     modul_set_ul (tmp, f[i], mod);
321     modul_add (vtmp, tmp, vtmp, mod);
322   }
323 
324   v = (unsigned int) modul_get_ul (vtmp, mod);
325   modul_clear (vtmp, mod);
326   modul_clear (rtmp, mod);
327   modul_clear (tmp, mod);
328   modul_clearmod (mod);
329 
330   return v;
331 }
332 
333 
334 /**
335  * Reduce mpz_t *f to unsigned int *f_mod;
336  * Given modulus pe, return f (mod pe).
337  */
338 inline void
reduce_poly_ul(unsigned int * f_ui,mpz_t * f,int d,unsigned int pe)339 reduce_poly_ul ( unsigned int *f_ui,
340                  mpz_t *f,
341                  int d,
342                  unsigned int pe )
343 {
344   int i;
345   for (i = 0; i <= d; i ++) {
346     f_ui[i] = (unsigned int) mpz_fdiv_ui (f[i], pe);
347   }
348 }
349 
350 
351 /**
352  * From polyselect.c
353  * Implements Lemma 2.1 from Kleinjung's paper.
354  * If a[d] is non-zero, it is assumed it is already set, otherwise it is
355  * determined as a[d] = N/m^d (mod p).
356  */
357 void
Lemma21(mpz_t * a,mpz_t N,int d,mpz_t p,mpz_t m,mpz_t res)358 Lemma21 ( mpz_t *a,
359           mpz_t N,
360           int d,
361           mpz_t p,
362           mpz_t m,
363           mpz_t res )
364 {
365   mpz_t r, mi, invp, l, ln;
366   int i;
367 
368   mpz_init (r);
369   mpz_init_set_ui (l, 1);
370   mpz_init (ln);
371   mpz_init (mi);
372   mpz_init (invp);
373   mpz_pow_ui (mi, m, d);
374 
375   if (mpz_cmp_ui (a[d], 0) < 0)
376     mpz_abs (a[d], a[d]);
377 
378   if (mpz_cmp_ui (a[d], 0) == 0) {
379     mpz_invert (a[d], mi, p); /* 1/m^d mod p */
380     mpz_mul (a[d], a[d], N);
381     mpz_mod (a[d], a[d], p);
382     mpz_set_ui (l, 1);
383   }
384   /* multiplier l < m1 */
385   else {
386     mpz_invert (l, N, p);
387     mpz_mul (l, l, mi);
388     mpz_mul (l, l, a[d]);
389     mpz_mod (l, l, p);
390   }
391   mpz_mul (ln, N, l);
392   mpz_set (r, ln);
393   mpz_set (res, l);
394 
395   for (i = d - 1; i >= 0; i--)
396   {
397     /* invariant: mi = m^(i+1) */
398     mpz_mul (a[i], a[i+1], mi);
399     mpz_sub (r, r, a[i]);
400     ASSERT (mpz_divisible_p (r, p));
401     mpz_divexact (r, r, p);
402     mpz_divexact (mi, mi, m); /* now mi = m^i */
403     if (i == d - 1)
404     {
405       mpz_invert (invp, p, mi); /* 1/p mod m^i */
406       mpz_sub (invp, mi, invp); /* -1/p mod m^i */
407     }
408     else
409       mpz_mod (invp, invp, mi);
410     mpz_mul (a[i], invp, r);
411     mpz_mod (a[i], a[i], mi); /* -r/p mod m^i */
412     /* round to nearest in [-m^i/2, m^i/2] */
413     mpz_mul_2exp (a[i], a[i], 1);
414     if (mpz_cmp (a[i], mi) >= 0)
415     {
416       mpz_div_2exp (a[i], a[i], 1);
417       mpz_sub (a[i], a[i], mi);
418     }
419     else
420       mpz_div_2exp (a[i], a[i], 1);
421     mpz_mul (a[i], a[i], p);
422     mpz_add (a[i], a[i], r);
423     ASSERT (mpz_divisible_p (a[i], mi));
424     mpz_divexact (a[i], a[i], mi);
425   }
426   mpz_clear (r);
427   mpz_clear (l);
428   mpz_clear (ln);
429   mpz_clear (mi);
430   mpz_clear (invp);
431 }
432