1 /*
2     Copyright (C) 2017-2019 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include <string.h>
13 #include "thread_pool.h"
14 #include "fmpz_mpoly.h"
15 
_fmpz_mpoly_mul_heap_part1(fmpz ** A_coeff,ulong ** A_exp,slong * A_alloc,const fmpz * Bcoeff,const ulong * Bexp,slong Blen,const fmpz * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const fmpz_mpoly_stripe_t S)16 slong _fmpz_mpoly_mul_heap_part1(fmpz ** A_coeff, ulong ** A_exp, slong * A_alloc,
17               const fmpz * Bcoeff, const ulong * Bexp, slong Blen,
18               const fmpz * Ccoeff, const ulong * Cexp, slong Clen,
19          slong * start, slong * end, slong * hind, const fmpz_mpoly_stripe_t S)
20 {
21     const int flint_small = S->flint_small;
22     const ulong cmpmask = S->cmpmask[0];
23     slong i, j;
24     ulong exp;
25     mpoly_heap_t * x;
26     slong next_loc;
27     slong heap_len;
28     mpoly_heap1_s * heap;
29     mpoly_heap_t * chain;
30     slong * store, * store_base;
31     slong Alen;
32     fmpz * Acoeff = *A_coeff;
33     ulong * Aexp = *A_exp;
34     slong Aalloc = *A_alloc;
35     ulong acc[3], p[3];
36     int first_prod;
37 
38     i = 0;
39     store = store_base = (slong *) (S->big_mem + i);
40     i += 2*Blen*sizeof(slong);
41     heap = (mpoly_heap1_s *)(S->big_mem + i);
42     i += (Blen + 1)*sizeof(mpoly_heap1_s);
43     chain = (mpoly_heap_t *)(S->big_mem + i);
44     i += Blen*sizeof(mpoly_heap_t);
45     FLINT_ASSERT(i <= S->big_mem_alloc);
46 
47     /* put all the starting nodes on the heap */
48     heap_len = 1; /* heap zero index unused */
49     next_loc = Blen + 4;   /* something bigger than heap can ever be */
50     for (i = 0; i < Blen; i++)
51     {
52         hind[i] = 2*start[i] + 1;
53     }
54     for (i = 0; i < Blen; i++)
55     {
56         if (  (start[i] < end[i])
57            && (  (i == 0)
58               || (start[i] < start[i - 1])
59               )
60            )
61         {
62             x = chain + i;
63             x->i = i;
64             x->j = start[i];
65             x->next = NULL;
66             hind[x->i] = 2*(x->j + 1) + 0;
67             _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
68                                                 &next_loc, &heap_len, cmpmask);
69         }
70     }
71 
72     Alen = 0;
73     while (heap_len > 1)
74     {
75         exp = heap[1].exp;
76 
77         _fmpz_mpoly_fit_length(&Acoeff, &Aexp, &Aalloc, Alen + 1, 1);
78 
79         Aexp[Alen] = exp;
80 
81         acc[0] = acc[1] = acc[2] = 0;
82         first_prod = 1;
83         while (heap_len > 1 && heap[1].exp == exp)
84         {
85             x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
86 
87             hind[x->i] |= WORD(1);
88             *store++ = x->i;
89             *store++ = x->j;
90 
91             if (flint_small)
92             {
93                 smul_ppmm(p[1], p[0], Bcoeff[x->i], Ccoeff[x->j]);
94                 p[2] = FLINT_SIGN_EXT(p[1]);
95                 add_sssaaaaaa(acc[2], acc[1], acc[0], acc[2], acc[1], acc[0],
96                                                          p[2], p[1], p[0]);
97                 first_prod = 0;
98                 while ((x = x->next) != NULL)
99                 {
100                     smul_ppmm(p[1], p[0], Bcoeff[x->i], Ccoeff[x->j]);
101                     p[2] = FLINT_SIGN_EXT(p[1]);
102                     add_sssaaaaaa(acc[2], acc[1], acc[0], acc[2], acc[1], acc[0],
103                                                              p[2], p[1], p[0]);
104                     hind[x->i] |= WORD(1);
105                     *store++ = x->i;
106                     *store++ = x->j;
107                 }
108             }
109             else /* output coeffs require multiprecision */
110             {
111                 if (first_prod)
112                     fmpz_mul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
113                 else
114                     fmpz_addmul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
115                 first_prod = 0;
116                 while ((x = x->next) != NULL)
117                 {
118                     fmpz_addmul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
119                     hind[x->i] |= WORD(1);
120                     *store++ = x->i;
121                     *store++ = x->j;
122                 }
123             }
124         }
125 
126         /* for each node temporarily stored */
127         while (store > store_base)
128         {
129             j = *--store;
130             i = *--store;
131 
132             /* should we go right? */
133             if (  (i + 1 < Blen)
134                && (j + 0 < end[i + 1])
135                && (hind[i + 1] == 2*j + 1)
136                )
137             {
138                 x = chain + i + 1;
139                 x->i = i + 1;
140                 x->j = j;
141                 x->next = NULL;
142                 hind[x->i] = 2*(x->j + 1) + 0;
143                 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
144                                                 &next_loc, &heap_len, cmpmask);
145             }
146 
147             /* should we go up? */
148             if (  (j + 1 < end[i + 0])
149                && ((hind[i] & 1) == 1)
150                && (  (i == 0)
151                   || (hind[i - 1] >= 2*(j + 2) + 1)
152                   )
153                )
154             {
155                 x = chain + i;
156                 x->i = i;
157                 x->j = j + 1;
158                 x->next = NULL;
159                 hind[x->i] = 2*(x->j+1) + 0;
160                 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
161                                                 &next_loc, &heap_len, cmpmask);
162             }
163         }
164 
165         /* set output poly coeff from temporary accumulation, if not multiprec */
166         if (flint_small)
167         {
168             fmpz_set_signed_uiuiui(Acoeff + Alen, acc[2], acc[1], acc[0]);
169         }
170 
171         Alen += !fmpz_is_zero(Acoeff + Alen);
172     }
173 
174     *A_coeff = Acoeff;
175     *A_exp = Aexp;
176     *A_alloc = Aalloc;
177     return Alen;
178 }
179 
180 
_fmpz_mpoly_mul_heap_part(fmpz ** A_coeff,ulong ** A_exp,slong * A_alloc,const fmpz * Bcoeff,const ulong * Bexp,slong Blen,const fmpz * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const fmpz_mpoly_stripe_t S)181 slong _fmpz_mpoly_mul_heap_part(fmpz ** A_coeff, ulong ** A_exp, slong * A_alloc,
182                  const fmpz * Bcoeff, const ulong * Bexp, slong Blen,
183                  const fmpz * Ccoeff, const ulong * Cexp, slong Clen,
184          slong * start, slong * end, slong * hind, const fmpz_mpoly_stripe_t S)
185 {
186     const int flint_small = S->flint_small;
187     flint_bitcnt_t bits = S->bits;
188     slong N = S->N;
189     const ulong * cmpmask = S->cmpmask;
190     slong i, j;
191     ulong * exp, * exps;
192     ulong ** exp_list;
193     slong exp_next;
194     mpoly_heap_t * x;
195     slong next_loc;
196     slong heap_len;
197     mpoly_heap_s * heap;
198     mpoly_heap_t * chain;
199     slong * store, * store_base;
200     slong Alen;
201     ulong * Aexp = *A_exp;
202     slong Aalloc = *A_alloc;
203     fmpz * Acoeff = *A_coeff;
204     ulong acc[3], p[3];
205     int first_prod;
206 
207     i = 0;
208     store = store_base = (slong *) (S->big_mem + i);
209     i += 2*Blen*sizeof(slong);
210     exp_list = (ulong **) (S->big_mem + i);
211     i += Blen*sizeof(ulong *);
212     exps = (ulong *) (S->big_mem + i);
213     i += Blen*N*sizeof(ulong);
214     heap = (mpoly_heap_s *) (S->big_mem + i);
215     i += (Blen + 1)*sizeof(mpoly_heap_s);
216     chain = (mpoly_heap_t *) (S->big_mem + i);
217     i += Blen*sizeof(mpoly_heap_t);
218     FLINT_ASSERT(i <= S->big_mem_alloc);
219 
220     /* put all the starting nodes on the heap */
221     heap_len = 1; /* heap zero index unused */
222     next_loc = Blen + 4;   /* something bigger than heap can ever be */
223     exp_next = 0;
224     for (i = 0; i < Blen; i++)
225         exp_list[i] = exps + i*N;
226     for (i = 0; i < Blen; i++)
227         hind[i] = 2*start[i] + 1;
228     for (i = 0; i < Blen; i++)
229     {
230         if (  (start[i] < end[i])
231            && (  (i == 0)
232               || (start[i] < start[i - 1])
233               )
234            )
235         {
236             x = chain + i;
237             x->i = i;
238             x->j = start[i];
239             x->next = NULL;
240             hind[x->i] = 2*(x->j + 1) + 0;
241             if (bits <= FLINT_BITS)
242                 mpoly_monomial_add(exp_list[exp_next], Bexp + x->i*N,
243                                                        Cexp + x->j*N, N);
244             else
245                 mpoly_monomial_add_mp(exp_list[exp_next], Bexp + x->i*N,
246                                                           Cexp + x->j*N, N);
247 
248             exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
249                                              &next_loc, &heap_len, N, cmpmask);
250         }
251     }
252 
253     Alen = 0;
254     while (heap_len > 1)
255     {
256         exp = heap[1].exp;
257 
258         _fmpz_mpoly_fit_length(&Acoeff, &Aexp, &Aalloc, Alen + 1, N);
259 
260         mpoly_monomial_set(Aexp + N*Alen, exp, N);
261 
262         acc[0] = acc[1] = acc[2] = 0;
263         first_prod = 1;
264         while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N))
265         {
266             exp_list[--exp_next] = heap[1].exp;
267 
268             x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
269 
270             hind[x->i] |= WORD(1);
271             *store++ = x->i;
272             *store++ = x->j;
273 
274             /* if output coeffs will fit in three words */
275             if (flint_small)
276             {
277                 smul_ppmm(p[1], p[0], Bcoeff[x->i], Ccoeff[x->j]);
278                 p[2] = FLINT_SIGN_EXT(p[1]);
279                 add_sssaaaaaa(acc[2], acc[1], acc[0], acc[2], acc[1], acc[0],
280                                                          p[2], p[1], p[0]);
281                 first_prod = 0;
282                 while ((x = x->next) != NULL)
283                 {
284                     smul_ppmm(p[1], p[0], Bcoeff[x->i], Ccoeff[x->j]);
285                     p[2] = FLINT_SIGN_EXT(p[1]);
286                     add_sssaaaaaa(acc[2], acc[1], acc[0], acc[2], acc[1], acc[0],
287                                                              p[2], p[1], p[0]);
288                     hind[x->i] |= WORD(1);
289                     *store++ = x->i;
290                     *store++ = x->j;
291                 }
292             }
293             else /* output coeffs require multiprecision */
294             {
295                 if (first_prod)
296                     fmpz_mul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
297                 else
298                     fmpz_addmul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
299                 first_prod = 0;
300                 while ((x = x->next) != NULL)
301                 {
302                     fmpz_addmul(Acoeff + Alen, Bcoeff + x->i, Ccoeff + x->j);
303                     hind[x->i] |= WORD(1);
304                     *store++ = x->i;
305                     *store++ = x->j;
306                 }
307             }
308         }
309 
310         /* for each node temporarily stored */
311         while (store > store_base)
312         {
313             j = *--store;
314             i = *--store;
315 
316             /* should we go right? */
317             if (  (i + 1 < Blen)
318                && (j + 0 < end[i + 1])
319                && (hind[i + 1] == 2*j + 1)
320                )
321             {
322                 x = chain + i + 1;
323                 x->i = i + 1;
324                 x->j = j;
325                 x->next = NULL;
326 
327                 hind[x->i] = 2*(x->j + 1) + 0;
328 
329                 if (bits <= FLINT_BITS)
330                     mpoly_monomial_add(exp_list[exp_next], Bexp + x->i*N,
331                                                            Cexp + x->j*N, N);
332                 else
333                     mpoly_monomial_add_mp(exp_list[exp_next], Bexp + x->i*N,
334                                                               Cexp + x->j*N, N);
335 
336                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
337                                              &next_loc, &heap_len, N, cmpmask);
338             }
339 
340             /* should we go up? */
341             if (  (j + 1 < end[i + 0])
342                && ((hind[i] & 1) == 1)
343                && (  (i == 0)
344                   || (hind[i - 1] >= 2*(j + 2) + 1)
345                   )
346                )
347             {
348                 x = chain + i;
349                 x->i = i;
350                 x->j = j + 1;
351                 x->next = NULL;
352 
353                 hind[x->i] = 2*(x->j + 1) + 0;
354 
355                 if (bits <= FLINT_BITS)
356                     mpoly_monomial_add(exp_list[exp_next], Bexp + x->i*N,
357                                                            Cexp + x->j*N, N);
358                 else
359                     mpoly_monomial_add_mp(exp_list[exp_next], Bexp + x->i*N,
360                                                               Cexp + x->j*N, N);
361 
362                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
363                                              &next_loc, &heap_len, N, cmpmask);
364             }
365         }
366 
367         /* set output poly coeff from temporary accumulation, if not multiprec */
368         if (flint_small)
369         {
370             fmpz_set_signed_uiuiui(Acoeff + Alen, acc[2], acc[1], acc[0]);
371         }
372 
373         Alen += !fmpz_is_zero(Acoeff + Alen);
374     }
375 
376     *A_coeff = Acoeff;
377     *A_exp = Aexp;
378     *A_alloc = Aalloc;
379     return Alen;
380 }
381 
382 
383 /*
384     The workers calculate product terms from 4*n divisions, where n is the
385     number of threads.
386 */
387 
388 typedef struct
389 {
390     volatile int idx;
391 #if FLINT_USES_PTHREAD
392     pthread_mutex_t mutex;
393 #endif
394     slong nthreads;
395     slong ndivs;
396     fmpz * Acoeff;
397     ulong * Aexp;
398     const fmpz * Bcoeff;
399     const ulong * Bexp;
400     slong Blen;
401     const fmpz * Ccoeff;
402     const ulong * Cexp;
403     slong Clen;
404     slong N;
405     flint_bitcnt_t bits;
406     const ulong * cmpmask;
407     int flint_small;
408 }
409 _base_struct;
410 
411 typedef _base_struct _base_t[1];
412 
413 typedef struct
414 {
415     slong lower;
416     slong upper;
417     slong thread_idx;
418     slong Aoffset;
419     slong Alen;
420     slong Aalloc;
421     ulong * Aexp;
422     fmpz * Acoeff;
423 }
424 _div_struct;
425 
426 typedef struct
427 {
428     fmpz_mpoly_stripe_t S;
429     slong idx;
430     slong time;
431     _base_struct * base;
432     _div_struct * divs;
433 #if FLINT_USES_PTHREAD
434     pthread_mutex_t mutex;
435     pthread_cond_t cond;
436 #endif
437     slong * t1, * t2, * t3, * t4;
438     ulong * exp;
439 }
440 _worker_arg_struct;
441 
442 
443 /*
444     The workers simply take the next available division and calculate all
445     product terms in this division.
446 */
447 
448 #define SWAP_PTRS(xx, yy) \
449    do { \
450       tt = xx; \
451       xx = yy; \
452       yy = tt; \
453    } while (0)
454 
_fmpz_mpoly_mul_heap_threaded_worker(void * varg)455 static void _fmpz_mpoly_mul_heap_threaded_worker(void * varg)
456 {
457     _worker_arg_struct * arg = (_worker_arg_struct *) varg;
458     fmpz_mpoly_stripe_struct * S = arg->S;
459     _div_struct * divs = arg->divs;
460     _base_struct * base = arg->base;
461     slong Blen = base->Blen;
462     slong N = base->N;
463     slong i, j;
464     ulong *exp;
465     slong score;
466     slong *start, *end, *t1, *t2, *t3, *t4, *tt;
467 
468     exp = (ulong *) flint_malloc(N*sizeof(ulong));
469     t1 = (slong *) flint_malloc(Blen*sizeof(slong));
470     t2 = (slong *) flint_malloc(Blen*sizeof(slong));
471     t3 = (slong *) flint_malloc(Blen*sizeof(slong));
472     t4 = (slong *) flint_malloc(Blen*sizeof(slong));
473 
474     S->N = N;
475     S->bits = base->bits;
476     S->cmpmask = base->cmpmask;
477     S->flint_small = base->flint_small;
478 
479     S->big_mem_alloc = 0;
480     if (N == 1)
481     {
482         S->big_mem_alloc += 2*Blen*sizeof(slong);
483         S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap1_s);
484         S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
485     }
486     else
487     {
488         S->big_mem_alloc += 2*Blen*sizeof(slong);
489         S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap_s);
490         S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
491         S->big_mem_alloc += Blen*S->N*sizeof(ulong);
492         S->big_mem_alloc += Blen*sizeof(ulong *);
493     }
494     S->big_mem = (char *) flint_malloc(S->big_mem_alloc);
495 
496     /* get index to start working on */
497     if (arg->idx + 1 < base->nthreads)
498     {
499 #if FLINT_USES_PTHREAD
500         pthread_mutex_lock(&base->mutex);
501 #endif
502         i = base->idx - 1;
503         base->idx = i;
504 #if FLINT_USES_PTHREAD
505         pthread_mutex_unlock(&base->mutex);
506 #endif
507     }
508     else
509     {
510         i = base->ndivs - 1;
511     }
512 
513     while (i >= 0)
514     {
515         FLINT_ASSERT(divs[i].thread_idx == -WORD(1));
516         divs[i].thread_idx = arg->idx;
517 
518         /* calculate start */
519         if (i + 1 < base-> ndivs)
520         {
521             mpoly_search_monomials(
522                 &start, exp, &score, t1, t2, t3,
523                             divs[i].lower, divs[i].lower,
524                             base->Bexp, base->Blen, base->Cexp, base->Clen,
525                                           base->N, base->cmpmask);
526             if (start == t2)
527             {
528                 SWAP_PTRS(t1, t2);
529             }
530             else if (start == t3)
531             {
532                 SWAP_PTRS(t1, t3);
533             }
534         }
535         else
536         {
537             start = t1;
538             for (j = 0; j < base->Blen; j++)
539                 start[j] = 0;
540         }
541 
542         /* calculate end */
543         if (i > 0)
544         {
545             mpoly_search_monomials(
546                 &end, exp, &score, t2, t3, t4,
547                             divs[i - 1].lower, divs[i - 1].lower,
548                             base->Bexp, base->Blen, base->Cexp, base->Clen,
549                                           base->N, base->cmpmask);
550             if (end == t3)
551             {
552                 SWAP_PTRS(t2, t3);
553             }
554             else if (end == t4)
555             {
556                 SWAP_PTRS(t2, t4);
557             }
558         }
559         else
560         {
561             end = t2;
562             for (j = 0; j < base->Blen; j++)
563                 end[j] = base->Clen;
564         }
565         /* t3 and t4 are free for workspace at this point */
566 
567         /* join code assumes all divisions have been allocated */
568         _fmpz_mpoly_fit_length(&divs[i].Acoeff, &divs[i].Aexp, &divs[i].Aalloc,
569                                                                        256, N);
570         /* calculate products in [start, end) */
571         if (N == 1)
572         {
573             divs[i].Alen = _fmpz_mpoly_mul_heap_part1(
574                          &divs[i].Acoeff, &divs[i].Aexp, &divs[i].Aalloc,
575                               base->Bcoeff, base->Bexp, base->Blen,
576                               base->Ccoeff, base->Cexp, base->Clen,
577                                                           start, end, t3, S);
578         }
579         else
580         {
581             divs[i].Alen = _fmpz_mpoly_mul_heap_part(
582                          &divs[i].Acoeff, &divs[i].Aexp, &divs[i].Aalloc,
583                               base->Bcoeff, base->Bexp, base->Blen,
584                               base->Ccoeff, base->Cexp, base->Clen,
585                                                           start, end, t3, S);
586         }
587 
588         /* get next index to work on */
589 #if FLINT_USES_PTHREAD
590         pthread_mutex_lock(&base->mutex);
591 #endif
592 	i = base->idx - 1;
593         base->idx = i;
594 #if FLINT_USES_PTHREAD
595 	pthread_mutex_unlock(&base->mutex);
596 #endif
597     }
598 
599     flint_free(S->big_mem);
600     flint_free(t4);
601     flint_free(t3);
602     flint_free(t2);
603     flint_free(t1);
604     flint_free(exp);
605 }
606 
_join_worker(void * varg)607 static void _join_worker(void * varg)
608 {
609     _worker_arg_struct * arg = (_worker_arg_struct *) varg;
610     _div_struct * divs = arg->divs;
611     _base_struct * base = arg->base;
612     slong N = base->N;
613     slong i;
614 
615     for (i = base->ndivs - 2; i >= 0; i--)
616     {
617         FLINT_ASSERT(divs[i].thread_idx != -WORD(1));
618 
619         if (divs[i].thread_idx != arg->idx)
620             continue;
621 
622         FLINT_ASSERT(divs[i].Acoeff != NULL);
623         FLINT_ASSERT(divs[i].Aexp != NULL);
624 
625         memcpy(base->Acoeff + divs[i].Aoffset, divs[i].Acoeff,
626                                                     divs[i].Alen*sizeof(fmpz));
627 
628         memcpy(base->Aexp + N*divs[i].Aoffset, divs[i].Aexp,
629                                                  N*divs[i].Alen*sizeof(ulong));
630 
631         flint_free(divs[i].Acoeff);
632         flint_free(divs[i].Aexp);
633     }
634 }
635 
_fmpz_mpoly_mul_heap_threaded(fmpz_mpoly_t A,const fmpz * Bcoeff,const ulong * Bexp,slong Blen,const fmpz * Ccoeff,const ulong * Cexp,slong Clen,flint_bitcnt_t bits,slong N,const ulong * cmpmask,const thread_pool_handle * handles,slong num_handles)636 void _fmpz_mpoly_mul_heap_threaded(
637     fmpz_mpoly_t A,
638     const fmpz * Bcoeff, const ulong * Bexp, slong Blen,
639     const fmpz * Ccoeff, const ulong * Cexp, slong Clen,
640     flint_bitcnt_t bits,
641     slong N,
642     const ulong * cmpmask,
643     const thread_pool_handle * handles,
644     slong num_handles)
645 {
646     slong i, j;
647     slong BClen, hi;
648     _base_t base;
649     _div_struct * divs;
650     _worker_arg_struct * args;
651     slong Aalloc;
652     slong Alen;
653     fmpz * Acoeff;
654     ulong * Aexp;
655 
656     /* bail if product of lengths overflows a word */
657     umul_ppmm(hi, BClen, Blen, Clen);
658     if (hi != 0 || BClen < 0)
659     {
660         Alen = _fmpz_mpoly_mul_johnson(&A->coeffs, &A->exps, &A->alloc,
661                                          Bcoeff, Bexp, Blen,
662                                          Ccoeff, Cexp, Clen, bits, N, cmpmask);
663         _fmpz_mpoly_set_length(A, Alen, NULL);
664         return;
665 
666     }
667 
668     base->nthreads = num_handles + 1;
669     base->ndivs = base->nthreads*4;  /* number of divisons */
670     base->Bcoeff = Bcoeff;
671     base->Bexp = Bexp;
672     base->Blen = Blen;
673     base->Ccoeff = Ccoeff;
674     base->Cexp = Cexp;
675     base->Clen = Clen;
676     base->bits = bits;
677     base->N = N;
678     base->cmpmask = cmpmask;
679     base->idx = base->ndivs - 1;    /* decremented by worker threads */
680     base->flint_small =   _fmpz_mpoly_fits_small(Bcoeff, Blen)
681                        && _fmpz_mpoly_fits_small(Ccoeff, Clen);
682 
683     divs = (_div_struct *) flint_malloc(base->ndivs*sizeof(_div_struct));
684     args = (_worker_arg_struct *) flint_malloc(base->nthreads
685                                                   *sizeof(_worker_arg_struct));
686 
687     /* allocate space and set the boundary for each division */
688     FLINT_ASSERT(BClen/Blen == Clen);
689     for (i = base->ndivs - 1; i >= 0; i--)
690     {
691         double d = (double)(i + 1) / (double)(base->ndivs);
692 
693         /* divisions decrease in size so that no worker finishes too early */
694         divs[i].lower = (d * d) * BClen;
695         divs[i].lower = FLINT_MIN(divs[i].lower, BClen);
696         divs[i].lower = FLINT_MAX(divs[i].lower, WORD(0));
697         divs[i].upper = divs[i].lower;
698         divs[i].Aoffset = -WORD(1);
699         divs[i].thread_idx = -WORD(1);
700 
701         divs[i].Alen = 0;
702         if (i == base->ndivs - 1)
703         {
704             /* highest division writes to original poly */
705             divs[i].Aalloc = A->alloc;
706             divs[i].Aexp = A->exps;
707             divs[i].Acoeff = A->coeffs;
708             /* must clear output coefficients before working in parallel */
709             for (j = 0; j < A->length; j++)
710                _fmpz_demote(A->coeffs + j);
711         }
712         else
713         {
714             /* lower divisions write to a new worker poly */
715             divs[i].Aalloc = 0;
716             divs[i].Aexp = NULL;
717             divs[i].Acoeff = NULL;
718         }
719     }
720 
721     /* compute each chunk in parallel */
722 #if FLINT_USES_PTHREAD
723     pthread_mutex_init(&base->mutex, NULL);
724 #endif
725     for (i = 0; i < num_handles; i++)
726     {
727         args[i].idx = i;
728         args[i].base = base;
729         args[i].divs = divs;
730         thread_pool_wake(global_thread_pool, handles[i], 0,
731                                _fmpz_mpoly_mul_heap_threaded_worker, &args[i]);
732     }
733     i = num_handles;
734     args[i].idx = i;
735     args[i].base = base;
736     args[i].divs = divs;
737     _fmpz_mpoly_mul_heap_threaded_worker(&args[i]);
738     for (i = 0; i < num_handles; i++)
739     {
740         thread_pool_wait(global_thread_pool, handles[i]);
741     }
742 
743     /* calculate and allocate space for final answer */
744     i = base->ndivs - 1;
745     Alen = divs[i].Alen;
746     Acoeff = divs[i].Acoeff;
747     Aexp = divs[i].Aexp;
748     Aalloc = divs[i].Aalloc;
749     for (i = base->ndivs - 2; i >= 0; i--)
750     {
751         divs[i].Aoffset = Alen;
752         Alen += divs[i].Alen;
753     }
754     if (Alen > Aalloc)
755     {
756         Acoeff = (fmpz *) flint_realloc(Acoeff, Alen*sizeof(fmpz));
757         Aexp = (ulong *) flint_realloc(Aexp, Alen*N*sizeof(ulong));
758         Aalloc = Alen;
759     }
760     base->Acoeff = Acoeff;
761     base->Aexp = Aexp;
762 
763     /* join answers */
764     for (i = 0; i < num_handles; i++)
765     {
766         thread_pool_wake(global_thread_pool, handles[i], 0, _join_worker, &args[i]);
767     }
768     _join_worker(&args[num_handles]);
769     for (i = 0; i < num_handles; i++)
770     {
771         thread_pool_wait(global_thread_pool, handles[i]);
772     }
773 
774 #if FLINT_USES_PTHREAD
775     pthread_mutex_destroy(&base->mutex);
776 #endif
777 
778     flint_free(args);
779     flint_free(divs);
780 
781     /* we should have managed to keep coefficients past length demoted */
782     FLINT_ASSERT(Alen <= Aalloc);
783 #if FLINT_WANT_ASSERT
784     for (i = Alen; i < Aalloc; i++)
785     {
786         FLINT_ASSERT(!COEFF_IS_MPZ(*(Acoeff + i)));
787     }
788 #endif
789 
790     A->coeffs = Acoeff;
791     A->exps = Aexp;
792     A->alloc = Aalloc;
793     A->length = Alen;
794 }
795 
796 
797 
798 /* maxBfields gets clobbered */
_fmpz_mpoly_mul_heap_threaded_pool_maxfields(fmpz_mpoly_t A,const fmpz_mpoly_t B,fmpz * maxBfields,const fmpz_mpoly_t C,fmpz * maxCfields,const fmpz_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)799 void _fmpz_mpoly_mul_heap_threaded_pool_maxfields(
800     fmpz_mpoly_t A,
801     const fmpz_mpoly_t B, fmpz * maxBfields,
802     const fmpz_mpoly_t C, fmpz * maxCfields,
803     const fmpz_mpoly_ctx_t ctx,
804     const thread_pool_handle * handles,
805     slong num_handles)
806 {
807     slong N;
808     flint_bitcnt_t exp_bits;
809     ulong * cmpmask;
810     ulong * Bexp, * Cexp;
811     int freeBexp, freeCexp;
812     TMP_INIT;
813 
814     TMP_START;
815 
816     _fmpz_vec_add(maxBfields, maxBfields, maxCfields, ctx->minfo->nfields);
817 
818     exp_bits = _fmpz_vec_max_bits(maxBfields, ctx->minfo->nfields);
819     exp_bits = FLINT_MAX(MPOLY_MIN_BITS, exp_bits + 1);
820     exp_bits = FLINT_MAX(exp_bits, B->bits);
821     exp_bits = FLINT_MAX(exp_bits, C->bits);
822     exp_bits = mpoly_fix_bits(exp_bits, ctx->minfo);
823 
824     N = mpoly_words_per_exp(exp_bits, ctx->minfo);
825     cmpmask = (ulong *) TMP_ALLOC(N*sizeof(ulong));
826     mpoly_get_cmpmask(cmpmask, N, exp_bits, ctx->minfo);
827 
828     /* ensure input exponents are packed into same sized fields as output */
829     freeBexp = 0;
830     Bexp = B->exps;
831     if (exp_bits > B->bits)
832     {
833         freeBexp = 1;
834         Bexp = (ulong *) flint_malloc(N*B->length*sizeof(ulong));
835         mpoly_repack_monomials(Bexp, exp_bits, B->exps, B->bits,
836                                                         B->length, ctx->minfo);
837     }
838 
839     freeCexp = 0;
840     Cexp = C->exps;
841     if (exp_bits > C->bits)
842     {
843         freeCexp = 1;
844         Cexp = (ulong *) flint_malloc(N*C->length*sizeof(ulong));
845         mpoly_repack_monomials(Cexp, exp_bits, C->exps, C->bits,
846                                                         C->length, ctx->minfo);
847     }
848 
849     /* deal with aliasing and do multiplication */
850     if (A == B || A == C)
851     {
852         fmpz_mpoly_t T;
853         fmpz_mpoly_init3(T, 0, exp_bits, ctx);
854 
855         /* algorithm more efficient if smaller poly first */
856         if (B->length >= C->length)
857         {
858             _fmpz_mpoly_mul_heap_threaded(T, C->coeffs, Cexp, C->length,
859                                              B->coeffs, Bexp, B->length,
860                                    exp_bits, N, cmpmask, handles, num_handles);
861         }
862         else
863         {
864             _fmpz_mpoly_mul_heap_threaded(T, B->coeffs, Bexp, B->length,
865                                              C->coeffs, Cexp, C->length,
866                                    exp_bits, N, cmpmask, handles, num_handles);
867         }
868 
869         fmpz_mpoly_swap(T, A, ctx);
870         fmpz_mpoly_clear(T, ctx);
871     }
872     else
873     {
874         fmpz_mpoly_fit_length_reset_bits(A, A->length, exp_bits, ctx);
875 
876         /* algorithm more efficient if smaller poly first */
877         if (B->length > C->length)
878         {
879             _fmpz_mpoly_mul_heap_threaded(A, C->coeffs, Cexp, C->length,
880                                              B->coeffs, Bexp, B->length,
881                                    exp_bits, N, cmpmask, handles, num_handles);
882         }
883         else
884         {
885             _fmpz_mpoly_mul_heap_threaded(A, B->coeffs, Bexp, B->length,
886                                              C->coeffs, Cexp, C->length,
887                                    exp_bits, N, cmpmask, handles, num_handles);
888         }
889     }
890 
891     if (freeBexp)
892         flint_free(Bexp);
893 
894     if (freeCexp)
895         flint_free(Cexp);
896 
897     TMP_END;
898 }
899 
900 
fmpz_mpoly_mul_heap_threaded(fmpz_mpoly_t A,const fmpz_mpoly_t B,const fmpz_mpoly_t C,const fmpz_mpoly_ctx_t ctx)901 void fmpz_mpoly_mul_heap_threaded(
902     fmpz_mpoly_t A,
903     const fmpz_mpoly_t B,
904     const fmpz_mpoly_t C,
905     const fmpz_mpoly_ctx_t ctx)
906 {
907     slong i;
908     fmpz * maxBfields, * maxCfields;
909     thread_pool_handle * handles;
910     slong num_handles;
911     slong thread_limit = FLINT_MIN(A->length, B->length)/16;
912     TMP_INIT;
913 
914     if (B->length == 0 || C->length == 0)
915     {
916         fmpz_mpoly_zero(A, ctx);
917         return;
918     }
919 
920     TMP_START;
921 
922     maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
923     maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
924     for (i = 0; i < ctx->minfo->nfields; i++)
925     {
926         fmpz_init(maxBfields + i);
927         fmpz_init(maxCfields + i);
928     }
929     mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
930     mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
931 
932     num_handles = flint_request_threads(&handles, thread_limit);
933 
934     _fmpz_mpoly_mul_heap_threaded_pool_maxfields(A, B, maxBfields, C, maxCfields,
935                                                     ctx, handles, num_handles);
936 
937     flint_give_back_threads(handles, num_handles);
938 
939     for (i = 0; i < ctx->minfo->nfields; i++)
940     {
941         fmpz_clear(maxBfields + i);
942         fmpz_clear(maxCfields + i);
943     }
944 
945     TMP_END;
946 }
947