1 #include "cado.h" // IWYU pragma: keep
2 #include "modredc_2ul2.h"
3 #include "modredc_2ul2_default.h"
4 #include "modredc_2ul_common.c"
5 
6 #define PARI 0
7 #if PARI
8 #include <stdio.h>      // IWYU pragma: keep
9 #include <stdlib.h>
10 #include "macros.h"
11 #define MODINV_PRINT_PARI_M \
12     printf ("m = (%lu << %d) + %lu; /* PARI %d */\n", m[0].m[1], LONG_BIT, m[0].m[0], __LINE__)
13 #define MODINV_PRINT_PARI_x \
14     printf ("x = (%lu << %d) + %lu; /* PARI %d */\n", a[1], LONG_BIT, a[0], __LINE__);
15 #define MODINV_PRINT_PARI_X \
16     printf ("X = (%lu << %d) + %lu; /* PARI %d */\n", a[1], LONG_BIT, a[0], __LINE__);
17 #define MODINV_PRINT_PARI_INVARIANT_A \
18     printf ("a = %lu *2^%d + %lu; u = %lu *2^%d + %lu; Mod(u, m) * X == a << %d /* PARIC %d */\n", a[1], LONG_BIT, a[0], u[1], LONG_BIT, u[0], t, __LINE__)
19 #define MODINV_PRINT_PARI_INVARIANT_B \
20     printf ("b = %lu *2^%d + %lu; v = %lu *2^%d + %lu; -Mod(v, m) * X == b << %d /* PARIC %d */\n", b[1], LONG_BIT, b[0], v[1], LONG_BIT, v[0], t, __LINE__)
21 #else
22 #define MODINV_PRINT_PARI_M
23 #define MODINV_PRINT_PARI_x
24 #define MODINV_PRINT_PARI_X
25 #define MODINV_PRINT_PARI_INVARIANT_A
26 #define MODINV_PRINT_PARI_INVARIANT_B
27 #endif
28 
29 int
modredc2ul2_inv(residueredc2ul2_t r,const residueredc2ul2_t A,const modulusredc2ul2_t m)30 modredc2ul2_inv (residueredc2ul2_t r, const residueredc2ul2_t A,
31 		 const modulusredc2ul2_t m)
32 {
33   modintredc2ul2_t a, b, u, v;
34   int t, lsh;
35 #ifdef WANT_ASSERT_EXPENSIVE
36   residueredc2ul2_t tmp;
37 
38   modredc2ul2_init_noset0 (tmp, m);
39   modredc2ul2_set (tmp, A, m);
40 #endif
41 
42   ASSERT_EXPENSIVE (modredc2ul2_intcmp (A, m[0].m) < 0);
43   ASSERT_EXPENSIVE (m[0].m[0] & 1UL);
44 
45   MODINV_PRINT_PARI_M;
46 
47   if (A[0] == 0UL && A[1] == 0UL)
48     return 0;
49 
50   modredc2ul2_getmod_int (b, m);
51 
52   /* Let A = x*2^{2w}, so we want the Montgomery representation of 1/x,
53      which is 2^{2w}/x. We start by getting a = x */
54   modredc2ul2_get_int (a, A, m);
55   MODINV_PRINT_PARI_x;
56 
57   /* We simply set a = x/2^{2w} and t=0. The result before correction
58      will be 2^(2w+t)/x so we have to divide by t, which may be >64,
59      so we may have to do one or more full and a variable width REDC. */
60   /* TODO: If b[1] > 1, we could skip one of the two REDC */
61   modredc2ul2_redc1 (a, a, m);
62   /* Now a = x/2^w */
63   MODINV_PRINT_PARI_X;
64   t = -LONG_BIT;
65 
66   modredc2ul2_intset_ul (u, 1UL);
67   modredc2ul2_intset_ul (v, 0UL);
68 
69   MODINV_PRINT_PARI_INVARIANT_A;
70   MODINV_PRINT_PARI_INVARIANT_B;
71 
72   /* make a odd */
73   if (a[0] == 0UL)
74     {
75       /* x86 bsf gives undefined result for zero input */
76       a[0] = a[1];
77       a[1] = 0UL;
78       t += LONG_BIT;
79     }
80   ASSERT_EXPENSIVE (a[0] != 0UL);
81   lsh = ularith_ctz (a[0]);
82   modredc2ul2_intshr (a, a, lsh);
83   t += lsh;
84 
85   // Here a and b are odd, and a < b
86   do {
87     /* Here, a and b are odd, 0 < a < b, u is odd and v is even */
88     ASSERT_EXPENSIVE (modredc2ul2_intcmp (a, b) < 0);
89     ASSERT_EXPENSIVE ((a[0] & 1UL) == 1UL);
90     ASSERT_EXPENSIVE ((b[0] & 1UL) == 1UL);
91     ASSERT_EXPENSIVE ((u[0] & 1UL) == 1UL);
92     ASSERT_EXPENSIVE ((v[0] & 1UL) == 0UL);
93 
94     MODINV_PRINT_PARI_INVARIANT_A;
95     MODINV_PRINT_PARI_INVARIANT_B;
96 
97     do {
98       modredc2ul2_intsub (b, b, a);
99       modredc2ul2_intadd (v, v, u);
100 
101       MODINV_PRINT_PARI_INVARIANT_A;
102       MODINV_PRINT_PARI_INVARIANT_B;
103 
104       if (b[0] == 0UL)
105 	{
106 	  b[0] = b[1]; /* b[0] can be odd now, so lsh might be 0 below! */
107 	  b[1] = 0UL;
108 	  ASSERT_EXPENSIVE (u[1] == 0UL);
109 	  u[1] = u[0]; /* Shift left u by LONG_BIT */
110 	  u[0] = 0UL;
111 	  t += LONG_BIT;
112 	}
113       else
114         {
115 	  ASSERT_EXPENSIVE (ularith_ctz (b[0]) > 0);
116         }
117       lsh = ularith_ctz (b[0]);
118       ASSERT_EXPENSIVE ((b[0] & ((1UL << lsh) - 1UL)) == 0UL);
119       modredc2ul2_intshr (b, b, lsh);
120       t += lsh;
121       modredc2ul2_intshl (u, u, lsh);
122       MODINV_PRINT_PARI_INVARIANT_A;
123       MODINV_PRINT_PARI_INVARIANT_B;
124     } while (modredc2ul2_intlt (a, b)); /* ~50% branch taken :( */
125 
126     /* Here, a and b are odd, 0 < b =< a, u is even and v is odd */
127     ASSERT_EXPENSIVE ((a[0] & 1UL) == 1UL);
128     ASSERT_EXPENSIVE ((b[0] & 1UL) == 1UL);
129     ASSERT_EXPENSIVE ((u[0] & 1UL) == 0UL);
130     ASSERT_EXPENSIVE ((v[0] & 1UL) == 1UL);
131 
132     if (modredc2ul2_intequal (a, b))
133       break;
134     ASSERT_EXPENSIVE (modredc2ul2_intcmp (a, b) > 0);
135 
136     /* Here, a and b are odd, 0 < b < a, u is even and v is odd */
137     do {
138       modredc2ul2_intsub (a, a, b);
139       modredc2ul2_intadd (u, u, v);
140       MODINV_PRINT_PARI_INVARIANT_A;
141       MODINV_PRINT_PARI_INVARIANT_B;
142 
143       if (a[0] == 0UL)
144 	{
145 	  a[0] = a[1];
146 	  a[1] = 0UL;
147 	  v[1] = v[0]; /* Shift left v by LONG_BIT */
148 	  v[0] = 0UL;
149 	  t += LONG_BIT;
150 	}
151       else
152         {
153 	  ASSERT_EXPENSIVE (ularith_ctz (a[0]) > 0);
154         }
155 	lsh = ularith_ctz (a[0]);
156         ASSERT_EXPENSIVE ((a[0] & ((1UL << lsh) - 1UL)) == 0UL);
157 	modredc2ul2_intshr (a, a, lsh);
158 	t += lsh;
159 	modredc2ul2_intshl (v, v, lsh);
160 	MODINV_PRINT_PARI_INVARIANT_A;
161 	MODINV_PRINT_PARI_INVARIANT_B;
162     } while (modredc2ul2_intlt (b, a)); /* about 50% branch taken :( */
163     /* Here, a and b are odd, 0 < a =< b, u is odd and v is even */
164   } while (!modredc2ul2_intequal (a, b));
165 
166   if (modredc2ul2_intcmp_ul (a, 1UL) != 0) /* Non-trivial GCD */
167     return 0;
168 
169   ASSERT (t >= 0);
170 
171   /* Here, the inverse of a is u/2^t mod m. To do the division by t,
172      we use a variable-width REDC. We want to add a multiple of m to u
173      so that the low t bits of the sum are 0 and we can right-shift by t
174      with impunity. */
175   while (t >= LONG_BIT)
176     {
177       modredc2ul2_redc1 (u, u, m);
178       t -= LONG_BIT;
179     }
180 
181   if (t > 0)
182     {
183       unsigned long s[5], k;
184       k = ((u[0] * m[0].invm) & ((1UL << t) - 1UL)); /* tlow <= 2^t-1 */
185       ularith_mul_ul_ul_2ul (&(s[0]), &(s[1]), k, m[0].m[0]);
186       /* s[1]:s[0] <= (2^w-1)*(2^t-1) <= (2^w-1)*(2^(w-1)-1) */
187       ularith_add_2ul_2ul (&(s[0]), &(s[1]), u[0], u[1]);
188       /* s[1]:s[0] <= (2^w-1)*(2^(w-1)-1) + (m-1) < 2^(2w) */
189       /* s[0] == 0 (mod 2^t) */
190       ASSERT_EXPENSIVE ((s[0] & ((1UL << t) - 1UL)) == 0UL);
191       s[2] = 0;
192       ularith_mul_ul_ul_2ul (&(s[3]), &(s[4]), k, m[0].m[1]);
193       ularith_add_2ul_2ul (&(s[1]), &(s[2]), s[3], s[4]);
194 
195       /* Now shift s[2]:s[1]:s[0] right by t */
196       ularith_shrd (&(s[0]), s[1], s[0], t);
197       ularith_shrd (&(s[1]), s[2], s[1], t);
198 
199       u[0] = s[0];
200       u[1] = s[1];
201       t = 0;
202       MODINV_PRINT_PARI_INVARIANT_A;
203     }
204 
205 #ifdef WANT_ASSERT_EXPENSIVE
206   modredc2ul2_mul (tmp, tmp, u, m);
207   if (!modredc2ul2_is1 (tmp, m))
208     {
209       modintredc2ul2_t tmpi;
210       modredc2ul2_get_int (tmpi, tmp, m);
211       fprintf (stderr, "Error, Mod(1/(%lu + 2^%d * %lu), %lu + 2^%d * %lu) == "
212                "%lu + 2^%d * %lu\n",
213                A[0], LONG_BIT, A[1], m[0].m[0], LONG_BIT, m[0].m[1],
214                tmpi[0], LONG_BIT, tmpi[1]);
215       ASSERT_EXPENSIVE (modredc2ul2_intcmp_ul (tmpi, 1UL) == 0);
216     }
217   modredc2ul2_clear (tmp, m);
218 #endif
219 
220   r[0] = u[0];
221   r[1] = u[1];
222   return 1;
223 }
224 
225 
226 int
modredc2ul2_batchinv_ul(residue_t * r,const unsigned long * a,const size_t n,const residue_t c,const modulus_t m)227 modredc2ul2_batchinv_ul (residue_t *r, const unsigned long *a, const size_t n,
228                          const residue_t c, const modulus_t m)
229 {
230   residue_t R;
231 
232   if (n == 0)
233     return 1;
234 
235   mod_intset_ul(r[0], a[0]);
236 
237   /* beta' = 2^64, beta = 2^128 */
238   for (size_t i = 1; i < n; i++) {
239     _modredc2ul2_mul_ul(r[i], r[i-1], a[i], m);
240     /* r[i] = beta'^{-i} \prod_{0 <= j <= i} a[j] */
241   }
242 
243   mod_init_noset0(R, m);
244   /* Computes R = beta^2/r[n-1] */
245   if (!modredc2ul2_inv(R, r[n - 1], m))
246     return 0;
247   /* R = beta^2 beta'^{n-1} \prod_{0 <= j < n} a[j]^{-1} */
248 
249   if (c != NULL) {
250     mod_mul(R, R, c, m);
251   } else {
252     modredc2ul2_redc1(R, R, m); /* Assume c=1 */
253     modredc2ul2_redc1(R, R, m);
254   }
255   /* R = beta beta'^{n-1} c \prod_{0 <= j < n} a[j]^{-1} */
256 
257   modredc2ul2_redc1(R, R, m);
258   /* R = beta beta'^{n-2} c \prod_{0 <= j < n} a[j]^{-1} */
259 
260   for (size_t i = n-1; i > 0; i--) {
261     /* Invariant: R = beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1} */
262 
263     mod_mul(r[i], R, r[i-1], m);
264     /* r[i] := R * r[i-1] / beta
265             = (beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1}) * (1/beta'^{i-1} \prod_{0 <= j <= i-1} a[j]) / beta
266             = c a[i]^{-1} */
267 
268     _modredc2ul2_mul_ul(R, R, a[i], m);
269     /* R := R * a[i] / beta'
270          = (beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1}) * a[i] / beta'
271          = beta beta'^{i-2} c \prod_{0 <= j < i} a[j]^{-1},
272        thus satisfying the invariant for i := i - 1 */
273   }
274   /* Here have R = beta * beta'^{-1} / a[0]. We need to convert the factor
275      beta to a factor of beta', so that the beta' cancel. */
276   modredc2ul2_redc1(R, R, m); /* R := beta * beta'^{-1} / a[0] / beta',
277                                  with beta = beta'^2, this is 1/a[0] */
278   mod_set(r[0], R, m);
279   mod_clear(R, m);
280   return 1;
281 }
282 
283 modredc2ul2_batch_Q_to_Fp_context_t *
modredc2ul2_batch_Q_to_Fp_init(const modint_t num,const modint_t den)284 modredc2ul2_batch_Q_to_Fp_init (const modint_t num, const modint_t den)
285 {
286   modredc2ul2_batch_Q_to_Fp_context_t *context;
287   modint_t ratio, remainder;
288 
289   context = (modredc2ul2_batch_Q_to_Fp_context_t *) malloc(sizeof(modredc2ul2_batch_Q_to_Fp_context_t));
290   if (context == NULL)
291     return NULL;
292 
293   mod_initmod_int(context->m, den);
294   mod_intinit(context->c); /* c = 0 */
295 
296   /* Compute ratio = floor(num / den), remainder = num % den. We assume that
297     ratio fits into unsigned long, and abort if it does not. We need only the
298     low word of remainder. */
299   mod_intinit(remainder);
300   mod_intinit(ratio);
301   mod_intmod(remainder, num, den);
302   mod_intsub(ratio, num, remainder);
303   mod_intdivexact(ratio, ratio, den);
304   ASSERT_ALWAYS(modredc2ul2_intfits_ul(ratio));
305   context->ratio_ul = modredc2ul2_intget_ul(ratio);
306   context->rem_ul = modredc2ul2_intget_ul(remainder);
307   if (!mod_intequal_ul(remainder, 0))
308     mod_intsub(context->c, den, remainder); /* c = -remainder (mod den) */
309   mod_intclear(ratio);
310   mod_intclear(remainder);
311 
312   context->den_inv = ularith_invmod(modredc2ul2_intget_ul(den));
313 
314   return context;
315 }
316 
317 
318 void
modredc2ul2_batch_Q_to_Fp_clear(modredc2ul2_batch_Q_to_Fp_context_t * context)319 modredc2ul2_batch_Q_to_Fp_clear (modredc2ul2_batch_Q_to_Fp_context_t * context)
320 {
321   mod_clearmod(context->m);
322   mod_intclear(context->c);
323   free(context);
324 }
325 
326 
327 int
modredc2ul2_batch_Q_to_Fp(unsigned long * r,const modredc2ul2_batch_Q_to_Fp_context_t * context,const unsigned long k,const int neg,const unsigned long * p,const size_t n)328 modredc2ul2_batch_Q_to_Fp (unsigned long *r,
329                            const modredc2ul2_batch_Q_to_Fp_context_t *context,
330                            const unsigned long k, const int neg,
331                            const unsigned long *p, const size_t n)
332 {
333   residue_t *tr;
334   int rc = 1;
335 
336   tr = (residue_t *) malloc(n * sizeof(residue_t));
337   for (size_t i = 0; i < n; i++) {
338     mod_init_noset0(tr[i], context->m);
339   }
340 
341   if (!modredc2ul2_batchinv_ul(tr, p, n, context->c, context->m)) {
342     rc = 0;
343     goto clear_and_exit;
344   }
345 
346   for (size_t i = 0; i < n; i++) {
347     unsigned long t;
348     t = ularith_post_process_inverse(mod_intget_ul(tr[i]), p[i],
349                                      context->rem_ul, context->den_inv,
350                                      context->ratio_ul, k);
351     if (neg && t != 0)
352       t = p[i] - t;
353     r[i] = t;
354   }
355 
356 clear_and_exit:
357   for (size_t i = 0; i < n; i++) {
358     mod_clear(tr[i], context->m);
359   }
360   free(tr);
361   return rc;
362 }
363