1 /*
2 Copyright (C) 2018 Daniel Schultz
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <http://www.gnu.org/licenses/>.
10 */
11
12 #include "string.h"
13 #include "thread_pool.h"
14 #include "nmod_mpoly.h"
15
16 /* improve locality */
17 #define BLOCK 128
18 #define MAX_ARRAY_SIZE (WORD(300000))
19 #define MAX_LEX_SIZE (WORD(300))
20
21
22 typedef struct
23 {
24 slong idx;
25 slong work;
26 slong len;
27 nmod_mpoly_t poly;
28 }
29 _chunk_struct;
30
31
32 typedef struct
33 {
34 #if HAVE_PTHREAD
35 pthread_mutex_t mutex;
36 #endif
37 volatile int idx;
38 slong nthreads;
39 slong Al, Bl, Pl;
40 mp_limb_t * Acoeffs, * Bcoeffs;
41 slong * Amain, * Bmain;
42 ulong * Apexp, * Bpexp;
43 slong * perm;
44 slong nvars;
45 const ulong * mults;
46 slong array_size;
47 slong degb;
48 const nmod_mpoly_ctx_struct * ctx;
49 _chunk_struct * Pchunks;
50 int rev;
51 }
52 _base_struct;
53
54 typedef _base_struct _base_t[1];
55
56
57 typedef struct
58 {
59 slong idx;
60 slong time;
61 _base_struct * base;
62 ulong * exp;
63 }
64 _worker_arg_struct;
65
66
67 /******************
68 LEX
69 ******************/
70
_nmod_mpoly_mul_array_threaded_worker_LEX(void * varg)71 static void _nmod_mpoly_mul_array_threaded_worker_LEX(void * varg)
72 {
73 slong i, j, Pi;
74 _worker_arg_struct * arg = (_worker_arg_struct *) varg;
75 _base_struct * base = arg->base;
76 slong Al = base->Al;
77 slong Bl = base->Bl;
78 slong Pl = base->Pl;
79 slong * Amain = base->Amain;
80 slong * Bmain = base->Bmain;
81 ulong * coeff_array;
82
83 TMP_INIT;
84
85 TMP_START;
86 coeff_array = (ulong *) TMP_ALLOC(3*base->array_size*sizeof(ulong));
87 for (j = 0; j < 3*base->array_size; j++)
88 coeff_array[j] = 0;
89
90 #if HAVE_PTHREAD
91 pthread_mutex_lock(&base->mutex);
92 #endif
93 Pi = base->idx;
94 base->idx = Pi + 1;
95 #if HAVE_PTHREAD
96 pthread_mutex_unlock(&base->mutex);
97 #endif
98
99 while (Pi < Pl)
100 {
101 slong len;
102 mp_limb_t t2, t1, t0, u1, u0;
103
104 Pi = base->perm[Pi];
105
106 /* work out bit counts for this chunk */
107 len = 0;
108 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
109 {
110 if (j < Bl)
111 {
112 len += FLINT_MIN(Amain[i + 1] - Amain[i],
113 Bmain[j + 1] - Bmain[j]);
114 }
115 }
116
117 umul_ppmm(t1, t0, base->ctx->ffinfo->mod.n - 1,
118 base->ctx->ffinfo->mod.n - 1);
119 umul_ppmm(t2, t1, t1, len);
120 umul_ppmm(u1, u0, t0, len);
121 add_sssaaaaaa(t2, t1, t0, t2, t1, UWORD(0), UWORD(0), u1, u0);
122
123 (base->Pchunks + Pi)->len = 0;
124
125 if (t2 != 0)
126 {
127 /* need three words */
128 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
129 {
130 if (j >= Bl)
131 continue;
132
133 _nmod_mpoly_addmul_array1_ulong3(coeff_array,
134 base->Acoeffs + base->Amain[i],
135 base->Apexp + base->Amain[i],
136 base->Amain[i + 1] - base->Amain[i],
137 base->Bcoeffs + base->Bmain[j],
138 base->Bpexp + base->Bmain[j],
139 base->Bmain[j + 1] - base->Bmain[j]);
140 }
141
142 (base->Pchunks + Pi)->len =
143 nmod_mpoly_append_array_sm3_LEX(
144 (base->Pchunks + Pi)->poly, 0,
145 coeff_array, base->mults, base->nvars - 1,
146 base->array_size, Pl - Pi - 1, base->ctx);
147 }
148 else if (t1 != 0)
149 {
150 /* fits into two words */
151 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
152 {
153 if (j >= Bl)
154 continue;
155
156 _nmod_mpoly_addmul_array1_ulong2(coeff_array,
157 base->Acoeffs + base->Amain[i],
158 base->Apexp + base->Amain[i],
159 base->Amain[i + 1] - base->Amain[i],
160 base->Bcoeffs + base->Bmain[j],
161 base->Bpexp + base->Bmain[j],
162 base->Bmain[j + 1] - base->Bmain[j]);
163 }
164
165 (base->Pchunks + Pi)->len =
166 nmod_mpoly_append_array_sm2_LEX(
167 (base->Pchunks + Pi)->poly, 0,
168 coeff_array, base->mults, base->nvars - 1,
169 base->array_size, Pl - Pi - 1, base->ctx);
170
171 }
172 else if (t0 != 0)
173 {
174 /* fits into one word */
175 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
176 {
177 if (j >= Bl)
178 continue;
179
180 _nmod_mpoly_addmul_array1_ulong1(coeff_array,
181 base->Acoeffs + base->Amain[i],
182 base->Apexp + base->Amain[i],
183 base->Amain[i + 1] - base->Amain[i],
184 base->Bcoeffs + base->Bmain[j],
185 base->Bpexp + base->Bmain[j],
186 base->Bmain[j + 1] - base->Bmain[j]);
187 }
188
189 (base->Pchunks + Pi)->len =
190 nmod_mpoly_append_array_sm1_LEX(
191 (base->Pchunks + Pi)->poly, 0,
192 coeff_array, base->mults, base->nvars - 1,
193 base->array_size, Pl - Pi - 1, base->ctx);
194 }
195
196 #if HAVE_PTHREAD
197 pthread_mutex_lock(&base->mutex);
198 #endif
199 Pi = base->idx;
200 base->idx = Pi + 1;
201 #if HAVE_PTHREAD
202 pthread_mutex_unlock(&base->mutex);
203 #endif
204 }
205
206 TMP_END;
207 }
208
_nmod_mpoly_mul_array_chunked_threaded_LEX(nmod_mpoly_t P,const nmod_mpoly_t A,const nmod_mpoly_t B,const ulong * mults,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)209 void _nmod_mpoly_mul_array_chunked_threaded_LEX(
210 nmod_mpoly_t P,
211 const nmod_mpoly_t A,
212 const nmod_mpoly_t B,
213 const ulong * mults,
214 const nmod_mpoly_ctx_t ctx,
215 const thread_pool_handle * handles,
216 slong num_handles)
217 {
218 slong nvars = ctx->minfo->nvars;
219 slong Pi, i, j, Plen, Pl, Al, Bl, array_size;
220 slong * Amain, * Bmain;
221 ulong * Apexp, * Bpexp;
222 _base_t base;
223 _worker_arg_struct * args;
224 _chunk_struct * Pchunks;
225 slong * perm;
226 TMP_INIT;
227
228 array_size = 1;
229 for (i = 0; i < nvars - 1; i++) {
230 array_size *= mults[i];
231 }
232
233 /* compute lengths of poly2 and poly3 in chunks */
234 Al = 1 + (slong) (A->exps[0] >> (A->bits*(nvars - 1)));
235 Bl = 1 + (slong) (B->exps[0] >> (B->bits*(nvars - 1)));
236
237 TMP_START;
238
239 /* compute indices and lengths of coefficients of polys in main variable */
240 Amain = (slong *) TMP_ALLOC((Al + 1)*sizeof(slong));
241 Bmain = (slong *) TMP_ALLOC((Bl + 1)*sizeof(slong));
242 Apexp = (ulong *) flint_malloc(A->length*sizeof(ulong));
243 Bpexp = (ulong *) flint_malloc(B->length*sizeof(ulong));
244 mpoly_main_variable_split_LEX(Amain, Apexp, A->exps, Al, A->length,
245 mults, nvars - 1, A->bits);
246 mpoly_main_variable_split_LEX(Bmain, Bpexp, B->exps, Bl, B->length,
247 mults, nvars - 1, B->bits);
248
249 Pl = Al + Bl - 1;
250
251 /* work out data for each chunk of the output */
252 Pchunks = (_chunk_struct *) TMP_ALLOC(Pl*sizeof(_chunk_struct));
253 perm = (slong *) TMP_ALLOC(Pl*sizeof(slong));
254 for (Pi = 0; Pi < Pl; Pi++)
255 {
256 nmod_mpoly_init2((Pchunks + Pi)->poly, 8, ctx);
257 nmod_mpoly_fit_bits((Pchunks + Pi)->poly, P->bits, ctx);
258 (Pchunks + Pi)->work = 0;
259 perm[Pi] = Pi;
260 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
261 {
262 if (j >= Bl)
263 continue;
264
265 (Pchunks + Pi)->work += (Amain[i + 1] - Amain[i])
266 *(Bmain[j + 1] - Bmain[j]);
267 }
268 }
269
270 for (i = 0; i < Pl; i++)
271 {
272 for (j = i; j > 0 && (Pchunks + perm[j - 1])->work
273 < (Pchunks + perm[j])->work; j--)
274 {
275 slong t = perm[j - 1];
276 perm[j - 1] = perm[j];
277 perm[j] = t;
278 }
279 }
280
281 base->nthreads = num_handles + 1;
282 base->Al = Al;
283 base->Bl = Bl;
284 base->Pl = Pl;
285 base->Acoeffs = A->coeffs;
286 base->Amain = Amain;
287 base->Apexp = Apexp;
288 base->Bcoeffs = B->coeffs;
289 base->Bmain = Bmain;
290 base->Bpexp = Bpexp;
291 base->idx = 0;
292 base->perm = perm;
293 base->nvars = nvars;
294 base->ctx = ctx;
295 base->Pchunks = Pchunks;
296 base->array_size = array_size;
297 base->mults = mults;
298
299 args = (_worker_arg_struct *) TMP_ALLOC(base->nthreads
300 *sizeof(_worker_arg_struct));
301
302 #if HAVE_PTHREAD
303 pthread_mutex_init(&base->mutex, NULL);
304 #endif
305 for (i = 0; i < num_handles; i++)
306 {
307 args[i].idx = i;
308 args[i].base = base;
309 thread_pool_wake(global_thread_pool, handles[i], 0,
310 _nmod_mpoly_mul_array_threaded_worker_LEX, &args[i]);
311 }
312 i = num_handles;
313 args[i].idx = i;
314 args[i].base = base;
315 _nmod_mpoly_mul_array_threaded_worker_LEX(&args[i]);
316 for (i = 0; i < num_handles; i++)
317 {
318 thread_pool_wait(global_thread_pool, handles[i]);
319 }
320 #if HAVE_PTHREAD
321 pthread_mutex_destroy(&base->mutex);
322 #endif
323
324 /* join answers */
325 Plen = 0;
326 for (Pi = 0; Pi < Pl; Pi++)
327 {
328 _nmod_mpoly_fit_length(&P->coeffs, &P->exps, &P->alloc,
329 Plen + (Pchunks + Pi)->len, 1);
330
331 FLINT_ASSERT((Pchunks + Pi)->poly->coeffs != NULL);
332 FLINT_ASSERT((Pchunks + Pi)->poly->exps != NULL);
333
334 memcpy(P->exps + Plen, (Pchunks + Pi)->poly->exps, (Pchunks + Pi)->len*sizeof(ulong));
335 memcpy(P->coeffs + Plen, (Pchunks + Pi)->poly->coeffs, (Pchunks + Pi)->len*sizeof(mp_limb_t));
336
337 Plen += (Pchunks + Pi)->len;
338
339 flint_free((Pchunks + Pi)->poly->coeffs);
340 flint_free((Pchunks + Pi)->poly->exps);
341 }
342
343 P->length = Plen;
344
345 flint_free(Apexp);
346 flint_free(Bpexp);
347 TMP_END;
348 }
349
350
_nmod_mpoly_mul_array_threaded_pool_LEX(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)351 int _nmod_mpoly_mul_array_threaded_pool_LEX(
352 nmod_mpoly_t A,
353 const nmod_mpoly_t B, fmpz * maxBfields,
354 const nmod_mpoly_t C, fmpz * maxCfields,
355 const nmod_mpoly_ctx_t ctx,
356 const thread_pool_handle * handles,
357 slong num_handles)
358 {
359 slong i, exp_bits, array_size;
360 ulong max, * mults;
361 int success;
362 TMP_INIT;
363
364 FLINT_ASSERT(B->length != 0);
365 FLINT_ASSERT(C->length != 0);
366
367 FLINT_ASSERT(ctx->minfo->ord == ORD_LEX);
368
369 FLINT_ASSERT(1 == mpoly_words_per_exp(B->bits, ctx->minfo));
370 FLINT_ASSERT(1 == mpoly_words_per_exp(C->bits, ctx->minfo));
371
372 TMP_START;
373
374 /* compute maximum exponents for each variable */
375 mults = (ulong *) TMP_ALLOC(ctx->minfo->nfields*sizeof(ulong));
376
377 /* the field of index n-1 is the one that wil be pulled out */
378 i = ctx->minfo->nfields - 1;
379 FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
380 FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
381 mults[i] = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
382 max = mults[i];
383 if (((slong) mults[i]) <= 0 || mults[i] > MAX_LEX_SIZE)
384 {
385 success = 0;
386 goto cleanup;
387 }
388
389 /* the fields of index n-2...0, contribute to the array size */
390 array_size = WORD(1);
391 for (i--; i >= 0; i--)
392 {
393 ulong hi;
394 FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
395 FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
396 mults[i] = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
397 max |= mults[i];
398 umul_ppmm(hi, array_size, array_size, mults[i]);
399 if (hi != 0 || (slong) mults[i] <= 0
400 || array_size <= 0
401 || array_size > MAX_ARRAY_SIZE)
402 {
403 success = 0;
404 goto cleanup;
405 }
406 }
407
408 exp_bits = FLINT_MAX(MPOLY_MIN_BITS, FLINT_BIT_COUNT(max) + 1);
409 exp_bits = mpoly_fix_bits(exp_bits, ctx->minfo);
410
411 /* array multiplication assumes result fits into 1 word */
412 if (1 != mpoly_words_per_exp(exp_bits, ctx->minfo))
413 {
414 success = 0;
415 goto cleanup;
416 }
417
418 /* handle aliasing and do array multiplication */
419 if (A == B || A == C)
420 {
421 nmod_mpoly_t T;
422 nmod_mpoly_init3(T, B->length + C->length - 1, exp_bits, ctx);
423 _nmod_mpoly_mul_array_chunked_threaded_LEX(T, C, B, mults, ctx,
424 handles, num_handles);
425 nmod_mpoly_swap(T, A, ctx);
426 nmod_mpoly_clear(T, ctx);
427 }
428 else
429 {
430 nmod_mpoly_fit_length(A, B->length + C->length - 1, ctx);
431 nmod_mpoly_fit_bits(A, exp_bits, ctx);
432 A->bits = exp_bits;
433 _nmod_mpoly_mul_array_chunked_threaded_LEX(A, C, B, mults, ctx,
434 handles, num_handles);
435 }
436 success = 1;
437
438 cleanup:
439
440 TMP_END;
441
442 return success;
443 }
444
445
446
447
448 /*****************************
449 DEGLEX and DEGREVLEX
450 *****************************/
451
_nmod_mpoly_mul_array_threaded_worker_DEG(void * varg)452 static void _nmod_mpoly_mul_array_threaded_worker_DEG(void * varg)
453 {
454 slong i, j, Pi;
455 _worker_arg_struct * arg = (_worker_arg_struct *) varg;
456 _base_struct * base = arg->base;
457 slong Al = base->Al;
458 slong Bl = base->Bl;
459 slong Pl = base->Pl;
460 slong * Amain = base->Amain;
461 slong * Bmain = base->Bmain;
462 ulong * coeff_array;
463 slong (* upack_sm1)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
464 slong (* upack_sm2)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
465 slong (* upack_sm3)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
466 TMP_INIT;
467
468 upack_sm1 = &nmod_mpoly_append_array_sm1_DEGLEX;
469 upack_sm2 = &nmod_mpoly_append_array_sm2_DEGLEX;
470 upack_sm3 = &nmod_mpoly_append_array_sm3_DEGLEX;
471 if (base->rev)
472 {
473 upack_sm1 = &nmod_mpoly_append_array_sm1_DEGREVLEX;
474 upack_sm2 = &nmod_mpoly_append_array_sm2_DEGREVLEX;
475 upack_sm3 = &nmod_mpoly_append_array_sm3_DEGREVLEX;
476 }
477
478 TMP_START;
479 coeff_array = (ulong *) TMP_ALLOC(3*base->array_size*sizeof(ulong));
480 for (j = 0; j < 3*base->array_size; j++)
481 coeff_array[j] = 0;
482
483 #if HAVE_PTHREAD
484 pthread_mutex_lock(&base->mutex);
485 #endif
486 Pi = base->idx;
487 base->idx = Pi + 1;
488 #if HAVE_PTHREAD
489 pthread_mutex_unlock(&base->mutex);
490 #endif
491
492 while (Pi < Pl)
493 {
494 slong len;
495 mp_limb_t t2, t1, t0, u1, u0;
496
497 Pi = base->perm[Pi];
498
499 /* work out bit counts for this chunk */
500 len = 0;
501 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
502 {
503 if (j < Bl)
504 {
505 len += FLINT_MIN(Amain[i + 1] - Amain[i],
506 Bmain[j + 1] - Bmain[j]);
507 }
508 }
509
510 umul_ppmm(t1, t0, base->ctx->ffinfo->mod.n - 1,
511 base->ctx->ffinfo->mod.n - 1);
512 umul_ppmm(t2, t1, t1, len);
513 umul_ppmm(u1, u0, t0, len);
514 add_sssaaaaaa(t2, t1, t0, t2, t1, UWORD(0), UWORD(0), u1, u0);
515
516 (base->Pchunks + Pi)->len = 0;
517
518 if (t2 != 0)
519 {
520 /* need three words */
521 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
522 {
523 if (j >= Bl)
524 continue;
525
526 _nmod_mpoly_addmul_array1_ulong3(coeff_array,
527 base->Acoeffs + base->Amain[i],
528 base->Apexp + base->Amain[i],
529 base->Amain[i + 1] - base->Amain[i],
530 base->Bcoeffs + base->Bmain[j],
531 base->Bpexp + base->Bmain[j],
532 base->Bmain[j + 1] - base->Bmain[j]);
533 }
534
535 (base->Pchunks + Pi)->len = upack_sm3((base->Pchunks + Pi)->poly, 0,
536 coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
537 }
538 else if (t1 != 0)
539 {
540 /* fits into two words */
541 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
542 {
543 if (j >= Bl)
544 continue;
545
546 _nmod_mpoly_addmul_array1_ulong2(coeff_array,
547 base->Acoeffs + base->Amain[i],
548 base->Apexp + base->Amain[i],
549 base->Amain[i + 1] - base->Amain[i],
550 base->Bcoeffs + base->Bmain[j],
551 base->Bpexp + base->Bmain[j],
552 base->Bmain[j + 1] - base->Bmain[j]);
553 }
554
555 (base->Pchunks + Pi)->len = upack_sm2((base->Pchunks + Pi)->poly, 0,
556 coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
557 }
558 else if (t0 != 0)
559 {
560 /* fits into one word */
561 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
562 {
563 if (j >= Bl)
564 continue;
565
566 _nmod_mpoly_addmul_array1_ulong1(coeff_array,
567 base->Acoeffs + base->Amain[i],
568 base->Apexp + base->Amain[i],
569 base->Amain[i + 1] - base->Amain[i],
570 base->Bcoeffs + base->Bmain[j],
571 base->Bpexp + base->Bmain[j],
572 base->Bmain[j + 1] - base->Bmain[j]);
573 }
574
575 (base->Pchunks + Pi)->len = upack_sm1((base->Pchunks + Pi)->poly, 0,
576 coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
577 }
578
579 #if HAVE_PTHREAD
580 pthread_mutex_lock(&base->mutex);
581 #endif
582 Pi = base->idx;
583 base->idx = Pi + 1;
584 #if HAVE_PTHREAD
585 pthread_mutex_unlock(&base->mutex);
586 #endif
587 }
588
589 TMP_END;
590 }
591
592
593
_nmod_mpoly_mul_array_chunked_threaded_DEG(nmod_mpoly_t P,const nmod_mpoly_t A,const nmod_mpoly_t B,ulong degb,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)594 void _nmod_mpoly_mul_array_chunked_threaded_DEG(
595 nmod_mpoly_t P,
596 const nmod_mpoly_t A,
597 const nmod_mpoly_t B,
598 ulong degb,
599 const nmod_mpoly_ctx_t ctx,
600 const thread_pool_handle * handles,
601 slong num_handles)
602 {
603 slong nvars = ctx->minfo->nvars;
604 slong Pi, i, j, Plen, Pl, Al, Bl, array_size;
605 slong * Amain, * Bmain;
606 ulong * Apexp, * Bpexp;
607 _base_t base;
608 _worker_arg_struct * args;
609 _chunk_struct * Pchunks;
610 slong * perm;
611 TMP_INIT;
612
613 /* compute lengths of poly2 and poly3 in chunks */
614 Al = 1 + (slong) (A->exps[0] >> (A->bits*nvars));
615 Bl = 1 + (slong) (B->exps[0] >> (B->bits*nvars));
616
617 array_size = 1;
618 for (i = 0; i < nvars-1; i++) {
619 array_size *= degb;
620 }
621
622 TMP_START;
623
624 /* compute indices and lengths of coefficients of polys in main variable */
625 Amain = (slong *) TMP_ALLOC((Al + 1)*sizeof(slong));
626 Bmain = (slong *) TMP_ALLOC((Bl + 1)*sizeof(slong));
627 Apexp = (ulong *) flint_malloc(A->length*sizeof(ulong));
628 Bpexp = (ulong *) flint_malloc(B->length*sizeof(ulong));
629 mpoly_main_variable_split_DEG(Amain, Apexp, A->exps, Al, A->length,
630 degb, nvars, A->bits);
631 mpoly_main_variable_split_DEG(Bmain, Bpexp, B->exps, Bl, B->length,
632 degb, nvars, B->bits);
633
634 Pl = Al + Bl - 1;
635 FLINT_ASSERT(Pl == degb);
636
637 /* work out data for each chunk of the output */
638 Pchunks = (_chunk_struct *) TMP_ALLOC(Pl*sizeof(_chunk_struct));
639 perm = (slong *) TMP_ALLOC(Pl*sizeof(slong));
640 for (Pi = 0; Pi < Pl; Pi++)
641 {
642 nmod_mpoly_init2((Pchunks + Pi)->poly, 8, ctx);
643 nmod_mpoly_fit_bits((Pchunks + Pi)->poly, P->bits, ctx);
644 (Pchunks + Pi)->work = 0;
645 perm[Pi] = Pi;
646 for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
647 {
648 if (j < Bl)
649 {
650 (Pchunks + Pi)->work += (Amain[i + 1] - Amain[i])
651 *(Bmain[j + 1] - Bmain[j]);
652 }
653 }
654 }
655
656 for (i = 0; i < Pl; i++)
657 {
658 for (j = i; j > 0 && (Pchunks + perm[j-1])->work
659 < (Pchunks + perm[j])->work; j--)
660 {
661 slong t = perm[j - 1];
662 perm[j - 1] = perm[j];
663 perm[j] = t;
664 }
665 }
666
667 base->nthreads = num_handles + 1;
668 base->Al = Al;
669 base->Bl = Bl;
670 base->Pl = Pl;
671 base->Acoeffs = A->coeffs;
672 base->Amain = Amain;
673 base->Apexp = Apexp;
674 base->Bcoeffs = B->coeffs;
675 base->Bmain = Bmain;
676 base->Bpexp = Bpexp;
677 base->idx = 0;
678 base->perm = perm;
679 base->nvars = nvars;
680 base->Pchunks = Pchunks;
681 base->ctx = ctx;
682 base->array_size = array_size;
683 base->degb = degb;
684 base->rev = (ctx->minfo->ord == ORD_DEGREVLEX);
685
686 args = (_worker_arg_struct *) TMP_ALLOC(base->nthreads
687 *sizeof(_worker_arg_struct));
688
689 #if HAVE_PTHREAD
690 pthread_mutex_init(&base->mutex, NULL);
691 #endif
692 for (i = 0; i < num_handles; i++)
693 {
694 args[i].idx = i;
695 args[i].base = base;
696
697 thread_pool_wake(global_thread_pool, handles[i], 0,
698 _nmod_mpoly_mul_array_threaded_worker_DEG, &args[i]);
699 }
700 i = num_handles;
701 args[i].idx = i;
702 args[i].base = base;
703 _nmod_mpoly_mul_array_threaded_worker_DEG(&args[i]);
704 for (i = 0; i < num_handles; i++)
705 {
706 thread_pool_wait(global_thread_pool, handles[i]);
707 }
708 #if HAVE_PTHREAD
709 pthread_mutex_destroy(&base->mutex);
710 #endif
711
712 /* join answers */
713 Plen = 0;
714 for (Pi = 0; Pi < Pl; Pi++)
715 {
716 _nmod_mpoly_fit_length(&P->coeffs, &P->exps, &P->alloc,
717 Plen + (Pchunks + Pi)->len, 1);
718
719 FLINT_ASSERT((Pchunks + Pi)->poly->coeffs != NULL);
720 FLINT_ASSERT((Pchunks + Pi)->poly->exps != NULL);
721
722 memcpy(P->exps + Plen, (Pchunks + Pi)->poly->exps, (Pchunks + Pi)->len*sizeof(ulong));
723 memcpy(P->coeffs + Plen, (Pchunks + Pi)->poly->coeffs, (Pchunks + Pi)->len*sizeof(mp_limb_t));
724
725 Plen += (Pchunks + Pi)->len;
726
727 flint_free((Pchunks + Pi)->poly->coeffs);
728 flint_free((Pchunks + Pi)->poly->exps);
729 }
730
731 P->length = Plen;
732
733 flint_free(Apexp);
734 flint_free(Bpexp);
735 TMP_END;
736 }
737
_nmod_mpoly_mul_array_threaded_pool_DEG(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)738 int _nmod_mpoly_mul_array_threaded_pool_DEG(
739 nmod_mpoly_t A,
740 const nmod_mpoly_t B, fmpz * maxBfields,
741 const nmod_mpoly_t C, fmpz * maxCfields,
742 const nmod_mpoly_ctx_t ctx,
743 const thread_pool_handle * handles,
744 slong num_handles)
745 {
746 slong i, exp_bits, array_size;
747 ulong deg;
748 int success;
749
750 FLINT_ASSERT(B->length != 0);
751 FLINT_ASSERT(C->length != 0);
752
753 FLINT_ASSERT( ctx->minfo->ord == ORD_DEGREVLEX
754 || ctx->minfo->ord == ORD_DEGLEX);
755
756 FLINT_ASSERT(1 == mpoly_words_per_exp(B->bits, ctx->minfo));
757 FLINT_ASSERT(1 == mpoly_words_per_exp(C->bits, ctx->minfo));
758
759 /* the field of index n-1 is the one that wil be pulled out */
760 i = ctx->minfo->nfields - 1;
761 FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
762 FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
763 deg = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
764 if (((slong) deg) <= 0 || deg > MAX_ARRAY_SIZE)
765 {
766 success = 0;
767 goto cleanup;
768 }
769
770 /* the fields of index n-2...1, contribute to the array size */
771 array_size = WORD(1);
772 for (i--; i >= 1; i--)
773 {
774 ulong hi;
775 umul_ppmm(hi, array_size, array_size, deg);
776 if (hi != WORD(0) || array_size <= 0
777 || array_size > MAX_ARRAY_SIZE)
778 {
779 success = 0;
780 goto cleanup;
781 }
782 }
783
784 exp_bits = FLINT_MAX(MPOLY_MIN_BITS, FLINT_BIT_COUNT(deg) + 1);
785 exp_bits = mpoly_fix_bits(exp_bits, ctx->minfo);
786
787 /* array multiplication assumes result fit into 1 word */
788 if (1 != mpoly_words_per_exp(exp_bits, ctx->minfo))
789 {
790 success = 0;
791 goto cleanup;
792 }
793
794 /* handle aliasing and do array multiplication */
795 if (A == B || A == C)
796 {
797 nmod_mpoly_t T;
798 nmod_mpoly_init3(T, B->length + C->length - 1, exp_bits, ctx);
799 _nmod_mpoly_mul_array_chunked_threaded_DEG(T, C, B, deg, ctx,
800 handles, num_handles);
801 nmod_mpoly_swap(T, A, ctx);
802 nmod_mpoly_clear(T, ctx);
803 }
804 else
805 {
806 nmod_mpoly_fit_length(A, B->length + C->length - 1, ctx);
807 nmod_mpoly_fit_bits(A, exp_bits, ctx);
808 A->bits = exp_bits;
809 _nmod_mpoly_mul_array_chunked_threaded_DEG(A, C, B, deg, ctx,
810 handles, num_handles);
811 }
812 success = 1;
813
814 cleanup:
815
816 return success;
817 }
818
819
nmod_mpoly_mul_array_threaded(nmod_mpoly_t A,const nmod_mpoly_t B,const nmod_mpoly_t C,const nmod_mpoly_ctx_t ctx)820 int nmod_mpoly_mul_array_threaded(
821 nmod_mpoly_t A,
822 const nmod_mpoly_t B,
823 const nmod_mpoly_t C,
824 const nmod_mpoly_ctx_t ctx)
825 {
826 slong i;
827 int success;
828 fmpz * maxBfields, * maxCfields;
829 thread_pool_handle * handles;
830 slong num_handles;
831 slong thread_limit = FLINT_MIN(B->length, C->length)/16;
832 TMP_INIT;
833
834 if (B->length == 0 || C->length == 0)
835 {
836 nmod_mpoly_zero(A, ctx);
837 return 1;
838 }
839
840 if (1 != mpoly_words_per_exp(B->bits, ctx->minfo) ||
841 1 != mpoly_words_per_exp(C->bits, ctx->minfo))
842 {
843 return 0;
844 }
845
846 TMP_START;
847
848 maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
849 maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
850 for (i = 0; i < ctx->minfo->nfields; i++)
851 {
852 fmpz_init(maxBfields + i);
853 fmpz_init(maxCfields + i);
854 }
855 mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
856 mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
857
858 num_handles = flint_request_threads(&handles, thread_limit);
859
860 switch (ctx->minfo->ord)
861 {
862 case ORD_LEX:
863 {
864 success = _nmod_mpoly_mul_array_threaded_pool_LEX(A,
865 B, maxBfields, C, maxCfields, ctx, handles, num_handles);
866 break;
867 }
868 case ORD_DEGREVLEX:
869 case ORD_DEGLEX:
870 {
871 success = _nmod_mpoly_mul_array_threaded_pool_DEG(A,
872 B, maxBfields, C, maxCfields, ctx, handles, num_handles);
873 break;
874 }
875 default:
876 {
877 success = 0;
878 break;
879 }
880 }
881
882 flint_give_back_threads(handles, num_handles);
883
884 for (i = 0; i < ctx->minfo->nfields; i++)
885 {
886 fmpz_clear(maxBfields + i);
887 fmpz_clear(maxCfields + i);
888 }
889
890 TMP_END;
891 return success;
892 }
893