1 /*
2     Copyright (C) 2018 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include "nmod_poly.h"
13 
nmod_poly_multi_crt_init(nmod_poly_multi_crt_t P)14 void nmod_poly_multi_crt_init(nmod_poly_multi_crt_t P)
15 {
16     P->prog = NULL;
17     P->alloc = 0;
18     P->length = 0;
19     P->localsize = 1;
20     P->temp1loc = 0;
21     P->temp2loc = 0;
22     P->good = 0;
23 }
24 
_nmod_poly_multi_crt_fit_length(nmod_poly_multi_crt_t P,slong k)25 static void _nmod_poly_multi_crt_fit_length(nmod_poly_multi_crt_t P, slong k)
26 {
27     k = FLINT_MAX(WORD(1), k);
28 
29     if (P->alloc == 0)
30     {
31         FLINT_ASSERT(P->prog == NULL);
32         P->prog = (_nmod_poly_multi_crt_prog_instr *) flint_malloc(k
33                                      *sizeof(_nmod_poly_multi_crt_prog_instr));
34         P->alloc = k;
35     }
36     else if (k > P->alloc)
37     {
38         FLINT_ASSERT(P->prog != NULL);
39         P->prog = (_nmod_poly_multi_crt_prog_instr *) flint_realloc(P->prog, k
40                                      *sizeof(_nmod_poly_multi_crt_prog_instr));
41         P->alloc = k;
42     }
43 }
44 
_nmod_poly_multi_crt_set_length(nmod_poly_multi_crt_t P,slong k)45 static void _nmod_poly_multi_crt_set_length(nmod_poly_multi_crt_t P, slong k)
46 {
47     slong i;
48 
49     FLINT_ASSERT(k <= P->length);
50 
51     for (i = k; i < P->length; i++)
52     {
53         nmod_poly_clear(P->prog[i].modulus);
54         nmod_poly_clear(P->prog[i].idem);
55     }
56     P->length = k;
57 }
58 
nmod_poly_multi_crt_clear(nmod_poly_multi_crt_t P)59 void nmod_poly_multi_crt_clear(nmod_poly_multi_crt_t P)
60 {
61     _nmod_poly_multi_crt_set_length(P, 0);
62 
63     if (P->alloc > 0)
64     {
65         flint_free(P->prog);
66     }
67 }
68 
69 typedef struct {
70     slong idx;
71     slong degree;
72 } index_deg_pair;
73 
74 /*
75     combine all moduli in [start, stop)
76     return index of instruction that computes the result
77 */
_push_prog(nmod_poly_multi_crt_t P,const nmod_poly_struct * const * moduli,const index_deg_pair * perm,slong ret_idx,slong start,slong stop)78 static slong _push_prog(
79     nmod_poly_multi_crt_t P,
80     const nmod_poly_struct * const * moduli,
81     const index_deg_pair * perm,
82     slong ret_idx,
83     slong start,
84     slong stop)
85 {
86     slong i, mid;
87     slong b_idx, c_idx;
88     slong lefttot, righttot;
89     slong leftret, rightret;
90     nmod_poly_struct * leftmodulus, * rightmodulus;
91 
92     /* we should have at least 2 moduli */
93     FLINT_ASSERT(start + 1 < stop);
94 
95     mid = start + (stop - start)/2;
96 
97     FLINT_ASSERT(start < mid);
98     FLINT_ASSERT(mid < stop);
99 
100     lefttot = 0;
101     for (i = start; i < mid; i++)
102     {
103         lefttot += perm[i].degree;
104     }
105 
106     righttot = 0;
107     for (i = mid; i < stop; i++)
108     {
109         righttot += perm[i].degree;
110     }
111 
112     /* try to balance the total degree on left and right */
113     while (lefttot < righttot
114             && mid + 1 < stop
115             && perm[mid].degree < righttot - lefttot)
116     {
117         lefttot += perm[mid].degree;
118         righttot -= perm[mid].degree;
119         mid++;
120     }
121 
122     P->localsize = FLINT_MAX(P->localsize, 1 + ret_idx);
123 
124     /* compile left [start, mid) */
125     if (start + 1 < mid)
126     {
127         b_idx = ret_idx + 1;
128         leftret = _push_prog(P, moduli, perm, b_idx, start, mid);
129         if (!P->good)
130         {
131             return -1;
132         }
133         leftmodulus = P->prog[leftret].modulus;
134     }
135     else
136     {
137         b_idx = -1 - perm[start].idx;
138         leftmodulus = (nmod_poly_struct *) moduli[perm[start].idx];
139     }
140 
141     /* compile right [mid, end) */
142     if (mid + 1 < stop)
143     {
144         c_idx = ret_idx + 2;
145         rightret = _push_prog(P, moduli, perm, c_idx, mid, stop);
146         if (!P->good)
147         {
148             return -1;
149         }
150         rightmodulus = P->prog[rightret].modulus;
151     }
152     else
153     {
154         c_idx = -1 - perm[mid].idx;
155         rightmodulus = (nmod_poly_struct *) moduli[perm[mid].idx];
156     }
157 
158     /* check if nmod_poly_invmod is going to throw */
159     if (nmod_poly_degree(leftmodulus) < 1 || nmod_poly_degree(rightmodulus) < 1)
160     {
161         P->good = 0;
162         return -1;
163     }
164 
165     /* compile [start, end) */
166     i = P->length;
167     _nmod_poly_multi_crt_fit_length(P, i + 1);
168     nmod_poly_init_mod(P->prog[i].modulus, rightmodulus->mod);
169     nmod_poly_init_mod(P->prog[i].idem, rightmodulus->mod);
170     P->good = P->good && nmod_poly_invmod(P->prog[i].modulus, leftmodulus, rightmodulus);
171     nmod_poly_mul(P->prog[i].idem, leftmodulus, P->prog[i].modulus);
172     nmod_poly_mul(P->prog[i].modulus, leftmodulus, rightmodulus);
173     P->prog[i].a_idx = ret_idx;
174     P->prog[i].b_idx = b_idx;
175     P->prog[i].c_idx = c_idx;
176     P->length = i + 1;
177 
178     return i;
179 }
180 
181 
index_deg_pair_cmp(const index_deg_pair * lhs,const index_deg_pair * rhs)182 static int index_deg_pair_cmp(
183     const index_deg_pair * lhs,
184     const index_deg_pair * rhs)
185 {
186     return (lhs->degree < rhs->degree) ? -1 : (lhs->degree > rhs->degree);
187 }
188 
189 /*
190     Return 1 if moduli can be CRT'ed, 0 otherwise.
191     A return of 0 means that future calls to run will leave output undefined.
192 */
nmod_poly_multi_crt_precompute_p(nmod_poly_multi_crt_t P,const nmod_poly_struct * const * moduli,slong len)193 int nmod_poly_multi_crt_precompute_p(
194     nmod_poly_multi_crt_t P,
195     const nmod_poly_struct * const * moduli,
196     slong len)
197 {
198     slong i;
199     index_deg_pair * perm;
200     TMP_INIT;
201 
202     FLINT_ASSERT(len > 0);
203     for (i = 1; i < len; i++)
204     {
205         FLINT_ASSERT(moduli[i - 1]->mod.n == moduli[i]->mod.n);
206     }
207 
208     TMP_START;
209     perm = (index_deg_pair *) TMP_ALLOC(len * sizeof(index_deg_pair));
210 
211     for (i = 0; i < len; i++)
212     {
213         perm[i].idx = i;
214         perm[i].degree = nmod_poly_degree(moduli[i]);
215     }
216 
217     /* make perm sort the degs so that degs[perm[i-1]] <= degs[perm[i-0]] */
218     qsort(perm, len, sizeof(index_deg_pair),
219                         (int(*)(const void*, const void*)) index_deg_pair_cmp);
220     for (i = 0; i < len; i++)
221     {
222         FLINT_ASSERT(perm[i].degree == nmod_poly_degree(moduli[perm[i].idx]));
223         FLINT_ASSERT(i == 0 || perm[i - 1].degree <= perm[i].degree);
224     }
225 
226     _nmod_poly_multi_crt_fit_length(P, FLINT_MAX(WORD(1), len - 1));
227     _nmod_poly_multi_crt_set_length(P, 0);
228     P->localsize = 1;
229     P->good = 1;
230 
231     if (1 < len)
232     {
233         _push_prog(P, moduli, perm, 0, 0, len);
234     }
235     else
236     {
237         /*
238             There is only one modulus. Let's compute as
239                 output[0] = input[0] + 0*(input[0] - input[0]) mod moduli[0]
240         */
241         i = 0;
242         nmod_poly_init_mod(P->prog[i].modulus, moduli[0]->mod);
243         nmod_poly_init_mod(P->prog[i].idem, moduli[0]->mod);
244         nmod_poly_set(P->prog[i].modulus, moduli[0]);
245         P->prog[i].a_idx = 0;
246         P->prog[i].b_idx = -WORD(1);
247         P->prog[i].c_idx = -WORD(1);
248         P->length = i + 1;
249 
250         P->good = !nmod_poly_is_zero(moduli[0]);
251     }
252 
253     if (!P->good)
254     {
255         _nmod_poly_multi_crt_set_length(P, 0);
256     }
257 
258     /* two more spots for temporaries */
259     P->temp1loc = P->localsize++;
260     P->temp2loc = P->localsize++;
261 
262     TMP_END;
263 
264     return P->good;
265 }
266 
nmod_poly_multi_crt_precompute(nmod_poly_multi_crt_t P,const nmod_poly_struct * moduli,slong len)267 int nmod_poly_multi_crt_precompute(
268     nmod_poly_multi_crt_t P,
269     const nmod_poly_struct * moduli,
270     slong len)
271 {
272     int success;
273     slong i;
274     const nmod_poly_struct ** m;
275     TMP_INIT;
276 
277     FLINT_ASSERT(len > 0);
278 
279     TMP_START;
280 
281     m = (const nmod_poly_struct **) TMP_ALLOC(len*sizeof(nmod_poly_struct *));
282     for (i = 0; i < len; i++)
283     {
284         m[i] = moduli + i;
285     }
286 
287     success = nmod_poly_multi_crt_precompute_p(P,
288                                     (const nmod_poly_struct * const *) m, len);
289     TMP_END;
290 
291     return success;
292 }
293 
294 
nmod_poly_multi_crt_precomp(nmod_poly_t output,const nmod_poly_multi_crt_t P,const nmod_poly_struct * inputs)295 void nmod_poly_multi_crt_precomp(
296     nmod_poly_t output,
297     const nmod_poly_multi_crt_t P,
298     const nmod_poly_struct * inputs)
299 {
300     slong i;
301     nmod_poly_struct * out;
302     TMP_INIT;
303 
304     TMP_START;
305     out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
306                                                     *sizeof(nmod_poly_struct));
307     for (i = 0; i < P->localsize; i++)
308     {
309         nmod_poly_init_mod(out + i, inputs[0].mod);
310     }
311 
312     nmod_poly_swap(out + 0, output);
313     _nmod_poly_multi_crt_run(out, P, inputs);
314     nmod_poly_swap(out + 0, output);
315 
316     for (i = 0; i < P->localsize; i++)
317     {
318         nmod_poly_clear(out + i);
319     }
320 
321     TMP_END;
322 }
323 
nmod_poly_multi_crt_precomp_p(nmod_poly_t output,const nmod_poly_multi_crt_t P,const nmod_poly_struct * const * inputs)324 void nmod_poly_multi_crt_precomp_p(
325     nmod_poly_t output,
326     const nmod_poly_multi_crt_t P,
327     const nmod_poly_struct * const * inputs)
328 {
329     slong i;
330     nmod_poly_struct * out;
331     TMP_INIT;
332 
333     TMP_START;
334     out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
335                                                     *sizeof(nmod_poly_struct));
336     for (i = 0; i < P->localsize; i++)
337     {
338         nmod_poly_init_mod(out + i, inputs[0]->mod);
339     }
340 
341     nmod_poly_swap(out + 0, output);
342     _nmod_poly_multi_crt_run_p(out, P, inputs);
343     nmod_poly_swap(out + 0, output);
344 
345     for (i = 0; i < P->localsize; i++)
346     {
347         nmod_poly_clear(out + i);
348     }
349 
350     TMP_END;
351 }
352 
nmod_poly_multi_crt(nmod_poly_t output,const nmod_poly_struct * moduli,const nmod_poly_struct * values,slong len)353 int nmod_poly_multi_crt(
354     nmod_poly_t output,
355     const nmod_poly_struct * moduli,
356     const nmod_poly_struct * values,
357     slong len)
358 {
359     int success;
360     slong i;
361     nmod_poly_multi_crt_t P;
362     nmod_poly_struct * out;
363     TMP_INIT;
364 
365     FLINT_ASSERT(len > 0);
366 
367     TMP_START;
368 
369     nmod_poly_multi_crt_init(P);
370     success = nmod_poly_multi_crt_precompute(P, moduli, len);
371 
372     out = (nmod_poly_struct *) TMP_ALLOC(P->localsize
373                                                     *sizeof(nmod_poly_struct));
374     for (i = 0; i < P->localsize; i++)
375     {
376         nmod_poly_init_mod(out + i, values[0].mod);
377     }
378 
379     nmod_poly_swap(out + 0, output);
380     _nmod_poly_multi_crt_run(out, P, values);
381     nmod_poly_swap(out + 0, output);
382 
383     for (i = 0; i < P->localsize; i++)
384     {
385         nmod_poly_clear(out + i);
386     }
387 
388     nmod_poly_multi_crt_clear(P);
389 
390     TMP_END;
391 
392     return success;
393 }
394 
395 /*
396     If P was set with a call to nmod_poly_multi_crt_compile(P, m, len), return
397     in outputs[0] polynomial r of smallest degree such that
398         r = inputs[0] mod m[0]
399         r = inputs[1] mod m[1]
400             ...
401         r = inputs[len-1] mod m[len-1]
402     For thread safety "outputs" is expected to have enough space for all
403     temporaries, thus should be at least as long as P->localsize.
404 */
_nmod_poly_multi_crt_run(nmod_poly_struct * outputs,const nmod_poly_multi_crt_t P,const nmod_poly_struct * inputs)405 void _nmod_poly_multi_crt_run(
406     nmod_poly_struct * outputs,
407     const nmod_poly_multi_crt_t P,
408     const nmod_poly_struct * inputs)
409 {
410     slong i;
411     slong a, b, c;
412     const nmod_poly_struct * B, * C;
413     nmod_poly_struct * A, * t1, * t2;
414 
415     t1 = outputs + P->temp1loc;
416     t2 = outputs + P->temp2loc;
417 
418     for (i = 0; i < P->length; i++)
419     {
420         a = P->prog[i].a_idx;
421         b = P->prog[i].b_idx;
422         c = P->prog[i].c_idx;
423         FLINT_ASSERT(a >= 0);
424         A = outputs + a;
425         B = b < 0 ? inputs + (-b-1) : outputs + b;
426         C = c < 0 ? inputs + (-c-1) : outputs + c;
427 
428         FLINT_ASSERT(A->mod.n == P->prog[i].modulus->mod.n);
429         FLINT_ASSERT(B->mod.n == P->prog[i].modulus->mod.n);
430         FLINT_ASSERT(C->mod.n == P->prog[i].modulus->mod.n);
431 
432         /* A = B + I*(C - B) mod M */
433         nmod_poly_sub(t1, B, C);
434         nmod_poly_mul(t2, P->prog[i].idem, t1);
435         nmod_poly_sub(t1, B, t2);
436 
437         if (nmod_poly_degree(t1) < nmod_poly_degree(P->prog[i].modulus))
438         {
439             nmod_poly_swap(A, t1);
440         }
441         else
442         {
443             nmod_poly_rem(A, t1, P->prog[i].modulus);
444         }
445 
446         /* last calculation should write answer to outputs[0] */
447         if (i + 1 >= P->length)
448         {
449             FLINT_ASSERT(A == outputs + 0);
450         }
451     }
452 }
453 
_nmod_poly_multi_crt_run_p(nmod_poly_struct * outputs,const nmod_poly_multi_crt_t P,const nmod_poly_struct * const * inputs)454 void _nmod_poly_multi_crt_run_p(
455     nmod_poly_struct * outputs,
456     const nmod_poly_multi_crt_t P,
457     const nmod_poly_struct * const * inputs)
458 {
459     slong i;
460     slong a, b, c;
461     const nmod_poly_struct * B, * C;
462     nmod_poly_struct * A, * t1, * t2;
463 
464     t1 = outputs + P->temp1loc;
465     t2 = outputs + P->temp2loc;
466 
467     for (i = 0; i < P->length; i++)
468     {
469         a = P->prog[i].a_idx;
470         b = P->prog[i].b_idx;
471         c = P->prog[i].c_idx;
472         FLINT_ASSERT(a >= 0);
473         A = outputs + a;
474         B = b < 0 ? inputs[-b-1] : outputs + b;
475         C = c < 0 ? inputs[-c-1] : outputs + c;
476 
477         FLINT_ASSERT(A->mod.n == P->prog[i].modulus->mod.n);
478         FLINT_ASSERT(B->mod.n == P->prog[i].modulus->mod.n);
479         FLINT_ASSERT(C->mod.n == P->prog[i].modulus->mod.n);
480 
481         /* A = B + I*(C - B) mod M */
482         nmod_poly_sub(t1, B, C);
483         nmod_poly_mul(t2, P->prog[i].idem, t1);
484         nmod_poly_sub(t1, B, t2);
485 
486         if (nmod_poly_degree(t1) < nmod_poly_degree(P->prog[i].modulus))
487         {
488             nmod_poly_swap(A, t1);
489         }
490         else
491         {
492             nmod_poly_rem(A, t1, P->prog[i].modulus);
493         }
494 
495         /* last calculation should write answer to outputs[0] */
496         if (i + 1 >= P->length)
497         {
498             FLINT_ASSERT(A == outputs + 0);
499         }
500     }
501 }
502