1 /*
2 Copyright (C) 2007, 2008 David Harvey (zn_poly)
3 Copyright (C) 2013 William Hart
4
5 This file is part of FLINT.
6
7 FLINT is free software: you can redistribute it and/or modify it under
8 the terms of the GNU Lesser General Public License (LGPL) as published
9 by the Free Software Foundation; either version 2.1 of the License, or
10 (at your option) any later version. See <https://www.gnu.org/licenses/>.
11 */
12
13 #include <stdlib.h>
14 #include <gmp.h>
15 #include "flint.h"
16 #include "nmod_vec.h"
17 #include "nmod_poly.h"
18
19 /*
20 Multiplication/squaring using Kronecker substitution at 2^b, -2^b,
21 2^(-b) and -2^(-b).
22 */
23 void
_nmod_poly_mul_KS4(mp_ptr res,mp_srcptr op1,slong n1,mp_srcptr op2,slong n2,nmod_t mod)24 _nmod_poly_mul_KS4(mp_ptr res, mp_srcptr op1, slong n1,
25 mp_srcptr op2, slong n2, nmod_t mod)
26 {
27 int sqr, v3m_neg;
28 ulong bits, b, w, a1, a2, a3;
29 slong n1o, n1e, n2o, n2e, n3o, n3e, n3, k1, k2, k3;
30 mp_ptr v1_buf0, v2_buf0, v1_buf1, v2_buf1, v1_buf2, v2_buf2, v1_buf3, v2_buf3, v1_buf4, v2_buf4;
31 mp_ptr v1on, v1en, v1pn, v1mn, v2on, v2en, v2pn, v2mn, v3on, v3en, v3pn, v3mn;
32 mp_ptr v1or, v1er, v1pr, v1mr, v2or, v2er, v2pr, v2mr, v3or, v3er, v3pr, v3mr;
33 mp_ptr z, zn, zr;
34 TMP_INIT;
35
36 if (n2 == 1)
37 {
38 /* code below needs n2 > 1, so fall back on scalar multiplication */
39 _nmod_vec_scalar_mul_nmod(res, op1, n1, op2[0], mod);
40 return;
41 }
42
43 TMP_START;
44
45 sqr = (op1 == op2 && n1 == n2);
46
47 /* bits in each output coefficient */
48 bits = 2 * (FLINT_BITS - mod.norm) + FLINT_CLOG2(n2);
49
50 /*
51 we're evaluating at x = B, -B, 1/B, -1/B,
52 where B = 2^b, and b = ceil(bits / 4)
53 */
54 b = (bits + 3) / 4;
55
56 /* number of ulongs required to store each base-B^2 digit */
57 w = (2*b - 1)/FLINT_BITS + 1;
58
59 /*
60 Write f1(x) = f1e(x^2) + x * f1o(x^2)
61 f2(x) = f2e(x^2) + x * f2o(x^2)
62 h(x) = he(x^2) + x * ho(x^2)
63 "e" = even, "o" = odd
64 */
65
66 n1o = n1 / 2;
67 n1e = n1 - n1o;
68
69 n2o = n2 / 2;
70 n2e = n2 - n2o;
71
72 n3 = n1 + n2 - 1; /* length of h */
73 n3o = n3 / 2;
74 n3e = n3 - n3o;
75
76 /*
77 Put k1 = number of limbs needed to store f1(B) and |f1(-B)|.
78 In f1(B), the leading coefficient starts at bit position b * (n1 - 1)
79 and has length 2b, and the coefficients overlap so we need an extra bit
80 for the carry: this gives (n1 + 1) * b + 1 bits. Ditto for f2.
81 */
82 k1 = ((n1 + 1) * b)/FLINT_BITS + 1;
83 k2 = ((n2 + 1) * b)/FLINT_BITS + 1;
84 k3 = k1 + k2;
85
86 /* allocate space */
87 v1_buf0 = TMP_ALLOC(sizeof(mp_limb_t) * 5 * k3); /* k1 limbs */
88 v2_buf0 = v1_buf0 + k1; /* k2 limbs */
89 v1_buf1 = v2_buf0 + k2; /* k1 limbs */
90 v2_buf1 = v1_buf1 + k1; /* k2 limbs */
91 v1_buf2 = v2_buf1 + k2; /* k1 limbs */
92 v2_buf2 = v1_buf2 + k1; /* k2 limbs */
93 v1_buf3 = v2_buf2 + k2; /* k1 limbs */
94 v2_buf3 = v1_buf3 + k1; /* k2 limbs */
95 v1_buf4 = v2_buf3 + k2; /* k1 limbs */
96 v2_buf4 = v1_buf4 + k1; /* k2 limbs */
97
98 /*
99 arrange overlapping buffers to minimise memory use
100 "p" = plus, "m" = minus
101 "n" = normal order, "r" = reciprocal order
102 */
103 v1en = v1_buf0;
104 v1on = v1_buf1;
105 v1pn = v1_buf2;
106 v1mn = v1_buf0;
107 v2en = v2_buf0;
108 v2on = v2_buf1;
109 v2pn = v2_buf2;
110 v2mn = v2_buf0;
111 v3pn = v1_buf1;
112 v3mn = v1_buf2;
113 v3en = v1_buf0;
114 v3on = v1_buf1;
115
116 v1er = v1_buf2;
117 v1or = v1_buf3;
118 v1pr = v1_buf4;
119 v1mr = v1_buf2;
120 v2er = v2_buf2;
121 v2or = v2_buf3;
122 v2pr = v2_buf4;
123 v2mr = v2_buf2;
124 v3pr = v1_buf3;
125 v3mr = v1_buf4;
126 v3er = v1_buf2;
127 v3or = v1_buf3;
128
129 z = TMP_ALLOC(sizeof(mp_limb_t) * 2*w*(n3e + 1));
130 zn = z;
131 zr = z + w*(n3e + 1);
132
133 /* -------------------------------------------------------------------------
134 "normal" evaluation points
135 */
136
137 if (!sqr)
138 {
139 /* multiplication version */
140
141 /*
142 evaluate f1e(B^2) and B * f1o(B^2)
143 We need max(2 * b*n1e, 2 * b*n1o + b) bits for this packing step,
144 which is safe since (n1 + 1) * b + 1 >= max(2 * b*n1e, 2 * b*n1o + b).
145 Ditto for f2 below.
146 */
147 _nmod_poly_KS2_pack(v1en, op1, n1e, 2, 2 * b, 0, k1);
148 _nmod_poly_KS2_pack(v1on, op1 + 1, n1o, 2, 2 * b, b, k1);
149
150 /*
151 compute f1(B) = f1e(B^2) + B * f1o(B^2)
152 and |f1(-B)| = |f1e(B^2) - B * f1o(B^2)|
153 */
154 mpn_add_n (v1pn, v1en, v1on, k1);
155 v3m_neg = signed_mpn_sub_n(v1mn, v1en, v1on, k1);
156
157 /* evaluate f2e(B^2) and B * f2o(B^2) */
158 _nmod_poly_KS2_pack(v2en, op2, n2e, 2, 2 * b, 0, k2);
159 _nmod_poly_KS2_pack(v2on, op2 + 1, n2o, 2, 2 * b, b, k2);
160
161 /*
162 compute f2(B) = f2e(B^2) + B * f2o(B^2)
163 and |f2(-B)| = |f2e(B^2) - B * f2o(B^2)|
164 */
165 mpn_add_n(v2pn, v2en, v2on, k2);
166 v3m_neg ^= signed_mpn_sub_n(v2mn, v2en, v2on, k2);
167
168 /*
169 compute h(B) = f1(B) * f2(B)
170 and |h(-B)| = |f1(-B)| * |f2(-B)|
171 hn_neg is set if h(-B) is negative
172 */
173 mpn_mul(v3pn, v1pn, k1, v2pn, k2);
174 mpn_mul(v3mn, v1mn, k1, v2mn, k2);
175 }
176 else
177 {
178 /* squaring version */
179
180 /* evaluate f1e(B^2) and B * f1o(B^2) */
181 _nmod_poly_KS2_pack(v1en, op1, n1e, 2, 2 * b, 0, k1);
182 _nmod_poly_KS2_pack(v1on, op1 + 1, n1o, 2, 2 * b, b, k1);
183
184 /*
185 compute f1(B) = f1e(B^2) + B * f1o(B^2)
186 and |f1(-B)| = |f1e(B^2) - B * f1o(B^2)|
187 */
188 mpn_add_n (v1pn, v1en, v1on, k1);
189 signed_mpn_sub_n(v1mn, v1en, v1on, k1);
190
191 /*
192 compute h(B) = f1(B)^2
193 and h(-B) = |f1(-B)|^2
194 hn_neg is cleared since h(-B) is never negative
195 */
196 mpn_sqr(v3pn, v1pn, k1);
197 mpn_sqr(v3mn, v1mn, k1);
198 v3m_neg = 0;
199 }
200
201 /*
202 Each coefficient of h(B) is up to 4b bits long, so h(B) needs at most
203 ((n1 + n2 + 2) * b + 1) bits. (The extra +1 is to accommodate carries
204 generated by overlapping coefficients.) The buffer has at least
205 ((n1 + n2 + 2) * b + 2) bits. Therefore we can safely store 2*h(B) etc.
206 */
207
208 /*
209 compute 2 * he(B^2) = h(B) + h(-B)
210 and B * 2 * ho(B^2) = h(B) - h(-B)
211 */
212 if (v3m_neg)
213 {
214 mpn_sub_n(v3en, v3pn, v3mn, k3);
215 mpn_add_n (v3on, v3pn, v3mn, k3);
216 }
217 else
218 {
219 mpn_add_n (v3en, v3pn, v3mn, k3);
220 mpn_sub_n (v3on, v3pn, v3mn, k3);
221 }
222
223 /* -------------------------------------------------------------------------
224 "reciprocal" evaluation points
225 */
226
227 /*
228 correction factors to take into account that if a polynomial has even
229 length, its even and odd coefficients are swapped when the polynomial
230 is reversed
231 */
232 a1 = (n1 & 1) ? 0 : b;
233 a2 = (n2 & 1) ? 0 : b;
234 a3 = (n3 & 1) ? 0 : b;
235
236 if (!sqr)
237 {
238 /* multiplication version */
239
240 /* evaluate B^(n1-1) * f1e(1/B^2) and B^(n1-2) * f1o(1/B^2) */
241 _nmod_poly_KS2_pack(v1er, op1 + 2*(n1e - 1), n1e, -2, 2 * b, a1, k1);
242 _nmod_poly_KS2_pack(v1or, op1 + 1 + 2*(n1o - 1), n1o, -2, 2 * b, b - a1, k1);
243
244 /*
245 compute B^(n1-1) * f1(1/B) =
246 B^(n1-1) * f1e(1/B^2) + B^(n1-2) * f1o(1/B^2)
247 and |B^(n1-1) * f1(-1/B)| =
248 |B^(n1-1) * f1e(1/B^2) - B^(n1-2) * f1o(1/B^2)|
249 */
250 mpn_add_n(v1pr, v1er, v1or, k1);
251 v3m_neg = signed_mpn_sub_n(v1mr, v1er, v1or, k1);
252
253 /* evaluate B^(n2-1) * f2e(1/B^2) and B^(n2-2) * f2o(1/B^2) */
254 _nmod_poly_KS2_pack(v2er, op2 + 2*(n2e - 1), n2e, -2, 2 * b, a2, k2);
255 _nmod_poly_KS2_pack(v2or, op2 + 1 + 2*(n2o - 1), n2o, -2, 2 * b, b - a2, k2);
256
257 /*
258 compute B^(n2-1) * f2(1/B) =
259 B^(n2-1) * f2e(1/B^2) + B^(n2-2) * f2o(1/B^2)
260 and |B^(n1-1) * f2(-1/B)| =
261 |B^(n2-1) * f2e(1/B^2) - B^(n2-2) * f2o(1/B^2)|
262 */
263 mpn_add_n (v2pr, v2er, v2or, k2);
264 v3m_neg ^= signed_mpn_sub_n(v2mr, v2er, v2or, k2);
265
266 /*
267 compute B^(n3-1) * h(1/B) =
268 (B^(n1-1) * f1(1/B)) * (B^(n2-1) * f2(1/B))
269 and |B^(n3-1) * h(-1/B)| =
270 |B^(n1-1) * f1(-1/B)| * |B^(n2-1) * f2(-1/B)|
271 hr_neg is set if h(-1/B) is negative
272 */
273 mpn_mul(v3pr, v1pr, k1, v2pr, k2);
274 mpn_mul(v3mr, v1mr, k1, v2mr, k2);
275 }
276 else
277 {
278 /* squaring version */
279
280 /* evaluate B^(n1-1) * f1e(1/B^2) and B^(n1-2) * f1o(1/B^2) */
281 _nmod_poly_KS2_pack(v1er, op1 + 2*(n1e - 1), n1e, -2, 2 * b, a1, k1);
282 _nmod_poly_KS2_pack(v1or, op1 + 1 + 2*(n1o - 1), n1o, -2, 2 * b, b - a1, k1);
283
284 /*
285 compute B^(n1-1) * f1(1/B) =
286 B^(n1-1) * f1e(1/B^2) + B^(n1-2) * f1o(1/B^2)
287 and |B^(n1-1) * f1(-1/B)| =
288 |B^(n1-1) * f1e(1/B^2) - B^(n1-2) * f1o(1/B^2)|
289 */
290 mpn_add_n(v1pr, v1er, v1or, k1);
291 signed_mpn_sub_n(v1mr, v1er, v1or, k1);
292
293 /*
294 compute B^(n3-1) * h(1/B) = (B^(n1-1) * f1(1/B))^2
295 and B^(n3-1) * h(-1/B) = |B^(n1-1) * f1(-1/B)|^2
296 hr_neg is cleared since h(-1/B) is never negative
297 */
298 mpn_sqr(v3pr, v1pr, k1);
299 mpn_sqr(v3mr, v1mr, k1);
300 v3m_neg = 0;
301 }
302
303 /*
304 compute 2 * B^(n3-1) * he(1/B^2)
305 = B^(n3-1) * h(1/B) + B^(n3-1) * h(-1/B)
306 and 2 * B^(n3-2) * ho(1/B^2)
307 = B^(n3-1) * h(1/B) - B^(n3-1) * h(-1/B)
308 */
309 if (v3m_neg)
310 {
311 mpn_sub_n(v3er, v3pr, v3mr, k3);
312 mpn_add_n(v3or, v3pr, v3mr, k3);
313 }
314 else
315 {
316 mpn_add_n (v3er, v3pr, v3mr, k3);
317 mpn_sub_n (v3or, v3pr, v3mr, k3);
318 }
319
320 /* -------------------------------------------------------------------------
321 combine "normal" and "reciprocal" information
322 */
323
324 /* decompose he(B^2) and B^(2*(n3e-1)) * he(1/B^2) into base-B^2 digits */
325 _nmod_poly_KS2_unpack(zn, v3en, n3e + 1, 2 * b, 1);
326 _nmod_poly_KS2_unpack(zr, v3er, n3e + 1, 2 * b, a3 + 1);
327
328 /* combine he(B^2) and he(1/B^2) information to get even coefficients of h */
329 _nmod_poly_KS2_recover_reduce(res, 2, zn, zr, n3e, 2 * b, mod);
330
331 /* decompose ho(B^2) and B^(2*(n3o-1)) * ho(1/B^2) into base-B^2 digits */
332 _nmod_poly_KS2_unpack(zn, v3on, n3o + 1, 2 * b, b + 1);
333 _nmod_poly_KS2_unpack(zr, v3or, n3o + 1, 2 * b, b - a3 + 1);
334
335 /* combine ho(B^2) and ho(1/B^2) information to get odd coefficients of h */
336 _nmod_poly_KS2_recover_reduce(res + 1, 2, zn, zr, n3o, 2 * b, mod);
337
338 TMP_END;
339 }
340
341 void
nmod_poly_mul_KS4(nmod_poly_t res,const nmod_poly_t poly1,const nmod_poly_t poly2)342 nmod_poly_mul_KS4(nmod_poly_t res,
343 const nmod_poly_t poly1, const nmod_poly_t poly2)
344 {
345 slong len_out;
346
347 if ((poly1->length == 0) || (poly2->length == 0))
348 {
349 nmod_poly_zero(res);
350 return;
351 }
352
353 len_out = poly1->length + poly2->length - 1;
354
355 if (res == poly1 || res == poly2)
356 {
357 nmod_poly_t temp;
358 nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
359 if (poly1->length >= poly2->length)
360 _nmod_poly_mul_KS4(temp->coeffs, poly1->coeffs, poly1->length,
361 poly2->coeffs, poly2->length,
362 poly1->mod);
363 else
364 _nmod_poly_mul_KS4(temp->coeffs, poly2->coeffs, poly2->length,
365 poly1->coeffs, poly1->length,
366 poly1->mod);
367 nmod_poly_swap(res, temp);
368 nmod_poly_clear(temp);
369 }
370 else
371 {
372 nmod_poly_fit_length(res, len_out);
373 if (poly1->length >= poly2->length)
374 _nmod_poly_mul_KS4(res->coeffs, poly1->coeffs, poly1->length,
375 poly2->coeffs, poly2->length,
376 poly1->mod);
377 else
378 _nmod_poly_mul_KS4(res->coeffs, poly2->coeffs, poly2->length,
379 poly1->coeffs, poly1->length,
380 poly1->mod);
381 }
382
383 res->length = len_out;
384 _nmod_poly_normalise(res);
385 }
386