1 /*
2 Copyright (C) 2018 Daniel Schultz
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <https://www.gnu.org/licenses/>.
10 */
11
12 #include "nmod_poly.h"
13
nmod_poly_multi_crt_init(nmod_poly_multi_crt_t P)14 void nmod_poly_multi_crt_init(nmod_poly_multi_crt_t P)
15 {
16 P->prog = NULL;
17 P->alloc = 0;
18 P->length = 0;
19 P->localsize = 1;
20 P->temp1loc = 0;
21 P->temp2loc = 0;
22 P->good = 0;
23 }
24
_nmod_poly_multi_crt_fit_length(nmod_poly_multi_crt_t P,slong k)25 static void _nmod_poly_multi_crt_fit_length(nmod_poly_multi_crt_t P, slong k)
26 {
27 k = FLINT_MAX(WORD(1), k);
28
29 if (P->alloc == 0)
30 {
31 FLINT_ASSERT(P->prog == NULL);
32 P->prog = (_nmod_poly_multi_crt_prog_instr *) flint_malloc(k
33 *sizeof(_nmod_poly_multi_crt_prog_instr));
34 P->alloc = k;
35 }
36 else if (k > P->alloc)
37 {
38 FLINT_ASSERT(P->prog != NULL);
39 P->prog = (_nmod_poly_multi_crt_prog_instr *) flint_realloc(P->prog, k
40 *sizeof(_nmod_poly_multi_crt_prog_instr));
41 P->alloc = k;
42 }
43 }
44
_nmod_poly_multi_crt_set_length(nmod_poly_multi_crt_t P,slong k)45 static void _nmod_poly_multi_crt_set_length(nmod_poly_multi_crt_t P, slong k)
46 {
47 slong i;
48
49 FLINT_ASSERT(k <= P->length);
50
51 for (i = k; i < P->length; i++)
52 {
53 nmod_poly_clear(P->prog[i].modulus);
54 nmod_poly_clear(P->prog[i].idem);
55 }
56 P->length = k;
57 }
58
nmod_poly_multi_crt_clear(nmod_poly_multi_crt_t P)59 void nmod_poly_multi_crt_clear(nmod_poly_multi_crt_t P)
60 {
61 _nmod_poly_multi_crt_set_length(P, 0);
62
63 if (P->alloc > 0)
64 {
65 flint_free(P->prog);
66 }
67 }
68
69 typedef struct {
70 slong idx;
71 slong degree;
72 } index_deg_pair;
73
74 /*
75 combine all moduli in [start, stop)
76 return index of instruction that computes the result
77 */
_push_prog(nmod_poly_multi_crt_t P,const nmod_poly_struct * const * moduli,const index_deg_pair * perm,slong ret_idx,slong start,slong stop)78 static slong _push_prog(
79 nmod_poly_multi_crt_t P,
80 const nmod_poly_struct * const * moduli,
81 const index_deg_pair * perm,
82 slong ret_idx,
83 slong start,
84 slong stop)
85 {
86 slong i, mid;
87 slong b_idx, c_idx;
88 slong lefttot, righttot;
89 slong leftret, rightret;
90 nmod_poly_struct * leftmodulus, * rightmodulus;
91
92 /* we should have at least 2 moduli */
93 FLINT_ASSERT(start + 1 < stop);
94
95 mid = start + (stop - start)/2;
96
97 FLINT_ASSERT(start < mid);
98 FLINT_ASSERT(mid < stop);
99
100 lefttot = 0;
101 for (i = start; i < mid; i++)
102 {
103 lefttot += perm[i].degree;
104 }
105
106 righttot = 0;
107 for (i = mid; i < stop; i++)
108 {
109 righttot += perm[i].degree;
110 }
111
112 /* try to balance the total degree on left and right */
113 while (lefttot < righttot
114 && mid + 1 < stop
115 && perm[mid].degree < righttot - lefttot)
116 {
117 lefttot += perm[mid].degree;
118 righttot -= perm[mid].degree;
119 mid++;
120 }
121
122 P->localsize = FLINT_MAX(P->localsize, 1 + ret_idx);
123
124 /* compile left [start, mid) */
125 if (start + 1 < mid)
126 {
127 b_idx = ret_idx + 1;
128 leftret = _push_prog(P, moduli, perm, b_idx, start, mid);
129 if (!P->good)
130 {
131 return -1;
132 }
133 leftmodulus = P->prog[leftret].modulus;
134 }
135 else
136 {
137 b_idx = -1 - perm[start].idx;
138 leftmodulus = (nmod_poly_struct *) moduli[perm[start].idx];
139 }
140
141 /* compile right [mid, end) */
142 if (mid + 1 < stop)
143 {
144 c_idx = ret_idx + 2;
145 rightret = _push_prog(P, moduli, perm, c_idx, mid, stop);
146 if (!P->good)
147 {
148 return -1;
149 }
150 rightmodulus = P->prog[rightret].modulus;
151 }
152 else
153 {
154 c_idx = -1 - perm[mid].idx;
155 rightmodulus = (nmod_poly_struct *) moduli[perm[mid].idx];
156 }
157
158 /* check if nmod_poly_invmod is going to throw */
159 if (nmod_poly_degree(leftmodulus) < 1 || nmod_poly_degree(rightmodulus) < 1)
160 {
161 P->good = 0;
162 return -1;
163 }
164
165 /* compile [start, end) */
166 i = P->length;
167 _nmod_poly_multi_crt_fit_length(P, i + 1);
168 nmod_poly_init_mod(P->prog[i].modulus, rightmodulus->mod);
169 nmod_poly_init_mod(P->prog[i].idem, rightmodulus->mod);
170 P->good = P->good && nmod_poly_invmod(P->prog[i].modulus, leftmodulus, rightmodulus);
171 nmod_poly_mul(P->prog[i].idem, leftmodulus, P->prog[i].modulus);
172 nmod_poly_mul(P->prog[i].modulus, leftmodulus, rightmodulus);
173 P->prog[i].a_idx = ret_idx;
174 P->prog[i].b_idx = b_idx;
175 P->prog[i].c_idx = c_idx;
176 P->length = i + 1;
177
178 return i;
179 }
180
181
index_deg_pair_cmp(const index_deg_pair * lhs,const index_deg_pair * rhs)182 static int index_deg_pair_cmp(
183 const index_deg_pair * lhs,
184 const index_deg_pair * rhs)
185 {
186 return (lhs->degree < rhs->degree) ? -1 : (lhs->degree > rhs->degree);
187 }
188
189 /*
190 Return 1 if moduli can be CRT'ed, 0 otherwise.
191 A return of 0 means that future calls to run will leave output undefined.
192 */
nmod_poly_multi_crt_precompute_p(nmod_poly_multi_crt_t P,const nmod_poly_struct * const * moduli,slong len)193 int nmod_poly_multi_crt_precompute_p(
194 nmod_poly_multi_crt_t P,
195 const nmod_poly_struct * const * moduli,
196 slong len)
197 {
198 slong i;
199 index_deg_pair * perm;
200 TMP_INIT;
201
202 FLINT_ASSERT(len > 0);
203 for (i = 1; i < len; i++)
204 {
205 FLINT_ASSERT(moduli[i - 1]->mod.n == moduli[i]->mod.n);
206 }
207
208 TMP_START;
209 perm = (index_deg_pair *) TMP_ALLOC(len * sizeof(index_deg_pair));
210
211 for (i = 0; i < len; i++)
212 {
213 perm[i].idx = i;
214 perm[i].degree = nmod_poly_degree(moduli[i]);
215 }
216
217 /* make perm sort the degs so that degs[perm[i-1]] <= degs[perm[i-0]] */
218 qsort(perm, len, sizeof(index_deg_pair),
219 (int(*)(const void*, const void*)) index_deg_pair_cmp);
220 for (i = 0; i < len; i++)
221 {
222 FLINT_ASSERT(perm[i].degree == nmod_poly_degree(moduli[perm[i].idx]));
223 FLINT_ASSERT(i == 0 || perm[i - 1].degree <= perm[i].degree);
224 }
225
226 _nmod_poly_multi_crt_fit_length(P, FLINT_MAX(WORD(1), len - 1));
227 _nmod_poly_multi_crt_set_length(P, 0);
228 P->localsize = 1;
229 P->good = 1;
230
231 if (1 < len)
232 {
233 _push_prog(P, moduli, perm, 0, 0, len);
234 }
235 else
236 {
237 /*
238 There is only one modulus. Let's compute as
239 output[0] = input[0] + 0*(input[0] - input[0]) mod moduli[0]
240 */
241 i = 0;
242 nmod_poly_init_mod(P->prog[i].modulus, moduli[0]->mod);
243 nmod_poly_init_mod(P->prog[i].idem, moduli[0]->mod);
244 nmod_poly_set(P->prog[i].modulus, moduli[0]);
245 P->prog[i].a_idx = 0;
246 P->prog[i].b_idx = -WORD(1);
247 P->prog[i].c_idx = -WORD(1);
248 P->length = i + 1;
249
250 P->good = !nmod_poly_is_zero(moduli[0]);
251 }
252
253 if (!P->good)
254 {
255 _nmod_poly_multi_crt_set_length(P, 0);
256 }
257
258 /* two more spots for temporaries */
259 P->temp1loc = P->localsize++;
260 P->temp2loc = P->localsize++;
261
262 TMP_END;
263
264 return P->good;
265 }
266
nmod_poly_multi_crt_precompute(nmod_poly_multi_crt_t P,const nmod_poly_struct * moduli,slong len)267 int nmod_poly_multi_crt_precompute(
268 nmod_poly_multi_crt_t P,
269 const nmod_poly_struct * moduli,
270 slong len)
271 {
272 int success;
273 slong i;
274 const nmod_poly_struct ** m;
275 TMP_INIT;
276
277 FLINT_ASSERT(len > 0);
278
279 TMP_START;
280
281 m = (const nmod_poly_struct **) TMP_ALLOC(len*sizeof(nmod_poly_struct *));
282 for (i = 0; i < len; i++)
283 {
284 m[i] = moduli + i;
285 }
286
287 success = nmod_poly_multi_crt_precompute_p(P,
288 (const nmod_poly_struct * const *) m, len);
289 TMP_END;
290
291 return success;
292 }
293
294
nmod_poly_multi_crt_precomp(nmod_poly_t output,const nmod_poly_multi_crt_t P,const nmod_poly_struct * inputs)295 void nmod_poly_multi_crt_precomp(
296 nmod_poly_t output,
297 const nmod_poly_multi_crt_t P,
298 const nmod_poly_struct * inputs)
299 {
300 slong i;
301 nmod_poly_struct * out;
302 TMP_INIT;
303
304 TMP_START;
305 out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
306 *sizeof(nmod_poly_struct));
307 for (i = 0; i < P->localsize; i++)
308 {
309 nmod_poly_init_mod(out + i, inputs[0].mod);
310 }
311
312 nmod_poly_swap(out + 0, output);
313 _nmod_poly_multi_crt_run(out, P, inputs);
314 nmod_poly_swap(out + 0, output);
315
316 for (i = 0; i < P->localsize; i++)
317 {
318 nmod_poly_clear(out + i);
319 }
320
321 TMP_END;
322 }
323
nmod_poly_multi_crt_precomp_p(nmod_poly_t output,const nmod_poly_multi_crt_t P,const nmod_poly_struct * const * inputs)324 void nmod_poly_multi_crt_precomp_p(
325 nmod_poly_t output,
326 const nmod_poly_multi_crt_t P,
327 const nmod_poly_struct * const * inputs)
328 {
329 slong i;
330 nmod_poly_struct * out;
331 TMP_INIT;
332
333 TMP_START;
334 out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
335 *sizeof(nmod_poly_struct));
336 for (i = 0; i < P->localsize; i++)
337 {
338 nmod_poly_init_mod(out + i, inputs[0]->mod);
339 }
340
341 nmod_poly_swap(out + 0, output);
342 _nmod_poly_multi_crt_run_p(out, P, inputs);
343 nmod_poly_swap(out + 0, output);
344
345 for (i = 0; i < P->localsize; i++)
346 {
347 nmod_poly_clear(out + i);
348 }
349
350 TMP_END;
351 }
352
nmod_poly_multi_crt(nmod_poly_t output,const nmod_poly_struct * moduli,const nmod_poly_struct * values,slong len)353 int nmod_poly_multi_crt(
354 nmod_poly_t output,
355 const nmod_poly_struct * moduli,
356 const nmod_poly_struct * values,
357 slong len)
358 {
359 int success;
360 slong i;
361 nmod_poly_multi_crt_t P;
362 nmod_poly_struct * out;
363 TMP_INIT;
364
365 FLINT_ASSERT(len > 0);
366
367 TMP_START;
368
369 nmod_poly_multi_crt_init(P);
370 success = nmod_poly_multi_crt_precompute(P, moduli, len);
371
372 out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
373 *sizeof(nmod_poly_struct));
374 for (i = 0; i < P->localsize; i++)
375 {
376 nmod_poly_init_mod(out + i, values[0].mod);
377 }
378
379 nmod_poly_swap(out + 0, output);
380 _nmod_poly_multi_crt_run(out, P, values);
381 nmod_poly_swap(out + 0, output);
382
383 for (i = 0; i < P->localsize; i++)
384 {
385 nmod_poly_clear(out + i);
386 }
387
388 nmod_poly_multi_crt_clear(P);
389
390 TMP_END;
391
392 return success;
393 }
394
395 /*
396 If P was set with a call to nmod_poly_multi_crt_compile(P, m, len), return
397 in outputs[0] polynomial r of smallest degree such that
398 r = inputs[0] mod m[0]
399 r = inputs[1] mod m[1]
400 ...
401 r = inputs[len-1] mod m[len-1]
402 For thread safety "outputs" is expected to have enough space for all
403 temporaries, thus should be at least as long as P->localsize.
404 */
_nmod_poly_multi_crt_run(nmod_poly_struct * outputs,const nmod_poly_multi_crt_t P,const nmod_poly_struct * inputs)405 void _nmod_poly_multi_crt_run(
406 nmod_poly_struct * outputs,
407 const nmod_poly_multi_crt_t P,
408 const nmod_poly_struct * inputs)
409 {
410 slong i;
411 slong a, b, c;
412 const nmod_poly_struct * B, * C;
413 nmod_poly_struct * A, * t1, * t2;
414
415 t1 = outputs + P->temp1loc;
416 t2 = outputs + P->temp2loc;
417
418 for (i = 0; i < P->length; i++)
419 {
420 a = P->prog[i].a_idx;
421 b = P->prog[i].b_idx;
422 c = P->prog[i].c_idx;
423 FLINT_ASSERT(a >= 0);
424 A = outputs + a;
425 B = b < 0 ? inputs + (-b-1) : outputs + b;
426 C = c < 0 ? inputs + (-c-1) : outputs + c;
427
428 FLINT_ASSERT(A->mod.n == P->prog[i].modulus->mod.n);
429 FLINT_ASSERT(B->mod.n == P->prog[i].modulus->mod.n);
430 FLINT_ASSERT(C->mod.n == P->prog[i].modulus->mod.n);
431
432 /* A = B + I*(C - B) mod M */
433 nmod_poly_sub(t1, B, C);
434 nmod_poly_mul(t2, P->prog[i].idem, t1);
435 nmod_poly_sub(t1, B, t2);
436
437 if (nmod_poly_degree(t1) < nmod_poly_degree(P->prog[i].modulus))
438 {
439 nmod_poly_swap(A, t1);
440 }
441 else
442 {
443 nmod_poly_rem(A, t1, P->prog[i].modulus);
444 }
445
446 /* last calculation should write answer to outputs[0] */
447 if (i + 1 >= P->length)
448 {
449 FLINT_ASSERT(A == outputs + 0);
450 }
451 }
452 }
453
_nmod_poly_multi_crt_run_p(nmod_poly_struct * outputs,const nmod_poly_multi_crt_t P,const nmod_poly_struct * const * inputs)454 void _nmod_poly_multi_crt_run_p(
455 nmod_poly_struct * outputs,
456 const nmod_poly_multi_crt_t P,
457 const nmod_poly_struct * const * inputs)
458 {
459 slong i;
460 slong a, b, c;
461 const nmod_poly_struct * B, * C;
462 nmod_poly_struct * A, * t1, * t2;
463
464 t1 = outputs + P->temp1loc;
465 t2 = outputs + P->temp2loc;
466
467 for (i = 0; i < P->length; i++)
468 {
469 a = P->prog[i].a_idx;
470 b = P->prog[i].b_idx;
471 c = P->prog[i].c_idx;
472 FLINT_ASSERT(a >= 0);
473 A = outputs + a;
474 B = b < 0 ? inputs[-b-1] : outputs + b;
475 C = c < 0 ? inputs[-c-1] : outputs + c;
476
477 FLINT_ASSERT(A->mod.n == P->prog[i].modulus->mod.n);
478 FLINT_ASSERT(B->mod.n == P->prog[i].modulus->mod.n);
479 FLINT_ASSERT(C->mod.n == P->prog[i].modulus->mod.n);
480
481 /* A = B + I*(C - B) mod M */
482 nmod_poly_sub(t1, B, C);
483 nmod_poly_mul(t2, P->prog[i].idem, t1);
484 nmod_poly_sub(t1, B, t2);
485
486 if (nmod_poly_degree(t1) < nmod_poly_degree(P->prog[i].modulus))
487 {
488 nmod_poly_swap(A, t1);
489 }
490 else
491 {
492 nmod_poly_rem(A, t1, P->prog[i].modulus);
493 }
494
495 /* last calculation should write answer to outputs[0] */
496 if (i + 1 >= P->length)
497 {
498 FLINT_ASSERT(A == outputs + 0);
499 }
500 }
501 }
502