1 /*
2     Copyright (C) 2020 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include "fmpz_mod_mpoly.h"
13 #include "mpn_extras.h"
14 
15 
_fmpz_mod_mpoly_mul_johnson1(fmpz_mod_mpoly_t A,const fmpz * Bcoeffs,const ulong * Bexps,slong Blen,const fmpz * Ccoeffs,const ulong * Cexps,slong Clen,ulong cmpmask,const fmpz_mod_ctx_t ctx)16 void _fmpz_mod_mpoly_mul_johnson1(
17     fmpz_mod_mpoly_t A,
18     const fmpz * Bcoeffs, const ulong * Bexps, slong Blen,
19     const fmpz * Ccoeffs, const ulong * Cexps, slong Clen,
20     ulong cmpmask,
21     const fmpz_mod_ctx_t ctx)
22 {
23     slong n = fmpz_size(fmpz_mod_ctx_modulus(ctx));
24     slong i, j;
25     slong next_loc;
26     slong heap_len = 2; /* heap zero index unused */
27     mpoly_heap1_s * heap;
28     mpoly_heap_t * chain;
29     slong * store, * store_base;
30     mpoly_heap_t * x;
31     slong * hind;
32     ulong exp;
33     fmpz * Acoeffs = A->coeffs;
34     ulong * Aexps = A->exps;
35     slong Alen;
36     mpz_t t, acc, modulus;
37     mp_limb_t * Bcoeffs_packed = NULL;
38     mp_limb_t * Ccoeffs_packed = NULL;
39     TMP_INIT;
40 
41     TMP_START;
42 
43     mpz_init(t);
44     mpz_init(acc);
45     fmpz_mod_ctx_get_modulus_mpz_read_only(modulus, ctx);
46 
47     next_loc = Blen + 4;   /* something bigger than heap can ever be */
48     heap = (mpoly_heap1_s *) TMP_ALLOC((Blen + 1)*sizeof(mpoly_heap1_s));
49     chain = (mpoly_heap_t *) TMP_ALLOC(Blen*sizeof(mpoly_heap_t));
50     store = store_base = (slong *) TMP_ALLOC(2*Blen*sizeof(slong));
51     hind = (slong *) TMP_ALLOC(Blen*sizeof(slong));
52 
53     for (i = 0; i < Blen; i++)
54         hind[i] = 1;
55 
56     if (Blen > 8*n)
57     {
58         Bcoeffs_packed = FLINT_ARRAY_ALLOC(n*(Blen + Clen), mp_limb_t);
59         Ccoeffs_packed = Bcoeffs_packed + n*Blen;
60         for (i = 0; i < Blen; i++)
61             fmpz_get_ui_array(Bcoeffs_packed + n*i, n, Bcoeffs + i);
62         for (i = 0; i < Clen; i++)
63             fmpz_get_ui_array(Ccoeffs_packed + n*i, n, Ccoeffs + i);
64     }
65 
66     /* put (0, 0, exp2[0] + exp3[0]) on heap */
67     x = chain + 0;
68     x->i = 0;
69     x->j = 0;
70     x->next = NULL;
71 
72     HEAP_ASSIGN(heap[1], Bexps[0] + Cexps[0], x);
73     hind[0] = 2*1 + 0;
74 
75     Alen = 0;
76     while (heap_len > 1)
77     {
78         exp = heap[1].exp;
79 
80         _fmpz_mod_mpoly_fit_length(&Acoeffs, &A->coeffs_alloc,
81                                    &Aexps, &A->exps_alloc, 1, Alen + 1);
82         Aexps[Alen] = exp;
83 
84         if (Bcoeffs_packed)
85         {
86             mp_limb_t * acc_d, * t_d;
87             slong acc_len;
88 
89             FLINT_MPZ_REALLOC(acc, 2*n+1);
90             FLINT_MPZ_REALLOC(t, 2*n);
91             acc_d = acc->_mp_d;
92             t_d = t->_mp_d;
93 
94             flint_mpn_zero(acc_d, 2*n+1);
95             do {
96                 x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
97                 do {
98                     *store++ = x->i;
99                     *store++ = x->j;
100                     hind[x->i] |= WORD(1);
101                     mpn_mul_n(t_d, Bcoeffs_packed + n*x->i,
102                                    Ccoeffs_packed + n*x->j, n);
103                     acc_d[2*n] += mpn_add_n(acc_d, acc_d, t_d, 2*n);
104                 } while ((x = x->next) != NULL);
105             } while (heap_len > 1 && heap[1].exp == exp);
106 
107             acc_len = 2*n+1;
108             MPN_NORM(acc_d, acc_len);
109             acc->_mp_size = acc_len;
110         }
111         else
112         {
113             mpz_set_ui(acc, 0);
114             do {
115                 x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
116                 do {
117                     fmpz Bi = Bcoeffs[x->i];
118                     fmpz Cj = Ccoeffs[x->j];
119 
120                     *store++ = x->i;
121                     *store++ = x->j;
122 
123                     hind[x->i] |= WORD(1);
124 
125                     if (COEFF_IS_MPZ(Bi) && COEFF_IS_MPZ(Cj))
126                     {
127                         mpz_addmul(acc, COEFF_TO_PTR(Bi), COEFF_TO_PTR(Cj));
128                     }
129                     else if (COEFF_IS_MPZ(Bi) && !COEFF_IS_MPZ(Cj))
130                     {
131                         flint_mpz_addmul_ui(acc, COEFF_TO_PTR(Bi), Cj);
132                     }
133                     else if (!COEFF_IS_MPZ(Bi) && COEFF_IS_MPZ(Cj))
134                     {
135                         flint_mpz_addmul_ui(acc, COEFF_TO_PTR(Cj), Bi);
136                     }
137                     else
138                     {
139                         ulong pp1, pp0;
140                         umul_ppmm(pp1, pp0, Bi, Cj);
141                         flint_mpz_add_uiui(acc, acc, pp1, pp0);
142                     }
143                 } while ((x = x->next) != NULL);
144             } while (heap_len > 1 && heap[1].exp == exp);
145         }
146 
147         mpz_tdiv_qr(t, _fmpz_promote(Acoeffs + Alen), acc, modulus);
148         _fmpz_demote_val(Acoeffs + Alen);
149         Alen += !fmpz_is_zero(Acoeffs + Alen);
150 
151         while (store > store_base)
152         {
153             j = *--store;
154             i = *--store;
155 
156             /* should we go right? */
157             if ((i + 1 < Blen) &&
158                 (hind[i + 1] == 2*j + 1))
159             {
160                 x = chain + i + 1;
161                 x->i = i + 1;
162                 x->j = j;
163                 x->next = NULL;
164 
165                 hind[x->i] = 2*(x->j + 1) + 0;
166                 _mpoly_heap_insert1(heap, Bexps[x->i] + Cexps[x->j], x,
167                                                  &next_loc, &heap_len, cmpmask);
168             }
169 
170             /* should we go up? */
171             if ((j + 1 < Clen) &&
172                 ((hind[i] & 1) == 1) &&
173                 ((i == 0) || (hind[i - 1] >= 2*(j + 2) + 1)))
174             {
175                 x = chain + i;
176                 x->i = i;
177                 x->j = j + 1;
178                 x->next = NULL;
179 
180                 hind[x->i] = 2*(x->j + 1) + 0;
181                 _mpoly_heap_insert1(heap, Bexps[x->i] + Cexps[x->j], x,
182                                                  &next_loc, &heap_len, cmpmask);
183             }
184         }
185     }
186 
187     A->coeffs = Acoeffs;
188     A->exps = Aexps;
189     A->length = Alen;
190 
191     mpz_clear(t);
192     mpz_clear(acc);
193     flint_free(Bcoeffs_packed);
194 
195     TMP_END;
196 }
197 
198 
_fmpz_mod_mpoly_mul_johnson(fmpz_mod_mpoly_t A,const fmpz * Bcoeffs,const ulong * Bexps,slong Blen,const fmpz * Ccoeffs,const ulong * Cexps,slong Clen,flint_bitcnt_t bits,slong N,const ulong * cmpmask,const fmpz_mod_ctx_t ctx)199 void _fmpz_mod_mpoly_mul_johnson(
200     fmpz_mod_mpoly_t A,
201     const fmpz * Bcoeffs, const ulong * Bexps, slong Blen,
202     const fmpz * Ccoeffs, const ulong * Cexps, slong Clen,
203     flint_bitcnt_t bits,
204     slong N,
205     const ulong * cmpmask,
206     const fmpz_mod_ctx_t ctx)
207 {
208     slong n = fmpz_size(fmpz_mod_ctx_modulus(ctx));
209     slong i, j;
210     slong next_loc;
211     slong heap_len = 2; /* heap zero index unused */
212     mpoly_heap_s * heap;
213     mpoly_heap_t * chain;
214     slong * store, * store_base;
215     mpoly_heap_t * x;
216     ulong * exps;
217     ulong ** exp_list;
218     slong exp_next;
219     slong * hind;
220     fmpz * Acoeffs = A->coeffs;
221     ulong * Aexps = A->exps;
222     slong Alen;
223     mpz_t t, acc, modulus;
224     mp_limb_t * Bcoeffs_packed = NULL;
225     mp_limb_t * Ccoeffs_packed = NULL;
226     TMP_INIT;
227 
228     FLINT_ASSERT(Blen > 0);
229     FLINT_ASSERT(Clen > 0);
230     FLINT_ASSERT(A->bits == bits);
231 
232     if (N == 1)
233     {
234         _fmpz_mod_mpoly_mul_johnson1(A, Bcoeffs, Bexps, Blen,
235                                         Ccoeffs, Cexps, Clen, cmpmask[0], ctx);
236         return;
237     }
238 
239     TMP_START;
240 
241     mpz_init(t);
242     mpz_init(acc);
243     fmpz_mod_ctx_get_modulus_mpz_read_only(modulus, ctx);
244 
245     next_loc = Blen + 4;   /* something bigger than heap can ever be */
246     heap = (mpoly_heap_s *) TMP_ALLOC((Blen + 1)*sizeof(mpoly_heap_s));
247     chain = (mpoly_heap_t *) TMP_ALLOC(Blen*sizeof(mpoly_heap_t));
248     store = store_base = (slong *) TMP_ALLOC(2*Blen*sizeof(slong));
249     exps = (ulong *) TMP_ALLOC(Blen*N*sizeof(ulong));
250     exp_list = (ulong **) TMP_ALLOC(Blen*sizeof(ulong *));
251     hind = (slong *) TMP_ALLOC(Blen*sizeof(slong));
252 
253     for (i = 0; i < Blen; i++)
254     {
255         exp_list[i] = exps + i*N;
256         hind[i] = 1;
257     }
258 
259     if (Blen > 8*n)
260     {
261         Bcoeffs_packed = FLINT_ARRAY_ALLOC(n*(Blen + Clen), mp_limb_t);
262         Ccoeffs_packed = Bcoeffs_packed + n*Blen;
263         for (i = 0; i < Blen; i++)
264             fmpz_get_ui_array(Bcoeffs_packed + n*i, n, Bcoeffs + i);
265         for (i = 0; i < Clen; i++)
266             fmpz_get_ui_array(Ccoeffs_packed + n*i, n, Ccoeffs + i);
267     }
268 
269     /* start with no heap nodes and no exponent vectors in use */
270     exp_next = 0;
271 
272     /* put (0, 0, exp2[0] + exp3[0]) on heap */
273     x = chain + 0;
274     x->i = 0;
275     x->j = 0;
276     x->next = NULL;
277 
278     heap[1].next = x;
279     heap[1].exp = exp_list[exp_next++];
280 
281     mpoly_monomial_add_mp(heap[1].exp, Bexps + N*0, Cexps + N*0, N);
282 
283     hind[0] = 2*1 + 0;
284 
285     Alen = 0;
286     while (heap_len > 1)
287     {
288         _fmpz_mod_mpoly_fit_length(&Acoeffs, &A->coeffs_alloc,
289                                    &Aexps, &A->exps_alloc, N, Alen + 1);
290 
291         mpoly_monomial_set(Aexps + N*Alen, heap[1].exp, N);
292 
293         if (Bcoeffs_packed)
294         {
295             mp_limb_t * acc_d, * t_d;
296             slong acc_len;
297 
298             FLINT_MPZ_REALLOC(acc, 2*n+1);
299             FLINT_MPZ_REALLOC(t, 2*n);
300             acc_d = acc->_mp_d;
301             t_d = t->_mp_d;
302 
303             flint_mpn_zero(acc_d, 2*n+1);
304             do {
305                 exp_list[--exp_next] = heap[1].exp;
306                 x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
307                 do {
308                     *store++ = x->i;
309                     *store++ = x->j;
310                     hind[x->i] |= WORD(1);
311                     mpn_mul_n(t_d, Bcoeffs_packed + n*x->i,
312                                    Ccoeffs_packed + n*x->j, n);
313                     acc_d[2*n] += mpn_add_n(acc_d, acc_d, t_d, 2*n);
314                 } while ((x = x->next) != NULL);
315             } while (heap_len > 1 &&
316                      mpoly_monomial_equal(heap[1].exp, Aexps + N*Alen, N));
317             acc_len = 2*n+1;
318             MPN_NORM(acc_d, acc_len);
319             acc->_mp_size = acc_len;
320         }
321         else
322         {
323             mpz_set_ui(acc, 0);
324             do {
325                 exp_list[--exp_next] = heap[1].exp;
326                 x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
327                 do {
328                     fmpz Bi, Cj;
329 
330                     *store++ = x->i;
331                     *store++ = x->j;
332 
333                     Bi = Bcoeffs[x->i];
334                     Cj = Ccoeffs[x->j];
335 
336                     hind[x->i] |= WORD(1);
337 
338                     if (COEFF_IS_MPZ(Bi) && COEFF_IS_MPZ(Cj))
339                     {
340                         mpz_addmul(acc, COEFF_TO_PTR(Bi), COEFF_TO_PTR(Cj));
341                     }
342                     else if (COEFF_IS_MPZ(Bi) && !COEFF_IS_MPZ(Cj))
343                     {
344                         flint_mpz_addmul_ui(acc, COEFF_TO_PTR(Bi), Cj);
345                     }
346                     else if (!COEFF_IS_MPZ(Bi) && COEFF_IS_MPZ(Cj))
347                     {
348                         flint_mpz_addmul_ui(acc, COEFF_TO_PTR(Cj), Bi);
349                     }
350                     else
351                     {
352                         ulong pp1, pp0;
353                         umul_ppmm(pp1, pp0, Bi, Cj);
354                         flint_mpz_add_uiui(acc, acc, pp1, pp0);
355                     }
356                 } while ((x = x->next) != NULL);
357             } while (heap_len > 1 &&
358                      mpoly_monomial_equal(heap[1].exp, Aexps + N*Alen, N));
359         }
360 
361         mpz_tdiv_qr(t, _fmpz_promote(Acoeffs + Alen), acc, modulus);
362         _fmpz_demote_val(Acoeffs + Alen);
363         Alen += !fmpz_is_zero(Acoeffs + Alen);
364 
365         while (store > store_base)
366         {
367             j = *--store;
368             i = *--store;
369 
370             /* should we go right? */
371             if ((i + 1 < Blen) &&
372                 (hind[i + 1] == 2*j + 1))
373             {
374                 x = chain + i + 1;
375                 x->i = i + 1;
376                 x->j = j;
377                 x->next = NULL;
378 
379                 hind[x->i] = 2*(x->j + 1) + 0;
380 
381                 mpoly_monomial_add_mp(exp_list[exp_next], Bexps + N*x->i,
382                                                           Cexps + N*x->j, N);
383 
384                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
385                                              &next_loc, &heap_len, N, cmpmask);
386             }
387 
388             /* should we go up? */
389             if ((j + 1 < Clen) &&
390                 ((hind[i] & 1) == 1) &&
391                 ((i == 0) || (hind[i - 1] >= 2*(j + 2) + 1)))
392             {
393                 x = chain + i;
394                 x->i = i;
395                 x->j = j + 1;
396                 x->next = NULL;
397 
398                 hind[x->i] = 2*(x->j + 1) + 0;
399 
400                 mpoly_monomial_add_mp(exp_list[exp_next], Bexps + N*x->i,
401                                                           Cexps + N*x->j, N);
402 
403                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
404                                              &next_loc, &heap_len, N, cmpmask);
405             }
406         }
407     }
408 
409     A->coeffs = Acoeffs;
410     A->exps = Aexps;
411     A->length = Alen;
412 
413     mpz_clear(t);
414     mpz_clear(acc);
415     flint_free(Bcoeffs_packed);
416 
417     TMP_END;
418 }
419 
_fmpz_mod_mpoly_mul_johnson_maxfields(fmpz_mod_mpoly_t A,const fmpz_mod_mpoly_t B,fmpz * maxBfields,const fmpz_mod_mpoly_t C,fmpz * maxCfields,const fmpz_mod_mpoly_ctx_t ctx)420 void _fmpz_mod_mpoly_mul_johnson_maxfields(
421     fmpz_mod_mpoly_t A,
422     const fmpz_mod_mpoly_t B, fmpz * maxBfields,
423     const fmpz_mod_mpoly_t C, fmpz * maxCfields,
424     const fmpz_mod_mpoly_ctx_t ctx)
425 {
426     slong N;
427     flint_bitcnt_t Abits;
428     ulong * cmpmask;
429     ulong * Bexps = B->exps, * Cexps = C->exps;
430     int freeBexps = 0, freeCexps = 0;
431     fmpz_mod_mpoly_struct * P, T[1];
432     TMP_INIT;
433 
434     FLINT_ASSERT(B->length > 0 && C->length > 0);
435 
436     TMP_START;
437 
438     _fmpz_vec_add(maxBfields, maxBfields, maxCfields, ctx->minfo->nfields);
439 
440     Abits = 1 + _fmpz_vec_max_bits(maxBfields, ctx->minfo->nfields);
441     Abits = FLINT_MAX(Abits, B->bits);
442     Abits = FLINT_MAX(Abits, C->bits);
443     Abits = mpoly_fix_bits(Abits, ctx->minfo);
444 
445     N = mpoly_words_per_exp(Abits, ctx->minfo);
446     cmpmask = (ulong *) TMP_ALLOC(N*sizeof(ulong));
447     mpoly_get_cmpmask(cmpmask, N, Abits, ctx->minfo);
448 
449     /* ensure input exponents are packed into same sized fields as output */
450     if (Abits != B->bits)
451     {
452         freeBexps = 1;
453         Bexps = (ulong *) flint_malloc(N*B->length*sizeof(ulong));
454         mpoly_repack_monomials(Bexps, Abits, B->exps, B->bits, B->length, ctx->minfo);
455     }
456 
457     if (Abits != C->bits)
458     {
459         freeCexps = 1;
460         Cexps = (ulong *) flint_malloc(N*C->length*sizeof(ulong));
461         mpoly_repack_monomials(Cexps, Abits, C->exps, C->bits, C->length, ctx->minfo);
462     }
463 
464     if (A == B || A == C)
465     {
466         fmpz_mod_mpoly_init(T, ctx);
467         P = T;
468     }
469     else
470     {
471         P = A;
472     }
473 
474     fmpz_mod_mpoly_fit_length_reset_bits(P, B->length + C->length, Abits, ctx);
475 
476     if (B->length > C->length)
477     {
478         _fmpz_mod_mpoly_mul_johnson(P, C->coeffs, Cexps, C->length,
479                   B->coeffs, Bexps, B->length, Abits, N, cmpmask, ctx->ffinfo);
480     }
481     else
482     {
483         _fmpz_mod_mpoly_mul_johnson(P, B->coeffs, Bexps, B->length,
484                   C->coeffs, Cexps, C->length, Abits, N, cmpmask, ctx->ffinfo);
485     }
486 
487     if (A == B || A == C)
488     {
489         fmpz_mod_mpoly_swap(A, T, ctx);
490         fmpz_mod_mpoly_clear(T, ctx);
491     }
492 
493     if (freeBexps)
494         flint_free(Bexps);
495 
496     if (freeCexps)
497         flint_free(Cexps);
498 
499     TMP_END;
500 }
501 
fmpz_mod_mpoly_mul_johnson(fmpz_mod_mpoly_t A,const fmpz_mod_mpoly_t B,const fmpz_mod_mpoly_t C,const fmpz_mod_mpoly_ctx_t ctx)502 void fmpz_mod_mpoly_mul_johnson(
503     fmpz_mod_mpoly_t A,
504     const fmpz_mod_mpoly_t B,
505     const fmpz_mod_mpoly_t C,
506     const fmpz_mod_mpoly_ctx_t ctx)
507 {
508     slong i;
509     fmpz * maxBfields, * maxCfields;
510     TMP_INIT;
511 
512     if (B->length < 1 || C->length < 1)
513     {
514         fmpz_mod_mpoly_zero(A, ctx);
515         return;
516     }
517 
518     TMP_START;
519 
520     maxBfields = TMP_ARRAY_ALLOC(2*ctx->minfo->nfields, fmpz);
521     maxCfields = maxBfields + ctx->minfo->nfields;
522     for (i = 0; i < 2*ctx->minfo->nfields; i++)
523         fmpz_init(maxBfields + i);
524 
525     mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
526     mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
527 
528     _fmpz_mod_mpoly_mul_johnson_maxfields(A, B, maxBfields, C, maxCfields, ctx);
529 
530     for (i = 0; i < 2*ctx->minfo->nfields; i++)
531         fmpz_clear(maxBfields + i);
532 
533     TMP_END;
534 }
535 
536