1 /*
2     Copyright (C) 2018 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include "string.h"
13 #include "thread_pool.h"
14 #include "nmod_mpoly.h"
15 
16 /* improve locality */
17 #define BLOCK 128
18 #define MAX_ARRAY_SIZE (WORD(300000))
19 #define MAX_LEX_SIZE (WORD(300))
20 
21 
22 typedef struct
23 {
24     slong idx;
25     slong work;
26     slong len;
27     nmod_mpoly_t poly;
28 }
29 _chunk_struct;
30 
31 
32 typedef struct
33 {
34 #if HAVE_PTHREAD
35     pthread_mutex_t mutex;
36 #endif
37     volatile int idx;
38     slong nthreads;
39     slong Al, Bl, Pl;
40     mp_limb_t * Acoeffs, * Bcoeffs;
41     slong * Amain, * Bmain;
42     ulong * Apexp, * Bpexp;
43     slong * perm;
44     slong nvars;
45     const ulong * mults;
46     slong array_size;
47     slong degb;
48     const nmod_mpoly_ctx_struct * ctx;
49     _chunk_struct * Pchunks;
50     int rev;
51 }
52 _base_struct;
53 
54 typedef _base_struct _base_t[1];
55 
56 
57 typedef struct
58 {
59     slong idx;
60     slong time;
61     _base_struct * base;
62     ulong * exp;
63 }
64 _worker_arg_struct;
65 
66 
67 /******************
68     LEX
69 ******************/
70 
_nmod_mpoly_mul_array_threaded_worker_LEX(void * varg)71 static void _nmod_mpoly_mul_array_threaded_worker_LEX(void * varg)
72 {
73     slong i, j, Pi;
74     _worker_arg_struct * arg = (_worker_arg_struct *) varg;
75     _base_struct * base = arg->base;
76     slong  Al = base->Al;
77     slong Bl = base->Bl;
78     slong Pl = base->Pl;
79     slong * Amain = base->Amain;
80     slong * Bmain = base->Bmain;
81     ulong * coeff_array;
82 
83     TMP_INIT;
84 
85     TMP_START;
86     coeff_array = (ulong *) TMP_ALLOC(3*base->array_size*sizeof(ulong));
87     for (j = 0; j < 3*base->array_size; j++)
88         coeff_array[j] = 0;
89 
90 #if HAVE_PTHREAD
91     pthread_mutex_lock(&base->mutex);
92 #endif
93     Pi = base->idx;
94     base->idx = Pi + 1;
95 #if HAVE_PTHREAD
96     pthread_mutex_unlock(&base->mutex);
97 #endif
98 
99     while (Pi < Pl)
100     {
101         slong len;
102         mp_limb_t t2, t1, t0, u1, u0;
103 
104         Pi = base->perm[Pi];
105 
106         /* work out bit counts for this chunk */
107         len = 0;
108         for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
109         {
110             if (j < Bl)
111             {
112                 len += FLINT_MIN(Amain[i + 1] - Amain[i],
113                                  Bmain[j + 1] - Bmain[j]);
114             }
115         }
116 
117         umul_ppmm(t1, t0, base->ctx->ffinfo->mod.n - 1,
118                           base->ctx->ffinfo->mod.n - 1);
119         umul_ppmm(t2, t1, t1, len);
120         umul_ppmm(u1, u0, t0, len);
121         add_sssaaaaaa(t2, t1, t0,  t2, t1, UWORD(0),  UWORD(0), u1, u0);
122 
123         (base->Pchunks + Pi)->len = 0;
124 
125         if (t2 != 0)
126         {
127             /* need three words */
128             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
129             {
130                 if (j >= Bl)
131                     continue;
132 
133                 _nmod_mpoly_addmul_array1_ulong3(coeff_array,
134                         base->Acoeffs + base->Amain[i],
135                             base->Apexp + base->Amain[i],
136                             base->Amain[i + 1] - base->Amain[i],
137                         base->Bcoeffs + base->Bmain[j],
138                             base->Bpexp + base->Bmain[j],
139                             base->Bmain[j + 1] - base->Bmain[j]);
140             }
141 
142             (base->Pchunks + Pi)->len =
143                 nmod_mpoly_append_array_sm3_LEX(
144                     (base->Pchunks + Pi)->poly, 0,
145                     coeff_array, base->mults, base->nvars - 1,
146                     base->array_size, Pl - Pi - 1, base->ctx);
147         }
148         else if (t1 != 0)
149         {
150             /* fits into two words */
151             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
152             {
153                 if (j >= Bl)
154                     continue;
155 
156                 _nmod_mpoly_addmul_array1_ulong2(coeff_array,
157                         base->Acoeffs + base->Amain[i],
158                             base->Apexp + base->Amain[i],
159                             base->Amain[i + 1] - base->Amain[i],
160                         base->Bcoeffs + base->Bmain[j],
161                             base->Bpexp + base->Bmain[j],
162                             base->Bmain[j + 1] - base->Bmain[j]);
163             }
164 
165             (base->Pchunks + Pi)->len =
166                 nmod_mpoly_append_array_sm2_LEX(
167                     (base->Pchunks + Pi)->poly, 0,
168                     coeff_array, base->mults, base->nvars - 1,
169                     base->array_size, Pl - Pi - 1, base->ctx);
170 
171         }
172         else if (t0 != 0)
173         {
174             /* fits into one word */
175             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
176             {
177                 if (j >= Bl)
178                     continue;
179 
180                 _nmod_mpoly_addmul_array1_ulong1(coeff_array,
181                         base->Acoeffs + base->Amain[i],
182                             base->Apexp + base->Amain[i],
183                             base->Amain[i + 1] - base->Amain[i],
184                         base->Bcoeffs + base->Bmain[j],
185                             base->Bpexp + base->Bmain[j],
186                             base->Bmain[j + 1] - base->Bmain[j]);
187             }
188 
189             (base->Pchunks + Pi)->len =
190                 nmod_mpoly_append_array_sm1_LEX(
191                     (base->Pchunks + Pi)->poly, 0,
192                     coeff_array, base->mults, base->nvars - 1,
193                     base->array_size, Pl - Pi - 1, base->ctx);
194         }
195 
196 #if HAVE_PTHREAD
197         pthread_mutex_lock(&base->mutex);
198 #endif
199 	Pi = base->idx;
200         base->idx = Pi + 1;
201 #if HAVE_PTHREAD
202         pthread_mutex_unlock(&base->mutex);
203 #endif
204     }
205 
206     TMP_END;
207 }
208 
_nmod_mpoly_mul_array_chunked_threaded_LEX(nmod_mpoly_t P,const nmod_mpoly_t A,const nmod_mpoly_t B,const ulong * mults,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)209 void _nmod_mpoly_mul_array_chunked_threaded_LEX(
210     nmod_mpoly_t P,
211     const nmod_mpoly_t A,
212     const nmod_mpoly_t B,
213     const ulong * mults,
214     const nmod_mpoly_ctx_t ctx,
215     const thread_pool_handle * handles,
216     slong num_handles)
217 {
218     slong nvars = ctx->minfo->nvars;
219     slong Pi, i, j, Plen, Pl, Al, Bl, array_size;
220     slong * Amain, * Bmain;
221     ulong * Apexp, * Bpexp;
222     _base_t base;
223     _worker_arg_struct * args;
224     _chunk_struct * Pchunks;
225     slong * perm;
226     TMP_INIT;
227 
228     array_size = 1;
229     for (i = 0; i < nvars - 1; i++) {
230         array_size *= mults[i];
231     }
232 
233     /* compute lengths of poly2 and poly3 in chunks */
234     Al = 1 + (slong) (A->exps[0] >> (A->bits*(nvars - 1)));
235     Bl = 1 + (slong) (B->exps[0] >> (B->bits*(nvars - 1)));
236 
237     TMP_START;
238 
239     /* compute indices and lengths of coefficients of polys in main variable */
240     Amain = (slong *) TMP_ALLOC((Al + 1)*sizeof(slong));
241     Bmain = (slong *) TMP_ALLOC((Bl + 1)*sizeof(slong));
242     Apexp = (ulong *) flint_malloc(A->length*sizeof(ulong));
243     Bpexp = (ulong *) flint_malloc(B->length*sizeof(ulong));
244     mpoly_main_variable_split_LEX(Amain, Apexp, A->exps, Al, A->length,
245                                                     mults, nvars - 1, A->bits);
246     mpoly_main_variable_split_LEX(Bmain, Bpexp, B->exps, Bl, B->length,
247                                                     mults, nvars - 1, B->bits);
248 
249     Pl = Al + Bl - 1;
250 
251     /* work out data for each chunk of the output */
252     Pchunks = (_chunk_struct *) TMP_ALLOC(Pl*sizeof(_chunk_struct));
253     perm = (slong *) TMP_ALLOC(Pl*sizeof(slong));
254     for (Pi = 0; Pi < Pl; Pi++)
255     {
256         nmod_mpoly_init2((Pchunks + Pi)->poly, 8, ctx);
257         nmod_mpoly_fit_bits((Pchunks + Pi)->poly, P->bits, ctx);
258         (Pchunks + Pi)->work = 0;
259         perm[Pi] = Pi;
260         for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
261         {
262             if (j >= Bl)
263                 continue;
264 
265             (Pchunks + Pi)->work += (Amain[i + 1] - Amain[i])
266                                    *(Bmain[j + 1] - Bmain[j]);
267         }
268     }
269 
270     for (i = 0; i < Pl; i++)
271     {
272         for (j = i; j > 0 && (Pchunks + perm[j - 1])->work
273                                    < (Pchunks + perm[j])->work; j--)
274         {
275             slong t = perm[j - 1];
276             perm[j - 1] = perm[j];
277             perm[j] = t;
278         }
279     }
280 
281     base->nthreads = num_handles + 1;
282     base->Al = Al;
283     base->Bl = Bl;
284     base->Pl = Pl;
285     base->Acoeffs = A->coeffs;
286     base->Amain = Amain;
287     base->Apexp = Apexp;
288     base->Bcoeffs = B->coeffs;
289     base->Bmain = Bmain;
290     base->Bpexp = Bpexp;
291     base->idx = 0;
292     base->perm = perm;
293     base->nvars = nvars;
294     base->ctx = ctx;
295     base->Pchunks = Pchunks;
296     base->array_size = array_size;
297     base->mults = mults;
298 
299     args = (_worker_arg_struct *) TMP_ALLOC(base->nthreads
300                                                   *sizeof(_worker_arg_struct));
301 
302 #if HAVE_PTHREAD
303     pthread_mutex_init(&base->mutex, NULL);
304 #endif
305     for (i = 0; i < num_handles; i++)
306     {
307         args[i].idx = i;
308         args[i].base = base;
309         thread_pool_wake(global_thread_pool, handles[i], 0,
310                           _nmod_mpoly_mul_array_threaded_worker_LEX, &args[i]);
311     }
312     i = num_handles;
313     args[i].idx = i;
314     args[i].base = base;
315     _nmod_mpoly_mul_array_threaded_worker_LEX(&args[i]);
316     for (i = 0; i < num_handles; i++)
317     {
318         thread_pool_wait(global_thread_pool, handles[i]);
319     }
320 #if HAVE_PTHREAD
321     pthread_mutex_destroy(&base->mutex);
322 #endif
323 
324     /* join answers */
325     Plen = 0;
326     for (Pi = 0; Pi < Pl; Pi++)
327     {
328         _nmod_mpoly_fit_length(&P->coeffs, &P->exps, &P->alloc,
329                                                 Plen + (Pchunks + Pi)->len, 1);
330 
331         FLINT_ASSERT((Pchunks + Pi)->poly->coeffs != NULL);
332         FLINT_ASSERT((Pchunks + Pi)->poly->exps != NULL);
333 
334         memcpy(P->exps + Plen, (Pchunks + Pi)->poly->exps, (Pchunks + Pi)->len*sizeof(ulong));
335         memcpy(P->coeffs + Plen, (Pchunks + Pi)->poly->coeffs, (Pchunks + Pi)->len*sizeof(mp_limb_t));
336 
337         Plen += (Pchunks + Pi)->len;
338 
339         flint_free((Pchunks + Pi)->poly->coeffs);
340         flint_free((Pchunks + Pi)->poly->exps);
341     }
342 
343     P->length = Plen;
344 
345     flint_free(Apexp);
346     flint_free(Bpexp);
347     TMP_END;
348 }
349 
350 
_nmod_mpoly_mul_array_threaded_pool_LEX(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)351 int _nmod_mpoly_mul_array_threaded_pool_LEX(
352     nmod_mpoly_t A,
353     const nmod_mpoly_t B, fmpz * maxBfields,
354     const nmod_mpoly_t C, fmpz * maxCfields,
355     const nmod_mpoly_ctx_t ctx,
356     const thread_pool_handle * handles,
357     slong num_handles)
358 {
359     slong i, exp_bits, array_size;
360     ulong max, * mults;
361     int success;
362     TMP_INIT;
363 
364     FLINT_ASSERT(B->length != 0);
365     FLINT_ASSERT(C->length != 0);
366 
367     FLINT_ASSERT(ctx->minfo->ord == ORD_LEX);
368 
369     FLINT_ASSERT(1 == mpoly_words_per_exp(B->bits, ctx->minfo));
370     FLINT_ASSERT(1 == mpoly_words_per_exp(C->bits, ctx->minfo));
371 
372     TMP_START;
373 
374     /* compute maximum exponents for each variable */
375     mults = (ulong *) TMP_ALLOC(ctx->minfo->nfields*sizeof(ulong));
376 
377     /* the field of index n-1 is the one that wil be pulled out */
378     i = ctx->minfo->nfields - 1;
379     FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
380     FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
381     mults[i] = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
382     max = mults[i];
383     if (((slong) mults[i]) <= 0 || mults[i] > MAX_LEX_SIZE)
384     {
385         success = 0;
386         goto cleanup;
387     }
388 
389     /* the fields of index n-2...0, contribute to the array size */
390     array_size = WORD(1);
391     for (i--; i >= 0; i--)
392     {
393         ulong hi;
394         FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
395         FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
396         mults[i] = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
397         max |= mults[i];
398         umul_ppmm(hi, array_size, array_size, mults[i]);
399         if (hi != 0 || (slong) mults[i] <= 0
400                     || array_size <= 0
401                     || array_size > MAX_ARRAY_SIZE)
402         {
403             success = 0;
404             goto cleanup;
405         }
406     }
407 
408     exp_bits = FLINT_MAX(MPOLY_MIN_BITS, FLINT_BIT_COUNT(max) + 1);
409     exp_bits = mpoly_fix_bits(exp_bits, ctx->minfo);
410 
411     /* array multiplication assumes result fits into 1 word */
412     if (1 != mpoly_words_per_exp(exp_bits, ctx->minfo))
413     {
414         success = 0;
415         goto cleanup;
416     }
417 
418     /* handle aliasing and do array multiplication */
419     if (A == B || A == C)
420     {
421         nmod_mpoly_t T;
422         nmod_mpoly_init3(T, B->length + C->length - 1, exp_bits, ctx);
423         _nmod_mpoly_mul_array_chunked_threaded_LEX(T, C, B, mults, ctx,
424                                                          handles, num_handles);
425         nmod_mpoly_swap(T, A, ctx);
426         nmod_mpoly_clear(T, ctx);
427     }
428     else
429     {
430         nmod_mpoly_fit_length(A, B->length + C->length - 1, ctx);
431         nmod_mpoly_fit_bits(A, exp_bits, ctx);
432         A->bits = exp_bits;
433         _nmod_mpoly_mul_array_chunked_threaded_LEX(A, C, B, mults, ctx,
434                                                          handles, num_handles);
435     }
436     success = 1;
437 
438 cleanup:
439 
440     TMP_END;
441 
442     return success;
443 }
444 
445 
446 
447 
448 /*****************************
449     DEGLEX and DEGREVLEX
450 *****************************/
451 
_nmod_mpoly_mul_array_threaded_worker_DEG(void * varg)452 static void _nmod_mpoly_mul_array_threaded_worker_DEG(void * varg)
453 {
454     slong i, j, Pi;
455     _worker_arg_struct * arg = (_worker_arg_struct *) varg;
456     _base_struct * base = arg->base;
457     slong  Al = base->Al;
458     slong Bl = base->Bl;
459     slong Pl = base->Pl;
460     slong * Amain = base->Amain;
461     slong * Bmain = base->Bmain;
462     ulong * coeff_array;
463     slong (* upack_sm1)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
464     slong (* upack_sm2)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
465     slong (* upack_sm3)(nmod_mpoly_t, slong, ulong *, slong, slong, slong, const nmod_mpoly_ctx_t);
466     TMP_INIT;
467 
468     upack_sm1 = &nmod_mpoly_append_array_sm1_DEGLEX;
469     upack_sm2 = &nmod_mpoly_append_array_sm2_DEGLEX;
470     upack_sm3 = &nmod_mpoly_append_array_sm3_DEGLEX;
471     if (base->rev)
472     {
473         upack_sm1 = &nmod_mpoly_append_array_sm1_DEGREVLEX;
474         upack_sm2 = &nmod_mpoly_append_array_sm2_DEGREVLEX;
475         upack_sm3 = &nmod_mpoly_append_array_sm3_DEGREVLEX;
476     }
477 
478     TMP_START;
479     coeff_array = (ulong *) TMP_ALLOC(3*base->array_size*sizeof(ulong));
480     for (j = 0; j < 3*base->array_size; j++)
481         coeff_array[j] = 0;
482 
483 #if HAVE_PTHREAD
484     pthread_mutex_lock(&base->mutex);
485 #endif
486     Pi = base->idx;
487     base->idx = Pi + 1;
488 #if HAVE_PTHREAD
489     pthread_mutex_unlock(&base->mutex);
490 #endif
491 
492     while (Pi < Pl)
493     {
494         slong len;
495         mp_limb_t t2, t1, t0, u1, u0;
496 
497         Pi = base->perm[Pi];
498 
499         /* work out bit counts for this chunk */
500         len = 0;
501         for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
502         {
503             if (j < Bl)
504             {
505                 len += FLINT_MIN(Amain[i + 1] - Amain[i],
506                                  Bmain[j + 1] - Bmain[j]);
507             }
508         }
509 
510         umul_ppmm(t1, t0, base->ctx->ffinfo->mod.n - 1,
511                           base->ctx->ffinfo->mod.n - 1);
512         umul_ppmm(t2, t1, t1, len);
513         umul_ppmm(u1, u0, t0, len);
514         add_sssaaaaaa(t2, t1, t0,  t2, t1, UWORD(0),  UWORD(0), u1, u0);
515 
516         (base->Pchunks + Pi)->len = 0;
517 
518         if (t2 != 0)
519         {
520             /* need three words */
521             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
522             {
523                 if (j >= Bl)
524                     continue;
525 
526                 _nmod_mpoly_addmul_array1_ulong3(coeff_array,
527                         base->Acoeffs + base->Amain[i],
528                             base->Apexp + base->Amain[i],
529                             base->Amain[i + 1] - base->Amain[i],
530                         base->Bcoeffs + base->Bmain[j],
531                             base->Bpexp + base->Bmain[j],
532                             base->Bmain[j + 1] - base->Bmain[j]);
533             }
534 
535             (base->Pchunks + Pi)->len = upack_sm3((base->Pchunks + Pi)->poly, 0,
536                        coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
537         }
538         else if (t1 != 0)
539         {
540             /* fits into two words */
541             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
542             {
543                 if (j >= Bl)
544                     continue;
545 
546                 _nmod_mpoly_addmul_array1_ulong2(coeff_array,
547                         base->Acoeffs + base->Amain[i],
548                             base->Apexp + base->Amain[i],
549                             base->Amain[i + 1] - base->Amain[i],
550                         base->Bcoeffs + base->Bmain[j],
551                             base->Bpexp + base->Bmain[j],
552                             base->Bmain[j + 1] - base->Bmain[j]);
553             }
554 
555             (base->Pchunks + Pi)->len = upack_sm2((base->Pchunks + Pi)->poly, 0,
556                        coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
557         }
558         else if (t0 != 0)
559         {
560             /* fits into one word */
561             for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
562             {
563                 if (j >= Bl)
564                     continue;
565 
566                 _nmod_mpoly_addmul_array1_ulong1(coeff_array,
567                         base->Acoeffs + base->Amain[i],
568                             base->Apexp + base->Amain[i],
569                             base->Amain[i + 1] - base->Amain[i],
570                         base->Bcoeffs + base->Bmain[j],
571                             base->Bpexp + base->Bmain[j],
572                             base->Bmain[j + 1] - base->Bmain[j]);
573             }
574 
575             (base->Pchunks + Pi)->len = upack_sm1((base->Pchunks + Pi)->poly, 0,
576                        coeff_array, Pl - Pi - 1, base->nvars, base->degb, base->ctx);
577         }
578 
579 #if HAVE_PTHREAD
580 	pthread_mutex_lock(&base->mutex);
581 #endif
582         Pi = base->idx;
583         base->idx = Pi + 1;
584 #if HAVE_PTHREAD
585         pthread_mutex_unlock(&base->mutex);
586 #endif
587     }
588 
589     TMP_END;
590 }
591 
592 
593 
_nmod_mpoly_mul_array_chunked_threaded_DEG(nmod_mpoly_t P,const nmod_mpoly_t A,const nmod_mpoly_t B,ulong degb,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)594 void _nmod_mpoly_mul_array_chunked_threaded_DEG(
595     nmod_mpoly_t P,
596     const nmod_mpoly_t A,
597     const nmod_mpoly_t B,
598     ulong degb,
599     const nmod_mpoly_ctx_t ctx,
600     const thread_pool_handle * handles,
601     slong num_handles)
602 {
603     slong nvars = ctx->minfo->nvars;
604     slong Pi, i, j, Plen, Pl, Al, Bl, array_size;
605     slong * Amain, * Bmain;
606     ulong * Apexp, * Bpexp;
607     _base_t base;
608     _worker_arg_struct * args;
609     _chunk_struct * Pchunks;
610     slong * perm;
611     TMP_INIT;
612 
613     /* compute lengths of poly2 and poly3 in chunks */
614     Al = 1 + (slong) (A->exps[0] >> (A->bits*nvars));
615     Bl = 1 + (slong) (B->exps[0] >> (B->bits*nvars));
616 
617     array_size = 1;
618     for (i = 0; i < nvars-1; i++) {
619         array_size *= degb;
620     }
621 
622     TMP_START;
623 
624     /* compute indices and lengths of coefficients of polys in main variable */
625     Amain = (slong *) TMP_ALLOC((Al + 1)*sizeof(slong));
626     Bmain = (slong *) TMP_ALLOC((Bl + 1)*sizeof(slong));
627     Apexp = (ulong *) flint_malloc(A->length*sizeof(ulong));
628     Bpexp = (ulong *) flint_malloc(B->length*sizeof(ulong));
629     mpoly_main_variable_split_DEG(Amain, Apexp, A->exps, Al, A->length,
630                                                          degb, nvars, A->bits);
631     mpoly_main_variable_split_DEG(Bmain, Bpexp, B->exps, Bl, B->length,
632                                                          degb, nvars, B->bits);
633 
634     Pl = Al + Bl - 1;
635     FLINT_ASSERT(Pl == degb);
636 
637     /* work out data for each chunk of the output */
638     Pchunks = (_chunk_struct *) TMP_ALLOC(Pl*sizeof(_chunk_struct));
639     perm = (slong *) TMP_ALLOC(Pl*sizeof(slong));
640     for (Pi = 0; Pi < Pl; Pi++)
641     {
642         nmod_mpoly_init2((Pchunks + Pi)->poly, 8, ctx);
643         nmod_mpoly_fit_bits((Pchunks + Pi)->poly, P->bits, ctx);
644         (Pchunks + Pi)->work = 0;
645         perm[Pi] = Pi;
646         for (i = 0, j = Pi; i < Al && j >= 0; i++, j--)
647         {
648             if (j < Bl)
649             {
650                 (Pchunks + Pi)->work += (Amain[i + 1] - Amain[i])
651                                        *(Bmain[j + 1] - Bmain[j]);
652             }
653         }
654     }
655 
656     for (i = 0; i < Pl; i++)
657     {
658         for (j = i; j > 0 && (Pchunks + perm[j-1])->work
659                                          < (Pchunks + perm[j])->work; j--)
660         {
661             slong t = perm[j - 1];
662             perm[j - 1] = perm[j];
663             perm[j] = t;
664         }
665     }
666 
667     base->nthreads = num_handles + 1;
668     base->Al = Al;
669     base->Bl = Bl;
670     base->Pl = Pl;
671     base->Acoeffs = A->coeffs;
672     base->Amain = Amain;
673     base->Apexp = Apexp;
674     base->Bcoeffs = B->coeffs;
675     base->Bmain = Bmain;
676     base->Bpexp = Bpexp;
677     base->idx = 0;
678     base->perm = perm;
679     base->nvars = nvars;
680     base->Pchunks = Pchunks;
681     base->ctx = ctx;
682     base->array_size = array_size;
683     base->degb = degb;
684     base->rev = (ctx->minfo->ord == ORD_DEGREVLEX);
685 
686     args = (_worker_arg_struct *) TMP_ALLOC(base->nthreads
687                                                   *sizeof(_worker_arg_struct));
688 
689 #if HAVE_PTHREAD
690     pthread_mutex_init(&base->mutex, NULL);
691 #endif
692     for (i = 0; i < num_handles; i++)
693     {
694         args[i].idx = i;
695         args[i].base = base;
696 
697         thread_pool_wake(global_thread_pool, handles[i], 0,
698                           _nmod_mpoly_mul_array_threaded_worker_DEG, &args[i]);
699     }
700     i = num_handles;
701     args[i].idx = i;
702     args[i].base = base;
703     _nmod_mpoly_mul_array_threaded_worker_DEG(&args[i]);
704     for (i = 0; i < num_handles; i++)
705     {
706         thread_pool_wait(global_thread_pool, handles[i]);
707     }
708 #if HAVE_PTHREAD
709     pthread_mutex_destroy(&base->mutex);
710 #endif
711 
712     /* join answers */
713     Plen = 0;
714     for (Pi = 0; Pi < Pl; Pi++)
715     {
716         _nmod_mpoly_fit_length(&P->coeffs, &P->exps, &P->alloc,
717                                                 Plen + (Pchunks + Pi)->len, 1);
718 
719         FLINT_ASSERT((Pchunks + Pi)->poly->coeffs != NULL);
720         FLINT_ASSERT((Pchunks + Pi)->poly->exps != NULL);
721 
722         memcpy(P->exps + Plen, (Pchunks + Pi)->poly->exps, (Pchunks + Pi)->len*sizeof(ulong));
723         memcpy(P->coeffs + Plen, (Pchunks + Pi)->poly->coeffs, (Pchunks + Pi)->len*sizeof(mp_limb_t));
724 
725         Plen += (Pchunks + Pi)->len;
726 
727         flint_free((Pchunks + Pi)->poly->coeffs);
728         flint_free((Pchunks + Pi)->poly->exps);
729     }
730 
731     P->length = Plen;
732 
733     flint_free(Apexp);
734     flint_free(Bpexp);
735     TMP_END;
736 }
737 
_nmod_mpoly_mul_array_threaded_pool_DEG(nmod_mpoly_t A,const nmod_mpoly_t B,fmpz * maxBfields,const nmod_mpoly_t C,fmpz * maxCfields,const nmod_mpoly_ctx_t ctx,const thread_pool_handle * handles,slong num_handles)738 int _nmod_mpoly_mul_array_threaded_pool_DEG(
739     nmod_mpoly_t A,
740     const nmod_mpoly_t B, fmpz * maxBfields,
741     const nmod_mpoly_t C, fmpz * maxCfields,
742     const nmod_mpoly_ctx_t ctx,
743     const thread_pool_handle * handles,
744     slong num_handles)
745 {
746     slong i, exp_bits, array_size;
747     ulong deg;
748     int success;
749 
750     FLINT_ASSERT(B->length != 0);
751     FLINT_ASSERT(C->length != 0);
752 
753     FLINT_ASSERT(  ctx->minfo->ord == ORD_DEGREVLEX
754                 || ctx->minfo->ord == ORD_DEGLEX);
755 
756     FLINT_ASSERT(1 == mpoly_words_per_exp(B->bits, ctx->minfo));
757     FLINT_ASSERT(1 == mpoly_words_per_exp(C->bits, ctx->minfo));
758 
759     /* the field of index n-1 is the one that wil be pulled out */
760     i = ctx->minfo->nfields - 1;
761     FLINT_ASSERT(fmpz_fits_si(maxBfields + i));
762     FLINT_ASSERT(fmpz_fits_si(maxCfields + i));
763     deg = 1 + fmpz_get_ui(maxBfields + i) + fmpz_get_ui(maxCfields + i);
764     if (((slong) deg) <= 0 || deg > MAX_ARRAY_SIZE)
765     {
766         success = 0;
767         goto cleanup;
768     }
769 
770     /* the fields of index n-2...1, contribute to the array size */
771     array_size = WORD(1);
772     for (i--; i >= 1; i--)
773     {
774         ulong hi;
775         umul_ppmm(hi, array_size, array_size, deg);
776         if (hi != WORD(0) || array_size <= 0
777                           || array_size > MAX_ARRAY_SIZE)
778         {
779             success = 0;
780             goto cleanup;
781         }
782     }
783 
784     exp_bits = FLINT_MAX(MPOLY_MIN_BITS, FLINT_BIT_COUNT(deg) + 1);
785     exp_bits = mpoly_fix_bits(exp_bits, ctx->minfo);
786 
787     /* array multiplication assumes result fit into 1 word */
788     if (1 != mpoly_words_per_exp(exp_bits, ctx->minfo))
789     {
790         success = 0;
791         goto cleanup;
792     }
793 
794     /* handle aliasing and do array multiplication */
795     if (A == B || A == C)
796     {
797         nmod_mpoly_t T;
798         nmod_mpoly_init3(T, B->length + C->length - 1, exp_bits, ctx);
799         _nmod_mpoly_mul_array_chunked_threaded_DEG(T, C, B, deg, ctx,
800                                                          handles, num_handles);
801         nmod_mpoly_swap(T, A, ctx);
802         nmod_mpoly_clear(T, ctx);
803     }
804     else
805     {
806         nmod_mpoly_fit_length(A, B->length + C->length - 1, ctx);
807         nmod_mpoly_fit_bits(A, exp_bits, ctx);
808         A->bits = exp_bits;
809         _nmod_mpoly_mul_array_chunked_threaded_DEG(A, C, B, deg, ctx,
810                                                          handles, num_handles);
811     }
812     success = 1;
813 
814 cleanup:
815 
816     return success;
817 }
818 
819 
nmod_mpoly_mul_array_threaded(nmod_mpoly_t A,const nmod_mpoly_t B,const nmod_mpoly_t C,const nmod_mpoly_ctx_t ctx)820 int nmod_mpoly_mul_array_threaded(
821     nmod_mpoly_t A,
822     const nmod_mpoly_t B,
823     const nmod_mpoly_t C,
824     const nmod_mpoly_ctx_t ctx)
825 {
826     slong i;
827     int success;
828     fmpz * maxBfields, * maxCfields;
829     thread_pool_handle * handles;
830     slong num_handles;
831     slong thread_limit = FLINT_MIN(B->length, C->length)/16;
832     TMP_INIT;
833 
834     if (B->length == 0 || C->length == 0)
835     {
836         nmod_mpoly_zero(A, ctx);
837         return 1;
838     }
839 
840     if (1 != mpoly_words_per_exp(B->bits, ctx->minfo) ||
841         1 != mpoly_words_per_exp(C->bits, ctx->minfo))
842     {
843         return 0;
844     }
845 
846     TMP_START;
847 
848     maxBfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
849     maxCfields = (fmpz *) TMP_ALLOC(ctx->minfo->nfields*sizeof(fmpz));
850     for (i = 0; i < ctx->minfo->nfields; i++)
851     {
852         fmpz_init(maxBfields + i);
853         fmpz_init(maxCfields + i);
854     }
855     mpoly_max_fields_fmpz(maxBfields, B->exps, B->length, B->bits, ctx->minfo);
856     mpoly_max_fields_fmpz(maxCfields, C->exps, C->length, C->bits, ctx->minfo);
857 
858     num_handles = flint_request_threads(&handles, thread_limit);
859 
860     switch (ctx->minfo->ord)
861     {
862         case ORD_LEX:
863         {
864             success = _nmod_mpoly_mul_array_threaded_pool_LEX(A,
865                       B, maxBfields, C, maxCfields, ctx, handles, num_handles);
866             break;
867         }
868         case ORD_DEGREVLEX:
869         case ORD_DEGLEX:
870         {
871             success = _nmod_mpoly_mul_array_threaded_pool_DEG(A,
872                       B, maxBfields, C, maxCfields, ctx, handles, num_handles);
873             break;
874         }
875         default:
876         {
877             success = 0;
878             break;
879         }
880     }
881 
882     flint_give_back_threads(handles, num_handles);
883 
884     for (i = 0; i < ctx->minfo->nfields; i++)
885     {
886         fmpz_clear(maxBfields + i);
887         fmpz_clear(maxCfields + i);
888     }
889 
890     TMP_END;
891     return success;
892 }
893