1 /*
2    mulmid_ks.c:  polynomial middle products by Kronecker substitution
3 
4    Copyright (C) 2007, 2008, David Harvey
5 
6    This file is part of the zn_poly library (version 0.9).
7 
8    This program is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation, either version 2 of the License, or
11    (at your option) version 3 of the License.
12 
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17 
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #include "zn_poly_internal.h"
24 
25 
26 /*
27    In the routines below, we denote by f1(x) and f2(x) the input polynomials
28    op1[0, n1) and op2[0, n2), and by h(x) their product in Z[x].
29 
30    We write h(x) = LO(x) + x^(n2-1) * g(x) + x^n1 * HI(x), where
31    len(LO) = len(HI) = n2 - 1 and len(g) = n1 - n2 + 1. Our goal is to
32    compute the middle segment g.
33 
34    The basic strategy is: if X is an evaluation point (i.e. X = 2^b, -2^b,
35    2^(-b) or -2^(-b)) then g(X) corresponds roughly to the integer middle
36    product of f1(X) and f2(X), and we will use mpn_mulmid() to compute the
37    latter.
38 
39    Unfortunately there are some complications.
40 
41    First, mpn_mulmid() works in terms of whole limb counts, not bit counts,
42    and moreover the first two and last two limbs of the output of mpn_mulmid()
43    are always garbage. We handle this issue using zero-padding as follows.
44    Suppose that we need s bits of g(X) starting at bit index r. We compute
45    f2(X) as usual. Let k2 = number of limbs used to store f2(X). Instead of
46    evaluating f1(X), we evaluate 2^p * f1(X), i.e. zero-pad by p bits, where
47 
48       p = (k2 + 1) * GMP_NUMB_BITS - r.
49 
50    (We will verify in each case that p >= 0.) This shifts g(X) left by p bits,
51    and ensures that bit #r of g(X) starts exactly at the first bit of the
52    third limb of the output of mpn_mulmid(). Let k1 = number of limbs used to
53    store f1(X). To be guaranteed of obtaining s correct bits of g(X), we need
54    to have
55 
56       (k1 - k2 - 1) * GMP_NUMB_BITS >= s,
57 
58    or equivalently
59 
60       k1 * GMP_NUMB_BITS >= p + r + s.           (*)
61 
62    We zero-pad 2^p * f1(X) on the right to ensure that (*) holds. In every
63    case, it turns out that the total amount of zero-padding is O(1) bits.
64 
65    Second, in the "reciprocal" variants (KS3 and KS4) there is the problem of
66    overlapping coefficients, e.g. when we compute the integer middle product,
67    the low bits of g(X) are polluted by the high bits of LO(X). To deal with
68    this we need to compute the low coefficient of g(X) separately, and remove
69    its effect from the overlapping portion. Similarly at the other end. The
70    diagonal() function accomplishes this.
71 
72 */
73 
74 
75 
76 /*
77    Middle product using Kronecker substitution at 2^b.
78 */
79 void
zn_array_mulmid_KS1(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)80 zn_array_mulmid_KS1 (ulong* res,
81                      const ulong* op1, size_t n1,
82                      const ulong* op2, size_t n2,
83                      int redc, const zn_mod_t mod)
84 {
85    ZNP_ASSERT (n2 >= 1);
86    ZNP_ASSERT (n1 >= n2);
87    ZNP_ASSERT (n1 <= ULONG_MAX);
88    ZNP_ASSERT ((mod->m & 1) || !redc);
89 
90    // length of g
91    size_t n3 = n1 - n2 + 1;
92 
93    // bits in each output coefficient
94    unsigned b = 2 * mod->bits + ceil_lg (n2);
95 
96    // number of ulongs required to store each output coefficient
97    unsigned w = CEIL_DIV (b, ULONG_BITS);
98    ZNP_ASSERT (w <= 3);
99 
100    // number of limbs needed to store f2(2^b)
101    size_t k2 = CEIL_DIV (n2 * b, GMP_NUMB_BITS);
102 
103    // We need r = (n2 - 1) * b and s = (n1 - n2 + 1) * b. Note that p is
104    // non-negative since k2 * GMP_NUMB_BITS >= n2 * b.
105    unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b;
106 
107    // For (*) to hold we need k1 * GMP_NUMB_BITS >= p + n1 * b.
108    size_t k1 = CEIL_DIV (p + n1 * b, GMP_NUMB_BITS);
109 
110    // allocate space
111    ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 2 * k1 + 3);
112    mp_limb_t* v1 = limbs;       // k1 limbs
113    mp_limb_t* v2 = v1 + k1;     // k2 limbs
114    mp_limb_t* v3 = v2 + k2;     // k1 - k2 + 3 limbs
115 
116    // evaluate 2^p * f1(2^b) and f2(2^b)
117    zn_array_pack (v1, op1, n1, 1, b, p, 0);
118    zn_array_pack (v2, op2, n2, 1, b, 0, 0);
119 
120    // compute segment of f1(2^b) * f2(2^b) starting at bit index r
121    ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
122 
123    // unpack coefficients of g, and reduce mod m
124    ZNP_FASTALLOC (z, ulong, 6624, n3 * w);
125    zn_array_unpack_SAFE (z, v3 + 2, n3, b, 0, k1 - k2 - 1);
126    array_reduce (res, 1, z, n3, w, redc, mod);
127 
128    ZNP_FASTFREE (z);
129    ZNP_FASTFREE (limbs);
130 }
131 
132 
133 
134 /*
135    Middle product using Kronecker substitution at 2^b and -2^b.
136 */
137 void
zn_array_mulmid_KS2(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)138 zn_array_mulmid_KS2 (ulong* res,
139                      const ulong* op1, size_t n1,
140                      const ulong* op2, size_t n2,
141                      int redc, const zn_mod_t mod)
142 {
143    ZNP_ASSERT (n2 >= 1);
144    ZNP_ASSERT (n1 >= n2);
145    ZNP_ASSERT (n1 <= ULONG_MAX);
146    ZNP_ASSERT ((mod->m & 1) || !redc);
147 
148    if (n2 == 1)
149    {
150       // code below needs n2 > 1, so fall back on scalar multiplication
151       _zn_array_scalar_mul (res, op1, n1, op2[0], redc, mod);
152       return;
153    }
154 
155    // bits in each output coefficient
156    unsigned bits = 2 * mod->bits + ceil_lg (n2);
157 
158    // we're evaluating at x = B and -B, where B = 2^b, and b = ceil(bits / 2)
159    unsigned b = (bits + 1) / 2;
160 
161    // number of ulongs required to store each output coefficient
162    unsigned w = CEIL_DIV (2 * b, ULONG_BITS);
163    ZNP_ASSERT (w <= 3);
164 
165    // Write f1(x) = f1e(x^2) + x * f1o(x^2)
166    //       f2(x) = f2e(x^2) + x * f2o(x^2)
167    //        h(x) =  he(x^2) + x *  ho(x^2)
168    //        g(x) =  ge(x^2) + x *  go(x^2)
169    // "e" = even, "o" = odd
170 
171    // When evaluating f2e(B^2) and B * f2o(B^2) the bit-packing routine needs
172    // room for the last chunk of 2b bits, so we need to allow room for
173    // (n2 + 1) * b bits.
174    size_t k2 = CEIL_DIV ((n2 + 1) * b, GMP_NUMB_BITS);
175 
176    // We need r = (n2 - 2) * b + 1 and s = (n1 - n2 + 3) * b.
177    // Note that p is non-negative (since k2 * GMP_NUMB_BITS >= (n2 + 1) * b
178    // >= (n2 - 2) * b - 1).
179    unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 2) * b - 1;
180 
181    // For (*) to hold we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b + 1.
182    // Also, to ensure that there is enough room for bit-packing (as for k2
183    // above), we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b; this condition
184    // is subsumed by the first one.
185    size_t k1 = CEIL_DIV (p + (n1 + 1) * b + 1, GMP_NUMB_BITS);
186 
187    size_t k3 = k1 - k2 + 3;
188    ZNP_ASSERT (k3 >= 5);
189 
190    // allocate space
191    ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 3 * k3 + 5 * k2);
192    mp_limb_t* v2_buf0 = limbs;             // k2 limbs
193    mp_limb_t* v3_buf0 = v2_buf0 + k2;      // k3 limbs
194    mp_limb_t* v2_buf1 = v3_buf0 + k3;      // k2 limbs
195    mp_limb_t* v3_buf1 = v2_buf1 + k2;      // k3 limbs
196    mp_limb_t* v2_buf2 = v3_buf1 + k3;      // k2 limbs
197    mp_limb_t* v3_buf2 = v2_buf2 + k2;      // k3 limbs
198    mp_limb_t* v2_buf3 = v3_buf2 + k3;      // k2 limbs
199    mp_limb_t* v2_buf4 = v2_buf3 + k2;      // k2 limbs
200 
201    // arrange overlapping buffers to minimise memory use
202    // "p" = plus, "m" = minus
203    mp_limb_t* v1e = v2_buf0;
204    mp_limb_t* v1o = v2_buf2;
205    mp_limb_t* v1p = v2_buf1;
206    mp_limb_t* v1m = v2_buf0;
207    mp_limb_t* v2e = v2_buf2;
208    mp_limb_t* v2o = v2_buf4;
209    mp_limb_t* v2p = v2_buf3;
210    mp_limb_t* v2m = v2_buf2;
211    mp_limb_t* v3m = v3_buf2;
212    mp_limb_t* v3p = v3_buf0;
213    mp_limb_t* v3e = v3_buf1;
214    mp_limb_t* v3o = v3_buf1;
215 
216    // length of g
217    size_t n3 = n1 - n2 + 1;
218 
219    ZNP_FASTALLOC (z, ulong, 6624, w * ((n3 + 1) / 2));
220 
221    // evaluate 2^p * f1e(B^2) and 2^p * B * f1o(B^2)
222    zn_array_pack (v1e, op1, (n1 + 1) / 2, 2, 2 * b, p, k1);
223    zn_array_pack (v1o, op1 + 1, n1 / 2, 2, 2 * b, p + b, k1);
224 
225    // compute   2^p * f1(B)   =  2^p * (f1e(B^2) + B * f1o(B^2))
226    //     and  |2^p * f1(-B)| = |2^p * (f1e(B^2) - B * f1o(B^2))|
227    // v3m_neg is set if f1(-B) is negative
228    ZNP_ASSERT_NOCARRY (mpn_add_n (v1p, v1e, v1o, k1));
229    int v3m_neg = signed_mpn_sub_n (v1m, v1e, v1o, k1);
230 
231    // evaluate f2e(B^2) and B * f2o(B^2)
232    zn_array_pack (v2e, op2, (n2 + 1) / 2, 2, 2 * b, 0, k2);
233    zn_array_pack (v2o, op2 + 1, n2 / 2, 2, 2 * b, b, k2);
234 
235    // compute    f2(B)   =   f2e(B^2) + B * f2o(B^2)
236    //     and   |f2(-B)| =  |f2e(B^2) - B * f2o(B^2)|
237    // v3m_neg is set if f1(-B) and f2(-B) have opposite signs
238    ZNP_ASSERT_NOCARRY (mpn_add_n (v2p, v2e, v2o, k2));
239    v3m_neg ^= signed_mpn_sub_n (v2m, v2e, v2o, k2);
240 
241    // compute segment starting at bit index r of
242    //           h(B)   =  f1(B)   *  f2(B)
243    //    and   |h(-B)| = |f1(-B)| * |f2(-B)|
244    // v3m_neg is set if h(-B) is negative
245    ZNP_mpn_mulmid (v3m, v1m, k1, v2m, k2);
246    ZNP_mpn_mulmid (v3p, v1p, k1, v2p, k2);
247 
248    // compute segment starting at bit index r of
249    //         2     * he(B^2) = h(B) + h(-B)       (if n2 is odd)
250    //    or   2 * B * ho(B^2) = h(B) - h(-B)       (if n2 is even)
251    // i.e. the segment of he(B^2) or B * ho(B^2) starting at bit index
252    // r - 1 = (n2 - 2) * b. This encodes the coefficients of ge(x).
253 
254    // Note that when we do the addition (resp. subtraction) below, we might
255    // miss a carry (resp. borrow) from the unknown previous limbs. We arrange
256    // so that the answers are either correct or one too big, by adding 1
257    // appropriately.
258 
259    if (v3m_neg ^ (n2 & 1))
260    {
261       mpn_add_n (v3e, v3p + 2, v3m + 2, k3 - 4);    // miss carry?
262       mpn_add_1 (v3e, v3e, k3 - 4, 1);
263    }
264    else
265       mpn_sub_n (v3e, v3p + 2, v3m + 2, k3 - 4);    // miss borrow?
266 
267    // Now we extract ge(x). The first coefficient we want is the coefficient
268    // of x^(n2 - 1) in h(x); this starts at bit b index in v3e. We want
269    // ceil(n3 / 2) coefficients altogether, with 2b bits each. This accounts
270    // for the definition of s.
271 
272    // Claim: if we committed a "one-too-big" error above, this does not affect
273    // the coefficients we extract. Proof: the first b bits of v3e are the top
274    // half of the coefficient of x^(n2 - 2) in h(x). The base-B digit in those
275    // b bits has value at most B - 2. Therefore adding 1 to it will never
276    // overflow those b bits.
277 
278    zn_array_unpack_SAFE (z, v3e, (n3 + 1) / 2, 2 * b, b, k3 - 4);
279    array_reduce (res, 2, z, (n3 + 1) / 2, w, redc, mod);
280 
281    // Now repeat all the above for go(x).
282 
283    if (v3m_neg ^ (n2 & 1))
284       mpn_sub_n (v3o, v3p + 2, v3m + 2, k3 - 4);
285    else
286    {
287       mpn_add_n (v3o, v3p + 2, v3m + 2, k3 - 4);
288       mpn_add_1 (v3o, v3o, k3 - 4, 1);
289    }
290 
291    zn_array_unpack_SAFE (z, v3o, n3 / 2, 2 * b, 2 * b, k3 - 4);
292    array_reduce (res + 1, 2, z, n3 / 2, w, redc, mod);
293 
294    ZNP_FASTFREE (z);
295    ZNP_FASTFREE (limbs);
296 }
297 
298 
299 
300 /*
301    Computes the sum
302 
303       op1[0] * op2[n-1] + ... + op1[n-1] * op2[0]
304 
305    as an *integer*. The result is assumed to fit into w ulongs (where
306    1 <= w <= 3), and is written to res[0, w). The return value is the result
307    reduced modulo mod->m (using redc if requested).
308 */
309 #define diagonal_sum \
310     ZNP_diagonal_sum
311 ulong
diagonal_sum(ulong * res,const ulong * op1,const ulong * op2,size_t n,unsigned w,int redc,const zn_mod_t mod)312 diagonal_sum (ulong* res, const ulong* op1, const ulong* op2,
313               size_t n, unsigned w, int redc, const zn_mod_t mod)
314 {
315    ZNP_ASSERT (n >= 1);
316    ZNP_ASSERT (w >= 1);
317    ZNP_ASSERT (w <= 3);
318 
319    size_t i;
320 
321    if (w == 1)
322    {
323       ulong sum = op1[0] * op2[n - 1];
324 
325       for (i = 1; i < n; i++)
326          sum += op1[i] * op2[n - 1 - i];
327 
328       res[0] = sum;
329       return redc ? zn_mod_reduce_redc (sum, mod) : zn_mod_reduce (sum, mod);
330    }
331    else if (w == 2)
332    {
333       ulong lo, hi, sum0, sum1;
334 
335       ZNP_MUL_WIDE (sum1, sum0, op1[0], op2[n - 1]);
336 
337       for (i = 1; i < n; i++)
338       {
339          ZNP_MUL_WIDE (hi, lo, op1[i], op2[n - 1 - i]);
340          ZNP_ADD_WIDE (sum1, sum0, sum1, sum0, hi, lo);
341       }
342 
343       res[0] = sum0;
344       res[1] = sum1;
345       return redc ? zn_mod_reduce2_redc (sum1, sum0, mod)
346                   : zn_mod_reduce2 (sum1, sum0, mod);
347    }
348    else    // w == 3
349    {
350       ulong lo, hi, sum0, sum1, sum2 = 0;
351 
352       ZNP_MUL_WIDE (sum1, sum0, op1[0], op2[n - 1]);
353 
354       for (i = 1; i < n; i++)
355       {
356          ZNP_MUL_WIDE (hi, lo, op1[i], op2[n - 1 - i]);
357          ZNP_ADD_WIDE (sum1, sum0, sum1, sum0, hi, lo);
358          // carry into third limb:
359          if (sum1 <= hi)
360             sum2 += (sum1 < hi || sum0 < lo);
361       }
362 
363       res[0] = sum0;
364       res[1] = sum1;
365       res[2] = sum2;
366       return redc ? zn_mod_reduce3_redc (sum2, sum1, sum0, mod)
367                   : zn_mod_reduce3 (sum2, sum1, sum0, mod);
368    }
369 }
370 
371 
372 /*
373    Inplace subtract 2^i*x from res[0, n).
374    x is an array of w ulongs, where 1 <= w <= 3.
375    i may be any non-negative integer.
376 */
377 #define subtract_ulongs \
378     ZNP_subtract_ulongs
379 void
subtract_ulongs(mp_limb_t * res,size_t n,size_t i,ulong * x,unsigned w)380 subtract_ulongs (mp_limb_t* res, size_t n, size_t i, ulong* x, unsigned w)
381 {
382    ZNP_ASSERT (w >= 1);
383    ZNP_ASSERT (w <= 3);
384 
385 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
386    size_t k = i / GMP_NUMB_BITS;
387 
388    if (k >= n)
389       return;
390 
391    unsigned j = i % GMP_NUMB_BITS;
392 
393    if (j == 0)
394       mpn_sub (res + k, res + k, n - k, (mp_srcptr) x, ZNP_MIN (n - k, w));
395    else
396    {
397       mp_limb_t y[4];
398       y[w] = mpn_lshift (y, (mp_srcptr) x, w, j);
399       mpn_sub (res + k, res + k, n - k, y, ZNP_MIN (n - k, w + 1));
400    }
401 #else
402 #error Not nails-safe yet
403 #endif
404 }
405 
406 
407 /*
408    Middle product using Kronecker substitution at 2^b and 2^(-b).
409 
410    Note: this routine does not appear to be competitive in practice with the
411    other KS routines. It's here just for fun.
412 */
413 void
zn_array_mulmid_KS3(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)414 zn_array_mulmid_KS3 (ulong* res,
415                      const ulong* op1, size_t n1,
416                      const ulong* op2, size_t n2,
417                      int redc, const zn_mod_t mod)
418 {
419    ZNP_ASSERT (n2 >= 1);
420    ZNP_ASSERT (n1 >= n2);
421    ZNP_ASSERT (n1 <= ULONG_MAX);
422    ZNP_ASSERT ((mod->m & 1) || !redc);
423 
424    // length of g
425    size_t n3 = n1 - n2 + 1;
426 
427    // bits in each output coefficient
428    unsigned bits = 2 * mod->bits + ceil_lg (n2);
429 
430    // we're evaluating at x = B and 1/B, where B = 2^b, and b = ceil(bits / 2)
431    unsigned b = (bits + 1) / 2;
432 
433    // number of ulongs required to store each base-B digit
434    unsigned w = CEIL_DIV (b, ULONG_BITS);
435    ZNP_ASSERT (w <= 2);
436 
437    // number of ulongs needed to store each output coefficient
438    unsigned ww = CEIL_DIV (2 * b, ULONG_BITS);
439    ZNP_ASSERT (ww <= 3);
440 
441    // directly compute coefficient of x^0 in g(x)
442    ulong dlo[3];
443    res[0] = diagonal_sum (dlo, op1, op2, n2, ww, redc, mod);
444    if (n3 == 1)
445       return;      // only need one coefficient of output
446 
447    // directly compute coefficient of x^(n3-1) in g(x)
448    ulong dhi[3];
449    res[n3 - 1] = diagonal_sum (dhi, op1 + n3 - 1, op2, n2, ww, redc, mod);
450    if (n3 == 2)
451       return;      // only need two coefficients of output
452 
453    // limbs needed to store f2(B) and B^(n2-1) * f2(1/B)
454    size_t k2 = CEIL_DIV (n2 * b, GMP_NUMB_BITS);
455 
456    // we need r = (n2 - 1) * b and s = (n1 - n2 + 1) * b, thus p is:
457    unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b;
458 
459    // for (*) we need k1 * GMP_NUMB_BITS >= p + n1 * b
460    size_t k1 = CEIL_DIV (p + n1 * b, GMP_NUMB_BITS);
461 
462    size_t k3 = k1 - k2 + 3;
463    ZNP_ASSERT (k3 >= 5);
464 
465    // allocate space
466    ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 2 * k1 + 3);
467    mp_limb_t* v1 = limbs;        // k1 limbs
468    mp_limb_t* v2 = v1 + k1;      // k2 limbs
469    mp_limb_t* v3 = v2 + k2;      // k1 - k2 + 3 limbs
470 
471    ZNP_FASTALLOC (z, ulong, 6624, 2 * w * n3);
472    // "n" = normal order, "r" = reciprocal order
473    ulong* zn = z;
474    ulong* zr = z + w * n3;
475 
476    // -------------------------------------------------------------------------
477    //     "normal" evaluation point
478 
479    // evaluate 2^p * f1(B) and f2(B)
480    zn_array_pack (v1, op1, n1, 1, b, p, k1);
481    zn_array_pack (v2, op2, n2, 1, b, 0, k2);
482 
483    // compute segment starting at bit index r of h(B) = f1(B) * f2(B)
484    ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
485 
486    // remove x^0 and x^(n3 - 1) coefficient of g(x)
487    subtract_ulongs (v3 + 2, k3 - 4, 0, dlo, ww);
488    subtract_ulongs (v3 + 2, k3 - 4, (n3 - 1) * b, dhi, ww);
489 
490    // decompose relevant portion of h(B) into base-B digits
491    zn_array_unpack_SAFE (zn, v3 + 2, n3 - 1, b, b, k3 - 4);
492 
493    // At this stage zn contains (n3 - 1) base-B digits, representing the
494    // integer g[1] + g[2]*B + ... + g[n3-2]*B^(n3-3)
495 
496    // -------------------------------------------------------------------------
497    //     "reciprocal" evaluation point
498 
499    // evaluate 2^p * B^(n1-1) * f1(1/B) and B^(n2-1) * f2(B)
500    zn_array_pack (v1, op1 + n1 - 1, n1, -1, b, p, k1);
501    zn_array_pack (v2, op2 + n2 - 1, n2, -1, b, 0, k2);
502 
503    // compute segment starting at bit index r of B^(n1+n2-2) * h(1/B) =
504    // (B^(n1-1) * f1(1/B)) * (B^(n2-1) * f2(1/B))
505    ZNP_mpn_mulmid (v3, v1, k1, v2, k2);
506 
507    // remove x^0 and x^(n3 - 1) coefficient of g(x)
508    subtract_ulongs (v3 + 2, k3 - 4, 0, dhi, ww);
509    subtract_ulongs (v3 + 2, k3 - 4, (n3 - 1) * b, dlo, ww);
510 
511    // decompose relevant portion of B^(n1+n2-2) * h(1/B) into base-B digits
512    zn_array_unpack_SAFE (zr, v3 + 2, n3 - 1, b, b, k3 - 4);
513 
514    // At this stage zr contains (n3 - 1) base-B digits, representing the
515    // integer g[n3-2] + g[n3-3]*B + ... + g[1]*B^(n3-3)
516 
517    // -------------------------------------------------------------------------
518    //     combine "normal" and "reciprocal" information
519 
520    zn_array_recover_reduce (res + 1, 1, zn, zr, n3 - 2, b, redc, mod);
521 
522    ZNP_FASTFREE (z);
523    ZNP_FASTFREE (limbs);
524 }
525 
526 
527 
528 /*
529    Middle product using Kronecker substitution at 2^b, -2^b, 2^(-b)
530    and -2^(-b).
531 */
532 void
zn_array_mulmid_KS4(ulong * res,const ulong * op1,size_t n1,const ulong * op2,size_t n2,int redc,const zn_mod_t mod)533 zn_array_mulmid_KS4 (ulong* res,
534                      const ulong* op1, size_t n1,
535                      const ulong* op2, size_t n2,
536                      int redc, const zn_mod_t mod)
537 {
538    ZNP_ASSERT (n2 >= 1);
539    ZNP_ASSERT (n1 >= n2);
540    ZNP_ASSERT (n1 <= ULONG_MAX);
541    ZNP_ASSERT ((mod->m & 1) || !redc);
542 
543    if (n2 == 1)
544    {
545       // code below needs n2 > 1, so fall back on scalar multiplication
546       _zn_array_scalar_mul (res, op1, n1, op2[0], redc, mod);
547       return;
548    }
549 
550    // bits in each output coefficient
551    unsigned bits = 2 * mod->bits + ceil_lg (n2);
552 
553    // we're evaluating at x = B, -B, 1/B, -1/B,
554    // where B = 2^b, and b = ceil(bits / 4)
555    unsigned b = (bits + 3) / 4;
556 
557    // number of ulongs required to store each base-B^2 digit
558    unsigned w = CEIL_DIV (2 * b, ULONG_BITS);
559    ZNP_ASSERT (w <= 2);
560 
561    // number of ulongs needed to store each output coefficient
562    unsigned ww = CEIL_DIV (4 * b, ULONG_BITS);
563    ZNP_ASSERT (ww <= 3);
564 
565    // mask = 2^c - 1, where c = number of bits used in high ulong of each
566    // base-B^2 digit
567    ulong mask;
568    if (w == 1)
569       mask = ((2 * b) < ULONG_BITS) ? ((1UL << (2 * b)) - 1) : (-1UL);
570    else   // w == 2
571       mask = (1UL << ((2 * b) - ULONG_BITS)) - 1;
572 
573    // Write f1(x) = f1e(x^2) + x * f1o(x^2)
574    //       f2(x) = f2e(x^2) + x * f2o(x^2)
575    //        h(x) =  he(x^2) + x *  ho(x^2)
576    //        g(x) =  ge(x^2) + x *  go(x^2)
577    // "e" = even, "o" = odd
578 
579    size_t n1o = n1 / 2;
580    size_t n1e = n1 - n1o;
581 
582    size_t n2o = n2 / 2;
583    size_t n2e = n2 - n2o;
584 
585    size_t n3 = n1 - n2 + 1;   // length of g
586    size_t n3o = n3 / 2;
587    size_t n3e = n3 - n3o;
588 
589    // directly compute coefficient of x^0 in ge(x)
590    ulong delo[3];
591    res[0] = diagonal_sum (delo, op1, op2, n2, ww, redc, mod);
592    if (n3 == 1)
593       return;      // only need one coefficient of output
594 
595    // directly compute coefficient of x^0 in go(x)
596    ulong dolo[3];
597    res[1] = diagonal_sum (dolo, op1 + 1, op2, n2, ww, redc, mod);
598    if (n3 == 2)
599       return;      // only need two coefficients of output
600 
601    // directly compute coefficient of x^(n3e - 1) in ge(x)
602    ulong dehi[3];
603    res[2*n3e - 2] = diagonal_sum (dehi, op1 + 2*n3e - 2, op2,
604                                   n2, ww, redc, mod);
605    if (n3 == 3)
606       return;      // only need three coefficients of output
607 
608    // directly compute coefficient of x^(n3o - 1) in go(x)
609    ulong dohi[3];
610    res[2*n3o - 1] = diagonal_sum (dohi, op1 + 2*n3o - 1, op2,
611                                   n2, ww, redc, mod);
612    if (n3 == 4)
613       return;      // only need four coefficients of output
614 
615    // In f2(B), the leading coefficient starts at bit position b * (n2 - 1)
616    // and has length 2*b, and the coefficients overlap so we need an extra bit
617    // for the carry: this gives (n2 + 1) * b + 1 bits.
618    size_t k2 = CEIL_DIV ((n2 + 1) * b + 1, GMP_NUMB_BITS);
619 
620    // We need r = (n2 - 1) * b + 1 and s = (n3 + 1) * b.
621    // Note that p is non-negative (since k2 * GMP_NUMB_BITS >= (n2 + 1) * b
622    // >= (n2 - 1) * b - 1).
623    unsigned p = GMP_NUMB_BITS * (k2 + 1) - (n2 - 1) * b - 1;
624 
625    // For (*) we need k1 * GMP_NUMB_BITS >= p + (n1 + 1) * b + 1.
626    size_t k1 = CEIL_DIV (p + (n1 + 1) * b + 1, GMP_NUMB_BITS);
627 
628    size_t k3 = k1 - k2 + 3;
629    ZNP_ASSERT (k3 >= 5);
630 
631    // allocate space
632    ZNP_FASTALLOC (limbs, mp_limb_t, 6624, 5 * (k2 + k3));
633    mp_limb_t* v2_buf0 = limbs;             // k2 limbs
634    mp_limb_t* v3_buf0 = v2_buf0 + k2;      // k3 limbs
635    mp_limb_t* v2_buf1 = v3_buf0 + k3;      // k2 limbs
636    mp_limb_t* v3_buf1 = v2_buf1 + k2;      // k3 limbs
637    mp_limb_t* v2_buf2 = v3_buf1 + k3;      // k2 limbs
638    mp_limb_t* v3_buf2 = v2_buf2 + k2;      // k3 limbs
639    mp_limb_t* v2_buf3 = v3_buf2 + k3;      // k2 limbs
640    mp_limb_t* v3_buf3 = v2_buf3 + k2;      // k3 limbs
641    mp_limb_t* v2_buf4 = v3_buf3 + k3;      // k2 limbs
642    mp_limb_t* v3_buf4 = v2_buf4 + k2;      // k3 limbs
643 
644    // arrange overlapping buffers to minimise memory use
645    // "p" = plus, "m" = minus
646    // "n" = normal order, "r" = reciprocal order
647    mp_limb_t* v1en = v2_buf1;
648    mp_limb_t* v1on = v2_buf2;
649    mp_limb_t* v1pn = v2_buf0;
650    mp_limb_t* v1mn = v2_buf1;
651    mp_limb_t* v2en = v2_buf3;
652    mp_limb_t* v2on = v2_buf4;
653    mp_limb_t* v2pn = v2_buf2;
654    mp_limb_t* v2mn = v2_buf3;
655    mp_limb_t* v3mn = v3_buf2;
656    mp_limb_t* v3pn = v3_buf3;
657    mp_limb_t* v3en = v3_buf4;
658    mp_limb_t* v3on = v3_buf3;
659 
660    mp_limb_t* v1er = v2_buf1;
661    mp_limb_t* v1or = v2_buf2;
662    mp_limb_t* v1pr = v2_buf0;
663    mp_limb_t* v1mr = v2_buf1;
664    mp_limb_t* v2er = v2_buf3;
665    mp_limb_t* v2or = v2_buf4;
666    mp_limb_t* v2pr = v2_buf2;
667    mp_limb_t* v2mr = v2_buf3;
668    mp_limb_t* v3mr = v3_buf2;
669    mp_limb_t* v3pr = v3_buf1;
670    mp_limb_t* v3er = v3_buf0;
671    mp_limb_t* v3or = v3_buf1;
672 
673    ZNP_FASTALLOC (z, ulong, 6624, 2 * w * (n3e - 1));
674    ulong* zn = z;
675    ulong* zr = z + w * (n3e - 1);
676 
677    int v3m_neg;
678 
679    // -------------------------------------------------------------------------
680    //     "normal" evaluation point
681 
682    // evaluate 2^p * f1e(B^2) and 2^p * B * f1o(B^2)
683    zn_array_pack (v1en, op1, n1e, 2, 2 * b, p, k1);
684    zn_array_pack (v1on, op1 + 1, n1o, 2, 2 * b, p + b, k1);
685 
686    // compute 2^p *   f1(B)  =  2^p * f1e(B^2) + 2^p * B * f1o(B^2)
687    //    and  2^p * |f1(-B)| = |2^p * f1e(B^2) - 2^p * B * f1o(B^2)|
688    ZNP_ASSERT_NOCARRY (mpn_add_n (v1pn, v1en, v1on, k1));
689    v3m_neg = signed_mpn_sub_n (v1mn, v1en, v1on, k1);
690 
691    // evaluate f2e(B^2) and B * f2o(B^2)
692    zn_array_pack (v2en, op2, n2e, 2, 2 * b, 0, k2);
693    zn_array_pack (v2on, op2 + 1, n2o, 2, 2 * b, b, k2);
694 
695    // compute   f2(B)  =  f2e(B^2) + B * f2o(B^2)
696    //    and  |f2(-B)| = |f2e(B^2) - B * f2o(B^2)|
697    ZNP_ASSERT_NOCARRY (mpn_add_n (v2pn, v2en, v2on, k2));
698    v3m_neg ^= signed_mpn_sub_n (v2mn, v2en, v2on, k2);
699 
700    // compute segment starting at bit index r of
701    //            h(B)  =   f1(B)  *  f2(B)
702    //    and   |h(-B)| = |f1(-B)| * |f2(-B)|
703    // hn_neg is set if h(-B) is negative
704    ZNP_mpn_mulmid (v3mn, v1mn, k1, v2mn, k2);
705    ZNP_mpn_mulmid (v3pn, v1pn, k1, v2pn, k2);
706 
707    // compute segments starting at bit index r of
708    //         2     * he(B^2) = h(B) + h(-B)
709    //    and  2 * B * ho(B^2) = h(B) - h(-B)
710    // ie. the segments of he(B^2) and B * ho(B^2) starting at bit index r - 1.
711 
712    // If n2 is odd, the former encodes ge(x) and the latter encodes go(x).
713    // Otherwise the situation is reversed. We write the results to v3en/v3on
714    // accordingly.
715 
716    // Note that when we do the addition (resp. subtraction) below, we might
717    // miss a carry (resp. borrow) from the unknown previous limbs. We arrange
718    // so that the answers are either correct or one too big, by adding 1
719    // appropriately.
720 
721    if (v3m_neg ^ (n2 & 1))
722    {
723       mpn_add_n (v3en + 2, v3pn + 2, v3mn + 2, k3 - 4);    // miss carry?
724       mpn_add_1 (v3en + 2, v3en + 2, k3 - 4, 1);
725       mpn_sub_n (v3on + 2, v3pn + 2, v3mn + 2, k3 - 4);    // miss borrow?
726    }
727    else
728    {
729       mpn_sub_n (v3en + 2, v3pn + 2, v3mn + 2, k3 - 4);
730       mpn_add_n (v3on + 2, v3pn + 2, v3mn + 2, k3 - 4);
731       mpn_add_1 (v3on + 2, v3on + 2, k3 - 4, 1);
732    }
733 
734    // remove x^0 and x^(n3e - 1) coefficients of ge(x),
735    //    and x^0 and x^(n3o - 1) coefficients of go(x).
736    subtract_ulongs (v3en + 2, k3 - 4, 0, delo, ww);
737    subtract_ulongs (v3en + 2, k3 - 4, (2*n3e - 2) * b, dehi, ww);
738    subtract_ulongs (v3on + 2, k3 - 4, b, dolo, ww);
739    subtract_ulongs (v3on + 2, k3 - 4, (2*n3o - 1) * b, dohi, ww);
740 
741    // At this stage, the integer
742    //   g[2] + g[4]*B^2 + ... + g[2*n3e - 4]*B^(2*n3e - 6)
743    // appears in v3en + 2, starting at bit index 2*b, occupying
744    // (2 * n3e - 2) * b bits. The integer
745    //   g[3] + g[5]*B^2 + ... + g[2*n3o - 3]*B^(2*n3o - 6)
746    // appears in v3on + 2, starting at bit index 3*b, occupying
747    // (2 * n3o - 2) * b bits.
748 
749    // -------------------------------------------------------------------------
750    //     "reciprocal" evaluation point
751 
752    // evaluate 2^p * B^(n1-1) * f1e(1/B^2) and 2^p * B^(n1-2) * f1o(B^2)
753    zn_array_pack (v1er, op1 + 2*(n1e - 1), n1e, -2, 2 * b,
754                   p + ((n1 & 1) ? 0 : b), k1);
755    zn_array_pack (v1or, op1 + 1 + 2*(n1o - 1), n1o, -2, 2 * b,
756                   p + ((n1 & 1) ? b : 0), k1);
757 
758    // compute 2^p * B^(n1-1) * f1(1/B) =
759    //                2^p * B^(n1-1) * f1e(1/B^2) + 2^p * B^(n1-2) * f1o(B^2)
760    //   and  |2^p * B^(n1-1) * f1(-1/B)| =
761    //               |2^p * B^(n1-1) * f1e(1/B^2) - 2^p * B^(n1-2) * f1o(B^2)|
762    ZNP_ASSERT_NOCARRY (mpn_add_n (v1pr, v1er, v1or, k1));
763    v3m_neg = signed_mpn_sub_n (v1mr, v1er, v1or, k1);
764 
765    // evaluate B^(n2-1) * f2e(1/B^2) and B^(n2-2) * f2o(B^2)
766    zn_array_pack (v2er, op2 + 2*(n2e - 1), n2e, -2, 2 * b,
767                   (n2 & 1) ? 0 : b, k2);
768    zn_array_pack (v2or, op2 + 1 + 2*(n2o - 1), n2o, -2, 2 * b,
769                   (n2 & 1) ? b : 0, k2);
770 
771    // compute B^(n2-1) * f2(1/B) =
772    //                B^(n2-1) * f2e(1/B^2) + B^(n2-2) * f2o(B^2)
773    //   and  |B^(n2-1) * f2(-1/B)| =
774    //               |B^(n2-1) * f2e(1/B^2) - B^(n2-2) * f2o(B^2)|
775    ZNP_ASSERT_NOCARRY (mpn_add_n (v2pr, v2er, v2or, k2));
776    v3m_neg ^= signed_mpn_sub_n (v2mr, v2er, v2or, k2);
777 
778    // compute segment starting at bit index r of
779    //       B^(n3-1) * h(1/B)   = (B^(n1-1) * f1(1/B))  * (B^(n2-1) * f2(1/B))
780    //  and |B^(n3-1) * h(-1/B)| = |B^(n1-1) * f1(-1/B)| * |B^(n2-1) * f2(-1/B)|
781    // hr_neg is set if h(-1/B) is negative
782    ZNP_mpn_mulmid (v3mr, v1mr, k1, v2mr, k2);
783    ZNP_mpn_mulmid (v3pr, v1pr, k1, v2pr, k2);
784 
785    // compute segments starting at bit index r of
786    //        2 * B^(n3-1) * he(1/B^2) = B^(n3-1) * h(1/B) + B^(n3-1) * h(-1/B)
787    //   and  2 * B^(n3-2) * ho(1/B^2) = B^(n3-1) * h(1/B) - B^(n3-1) * h(-1/B)
788    // ie. the segments of B^(n3-1) * he(1/B^2) and B^(n3-2) * ho(1/B^2)
789    // starting at bit index r - 1.
790 
791    // If n2 is odd, the former encodes ge(x) and the latter encodes go(x).
792    // Otherwise the situation is reversed. We write the results to v3er/v3or
793    // accordingly.
794 
795    if (v3m_neg ^ (n2 & 1))
796    {
797       mpn_add_n (v3er + 2, v3pr + 2, v3mr + 2, k3 - 4);    // miss carry?
798       mpn_add_1 (v3er + 2, v3er + 2, k3 - 4, 1);
799       mpn_sub_n (v3or + 2, v3pr + 2, v3mr + 2, k3 - 4);    // miss borrow?
800    }
801    else
802    {
803       mpn_sub_n (v3er + 2, v3pr + 2, v3mr + 2, k3 - 4);
804       mpn_add_n (v3or + 2, v3pr + 2, v3mr + 2, k3 - 4);
805       mpn_add_1 (v3or + 2, v3or + 2, k3 - 4, 1);
806    }
807 
808    unsigned s = (n3 & 1) ? 0 : b;
809 
810    // remove x^0 and x^(n3e - 1) coefficients of ge(x),
811    //    and x^0 and x^(n3o - 1) coefficients of go(x).
812    subtract_ulongs (v3er + 2, k3 - 4, s, dehi, ww);
813    subtract_ulongs (v3er + 2, k3 - 4, (2*n3e - 2) * b + s, delo, ww);
814    subtract_ulongs (v3or + 2, k3 - 4, b - s, dohi, ww);
815    subtract_ulongs (v3or + 2, k3 - 4, (2*n3o - 2) * b + b - s, dolo, ww);
816 
817    // At this stage, the integer
818    //   g[2*n3e - 4] + g[2*n3e - 6]*B^2 + ... + g[2]*B^(2*n3e - 6)
819    // appears in v3er + 2, starting at bit index 2*b if n3 is odd, or 3*b
820    // if n3 is even, and occupying (2 * n3e - 2) * b bits. The integer
821    //   g[2*n3o - 3] + g[2*n3o - 5]*B^2 + ... + g[3]*B^(2*n3o - 6)
822    // appears in v3or + 2, starting at bit index 3*b if n3 is odd, or 2*b
823    // if n3 is even, and occupying (2 * n3o - 2) * b bits.
824 
825    // -------------------------------------------------------------------------
826    //     combine "normal" and "reciprocal" information
827 
828    // decompose relevant portion of ge(B^2) and ge(1/B^2) into base-B^2 digits
829    zn_array_unpack_SAFE (zn, v3en + 2, n3e - 1, 2 * b, 2 * b, k3 - 4);
830    zn_array_unpack_SAFE (zr, v3er + 2, n3e - 1, 2 * b, 2 * b + s, k3 - 4);
831 
832    // combine ge(B^2) and ge(1/B^2) information to get even coefficients of g
833    zn_array_recover_reduce (res + 2, 2, zn, zr, n3e - 2, 2 * b, redc, mod);
834 
835    // decompose relevant portion of go(B^2) and go(1/B^2) into base-B^2 digits
836    zn_array_unpack_SAFE (zn, v3on + 2, n3o - 1, 2 * b, 3 * b, k3 - 4);
837    zn_array_unpack_SAFE (zr, v3or + 2, n3o - 1, 2 * b, 3 * b - s, k3 - 4);
838 
839    // combine go(B^2) and go(1/B^2) information to get odd coefficients of g
840    zn_array_recover_reduce (res + 3, 2, zn, zr, n3o - 2, 2 * b, redc, mod);
841 
842    ZNP_FASTFREE (z);
843    ZNP_FASTFREE (limbs);
844 }
845 
846 
847 // end of file ****************************************************************
848