1 #include "cado.h" // IWYU pragma: keep
2 #include "modredc_2ul2.h"
3 #include "modredc_2ul2_default.h"
4 #include "modredc_2ul_common.c"
5
6 #define PARI 0
7 #if PARI
8 #include <stdio.h> // IWYU pragma: keep
9 #include <stdlib.h>
10 #include "macros.h"
11 #define MODINV_PRINT_PARI_M \
12 printf ("m = (%lu << %d) + %lu; /* PARI %d */\n", m[0].m[1], LONG_BIT, m[0].m[0], __LINE__)
13 #define MODINV_PRINT_PARI_x \
14 printf ("x = (%lu << %d) + %lu; /* PARI %d */\n", a[1], LONG_BIT, a[0], __LINE__);
15 #define MODINV_PRINT_PARI_X \
16 printf ("X = (%lu << %d) + %lu; /* PARI %d */\n", a[1], LONG_BIT, a[0], __LINE__);
17 #define MODINV_PRINT_PARI_INVARIANT_A \
18 printf ("a = %lu *2^%d + %lu; u = %lu *2^%d + %lu; Mod(u, m) * X == a << %d /* PARIC %d */\n", a[1], LONG_BIT, a[0], u[1], LONG_BIT, u[0], t, __LINE__)
19 #define MODINV_PRINT_PARI_INVARIANT_B \
20 printf ("b = %lu *2^%d + %lu; v = %lu *2^%d + %lu; -Mod(v, m) * X == b << %d /* PARIC %d */\n", b[1], LONG_BIT, b[0], v[1], LONG_BIT, v[0], t, __LINE__)
21 #else
22 #define MODINV_PRINT_PARI_M
23 #define MODINV_PRINT_PARI_x
24 #define MODINV_PRINT_PARI_X
25 #define MODINV_PRINT_PARI_INVARIANT_A
26 #define MODINV_PRINT_PARI_INVARIANT_B
27 #endif
28
29 int
modredc2ul2_inv(residueredc2ul2_t r,const residueredc2ul2_t A,const modulusredc2ul2_t m)30 modredc2ul2_inv (residueredc2ul2_t r, const residueredc2ul2_t A,
31 const modulusredc2ul2_t m)
32 {
33 modintredc2ul2_t a, b, u, v;
34 int t, lsh;
35 #ifdef WANT_ASSERT_EXPENSIVE
36 residueredc2ul2_t tmp;
37
38 modredc2ul2_init_noset0 (tmp, m);
39 modredc2ul2_set (tmp, A, m);
40 #endif
41
42 ASSERT_EXPENSIVE (modredc2ul2_intcmp (A, m[0].m) < 0);
43 ASSERT_EXPENSIVE (m[0].m[0] & 1UL);
44
45 MODINV_PRINT_PARI_M;
46
47 if (A[0] == 0UL && A[1] == 0UL)
48 return 0;
49
50 modredc2ul2_getmod_int (b, m);
51
52 /* Let A = x*2^{2w}, so we want the Montgomery representation of 1/x,
53 which is 2^{2w}/x. We start by getting a = x */
54 modredc2ul2_get_int (a, A, m);
55 MODINV_PRINT_PARI_x;
56
57 /* We simply set a = x/2^{2w} and t=0. The result before correction
58 will be 2^(2w+t)/x so we have to divide by t, which may be >64,
59 so we may have to do one or more full and a variable width REDC. */
60 /* TODO: If b[1] > 1, we could skip one of the two REDC */
61 modredc2ul2_redc1 (a, a, m);
62 /* Now a = x/2^w */
63 MODINV_PRINT_PARI_X;
64 t = -LONG_BIT;
65
66 modredc2ul2_intset_ul (u, 1UL);
67 modredc2ul2_intset_ul (v, 0UL);
68
69 MODINV_PRINT_PARI_INVARIANT_A;
70 MODINV_PRINT_PARI_INVARIANT_B;
71
72 /* make a odd */
73 if (a[0] == 0UL)
74 {
75 /* x86 bsf gives undefined result for zero input */
76 a[0] = a[1];
77 a[1] = 0UL;
78 t += LONG_BIT;
79 }
80 ASSERT_EXPENSIVE (a[0] != 0UL);
81 lsh = ularith_ctz (a[0]);
82 modredc2ul2_intshr (a, a, lsh);
83 t += lsh;
84
85 // Here a and b are odd, and a < b
86 do {
87 /* Here, a and b are odd, 0 < a < b, u is odd and v is even */
88 ASSERT_EXPENSIVE (modredc2ul2_intcmp (a, b) < 0);
89 ASSERT_EXPENSIVE ((a[0] & 1UL) == 1UL);
90 ASSERT_EXPENSIVE ((b[0] & 1UL) == 1UL);
91 ASSERT_EXPENSIVE ((u[0] & 1UL) == 1UL);
92 ASSERT_EXPENSIVE ((v[0] & 1UL) == 0UL);
93
94 MODINV_PRINT_PARI_INVARIANT_A;
95 MODINV_PRINT_PARI_INVARIANT_B;
96
97 do {
98 modredc2ul2_intsub (b, b, a);
99 modredc2ul2_intadd (v, v, u);
100
101 MODINV_PRINT_PARI_INVARIANT_A;
102 MODINV_PRINT_PARI_INVARIANT_B;
103
104 if (b[0] == 0UL)
105 {
106 b[0] = b[1]; /* b[0] can be odd now, so lsh might be 0 below! */
107 b[1] = 0UL;
108 ASSERT_EXPENSIVE (u[1] == 0UL);
109 u[1] = u[0]; /* Shift left u by LONG_BIT */
110 u[0] = 0UL;
111 t += LONG_BIT;
112 }
113 else
114 {
115 ASSERT_EXPENSIVE (ularith_ctz (b[0]) > 0);
116 }
117 lsh = ularith_ctz (b[0]);
118 ASSERT_EXPENSIVE ((b[0] & ((1UL << lsh) - 1UL)) == 0UL);
119 modredc2ul2_intshr (b, b, lsh);
120 t += lsh;
121 modredc2ul2_intshl (u, u, lsh);
122 MODINV_PRINT_PARI_INVARIANT_A;
123 MODINV_PRINT_PARI_INVARIANT_B;
124 } while (modredc2ul2_intlt (a, b)); /* ~50% branch taken :( */
125
126 /* Here, a and b are odd, 0 < b =< a, u is even and v is odd */
127 ASSERT_EXPENSIVE ((a[0] & 1UL) == 1UL);
128 ASSERT_EXPENSIVE ((b[0] & 1UL) == 1UL);
129 ASSERT_EXPENSIVE ((u[0] & 1UL) == 0UL);
130 ASSERT_EXPENSIVE ((v[0] & 1UL) == 1UL);
131
132 if (modredc2ul2_intequal (a, b))
133 break;
134 ASSERT_EXPENSIVE (modredc2ul2_intcmp (a, b) > 0);
135
136 /* Here, a and b are odd, 0 < b < a, u is even and v is odd */
137 do {
138 modredc2ul2_intsub (a, a, b);
139 modredc2ul2_intadd (u, u, v);
140 MODINV_PRINT_PARI_INVARIANT_A;
141 MODINV_PRINT_PARI_INVARIANT_B;
142
143 if (a[0] == 0UL)
144 {
145 a[0] = a[1];
146 a[1] = 0UL;
147 v[1] = v[0]; /* Shift left v by LONG_BIT */
148 v[0] = 0UL;
149 t += LONG_BIT;
150 }
151 else
152 {
153 ASSERT_EXPENSIVE (ularith_ctz (a[0]) > 0);
154 }
155 lsh = ularith_ctz (a[0]);
156 ASSERT_EXPENSIVE ((a[0] & ((1UL << lsh) - 1UL)) == 0UL);
157 modredc2ul2_intshr (a, a, lsh);
158 t += lsh;
159 modredc2ul2_intshl (v, v, lsh);
160 MODINV_PRINT_PARI_INVARIANT_A;
161 MODINV_PRINT_PARI_INVARIANT_B;
162 } while (modredc2ul2_intlt (b, a)); /* about 50% branch taken :( */
163 /* Here, a and b are odd, 0 < a =< b, u is odd and v is even */
164 } while (!modredc2ul2_intequal (a, b));
165
166 if (modredc2ul2_intcmp_ul (a, 1UL) != 0) /* Non-trivial GCD */
167 return 0;
168
169 ASSERT (t >= 0);
170
171 /* Here, the inverse of a is u/2^t mod m. To do the division by t,
172 we use a variable-width REDC. We want to add a multiple of m to u
173 so that the low t bits of the sum are 0 and we can right-shift by t
174 with impunity. */
175 while (t >= LONG_BIT)
176 {
177 modredc2ul2_redc1 (u, u, m);
178 t -= LONG_BIT;
179 }
180
181 if (t > 0)
182 {
183 unsigned long s[5], k;
184 k = ((u[0] * m[0].invm) & ((1UL << t) - 1UL)); /* tlow <= 2^t-1 */
185 ularith_mul_ul_ul_2ul (&(s[0]), &(s[1]), k, m[0].m[0]);
186 /* s[1]:s[0] <= (2^w-1)*(2^t-1) <= (2^w-1)*(2^(w-1)-1) */
187 ularith_add_2ul_2ul (&(s[0]), &(s[1]), u[0], u[1]);
188 /* s[1]:s[0] <= (2^w-1)*(2^(w-1)-1) + (m-1) < 2^(2w) */
189 /* s[0] == 0 (mod 2^t) */
190 ASSERT_EXPENSIVE ((s[0] & ((1UL << t) - 1UL)) == 0UL);
191 s[2] = 0;
192 ularith_mul_ul_ul_2ul (&(s[3]), &(s[4]), k, m[0].m[1]);
193 ularith_add_2ul_2ul (&(s[1]), &(s[2]), s[3], s[4]);
194
195 /* Now shift s[2]:s[1]:s[0] right by t */
196 ularith_shrd (&(s[0]), s[1], s[0], t);
197 ularith_shrd (&(s[1]), s[2], s[1], t);
198
199 u[0] = s[0];
200 u[1] = s[1];
201 t = 0;
202 MODINV_PRINT_PARI_INVARIANT_A;
203 }
204
205 #ifdef WANT_ASSERT_EXPENSIVE
206 modredc2ul2_mul (tmp, tmp, u, m);
207 if (!modredc2ul2_is1 (tmp, m))
208 {
209 modintredc2ul2_t tmpi;
210 modredc2ul2_get_int (tmpi, tmp, m);
211 fprintf (stderr, "Error, Mod(1/(%lu + 2^%d * %lu), %lu + 2^%d * %lu) == "
212 "%lu + 2^%d * %lu\n",
213 A[0], LONG_BIT, A[1], m[0].m[0], LONG_BIT, m[0].m[1],
214 tmpi[0], LONG_BIT, tmpi[1]);
215 ASSERT_EXPENSIVE (modredc2ul2_intcmp_ul (tmpi, 1UL) == 0);
216 }
217 modredc2ul2_clear (tmp, m);
218 #endif
219
220 r[0] = u[0];
221 r[1] = u[1];
222 return 1;
223 }
224
225
226 int
modredc2ul2_batchinv_ul(residue_t * r,const unsigned long * a,const size_t n,const residue_t c,const modulus_t m)227 modredc2ul2_batchinv_ul (residue_t *r, const unsigned long *a, const size_t n,
228 const residue_t c, const modulus_t m)
229 {
230 residue_t R;
231
232 if (n == 0)
233 return 1;
234
235 mod_intset_ul(r[0], a[0]);
236
237 /* beta' = 2^64, beta = 2^128 */
238 for (size_t i = 1; i < n; i++) {
239 _modredc2ul2_mul_ul(r[i], r[i-1], a[i], m);
240 /* r[i] = beta'^{-i} \prod_{0 <= j <= i} a[j] */
241 }
242
243 mod_init_noset0(R, m);
244 /* Computes R = beta^2/r[n-1] */
245 if (!modredc2ul2_inv(R, r[n - 1], m))
246 return 0;
247 /* R = beta^2 beta'^{n-1} \prod_{0 <= j < n} a[j]^{-1} */
248
249 if (c != NULL) {
250 mod_mul(R, R, c, m);
251 } else {
252 modredc2ul2_redc1(R, R, m); /* Assume c=1 */
253 modredc2ul2_redc1(R, R, m);
254 }
255 /* R = beta beta'^{n-1} c \prod_{0 <= j < n} a[j]^{-1} */
256
257 modredc2ul2_redc1(R, R, m);
258 /* R = beta beta'^{n-2} c \prod_{0 <= j < n} a[j]^{-1} */
259
260 for (size_t i = n-1; i > 0; i--) {
261 /* Invariant: R = beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1} */
262
263 mod_mul(r[i], R, r[i-1], m);
264 /* r[i] := R * r[i-1] / beta
265 = (beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1}) * (1/beta'^{i-1} \prod_{0 <= j <= i-1} a[j]) / beta
266 = c a[i]^{-1} */
267
268 _modredc2ul2_mul_ul(R, R, a[i], m);
269 /* R := R * a[i] / beta'
270 = (beta beta'^{i-1} c \prod_{0 <= j <= i} a[j]^{-1}) * a[i] / beta'
271 = beta beta'^{i-2} c \prod_{0 <= j < i} a[j]^{-1},
272 thus satisfying the invariant for i := i - 1 */
273 }
274 /* Here have R = beta * beta'^{-1} / a[0]. We need to convert the factor
275 beta to a factor of beta', so that the beta' cancel. */
276 modredc2ul2_redc1(R, R, m); /* R := beta * beta'^{-1} / a[0] / beta',
277 with beta = beta'^2, this is 1/a[0] */
278 mod_set(r[0], R, m);
279 mod_clear(R, m);
280 return 1;
281 }
282
283 modredc2ul2_batch_Q_to_Fp_context_t *
modredc2ul2_batch_Q_to_Fp_init(const modint_t num,const modint_t den)284 modredc2ul2_batch_Q_to_Fp_init (const modint_t num, const modint_t den)
285 {
286 modredc2ul2_batch_Q_to_Fp_context_t *context;
287 modint_t ratio, remainder;
288
289 context = (modredc2ul2_batch_Q_to_Fp_context_t *) malloc(sizeof(modredc2ul2_batch_Q_to_Fp_context_t));
290 if (context == NULL)
291 return NULL;
292
293 mod_initmod_int(context->m, den);
294 mod_intinit(context->c); /* c = 0 */
295
296 /* Compute ratio = floor(num / den), remainder = num % den. We assume that
297 ratio fits into unsigned long, and abort if it does not. We need only the
298 low word of remainder. */
299 mod_intinit(remainder);
300 mod_intinit(ratio);
301 mod_intmod(remainder, num, den);
302 mod_intsub(ratio, num, remainder);
303 mod_intdivexact(ratio, ratio, den);
304 ASSERT_ALWAYS(modredc2ul2_intfits_ul(ratio));
305 context->ratio_ul = modredc2ul2_intget_ul(ratio);
306 context->rem_ul = modredc2ul2_intget_ul(remainder);
307 if (!mod_intequal_ul(remainder, 0))
308 mod_intsub(context->c, den, remainder); /* c = -remainder (mod den) */
309 mod_intclear(ratio);
310 mod_intclear(remainder);
311
312 context->den_inv = ularith_invmod(modredc2ul2_intget_ul(den));
313
314 return context;
315 }
316
317
318 void
modredc2ul2_batch_Q_to_Fp_clear(modredc2ul2_batch_Q_to_Fp_context_t * context)319 modredc2ul2_batch_Q_to_Fp_clear (modredc2ul2_batch_Q_to_Fp_context_t * context)
320 {
321 mod_clearmod(context->m);
322 mod_intclear(context->c);
323 free(context);
324 }
325
326
327 int
modredc2ul2_batch_Q_to_Fp(unsigned long * r,const modredc2ul2_batch_Q_to_Fp_context_t * context,const unsigned long k,const int neg,const unsigned long * p,const size_t n)328 modredc2ul2_batch_Q_to_Fp (unsigned long *r,
329 const modredc2ul2_batch_Q_to_Fp_context_t *context,
330 const unsigned long k, const int neg,
331 const unsigned long *p, const size_t n)
332 {
333 residue_t *tr;
334 int rc = 1;
335
336 tr = (residue_t *) malloc(n * sizeof(residue_t));
337 for (size_t i = 0; i < n; i++) {
338 mod_init_noset0(tr[i], context->m);
339 }
340
341 if (!modredc2ul2_batchinv_ul(tr, p, n, context->c, context->m)) {
342 rc = 0;
343 goto clear_and_exit;
344 }
345
346 for (size_t i = 0; i < n; i++) {
347 unsigned long t;
348 t = ularith_post_process_inverse(mod_intget_ul(tr[i]), p[i],
349 context->rem_ul, context->den_inv,
350 context->ratio_ul, k);
351 if (neg && t != 0)
352 t = p[i] - t;
353 r[i] = t;
354 }
355
356 clear_and_exit:
357 for (size_t i = 0; i < n; i++) {
358 mod_clear(tr[i], context->m);
359 }
360 free(tr);
361 return rc;
362 }
363