1 /*
2 mpn_mulmid.c: middle products of integers
3
4 Copyright (C) 2007, 2008, David Harvey
5
6 This file is part of the zn_poly library (version 0.9).
7
8 This program is free software: you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation, either version 2 of the License, or
11 (at your option) version 3 of the License.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see <http://www.gnu.org/licenses/>.
20
21 */
22
23 #include "zn_poly_internal.h"
24 #include <string.h>
25
26
27 void
ZNP_mpn_smp_basecase(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)28 ZNP_mpn_smp_basecase (mp_limb_t* res,
29 const mp_limb_t* op1, size_t n1,
30 const mp_limb_t* op2, size_t n2)
31 {
32 ZNP_ASSERT (n1 >= n2);
33 ZNP_ASSERT (n2 >= 1);
34
35 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
36
37 mp_limb_t hi0, hi1, hi;
38 size_t s, j;
39
40 j = n2 - 1;
41 s = n1 - j;
42 op2 += j;
43
44 hi0 = mpn_mul_1 (res, op1, s, *op2);
45 hi1 = 0;
46
47 for (op1++, op2--; j; j--, op1++, op2--)
48 {
49 hi = mpn_addmul_1 (res, op1, s, *op2);
50 ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, hi);
51 }
52
53 res[s] = hi0;
54 res[s + 1] = hi1;
55
56 #else
57 #error Not nails-safe yet
58 #endif
59 }
60
61
62 /*
63 Let x = op1[0, 2*n-1),
64 y = op2[0, n),
65 z = op3[0, n).
66
67 If y >= z, this function computes y - z and the correction term
68 SMP(x, y) - SMP(x, z) - SMP(x, y - z)
69 and returns 0.
70
71 If y < z, it computes z - y and the correction term
72 SMP(x, z) - SMP(x, y) - SMP(x, z - y)
73 and returns 1.
74
75 In both cases abs(y - z) is stored at res[0, n).
76
77 The correction term is v - u*B^n, where u is stored at hi[0, 2) and
78 v is stored at lo[0, 2).
79
80 None of the output buffers are allowed to overlap either each other or
81 the input buffers.
82 */
83 #define bilinear2_sub_fixup \
84 ZNP_bilinear2_sub_fixup
85 int
bilinear2_sub_fixup(mp_limb_t * hi,mp_limb_t * lo,mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,const mp_limb_t * op3,size_t n)86 bilinear2_sub_fixup (mp_limb_t* hi, mp_limb_t* lo, mp_limb_t* res,
87 const mp_limb_t* op1, const mp_limb_t* op2,
88 const mp_limb_t* op3, size_t n)
89 {
90 ZNP_ASSERT (n >= 1);
91
92 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
93
94 int sign = 0;
95 if (mpn_cmp (op2, op3, n) < 0)
96 {
97 // swap y and z if necessary
98 const mp_limb_t* temp = op2;
99 op2 = op3;
100 op3 = temp;
101 sign = 1;
102 }
103 // now can assume y >= z
104
105 // The correction term is computed as follows. Let
106 //
107 // y_0 - z_0 = u_0 - c_0 B,
108 // y_1 - z_1 - c_0 = u_1 - c_1 B,
109 // y_2 - z_2 - c_1 = u_2 - c_2 B,
110 // ...
111 // y_{n-1} - z_{n-1} - c_{n-2} = u_{n-1},
112 //
113 // i.e. where c_j is the borrow (0 or 1) from the j-th limb of the
114 // subtraction y - z, and where u_j is the j-th digit of y - z. Note
115 // that c_{-1} = c_{n-1} = 0. By definition we want to compute
116 //
117 // \sum_{0 <= i < 2n-1, 0 <= j < n, n-1 <= i+j < 2n-1}
118 // (c_{j-1} - c_j B) x_i B^{i+j-(n-1)}
119 //
120 // After some algebra this collapses down to
121 //
122 // \sum_{0 <= i < n-1} c_i (x_{n-2-i} - B^n x_{2n-2-i}).
123
124 // First compute y - z using mpn_sub_n (fast)
125 mpn_sub_n (res, op2, op3, n);
126
127 // Now loop through and figure out where the borrows happened
128 size_t i;
129 mp_limb_t hi0 = 0, hi1 = 0;
130 mp_limb_t lo0 = 0, lo1 = 0;
131
132 for (i = n - 1; i; i--, op1++)
133 {
134 mp_limb_t borrow = res[i] - op2[i] + op3[i];
135 ZNP_ADD_WIDE (lo1, lo0, lo1, lo0, 0, borrow & op1[0]);
136 ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, borrow & op1[n]);
137 }
138
139 hi[0] = hi0;
140 hi[1] = hi1;
141 lo[0] = lo0;
142 lo[1] = lo1;
143
144 return sign;
145 #else
146 #error Not nails-safe yet
147 #endif
148 }
149
150
151 /*
152 Let x = op1[0, 2*n-1),
153 y = op2[0, 2*n-1),
154 z = op3[0, n).
155
156 This function computes x + y mod B^(2n-1) and the correction term
157 SMP(x, z) + SMP(y, z) - SMP((x + y) mod B^(2n-1), z).
158
159 The value x + y mod B^(2n-1) is stored at res[0, 2n-1).
160
161 The correction term is u*B^n - v, where u is stored at hi[0, 2) and
162 v is stored at lo[0, 2).
163
164 None of the output buffers are allowed to overlap either each other or
165 the input buffers.
166 */
167 #define bilinear1_add_fixup \
168 ZNP_bilinear1_add_fixup
169 void
bilinear1_add_fixup(mp_limb_t * hi,mp_limb_t * lo,mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,const mp_limb_t * op3,size_t n)170 bilinear1_add_fixup (mp_limb_t* hi, mp_limb_t* lo, mp_limb_t* res,
171 const mp_limb_t* op1, const mp_limb_t* op2,
172 const mp_limb_t* op3, size_t n)
173 {
174 ZNP_ASSERT (n >= 1);
175
176 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
177
178 // The correction term is computed as follows. Let
179 //
180 // x_0 + y_0 = u_0 + c_0 B,
181 // x_1 + y_1 + c_0 = u_1 + c_1 B,
182 // x_2 + y_2 + c_1 = u_2 + c_2 B,
183 // ...
184 // x_{2n-2} + y_{2n-2} + c_{2n-3} = u_{2n-2} + c_{2n-1} B,
185 //
186 // i.e. where c_j is the carry (0 or 1) from the j-th limb of the
187 // addition x + y, and u_j is the j-th digit of x + y. Note that
188 // c_{-1} = 0. By definition we want to compute
189 //
190 // \sum_{0 <= i < 2n-1, 0 <= j < n, n-1 <= i+j < 2n-1}
191 // (c_i B - c_{i-1}) z_j B^{i+j-(n-1)}
192 //
193 // After some algebra this collapses down to
194 //
195 // -\sum_{0 <= j < n-1} c_j z_{n-2-j} +
196 // B^n \sum_{n-1 <= j < 2n-1} c_j z_{2n-2-j}.
197
198 // First compute x + y using mpn_add_n (fast)
199 mp_limb_t last_carry = mpn_add_n (res, op1, op2, 2*n - 1);
200
201 // Now loop through and figure out where the carries happened
202 size_t j;
203 mp_limb_t fix0 = 0, fix1 = 0;
204 op3 += n - 2;
205
206 for (j = 0; j < n - 1; j++, op3--)
207 {
208 // carry = -1 if there was a carry in the j-th limb addition
209 mp_limb_t carry = op1[j+1] + op2[j+1] - res[j+1];
210 ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, carry & *op3);
211 }
212
213 lo[0] = fix0;
214 lo[1] = fix1;
215
216 fix0 = fix1 = 0;
217 op3 += n;
218
219 for (; j < 2*n - 2; j++, op3--)
220 {
221 // carry = -1 if there was a carry in the j-th limb addition
222 mp_limb_t carry = op1[j+1] + op2[j+1] - res[j+1];
223 ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, carry & *op3);
224 }
225
226 ZNP_ADD_WIDE (fix1, fix0, fix1, fix0, 0, (-last_carry) & *op3);
227
228 hi[0] = fix0;
229 hi[1] = fix1;
230 #else
231 #error Not nails-safe yet
232 #endif
233 }
234
235
236
237 void
ZNP_mpn_smp_kara(mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,size_t n)238 ZNP_mpn_smp_kara (mp_limb_t* res, const mp_limb_t* op1, const mp_limb_t* op2,
239 size_t n)
240 {
241 ZNP_ASSERT (n >= 2);
242
243 if (n & 1)
244 {
245 // If n is odd, we strip off the bottom row and last diagonal and
246 // handle them separately at the end (stuff marked O in the diagram
247 // below); the remainder gets handled via karatsuba (stuff marked E):
248
249 // EEEEO....
250 // .EEEEO...
251 // ..EEEEO..
252 // ...EEEEO.
253 // ....OOOOO
254
255 op2++;
256 }
257
258 size_t k = n / 2;
259
260 ZNP_FASTALLOC (temp, mp_limb_t, 6642, 2 * k + 2);
261
262 mp_limb_t hi[2], lo[2];
263
264 // The following diagram shows the contributions from various regions
265 // for k = 3:
266
267 // AAABBB.....
268 // .AAABBB....
269 // ..AAABBB...
270 // ...CCCDDD..
271 // ....CCCDDD.
272 // .....CCCDDD
273
274 // ------------------------------------------------------------------------
275 // Step 1: compute contribution from A + contribution from B
276
277 // Let x = op1[0, 2*k-1)
278 // y = op1[k, 3*k-1)
279 // z = op2[k, 2*k).
280
281 // Need to compute SMP(x, z) + SMP(y, z). To do this, we will compute
282 // SMP((x + y) mod B^(2k-1), z) and a correction term.
283
284 // First compute x + y mod B^(2k-1) and the correction term.
285 bilinear1_add_fixup (hi, lo, temp, op1, op1 + k, op2 + k, k);
286
287 // Now compute SMP(x + y mod B^(2k-1), z).
288 // Store result in first half of output.
289 if (k < ZNP_mpn_smp_kara_thresh)
290 ZNP_mpn_smp_basecase (res, temp, 2 * k - 1, op2 + k, k);
291 else
292 ZNP_mpn_smp_kara (res, temp, op2 + k, k);
293
294 // Add in the correction term.
295 mpn_sub (res, res, k + 2, lo, 2);
296 mpn_add_n (res + k, res + k, hi, 2);
297
298 // Save the last two limbs (they're about to get overwritten)
299 mp_limb_t saved[2];
300 saved[0] = res[k];
301 saved[1] = res[k + 1];
302
303 // ------------------------------------------------------------------------
304 // Step 2: compute contribution from C + contribution from D
305
306 // Let x = op1[k, 3*k-1)
307 // y = op1[2*k, 4*k-1)
308 // z = op2[0, k).
309
310 // Need to compute SMP(x, z) + SMP(y, z). To do this, we will compute
311 // SMP((x + y) mod B^(2k-1), z) and a correction term.
312
313 // First compute x + y mod B^(2k-1) and the correction term.
314 bilinear1_add_fixup (hi, lo, temp, op1 + k, op1 + 2 * k, op2, k);
315
316 // Now compute SMP(x + y mod B^(2k-1), z).
317 // Store result in second half of output.
318 if (k < ZNP_mpn_smp_kara_thresh)
319 ZNP_mpn_smp_basecase (res + k, temp, 2 * k - 1, op2, k);
320 else
321 ZNP_mpn_smp_kara (res + k, temp, op2, k);
322
323 // Add in the correction term.
324 mpn_sub (res + k, res + k, k + 2, lo, 2);
325 mpn_add_n (res + 2 * k, res + 2 * k, hi, 2);
326
327 // Add back the saved limbs.
328 mpn_add (res + k, res + k, k + 2, saved, 2);
329
330 // ------------------------------------------------------------------------
331 // Step 3: compute contribution from B - contribution from C
332
333 // Let x = op1[k, 3*k-1)
334 // y = op2[k, 2*k).
335 // z = op2[0, k)
336
337 // Need to compute SMP(x, y) - SMP(x, z). To do this, we will compute
338 // SMP(x, abs(y - z)), and a correction term.
339
340 // First compute abs(y - z) and the correction term.
341 int sign = bilinear2_sub_fixup (hi, lo, temp, op1 + k, op2 + k, op2, k);
342
343 // Now compute SMP(x, abs(y - z)).
344 // Store it in second half of temp space, in two's complement (mod B^(k+2))
345 if (k < ZNP_mpn_smp_kara_thresh)
346 ZNP_mpn_smp_basecase (temp + k, op1 + k, 2 * k - 1, temp, k);
347 else
348 ZNP_mpn_smp_kara (temp + k, op1 + k, temp, k);
349
350 // Add in the correction term.
351 mpn_add (temp + k, temp + k, k + 2, lo, 2);
352 mp_limb_t borrow = mpn_sub_n (temp + 2 * k, temp + 2 * k, hi, 2);
353
354 // ------------------------------------------------------------------------
355 // Step 4: put the pieces together
356
357 // First half of output is A + C = t4 - t2
358 // Second half of output is B + D = t6 + t2
359 if (sign)
360 {
361 mpn_add (res, res, 2 * k + 2, temp + k, k + 2);
362 mpn_sub_1 (res + k + 2, res + k + 2, k, borrow);
363 mpn_sub (res + k, res + k, k + 2, temp + k, k + 2);
364 }
365 else
366 {
367 mpn_sub (res, res, 2 * k + 2, temp + k, k + 2);
368 mpn_add_1 (res + k + 2, res + k + 2, k, borrow);
369 mpn_add (res + k, res + k, k + 2, temp + k, k + 2);
370 }
371
372 // ------------------------------------------------------------------------
373 // Step 5: add in correction if the length was odd
374
375 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
376
377 if (n & 1)
378 {
379 op2--;
380
381 mp_limb_t hi0 = mpn_addmul_1 (res, op1 + n - 1, n, *op2);
382 mp_limb_t hi1 = 0, lo0 = 0, lo1 = 0;
383
384 size_t i;
385 for (i = n - 1; i; i--)
386 {
387 mp_limb_t y0, y1;
388 ZNP_MUL_WIDE (y1, y0, op1[2 * n - i - 2], op2[i]);
389 ZNP_ADD_WIDE (hi1, hi0, hi1, hi0, 0, y1);
390 ZNP_ADD_WIDE (lo1, lo0, lo1, lo0, 0, y0);
391 }
392
393 res[n + 1] = hi1;
394 mpn_add_1 (res + n, res + n, 2, hi0);
395 mpn_add_1 (res + n, res + n, 2, lo1);
396 mpn_add_1 (res + n - 1, res + n - 1, 3, lo0);
397 }
398
399 ZNP_FASTFREE (temp);
400
401 #else
402 #error Not nails-safe yet
403 #endif
404 }
405
406
407 void
ZNP_mpn_smp_n(mp_limb_t * res,const mp_limb_t * op1,const mp_limb_t * op2,size_t n)408 ZNP_mpn_smp_n (mp_limb_t* res, const mp_limb_t* op1, const mp_limb_t* op2,
409 size_t n)
410 {
411 if (n < ZNP_mpn_smp_kara_thresh)
412 ZNP_mpn_smp_basecase (res, op1, 2*n - 1, op2, n);
413 else
414 ZNP_mpn_smp_kara (res, op1, op2, n);
415 }
416
417
418 void
ZNP_mpn_smp(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)419 ZNP_mpn_smp (mp_limb_t* res,
420 const mp_limb_t* op1, size_t n1,
421 const mp_limb_t* op2, size_t n2)
422 {
423 ZNP_ASSERT (n1 >= n2);
424 ZNP_ASSERT (n2 >= 1);
425
426 size_t n3 = n1 - n2 + 1;
427
428 if (n3 < ZNP_mpn_smp_kara_thresh)
429 {
430 // region is too narrow to make karatsuba worthwhile for any portion
431 ZNP_mpn_smp_basecase (res, op1, n1, op2, n2);
432 return;
433 }
434
435 if (n2 > n3)
436 {
437 // slice region into chunks horizontally, i.e. like this:
438
439 // AA.....
440 // .AA....
441 // ..BB...
442 // ...BB..
443 // ....CC.
444 // .....CC
445
446 // first chunk (marked A in the above diagram)
447 op2 += n2 - n3;
448 ZNP_mpn_smp_kara (res, op1, op2, n3);
449
450 // remaining chunks (B, C, etc)
451 ZNP_FASTALLOC (temp, mp_limb_t, 6642, n3 + 2);
452
453 n1 -= n3;
454 n2 -= n3;
455
456 while (n2 >= n3)
457 {
458 op1 += n3;
459 op2 -= n3;
460 ZNP_mpn_smp_kara (temp, op1, op2, n3);
461 mpn_add_n (res, res, temp, n3 + 2);
462 n1 -= n3;
463 n2 -= n3;
464 }
465
466 if (n2)
467 {
468 // last remaining chunk
469 op1 += n3;
470 op2 -= n2;
471 ZNP_mpn_smp (temp, op1, n1, op2, n2);
472 mpn_add_n (res, res, temp, n3 + 2);
473 }
474
475 ZNP_FASTFREE (temp);
476 }
477 else
478 {
479 mp_limb_t save[2];
480
481 // slice region into chunks diagonally, i.e. like this:
482
483 // AAABBBCC..
484 // .AAABBBCC.
485 // ..AAABBBCC
486
487 // first chunk (marked A in the above diagram)
488 ZNP_mpn_smp_n (res, op1, op2, n2);
489
490 n1 -= n2;
491 n3 -= n2;
492
493 // remaining chunks (B, C, etc)
494 while (n3 >= n2)
495 {
496 op1 += n2;
497 res += n2;
498
499 // save two limbs which are going to be overwritten
500 save[0] = res[0];
501 save[1] = res[1];
502
503 ZNP_mpn_smp_n (res, op1, op2, n2);
504
505 // add back saved limbs
506 mpn_add (res, res, n2 + 2, save, 2);
507
508 n1 -= n2;
509 n3 -= n2;
510 }
511
512 if (n3)
513 {
514 // last remaining chunk
515 op1 += n2;
516 res += n2;
517
518 save[0] = res[0];
519 save[1] = res[1];
520
521 ZNP_mpn_smp (res, op1, n1, op2, n2);
522
523 mpn_add (res, res, n3 + 2, save, 2);
524 }
525 }
526 }
527
528
529 void
ZNP_mpn_mulmid_fallback(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)530 ZNP_mpn_mulmid_fallback (mp_limb_t* res,
531 const mp_limb_t* op1, size_t n1,
532 const mp_limb_t* op2, size_t n2)
533 {
534 if (n1 < n2 + 1)
535 return;
536
537 ZNP_FASTALLOC (temp, mp_limb_t, 6642, n1 + n2);
538 ZNP_mpn_mul (temp, op1, n1, op2, n2);
539 memcpy (res + 2, temp + n2 + 1, sizeof(mp_limb_t) * (n1 - n2 - 1));
540 ZNP_FASTFREE (temp);
541 }
542
543
544 void
ZNP_mpn_mulmid(mp_limb_t * res,const mp_limb_t * op1,size_t n1,const mp_limb_t * op2,size_t n2)545 ZNP_mpn_mulmid (mp_limb_t* res, const mp_limb_t* op1, size_t n1,
546 const mp_limb_t* op2, size_t n2)
547 {
548 ZNP_ASSERT (n1 >= n2);
549 ZNP_ASSERT (n2 >= 1);
550
551 if (n2 >= ZNP_mpn_mulmid_fallback_thresh)
552 {
553 ZNP_mpn_mulmid_fallback (res, op1, n1, op2, n2);
554 return;
555 }
556
557 // try using the simple middle product
558 ZNP_mpn_smp (res, op1, n1, op2, n2);
559
560 #if GMP_NAIL_BITS == 0 && ULONG_BITS == GMP_NUMB_BITS
561
562 // If there's a possibility of overflow from lower diagonals, we just give
563 // up and do the whole product. (Note: this should happen extremely rarely
564 // on uniform random input. However, on data generated by mpn_random2, it
565 // seems to happen with non-negligible probability.)
566 if (res[1] >= -(mp_limb_t)(n2))
567 ZNP_mpn_mulmid_fallback (res, op1, n1, op2, n2);
568
569 #else
570 #error Not nails-safe yet
571 #endif
572 }
573
574
575 // end of file ****************************************************************
576