1 /*
2     Copyright (C) 2020 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include "fmpz_mod_mpoly.h"
13 
14 /* try to prove that A is not a square */
_is_proved_not_square(int count,flint_rand_t state,const fmpz * Acoeffs,const ulong * Aexps,slong Alen,flint_bitcnt_t Abits,const mpoly_ctx_t mctx,const fmpz_mod_ctx_t fctx)15 static int _is_proved_not_square(
16     int count,
17     flint_rand_t state,
18     const fmpz * Acoeffs,
19     const ulong * Aexps,
20     slong Alen,
21     flint_bitcnt_t Abits,
22     const mpoly_ctx_t mctx,
23     const fmpz_mod_ctx_t fctx)
24 {
25     int tries_left, success = 0;
26     slong i, N = mpoly_words_per_exp(Abits, mctx);
27     fmpz eval[1], * alphas;
28     ulong * t;
29     TMP_INIT;
30 
31     FLINT_ASSERT(Alen > 0);
32 
33     TMP_START;
34     t = (ulong *) TMP_ALLOC(N*sizeof(ulong));
35 
36     if (count == 1)
37     {
38         success = mpoly_is_proved_not_square(Aexps, Alen, Abits, N, t);
39         if (success)
40             goto cleanup;
41     }
42 
43     tries_left = 3*count;
44 
45     fmpz_init(eval);
46     alphas = (fmpz *) TMP_ALLOC(mctx->nvars*sizeof(fmpz));
47     for (i = 0; i < mctx->nvars; i++)
48         fmpz_init(alphas + i);
49 
50 next_p:
51 
52     for (i = 0; i < mctx->nvars; i++)
53         fmpz_randm(alphas + i, state, fmpz_mod_ctx_modulus(fctx));
54 
55     _fmpz_mod_mpoly_eval_all_fmpz_mod(eval, Acoeffs, Aexps, Alen, Abits,
56                                                            alphas, mctx, fctx);
57 
58     success = fmpz_jacobi(eval, fmpz_mod_ctx_modulus(fctx)) < 0;
59 
60     if (!success && --tries_left >= 0)
61         goto next_p;
62 
63     fmpz_clear(eval);
64     for (i = 0; i < mctx->nvars; i++)
65         fmpz_clear(alphas + i);
66 
67 cleanup:
68 
69     TMP_END;
70 
71     return success;
72 }
73 
74 
_fmpz_mod_mpoly_sqrt_heap1(fmpz_mod_mpoly_t Q,const fmpz * Acoeffs,const ulong * Aexps,slong Alen,flint_bitcnt_t bits,const mpoly_ctx_t mctx,const fmpz_mod_ctx_t fctx)75 static int _fmpz_mod_mpoly_sqrt_heap1(
76     fmpz_mod_mpoly_t Q,
77     const fmpz * Acoeffs,
78     const ulong * Aexps,
79     slong Alen,
80     flint_bitcnt_t bits,
81     const mpoly_ctx_t mctx,
82     const fmpz_mod_ctx_t fctx)
83 {
84     slong i, j, Qlen, Ai;
85     slong next_loc, heap_len = 1, heap_alloc;
86     mpoly_heap1_s * heap;
87     mpoly_heap_t * chain_nodes[64];
88     mpoly_heap_t ** chain;
89     slong exp_alloc;
90     slong * store, * store_base;
91     mpoly_heap_t * x;
92     fmpz * Qcoeffs = Q->coeffs;
93     ulong * Qexps = Q->exps;
94     ulong mask, exp, exp3 = 0;
95     ulong cmpmask;
96     mpz_t acc, acc2, modulus;
97     fmpz zero = 0;
98     const fmpz * s;
99     fmpz_t lc_inv;
100     int lt_divides;
101     flint_rand_t heuristic_state;
102     int heuristic_count = 0;
103 
104     fmpz_init(lc_inv);
105     mpz_init(modulus);
106     mpz_init(acc);
107     mpz_init(acc2);
108     fmpz_get_mpz(modulus, fmpz_mod_ctx_modulus(fctx));
109 
110     FLINT_ASSERT(mpoly_words_per_exp(bits, mctx) == 1);
111     mpoly_get_cmpmask(&cmpmask, 1, bits, mctx);
112 
113     flint_randinit(heuristic_state);
114 
115     /* alloc array of heap nodes which can be chained together */
116     next_loc = 2*n_sqrt(Alen) + 4;   /* something bigger than heap can ever be */
117     heap_alloc = next_loc - 3;
118     heap = (mpoly_heap1_s *) flint_malloc((heap_alloc + 1)*sizeof(mpoly_heap1_s));
119     chain_nodes[0] = (mpoly_heap_t *) flint_malloc(heap_alloc*sizeof(mpoly_heap_t));
120     chain = (mpoly_heap_t **) flint_malloc(heap_alloc*sizeof(mpoly_heap_t*));
121     store = store_base = (slong *) flint_malloc(2*heap_alloc*sizeof(mpoly_heap_t *));
122 
123     for (i = 0; i < heap_alloc; i++)
124        chain[i] = chain_nodes[0] + i;
125 
126     exp_alloc = 1;
127 
128     mask = mpoly_overflow_mask_sp(bits);
129 
130     /* "insert" (-1, 1, Aexps[1]) into "heap" */
131     Ai = 1;
132 
133     /* compute first term */
134     Qlen = 0;
135     _fmpz_mod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
136                                &Qexps, &Q->exps_alloc, 1, Qlen + 1);
137 
138     if (!fmpz_sqrtmod(Qcoeffs + 0, Acoeffs + 0, fmpz_mod_ctx_modulus(fctx)))
139         goto not_sqrt;
140 
141     Qlen = 1;
142 
143     /* precompute leading cofficient info */
144     fmpz_mod_add(lc_inv, Qcoeffs + 0, Qcoeffs + 0, fctx);
145     fmpz_mod_inv(lc_inv, lc_inv, fctx);
146 
147     if (!mpoly_monomial_halves1(Qexps + 0, Aexps[0], mask))
148         goto not_sqrt; /* exponent is not square */
149 
150     /* optimisation, compute final exponent */
151     {
152         if (fmpz_jacobi(Acoeffs + Alen - 1, fmpz_mod_ctx_modulus(fctx)) < 0)
153             goto not_sqrt;
154 
155         if (!mpoly_monomial_halves1(&exp3, Aexps[Alen - 1], mask))
156             goto not_sqrt; /* exponent is not square */
157 
158         exp3 += Qexps[0]; /* overflow not possible */
159     }
160 
161     while (heap_len > 1 || Ai < Alen)
162     {
163         _fmpz_mod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
164                                    &Qexps, &Q->exps_alloc, 1, Qlen + 1);
165 
166         if (heap_len > 1 && Ai < Alen && Aexps[Ai] == heap[1].exp)
167         {
168             /* take from both A and heap */
169             exp = Aexps[Ai];
170             s = Acoeffs + Ai;
171             Ai++;
172         }
173         else if (heap_len > 1 && (Ai >= Alen ||
174                           mpoly_monomial_gt1(heap[1].exp, Aexps[Ai], cmpmask)))
175         {
176             /* take only from heap */
177             exp = heap[1].exp;
178             s = &zero;
179             if (mpoly_monomial_overflows1(exp, mask))
180                 goto not_sqrt;
181         }
182         else
183         {
184             FLINT_ASSERT(Ai < Alen);
185 
186             /* take only from A */
187             exp = Aexps[Ai];
188             s = Acoeffs + Ai;
189             Ai++;
190 
191             goto skip_heap;
192         }
193 
194         /* total is always acc + 2*acc2 */
195         mpz_set_ui(acc, 0);
196         mpz_set_ui(acc2, 0);
197         do {
198             x = _mpoly_heap_pop1(heap, &heap_len, cmpmask);
199             do {
200                 mpz_ptr t;
201                 fmpz Qi, Qj;
202 
203                 *store++ = x->i;
204                 *store++ = x->j;
205 
206                 Qi = Qcoeffs[x->i];
207                 Qj = Qcoeffs[x->j];
208                 t = (x->i != x->j) ? acc2 : acc;
209 
210                 if (COEFF_IS_MPZ(Qi) && COEFF_IS_MPZ(Qj))
211                 {
212                     mpz_addmul(t, COEFF_TO_PTR(Qi), COEFF_TO_PTR(Qj));
213                 }
214                 else if (COEFF_IS_MPZ(Qi) && !COEFF_IS_MPZ(Qj))
215                 {
216                     flint_mpz_addmul_ui(t, COEFF_TO_PTR(Qi), Qj);
217                 }
218                 else if (!COEFF_IS_MPZ(Qi) && COEFF_IS_MPZ(Qj))
219                 {
220                     flint_mpz_addmul_ui(t, COEFF_TO_PTR(Qj), Qi);
221                 }
222                 else
223                 {
224                     ulong pp1, pp0;
225                     umul_ppmm(pp1, pp0, Qcoeffs[x->i], Qcoeffs[x->j]);
226                     flint_mpz_add_uiui(t, t, pp1, pp0);
227                 }
228             } while ((x = x->next) != NULL);
229         } while (heap_len > 1 && heap[1].exp == exp);
230 
231         mpz_addmul_ui(acc, acc2, 2);
232         mpz_tdiv_qr(acc2, _fmpz_promote(Qcoeffs + Qlen), acc, modulus);
233         _fmpz_demote_val(Qcoeffs + Qlen);
234 
235         fmpz_mod_sub(Qcoeffs + Qlen, s, Qcoeffs + Qlen, fctx);
236         s = Qcoeffs + Qlen;
237 
238         /* process nodes taken from the heap */
239         while (store > store_base)
240         {
241             j = *--store;
242             i = *--store;
243 
244             /* should we go right */
245             if (j < i)
246             {
247                 x = chain[i];
248                 x->i = i;
249                 x->j = j + 1;
250                 x->next = NULL;
251 
252                 _mpoly_heap_insert1(heap, Qexps[x->i] + Qexps[x->j], x,
253                                                 &next_loc, &heap_len, cmpmask);
254             }
255         }
256 
257     skip_heap:
258 
259         fmpz_mod_mul(Qcoeffs + Qlen, s, lc_inv, fctx);
260         if (fmpz_is_zero(Qcoeffs + Qlen))
261             continue;
262 
263         lt_divides = mpoly_monomial_divides1(Qexps + Qlen, exp, Qexps[0], mask);
264         if (!lt_divides)
265             goto not_sqrt;
266 
267         if (Qlen >= heap_alloc)
268         {
269             /* run some tests if the square root is getting long */
270             if (Qlen > Alen && _is_proved_not_square(
271                                         ++heuristic_count, heuristic_state,
272                                        Acoeffs, Aexps, Alen, bits, mctx, fctx))
273             {
274                 goto not_sqrt;
275             }
276 
277             heap_alloc *= 2;
278             heap = (mpoly_heap1_s *) flint_realloc(heap, (heap_alloc + 1)*sizeof(mpoly_heap1_s));
279             chain_nodes[exp_alloc] = (mpoly_heap_t *) flint_malloc((heap_alloc/2)*sizeof(mpoly_heap_t));
280             chain = (mpoly_heap_t **) flint_realloc(chain, heap_alloc*sizeof(mpoly_heap_t*));
281             store = store_base = (slong *) flint_realloc(store_base, 2*heap_alloc*sizeof(mpoly_heap_t *));
282             for (i = 0; i < heap_alloc/2; i++)
283                 chain[i + heap_alloc/2] = chain_nodes[exp_alloc] + i;
284             exp_alloc++;
285         }
286 
287         /* put (Qlen, 1) in heap */
288         i = Qlen;
289         x = chain[i];
290         x->i = i;
291         x->j = 1;
292         x->next = NULL;
293 
294         _mpoly_heap_insert1(heap, Qexps[i] + Qexps[1], x,
295                                                 &next_loc, &heap_len, cmpmask);
296 
297         Qlen++;
298     }
299 
300 cleanup:
301 
302     flint_randclear(heuristic_state);
303 
304     Q->coeffs = Qcoeffs;
305     Q->exps = Qexps;
306     Q->length = Qlen;
307 
308     fmpz_clear(lc_inv);
309     mpz_clear(modulus);
310     mpz_clear(acc);
311     mpz_clear(acc2);
312 
313     flint_free(heap);
314     flint_free(chain);
315     flint_free(store_base);
316     for (i = 0; i < exp_alloc; i++)
317         flint_free(chain_nodes[i]);
318 
319     return Qlen > 0;
320 
321 not_sqrt:
322 
323     Qlen = 0;
324     goto cleanup;
325 }
326 
_fmpz_mod_mpoly_sqrt_heap(fmpz_mod_mpoly_t Q,const fmpz * Acoeffs,const ulong * Aexps,slong Alen,flint_bitcnt_t bits,const mpoly_ctx_t mctx,const fmpz_mod_ctx_t fctx)327 static int _fmpz_mod_mpoly_sqrt_heap(
328     fmpz_mod_mpoly_t Q,
329     const fmpz * Acoeffs,
330     const ulong * Aexps,
331     slong Alen,
332     flint_bitcnt_t bits,
333     const mpoly_ctx_t mctx,
334     const fmpz_mod_ctx_t fctx)
335 {
336     slong N = mpoly_words_per_exp(bits, mctx);
337     ulong * cmpmask;
338     slong i, j, Qlen, Ai;
339     slong next_loc;
340     slong heap_len = 1, heap_alloc;
341     int exp_alloc;
342     mpoly_heap_s * heap;
343     mpoly_heap_t * chain_nodes[64];
344     mpoly_heap_t ** chain;
345     slong * store, * store_base;
346     mpoly_heap_t * x;
347     fmpz * Qcoeffs = Q->coeffs;
348     ulong * Qexps = Q->exps;
349     ulong * exp, * exp3;
350     ulong * exps[64];
351     ulong ** exp_list;
352     slong exp_next;
353     ulong mask;
354     mpz_t acc, acc2, modulus;
355     fmpz zero = 0;
356     const fmpz * s;
357     fmpz_t lc_inv;
358     int halves, lt_divides;
359     flint_rand_t heuristic_state;
360     int heuristic_count = 0;
361     TMP_INIT;
362 
363     if (N == 1)
364         return _fmpz_mod_mpoly_sqrt_heap1(Q, Acoeffs, Aexps, Alen, bits,
365                                                                    mctx, fctx);
366 
367     fmpz_init(lc_inv);
368     mpz_init(modulus);
369     mpz_init(acc);
370     mpz_init(acc2);
371     fmpz_get_mpz(modulus, fmpz_mod_ctx_modulus(fctx));
372 
373     TMP_START;
374 
375     cmpmask = (ulong *) TMP_ALLOC(N*sizeof(ulong));
376     mpoly_get_cmpmask(cmpmask, N, bits, mctx);
377 
378     flint_randinit(heuristic_state);
379 
380     /* alloc array of heap nodes which can be chained together */
381     next_loc = 2*sqrt(Alen) + 4;   /* something bigger than heap can ever be */
382     heap_alloc = next_loc - 3;
383     heap = (mpoly_heap_s *) flint_malloc((heap_alloc + 1)*sizeof(mpoly_heap_s));
384     chain_nodes[0] = (mpoly_heap_t *) flint_malloc(heap_alloc*sizeof(mpoly_heap_t));
385     chain = (mpoly_heap_t **) flint_malloc(heap_alloc*sizeof(mpoly_heap_t*));
386     store = store_base = (slong *) flint_malloc(2*heap_alloc*sizeof(mpoly_heap_t *));
387 
388     for (i = 0; i < heap_alloc; i++)
389        chain[i] = chain_nodes[0] + i;
390 
391     /* array of exponent vectors, each of "N" words */
392     exps[0] = (ulong *) flint_malloc(heap_alloc*N*sizeof(ulong));
393     exp_alloc = 1;
394     /* list of pointers to available exponent vectors */
395     exp_list = (ulong **) flint_malloc(heap_alloc*sizeof(ulong *));
396     /* space to save copy of current exponent vector */
397     exp = (ulong *) TMP_ALLOC(N*sizeof(ulong));
398     /* final exponent */
399     exp3 = (ulong *) TMP_ALLOC(N*sizeof(ulong));
400     /* set up list of available exponent vectors */
401     exp_next = 0;
402     for (i = 0; i < heap_alloc; i++)
403         exp_list[i] = exps[0] + i*N;
404 
405     mask = (bits <= FLINT_BITS) ? mpoly_overflow_mask_sp(bits) : 0;
406 
407     /* "insert" (-1, 1, Aexps[0]) into "heap" */
408     Ai = 1;
409 
410     /* compute first term */
411     Qlen = 0;
412     _fmpz_mod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
413                                &Qexps, &Q->exps_alloc, 1, Qlen + 1);
414 
415 FLINT_ASSERT(Alen > 0);
416 FLINT_ASSERT(!fmpz_is_zero(Acoeffs + 0));
417 FLINT_ASSERT(fmpz_mod_is_canonical(Acoeffs + 0, fctx));
418 
419     if (!fmpz_sqrtmod(Qcoeffs + 0, Acoeffs + 0, fmpz_mod_ctx_modulus(fctx)))
420         goto not_sqrt;
421 
422     Qlen = 1;
423 
424     /* precompute leading cofficient info */
425     fmpz_mod_add(lc_inv, Qcoeffs + 0, Qcoeffs + 0, fctx);
426     fmpz_mod_inv(lc_inv, lc_inv, fctx);
427 
428     if (bits <= FLINT_BITS)
429         halves = mpoly_monomial_halves(Qexps + 0, Aexps + 0, N, mask);
430     else
431         halves = mpoly_monomial_halves_mp(Qexps + 0, Aexps + 0, N, bits);
432 
433     if (!halves)
434         goto not_sqrt; /* exponent is not square */
435 
436     /* optimisation, compute final term */
437     {
438         if (fmpz_jacobi(Acoeffs + Alen - 1, fmpz_mod_ctx_modulus(fctx)) < 0)
439             goto not_sqrt;
440 
441         if (bits <= FLINT_BITS)
442             halves = mpoly_monomial_halves(exp3, Aexps + (Alen - 1)*N, N, mask);
443         else
444             halves = mpoly_monomial_halves_mp(exp3, Aexps + (Alen - 1)*N, N, bits);
445 
446         if (!halves)
447             goto not_sqrt; /* exponent is not square */
448 
449         if (bits <= FLINT_BITS)
450             mpoly_monomial_add(exp3, exp3, Qexps + 0, N);
451         else
452             mpoly_monomial_add_mp(exp3, exp3, Qexps + 0, N);
453     }
454 
455     while (heap_len > 1 || Ai < Alen)
456     {
457         _fmpz_mod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
458                                    &Qexps, &Q->exps_alloc, N, Qlen + 1);
459 
460         if (heap_len > 1 && Ai < Alen &&
461             mpoly_monomial_equal(Aexps + N*Ai, heap[1].exp, N))
462         {
463             /* take from both A and heap */
464             mpoly_monomial_set(exp, Aexps + N*Ai, N);
465             s = Acoeffs + Ai;
466             Ai++;
467         }
468         else if (heap_len > 1 && (Ai >= Alen || mpoly_monomial_lt(
469                                        Aexps + N*Ai, heap[1].exp, N, cmpmask)))
470         {
471             /* take only from heap */
472             mpoly_monomial_set(exp, heap[1].exp, N);
473             s = &zero;
474             if (bits <= FLINT_BITS ? mpoly_monomial_overflows(exp, N, mask)
475                                    : mpoly_monomial_overflows_mp(exp, N, bits))
476                 goto not_sqrt;
477         }
478         else
479         {
480             FLINT_ASSERT(Ai < Alen);
481 
482             /* take only from A */
483             mpoly_monomial_set(exp, Aexps + N*Ai, N);
484             s = Acoeffs + Ai;
485             Ai++;
486 
487             goto skip_heap;
488         }
489 
490         /* total is always acc + 2*acc2 */
491         mpz_set_ui(acc, 0);
492         mpz_set_ui(acc2, 0);
493         do {
494             exp_list[--exp_next] = heap[1].exp;
495             x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
496             do {
497                 mpz_ptr t;
498                 fmpz Qi, Qj;
499 
500                 *store++ = x->i;
501                 *store++ = x->j;
502 
503                 Qi = Qcoeffs[x->i];
504                 Qj = Qcoeffs[x->j];
505                 t = (x->i != x->j) ? acc2 : acc;
506 
507                 if (COEFF_IS_MPZ(Qi) && COEFF_IS_MPZ(Qj))
508                 {
509                     mpz_addmul(t, COEFF_TO_PTR(Qi), COEFF_TO_PTR(Qj));
510                 }
511                 else if (COEFF_IS_MPZ(Qi) && !COEFF_IS_MPZ(Qj))
512                 {
513                     flint_mpz_addmul_ui(t, COEFF_TO_PTR(Qi), Qj);
514                 }
515                 else if (!COEFF_IS_MPZ(Qi) && COEFF_IS_MPZ(Qj))
516                 {
517                     flint_mpz_addmul_ui(t, COEFF_TO_PTR(Qj), Qi);
518                 }
519                 else
520                 {
521                     ulong pp1, pp0;
522                     umul_ppmm(pp1, pp0, Qcoeffs[x->i], Qcoeffs[x->j]);
523                     flint_mpz_add_uiui(t, t, pp1, pp0);
524                 }
525             } while ((x = x->next) != NULL);
526         } while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N));
527 
528         mpz_addmul_ui(acc, acc2, 2);
529         mpz_tdiv_qr(acc2, _fmpz_promote(Qcoeffs + Qlen), acc, modulus);
530         _fmpz_demote_val(Qcoeffs + Qlen);
531 
532         fmpz_mod_sub(Qcoeffs + Qlen, s, Qcoeffs + Qlen, fctx);
533         s = Qcoeffs + Qlen;
534 
535         /* process nodes taken from the heap */
536         while (store > store_base)
537         {
538             j = *--store;
539             i = *--store;
540 
541             /* should we go right */
542             if (j < i)
543             {
544                 x = chain[i];
545                 x->i = i;
546                 x->j = j + 1;
547                 x->next = NULL;
548 
549                 if (bits <= FLINT_BITS)
550                     mpoly_monomial_add(exp_list[exp_next], Qexps + N*x->i,
551                                                             Qexps + N*x->j, N);
552                 else
553                     mpoly_monomial_add_mp(exp_list[exp_next], Qexps + N*x->i,
554                                                             Qexps + N*x->j, N);
555 
556                 exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
557                                              &next_loc, &heap_len, N, cmpmask);
558             }
559         }
560 
561     skip_heap:
562 
563         fmpz_mod_mul(Qcoeffs + Qlen, s, lc_inv, fctx);
564         if (fmpz_is_zero(Qcoeffs + Qlen))
565             continue;
566 
567         if (bits <= FLINT_BITS)
568             lt_divides = mpoly_monomial_divides(Qexps + N*Qlen,
569                                                 exp, Qexps + N*0, N, mask);
570         else
571             lt_divides = mpoly_monomial_divides_mp(Qexps + N*Qlen,
572                                                 exp, Qexps + N*0, N, bits);
573         if (!lt_divides)
574             goto not_sqrt;
575 
576         if (Qlen >= heap_alloc)
577         {
578             /* run some tests if the square root is getting long */
579             if (Qlen > Alen && _is_proved_not_square(
580                                           ++heuristic_count, heuristic_state,
581                                        Acoeffs, Aexps, Alen, bits, mctx, fctx))
582             {
583                 goto not_sqrt;
584             }
585 
586             heap_alloc *= 2;
587             heap = (mpoly_heap_s *) flint_realloc(heap, (heap_alloc + 1)*sizeof(mpoly_heap_s));
588             chain_nodes[exp_alloc] = (mpoly_heap_t *) flint_malloc((heap_alloc/2)*sizeof(mpoly_heap_t));
589             chain = (mpoly_heap_t **) flint_realloc(chain, heap_alloc*sizeof(mpoly_heap_t*));
590             store = store_base = (slong *) flint_realloc(store_base, 2*heap_alloc*sizeof(mpoly_heap_t *));
591             exps[exp_alloc] = (ulong *) flint_malloc((heap_alloc/2)*N*sizeof(ulong));
592             exp_list = (ulong **) flint_realloc(exp_list, heap_alloc*sizeof(ulong *));
593             for (i = 0; i < heap_alloc/2; i++)
594             {
595                chain[i + heap_alloc/2] = chain_nodes[exp_alloc] + i;
596                exp_list[i + heap_alloc/2] = exps[exp_alloc] + i*N;
597             }
598             exp_alloc++;
599         }
600 
601         /* put (Qlen, 1) in heap */
602         i = Qlen;
603         x = chain[i];
604         x->i = i;
605         x->j = 1;
606         x->next = NULL;
607 
608         if (bits <= FLINT_BITS)
609             mpoly_monomial_add(exp_list[exp_next], Qexps + x->i*N,
610                                                       Qexps + x->j*N, N);
611         else
612             mpoly_monomial_add_mp(exp_list[exp_next], Qexps + x->i*N,
613                                                          Qexps + x->j*N, N);
614 
615         exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
616                                          &next_loc, &heap_len, N, cmpmask);
617 
618         Qlen++;
619     }
620 
621 cleanup:
622 
623     flint_randclear(heuristic_state);
624 
625     Q->coeffs = Qcoeffs;
626     Q->exps = Qexps;
627     Q->length = Qlen;
628 
629     fmpz_clear(lc_inv);
630     mpz_clear(modulus);
631     mpz_clear(acc);
632     mpz_clear(acc2);
633 
634     flint_free(heap);
635     flint_free(chain);
636     flint_free(store_base);
637     flint_free(exp_list);
638     for (i = 0; i < exp_alloc; i++)
639     {
640         flint_free(exps[i]);
641         flint_free(chain_nodes[i]);
642     }
643 
644     TMP_END;
645 
646     return Qlen > 0;
647 
648 not_sqrt:
649     Qlen = 0;
650     goto cleanup;
651 }
652 
653 
fmpz_mod_mpoly_sqrt_heap(fmpz_mod_mpoly_t Q,const fmpz_mod_mpoly_t A,const fmpz_mod_mpoly_ctx_t ctx)654 int fmpz_mod_mpoly_sqrt_heap(
655     fmpz_mod_mpoly_t Q,
656     const fmpz_mod_mpoly_t A,
657     const fmpz_mod_mpoly_ctx_t ctx)
658 {
659     int success;
660     slong lenq_est;
661 
662     if (fmpz_mod_mpoly_is_zero(A, ctx))
663     {
664         fmpz_mod_mpoly_zero(Q, ctx);
665         return 1;
666     }
667 
668     if (fmpz_abs_fits_ui(fmpz_mod_ctx_modulus(ctx->ffinfo)))
669     {
670         nmod_mpoly_ctx_t nctx;
671         nmod_mpoly_t nQ, nA;
672 
673         nctx->minfo[0] = ctx->minfo[0];
674         nmod_init(&nctx->mod, fmpz_get_ui(fmpz_mod_ctx_modulus(ctx->ffinfo)));
675         nmod_mpoly_init(nQ, nctx);
676         nmod_mpoly_init(nA, nctx);
677 
678         _fmpz_mod_mpoly_get_nmod_mpoly(nA, nctx, A, ctx);
679         success = nmod_mpoly_sqrt_heap(nQ, nA, nctx);
680         _fmpz_mod_mpoly_set_nmod_mpoly(Q, ctx, nQ, nctx);
681 
682         nmod_mpoly_clear(nA, nctx);
683         nmod_mpoly_clear(nQ, nctx);
684 
685         return success;
686     }
687 
688     lenq_est = n_sqrt(A->length);
689 
690     if (Q == A)
691     {
692         fmpz_mod_mpoly_t T;
693         fmpz_mod_mpoly_init3(T, lenq_est, A->bits, ctx);
694         success = _fmpz_mod_mpoly_sqrt_heap(T, A->coeffs, A->exps, A->length,
695                                              A->bits, ctx->minfo, ctx->ffinfo);
696         fmpz_mod_mpoly_swap(Q, T, ctx);
697         fmpz_mod_mpoly_clear(T, ctx);
698     }
699     else
700     {
701         fmpz_mod_mpoly_fit_length_reset_bits(Q, lenq_est, A->bits, ctx);
702         success = _fmpz_mod_mpoly_sqrt_heap(Q, A->coeffs, A->exps, A->length,
703                                              A->bits, ctx->minfo, ctx->ffinfo);
704     }
705 
706     return success;
707 }
708