1 /*
2 mulmid_ks.c: polynomial middle products by Kronecker substitution
3
4 Copyright (C) 2007, 2008, David Harvey
5
6 This file is part of the zn_poly library (version 0.9).
7
8 This program is free software: you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation, either version 2 of the License, or
11 (at your option) version 3 of the License.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see <http://www.gnu.org/licenses/>.
20
21 */
22
23 #include "zn_poly_internal.h"
24
25
26 /*
27 In the routines below, we denote by f1(x) and f2(x) the input polynomials
28 op1[0, n1) and op2[0, n2), and by h(x) their product in Z[x].
29
30 We write h(x) = LO(x) + x^(n2-1) * g(x) + x^n1 * HI(x), where
31 len(LO) = len(HI) = n2 - 1 and len(g) = n1 - n2 + 1. Our goal is to
32 compute the middle segment g.
33
34 The basic strategy is: if X is an evaluation point (i.e. X = 2^b, -2^b,
35 2^(-b) or -2^(-b)) then g(X) corresponds roughly to the integer middle
36 product of f1(X) and f2(X), and we will use mpn_mulmid() to compute the
37 latter.
38
39 Unfortunately there are some complications.
40
41 First, mpn_mulmid() works in terms of whole limb counts, not bit counts,
42 and moreover the first two and last two limbs of the output of mpn_mulmid()
43 are always garbage. We handle this issue using zero-padding as follows.
44 Suppose that we need s bits of g(X) starting at bit index r. We compute
45 f2(X) as usual. Let k2 = number of limbs used to store f2(X). Instead of
46 evaluating f1(X), we evaluate 2^p * f1(X), i.e. zero-pad by p bits, where
47
48 p = (k2 + 1) * GMP_NUMB_BITS - r.
49
50 (We will verify in each case that p >= 0.) This shifts g(X) left by p bits,
51 and ensures that bit #r of g(X) starts exactly at the first bit of the
52 third limb of the output of mpn_mulmid(). Let k1 = number of limbs used to
53 store f1(X). To be guaranteed of obtaining s correct bits of g(X), we need
54 to have
55
56 (k1 - k2 - 1) * GMP_NUMB_BITS >= s,
57
58 or equivalently
59
60 k1 * GMP_NUMB_BITS >= p + r + s. (*)
61
62 We zero-pad 2^p * f1(X) on the right to ensure that (*) holds. In every
63 case, it turns out that the total amount of zero-padding is O(1) bits.
64
65 Second, in the "reciprocal" variants (KS3 and KS4) there is the problem of
66 overlapping coefficients, e.g. when we compute the integer middle product,
67 the low bits of g(X) are polluted by the high bits of LO(X). To deal with
68 this we need to compute the low coefficient of g(X) separately, and remove
69 its effect from the overlapping portion. Similarly at the other end. The
70 diagonal() function accomplishes this.
71
72 */
73
74
75
76 /*
77 Middle product using Kronecker substitution at 2^b.
78 */
79 void
zn_array_mulmid_KS1(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)80 zn_array_mulmid_KS1 (ulong* res,
81 const ulong* op1, size_t n1,
82 const ulong* op2, size_t n2,
83 int redc, const zn_mod_t mod)
84 {
85 ZNP_ASSERT (n2 >= 1);
86 ZNP_ASSERT (n1 >= n2);
87 ZNP_ASSERT (n1 <= ULONG_MAX);
88 ZNP_ASSERT ((mod->m & 1) || !redc);
89
90 // length of g
91 size_t n3 = n1 - n2 + 1;
92
93 // bits in each output coefficient
94 unsigned b = 2 * mod->bits + ceil_lg (n2);
95
96 // number of ulongs required to store each output coefficient
97 unsigned w = CEIL_DIV (b, ULONG_BITS);
98 ZNP_ASSERT (w <= 3);
99
100 // number of limbs needed to store f2(2^b)
101 size_t k2 = CEIL_DIV (n2 * b, GMP_NUMB_BITS);
102
103 // We need r = (n2 - 1) * b and s = (n1 - n2 + 1) * b. Note that p is
104 // non-negative since k2 * GMP_NUMB_BITS >= n2 * b.
105 unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b;
106
107 // For (*) to hold we need k1 * GMP_NUMB_BITS >= p + n1 * b.
108 size_t k1 = CEIL_DIV (p + n1 * b, GMP_NUMB_BITS);
109
110 // allocate space
111 ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 2 * k1 + 3);
112 mp_limb_t* v1 = limbs; // k1 limbs
113 mp_limb_t* v2 = v1 + k1; // k2 limbs
114 mp_limb_t* v3 = v2 + k2; // k1 - k2 + 3 limbs
115
116 // evaluate 2^p * f1(2^b) and f2(2^b)
117 zn_array_pack (v1, op1, n1, 1, b, p, 0);
118 zn_array_pack (v2, op2, n2, 1, b, 0, 0);
119
120 // compute segment of f1(2^b) * f2(2^b) starting at bit index r
121 ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
122
123 // unpack coefficients of g, and reduce mod m
124 ZNP_FASTALLOC (z, ulong, 6624, n3 * w);
125 zn_array_unpack_SAFE (z, v3 + 2, n3, b, 0, k1 - k2 - 1);
126 array_reduce (res, 1, z, n3, w, redc, mod);
127
128 ZNP_FASTFREE (z);
129 ZNP_FASTFREE (limbs);
130 }
131
132
133
134 /*
135 Middle product using Kronecker substitution at 2^b and -2^b.
136 */
137 void
zn_array_mulmid_KS2(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)138 zn_array_mulmid_KS2 (ulong* res,
139 const ulong* op1, size_t n1,
140 const ulong* op2, size_t n2,
141 int redc, const zn_mod_t mod)
142 {
143 ZNP_ASSERT (n2 >= 1);
144 ZNP_ASSERT (n1 >= n2);
145 ZNP_ASSERT (n1 <= ULONG_MAX);
146 ZNP_ASSERT ((mod->m & 1) || !redc);
147
148 if (n2 == 1)
149 {
150 // code below needs n2 > 1, so fall back on scalar multiplication
151 _zn_array_scalar_mul (res, op1, n1, op2[0], redc, mod);
152 return;
153 }
154
155 // bits in each output coefficient
156 unsigned bits = 2 * mod->bits + ceil_lg (n2);
157
158 // we're evaluating at x = B and -B, where B = 2^b, and b = ceil(bits / 2)
159 unsigned b = (bits + 1) / 2;
160
161 // number of ulongs required to store each output coefficient
162 unsigned w = CEIL_DIV (2 * b, ULONG_BITS);
163 ZNP_ASSERT (w <= 3);
164
165 // Write f1(x) = f1e(x^2) + x * f1o(x^2)
166 // f2(x) = f2e(x^2) + x * f2o(x^2)
167 // h(x) = he(x^2) + x * ho(x^2)
168 // g(x) = ge(x^2) + x * go(x^2)
169 // "e" = even, "o" = odd
170
171 // When evaluating f2e(B^2) and B * f2o(B^2) the bit-packing routine needs
172 // room for the last chunk of 2b bits, so we need to allow room for
173 // (n2 + 1) * b bits.
174 size_t k2 = CEIL_DIV ((n2 + 1) * b, GMP_NUMB_BITS);
175
176 // We need r = (n2 - 2) * b + 1 and s = (n1 - n2 + 3) * b.
177 // Note that p is non-negative (since k2 * GMP_NUMB_BITS >= (n2 + 1) * b
178 // >= (n2 - 2) * b - 1).
179 unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 2) * b - 1;
180
181 // For (*) to hold we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b + 1.
182 // Also, to ensure that there is enough room for bit-packing (as for k2
183 // above), we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b; this condition
184 // is subsumed by the first one.
185 size_t k1 = CEIL_DIV (p + (n1 + 1) * b + 1, GMP_NUMB_BITS);
186
187 size_t k3 = k1 - k2 + 3;
188 ZNP_ASSERT (k3 >= 5);
189
190 // allocate space
191 ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 3 * k3 + 5 * k2);
192 mp_limb_t* v2_buf0 = limbs; // k2 limbs
193 mp_limb_t* v3_buf0 = v2_buf0 + k2; // k3 limbs
194 mp_limb_t* v2_buf1 = v3_buf0 + k3; // k2 limbs
195 mp_limb_t* v3_buf1 = v2_buf1 + k2; // k3 limbs
196 mp_limb_t* v2_buf2 = v3_buf1 + k3; // k2 limbs
197 mp_limb_t* v3_buf2 = v2_buf2 + k2; // k3 limbs
198 mp_limb_t* v2_buf3 = v3_buf2 + k3; // k2 limbs
199 mp_limb_t* v2_buf4 = v2_buf3 + k2; // k2 limbs
200
201 // arrange overlapping buffers to minimise memory use
202 // "p" = plus, "m" = minus
203 mp_limb_t* v1e = v2_buf0;
204 mp_limb_t* v1o = v2_buf2;
205 mp_limb_t* v1p = v2_buf1;
206 mp_limb_t* v1m = v2_buf0;
207 mp_limb_t* v2e = v2_buf2;
208 mp_limb_t* v2o = v2_buf4;
209 mp_limb_t* v2p = v2_buf3;
210 mp_limb_t* v2m = v2_buf2;
211 mp_limb_t* v3m = v3_buf2;
212 mp_limb_t* v3p = v3_buf0;
213 mp_limb_t* v3e = v3_buf1;
214 mp_limb_t* v3o = v3_buf1;
215
216 // length of g
217 size_t n3 = n1 - n2 + 1;
218
219 ZNP_FASTALLOC (z, ulong, 6624, w * ((n3 + 1) / 2));
220
221 // evaluate 2^p * f1e(B^2) and 2^p * B * f1o(B^2)
222 zn_array_pack (v1e, op1, (n1 + 1) / 2, 2, 2 * b, p, k1);
223 zn_array_pack (v1o, op1 + 1, n1 / 2, 2, 2 * b, p + b, k1);
224
225 // compute 2^p * f1(B) = 2^p * (f1e(B^2) + B * f1o(B^2))
226 // and |2^p * f1(-B)| = |2^p * (f1e(B^2) - B * f1o(B^2))|
227 // v3m_neg is set if f1(-B) is negative
228 ZNP_ASSERT_NOCARRY (mpn_add_n (v1p, v1e, v1o, k1));
229 int v3m_neg = signed_mpn_sub_n (v1m, v1e, v1o, k1);
230
231 // evaluate f2e(B^2) and B * f2o(B^2)
232 zn_array_pack (v2e, op2, (n2 + 1) / 2, 2, 2 * b, 0, k2);
233 zn_array_pack (v2o, op2 + 1, n2 / 2, 2, 2 * b, b, k2);
234
235 // compute f2(B) = f2e(B^2) + B * f2o(B^2)
236 // and |f2(-B)| = |f2e(B^2) - B * f2o(B^2)|
237 // v3m_neg is set if f1(-B) and f2(-B) have opposite signs
238 ZNP_ASSERT_NOCARRY (mpn_add_n (v2p, v2e, v2o, k2));
239 v3m_neg ^= signed_mpn_sub_n (v2m, v2e, v2o, k2);
240
241 // compute segment starting at bit index r of
242 // h(B) = f1(B) * f2(B)
243 // and |h(-B)| = |f1(-B)| * |f2(-B)|
244 // v3m_neg is set if h(-B) is negative
245 ZNP_mpn_mulmid (v3m, v1m, k1, v2m, k2);
246 ZNP_mpn_mulmid (v3p, v1p, k1, v2p, k2);
247
248 // compute segment starting at bit index r of
249 // 2 * he(B^2) = h(B) + h(-B) (if n2 is odd)
250 // or 2 * B * ho(B^2) = h(B) - h(-B) (if n2 is even)
251 // i.e. the segment of he(B^2) or B * ho(B^2) starting at bit index
252 // r - 1 = (n2 - 2) * b. This encodes the coefficients of ge(x).
253
254 // Note that when we do the addition (resp. subtraction) below, we might
255 // miss a carry (resp. borrow) from the unknown previous limbs. We arrange
256 // so that the answers are either correct or one too big, by adding 1
257 // appropriately.
258
259 if (v3m_neg ^ (n2 & 1))
260 {
261 mpn_add_n (v3e, v3p + 2, v3m + 2, k3 - 4); // miss carry?
262 mpn_add_1 (v3e, v3e, k3 - 4, 1);
263 }
264 else
265 mpn_sub_n (v3e, v3p + 2, v3m + 2, k3 - 4); // miss borrow?
266
267 // Now we extract ge(x). The first coefficient we want is the coefficient
268 // of x^(n2 - 1) in h(x); this starts at bit b index in v3e. We want
269 // ceil(n3 / 2) coefficients altogether, with 2b bits each. This accounts
270 // for the definition of s.
271
272 // Claim: if we committed a "one-too-big" error above, this does not affect
273 // the coefficients we extract. Proof: the first b bits of v3e are the top
274 // half of the coefficient of x^(n2 - 2) in h(x). The base-B digit in those
275 // b bits has value at most B - 2. Therefore adding 1 to it will never
276 // overflow those b bits.
277
278 zn_array_unpack_SAFE (z, v3e, (n3 + 1) / 2, 2 * b, b, k3 - 4);
279 array_reduce (res, 2, z, (n3 + 1) / 2, w, redc, mod);
280
281 // Now repeat all the above for go(x).
282
283 if (v3m_neg ^ (n2 & 1))
284 mpn_sub_n (v3o, v3p + 2, v3m + 2, k3 - 4);
285 else
286 {
287 mpn_add_n (v3o, v3p + 2, v3m + 2, k3 - 4);
288 mpn_add_1 (v3o, v3o, k3 - 4, 1);
289 }
290
291 zn_array_unpack_SAFE (z, v3o, n3 / 2, 2 * b, 2 * b, k3 - 4);
292 array_reduce (res + 1, 2, z, n3 / 2, w, redc, mod);
293
294 ZNP_FASTFREE (z);
295 ZNP_FASTFREE (limbs);
296 }
297
298
299
300 /*
301 Computes the sum
302
303 op1[0] * op2[n-1] + ... + op1[n-1] * op2[0]
304
305 as an *integer*. The result is assumed to fit into w ulongs (where
306 1 <= w <= 3), and is written to res[0, w). The return value is the result
307 reduced modulo mod->m (using redc if requested).
308 */
309 #define diagonal_sum \
310 ZNP_diagonal_sum
311 ulong
diagonal_sum(ulong * res,const ulong * op1,const ulong * op2,size_t n,unsigned w,int redc,const zn_mod_t mod)312 diagonal_sum (ulong* res, const ulong* op1, const ulong* op2,
313 size_t n, unsigned w, int redc, const zn_mod_t mod)
314 {
315 ZNP_ASSERT (n >= 1);
316 ZNP_ASSERT (w >= 1);
317 ZNP_ASSERT (w <= 3);
318
319 size_t i;
320
321 if (w == 1)
322 {
323 ulong sum = op1[0] * op2[n - 1];
324
325 for (i = 1; i < n; i++)
326 sum += op1[i] * op2[n - 1 - i];
327
328 res[0] = sum;
329 return redc ? zn_mod_reduce_redc (sum, mod) : zn_mod_reduce (sum, mod);
330 }
331 else if (w == 2)
332 {
333 ulong lo, hi, sum0, sum1;
334
335 ZNP_MUL_WIDE (sum1, sum0, op1[0], op2[n - 1]);
336
337 for (i = 1; i < n; i++)
338 {
339 ZNP_MUL_WIDE (hi, lo, op1[i], op2[n - 1 - i]);
340 ZNP_ADD_WIDE (sum1, sum0, sum1, sum0, hi, lo);
341 }
342
343 res[0] = sum0;
344 res[1] = sum1;
345 return redc ? zn_mod_reduce2_redc (sum1, sum0, mod)
346 : zn_mod_reduce2 (sum1, sum0, mod);
347 }
348 else // w == 3
349 {
350 ulong lo, hi, sum0, sum1, sum2 = 0;
351
352 ZNP_MUL_WIDE (sum1, sum0, op1[0], op2[n - 1]);
353
354 for (i = 1; i < n; i++)
355 {
356 ZNP_MUL_WIDE (hi, lo, op1[i], op2[n - 1 - i]);
357 ZNP_ADD_WIDE (sum1, sum0, sum1, sum0, hi, lo);
358 // carry into third limb:
359 if (sum1 <= hi)
360 sum2 += (sum1 < hi || sum0 < lo);
361 }
362
363 res[0] = sum0;
364 res[1] = sum1;
365 res[2] = sum2;
366 return redc ? zn_mod_reduce3_redc (sum2, sum1, sum0, mod)
367 : zn_mod_reduce3 (sum2, sum1, sum0, mod);
368 }
369 }
370
371
372 /*
373 Inplace subtract 2^i*x from res[0, n).
374 x is an array of w ulongs, where 1 <= w <= 3.
375 i may be any non-negative integer.
376 */
377 #define subtract_ulongs \
378 ZNP_subtract_ulongs
379 void
subtract_ulongs(mp_limb_t * res,size_t n,size_t i,ulong * x,unsigned w)380 subtract_ulongs (mp_limb_t* res, size_t n, size_t i, ulong* x, unsigned w)
381 {
382 ZNP_ASSERT (w >= 1);
383 ZNP_ASSERT (w <= 3);
384
385 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
386 size_t k = i / GMP_NUMB_BITS;
387
388 if (k >= n)
389 return;
390
391 unsigned j = i % GMP_NUMB_BITS;
392
393 if (j == 0)
394 mpn_sub (res + k, res + k, n - k, (mp_srcptr) x, ZNP_MIN (n - k, w));
395 else
396 {
397 mp_limb_t y[4];
398 y[w] = mpn_lshift (y, (mp_srcptr) x, w, j);
399 mpn_sub (res + k, res + k, n - k, y, ZNP_MIN (n - k, w + 1));
400 }
401 #else
402 #error Not nails-safe yet
403 #endif
404 }
405
406
407 /*
408 Middle product using Kronecker substitution at 2^b and 2^(-b).
409
410 Note: this routine does not appear to be competitive in practice with the
411 other KS routines. It's here just for fun.
412 */
413 void
zn_array_mulmid_KS3(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)414 zn_array_mulmid_KS3 (ulong* res,
415 const ulong* op1, size_t n1,
416 const ulong* op2, size_t n2,
417 int redc, const zn_mod_t mod)
418 {
419 ZNP_ASSERT (n2 >= 1);
420 ZNP_ASSERT (n1 >= n2);
421 ZNP_ASSERT (n1 <= ULONG_MAX);
422 ZNP_ASSERT ((mod->m & 1) || !redc);
423
424 // length of g
425 size_t n3 = n1 - n2 + 1;
426
427 // bits in each output coefficient
428 unsigned bits = 2 * mod->bits + ceil_lg (n2);
429
430 // we're evaluating at x = B and 1/B, where B = 2^b, and b = ceil(bits / 2)
431 unsigned b = (bits + 1) / 2;
432
433 // number of ulongs required to store each base-B digit
434 unsigned w = CEIL_DIV (b, ULONG_BITS);
435 ZNP_ASSERT (w <= 2);
436
437 // number of ulongs needed to store each output coefficient
438 unsigned ww = CEIL_DIV (2 * b, ULONG_BITS);
439 ZNP_ASSERT (ww <= 3);
440
441 // directly compute coefficient of x^0 in g(x)
442 ulong dlo[3];
443 res[0] = diagonal_sum (dlo, op1, op2, n2, ww, redc, mod);
444 if (n3 == 1)
445 return; // only need one coefficient of output
446
447 // directly compute coefficient of x^(n3-1) in g(x)
448 ulong dhi[3];
449 res[n3 - 1] = diagonal_sum (dhi, op1 + n3 - 1, op2, n2, ww, redc, mod);
450 if (n3 == 2)
451 return; // only need two coefficients of output
452
453 // limbs needed to store f2(B) and B^(n2-1) * f2(1/B)
454 size_t k2 = CEIL_DIV (n2 * b, GMP_NUMB_BITS);
455
456 // we need r = (n2 - 1) * b and s = (n1 - n2 + 1) * b, thus p is:
457 unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b;
458
459 // for (*) we need k1 * GMP_NUMB_BITS >= p + n1 * b
460 size_t k1 = CEIL_DIV (p + n1 * b, GMP_NUMB_BITS);
461
462 size_t k3 = k1 - k2 + 3;
463 ZNP_ASSERT (k3 >= 5);
464
465 // allocate space
466 ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 2 * k1 + 3);
467 mp_limb_t* v1 = limbs; // k1 limbs
468 mp_limb_t* v2 = v1 + k1; // k2 limbs
469 mp_limb_t* v3 = v2 + k2; // k1 - k2 + 3 limbs
470
471 ZNP_FASTALLOC (z, ulong, 6624, 2 * w * n3);
472 // "n" = normal order, "r" = reciprocal order
473 ulong* zn = z;
474 ulong* zr = z + w * n3;
475
476 // -------------------------------------------------------------------------
477 // "normal" evaluation point
478
479 // evaluate 2^p * f1(B) and f2(B)
480 zn_array_pack (v1, op1, n1, 1, b, p, k1);
481 zn_array_pack (v2, op2, n2, 1, b, 0, k2);
482
483 // compute segment starting at bit index r of h(B) = f1(B) * f2(B)
484 ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
485
486 // remove x^0 and x^(n3 - 1) coefficient of g(x)
487 subtract_ulongs (v3 + 2, k3 - 4, 0, dlo, ww);
488 subtract_ulongs (v3 + 2, k3 - 4, (n3 - 1) * b, dhi, ww);
489
490 // decompose relevant portion of h(B) into base-B digits
491 zn_array_unpack_SAFE (zn, v3 + 2, n3 - 1, b, b, k3 - 4);
492
493 // At this stage zn contains (n3 - 1) base-B digits, representing the
494 // integer g[1] + g[2]*B + ... + g[n3-2]*B^(n3-3)
495
496 // -------------------------------------------------------------------------
497 // "reciprocal" evaluation point
498
499 // evaluate 2^p * B^(n1-1) * f1(1/B) and B^(n2-1) * f2(B)
500 zn_array_pack (v1, op1 + n1 - 1, n1, -1, b, p, k1);
501 zn_array_pack (v2, op2 + n2 - 1, n2, -1, b, 0, k2);
502
503 // compute segment starting at bit index r of B^(n1+n2-2) * h(1/B) =
504 // (B^(n1-1) * f1(1/B)) * (B^(n2-1) * f2(1/B))
505 ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
506
507 // remove x^0 and x^(n3 - 1) coefficient of g(x)
508 subtract_ulongs (v3 + 2, k3 - 4, 0, dhi, ww);
509 subtract_ulongs (v3 + 2, k3 - 4, (n3 - 1) * b, dlo, ww);
510
511 // decompose relevant portion of B^(n1+n2-2) * h(1/B) into base-B digits
512 zn_array_unpack_SAFE (zr, v3 + 2, n3 - 1, b, b, k3 - 4);
513
514 // At this stage zr contains (n3 - 1) base-B digits, representing the
515 // integer g[n3-2] + g[n3-3]*B + ... + g[1]*B^(n3-3)
516
517 // -------------------------------------------------------------------------
518 // combine "normal" and "reciprocal" information
519
520 zn_array_recover_reduce (res + 1, 1, zn, zr, n3 - 2, b, redc, mod);
521
522 ZNP_FASTFREE (z);
523 ZNP_FASTFREE (limbs);
524 }
525
526
527
528 /*
529 Middle product using Kronecker substitution at 2^b, -2^b, 2^(-b)
530 and -2^(-b).
531 */
532 void
zn_array_mulmid_KS4(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)533 zn_array_mulmid_KS4 (ulong* res,
534 const ulong* op1, size_t n1,
535 const ulong* op2, size_t n2,
536 int redc, const zn_mod_t mod)
537 {
538 ZNP_ASSERT (n2 >= 1);
539 ZNP_ASSERT (n1 >= n2);
540 ZNP_ASSERT (n1 <= ULONG_MAX);
541 ZNP_ASSERT ((mod->m & 1) || !redc);
542
543 if (n2 == 1)
544 {
545 // code below needs n2 > 1, so fall back on scalar multiplication
546 _zn_array_scalar_mul (res, op1, n1, op2[0], redc, mod);
547 return;
548 }
549
550 // bits in each output coefficient
551 unsigned bits = 2 * mod->bits + ceil_lg (n2);
552
553 // we're evaluating at x = B, -B, 1/B, -1/B,
554 // where B = 2^b, and b = ceil(bits / 4)
555 unsigned b = (bits + 3) / 4;
556
557 // number of ulongs required to store each base-B^2 digit
558 unsigned w = CEIL_DIV (2 * b, ULONG_BITS);
559 ZNP_ASSERT (w <= 2);
560
561 // number of ulongs needed to store each output coefficient
562 unsigned ww = CEIL_DIV (4 * b, ULONG_BITS);
563 ZNP_ASSERT (ww <= 3);
564
565 // mask = 2^c - 1, where c = number of bits used in high ulong of each
566 // base-B^2 digit
567 ulong mask;
568 if (w == 1)
569 mask = ((2 * b) < ULONG_BITS) ? ((1UL << (2 * b)) - 1) : (-1UL);
570 else // w == 2
571 mask = (1UL << ((2 * b) - ULONG_BITS)) - 1;
572
573 // Write f1(x) = f1e(x^2) + x * f1o(x^2)
574 // f2(x) = f2e(x^2) + x * f2o(x^2)
575 // h(x) = he(x^2) + x * ho(x^2)
576 // g(x) = ge(x^2) + x * go(x^2)
577 // "e" = even, "o" = odd
578
579 size_t n1o = n1 / 2;
580 size_t n1e = n1 - n1o;
581
582 size_t n2o = n2 / 2;
583 size_t n2e = n2 - n2o;
584
585 size_t n3 = n1 - n2 + 1; // length of g
586 size_t n3o = n3 / 2;
587 size_t n3e = n3 - n3o;
588
589 // directly compute coefficient of x^0 in ge(x)
590 ulong delo[3];
591 res[0] = diagonal_sum (delo, op1, op2, n2, ww, redc, mod);
592 if (n3 == 1)
593 return; // only need one coefficient of output
594
595 // directly compute coefficient of x^0 in go(x)
596 ulong dolo[3];
597 res[1] = diagonal_sum (dolo, op1 + 1, op2, n2, ww, redc, mod);
598 if (n3 == 2)
599 return; // only need two coefficients of output
600
601 // directly compute coefficient of x^(n3e - 1) in ge(x)
602 ulong dehi[3];
603 res[2*n3e - 2] = diagonal_sum (dehi, op1 + 2*n3e - 2, op2,
604 n2, ww, redc, mod);
605 if (n3 == 3)
606 return; // only need three coefficients of output
607
608 // directly compute coefficient of x^(n3o - 1) in go(x)
609 ulong dohi[3];
610 res[2*n3o - 1] = diagonal_sum (dohi, op1 + 2*n3o - 1, op2,
611 n2, ww, redc, mod);
612 if (n3 == 4)
613 return; // only need four coefficients of output
614
615 // In f2(B), the leading coefficient starts at bit position b * (n2 - 1)
616 // and has length 2*b, and the coefficients overlap so we need an extra bit
617 // for the carry: this gives (n2 + 1) * b + 1 bits.
618 size_t k2 = CEIL_DIV ((n2 + 1) * b + 1, GMP_NUMB_BITS);
619
620 // We need r = (n2 - 1) * b + 1 and s = (n3 + 1) * b.
621 // Note that p is non-negative (since k2 * GMP_NUMB_BITS >= (n2 + 1) * b
622 // >= (n2 - 1) * b - 1).
623 unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b - 1;
624
625 // For (*) we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b + 1.
626 size_t k1 = CEIL_DIV (p + (n1 + 1) * b + 1, GMP_NUMB_BITS);
627
628 size_t k3 = k1 - k2 + 3;
629 ZNP_ASSERT (k3 >= 5);
630
631 // allocate space
632 ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 5 * (k2 + k3));
633 mp_limb_t* v2_buf0 = limbs; // k2 limbs
634 mp_limb_t* v3_buf0 = v2_buf0 + k2; // k3 limbs
635 mp_limb_t* v2_buf1 = v3_buf0 + k3; // k2 limbs
636 mp_limb_t* v3_buf1 = v2_buf1 + k2; // k3 limbs
637 mp_limb_t* v2_buf2 = v3_buf1 + k3; // k2 limbs
638 mp_limb_t* v3_buf2 = v2_buf2 + k2; // k3 limbs
639 mp_limb_t* v2_buf3 = v3_buf2 + k3; // k2 limbs
640 mp_limb_t* v3_buf3 = v2_buf3 + k2; // k3 limbs
641 mp_limb_t* v2_buf4 = v3_buf3 + k3; // k2 limbs
642 mp_limb_t* v3_buf4 = v2_buf4 + k2; // k3 limbs
643
644 // arrange overlapping buffers to minimise memory use
645 // "p" = plus, "m" = minus
646 // "n" = normal order, "r" = reciprocal order
647 mp_limb_t* v1en = v2_buf1;
648 mp_limb_t* v1on = v2_buf2;
649 mp_limb_t* v1pn = v2_buf0;
650 mp_limb_t* v1mn = v2_buf1;
651 mp_limb_t* v2en = v2_buf3;
652 mp_limb_t* v2on = v2_buf4;
653 mp_limb_t* v2pn = v2_buf2;
654 mp_limb_t* v2mn = v2_buf3;
655 mp_limb_t* v3mn = v3_buf2;
656 mp_limb_t* v3pn = v3_buf3;
657 mp_limb_t* v3en = v3_buf4;
658 mp_limb_t* v3on = v3_buf3;
659
660 mp_limb_t* v1er = v2_buf1;
661 mp_limb_t* v1or = v2_buf2;
662 mp_limb_t* v1pr = v2_buf0;
663 mp_limb_t* v1mr = v2_buf1;
664 mp_limb_t* v2er = v2_buf3;
665 mp_limb_t* v2or = v2_buf4;
666 mp_limb_t* v2pr = v2_buf2;
667 mp_limb_t* v2mr = v2_buf3;
668 mp_limb_t* v3mr = v3_buf2;
669 mp_limb_t* v3pr = v3_buf1;
670 mp_limb_t* v3er = v3_buf0;
671 mp_limb_t* v3or = v3_buf1;
672
673 ZNP_FASTALLOC (z, ulong, 6624, 2 * w * (n3e - 1));
674 ulong* zn = z;
675 ulong* zr = z + w * (n3e - 1);
676
677 int v3m_neg;
678
679 // -------------------------------------------------------------------------
680 // "normal" evaluation point
681
682 // evaluate 2^p * f1e(B^2) and 2^p * B * f1o(B^2)
683 zn_array_pack (v1en, op1, n1e, 2, 2 * b, p, k1);
684 zn_array_pack (v1on, op1 + 1, n1o, 2, 2 * b, p + b, k1);
685
686 // compute 2^p * f1(B) = 2^p * f1e(B^2) + 2^p * B * f1o(B^2)
687 // and 2^p * |f1(-B)| = |2^p * f1e(B^2) - 2^p * B * f1o(B^2)|
688 ZNP_ASSERT_NOCARRY (mpn_add_n (v1pn, v1en, v1on, k1));
689 v3m_neg = signed_mpn_sub_n (v1mn, v1en, v1on, k1);
690
691 // evaluate f2e(B^2) and B * f2o(B^2)
692 zn_array_pack (v2en, op2, n2e, 2, 2 * b, 0, k2);
693 zn_array_pack (v2on, op2 + 1, n2o, 2, 2 * b, b, k2);
694
695 // compute f2(B) = f2e(B^2) + B * f2o(B^2)
696 // and |f2(-B)| = |f2e(B^2) - B * f2o(B^2)|
697 ZNP_ASSERT_NOCARRY (mpn_add_n (v2pn, v2en, v2on, k2));
698 v3m_neg ^= signed_mpn_sub_n (v2mn, v2en, v2on, k2);
699
700 // compute segment starting at bit index r of
701 // h(B) = f1(B) * f2(B)
702 // and |h(-B)| = |f1(-B)| * |f2(-B)|
703 // hn_neg is set if h(-B) is negative
704 ZNP_mpn_mulmid (v3mn, v1mn, k1, v2mn, k2);
705 ZNP_mpn_mulmid (v3pn, v1pn, k1, v2pn, k2);
706
707 // compute segments starting at bit index r of
708 // 2 * he(B^2) = h(B) + h(-B)
709 // and 2 * B * ho(B^2) = h(B) - h(-B)
710 // ie. the segments of he(B^2) and B * ho(B^2) starting at bit index r - 1.
711
712 // If n2 is odd, the former encodes ge(x) and the latter encodes go(x).
713 // Otherwise the situation is reversed. We write the results to v3en/v3on
714 // accordingly.
715
716 // Note that when we do the addition (resp. subtraction) below, we might
717 // miss a carry (resp. borrow) from the unknown previous limbs. We arrange
718 // so that the answers are either correct or one too big, by adding 1
719 // appropriately.
720
721 if (v3m_neg ^ (n2 & 1))
722 {
723 mpn_add_n (v3en + 2, v3pn + 2, v3mn + 2, k3 - 4); // miss carry?
724 mpn_add_1 (v3en + 2, v3en + 2, k3 - 4, 1);
725 mpn_sub_n (v3on + 2, v3pn + 2, v3mn + 2, k3 - 4); // miss borrow?
726 }
727 else
728 {
729 mpn_sub_n (v3en + 2, v3pn + 2, v3mn + 2, k3 - 4);
730 mpn_add_n (v3on + 2, v3pn + 2, v3mn + 2, k3 - 4);
731 mpn_add_1 (v3on + 2, v3on + 2, k3 - 4, 1);
732 }
733
734 // remove x^0 and x^(n3e - 1) coefficients of ge(x),
735 // and x^0 and x^(n3o - 1) coefficients of go(x).
736 subtract_ulongs (v3en + 2, k3 - 4, 0, delo, ww);
737 subtract_ulongs (v3en + 2, k3 - 4, (2*n3e - 2) * b, dehi, ww);
738 subtract_ulongs (v3on + 2, k3 - 4, b, dolo, ww);
739 subtract_ulongs (v3on + 2, k3 - 4, (2*n3o - 1) * b, dohi, ww);
740
741 // At this stage, the integer
742 // g[2] + g[4]*B^2 + ... + g[2*n3e - 4]*B^(2*n3e - 6)
743 // appears in v3en + 2, starting at bit index 2*b, occupying
744 // (2 * n3e - 2) * b bits. The integer
745 // g[3] + g[5]*B^2 + ... + g[2*n3o - 3]*B^(2*n3o - 6)
746 // appears in v3on + 2, starting at bit index 3*b, occupying
747 // (2 * n3o - 2) * b bits.
748
749 // -------------------------------------------------------------------------
750 // "reciprocal" evaluation point
751
752 // evaluate 2^p * B^(n1-1) * f1e(1/B^2) and 2^p * B^(n1-2) * f1o(B^2)
753 zn_array_pack (v1er, op1 + 2*(n1e - 1), n1e, -2, 2 * b,
754 p + ((n1 & 1) ? 0 : b), k1);
755 zn_array_pack (v1or, op1 + 1 + 2*(n1o - 1), n1o, -2, 2 * b,
756 p + ((n1 & 1) ? b : 0), k1);
757
758 // compute 2^p * B^(n1-1) * f1(1/B) =
759 // 2^p * B^(n1-1) * f1e(1/B^2) + 2^p * B^(n1-2) * f1o(B^2)
760 // and |2^p * B^(n1-1) * f1(-1/B)| =
761 // |2^p * B^(n1-1) * f1e(1/B^2) - 2^p * B^(n1-2) * f1o(B^2)|
762 ZNP_ASSERT_NOCARRY (mpn_add_n (v1pr, v1er, v1or, k1));
763 v3m_neg = signed_mpn_sub_n (v1mr, v1er, v1or, k1);
764
765 // evaluate B^(n2-1) * f2e(1/B^2) and B^(n2-2) * f2o(B^2)
766 zn_array_pack (v2er, op2 + 2*(n2e - 1), n2e, -2, 2 * b,
767 (n2 & 1) ? 0 : b, k2);
768 zn_array_pack (v2or, op2 + 1 + 2*(n2o - 1), n2o, -2, 2 * b,
769 (n2 & 1) ? b : 0, k2);
770
771 // compute B^(n2-1) * f2(1/B) =
772 // B^(n2-1) * f2e(1/B^2) + B^(n2-2) * f2o(B^2)
773 // and |B^(n2-1) * f2(-1/B)| =
774 // |B^(n2-1) * f2e(1/B^2) - B^(n2-2) * f2o(B^2)|
775 ZNP_ASSERT_NOCARRY (mpn_add_n (v2pr, v2er, v2or, k2));
776 v3m_neg ^= signed_mpn_sub_n (v2mr, v2er, v2or, k2);
777
778 // compute segment starting at bit index r of
779 // B^(n3-1) * h(1/B) = (B^(n1-1) * f1(1/B)) * (B^(n2-1) * f2(1/B))
780 // and |B^(n3-1) * h(-1/B)| = |B^(n1-1) * f1(-1/B)| * |B^(n2-1) * f2(-1/B)|
781 // hr_neg is set if h(-1/B) is negative
782 ZNP_mpn_mulmid (v3mr, v1mr, k1, v2mr, k2);
783 ZNP_mpn_mulmid (v3pr, v1pr, k1, v2pr, k2);
784
785 // compute segments starting at bit index r of
786 // 2 * B^(n3-1) * he(1/B^2) = B^(n3-1) * h(1/B) + B^(n3-1) * h(-1/B)
787 // and 2 * B^(n3-2) * ho(1/B^2) = B^(n3-1) * h(1/B) - B^(n3-1) * h(-1/B)
788 // ie. the segments of B^(n3-1) * he(1/B^2) and B^(n3-2) * ho(1/B^2)
789 // starting at bit index r - 1.
790
791 // If n2 is odd, the former encodes ge(x) and the latter encodes go(x).
792 // Otherwise the situation is reversed. We write the results to v3er/v3or
793 // accordingly.
794
795 if (v3m_neg ^ (n2 & 1))
796 {
797 mpn_add_n (v3er + 2, v3pr + 2, v3mr + 2, k3 - 4); // miss carry?
798 mpn_add_1 (v3er + 2, v3er + 2, k3 - 4, 1);
799 mpn_sub_n (v3or + 2, v3pr + 2, v3mr + 2, k3 - 4); // miss borrow?
800 }
801 else
802 {
803 mpn_sub_n (v3er + 2, v3pr + 2, v3mr + 2, k3 - 4);
804 mpn_add_n (v3or + 2, v3pr + 2, v3mr + 2, k3 - 4);
805 mpn_add_1 (v3or + 2, v3or + 2, k3 - 4, 1);
806 }
807
808 unsigned s = (n3 & 1) ? 0 : b;
809
810 // remove x^0 and x^(n3e - 1) coefficients of ge(x),
811 // and x^0 and x^(n3o - 1) coefficients of go(x).
812 subtract_ulongs (v3er + 2, k3 - 4, s, dehi, ww);
813 subtract_ulongs (v3er + 2, k3 - 4, (2*n3e - 2) * b + s, delo, ww);
814 subtract_ulongs (v3or + 2, k3 - 4, b - s, dohi, ww);
815 subtract_ulongs (v3or + 2, k3 - 4, (2*n3o - 2) * b + b - s, dolo, ww);
816
817 // At this stage, the integer
818 // g[2*n3e - 4] + g[2*n3e - 6]*B^2 + ... + g[2]*B^(2*n3e - 6)
819 // appears in v3er + 2, starting at bit index 2*b if n3 is odd, or 3*b
820 // if n3 is even, and occupying (2 * n3e - 2) * b bits. The integer
821 // g[2*n3o - 3] + g[2*n3o - 5]*B^2 + ... + g[3]*B^(2*n3o - 6)
822 // appears in v3or + 2, starting at bit index 3*b if n3 is odd, or 2*b
823 // if n3 is even, and occupying (2 * n3o - 2) * b bits.
824
825 // -------------------------------------------------------------------------
826 // combine "normal" and "reciprocal" information
827
828 // decompose relevant portion of ge(B^2) and ge(1/B^2) into base-B^2 digits
829 zn_array_unpack_SAFE (zn, v3en + 2, n3e - 1, 2 * b, 2 * b, k3 - 4);
830 zn_array_unpack_SAFE (zr, v3er + 2, n3e - 1, 2 * b, 2 * b + s, k3 - 4);
831
832 // combine ge(B^2) and ge(1/B^2) information to get even coefficients of g
833 zn_array_recover_reduce (res + 2, 2, zn, zr, n3e - 2, 2 * b, redc, mod);
834
835 // decompose relevant portion of go(B^2) and go(1/B^2) into base-B^2 digits
836 zn_array_unpack_SAFE (zn, v3on + 2, n3o - 1, 2 * b, 3 * b, k3 - 4);
837 zn_array_unpack_SAFE (zr, v3or + 2, n3o - 1, 2 * b, 3 * b - s, k3 - 4);
838
839 // combine go(B^2) and go(1/B^2) information to get odd coefficients of g
840 zn_array_recover_reduce (res + 3, 2, zn, zr, n3o - 2, 2 * b, redc, mod);
841
842 ZNP_FASTFREE (z);
843 ZNP_FASTFREE (limbs);
844 }
845
846
847 // end of file ****************************************************************
848