1 /*
2 Copyright (C) 2017-2019 Daniel Schultz
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <https://www.gnu.org/licenses/>.
10 */
11
12 #include <string.h>
13 #include <gmp.h>
14 #include <stdlib.h>
15 #include "thread_pool.h"
16 #include "nmod_mpoly.h"
17 #include "long_extras.h"
18
19
20 /*
21 set A = the part of B*C with exps in [start, end)
22 this functions reallocates A and returns the length of A
23 version for N == 1
24 */
_nmod_mpoly_mul_heap_part1(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const nmod_mpoly_stripe_t S)25 void _nmod_mpoly_mul_heap_part1(
26 nmod_mpoly_t A,
27 const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
28 const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
29 slong * start,
30 slong * end,
31 slong * hind,
32 const nmod_mpoly_stripe_t S)
33 {
34 const ulong cmpmask = S->cmpmask[0];
35 slong i, j;
36 ulong exp;
37 mpoly_heap_t * x;
38 slong next_loc;
39 slong heap_len;
40 mpoly_heap1_s * heap;
41 mpoly_heap_t * chain;
42 slong * store, * store_base;
43 slong Alen;
44 ulong * Aexp = A->exps;
45 mp_limb_t * Acoeff = A->coeffs;
46 mp_limb_t acc0, acc1, acc2, pp0, pp1;
47
48 FLINT_ASSERT(S->N == 1);
49
50 /* tmp allocs from S->big_mem */
51 i = 0;
52 store = store_base = (slong *) (S->big_mem + i);
53 i += 2*Blen*sizeof(slong);
54 heap = (mpoly_heap1_s *)(S->big_mem + i);
55 i += (Blen + 1)*sizeof(mpoly_heap1_s);
56 chain = (mpoly_heap_t *)(S->big_mem + i);
57 i += Blen*sizeof(mpoly_heap_t);
58 FLINT_ASSERT(i <= S->big_mem_alloc);
59
60 /* put all the starting nodes on the heap */
61 heap_len = 1; /* heap zero index unused */
62 next_loc = Blen + 4; /* something bigger than heap can ever be */
63 for (i = 0; i < Blen; i++)
64 {
65 hind[i] = 2*start[i] + 1;
66 }
67 for (i = 0; i < Blen; i++)
68 {
69 if ( (start[i] < end[i])
70 && ( (i == 0)
71 || (start[i] < start[i - 1])
72 )
73 )
74 {
75 x = chain + i;
76 x->i = i;
77 x->j = start[i];
78 x->next = NULL;
79 hind[x->i] = 2*(x->j + 1) + 0;
80 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
81 &next_loc, &heap_len, cmpmask);
82 }
83 }
84
85 Alen = 0;
86 while (heap_len > 1)
87 {
88 exp = heap[1].exp;
89
90 _nmod_mpoly_fit_length(&Acoeff, &A->coeffs_alloc,
91 &Aexp, &A->exps_alloc, 1, Alen + 1);
92
93 Aexp[Alen] = exp;
94
95 acc0 = acc1 = acc2 = 0;
96 do
97 {
98 x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
99
100 hind[x->i] |= WORD(1);
101 *store++ = x->i;
102 *store++ = x->j;
103 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
104 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
105
106 while ((x = x->next) != NULL)
107 {
108 hind[x->i] |= WORD(1);
109 *store++ = x->i;
110 *store++ = x->j;
111 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
112 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
113 }
114 } while (heap_len > 1 && heap[1].exp == exp);
115
116 NMOD_RED3(Acoeff[Alen], acc2, acc1, acc0, S->mod);
117 Alen += (Acoeff[Alen] != UWORD(0));
118
119 /* for each node temporarily stored */
120 while (store > store_base)
121 {
122 j = *--store;
123 i = *--store;
124
125 /* should we go right? */
126 if ( (i + 1 < Blen)
127 && (j + 0 < end[i + 1])
128 && (hind[i + 1] == 2*j + 1)
129 )
130 {
131 x = chain + i + 1;
132 x->i = i + 1;
133 x->j = j;
134 x->next = NULL;
135 hind[x->i] = 2*(x->j+1) + 0;
136 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
137 &next_loc, &heap_len, cmpmask);
138 }
139
140 /* should we go up? */
141 if ( (j + 1 < end[i + 0])
142 && ((hind[i] & 1) == 1)
143 && ( (i == 0)
144 || (hind[i - 1] >= 2*(j + 2) + 1)
145 )
146 )
147 {
148 x = chain + i;
149 x->i = i;
150 x->j = j + 1;
151 x->next = NULL;
152 hind[x->i] = 2*(x->j+1) + 0;
153 _mpoly_heap_insert1(heap, Bexp[x->i] + Cexp[x->j], x,
154 &next_loc, &heap_len, cmpmask);
155 }
156 }
157 }
158
159 A->coeffs = Acoeff;
160 A->exps = Aexp;
161 A->length = Alen;
162 }
163
164
_nmod_mpoly_mul_heap_part(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,slong * start,slong * end,slong * hind,const nmod_mpoly_stripe_t S)165 void _nmod_mpoly_mul_heap_part(
166 nmod_mpoly_t A,
167 const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
168 const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
169 slong * start,
170 slong * end,
171 slong * hind,
172 const nmod_mpoly_stripe_t S)
173 {
174 flint_bitcnt_t bits = S->bits;
175 slong N = S->N;
176 const ulong * cmpmask = S->cmpmask;
177 slong i, j;
178 slong next_loc;
179 slong heap_len;
180 ulong * exp, * exps;
181 ulong ** exp_list;
182 slong exp_next;
183 mpoly_heap_t * x;
184 mpoly_heap_s * heap;
185 mpoly_heap_t * chain;
186 slong * store, * store_base;
187 slong Alen;
188 ulong * Aexp = A->exps;
189 mp_limb_t * Acoeff = A->coeffs;
190 ulong acc0, acc1, acc2, pp0, pp1;
191
192 /* tmp allocs from S->big_mem */
193 i = 0;
194 store = store_base = (slong *) (S->big_mem + i);
195 i += 2*Blen*sizeof(slong);
196 exp_list = (ulong **) (S->big_mem + i);
197 i += Blen*sizeof(ulong *);
198 exps = (ulong *) (S->big_mem + i);
199 i += Blen*N*sizeof(ulong);
200 heap = (mpoly_heap_s *) (S->big_mem + i);
201 i += (Blen + 1)*sizeof(mpoly_heap_s);
202 chain = (mpoly_heap_t *) (S->big_mem + i);
203 i += Blen*sizeof(mpoly_heap_t);
204 FLINT_ASSERT(i <= S->big_mem_alloc);
205
206 /* put all the starting nodes on the heap */
207 heap_len = 1; /* heap zero index unused */
208 next_loc = Blen + 4; /* something bigger than heap can ever be */
209 exp_next = 0;
210 for (i = 0; i < Blen; i++)
211 exp_list[i] = exps + N*i;
212 for (i = 0; i < Blen; i++)
213 hind[i] = 2*start[i] + 1;
214 for (i = 0; i < Blen; i++)
215 {
216 if ( (start[i] < end[i])
217 && ( (i == 0)
218 || (start[i] < start[i - 1])
219 )
220 )
221 {
222 x = chain + i;
223 x->i = i;
224 x->j = start[i];
225 x->next = NULL;
226 hind[x->i] = 2*(x->j + 1) + 0;
227
228 if (bits <= FLINT_BITS)
229 mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
230 Cexp + N*x->j, N);
231 else
232 mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
233 Cexp + N*x->j, N);
234
235 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
236 &next_loc, &heap_len, N, cmpmask);
237 }
238 }
239
240 Alen = 0;
241 while (heap_len > 1)
242 {
243 exp = heap[1].exp;
244
245 _nmod_mpoly_fit_length(&Acoeff, &A->coeffs_alloc,
246 &Aexp, &A->exps_alloc, N, Alen + 1);
247
248 mpoly_monomial_set(Aexp + N*Alen, exp, N);
249
250 acc0 = acc1 = acc2 = 0;
251 do
252 {
253 exp_list[--exp_next] = heap[1].exp;
254
255 x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
256
257 hind[x->i] |= WORD(1);
258 *store++ = x->i;
259 *store++ = x->j;
260 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
261 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
262
263 while ((x = x->next) != NULL)
264 {
265 hind[x->i] |= WORD(1);
266 *store++ = x->i;
267 *store++ = x->j;
268 umul_ppmm(pp1, pp0, Bcoeff[x->i], Ccoeff[x->j]);
269 add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, WORD(0), pp1, pp0);
270 }
271 } while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N));
272
273 NMOD_RED3(Acoeff[Alen], acc2, acc1, acc0, S->mod);
274 Alen += (Acoeff[Alen] != UWORD(0));
275
276 /* for each node temporarily stored */
277 while (store > store_base)
278 {
279 j = *--store;
280 i = *--store;
281
282 /* should we go right? */
283 if ( (i + 1 < Blen)
284 && (j + 0 < end[i + 1])
285 && (hind[i + 1] == 2*j + 1)
286 )
287 {
288 x = chain + i + 1;
289 x->i = i + 1;
290 x->j = j;
291 x->next = NULL;
292
293 hind[x->i] = 2*(x->j + 1) + 0;
294
295 if (bits <= FLINT_BITS)
296 mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
297 Cexp + N*x->j, N);
298 else
299 mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
300 Cexp + N*x->j, N);
301
302 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
303 &next_loc, &heap_len, N, cmpmask);
304 }
305
306 /* should we go up? */
307 if ( (j + 1 < end[i + 0])
308 && ((hind[i] & 1) == 1)
309 && ( (i == 0)
310 || (hind[i - 1] >= 2*(j + 2) + 1)
311 )
312 )
313 {
314 x = chain + i;
315 x->i = i;
316 x->j = j + 1;
317 x->next = NULL;
318
319 hind[x->i] = 2*(x->j + 1) + 0;
320
321 if (bits <= FLINT_BITS)
322 mpoly_monomial_add(exp_list[exp_next], Bexp + N*x->i,
323 Cexp + N*x->j, N);
324 else
325 mpoly_monomial_add_mp(exp_list[exp_next], Bexp + N*x->i,
326 Cexp + N*x->j, N);
327
328 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
329 &next_loc, &heap_len, N, cmpmask);
330 }
331 }
332 }
333
334 A->coeffs = Acoeff;
335 A->exps = Aexp;
336 A->length = Alen;
337 }
338
339
340 /*
341 The workers calculate product terms from 4*n divisions, where n is the
342 number of threads.
343 */
344
345 typedef struct
346 {
347 volatile int idx;
348 #if FLINT_USES_PTHREAD
349 pthread_mutex_t mutex;
350 #endif
351 slong nthreads;
352 slong ndivs;
353 const nmod_mpoly_ctx_struct * ctx;
354 mp_limb_t * Acoeff;
355 ulong * Aexp;
356 const mp_limb_t * Bcoeff;
357 const ulong * Bexp;
358 slong Blen;
359 const mp_limb_t * Ccoeff;
360 const ulong * Cexp;
361 slong Clen;
362 slong N;
363 flint_bitcnt_t bits;
364 const ulong * cmpmask;
365 }
366 _base_struct;
367
368 typedef _base_struct _base_t[1];
369
370 typedef struct
371 {
372 slong lower;
373 slong upper;
374 slong thread_idx;
375 slong Aoffset;
376 nmod_mpoly_t A;
377 }
378 _div_struct;
379
380 typedef struct
381 {
382 nmod_mpoly_stripe_t S;
383 slong idx;
384 slong time;
385 _base_struct * base;
386 _div_struct * divs;
387 #if FLINT_USES_PTHREAD
388 pthread_mutex_t mutex;
389 pthread_cond_t cond;
390 #endif
391 slong * t1, * t2, * t3, * t4;
392 ulong * exp;
393 }
394 _worker_arg_struct;
395
396
397 /*
398 The workers simply take the next available division and calculate all
399 product terms in this division.
400 */
401
402 #define SWAP_PTRS(xx, yy) \
403 do { \
404 tt = xx; \
405 xx = yy; \
406 yy = tt; \
407 } while (0)
408
_nmod_mpoly_mul_heap_threaded_worker(void * arg_ptr)409 static void _nmod_mpoly_mul_heap_threaded_worker(void * arg_ptr)
410 {
411 _worker_arg_struct * arg = (_worker_arg_struct *) arg_ptr;
412 nmod_mpoly_stripe_struct * S = arg->S;
413 _div_struct * divs = arg->divs;
414 _base_struct * base = arg->base;
415 slong Blen = base->Blen;
416 slong N = base->N;
417 slong i, j;
418 ulong * exp;
419 slong score;
420 slong * start, * end, * t1, * t2, * t3, * t4, * tt;
421
422 exp = (ulong *) flint_malloc(N*sizeof(ulong));
423 t1 = (slong *) flint_malloc(Blen*sizeof(slong));
424 t2 = (slong *) flint_malloc(Blen*sizeof(slong));
425 t3 = (slong *) flint_malloc(Blen*sizeof(slong));
426 t4 = (slong *) flint_malloc(Blen*sizeof(slong));
427
428 S->N = N;
429 S->bits = base->bits;
430 S->cmpmask = base->cmpmask;
431 S->ctx = base->ctx;
432 S->mod = base->ctx->mod;
433
434 S->big_mem_alloc = 0;
435 if (N == 1)
436 {
437 S->big_mem_alloc += 2*Blen*sizeof(slong);
438 S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap1_s);
439 S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
440 }
441 else
442 {
443 S->big_mem_alloc += 2*Blen*sizeof(slong);
444 S->big_mem_alloc += (Blen + 1)*sizeof(mpoly_heap_s);
445 S->big_mem_alloc += Blen*sizeof(mpoly_heap_t);
446 S->big_mem_alloc += Blen*S->N*sizeof(ulong);
447 S->big_mem_alloc += Blen*sizeof(ulong *);
448 }
449 S->big_mem = (char *) flint_malloc(S->big_mem_alloc);
450
451 /* get index to start working on */
452 if (arg->idx + 1 < base->nthreads)
453 {
454 #if FLINT_USES_PTHREAD
455 pthread_mutex_lock(&base->mutex);
456 #endif
457 i = base->idx - 1;
458 base->idx = i;
459 #if FLINT_USES_PTHREAD
460 pthread_mutex_unlock(&base->mutex);
461 #endif
462 }
463 else
464 {
465 i = base->ndivs - 1;
466 }
467
468 while (i >= 0)
469 {
470 FLINT_ASSERT(divs[i].thread_idx == -WORD(1));
471 divs[i].thread_idx = arg->idx;
472
473 /* calculate start */
474 if (i + 1 < base-> ndivs)
475 {
476 mpoly_search_monomials(
477 &start, exp, &score, t1, t2, t3,
478 divs[i].lower, divs[i].lower,
479 base->Bexp, base->Blen, base->Cexp, base->Clen,
480 base->N, base->cmpmask);
481 if (start == t2)
482 {
483 SWAP_PTRS(t1, t2);
484 }
485 else if (start == t3)
486 {
487 SWAP_PTRS(t1, t3);
488 }
489 }
490 else
491 {
492 start = t1;
493 for (j = 0; j < base->Blen; j++)
494 start[j] = 0;
495 }
496
497 /* calculate end */
498 if (i > 0)
499 {
500 mpoly_search_monomials(
501 &end, exp, &score, t2, t3, t4,
502 divs[i - 1].lower, divs[i - 1].lower,
503 base->Bexp, base->Blen, base->Cexp, base->Clen,
504 base->N, base->cmpmask);
505
506
507 if (end == t3)
508 {
509 SWAP_PTRS(t2, t3);
510 }
511 else if (end == t4)
512 {
513 SWAP_PTRS(t2, t4);
514 }
515 }
516 else
517 {
518 end = t2;
519 for (j = 0; j < base->Blen; j++)
520 end[j] = base->Clen;
521 }
522 /* t3 and t4 are free for workspace at this point */
523
524 /* calculate products in [start, end) */
525 _nmod_mpoly_fit_length(&divs[i].A->coeffs, &divs[i].A->coeffs_alloc,
526 &divs[i].A->exps, &divs[i].A->exps_alloc, N, 256);
527 if (N == 1)
528 {
529 _nmod_mpoly_mul_heap_part1(divs[i].A,
530 base->Bcoeff, base->Bexp, base->Blen,
531 base->Ccoeff, base->Cexp, base->Clen,
532 start, end, t3, S);
533 }
534 else
535 {
536 _nmod_mpoly_mul_heap_part(divs[i].A,
537 base->Bcoeff, base->Bexp, base->Blen,
538 base->Ccoeff, base->Cexp, base->Clen,
539 start, end, t3, S);
540 }
541
542 /* get next index to work on */
543 #if FLINT_USES_PTHREAD
544 pthread_mutex_lock(&base->mutex);
545 #endif
546 i = base->idx - 1;
547 base->idx = i;
548 #if FLINT_USES_PTHREAD
549 pthread_mutex_unlock(&base->mutex);
550 #endif
551 }
552
553 /* clean up */
554 flint_free(S->big_mem);
555 flint_free(t4);
556 flint_free(t3);
557 flint_free(t2);
558 flint_free(t1);
559 flint_free(exp);
560 }
561
562
_join_worker(void * varg)563 static void _join_worker(void * varg)
564 {
565 _worker_arg_struct * arg = (_worker_arg_struct *) varg;
566 _div_struct * divs = arg->divs;
567 _base_struct * base = arg->base;
568 slong N = base->N;
569 slong i;
570
571 for (i = base->ndivs - 2; i >= 0; i--)
572 {
573 FLINT_ASSERT(divs[i].thread_idx != -WORD(1));
574
575 if (divs[i].thread_idx != arg->idx)
576 continue;
577
578 FLINT_ASSERT(divs[i].A->coeffs != NULL);
579 FLINT_ASSERT(divs[i].A->exps != NULL);
580
581 memcpy(base->Acoeff + divs[i].Aoffset, divs[i].A->coeffs,
582 divs[i].A->length*sizeof(mp_limb_t));
583
584 memcpy(base->Aexp + N*divs[i].Aoffset, divs[i].A->exps,
585 N*divs[i].A->length*sizeof(ulong));
586
587 flint_free(divs[i].A->coeffs);
588 flint_free(divs[i].A->exps);
589 }
590 }
591
_nmod_mpoly_mul_heap_threaded(nmod_mpoly_t A,const mp_limb_t * Bcoeff,const ulong * Bexp,slong Blen,const mp_limb_t * Ccoeff,const ulong * Cexp,slong Clen,flint_bitcnt_t bits,slong N,const ulong * cmpmask,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)592 void _nmod_mpoly_mul_heap_threaded(
593 nmod_mpoly_t A,
594 const mp_limb_t * Bcoeff, const ulong * Bexp, slong Blen,
595 const mp_limb_t * Ccoeff, const ulong * Cexp, slong Clen,
596 flint_bitcnt_t bits,
597 slong N,
598 const ulong * cmpmask,
599 const nmod_mpoly_ctx_t ctx,
600 const thread_pool_handle * handles,
601 slong num_handles)
602 {
603 slong i;
604 _base_t base;
605 _div_struct * divs;
606 _worker_arg_struct * args;
607 slong BClen, Alen;
608
609 /* bail if product of lengths overflows a word */
610 if (z_mul_checked(&BClen, Blen, Clen))
611 {
612 _nmod_mpoly_mul_johnson(A, Bcoeff, Bexp, Blen,
613 Ccoeff, Cexp, Clen, bits, N, cmpmask, ctx->mod);
614 return;
615 }
616
617 base->nthreads = num_handles + 1;
618 base->ndivs = base->nthreads*4; /* number of divisons */
619 base->Bcoeff = Bcoeff;
620 base->Bexp = Bexp;
621 base->Blen = Blen;
622 base->Ccoeff = Ccoeff;
623 base->Cexp = Cexp;
624 base->Clen = Clen;
625 base->bits = bits;
626 base->N = N;
627 base->cmpmask = cmpmask;
628 base->idx = base->ndivs - 1; /* decremented by worker threads */
629 base->ctx = ctx;
630
631 divs = (_div_struct *) flint_malloc(base->ndivs*sizeof(_div_struct));
632 args = (_worker_arg_struct *) flint_malloc(base->nthreads
633 *sizeof(_worker_arg_struct));
634
635 /* allocate space and set the boundary for each division */
636 FLINT_ASSERT(BClen/Blen == Clen);
637 for (i = base->ndivs - 1; i >= 0; i--)
638 {
639 double d = (double)(i + 1) / (double)(base->ndivs);
640
641 /* divisions decrease in size so that no worker finishes too early */
642 divs[i].lower = (d * d) * BClen;
643 divs[i].lower = FLINT_MIN(divs[i].lower, BClen);
644 divs[i].lower = FLINT_MAX(divs[i].lower, WORD(0));
645 divs[i].upper = divs[i].lower;
646 divs[i].Aoffset = -WORD(1);
647 divs[i].thread_idx = -WORD(1);
648
649 if (i == base->ndivs - 1)
650 {
651 /* highest division writes to original poly */
652 *divs[i].A = *A;
653 }
654 else
655 {
656 /* lower divisions write to a new worker poly */
657 divs[i].A->exps = NULL;
658 divs[i].A->coeffs = NULL;
659 divs[i].A->bits = A->bits;
660 divs[i].A->coeffs_alloc = 0;
661 divs[i].A->exps_alloc = 0;
662 }
663 divs[i].A->length = 0;
664 }
665
666 /* compute each chunk in parallel */
667 #if FLINT_USES_PTHREAD
668 pthread_mutex_init(&base->mutex, NULL);
669 #endif
670 for (i = 0; i < num_handles; i++)
671 {
672 args[i].idx = i;
673 args[i].base = base;
674 args[i].divs = divs;
675 thread_pool_wake(global_thread_pool, handles[i], 0,
676 _nmod_mpoly_mul_heap_threaded_worker, &args[i]);
677 }
678 i = num_handles;
679 args[i].idx = i;
680 args[i].base = base;
681 args[i].divs = divs;
682 _nmod_mpoly_mul_heap_threaded_worker(&args[i]);
683 for (i = 0; i < num_handles; i++)
684 {
685 thread_pool_wait(global_thread_pool, handles[i]);
686 }
687
688 /* calculate and allocate space for final answer */
689 i = base->ndivs - 1;
690 *A = *divs[i].A;
691 Alen = A->length;
692 for (i = base->ndivs - 2; i >= 0; i--)
693 {
694 divs[i].Aoffset = Alen;
695 Alen += divs[i].A->length;
696 }
697 nmod_mpoly_fit_length(A, Alen, ctx);
698
699 base->Acoeff = A->coeffs;
700 base->Aexp = A->exps;
701
702 /* join answers */
703 for (i = 0; i < num_handles; i++)
704 thread_pool_wake(global_thread_pool, handles[i], 0, _join_worker, &args[i]);
705 _join_worker(&args[num_handles]);
706 for (i = 0; i < num_handles; i++)
707 thread_pool_wait(global_thread_pool, handles[i]);
708
709 A->length = Alen;
710
711 #if FLINT_USES_PTHREAD
712 pthread_mutex_destroy(&base->mutex);
713 #endif
714
715 flint_free(args);
716 flint_free(divs);
717 }
718
719
720 /* maxBfields gets clobbered */
_nmod_mpoly_mul_heap_threaded_pool_maxfields(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)721 void _nmod_mpoly_mul_heap_threaded_pool_maxfields(
722 nmod_mpoly_t A,
723 const nmod_mpoly_t B, fmpz * maxBfields,
724 const nmod_mpoly_t C, fmpz * maxCfields,
725 const nmod_mpoly_ctx_t ctx,
726 const thread_pool_handle * handles,
727 slong num_handles)
728 {
729 slong N;
730 flint_bitcnt_t Abits;
731 ulong * cmpmask;
732 ulong * Bexp, * Cexp;
733 int freeBexp, freeCexp;
734 nmod_mpoly_struct * a, T[1];
735
736 TMP_INIT;
737
738 TMP_START;
739
740 _fmpz_vec_add(maxBfields, maxBfields, maxCfields, ctx->minfo->nfields);
741
742 Abits = 1 + _fmpz_vec_max_bits(maxBfields, ctx->minfo->nfields);
743 Abits = FLINT_MAX(Abits, B->bits);
744 Abits = FLINT_MAX(Abits, C->bits);
745 Abits = mpoly_fix_bits(Abits, ctx->minfo);
746
747 N = mpoly_words_per_exp(Abits, ctx->minfo);
748 cmpmask = (ulong*) TMP_ALLOC(N*sizeof(ulong));
749 mpoly_get_cmpmask(cmpmask, N, Abits, ctx->minfo);
750
751 /* ensure input exponents are packed into same sized fields as output */
752 freeBexp = 0;
753 Bexp = B->exps;
754 if (Abits > B->bits)
755 {
756 freeBexp = 1;
757 Bexp = (ulong *) flint_malloc(N*B->length*sizeof(ulong));
758 mpoly_repack_monomials(Bexp, Abits, B->exps, B->bits, B->length, ctx->minfo);
759 }
760
761 freeCexp = 0;
762 Cexp = C->exps;
763 if (Abits > C->bits)
764 {
765 freeCexp = 1;
766 Cexp = (ulong *) flint_malloc(N*C->length*sizeof(ulong));
767 mpoly_repack_monomials(Cexp, Abits, C->exps, C->bits, C->length, ctx->minfo);
768 }
769
770 if (A == B || A == C)
771 {
772 nmod_mpoly_init(T, ctx);
773 a = T;
774 }
775 else
776 {
777 a = A;
778 }
779
780 nmod_mpoly_fit_length_reset_bits(a, B->length + C->length, Abits, ctx);
781
782 if (B->length > C->length)
783 {
784 _nmod_mpoly_mul_heap_threaded(a, C->coeffs, Cexp, C->length,
785 B->coeffs, Bexp, B->length,
786 Abits, N, cmpmask, ctx, handles, num_handles);
787 }
788 else
789 {
790 _nmod_mpoly_mul_heap_threaded(a, B->coeffs, Bexp, B->length,
791 C->coeffs, Cexp, C->length,
792 Abits, N, cmpmask, ctx, handles, num_handles);
793 }
794
795 if (A == B || A == C)
796 {
797 nmod_mpoly_swap(A, T, ctx);
798 nmod_mpoly_clear(T, ctx);
799 }
800
801 if (freeBexp)
802 flint_free(Bexp);
803
804 if (freeCexp)
805 flint_free(Cexp);
806
807 TMP_END;
808 }
809
810
nmod_mpoly_mul_heap_threaded(nmod_mpoly_t A,const nmod_mpoly_t B,const nmod_mpoly_t C,const nmod_mpoly_ctx_t ctx)811 void nmod_mpoly_mul_heap_threaded(
812 nmod_mpoly_t A,
813 const nmod_mpoly_t B,
814 const nmod_mpoly_t C,
815 const nmod_mpoly_ctx_t ctx)
816 {
817 slong i;
818 fmpz * maxBfields, * maxCfields;
819 thread_pool_handle * handles;
820 slong num_handles;
821 slong thread_limit = FLINT_MIN(B->length, C->length)/16;
822 TMP_INIT;
823
824 if (B->length == 0 || C->length == 0)
825 {
826 nmod_mpoly_zero(A, ctx);
827 return;
828 }
829
830 TMP_START;
831
832 maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
833 maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
834 for (i = 0; i < ctx->minfo->nfields; i++)
835 {
836 fmpz_init(maxBfields + i);
837 fmpz_init(maxCfields + i);
838 }
839 mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
840 mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
841
842 num_handles = flint_request_threads(&handles, thread_limit);
843
844 _nmod_mpoly_mul_heap_threaded_pool_maxfields(A, B, maxBfields, C, maxCfields,
845 ctx, handles, num_handles);
846
847 flint_give_back_threads(handles, num_handles);
848
849 for (i = 0; i < ctx->minfo->nfields; i++)
850 {
851 fmpz_clear(maxBfields + i);
852 fmpz_clear(maxCfields + i);
853 }
854
855 TMP_END;
856 }
857