1 /*
2 Copyright (C) 2011 Fredrik Johansson
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <http://www.gnu.org/licenses/>.
10 */
11
12 #include "arith.h"
13
14 static const int mod4_tab[8] = { 2, 1, 3, 0, 0, 3, 1, 2 };
15
16 static const int gcd24_tab[24] = {
17 24, 1, 2, 3, 4, 1, 6, 1, 8, 3, 2, 1,
18 12, 1, 2, 3, 8, 1, 6, 1, 4, 3, 2, 1
19 };
20
21 static mp_limb_t
n_sqrtmod_2exp(mp_limb_t a,int k)22 n_sqrtmod_2exp(mp_limb_t a, int k)
23 {
24 mp_limb_t x;
25 int i;
26
27 if (a == 0 || k == 0)
28 return 0;
29
30 if (k == 1)
31 return 1;
32
33 if (k == 2)
34 {
35 if (a == 1)
36 return 1;
37 return 0;
38 }
39
40 x = 1;
41 for (i = 3; i < k; i++)
42 x += (a - x * x) / 2;
43
44 if (k < FLINT_BITS)
45 x &= ((UWORD(1) << k) - 1);
46
47 return x;
48 }
49
50 static mp_limb_t
n_sqrtmod_ppow(mp_limb_t a,mp_limb_t p,int k,mp_limb_t pk,mp_limb_t pkinv)51 n_sqrtmod_ppow(mp_limb_t a, mp_limb_t p, int k, mp_limb_t pk, mp_limb_t pkinv)
52 {
53 mp_limb_t r, t;
54 int i;
55
56 /* n_sqrtmod assumes that a is reduced */
57 r = n_sqrtmod(a % p, p);
58 if (r == 0)
59 return r;
60
61 i = 1;
62 while (i < k)
63 {
64 t = n_mulmod2_preinv(r, r, pk, pkinv);
65 t = n_submod(t, a, pk);
66 t = n_mulmod2_preinv(t, n_invmod(n_addmod(r, r, pk), pk), pk, pkinv);
67 r = n_submod(r, t, pk);
68 i *= 2;
69 }
70
71 return r;
72 }
73
74 void
trigprod_mul_prime_power(trig_prod_t prod,mp_limb_t k,mp_limb_t n,mp_limb_t p,int exp)75 trigprod_mul_prime_power(trig_prod_t prod, mp_limb_t k, mp_limb_t n,
76 mp_limb_t p, int exp)
77 {
78 mp_limb_t m, mod, inv;
79
80 if (k <= 3)
81 {
82 if (k == 0)
83 {
84 prod->prefactor = 0;
85 }
86 else if (k == 2 && (n % 2 == 1))
87 {
88 prod->prefactor *= -1;
89 }
90 else if (k == 3)
91 {
92 switch (n % 3)
93 {
94 case 0:
95 prod->prefactor *= 2;
96 prod->cos_p[prod->n] = 1;
97 prod->cos_q[prod->n] = 18;
98 break;
99 case 1:
100 prod->prefactor *= -2;
101 prod->cos_p[prod->n] = 7;
102 prod->cos_q[prod->n] = 18;
103 break;
104 case 2:
105 prod->prefactor *= -2;
106 prod->cos_p[prod->n] = 5;
107 prod->cos_q[prod->n] = 18;
108 break;
109 }
110 prod->n++;
111 }
112 return;
113 }
114
115 /* Power of 2 */
116 if (p == 2)
117 {
118 mod = 8 * k;
119 inv = n_preinvert_limb(mod);
120
121 m = n_submod(1, n_mod2_preinv(24 * n, mod, inv), mod);
122 m = n_sqrtmod_2exp(m, exp + 3);
123 m = n_mulmod2_preinv(m, n_invmod(3, mod), mod, inv);
124
125 prod->prefactor *= n_jacobi(-1, m);
126 if (exp % 2 == 1)
127 prod->prefactor *= -1;
128 prod->sqrt_p *= k;
129 prod->cos_p[prod->n] = (mp_limb_signed_t)(k - m);
130 prod->cos_q[prod->n] = 2 * k;
131 prod->n++;
132 return;
133 }
134
135 /* Power of 3 */
136 if (p == 3)
137 {
138 mod = 3 * k;
139 inv = n_preinvert_limb(mod);
140
141 m = n_submod(1, n_mod2_preinv(24 * n, mod, inv), mod);
142 m = n_sqrtmod_ppow(m, p, exp + 1, mod, inv);
143 m = n_mulmod2_preinv(m, n_invmod(8, mod), mod, inv);
144
145 prod->prefactor *= (2 * n_jacobi_unsigned(m, 3));
146 if (exp % 2 == 0)
147 prod->prefactor *= -1;
148 prod->sqrt_p *= k;
149 prod->sqrt_q *= 3;
150 prod->cos_p[prod->n] = (mp_limb_signed_t)(3 * k - 8 * m);
151 prod->cos_q[prod->n] = 6 * k;
152 prod->n++;
153 return;
154 }
155
156 /* Power of prime greater than 3 */
157 inv = n_preinvert_limb(k);
158 m = n_submod(1, n_mod2_preinv(24 * n, k, inv), k);
159
160 if (m % p == 0)
161 {
162 if (exp == 1)
163 {
164 prod->prefactor *= n_jacobi(3, k);
165 prod->sqrt_p *= k;
166 }
167 else
168 prod->prefactor = 0;
169 return;
170 }
171
172 m = n_sqrtmod_ppow(m, p, exp, k, inv);
173
174 if (m == 0)
175 {
176 prod->prefactor = 0;
177 return;
178 }
179
180 prod->prefactor *= 2;
181 prod->prefactor *= n_jacobi(3, k);
182 prod->sqrt_p *= k;
183 prod->cos_p[prod->n] = 4 * n_mulmod2_preinv(m, n_invmod(24 >= k ? n_mod2_preinv(24, k, inv) : 24, k), k, inv);
184 prod->cos_q[prod->n] = k;
185 prod->n++;
186 }
187
188 /*
189 Solve (k2^2 * d2 * e) * n1 = (d2 * e * n + (k2^2 - 1) / d1) mod k2
190
191 TODO: test this on 32 bit
192 */
193 static mp_limb_t
solve_n1(mp_limb_t n,mp_limb_t k1,mp_limb_t k2,mp_limb_t d1,mp_limb_t d2,mp_limb_t e)194 solve_n1(mp_limb_t n, mp_limb_t k1, mp_limb_t k2,
195 mp_limb_t d1, mp_limb_t d2, mp_limb_t e)
196 {
197 mp_limb_t inv, n1, u, t[2];
198
199 inv = n_preinvert_limb(k1);
200
201 umul_ppmm(t[1], t[0], k2, k2);
202 sub_ddmmss(t[1], t[0], t[1], t[0], UWORD(0), UWORD(1));
203 mpn_divrem_1(t, 0, t, 2, d1);
204
205 n1 = n_ll_mod_preinv(t[1], t[0], k1, inv);
206 n1 = n_mod2_preinv(n1 + d2*e*n, k1, inv);
207
208 u = n_mulmod2_preinv(k2, k2, k1, inv);
209 u = n_invmod(n_mod2_preinv(u * d2 * e, k1, inv), k1);
210 n1 = n_mulmod2_preinv(n1, u, k1, inv);
211
212 return n1;
213 }
214
215
216 void
arith_hrr_expsum_factored(trig_prod_t prod,mp_limb_t k,mp_limb_t n)217 arith_hrr_expsum_factored(trig_prod_t prod, mp_limb_t k, mp_limb_t n)
218 {
219 n_factor_t fac;
220 int i;
221
222 if (k <= 1)
223 {
224 prod->prefactor = k;
225 return;
226 }
227
228 n_factor_init(&fac);
229 n_factor(&fac, k, 0);
230
231 /* Repeatedly factor A_k(n) into A_k1(n1)*A_k2(n2) with k1, k2 coprime */
232 for (i = 0; i + 1 < fac.num && prod->prefactor != 0; i++)
233 {
234 mp_limb_t p, k1, k2, inv, n1, n2;
235
236 p = fac.p[i];
237
238 /* k = 2 * k1 with k1 odd */
239 if (p == UWORD(2) && fac.exp[i] == 1)
240 {
241 k2 = k / 2;
242 inv = n_preinvert_limb(k2);
243
244 n2 = n_invmod(32 >= k2 ? n_mod2_preinv(32, k2, inv) : 32, k2);
245 n2 = n_mulmod2_preinv(n2,
246 n_mod2_preinv(8*n + 1, k2, inv), k2, inv);
247 n1 = ((k2 % 8 == 3) || (k2 % 8 == 5)) ^ (n & 1);
248
249 trigprod_mul_prime_power(prod, 2, n1, 2, 1);
250 k = k2;
251 n = n2;
252 }
253 /* k = 4 * k1 with k1 odd */
254 else if (p == UWORD(2) && fac.exp[i] == 2)
255 {
256 k2 = k / 4;
257 inv = n_preinvert_limb(k2);
258
259 n2 = n_invmod(128 >= k2 ? n_mod2_preinv(128, k2, inv) : 128, k2);
260 n2 = n_mulmod2_preinv(n2,
261 n_mod2_preinv(8*n + 5, k2, inv), k2, inv);
262 n1 = (n + mod4_tab[(k2 / 2) % 8]) % 4;
263
264 trigprod_mul_prime_power(prod, 4, n1, 2, 2);
265 prod->prefactor *= -1;
266 k = k2;
267 n = n2;
268 }
269 /* k = k1 * k2 with k1 odd or divisible by 8 */
270 else
271 {
272 mp_limb_t d1, d2, e;
273
274 k1 = n_pow(fac.p[i], fac.exp[i]);
275 k2 = k / k1;
276
277 d1 = gcd24_tab[k1 % 24];
278 d2 = gcd24_tab[k2 % 24];
279 e = 24 / (d1 * d2);
280
281 n1 = solve_n1(n, k1, k2, d1, d2, e);
282 n2 = solve_n1(n, k2, k1, d2, d1, e);
283
284 trigprod_mul_prime_power(prod, k1, n1, fac.p[i], fac.exp[i]);
285 k = k2;
286 n = n2;
287 }
288 }
289
290 if (fac.num != 0 && prod->prefactor != 0)
291 trigprod_mul_prime_power(prod, k, n,
292 fac.p[fac.num - 1], fac.exp[fac.num - 1]);
293
294 }
295