1 /*
2     Copyright (C) 2017-2019 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include <string.h>
13 #include <gmp.h>
14 #include <stdlib.h>
15 #include "thread_pool.h"
16 #include "nmod_mpoly.h"
17 #include "long_extras.h"
18 
19 
20 /*
21     set A = the part of B*C with exps in [start, end)
22     this functions reallocates A and returns the length of A
23     version for N == 1
24 */
_nmod_mpoly_mul_heap_part1(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const nmod_mpoly_stripe_t S)25 void _nmod_mpoly_mul_heap_part1(
26     nmod_mpoly_t A,
27     const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
28     const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
29     slong * start,
30     slong * end,
31     slong * hind,
32     const nmod_mpoly_stripe_t S)
33 {
34     const ulong cmpmask = S->cmpmask[0];
35     slong i, j;
36     ulong exp;
37     mpoly_heap_t * x;
38     slong next_loc;
39     slong heap_len;
40     mpoly_heap1_s * heap;
41     mpoly_heap_t * chain;
42     slong * store, * store_base;
43     slong Alen;
44     ulong * Aexp = A->exps;
45     mp_limb_t * Acoeff = A->coeffs;
46     mp_limb_t acc0, acc1, acc2, pp0, pp1;
47 
48     FLINT_ASSERT(S->N == 1);
49 
50     /* tmp allocs from S->big_mem */
51     i = 0;
52     store = store_base = (slong *) (S->big_mem + i);
53     i += 2*Blen*sizeof(slong);
54     heap = (mpoly_heap1_s *)(S->big_mem + i);
55     i += (Blen + 1)*sizeof(mpoly_heap1_s);
56     chain = (mpoly_heap_t *)(S->big_mem + i);
57     i += Blen*sizeof(mpoly_heap_t);
58     FLINT_ASSERT(i <= S->big_mem_alloc);
59 
60     /* put all the starting nodes on the heap */
61     heap_len = 1; /* heap zero index unused */
62     next_loc = Blen + 4;   /* something bigger than heap can ever be */
63     for (i = 0; i < Blen; i++)
64     {
65         hind[i] = 2*start[i] + 1;
66     }
67     for (i = 0; i < Blen; i++)
68     {
69         if (  (start[i] < end[i])
70            && (  (i == 0)
71               || (start[i] < start[i - 1])
72               )
73            )
74         {
75             x = chain + i;
76             x->i = i;
77             x->j = start[i];
78             x->next = NULL;
79             hind[x->i] = 2*(x->j + 1) + 0;
80             _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
81                                                 &next_loc, &heap_len, cmpmask);
82         }
83     }
84 
85     Alen = 0;
86     while (heap_len > 1)
87     {
88         exp = heap[1].exp;
89 
90         _nmod_mpoly_fit_length(&Acoeff, &A->coeffs_alloc,
91                                &Aexp, &A->exps_alloc, 1, Alen + 1);
92 
93         Aexp[Alen] = exp;
94 
95         acc0 = acc1 = acc2 = 0;
96         do
97         {
98             x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
99 
100             hind[x->i] |= WORD(1);
101             *store++ = x->i;
102             *store++ = x->j;
103             umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
104             add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
105 
106             while ((x = x->next) != NULL)
107             {
108                 hind[x->i] |= WORD(1);
109                 *store++ = x->i;
110                 *store++ = x->j;
111                 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
112                 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
113             }
114         } while (heap_len > 1 && heap[1].exp == exp);
115 
116         NMOD_RED3(Acoeff[Alen], acc2, acc1, acc0, S->mod);
117         Alen += (Acoeff[Alen] != UWORD(0));
118 
119         /* for each node temporarily stored */
120         while (store > store_base)
121         {
122             j = *--store;
123             i = *--store;
124 
125             /* should we go right? */
126             if (  (i + 1 < Blen)
127                && (j + 0 < end[i + 1])
128                && (hind[i + 1] == 2*j + 1)
129                )
130             {
131                 x = chain + i + 1;
132                 x->i = i + 1;
133                 x->j = j;
134                 x->next = NULL;
135                 hind[x->i] = 2*(x->j+1) + 0;
136                 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
137                                                 &next_loc, &heap_len, cmpmask);
138             }
139 
140             /* should we go up? */
141             if (  (j + 1 < end[i + 0])
142                && ((hind[i] & 1) == 1)
143                && (  (i == 0)
144                   || (hind[i - 1] >= 2*(j + 2) + 1)
145                   )
146                )
147             {
148                 x = chain + i;
149                 x->i = i;
150                 x->j = j + 1;
151                 x->next = NULL;
152                 hind[x->i] = 2*(x->j+1) + 0;
153                 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
154                                                 &next_loc, &heap_len, cmpmask);
155             }
156         }
157     }
158 
159     A->coeffs = Acoeff;
160     A->exps = Aexp;
161     A->length = Alen;
162 }
163 
164 
_nmod_mpoly_mul_heap_part(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const nmod_mpoly_stripe_t S)165 void _nmod_mpoly_mul_heap_part(
166     nmod_mpoly_t A,
167     const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
168     const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
169     slong * start,
170     slong * end,
171     slong * hind,
172     const nmod_mpoly_stripe_t S)
173 {
174     flint_bitcnt_t bits = S->bits;
175     slong N = S->N;
176     const ulong * cmpmask = S->cmpmask;
177     slong i, j;
178     slong next_loc;
179     slong heap_len;
180     ulong * exp, * exps;
181     ulong ** exp_list;
182     slong exp_next;
183     mpoly_heap_t * x;
184     mpoly_heap_s * heap;
185     mpoly_heap_t * chain;
186     slong * store, * store_base;
187     slong Alen;
188     ulong * Aexp = A->exps;
189     mp_limb_t * Acoeff = A->coeffs;
190     ulong acc0, acc1, acc2, pp0, pp1;
191 
192     /* tmp allocs from S->big_mem */
193     i = 0;
194     store = store_base = (slong *) (S->big_mem + i);
195     i += 2*Blen*sizeof(slong);
196     exp_list = (ulong **) (S->big_mem + i);
197     i += Blen*sizeof(ulong *);
198     exps = (ulong *) (S->big_mem + i);
199     i += Blen*N*sizeof(ulong);
200     heap = (mpoly_heap_s *) (S->big_mem + i);
201     i += (Blen + 1)*sizeof(mpoly_heap_s);
202     chain = (mpoly_heap_t *) (S->big_mem + i);
203     i += Blen*sizeof(mpoly_heap_t);
204     FLINT_ASSERT(i <= S->big_mem_alloc);
205 
206     /* put all the starting nodes on the heap */
207     heap_len = 1; /* heap zero index unused */
208     next_loc = Blen + 4;   /* something bigger than heap can ever be */
209     exp_next = 0;
210     for (i = 0; i < Blen; i++)
211         exp_list[i] = exps + N*i;
212     for (i = 0; i < Blen; i++)
213         hind[i] = 2*start[i] + 1;
214     for (i = 0; i < Blen; i++)
215     {
216         if (  (start[i] < end[i])
217            && (  (i == 0)
218               || (start[i] < start[i - 1])
219               )
220            )
221         {
222             x = chain + i;
223             x->i = i;
224             x->j = start[i];
225             x->next = NULL;
226             hind[x->i] = 2*(x->j + 1) + 0;
227 
228             if (bits <= FLINT_BITS)
229                 mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
230                                                        Cexp + N*x->j, N);
231             else
232                 mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
233                                                           Cexp + N*x->j, N);
234 
235             exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
236                                              &next_loc, &heap_len, N, cmpmask);
237         }
238     }
239 
240     Alen = 0;
241     while (heap_len > 1)
242     {
243         exp = heap[1].exp;
244 
245         _nmod_mpoly_fit_length(&Acoeff, &A->coeffs_alloc,
246                                &Aexp, &A->exps_alloc, N, Alen + 1);
247 
248         mpoly_monomial_set(Aexp + N*Alen, exp, N);
249 
250         acc0 = acc1 = acc2 = 0;
251         do
252         {
253             exp_list[--exp_next] = heap[1].exp;
254 
255             x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
256 
257             hind[x->i] |= WORD(1);
258             *store++ = x->i;
259             *store++ = x->j;
260             umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
261             add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
262 
263             while ((x = x->next) != NULL)
264             {
265                 hind[x->i] |= WORD(1);
266                 *store++ = x->i;
267                 *store++ = x->j;
268                 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
269                 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
270             }
271         } while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N));
272 
273         NMOD_RED3(Acoeff[Alen], acc2, acc1, acc0, S->mod);
274         Alen += (Acoeff[Alen] != UWORD(0));
275 
276         /* for each node temporarily stored */
277         while (store > store_base)
278         {
279             j = *--store;
280             i = *--store;
281 
282             /* should we go right? */
283             if (  (i + 1 < Blen)
284                && (j + 0 < end[i + 1])
285                && (hind[i + 1] == 2*j + 1)
286                )
287             {
288                 x = chain + i + 1;
289                 x->i = i + 1;
290                 x->j = j;
291                 x->next = NULL;
292 
293                 hind[x->i] = 2*(x->j + 1) + 0;
294 
295                 if (bits <= FLINT_BITS)
296                     mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
297                                                            Cexp + N*x->j, N);
298                 else
299                     mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
300                                                               Cexp + N*x->j, N);
301 
302                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
303                                              &next_loc, &heap_len, N, cmpmask);
304             }
305 
306             /* should we go up? */
307             if (  (j + 1 < end[i + 0])
308                && ((hind[i] & 1) == 1)
309                && (  (i == 0)
310                   || (hind[i - 1] >= 2*(j + 2) + 1)
311                   )
312                )
313             {
314                 x = chain + i;
315                 x->i = i;
316                 x->j = j + 1;
317                 x->next = NULL;
318 
319                 hind[x->i] = 2*(x->j + 1) + 0;
320 
321                 if (bits <= FLINT_BITS)
322                     mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
323                                                            Cexp + N*x->j, N);
324                 else
325                     mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
326                                                               Cexp + N*x->j, N);
327 
328                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
329                                              &next_loc, &heap_len, N, cmpmask);
330             }
331         }
332     }
333 
334     A->coeffs = Acoeff;
335     A->exps = Aexp;
336     A->length = Alen;
337 }
338 
339 
340 /*
341     The workers calculate product terms from 4*n divisions, where n is the
342     number of threads.
343 */
344 
345 typedef struct
346 {
347     volatile int idx;
348 #if FLINT_USES_PTHREAD
349     pthread_mutex_t mutex;
350 #endif
351     slong nthreads;
352     slong ndivs;
353     const nmod_mpoly_ctx_struct * ctx;
354     mp_limb_t * Acoeff;
355     ulong * Aexp;
356     const mp_limb_t * Bcoeff;
357     const ulong * Bexp;
358     slong Blen;
359     const mp_limb_t * Ccoeff;
360     const ulong * Cexp;
361     slong Clen;
362     slong N;
363     flint_bitcnt_t bits;
364     const ulong * cmpmask;
365 }
366 _base_struct;
367 
368 typedef _base_struct _base_t[1];
369 
370 typedef struct
371 {
372     slong lower;
373     slong upper;
374     slong thread_idx;
375     slong Aoffset;
376     nmod_mpoly_t A;
377 }
378 _div_struct;
379 
380 typedef struct
381 {
382     nmod_mpoly_stripe_t S;
383     slong idx;
384     slong time;
385     _base_struct * base;
386     _div_struct * divs;
387 #if FLINT_USES_PTHREAD
388     pthread_mutex_t mutex;
389     pthread_cond_t cond;
390 #endif
391     slong * t1, * t2, * t3, * t4;
392     ulong * exp;
393 }
394 _worker_arg_struct;
395 
396 
397 /*
398     The workers simply take the next available division and calculate all
399     product terms in this division.
400 */
401 
402 #define SWAP_PTRS(xx, yy) \
403    do { \
404       tt = xx; \
405       xx = yy; \
406       yy = tt; \
407    } while (0)
408 
_nmod_mpoly_mul_heap_threaded_worker(void * arg_ptr)409 static void _nmod_mpoly_mul_heap_threaded_worker(void * arg_ptr)
410 {
411     _worker_arg_struct * arg = (_worker_arg_struct *) arg_ptr;
412     nmod_mpoly_stripe_struct * S = arg->S;
413     _div_struct * divs = arg->divs;
414     _base_struct * base = arg->base;
415     slong Blen = base->Blen;
416     slong N = base->N;
417     slong i, j;
418     ulong * exp;
419     slong score;
420     slong * start, * end, * t1, * t2, * t3, * t4, * tt;
421 
422     exp = (ulong *) flint_malloc(N*sizeof(ulong));
423     t1 = (slong *) flint_malloc(Blen*sizeof(slong));
424     t2 = (slong *) flint_malloc(Blen*sizeof(slong));
425     t3 = (slong *) flint_malloc(Blen*sizeof(slong));
426     t4 = (slong *) flint_malloc(Blen*sizeof(slong));
427 
428     S->N = N;
429     S->bits = base->bits;
430     S->cmpmask = base->cmpmask;
431     S->ctx = base->ctx;
432     S->mod = base->ctx->mod;
433 
434     S->big_mem_alloc = 0;
435     if (N == 1)
436     {
437         S->big_mem_alloc += 2*Blen*sizeof(slong);
438         S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap1_s);
439         S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
440     }
441     else
442     {
443         S->big_mem_alloc += 2*Blen*sizeof(slong);
444         S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap_s);
445         S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
446         S->big_mem_alloc += Blen*S->N*sizeof(ulong);
447         S->big_mem_alloc += Blen*sizeof(ulong *);
448     }
449     S->big_mem = (char *) flint_malloc(S->big_mem_alloc);
450 
451     /* get index to start working on */
452     if (arg->idx + 1 < base->nthreads)
453     {
454 #if FLINT_USES_PTHREAD
455         pthread_mutex_lock(&base->mutex);
456 #endif
457 	i = base->idx - 1;
458         base->idx = i;
459 #if FLINT_USES_PTHREAD
460         pthread_mutex_unlock(&base->mutex);
461 #endif
462     }
463     else
464     {
465         i = base->ndivs - 1;
466     }
467 
468     while (i >= 0)
469     {
470         FLINT_ASSERT(divs[i].thread_idx == -WORD(1));
471         divs[i].thread_idx = arg->idx;
472 
473         /* calculate start */
474         if (i + 1 < base-> ndivs)
475         {
476             mpoly_search_monomials(
477                 &start, exp, &score, t1, t2, t3,
478                             divs[i].lower, divs[i].lower,
479                             base->Bexp, base->Blen, base->Cexp, base->Clen,
480                                           base->N, base->cmpmask);
481             if (start == t2)
482             {
483                 SWAP_PTRS(t1, t2);
484             }
485             else if (start == t3)
486             {
487                 SWAP_PTRS(t1, t3);
488             }
489         }
490         else
491         {
492             start = t1;
493             for (j = 0; j < base->Blen; j++)
494                 start[j] = 0;
495         }
496 
497         /* calculate end */
498         if (i > 0)
499         {
500             mpoly_search_monomials(
501                 &end, exp, &score, t2, t3, t4,
502                             divs[i - 1].lower, divs[i - 1].lower,
503                             base->Bexp, base->Blen, base->Cexp, base->Clen,
504                                           base->N, base->cmpmask);
505 
506 
507             if (end == t3)
508             {
509                 SWAP_PTRS(t2, t3);
510             }
511             else if (end == t4)
512             {
513                 SWAP_PTRS(t2, t4);
514             }
515         }
516         else
517         {
518             end = t2;
519             for (j = 0; j < base->Blen; j++)
520                 end[j] = base->Clen;
521         }
522         /* t3 and t4 are free for workspace at this point */
523 
524         /* calculate products in [start, end) */
525         _nmod_mpoly_fit_length(&divs[i].A->coeffs, &divs[i].A->coeffs_alloc,
526                              &divs[i].A->exps, &divs[i].A->exps_alloc, N, 256);
527         if (N == 1)
528         {
529             _nmod_mpoly_mul_heap_part1(divs[i].A,
530                                       base->Bcoeff,  base->Bexp,  base->Blen,
531                                       base->Ccoeff,  base->Cexp,  base->Clen,
532                                                             start, end, t3, S);
533         }
534         else
535         {
536             _nmod_mpoly_mul_heap_part(divs[i].A,
537                                       base->Bcoeff,  base->Bexp,  base->Blen,
538                                       base->Ccoeff,  base->Cexp,  base->Clen,
539                                                             start, end, t3, S);
540         }
541 
542         /* get next index to work on */
543 #if FLINT_USES_PTHREAD
544         pthread_mutex_lock(&base->mutex);
545 #endif
546 	i = base->idx - 1;
547         base->idx = i;
548 #if FLINT_USES_PTHREAD
549         pthread_mutex_unlock(&base->mutex);
550 #endif
551     }
552 
553     /* clean up */
554     flint_free(S->big_mem);
555     flint_free(t4);
556     flint_free(t3);
557     flint_free(t2);
558     flint_free(t1);
559     flint_free(exp);
560 }
561 
562 
_join_worker(void * varg)563 static void _join_worker(void * varg)
564 {
565     _worker_arg_struct * arg = (_worker_arg_struct *) varg;
566     _div_struct * divs = arg->divs;
567     _base_struct * base = arg->base;
568     slong N = base->N;
569     slong i;
570 
571     for (i = base->ndivs - 2; i >= 0; i--)
572     {
573         FLINT_ASSERT(divs[i].thread_idx != -WORD(1));
574 
575         if (divs[i].thread_idx != arg->idx)
576             continue;
577 
578         FLINT_ASSERT(divs[i].A->coeffs != NULL);
579         FLINT_ASSERT(divs[i].A->exps != NULL);
580 
581         memcpy(base->Acoeff + divs[i].Aoffset, divs[i].A->coeffs,
582                                           divs[i].A->length*sizeof(mp_limb_t));
583 
584         memcpy(base->Aexp + N*divs[i].Aoffset, divs[i].A->exps,
585                                             N*divs[i].A->length*sizeof(ulong));
586 
587         flint_free(divs[i].A->coeffs);
588         flint_free(divs[i].A->exps);
589     }
590 }
591 
_nmod_mpoly_mul_heap_threaded(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,flint_bitcnt_t bits,slong N,const ulong * cmpmask,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)592 void _nmod_mpoly_mul_heap_threaded(
593     nmod_mpoly_t A,
594     const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
595     const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
596     flint_bitcnt_t bits,
597     slong N,
598     const ulong * cmpmask,
599     const nmod_mpoly_ctx_t ctx,
600     const thread_pool_handle * handles,
601     slong num_handles)
602 {
603     slong i;
604     _base_t base;
605     _div_struct * divs;
606     _worker_arg_struct * args;
607     slong BClen, Alen;
608 
609     /* bail if product of lengths overflows a word */
610     if (z_mul_checked(&BClen, Blen, Clen))
611     {
612         _nmod_mpoly_mul_johnson(A, Bcoeff, Bexp, Blen,
613                                Ccoeff, Cexp, Clen, bits, N, cmpmask, ctx->mod);
614         return;
615     }
616 
617     base->nthreads = num_handles + 1;
618     base->ndivs    = base->nthreads*4;  /* number of divisons */
619     base->Bcoeff = Bcoeff;
620     base->Bexp = Bexp;
621     base->Blen = Blen;
622     base->Ccoeff = Ccoeff;
623     base->Cexp = Cexp;
624     base->Clen = Clen;
625     base->bits = bits;
626     base->N = N;
627     base->cmpmask = cmpmask;
628     base->idx = base->ndivs - 1;    /* decremented by worker threads */
629     base->ctx = ctx;
630 
631     divs = (_div_struct *) flint_malloc(base->ndivs*sizeof(_div_struct));
632     args = (_worker_arg_struct *) flint_malloc(base->nthreads
633                                                   *sizeof(_worker_arg_struct));
634 
635     /* allocate space and set the boundary for each division */
636     FLINT_ASSERT(BClen/Blen == Clen);
637     for (i = base->ndivs - 1; i >= 0; i--)
638     {
639         double d = (double)(i + 1) / (double)(base->ndivs);
640 
641         /* divisions decrease in size so that no worker finishes too early */
642         divs[i].lower = (d * d) * BClen;
643         divs[i].lower = FLINT_MIN(divs[i].lower, BClen);
644         divs[i].lower = FLINT_MAX(divs[i].lower, WORD(0));
645         divs[i].upper = divs[i].lower;
646         divs[i].Aoffset = -WORD(1);
647         divs[i].thread_idx = -WORD(1);
648 
649         if (i == base->ndivs - 1)
650         {
651             /* highest division writes to original poly */
652             *divs[i].A = *A;
653         }
654         else
655         {
656             /* lower divisions write to a new worker poly */
657             divs[i].A->exps = NULL;
658             divs[i].A->coeffs = NULL;
659             divs[i].A->bits = A->bits;
660             divs[i].A->coeffs_alloc = 0;
661             divs[i].A->exps_alloc = 0;
662         }
663         divs[i].A->length = 0;
664     }
665 
666     /* compute each chunk in parallel */
667 #if FLINT_USES_PTHREAD
668     pthread_mutex_init(&base->mutex, NULL);
669 #endif
670     for (i = 0; i < num_handles; i++)
671     {
672         args[i].idx = i;
673         args[i].base = base;
674         args[i].divs = divs;
675         thread_pool_wake(global_thread_pool, handles[i], 0,
676                                _nmod_mpoly_mul_heap_threaded_worker, &args[i]);
677     }
678     i = num_handles;
679     args[i].idx = i;
680     args[i].base = base;
681     args[i].divs = divs;
682     _nmod_mpoly_mul_heap_threaded_worker(&args[i]);
683     for (i = 0; i < num_handles; i++)
684     {
685         thread_pool_wait(global_thread_pool, handles[i]);
686     }
687 
688     /* calculate and allocate space for final answer */
689     i = base->ndivs - 1;
690     *A = *divs[i].A;
691     Alen = A->length;
692     for (i = base->ndivs - 2; i >= 0; i--)
693     {
694         divs[i].Aoffset = Alen;
695         Alen += divs[i].A->length;
696     }
697     nmod_mpoly_fit_length(A, Alen, ctx);
698 
699     base->Acoeff = A->coeffs;
700     base->Aexp = A->exps;
701 
702     /* join answers */
703     for (i = 0; i < num_handles; i++)
704         thread_pool_wake(global_thread_pool, handles[i], 0, _join_worker, &args[i]);
705     _join_worker(&args[num_handles]);
706     for (i = 0; i < num_handles; i++)
707         thread_pool_wait(global_thread_pool, handles[i]);
708 
709     A->length = Alen;
710 
711 #if FLINT_USES_PTHREAD
712     pthread_mutex_destroy(&base->mutex);
713 #endif
714 
715     flint_free(args);
716     flint_free(divs);
717 }
718 
719 
720 /* maxBfields gets clobbered */
_nmod_mpoly_mul_heap_threaded_pool_maxfields(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)721 void _nmod_mpoly_mul_heap_threaded_pool_maxfields(
722     nmod_mpoly_t A,
723     const nmod_mpoly_t B, fmpz * maxBfields,
724     const nmod_mpoly_t C, fmpz * maxCfields,
725     const nmod_mpoly_ctx_t ctx,
726     const thread_pool_handle * handles,
727     slong num_handles)
728 {
729     slong N;
730     flint_bitcnt_t Abits;
731     ulong * cmpmask;
732     ulong * Bexp, * Cexp;
733     int freeBexp, freeCexp;
734     nmod_mpoly_struct * a, T[1];
735 
736     TMP_INIT;
737 
738     TMP_START;
739 
740     _fmpz_vec_add(maxBfields, maxBfields, maxCfields, ctx->minfo->nfields);
741 
742     Abits = 1 + _fmpz_vec_max_bits(maxBfields, ctx->minfo->nfields);
743     Abits = FLINT_MAX(Abits, B->bits);
744     Abits = FLINT_MAX(Abits, C->bits);
745     Abits = mpoly_fix_bits(Abits, ctx->minfo);
746 
747     N = mpoly_words_per_exp(Abits, ctx->minfo);
748     cmpmask = (ulong*) TMP_ALLOC(N*sizeof(ulong));
749     mpoly_get_cmpmask(cmpmask, N, Abits, ctx->minfo);
750 
751     /* ensure input exponents are packed into same sized fields as output */
752     freeBexp = 0;
753     Bexp = B->exps;
754     if (Abits > B->bits)
755     {
756         freeBexp = 1;
757         Bexp = (ulong *) flint_malloc(N*B->length*sizeof(ulong));
758         mpoly_repack_monomials(Bexp, Abits, B->exps, B->bits, B->length, ctx->minfo);
759     }
760 
761     freeCexp = 0;
762     Cexp = C->exps;
763     if (Abits > C->bits)
764     {
765         freeCexp = 1;
766         Cexp = (ulong *) flint_malloc(N*C->length*sizeof(ulong));
767         mpoly_repack_monomials(Cexp, Abits, C->exps, C->bits, C->length, ctx->minfo);
768     }
769 
770     if (A == B || A == C)
771     {
772         nmod_mpoly_init(T, ctx);
773         a = T;
774     }
775     else
776     {
777         a = A;
778     }
779 
780     nmod_mpoly_fit_length_reset_bits(a, B->length + C->length, Abits, ctx);
781 
782     if (B->length > C->length)
783     {
784         _nmod_mpoly_mul_heap_threaded(a, C->coeffs, Cexp, C->length,
785                                          B->coeffs, Bexp, B->length,
786                              Abits, N, cmpmask, ctx, handles, num_handles);
787     }
788     else
789     {
790         _nmod_mpoly_mul_heap_threaded(a, B->coeffs, Bexp, B->length,
791                                          C->coeffs, Cexp, C->length,
792                              Abits, N, cmpmask, ctx, handles, num_handles);
793     }
794 
795     if (A == B || A == C)
796     {
797         nmod_mpoly_swap(A, T, ctx);
798         nmod_mpoly_clear(T, ctx);
799     }
800 
801     if (freeBexp)
802         flint_free(Bexp);
803 
804     if (freeCexp)
805         flint_free(Cexp);
806 
807     TMP_END;
808 }
809 
810 
nmod_mpoly_mul_heap_threaded(nmod_mpoly_t A,const nmod_mpoly_t B,const nmod_mpoly_t C,const nmod_mpoly_ctx_t ctx)811 void nmod_mpoly_mul_heap_threaded(
812     nmod_mpoly_t A,
813     const nmod_mpoly_t B,
814     const nmod_mpoly_t C,
815     const nmod_mpoly_ctx_t ctx)
816 {
817     slong i;
818     fmpz * maxBfields, * maxCfields;
819     thread_pool_handle * handles;
820     slong num_handles;
821     slong thread_limit = FLINT_MIN(B->length, C->length)/16;
822     TMP_INIT;
823 
824     if (B->length == 0 || C->length == 0)
825     {
826         nmod_mpoly_zero(A, ctx);
827         return;
828     }
829 
830     TMP_START;
831 
832     maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
833     maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
834     for (i = 0; i < ctx->minfo->nfields; i++)
835     {
836         fmpz_init(maxBfields + i);
837         fmpz_init(maxCfields + i);
838     }
839     mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
840     mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
841 
842     num_handles = flint_request_threads(&handles, thread_limit);
843 
844     _nmod_mpoly_mul_heap_threaded_pool_maxfields(A, B, maxBfields, C, maxCfields,
845                                                     ctx, handles, num_handles);
846 
847     flint_give_back_threads(handles, num_handles);
848 
849     for (i = 0; i < ctx->minfo->nfields; i++)
850     {
851         fmpz_clear(maxBfields + i);
852         fmpz_clear(maxCfields + i);
853     }
854 
855     TMP_END;
856 }
857