1 /*
2    mpn_mulmid.c:  middle products of integers
3 
4    Copyright (C) 2007, 2008, David Harvey
5 
6    This file is part of the zn_poly library (version 0.9).
7 
8    This program is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation, either version 2 of the License, or
11    (at your option) version 3 of the License.
12 
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17 
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #include "zn_poly_internal.h"
24 #include <string.h>
25 
26 
27 void
ZNP_mpn_smp_basecase(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)28 ZNP_mpn_smp_basecase (mp_limb_t* res,
29                       const mp_limb_t* op1, size_t n1,
30                       const mp_limb_t* op2, size_t n2)
31 {
32    ZNP_ASSERT (n1 >= n2);
33    ZNP_ASSERT (n2 >= 1);
34 
35 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
36 
37    mp_limb_t hi0, hi1, hi;
38    size_t s, j;
39 
40    j = n2 - 1;
41    s = n1 - j;
42    op2 += j;
43 
44    hi0 = mpn_mul_1 (res, op1, s, *op2);
45    hi1 = 0;
46 
47    for (op1++, op2--; j; j--, op1++, op2--)
48    {
49       hi = mpn_addmul_1 (res, op1, s, *op2);
50       ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, hi);
51    }
52 
53    res[s] = hi0;
54    res[s + 1] = hi1;
55 
56 #else
57 #error Not nails-safe yet
58 #endif
59 }
60 
61 
62 /*
63    Let x = op1[0, 2*n-1),
64        y = op2[0, n),
65        z = op3[0, n).
66 
67    If y >= z, this function computes y - z and the correction term
68       SMP(x, y) - SMP(x, z) - SMP(x, y - z)
69    and returns 0.
70 
71    If y < z, it computes z - y and the correction term
72       SMP(x, z) - SMP(x, y) - SMP(x, z - y)
73    and returns 1.
74 
75    In both cases abs(y - z) is stored at res[0, n).
76 
77    The correction term is v - u*B^n, where u is stored at hi[0, 2) and
78    v is stored at lo[0, 2).
79 
80    None of the output buffers are allowed to overlap either each other or
81    the input buffers.
82 */
83 #define bilinear2_sub_fixup \
84     ZNP_bilinear2_sub_fixup
85 int
bilinear2_sub_fixup(mp_limb_t * hi,mp_limb_t * lo,mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,const mp_limb_t * op3,size_t n)86 bilinear2_sub_fixup (mp_limb_t* hi, mp_limb_t* lo, mp_limb_t* res,
87                      const mp_limb_t* op1, const mp_limb_t* op2,
88                      const mp_limb_t* op3, size_t n)
89 {
90    ZNP_ASSERT (n >= 1);
91 
92 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
93 
94    int sign = 0;
95    if (mpn_cmp (op2, op3, n) < 0)
96    {
97       // swap y and z if necessary
98       const mp_limb_t* temp = op2;
99       op2 = op3;
100       op3 = temp;
101       sign = 1;
102    }
103    // now can assume y >= z
104 
105    // The correction term is computed as follows. Let
106    //
107    //    y_0     - z_0                =  u_0     - c_0 B,
108    //    y_1     - z_1     - c_0      =  u_1     - c_1 B,
109    //    y_2     - z_2     - c_1      =  u_2     - c_2 B,
110    //                                ...
111    //    y_{n-1} - z_{n-1} - c_{n-2}  =  u_{n-1},
112    //
113    // i.e. where c_j is the borrow (0 or 1) from the j-th limb of the
114    // subtraction y - z, and where u_j is the j-th digit of y - z. Note
115    // that c_{-1} = c_{n-1} = 0. By definition we want to compute
116    //
117    //    \sum_{0 <= i < 2n-1, 0 <= j < n, n-1 <= i+j < 2n-1}
118    //                                  (c_{j-1} - c_j B) x_i B^{i+j-(n-1)}
119    //
120    // After some algebra this collapses down to
121    //
122    //    \sum_{0 <= i < n-1} c_i (x_{n-2-i} - B^n x_{2n-2-i}).
123 
124    // First compute y - z using mpn_sub_n (fast)
125    mpn_sub_n (res, op2, op3, n);
126 
127    // Now loop through and figure out where the borrows happened
128    size_t i;
129    mp_limb_t hi0 = 0, hi1 = 0;
130    mp_limb_t lo0 = 0, lo1 = 0;
131 
132    for (i = n - 1; i; i--, op1++)
133    {
134       mp_limb_t borrow = res[i] - op2[i] + op3[i];
135       ZNP_ADD_WIDE (lo1, lo0, lo1, lo0, 0, borrow & op1[0]);
136       ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, borrow & op1[n]);
137    }
138 
139    hi[0] = hi0;
140    hi[1] = hi1;
141    lo[0] = lo0;
142    lo[1] = lo1;
143 
144    return sign;
145 #else
146 #error Not nails-safe yet
147 #endif
148 }
149 
150 
151 /*
152    Let x = op1[0, 2*n-1),
153        y = op2[0, 2*n-1),
154        z = op3[0, n).
155 
156    This function computes x + y mod B^(2n-1) and the correction term
157       SMP(x, z) + SMP(y, z) - SMP((x + y) mod B^(2n-1), z).
158 
159    The value x + y mod B^(2n-1) is stored at res[0, 2n-1).
160 
161    The correction term is u*B^n - v, where u is stored at hi[0, 2) and
162    v is stored at lo[0, 2).
163 
164    None of the output buffers are allowed to overlap either each other or
165    the input buffers.
166 */
167 #define bilinear1_add_fixup \
168     ZNP_bilinear1_add_fixup
169 void
bilinear1_add_fixup(mp_limb_t * hi,mp_limb_t * lo,mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,const mp_limb_t * op3,size_t n)170 bilinear1_add_fixup (mp_limb_t* hi, mp_limb_t* lo, mp_limb_t* res,
171                      const mp_limb_t* op1, const mp_limb_t* op2,
172                      const mp_limb_t* op3, size_t n)
173 {
174    ZNP_ASSERT (n >= 1);
175 
176 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
177 
178    // The correction term is computed as follows. Let
179    //
180    //    x_0      + y_0                  =  u_0      + c_0 B,
181    //    x_1      + y_1      + c_0       =  u_1      + c_1 B,
182    //    x_2      + y_2      + c_1       =  u_2      + c_2 B,
183    //                                   ...
184    //    x_{2n-2} + y_{2n-2} + c_{2n-3}  =  u_{2n-2} + c_{2n-1} B,
185    //
186    // i.e. where c_j is the carry (0 or 1) from the j-th limb of the
187    // addition x + y, and u_j is the j-th digit of x + y. Note that
188    // c_{-1} = 0. By definition we want to compute
189    //
190    //    \sum_{0 <= i < 2n-1, 0 <= j < n, n-1 <= i+j < 2n-1}
191    //                                  (c_i B - c_{i-1}) z_j B^{i+j-(n-1)}
192    //
193    // After some algebra this collapses down to
194    //
195    //     -\sum_{0 <= j < n-1}    c_j z_{n-2-j}  +
196    //  B^n \sum_{n-1 <= j < 2n-1} c_j z_{2n-2-j}.
197 
198    // First compute x + y using mpn_add_n (fast)
199    mp_limb_t last_carry = mpn_add_n (res, op1, op2, 2*n - 1);
200 
201    // Now loop through and figure out where the carries happened
202    size_t j;
203    mp_limb_t fix0 = 0, fix1 = 0;
204    op3 += n - 2;
205 
206    for (j = 0; j < n - 1; j++, op3--)
207    {
208       // carry = -1 if there was a carry in the j-th limb addition
209       mp_limb_t carry = op1[j+1] + op2[j+1] - res[j+1];
210       ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, carry & *op3);
211    }
212 
213    lo[0] = fix0;
214    lo[1] = fix1;
215 
216    fix0 = fix1 = 0;
217    op3 += n;
218 
219    for (; j < 2*n - 2; j++, op3--)
220    {
221       // carry = -1 if there was a carry in the j-th limb addition
222       mp_limb_t carry = op1[j+1] + op2[j+1] - res[j+1];
223       ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, carry & *op3);
224    }
225 
226    ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, (-last_carry) & *op3);
227 
228    hi[0] = fix0;
229    hi[1] = fix1;
230 #else
231 #error Not nails-safe yet
232 #endif
233 }
234 
235 
236 
237 void
ZNP_mpn_smp_kara(mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,size_t n)238 ZNP_mpn_smp_kara (mp_limb_t* res, const mp_limb_t* op1, const mp_limb_t* op2,
239                   size_t n)
240 {
241    ZNP_ASSERT (n >= 2);
242 
243    if (n & 1)
244    {
245       // If n is odd, we strip off the bottom row and last diagonal and
246       // handle them separately at the end (stuff marked O in the diagram
247       // below); the remainder gets handled via karatsuba (stuff marked E):
248 
249       // EEEEO....
250       // .EEEEO...
251       // ..EEEEO..
252       // ...EEEEO.
253       // ....OOOOO
254 
255       op2++;
256    }
257 
258    size_t k = n / 2;
259 
260    ZNP_FASTALLOC (temp, mp_limb_t, 6642, 2 * k + 2);
261 
262    mp_limb_t hi[2], lo[2];
263 
264    // The following diagram shows the contributions from various regions
265    // for k = 3:
266 
267    //  AAABBB.....
268    //  .AAABBB....
269    //  ..AAABBB...
270    //  ...CCCDDD..
271    //  ....CCCDDD.
272    //  .....CCCDDD
273 
274    // ------------------------------------------------------------------------
275    // Step 1: compute contribution from A + contribution from B
276 
277    // Let x = op1[0, 2*k-1)
278    //     y = op1[k, 3*k-1)
279    //     z = op2[k, 2*k).
280 
281    // Need to compute SMP(x, z) + SMP(y, z). To do this, we will compute
282    // SMP((x + y) mod B^(2k-1), z) and a correction term.
283 
284    // First compute x + y mod B^(2k-1) and the correction term.
285    bilinear1_add_fixup (hi, lo, temp, op1, op1 + k, op2 + k, k);
286 
287    // Now compute SMP(x + y mod B^(2k-1), z).
288    // Store result in first half of output.
289    if (k < ZNP_mpn_smp_kara_thresh)
290       ZNP_mpn_smp_basecase (res, temp, 2 * k - 1, op2 + k, k);
291    else
292       ZNP_mpn_smp_kara (res, temp, op2 + k, k);
293 
294    // Add in the correction term.
295    mpn_sub (res, res, k + 2, lo, 2);
296    mpn_add_n (res + k, res + k, hi, 2);
297 
298    // Save the last two limbs (they're about to get overwritten)
299    mp_limb_t saved[2];
300    saved[0] = res[k];
301    saved[1] = res[k + 1];
302 
303    // ------------------------------------------------------------------------
304    // Step 2: compute contribution from C + contribution from D
305 
306    // Let x = op1[k, 3*k-1)
307    //     y = op1[2*k, 4*k-1)
308    //     z = op2[0, k).
309 
310    // Need to compute SMP(x, z) + SMP(y, z). To do this, we will compute
311    // SMP((x + y) mod B^(2k-1), z) and a correction term.
312 
313    // First compute x + y mod B^(2k-1) and the correction term.
314    bilinear1_add_fixup (hi, lo, temp, op1 + k, op1 + 2 * k, op2, k);
315 
316    // Now compute SMP(x + y mod B^(2k-1), z).
317    // Store result in second half of output.
318    if (k < ZNP_mpn_smp_kara_thresh)
319       ZNP_mpn_smp_basecase (res + k, temp, 2 * k - 1, op2, k);
320    else
321       ZNP_mpn_smp_kara (res + k, temp, op2, k);
322 
323    // Add in the correction term.
324    mpn_sub (res + k, res + k, k + 2, lo, 2);
325    mpn_add_n (res + 2 * k, res + 2 * k, hi, 2);
326 
327    // Add back the saved limbs.
328    mpn_add (res + k, res + k, k + 2, saved, 2);
329 
330    // ------------------------------------------------------------------------
331    // Step 3: compute contribution from B - contribution from C
332 
333    // Let x = op1[k, 3*k-1)
334    //     y = op2[k, 2*k).
335    //     z = op2[0, k)
336 
337    // Need to compute SMP(x, y) - SMP(x, z). To do this, we will compute
338    // SMP(x, abs(y - z)), and a correction term.
339 
340    // First compute abs(y - z) and the correction term.
341    int sign = bilinear2_sub_fixup (hi, lo, temp, op1 + k, op2 + k, op2, k);
342 
343    // Now compute SMP(x, abs(y - z)).
344    // Store it in second half of temp space, in two's complement (mod B^(k+2))
345    if (k < ZNP_mpn_smp_kara_thresh)
346       ZNP_mpn_smp_basecase (temp + k, op1 + k, 2 * k - 1, temp, k);
347    else
348       ZNP_mpn_smp_kara (temp + k, op1 + k, temp, k);
349 
350    // Add in the correction term.
351    mpn_add (temp + k, temp + k, k + 2, lo, 2);
352    mp_limb_t borrow = mpn_sub_n (temp + 2 * k, temp + 2 * k, hi, 2);
353 
354    // ------------------------------------------------------------------------
355    // Step 4: put the pieces together
356 
357    // First half of output is A + C = t4 - t2
358    // Second half of output is B + D = t6 + t2
359    if (sign)
360    {
361       mpn_add (res, res, 2 * k + 2, temp + k, k + 2);
362       mpn_sub_1 (res + k + 2, res + k + 2, k, borrow);
363       mpn_sub (res + k, res + k, k + 2, temp + k, k + 2);
364    }
365    else
366    {
367       mpn_sub (res, res, 2 * k + 2, temp + k, k + 2);
368       mpn_add_1 (res + k + 2, res + k + 2, k, borrow);
369       mpn_add (res + k, res + k, k + 2, temp + k, k + 2);
370    }
371 
372    // ------------------------------------------------------------------------
373    // Step 5: add in correction if the length was odd
374 
375 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
376 
377    if (n & 1)
378    {
379       op2--;
380 
381       mp_limb_t hi0 = mpn_addmul_1 (res, op1 + n - 1, n, *op2);
382       mp_limb_t hi1 = 0, lo0 = 0, lo1 = 0;
383 
384       size_t i;
385       for (i = n - 1; i; i--)
386       {
387          mp_limb_t y0, y1;
388          ZNP_MUL_WIDE (y1, y0, op1[2 * n - i - 2], op2[i]);
389          ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, y1);
390          ZNP_ADD_WIDE (lo1, lo0, lo1, lo0, 0, y0);
391       }
392 
393       res[n + 1] = hi1;
394       mpn_add_1 (res + n, res + n, 2, hi0);
395       mpn_add_1 (res + n, res + n, 2, lo1);
396       mpn_add_1 (res + n - 1, res + n - 1, 3, lo0);
397    }
398 
399    ZNP_FASTFREE (temp);
400 
401 #else
402 #error Not nails-safe yet
403 #endif
404 }
405 
406 
407 void
ZNP_mpn_smp_n(mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,size_t n)408 ZNP_mpn_smp_n (mp_limb_t* res, const mp_limb_t* op1, const mp_limb_t* op2,
409                size_t n)
410 {
411    if (n < ZNP_mpn_smp_kara_thresh)
412       ZNP_mpn_smp_basecase (res, op1, 2*n - 1, op2, n);
413    else
414       ZNP_mpn_smp_kara (res, op1, op2, n);
415 }
416 
417 
418 void
ZNP_mpn_smp(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)419 ZNP_mpn_smp (mp_limb_t* res,
420              const mp_limb_t* op1, size_t n1,
421              const mp_limb_t* op2, size_t n2)
422 {
423    ZNP_ASSERT (n1 >= n2);
424    ZNP_ASSERT (n2 >= 1);
425 
426    size_t n3 = n1 - n2 + 1;
427 
428    if (n3 < ZNP_mpn_smp_kara_thresh)
429    {
430       // region is too narrow to make karatsuba worthwhile for any portion
431       ZNP_mpn_smp_basecase (res, op1, n1, op2, n2);
432       return;
433    }
434 
435    if (n2 > n3)
436    {
437       // slice region into chunks horizontally, i.e. like this:
438 
439       // AA.....
440       // .AA....
441       // ..BB...
442       // ...BB..
443       // ....CC.
444       // .....CC
445 
446       // first chunk (marked A in the above diagram)
447       op2 += n2 - n3;
448       ZNP_mpn_smp_kara (res, op1, op2, n3);
449 
450       // remaining chunks (B, C, etc)
451       ZNP_FASTALLOC (temp, mp_limb_t, 6642, n3 + 2);
452 
453       n1 -= n3;
454       n2 -= n3;
455 
456       while (n2 >= n3)
457       {
458          op1 += n3;
459          op2 -= n3;
460          ZNP_mpn_smp_kara (temp, op1, op2, n3);
461          mpn_add_n (res, res, temp, n3 + 2);
462          n1 -= n3;
463          n2 -= n3;
464       }
465 
466       if (n2)
467       {
468          // last remaining chunk
469          op1 += n3;
470          op2 -= n2;
471          ZNP_mpn_smp (temp, op1, n1, op2, n2);
472          mpn_add_n (res, res, temp, n3 + 2);
473       }
474 
475       ZNP_FASTFREE (temp);
476    }
477    else
478    {
479       mp_limb_t save[2];
480 
481       // slice region into chunks diagonally, i.e. like this:
482 
483       // AAABBBCC..
484       // .AAABBBCC.
485       // ..AAABBBCC
486 
487       // first chunk (marked A in the above diagram)
488       ZNP_mpn_smp_n (res, op1, op2, n2);
489 
490       n1 -= n2;
491       n3 -= n2;
492 
493       // remaining chunks (B, C, etc)
494       while (n3 >= n2)
495       {
496          op1 += n2;
497          res += n2;
498 
499          // save two limbs which are going to be overwritten
500          save[0] = res[0];
501          save[1] = res[1];
502 
503          ZNP_mpn_smp_n (res, op1, op2, n2);
504 
505          // add back saved limbs
506          mpn_add (res, res, n2 + 2, save, 2);
507 
508          n1 -= n2;
509          n3 -= n2;
510       }
511 
512       if (n3)
513       {
514          // last remaining chunk
515          op1 += n2;
516          res += n2;
517 
518          save[0] = res[0];
519          save[1] = res[1];
520 
521          ZNP_mpn_smp (res, op1, n1, op2, n2);
522 
523          mpn_add (res, res, n3 + 2, save, 2);
524       }
525    }
526 }
527 
528 
529 void
ZNP_mpn_mulmid_fallback(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)530 ZNP_mpn_mulmid_fallback (mp_limb_t* res,
531                          const mp_limb_t* op1, size_t n1,
532                          const mp_limb_t* op2, size_t n2)
533 {
534    if (n1 < n2 + 1)
535       return;
536 
537    ZNP_FASTALLOC (temp, mp_limb_t, 6642, n1 + n2);
538    ZNP_mpn_mul (temp, op1, n1, op2, n2);
539    memcpy (res + 2, temp + n2 + 1, sizeof(mp_limb_t) * (n1 - n2 - 1));
540    ZNP_FASTFREE (temp);
541 }
542 
543 
544 void
ZNP_mpn_mulmid(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)545 ZNP_mpn_mulmid (mp_limb_t* res, const mp_limb_t* op1, size_t n1,
546                 const mp_limb_t* op2, size_t n2)
547 {
548    ZNP_ASSERT (n1 >= n2);
549    ZNP_ASSERT (n2 >= 1);
550 
551    if (n2 >= ZNP_mpn_mulmid_fallback_thresh)
552    {
553       ZNP_mpn_mulmid_fallback (res, op1, n1, op2, n2);
554       return;
555    }
556 
557    // try using the simple middle product
558    ZNP_mpn_smp (res, op1, n1, op2, n2);
559 
560 #if GMP_NAIL_BITS == 0  &&  ULONG_BITS == GMP_NUMB_BITS
561 
562    // If there's a possibility of overflow from lower diagonals, we just give
563    // up and do the whole product. (Note: this should happen extremely rarely
564    // on uniform random input. However, on data generated by mpn_random2, it
565    // seems to happen with non-negligible probability.)
566    if (res[1] >= -(mp_limb_t)(n2))
567       ZNP_mpn_mulmid_fallback (res, op1, n1, op2, n2);
568 
569 #else
570 #error Not nails-safe yet
571 #endif
572 }
573 
574 
575 // end of file ****************************************************************
576