1 /**
2 * @file ropt_arith.c
3 * Some arithmetics used in ropt.
4 */
5
6
7 #include "cado.h" // IWYU pragma: keep
8 #include <stdio.h> // fprintf stderr
9 #include <math.h> // ceil
10 #include <stdlib.h> // exit
11 #include <gmp.h>
12 #include "mod_ul.h"
13 #include "macros.h" // ASSERT
14 #include "ropt_arith.h"
15
16
17 /**
18 * Solve x in a + b*x = c (mod p)
19 */
20 unsigned long
solve_lineq(unsigned long a,unsigned long b,unsigned long c,unsigned long p)21 solve_lineq ( unsigned long a,
22 unsigned long b,
23 unsigned long c,
24 unsigned long p )
25 {
26 /* in general, we should know that gcd(b, p) = 1 */
27 if (b % p == 0) {
28 fprintf (stderr, "Error, impossible inverse in solve_lineq().\n");
29 exit (1);
30 }
31
32 unsigned long tmp;
33 modulusul_t mod;
34 residueul_t tmpr, ar, cr;
35 modul_initmod_ul (mod, p);
36 modul_init (tmpr, mod);
37 modul_init (cr, mod);
38 modul_init (ar, mod);
39 modul_set_ul (cr, c, mod);
40 modul_set_ul (ar, a, mod);
41 modul_sub (cr, cr, ar, mod);
42 modul_set_ul (tmpr, b, mod);
43 modul_inv (tmpr, tmpr, mod);
44 modul_mul (tmpr, tmpr, cr, mod);
45 tmp = modul_get_ul(tmpr, mod);
46 modul_clear (tmpr, mod);
47 modul_clear (cr, mod);
48 modul_clear (ar, mod);
49 modul_clearmod (mod);
50 return tmp;
51 }
52
53
54 /**
55 * Change coordinate from (a, b) to (u, v),
56 * where A + MOD*a = u.
57 */
58 void
ab2uv(mpz_t A,mpz_t MOD,long a,mpz_t u)59 ab2uv ( mpz_t A,
60 mpz_t MOD,
61 long a,
62 mpz_t u )
63 {
64 mpz_mul_si (u, MOD, a);
65 mpz_add (u, u, A);
66 }
67
68
69 /**
70 * Change coordinate from (a, b) to the index of
71 * the sieving array, where index = a - Amin,
72 * where Amin is negative.
73 */
74 long
ab2ij(long Amin,long a)75 ab2ij ( long Amin,
76 long a )
77 {
78 return ( a - Amin );
79 }
80
81
82 /**
83 * Change coordinate from (i, j) to (a, b).
84 */
85 long
ij2ab(long Amin,long i)86 ij2ab ( long Amin,
87 long i )
88 {
89 return ( i + Amin );
90 }
91
92
93 /**
94 * Change coordinate from (i, j) to (u, v).
95 */
96 void
ij2uv(mpz_t A,mpz_t MOD,long Amin,long i,mpz_t u)97 ij2uv ( mpz_t A,
98 mpz_t MOD,
99 long Amin,
100 long i,
101 mpz_t u )
102 {
103 ab2uv(A, MOD, ij2ab(Amin, i), u);
104 }
105
106
107 /**
108 * Find coordinate a such that
109 * A + MOD*a = u (mod p).
110 */
111 unsigned int
uv2ab_mod(mpz_t A,mpz_t MOD,unsigned int U,unsigned int p)112 uv2ab_mod ( mpz_t A,
113 mpz_t MOD,
114 unsigned int U,
115 unsigned int p )
116 {
117 unsigned long a = mpz_fdiv_ui (A, p);
118 unsigned long mod = mpz_fdiv_ui (MOD, p);
119 unsigned long u = U % p;
120 /* compute the A + MOD * a = u (mod p) */
121 return (unsigned int) solve_lineq(a, mod, u, p);
122 }
123
124
125 /**
126 * Same as above, but return the
127 * position of a in the array.
128 */
129 long
uv2ij_mod(mpz_t A,long Amin,mpz_t MOD,unsigned int U,unsigned int p)130 uv2ij_mod ( mpz_t A,
131 long Amin,
132 mpz_t MOD,
133 unsigned int U,
134 unsigned int p )
135 {
136 long i = (long) uv2ab_mod (A, MOD, U, p);
137
138 /* smallest k + p*i such that A0 < k + p*i, where A0 < 0,
139 hence i = ceil((A0-tmp)/p). Note, this should be negative. */
140 i = (long) ceil (((double) Amin - (double) i) / (double) p)
141 * (long) p + i;
142
143 /* compute the position of this (u, v) in the array. */
144 i = ab2ij (Amin, i);
145
146 return i;
147 }
148
149
150 /**
151 * Compute fuv = f+(u*x+v)*g,
152 * f(r) + u*r*g(r) + v*g(r) = 0
153 * The inputs for f and g are mpz.
154 */
155 void
compute_fuv_mp(mpz_t * fuv,mpz_t * f,mpz_t * g,int d,mpz_t u,mpz_t v)156 compute_fuv_mp ( mpz_t *fuv,
157 mpz_t *f,
158 mpz_t *g,
159 int d,
160 mpz_t u,
161 mpz_t v )
162 {
163 mpz_t tmp, tmp1;
164 mpz_init (tmp);
165 mpz_init (tmp1);
166 int i = 0;
167
168 for (i = 3; i <= d; i ++)
169 mpz_set (fuv[i], f[i]);
170
171 /* f + u*g1*x^2
172 + (g0*u* + v*g1)*x
173 + v*g0 */
174
175 /* Note, u, v are signed long! */
176 /* u*g1*x^2 */
177 mpz_mul (tmp, g[1], u);
178 mpz_add (fuv[2], f[2], tmp);
179
180 /* (g0*u* + v*g1)*x */
181 mpz_mul (tmp, g[0], u);
182 mpz_mul (tmp1, g[1], v);
183 mpz_add (tmp, tmp, tmp1);
184 mpz_add (fuv[1], f[1], tmp);
185
186 /* v*g0 */
187 mpz_mul (tmp, g[0], v);
188 mpz_add (fuv[0], f[0], tmp);
189
190 mpz_clear (tmp);
191 mpz_clear (tmp1);
192 }
193
194
195 /**
196 * Compute fuv = f+(u*x+v)*g,
197 * The inputs for f and g are unsigne long.
198 * Note, u, v are unsigned int.
199 * So they should be reduce (mod p) if necessary.
200 */
201 void
compute_fuv_ui(unsigned int * fuv_ui,unsigned int * f_ui,unsigned int * g_ui,int d,unsigned int u,unsigned int v,unsigned int p)202 compute_fuv_ui ( unsigned int *fuv_ui,
203 unsigned int *f_ui,
204 unsigned int *g_ui,
205 int d,
206 unsigned int u,
207 unsigned int v,
208 unsigned int p )
209 {
210 int i;
211 modulusul_t mod;
212 residueul_t tmp, tmp1, tmp2;
213 modul_initmod_ul (mod, p);
214 modul_init (tmp, mod);
215 modul_init (tmp1, mod);
216 modul_init (tmp2, mod);
217
218 for (i = 3; i <= d; i ++)
219 fuv_ui[i] = f_ui[i];
220
221 /* f + u*g1*x^2
222 + (g0*u* + v*g1)*x
223 + v*g0 */
224
225 /* u*g1*x^2 */
226 modul_set_ul (tmp, g_ui[1], mod);
227 modul_set_ul (tmp2, u, mod);
228 modul_mul (tmp, tmp, tmp2, mod);
229 modul_set_ul (tmp1, f_ui[2], mod);
230 modul_add (tmp, tmp, tmp1, mod);
231 fuv_ui[2] = (unsigned int) modul_get_ul(tmp, mod);
232
233 /* (g0*u* + v*g1)*x */
234 modul_set_ul (tmp, g_ui[1], mod);
235 modul_set_ul (tmp1, v, mod);
236 modul_mul (tmp, tmp, tmp1, mod);
237 modul_set_ul (tmp1, g_ui[0], mod);
238 // tmp2 = u as set above.
239 modul_mul (tmp1, tmp1, tmp2, mod);
240 modul_add (tmp, tmp, tmp1, mod);
241 modul_set_ul (tmp1, f_ui[1], mod);
242 modul_add (tmp, tmp, tmp1, mod);
243 fuv_ui[1] = (unsigned int) modul_get_ul(tmp, mod);
244
245 /* v*g0 */
246 modul_set_ul (tmp1, v, mod);
247 modul_set_ul (tmp2, g_ui[0], mod);
248 modul_mul (tmp1, tmp1, tmp2, mod);
249 modul_set_ul (tmp, f_ui[0], mod);
250 modul_add (tmp, tmp, tmp1, mod);
251 fuv_ui[0] = (unsigned int) modul_get_ul(tmp, mod);
252 }
253
254
255 /**
256 * Compute v (mod p) by
257 * f(r) + u*r*g(r) + v*g(r) = 0 (mod p).
258 * The inputs for f(r) and g(r) are unsigned int.
259 */
260 unsigned int
compute_v_ui(unsigned int fx,unsigned int gx,unsigned int r,unsigned int u,unsigned int p)261 compute_v_ui ( unsigned int fx,
262 unsigned int gx,
263 unsigned int r,
264 unsigned int u,
265 unsigned int p)
266 {
267 modulusul_t mod;
268 residueul_t tmp, tmp1;
269 unsigned long v;
270 modul_initmod_ul (mod, p);
271 modul_init (tmp, mod);
272 modul_init (tmp1, mod);
273
274 /* g(r)*r*u + f(r) */
275 modul_set_ul (tmp, gx, mod);
276 modul_set_ul (tmp1, r, mod);
277 modul_mul (tmp, tmp, tmp1, mod);
278 modul_set_ul (tmp1, u, mod);
279 modul_mul (tmp, tmp, tmp1, mod);
280 modul_set_ul (tmp1, fx, mod);
281 modul_add (tmp, tmp, tmp1, mod);
282 v = modul_get_ul(tmp, mod);
283
284 /* solve v in tmp2 + v*g(r) = 0 (mod p) */
285 v = solve_lineq (v, gx, 0, p);
286
287 modul_clear (tmp, mod);
288 modul_clear (tmp1, mod);
289 modul_clearmod (mod);
290 return (unsigned int) v;
291 }
292
293
294 /**
295 * Compute v = f(r) (mod pe), where f is of degree d.
296 * The input f should be unsigned int.
297 */
298 unsigned int
eval_poly_ui_mod(unsigned int * f,int d,unsigned int r,unsigned int pe)299 eval_poly_ui_mod ( unsigned int *f,
300 int d,
301 unsigned int r,
302 unsigned int pe )
303 {
304 int i;
305 modulusul_t mod;
306 residueul_t vtmp, rtmp, tmp;
307 unsigned int v;
308
309 modul_initmod_ul (mod, pe);
310 modul_init (vtmp, mod);
311 modul_init (rtmp, mod);
312 modul_init (tmp, mod);
313
314 /* set vtmp = f[d] (mod p) and rtmp = r (mod p) */
315 modul_set_ul (vtmp, f[d], mod);
316 modul_set_ul (rtmp, r, mod);
317
318 for (i = d - 1; i >= 0; i--) {
319 modul_mul (vtmp, vtmp, rtmp, mod);
320 modul_set_ul (tmp, f[i], mod);
321 modul_add (vtmp, tmp, vtmp, mod);
322 }
323
324 v = (unsigned int) modul_get_ul (vtmp, mod);
325 modul_clear (vtmp, mod);
326 modul_clear (rtmp, mod);
327 modul_clear (tmp, mod);
328 modul_clearmod (mod);
329
330 return v;
331 }
332
333
334 /**
335 * Reduce mpz_t *f to unsigned int *f_mod;
336 * Given modulus pe, return f (mod pe).
337 */
338 inline void
reduce_poly_ul(unsigned int * f_ui,mpz_t * f,int d,unsigned int pe)339 reduce_poly_ul ( unsigned int *f_ui,
340 mpz_t *f,
341 int d,
342 unsigned int pe )
343 {
344 int i;
345 for (i = 0; i <= d; i ++) {
346 f_ui[i] = (unsigned int) mpz_fdiv_ui (f[i], pe);
347 }
348 }
349
350
351 /**
352 * From polyselect.c
353 * Implements Lemma 2.1 from Kleinjung's paper.
354 * If a[d] is non-zero, it is assumed it is already set, otherwise it is
355 * determined as a[d] = N/m^d (mod p).
356 */
357 void
Lemma21(mpz_t * a,mpz_t N,int d,mpz_t p,mpz_t m,mpz_t res)358 Lemma21 ( mpz_t *a,
359 mpz_t N,
360 int d,
361 mpz_t p,
362 mpz_t m,
363 mpz_t res )
364 {
365 mpz_t r, mi, invp, l, ln;
366 int i;
367
368 mpz_init (r);
369 mpz_init_set_ui (l, 1);
370 mpz_init (ln);
371 mpz_init (mi);
372 mpz_init (invp);
373 mpz_pow_ui (mi, m, d);
374
375 if (mpz_cmp_ui (a[d], 0) < 0)
376 mpz_abs (a[d], a[d]);
377
378 if (mpz_cmp_ui (a[d], 0) == 0) {
379 mpz_invert (a[d], mi, p); /* 1/m^d mod p */
380 mpz_mul (a[d], a[d], N);
381 mpz_mod (a[d], a[d], p);
382 mpz_set_ui (l, 1);
383 }
384 /* multiplier l < m1 */
385 else {
386 mpz_invert (l, N, p);
387 mpz_mul (l, l, mi);
388 mpz_mul (l, l, a[d]);
389 mpz_mod (l, l, p);
390 }
391 mpz_mul (ln, N, l);
392 mpz_set (r, ln);
393 mpz_set (res, l);
394
395 for (i = d - 1; i >= 0; i--)
396 {
397 /* invariant: mi = m^(i+1) */
398 mpz_mul (a[i], a[i+1], mi);
399 mpz_sub (r, r, a[i]);
400 ASSERT (mpz_divisible_p (r, p));
401 mpz_divexact (r, r, p);
402 mpz_divexact (mi, mi, m); /* now mi = m^i */
403 if (i == d - 1)
404 {
405 mpz_invert (invp, p, mi); /* 1/p mod m^i */
406 mpz_sub (invp, mi, invp); /* -1/p mod m^i */
407 }
408 else
409 mpz_mod (invp, invp, mi);
410 mpz_mul (a[i], invp, r);
411 mpz_mod (a[i], a[i], mi); /* -r/p mod m^i */
412 /* round to nearest in [-m^i/2, m^i/2] */
413 mpz_mul_2exp (a[i], a[i], 1);
414 if (mpz_cmp (a[i], mi) >= 0)
415 {
416 mpz_div_2exp (a[i], a[i], 1);
417 mpz_sub (a[i], a[i], mi);
418 }
419 else
420 mpz_div_2exp (a[i], a[i], 1);
421 mpz_mul (a[i], a[i], p);
422 mpz_add (a[i], a[i], r);
423 ASSERT (mpz_divisible_p (a[i], mi));
424 mpz_divexact (a[i], a[i], mi);
425 }
426 mpz_clear (r);
427 mpz_clear (l);
428 mpz_clear (ln);
429 mpz_clear (mi);
430 mpz_clear (invp);
431 }
432