1 /*
2     Copyright (C) 2017 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include "nmod_mpoly.h"
13 
14 
_nmod_mpoly_mul_johnson1(mp_limb_t ** coeff1,ulong ** exp1,slong * alloc,const mp_limb_t * coeff2,const ulong * exp2,slong len2,const mp_limb_t * coeff3,const ulong * exp3,slong len3,ulong maskhi,const nmodf_ctx_t fctx)15 slong _nmod_mpoly_mul_johnson1(mp_limb_t ** coeff1, ulong ** exp1, slong * alloc,
16               const mp_limb_t * coeff2, const ulong * exp2, slong len2,
17               const mp_limb_t * coeff3, const ulong * exp3, slong len3,
18                                           ulong maskhi, const nmodf_ctx_t fctx)
19 {
20     slong i, j;
21     slong next_loc;
22     slong Q_len = 0, heap_len = 2; /* heap zero index unused */
23     mpoly_heap1_s * heap;
24     mpoly_heap_t * chain;
25     slong * Q;
26     mpoly_heap_t * x;
27     slong len1;
28     mp_limb_t * p1 = * coeff1;
29     ulong * e1 = * exp1;
30     slong * hind;
31     ulong exp;
32     ulong acc0, acc1, acc2, pp0, pp1;
33     TMP_INIT;
34 
35     TMP_START;
36 
37     next_loc = len2 + 4;   /* something bigger than heap can ever be */
38     heap = (mpoly_heap1_s *) TMP_ALLOC((len2 + 1)*sizeof(mpoly_heap1_s));
39     chain = (mpoly_heap_t *) TMP_ALLOC(len2*sizeof(mpoly_heap_t));
40     Q = (slong *) TMP_ALLOC(2*len2*sizeof(slong));
41 
42     /* space for heap indices */
43     hind = (slong *) TMP_ALLOC(len2*sizeof(slong));
44     for (i = 0; i < len2; i++)
45         hind[i] = 1;
46 
47     /* put (0, 0, exp2[0] + exp3[0]) on heap */
48     x = chain + 0;
49     x->i = 0;
50     x->j = 0;
51     x->next = NULL;
52 
53     HEAP_ASSIGN(heap[1], exp2[0] + exp3[0], x);
54     hind[0] = 2*1 + 0;
55 
56     len1 = 0;
57     while (heap_len > 1)
58     {
59         exp = heap[1].exp;
60 
61         _nmod_mpoly_fit_length(&p1, &e1, alloc, len1 + 1, 1);
62 
63         e1[len1] = exp;
64 
65         acc0 = acc1 = acc2 = 0;
66         do
67         {
68             x = _mpoly_heap_pop1(heap, &heap_len, maskhi);
69 
70             hind[x->i] |= WORD(1);
71             Q[Q_len++] = x->i;
72             Q[Q_len++] = x->j;
73             umul_ppmm(pp1, pp0, coeff2[x->i], coeff3[x->j]);
74             add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
75 
76             while ((x = x->next) != NULL)
77             {
78                 hind[x->i] |= WORD(1);
79                 Q[Q_len++] = x->i;
80                 Q[Q_len++] = x->j;
81                 umul_ppmm(pp1, pp0, coeff2[x->i], coeff3[x->j]);
82                 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
83             }
84         } while (heap_len > 1 && heap[1].exp == exp);
85 
86         NMOD_RED3(p1[len1], acc2, acc1, acc0, fctx->mod);
87         len1 += (p1[len1] != 0);
88 
89         while (Q_len > 0)
90         {
91             j = Q[--Q_len];
92             i = Q[--Q_len];
93 
94             /* should we go right? */
95             if (  (i + 1 < len2)
96                && (hind[i + 1] == 2*j + 1)
97                )
98             {
99                 x = chain + i + 1;
100                 x->i = i + 1;
101                 x->j = j;
102                 x->next = NULL;
103 
104                 hind[x->i] = 2*(x->j + 1) + 0;
105                 _mpoly_heap_insert1(heap, exp2[x->i] + exp3[x->j], x,
106                                                  &next_loc, &heap_len, maskhi);
107             }
108 
109             /* should we go up? */
110             if (  (j + 1 < len3)
111                && ((hind[i] & 1) == 1)
112                && (  (i == 0)
113                   || (hind[i - 1] >= 2*(j + 2) + 1)
114                   )
115                )
116             {
117                 x = chain + i;
118                 x->i = i;
119                 x->j = j + 1;
120                 x->next = NULL;
121 
122                 hind[x->i] = 2*(x->j + 1) + 0;
123                 _mpoly_heap_insert1(heap, exp2[x->i] + exp3[x->j], x,
124                                                  &next_loc, &heap_len, maskhi);
125             }
126         }
127     }
128 
129     (* coeff1) = p1;
130     (* exp1) = e1;
131 
132     TMP_END;
133 
134     return len1;
135 }
136 
137 
_nmod_mpoly_mul_johnson(mp_limb_t ** coeff1,ulong ** exp1,slong * alloc,const mp_limb_t * coeff2,const ulong * exp2,slong len2,const mp_limb_t * coeff3,const ulong * exp3,slong len3,flint_bitcnt_t bits,slong N,const ulong * cmpmask,const nmodf_ctx_t fctx)138 slong _nmod_mpoly_mul_johnson(mp_limb_t ** coeff1, ulong ** exp1, slong * alloc,
139                  const mp_limb_t * coeff2, const ulong * exp2, slong len2,
140                  const mp_limb_t * coeff3, const ulong * exp3, slong len3,
141       flint_bitcnt_t bits, slong N, const ulong * cmpmask, const nmodf_ctx_t fctx)
142 {
143     slong i, j;
144     slong next_loc;
145     slong Q_len = 0, heap_len = 2; /* heap zero index unused */
146     mpoly_heap_s * heap;
147     mpoly_heap_t * chain;
148     slong * Q;
149     mpoly_heap_t * x;
150     slong len1;
151     mp_limb_t * p1 = * coeff1;
152     ulong * e1 = *exp1;
153     ulong * exp, * exps;
154     ulong ** exp_list;
155     slong exp_next;
156     slong * hind;
157     ulong acc0, acc1, acc2, pp0, pp1;
158     TMP_INIT;
159 
160     if (N == 1)
161         return _nmod_mpoly_mul_johnson1(coeff1, exp1, alloc,
162                      coeff2, exp2, len2, coeff3, exp3, len3, cmpmask[0], fctx);
163 
164     TMP_START;
165 
166     next_loc = len2 + 4;   /* something bigger than heap can ever be */
167     heap = (mpoly_heap_s *) TMP_ALLOC((len2 + 1)*sizeof(mpoly_heap_s));
168     chain = (mpoly_heap_t *) TMP_ALLOC(len2*sizeof(mpoly_heap_t));
169     Q = (slong *) TMP_ALLOC(2*len2*sizeof(slong));
170     exps = (ulong *) TMP_ALLOC(len2*N*sizeof(ulong));
171     exp_list = (ulong **) TMP_ALLOC(len2*sizeof(ulong *));
172     for (i = 0; i < len2; i++)
173         exp_list[i] = exps + i*N;
174 
175     hind = (slong *) TMP_ALLOC(len2*sizeof(slong));
176     for (i = 0; i < len2; i++)
177         hind[i] = 1;
178 
179     /* start with no heap nodes and no exponent vectors in use */
180     exp_next = 0;
181 
182     /* put (0, 0, exp2[0] + exp3[0]) on heap */
183     x = chain + 0;
184     x->i = 0;
185     x->j = 0;
186     x->next = NULL;
187 
188     heap[1].next = x;
189     heap[1].exp = exp_list[exp_next++];
190 
191     if (bits <= FLINT_BITS)
192         mpoly_monomial_add(heap[1].exp, exp2, exp3, N);
193     else
194         mpoly_monomial_add_mp(heap[1].exp, exp2, exp3, N);
195 
196     hind[0] = 2*1 + 0;
197 
198     len1 = 0;
199     while (heap_len > 1)
200     {
201         exp = heap[1].exp;
202 
203         _nmod_mpoly_fit_length(&p1, &e1, alloc, len1 + 1, N);
204 
205         mpoly_monomial_set(e1 + len1*N, exp, N);
206 
207         acc0 = acc1 = acc2 = 0;
208         do
209         {
210             exp_list[--exp_next] = heap[1].exp;
211 
212             x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
213 
214             hind[x->i] |= WORD(1);
215             Q[Q_len++] = x->i;
216             Q[Q_len++] = x->j;
217             umul_ppmm(pp1, pp0, coeff2[x->i], coeff3[x->j]);
218             add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
219 
220             while ((x = x->next) != NULL)
221             {
222                 hind[x->i] |= WORD(1);
223                 Q[Q_len++] = x->i;
224                 Q[Q_len++] = x->j;
225                 umul_ppmm(pp1, pp0, coeff2[x->i], coeff3[x->j]);
226                 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
227             }
228         } while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N));
229 
230         NMOD_RED3(p1[len1], acc2, acc1, acc0, fctx->mod);
231         len1 += (p1[len1] != 0);
232 
233         while (Q_len > 0)
234         {
235             /* take node from store */
236             j = Q[--Q_len];
237             i = Q[--Q_len];
238 
239             /* should we go right? */
240             if (  (i + 1 < len2)
241                && (hind[i + 1] == 2*j + 1)
242                )
243             {
244                 x = chain + i + 1;
245                 x->i = i + 1;
246                 x->j = j;
247                 x->next = NULL;
248 
249                 hind[x->i] = 2*(x->j+1) + 0;
250 
251                 if (bits <= FLINT_BITS)
252                     mpoly_monomial_add(exp_list[exp_next], exp2 + x->i*N, exp3 + x->j*N, N);
253                 else
254                     mpoly_monomial_add_mp(exp_list[exp_next], exp2 + x->i*N, exp3 + x->j*N, N);
255 
256                 if (!_mpoly_heap_insert(heap, exp_list[exp_next++], x,
257                                       &next_loc, &heap_len, N, cmpmask))
258                     exp_next--;
259             }
260 
261             /* should we go up? */
262             if (  (j + 1 < len3)
263                && ((hind[i] & 1) == 1)
264                && (  (i == 0)
265                   || (hind[i - 1] >= 2*(j + 2) + 1)
266                   )
267                )
268             {
269                 x = chain + i;
270                 x->i = i;
271                 x->j = j + 1;
272                 x->next = NULL;
273 
274                 hind[x->i] = 2*(x->j+1) + 0;
275 
276                 if (bits <= FLINT_BITS)
277                     mpoly_monomial_add(exp_list[exp_next], exp2 + x->i*N, exp3 + x->j*N, N);
278                 else
279                     mpoly_monomial_add_mp(exp_list[exp_next], exp2 + x->i*N, exp3 + x->j*N, N);
280 
281                 if (!_mpoly_heap_insert(heap, exp_list[exp_next++], x,
282                                       &next_loc, &heap_len, N, cmpmask))
283                     exp_next--;
284             }
285         }
286     }
287 
288     (* coeff1) = p1;
289     (* exp1) = e1;
290 
291     TMP_END;
292 
293     return len1;
294 }
295 
296 /* maxBfields gets clobbered */
_nmod_mpoly_mul_johnson_maxfields(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx)297 void _nmod_mpoly_mul_johnson_maxfields(
298     nmod_mpoly_t A,
299     const nmod_mpoly_t B, fmpz * maxBfields,
300     const nmod_mpoly_t C, fmpz * maxCfields,
301     const nmod_mpoly_ctx_t ctx)
302 {
303     slong N;
304     flint_bitcnt_t Abits;
305     ulong * cmpmask;
306     ulong * Bexp, * Cexp;
307     int freeBexp, freeCexp;
308     TMP_INIT;
309 
310     TMP_START;
311 
312     _fmpz_vec_add(maxBfields, maxBfields, maxCfields, ctx->minfo->nfields);
313 
314     Abits = _fmpz_vec_max_bits(maxBfields, ctx->minfo->nfields);
315     Abits = FLINT_MAX(MPOLY_MIN_BITS, Abits + 1);
316     Abits = FLINT_MAX(Abits, B->bits);
317     Abits = FLINT_MAX(Abits, C->bits);
318     Abits = mpoly_fix_bits(Abits, ctx->minfo);
319 
320     N = mpoly_words_per_exp(Abits, ctx->minfo);
321     cmpmask = (ulong*) TMP_ALLOC(N*sizeof(ulong));
322     mpoly_get_cmpmask(cmpmask, N, Abits, ctx->minfo);
323 
324     /* ensure input exponents are packed into same sized fields as output */
325     freeBexp = 0;
326     Bexp = B->exps;
327     if (Abits > B->bits)
328     {
329         freeBexp = 1;
330         Bexp = (ulong *) flint_malloc(N*B->length*sizeof(ulong));
331         mpoly_repack_monomials(Bexp, Abits, B->exps, B->bits,
332                                                         B->length, ctx->minfo);
333     }
334 
335     freeCexp = 0;
336     Cexp = C->exps;
337     if (Abits > C->bits)
338     {
339         freeCexp = 1;
340         Cexp = (ulong *) flint_malloc(N*C->length*sizeof(ulong));
341         mpoly_repack_monomials(Cexp, Abits, C->exps, C->bits,
342                                                         C->length, ctx->minfo);
343     }
344 
345     /* deal with aliasing and do multiplication */
346     if (A == B || A == C)
347     {
348         nmod_mpoly_t T;
349         nmod_mpoly_init2(T, B->length + C->length - 1, ctx);
350         nmod_mpoly_fit_bits(T, Abits, ctx);
351         T->bits = Abits;
352 
353         if (B->length > C->length)
354         {
355             T->length = _nmod_mpoly_mul_johnson(&T->coeffs, &T->exps, &T->alloc,
356                                                   C->coeffs, Cexp, C->length,
357                                                   B->coeffs, Bexp, B->length,
358                                                Abits, N, cmpmask, ctx->ffinfo);
359         }
360         else
361         {
362             T->length = _nmod_mpoly_mul_johnson(&T->coeffs, &T->exps, &T->alloc,
363                                                   C->coeffs, Cexp, C->length,
364                                                   B->coeffs, Bexp, B->length,
365                                                Abits, N, cmpmask, ctx->ffinfo);
366         }
367 
368         nmod_mpoly_swap(T, A, ctx);
369         nmod_mpoly_clear(T, ctx);
370     }
371     else
372     {
373         nmod_mpoly_fit_length(A, B->length + C->length - 1, ctx);
374         nmod_mpoly_fit_bits(A, Abits, ctx);
375         A->bits = Abits;
376 
377         if (B->length > C->length)
378         {
379             A->length = _nmod_mpoly_mul_johnson(&A->coeffs, &A->exps, &A->alloc,
380                                                   C->coeffs, Cexp, C->length,
381                                                   B->coeffs, Bexp, B->length,
382                                                Abits, N, cmpmask, ctx->ffinfo);
383         }
384         else
385         {
386             A->length = _nmod_mpoly_mul_johnson(&A->coeffs, &A->exps, &A->alloc,
387                                                   C->coeffs, Cexp, C->length,
388                                                   B->coeffs, Bexp, B->length,
389                                                Abits, N, cmpmask, ctx->ffinfo);
390         }
391     }
392 
393     if (freeBexp)
394         flint_free(Bexp);
395 
396     if (freeCexp)
397         flint_free(Cexp);
398 
399     TMP_END;
400 }
401 
nmod_mpoly_mul_johnson(nmod_mpoly_t A,const nmod_mpoly_t B,const nmod_mpoly_t C,const nmod_mpoly_ctx_t ctx)402 void nmod_mpoly_mul_johnson(
403     nmod_mpoly_t A,
404     const nmod_mpoly_t B,
405     const nmod_mpoly_t C,
406     const nmod_mpoly_ctx_t ctx)
407 {
408     slong i;
409     fmpz * maxBfields, * maxCfields;
410     TMP_INIT;
411 
412     if (B->length == 0 || C->length == 0)
413     {
414         nmod_mpoly_zero(A, ctx);
415         return;
416     }
417 
418     TMP_START;
419 
420     maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
421     maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
422     for (i = 0; i < ctx->minfo->nfields; i++)
423     {
424         fmpz_init(maxBfields + i);
425         fmpz_init(maxCfields + i);
426     }
427     mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
428     mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
429 
430     _nmod_mpoly_mul_johnson_maxfields(A, B, maxBfields, C, maxCfields, ctx);
431 
432     for (i = 0; i < ctx->minfo->nfields; i++)
433     {
434         fmpz_clear(maxBfields + i);
435         fmpz_clear(maxCfields + i);
436     }
437 
438     TMP_END;
439 }
440