1 /*
2     Copyright (C) 2021 Daniel Schultz
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include "fmpz_mod_mpoly.h"
13 
14 /*
15     sort terms in [left, right) by exponent
16     assuming that bits in position >= pos are already sorted
17     and assuming exponent vectors fit into one word
18     and assuming that all bit positions that need to be sorted are in totalmask
19 */
_fmpz_mod_mpoly_radix_sort1(fmpz * Acoeffs,ulong * Aexps,slong left,slong right,flint_bitcnt_t pos,ulong cmpmask,ulong totalmask)20 void _fmpz_mod_mpoly_radix_sort1(
21     fmpz * Acoeffs,
22     ulong * Aexps,
23     slong left, slong right,
24     flint_bitcnt_t pos,
25     ulong cmpmask,
26     ulong totalmask)
27 {
28     ulong mask, cmp;
29     slong mid, cur;
30 
31     while (pos > 0)
32     {
33         pos--;
34 
35         FLINT_ASSERT(left <= right);
36         FLINT_ASSERT(pos < FLINT_BITS);
37 
38         mask = UWORD(1) << pos;
39         cmp = cmpmask & mask;
40 
41         /* insertion base case */
42         if (right - left < 10)
43         {
44             slong i, j;
45 
46             for (i = left + 1; i < right; i++)
47             {
48                 for (j = i; j > left && mpoly_monomial_gt1(Aexps[j],
49                                                    Aexps[j - 1], cmpmask); j--)
50                 {
51                     fmpz_swap(Acoeffs + j, Acoeffs + j - 1);
52                     ULONG_SWAP(Aexps[j], Aexps[j - 1]);
53                 }
54             }
55 
56             return;
57         }
58 
59         /* return if there is no information to sort on this bit */
60         if ((totalmask & mask) == 0)
61             continue;
62 
63         /* find first 'zero' */
64         mid = left;
65         while (mid < right && (Aexps[mid] & mask) != cmp)
66             mid++;
67 
68         /* make sure [left,mid)  doesn't match cmpmask in position pos 'one'
69                      [mid,right)    does match cmpmask in position pos 'zero' */
70         cur = mid;
71         while (++cur < right)
72         {
73             if ((Aexps[cur] & mask) != cmp)
74             {
75                 fmpz_swap(Acoeffs + cur, Acoeffs + mid);
76                 ULONG_SWAP(Aexps[cur], Aexps[mid]);
77                 mid++;
78             }
79         }
80 
81         if (mid - left < right - mid)
82         {
83             _fmpz_mod_mpoly_radix_sort1(Acoeffs, Aexps, left, mid,
84                                                       pos, cmpmask, totalmask);
85             left = mid;
86         }
87         else
88         {
89             _fmpz_mod_mpoly_radix_sort1(Acoeffs, Aexps, mid, right,
90                                                       pos, cmpmask, totalmask);
91             right = mid;
92         }
93     }
94 }
95 
96 
97 /*
98     sort terms in [left, right) by exponent
99     assuming that bits in position >= pos are already sorted
100 */
_fmpz_mod_mpoly_radix_sort(fmpz * Acoeffs,ulong * Aexps,slong left,slong right,flint_bitcnt_t pos,slong N,ulong * cmpmask)101 void _fmpz_mod_mpoly_radix_sort(
102     fmpz * Acoeffs,
103     ulong * Aexps,
104     slong left, slong right,
105     flint_bitcnt_t pos,
106     slong N,
107     ulong * cmpmask)
108 {
109     ulong off, bit, mask, cmp;
110     slong mid, check;
111 
112     while (pos > 0)
113     {
114         pos--;
115 
116         FLINT_ASSERT(left <= right);
117         FLINT_ASSERT(pos < N*FLINT_BITS);
118 
119         off = pos/FLINT_BITS;
120         bit = pos%FLINT_BITS;
121         mask = UWORD(1) << bit;
122         cmp = cmpmask[off] & mask;
123 
124         /* insertion base case */
125         if (right - left < 20)
126         {
127             slong i, j;
128 
129             for (i = left + 1; i < right; i++)
130             {
131                 for (j = i; j > left && mpoly_monomial_gt(Aexps + N*j,
132                                          Aexps + N*(j - 1), N, cmpmask); j--)
133                 {
134                     fmpz_swap(Acoeffs + j, Acoeffs + j - 1);
135                     mpoly_monomial_swap(Aexps + N*j, Aexps + N*(j - 1), N);
136                 }
137             }
138 
139             return;
140         }
141 
142         /* find first 'zero' */
143         mid = left;
144         while (mid < right && ((Aexps+N*mid)[off] & mask) != cmp)
145             mid++;
146 
147         /* make sure [left,mid)  doesn't match cmpmask in position pos 'one'
148                      [mid,right)    does match cmpmask in position pos 'zero' */
149         check = mid;
150         while (++check < right)
151         {
152             if (((Aexps + N*check)[off] & mask) != cmp)
153             {
154                 fmpz_swap(Acoeffs + check, Acoeffs + mid);
155                 mpoly_monomial_swap(Aexps + N*check, Aexps + N*mid, N);
156                 mid++;
157             }
158         }
159 
160         FLINT_ASSERT(left <= mid && mid <= right);
161 
162         if (mid - left < right - mid)
163         {
164             _fmpz_mod_mpoly_radix_sort(Acoeffs, Aexps, left, mid,
165                                                               pos, N, cmpmask);
166             left = mid;
167         }
168         else
169         {
170             _fmpz_mod_mpoly_radix_sort(Acoeffs, Aexps, mid, right,
171                                                               pos, N, cmpmask);
172             right = mid;
173         }
174     }
175 }
176 
177 
178 /*
179     sort the terms in A by exponent
180     assuming that the exponents are valid (other than being in order)
181 */
fmpz_mod_mpoly_sort_terms(fmpz_mod_mpoly_t A,const fmpz_mod_mpoly_ctx_t ctx)182 void fmpz_mod_mpoly_sort_terms(fmpz_mod_mpoly_t A, const fmpz_mod_mpoly_ctx_t ctx)
183 {
184     slong i, N;
185     flint_bitcnt_t pos;
186     fmpz * Acoeffs = A->coeffs;
187     ulong * Aexps = A->exps;
188     ulong himask, * ptempexp;
189     TMP_INIT;
190 
191     TMP_START;
192     N = mpoly_words_per_exp(A->bits, ctx->minfo);
193     ptempexp = (ulong *) TMP_ALLOC(N*sizeof(ulong));
194     mpoly_get_cmpmask(ptempexp, N, A->bits, ctx->minfo);
195 
196     himask = 0;
197     for (i = 0; i < A->length; i++)
198         himask |= (Aexps + N*i)[N - 1];
199 
200     pos = FLINT_BIT_COUNT(himask);
201     if (N == 1)
202         _fmpz_mod_mpoly_radix_sort1(Acoeffs, Aexps, 0, A->length,
203                                                      pos, ptempexp[0], himask);
204     else
205         _fmpz_mod_mpoly_radix_sort(Acoeffs, Aexps, 0, A->length,
206                                         (N - 1)*FLINT_BITS + pos, N, ptempexp);
207 
208     TMP_END;
209 }
210