1 /*
2     Copyright (C) 2011 Fredrik Johansson
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include "arith.h"
13 
14 static const int mod4_tab[8] = { 2, 1, 3, 0, 0, 3, 1, 2 };
15 
16 static const int gcd24_tab[24] = {
17     24, 1, 2, 3, 4, 1, 6, 1, 8, 3, 2, 1,
18     12, 1, 2, 3, 8, 1, 6, 1, 4, 3, 2, 1
19 };
20 
21 static mp_limb_t
n_sqrtmod_2exp(mp_limb_t a,int k)22 n_sqrtmod_2exp(mp_limb_t a, int k)
23 {
24     mp_limb_t x;
25     int i;
26 
27     if (a == 0 || k == 0)
28         return 0;
29 
30     if (k == 1)
31         return 1;
32 
33     if (k == 2)
34     {
35         if (a == 1)
36             return 1;
37         return 0;
38     }
39 
40     x = 1;
41     for (i = 3; i < k; i++)
42         x += (a - x * x) / 2;
43 
44     if (k < FLINT_BITS)
45         x &= ((UWORD(1) << k) - 1);
46 
47     return x;
48 }
49 
50 static mp_limb_t
n_sqrtmod_ppow(mp_limb_t a,mp_limb_t p,int k,mp_limb_t pk,mp_limb_t pkinv)51 n_sqrtmod_ppow(mp_limb_t a, mp_limb_t p, int k, mp_limb_t pk, mp_limb_t pkinv)
52 {
53     mp_limb_t r, t;
54     int i;
55 
56     /* n_sqrtmod assumes that a is reduced */
57     r = n_sqrtmod(a % p, p);
58     if (r == 0)
59         return r;
60 
61     i = 1;
62     while (i < k)
63     {
64         t = n_mulmod2_preinv(r, r, pk, pkinv);
65         t = n_submod(t, a, pk);
66         t = n_mulmod2_preinv(t, n_invmod(n_addmod(r, r, pk), pk), pk, pkinv);
67         r = n_submod(r, t, pk);
68         i *= 2;
69     }
70 
71     return r;
72 }
73 
74 void
trigprod_mul_prime_power(trig_prod_t prod,mp_limb_t k,mp_limb_t n,mp_limb_t p,int exp)75 trigprod_mul_prime_power(trig_prod_t prod, mp_limb_t k, mp_limb_t n,
76                                                 mp_limb_t p, int exp)
77 {
78     mp_limb_t m, mod, inv;
79 
80     if (k <= 3)
81     {
82         if (k == 0)
83         {
84             prod->prefactor = 0;
85         }
86         else if (k == 2 && (n % 2 == 1))
87         {
88             prod->prefactor *= -1;
89         }
90         else if (k == 3)
91         {
92             switch (n % 3)
93             {
94                 case 0:
95                     prod->prefactor *= 2;
96                     prod->cos_p[prod->n] = 1;
97                     prod->cos_q[prod->n] = 18;
98                     break;
99                 case 1:
100                     prod->prefactor *= -2;
101                     prod->cos_p[prod->n] = 7;
102                     prod->cos_q[prod->n] = 18;
103                     break;
104                 case 2:
105                     prod->prefactor *= -2;
106                     prod->cos_p[prod->n] = 5;
107                     prod->cos_q[prod->n] = 18;
108                     break;
109             }
110             prod->n++;
111         }
112         return;
113     }
114 
115     /* Power of 2 */
116     if (p == 2)
117     {
118         mod = 8 * k;
119         inv = n_preinvert_limb(mod);
120 
121         m = n_submod(1, n_mod2_preinv(24 * n, mod, inv), mod);
122         m = n_sqrtmod_2exp(m, exp + 3);
123         m = n_mulmod2_preinv(m, n_invmod(3, mod), mod, inv);
124 
125         prod->prefactor *= n_jacobi(-1, m);
126         if (exp % 2 == 1)
127             prod->prefactor *= -1;
128         prod->sqrt_p *= k;
129         prod->cos_p[prod->n] = (mp_limb_signed_t)(k - m);
130         prod->cos_q[prod->n] = 2 * k;
131         prod->n++;
132         return;
133     }
134 
135     /* Power of 3 */
136     if (p == 3)
137     {
138         mod = 3 * k;
139         inv = n_preinvert_limb(mod);
140 
141         m = n_submod(1, n_mod2_preinv(24 * n, mod, inv), mod);
142         m = n_sqrtmod_ppow(m, p, exp + 1, mod, inv);
143         m = n_mulmod2_preinv(m, n_invmod(8, mod), mod, inv);
144 
145         prod->prefactor *= (2 * n_jacobi_unsigned(m, 3));
146         if (exp % 2 == 0)
147             prod->prefactor *= -1;
148         prod->sqrt_p *= k;
149         prod->sqrt_q *= 3;
150         prod->cos_p[prod->n] = (mp_limb_signed_t)(3 * k - 8 * m);
151         prod->cos_q[prod->n] = 6 * k;
152         prod->n++;
153         return;
154     }
155 
156     /* Power of prime greater than 3 */
157     inv = n_preinvert_limb(k);
158     m = n_submod(1, n_mod2_preinv(24 * n, k, inv), k);
159 
160     if (m % p == 0)
161     {
162         if (exp == 1)
163         {
164             prod->prefactor *= n_jacobi(3, k);
165             prod->sqrt_p *= k;
166         }
167         else
168             prod->prefactor = 0;
169         return;
170     }
171 
172     m = n_sqrtmod_ppow(m, p, exp, k, inv);
173 
174     if (m == 0)
175     {
176         prod->prefactor = 0;
177         return;
178     }
179 
180     prod->prefactor *= 2;
181     prod->prefactor *= n_jacobi(3, k);
182     prod->sqrt_p *= k;
183     prod->cos_p[prod->n] = 4 * n_mulmod2_preinv(m, n_invmod(24 >= k ? n_mod2_preinv(24, k, inv) : 24, k), k, inv);
184     prod->cos_q[prod->n] = k;
185     prod->n++;
186 }
187 
188 /*
189 Solve (k2^2 * d2 * e) * n1 = (d2 * e * n + (k2^2 - 1) / d1)   mod k2
190 
191 TODO: test this on 32 bit
192 */
193 static mp_limb_t
solve_n1(mp_limb_t n,mp_limb_t k1,mp_limb_t k2,mp_limb_t d1,mp_limb_t d2,mp_limb_t e)194 solve_n1(mp_limb_t n, mp_limb_t k1, mp_limb_t k2,
195         mp_limb_t d1, mp_limb_t d2, mp_limb_t e)
196 {
197     mp_limb_t inv, n1, u, t[2];
198 
199     inv = n_preinvert_limb(k1);
200 
201     umul_ppmm(t[1], t[0], k2, k2);
202     sub_ddmmss(t[1], t[0], t[1], t[0], UWORD(0), UWORD(1));
203     mpn_divrem_1(t, 0, t, 2, d1);
204 
205     n1 = n_ll_mod_preinv(t[1], t[0], k1, inv);
206     n1 = n_mod2_preinv(n1 + d2*e*n, k1, inv);
207 
208     u = n_mulmod2_preinv(k2, k2, k1, inv);
209     u = n_invmod(n_mod2_preinv(u * d2 * e, k1, inv), k1);
210     n1 = n_mulmod2_preinv(n1, u, k1, inv);
211 
212     return n1;
213 }
214 
215 
216 void
arith_hrr_expsum_factored(trig_prod_t prod,mp_limb_t k,mp_limb_t n)217 arith_hrr_expsum_factored(trig_prod_t prod, mp_limb_t k, mp_limb_t n)
218 {
219     n_factor_t fac;
220     int i;
221 
222     if (k <= 1)
223     {
224         prod->prefactor = k;
225         return;
226     }
227 
228     n_factor_init(&fac);
229     n_factor(&fac, k, 0);
230 
231     /* Repeatedly factor A_k(n) into A_k1(n1)*A_k2(n2) with k1, k2 coprime */
232     for (i = 0; i + 1 < fac.num && prod->prefactor != 0; i++)
233     {
234         mp_limb_t p, k1, k2, inv, n1, n2;
235 
236         p = fac.p[i];
237 
238         /* k = 2 * k1 with k1 odd */
239         if (p == UWORD(2) && fac.exp[i] == 1)
240         {
241             k2 = k / 2;
242             inv = n_preinvert_limb(k2);
243 
244             n2 = n_invmod(32 >= k2 ? n_mod2_preinv(32, k2, inv) : 32, k2);
245             n2 = n_mulmod2_preinv(n2,
246                     n_mod2_preinv(8*n + 1, k2, inv), k2, inv);
247             n1 = ((k2 % 8 == 3) || (k2 % 8 == 5)) ^ (n & 1);
248 
249             trigprod_mul_prime_power(prod, 2, n1, 2, 1);
250             k = k2;
251             n = n2;
252         }
253         /* k = 4 * k1 with k1 odd */
254         else if (p == UWORD(2) && fac.exp[i] == 2)
255         {
256             k2 = k / 4;
257             inv = n_preinvert_limb(k2);
258 
259             n2 = n_invmod(128 >= k2 ? n_mod2_preinv(128, k2, inv) : 128, k2);
260             n2 = n_mulmod2_preinv(n2,
261                 n_mod2_preinv(8*n + 5, k2, inv), k2, inv);
262             n1 = (n + mod4_tab[(k2 / 2) % 8]) % 4;
263 
264             trigprod_mul_prime_power(prod, 4, n1, 2, 2);
265             prod->prefactor *= -1;
266             k = k2;
267             n = n2;
268         }
269         /* k = k1 * k2 with k1 odd or divisible by 8 */
270         else
271         {
272             mp_limb_t d1, d2, e;
273 
274             k1 = n_pow(fac.p[i], fac.exp[i]);
275             k2 = k / k1;
276 
277             d1 = gcd24_tab[k1 % 24];
278             d2 = gcd24_tab[k2 % 24];
279             e = 24 / (d1 * d2);
280 
281             n1 = solve_n1(n, k1, k2, d1, d2, e);
282             n2 = solve_n1(n, k2, k1, d2, d1, e);
283 
284             trigprod_mul_prime_power(prod, k1, n1, fac.p[i], fac.exp[i]);
285             k = k2;
286             n = n2;
287         }
288     }
289 
290     if (fac.num != 0 && prod->prefactor != 0)
291         trigprod_mul_prime_power(prod, k, n,
292             fac.p[fac.num - 1], fac.exp[fac.num - 1]);
293 
294 }
295