1 // Copyright 2021 the V8 project authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // FFT-based multiplication, due to Schönhage and Strassen.
6 // This implementation mostly follows the description given in:
7 // Christoph Lüders: Fast Multiplication of Large Integers,
8 // http://arxiv.org/abs/1503.04955
9 
10 #include "src/bigint/bigint-internal.h"
11 #include "src/bigint/digit-arithmetic.h"
12 #include "src/bigint/util.h"
13 #include "src/bigint/vector-arithmetic.h"
14 
15 namespace v8 {
16 namespace bigint {
17 
18 namespace {
19 
20 ////////////////////////////////////////////////////////////////////////////////
21 // Part 1: Functions for "mod F_n" arithmetic.
22 // F_n is of the shape 2^K + 1, and for convenience we use K to count the
23 // number of digits rather than the number of bits, so F_n (or K) are implicit
24 // and deduced from the length {len} of the digits array.
25 
26 // Helper function for {ModFn} below.
ModFn_Helper(digit_t * x,int len,signed_digit_t high)27 void ModFn_Helper(digit_t* x, int len, signed_digit_t high) {
28   if (high > 0) {
29     digit_t borrow = high;
30     x[len - 1] = 0;
31     for (int i = 0; i < len; i++) {
32       x[i] = digit_sub(x[i], borrow, &borrow);
33       if (borrow == 0) break;
34     }
35   } else {
36     digit_t carry = -high;
37     x[len - 1] = 0;
38     for (int i = 0; i < len; i++) {
39       x[i] = digit_add2(x[i], carry, &carry);
40       if (carry == 0) break;
41     }
42   }
43 }
44 
45 // {x} := {x} mod F_n, assuming that {x} is "slightly" larger than F_n (e.g.
46 // after addition of two numbers that were mod-F_n-normalized before).
ModFn(digit_t * x,int len)47 void ModFn(digit_t* x, int len) {
48   int K = len - 1;
49   signed_digit_t high = x[K];
50   if (high == 0) return;
51   ModFn_Helper(x, len, high);
52   high = x[K];
53   if (high == 0) return;
54   DCHECK(high == 1 || high == -1);
55   ModFn_Helper(x, len, high);
56   high = x[K];
57   if (high == -1) ModFn_Helper(x, len, high);
58 }
59 
60 // {dest} := {src} mod F_n, assuming that {src} is about twice as long as F_n
61 // (e.g. after multiplication of two numbers that were mod-F_n-normalized
62 // before).
63 // {len} is length of {dest}; {src} is twice as long.
ModFnDoubleWidth(digit_t * dest,const digit_t * src,int len)64 void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
65   int K = len - 1;
66   digit_t borrow = 0;
67   for (int i = 0; i < K; i++) {
68     dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
69   }
70   dest[K] = digit_sub2(0, src[2 * K], borrow, &borrow);
71   // {borrow} may be non-zero here, that's OK as {ModFn} will take care of it.
72   ModFn(dest, len);
73 }
74 
75 // Sets {sum} := {a} + {b} and {diff} := {a} - {b}, which is more efficient
76 // than computing sum and difference separately. Applies "mod F_n" normalization
77 // to both results.
SumDiff(digit_t * sum,digit_t * diff,const digit_t * a,const digit_t * b,int len)78 void SumDiff(digit_t* sum, digit_t* diff, const digit_t* a, const digit_t* b,
79              int len) {
80   digit_t carry = 0;
81   digit_t borrow = 0;
82   for (int i = 0; i < len; i++) {
83     // Read both values first, because inputs and outputs can overlap.
84     digit_t ai = a[i];
85     digit_t bi = b[i];
86     sum[i] = digit_add3(ai, bi, carry, &carry);
87     diff[i] = digit_sub2(ai, bi, borrow, &borrow);
88   }
89   ModFn(sum, len);
90   ModFn(diff, len);
91 }
92 
93 // {result} := ({input} << shift) mod F_n, where shift >= K.
ShiftModFn_Large(digit_t * result,const digit_t * input,int digit_shift,int bits_shift,int K)94 void ShiftModFn_Large(digit_t* result, const digit_t* input, int digit_shift,
95                       int bits_shift, int K) {
96   // If {digit_shift} is greater than K, we use the following transformation
97   // (where, since everything is mod 2^K + 1, we are allowed to add or
98   // subtract any multiple of 2^K + 1 at any time):
99   //      x * 2^{K+m}   mod 2^K + 1
100   //   == x * 2^K * 2^m - (2^K + 1)*(x * 2^m)   mod 2^K + 1
101   //   == x * 2^K * 2^m - x * 2^K * 2^m - x * 2^m   mod 2^K + 1
102   //   == -x * 2^m   mod 2^K + 1
103   // So the flow is the same as for m < K, but we invert the subtraction's
104   // operands. In order to avoid underflow, we virtually initialize the
105   // result to 2^K + 1:
106   //   input  =  [ iK ][iK-1] ....  .... [ i1 ][ i0 ]
107   //   result =  [   1][0000] ....  .... [0000][0001]
108   //            +                  [ iK ] .... [ iX ]
109   //            -      [iX-1] .... [ i0 ]
110   DCHECK(digit_shift >= K);
111   digit_shift -= K;
112   digit_t borrow = 0;
113   if (bits_shift == 0) {
114     digit_t carry = 1;
115     for (int i = 0; i < digit_shift; i++) {
116       result[i] = digit_add2(input[i + K - digit_shift], carry, &carry);
117     }
118     result[digit_shift] = digit_sub(input[K] + carry, input[0], &borrow);
119     for (int i = digit_shift + 1; i < K; i++) {
120       digit_t d = input[i - digit_shift];
121       result[i] = digit_sub2(0, d, borrow, &borrow);
122     }
123   } else {
124     digit_t add_carry = 1;
125     digit_t input_carry =
126         input[K - digit_shift - 1] >> (kDigitBits - bits_shift);
127     for (int i = 0; i < digit_shift; i++) {
128       digit_t d = input[i + K - digit_shift];
129       digit_t summand = (d << bits_shift) | input_carry;
130       result[i] = digit_add2(summand, add_carry, &add_carry);
131       input_carry = d >> (kDigitBits - bits_shift);
132     }
133     {
134       // result[digit_shift] = (add_carry + iK_part) - i0_part
135       digit_t d = input[K];
136       digit_t iK_part = (d << bits_shift) | input_carry;
137       digit_t iK_carry = d >> (kDigitBits - bits_shift);
138       digit_t sum = digit_add2(add_carry, iK_part, &add_carry);
139       // {iK_carry} is less than a full digit, so we can merge {add_carry}
140       // into it without overflow.
141       iK_carry += add_carry;
142       d = input[0];
143       digit_t i0_part = d << bits_shift;
144       result[digit_shift] = digit_sub(sum, i0_part, &borrow);
145       input_carry = d >> (kDigitBits - bits_shift);
146       if (digit_shift + 1 < K) {
147         d = input[1];
148         digit_t subtrahend = (d << bits_shift) | input_carry;
149         result[digit_shift + 1] =
150             digit_sub2(iK_carry, subtrahend, borrow, &borrow);
151         input_carry = d >> (kDigitBits - bits_shift);
152       }
153     }
154     for (int i = digit_shift + 2; i < K; i++) {
155       digit_t d = input[i - digit_shift];
156       digit_t subtrahend = (d << bits_shift) | input_carry;
157       result[i] = digit_sub2(0, subtrahend, borrow, &borrow);
158       input_carry = d >> (kDigitBits - bits_shift);
159     }
160   }
161   // The virtual 1 in result[K] should be eliminated by {borrow}. If there
162   // is no borrow, then the virtual initialization was too much. Subtract
163   // 2^K + 1.
164   result[K] = 0;
165   if (borrow != 1) {
166     borrow = 1;
167     for (int i = 0; i < K; i++) {
168       result[i] = digit_sub(result[i], borrow, &borrow);
169       if (borrow == 0) break;
170     }
171     if (borrow != 0) {
172       // The result must be 2^K.
173       for (int i = 0; i < K; i++) result[i] = 0;
174       result[K] = 1;
175     }
176   }
177 }
178 
179 // Sets {result} := {input} * 2^{power_of_two} mod 2^{K} + 1.
180 // This function is highly relevant for overall performance.
ShiftModFn(digit_t * result,const digit_t * input,int power_of_two,int K,int zero_above=0x7FFFFFFF)181 void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
182                 int zero_above = 0x7FFFFFFF) {
183   // The modulo-reduction amounts to a subtraction, which we combine
184   // with the shift as follows:
185   //   input  =  [ iK ][iK-1] ....  .... [ i1 ][ i0 ]
186   //   result =        [iX-1] .... [ i0 ] <<<<<<<<<<< shift by {power_of_two}
187   //            -                  [ iK ] .... [ iX ]
188   // where "X" is the index "K - digit_shift".
189   int digit_shift = power_of_two / kDigitBits;
190   int bits_shift = power_of_two % kDigitBits;
191   // By an analogous construction to the "digit_shift >= K" case,
192   // it turns out that:
193   //    x * 2^{2K+m} == x * 2^m   mod 2^K + 1.
194   while (digit_shift >= 2 * K) digit_shift -= 2 * K;  // Faster than '%'!
195   if (digit_shift >= K) {
196     return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
197   }
198   digit_t borrow = 0;
199   if (bits_shift == 0) {
200     // We do a single pass over {input}, starting by copying digits [i1] to
201     // [iX-1] to result indices digit_shift+1 to K-1.
202     int i = 1;
203     // Read input digits unless we know they are zero.
204     int cap = std::min(K - digit_shift, zero_above);
205     for (; i < cap; i++) {
206       result[i + digit_shift] = input[i];
207     }
208     // Any remaining work can hard-code the knowledge that input[i] == 0.
209     for (; i < K - digit_shift; i++) {
210       DCHECK(input[i] == 0);  // NOLINT(readability/check)
211       result[i + digit_shift] = 0;
212     }
213     // Second phase: subtract input digits [iX] to [iK] from (virtually) zero-
214     // initialized result indices 0 to digit_shift-1.
215     cap = std::min(K, zero_above);
216     for (; i < cap; i++) {
217       digit_t d = input[i];
218       result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
219     }
220     // Any remaining work can hard-code the knowledge that input[i] == 0.
221     for (; i < K; i++) {
222       DCHECK(input[i] == 0);  // NOLINT(readability/check)
223       result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
224     }
225     // Last step: subtract [iK] from [i0] and store at result index digit_shift.
226     result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
227   } else {
228     // Same flow as before, but taking bits_shift != 0 into account.
229     // First phase: result indices digit_shift+1 to K.
230     digit_t carry = 0;
231     int i = 0;
232     // Read input digits unless we know they are zero.
233     int cap = std::min(K - digit_shift, zero_above);
234     for (; i < cap; i++) {
235       digit_t d = input[i];
236       result[i + digit_shift] = (d << bits_shift) | carry;
237       carry = d >> (kDigitBits - bits_shift);
238     }
239     // Any remaining work can hard-code the knowledge that input[i] == 0.
240     for (; i < K - digit_shift; i++) {
241       DCHECK(input[i] == 0);  // NOLINT(readability/check)
242       result[i + digit_shift] = carry;
243       carry = 0;
244     }
245     // Second phase: result indices 0 to digit_shift - 1.
246     cap = std::min(K, zero_above);
247     for (; i < cap; i++) {
248       digit_t d = input[i];
249       result[i - K + digit_shift] =
250           digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
251       carry = d >> (kDigitBits - bits_shift);
252     }
253     // Any remaining work can hard-code the knowledge that input[i] == 0.
254     if (i < K) {
255       DCHECK(input[i] == 0);  // NOLINT(readability/check)
256       result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
257       carry = 0;
258       i++;
259     }
260     for (; i < K; i++) {
261       DCHECK(input[i] == 0);  // NOLINT(readability/check)
262       result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
263     }
264     // Last step: compute result[digit_shift].
265     digit_t d = input[K];
266     result[digit_shift] = digit_sub2(
267         result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);
268     // No carry left.
269     DCHECK((d >> (kDigitBits - bits_shift)) == 0);  // NOLINT(readability/check)
270   }
271   result[K] = 0;
272   for (int i = digit_shift + 1; i <= K && borrow > 0; i++) {
273     result[i] = digit_sub(result[i], borrow, &borrow);
274   }
275   if (borrow > 0) {
276     // Underflow means we subtracted too much. Add 2^K + 1.
277     digit_t carry = 1;
278     for (int i = 0; i <= K; i++) {
279       result[i] = digit_add2(result[i], carry, &carry);
280       if (carry == 0) break;
281     }
282     result[K] = digit_add2(result[K], 1, &carry);
283   }
284 }
285 
286 ////////////////////////////////////////////////////////////////////////////////
287 // Part 2: FFT-based multiplication is very sensitive to appropriate choice
288 // of parameters. The following functions choose the parameters that the
289 // subsequent actual computation will use. This is partially based on formal
290 // constraints and partially on experimentally-determined heuristics.
291 
292 struct Parameters {
293   // We never use the default values, but skipping zero-initialization
294   // of these fields saddens and confuses MSan.
295   int m{0};
296   int K{0};
297   int n{0};
298   int s{0};
299   int r{0};
300 };
301 
302 // Computes parameters for the main calculation, given a bit length {N} and
303 // an {m}. See the paper for details.
ComputeParameters(int N,int m,Parameters * params)304 void ComputeParameters(int N, int m, Parameters* params) {
305   N *= kDigitBits;
306   int n = 1 << m;  // 2^m
307   int nhalf = n >> 1;
308   int s = (N + n - 1) >> m;  // ceil(N/n)
309   s = RoundUp(s, kDigitBits);
310   int K = m + 2 * s + 1;  // K must be at least this big...
311   K = RoundUp(K, nhalf);  // ...and a multiple of n/2.
312   int r = K >> (m - 1);   // Which multiple?
313 
314   // We want recursive calls to make progress, so force K to be a multiple
315   // of 8 if it's above the recursion threshold. Otherwise, K must be a
316   // multiple of kDigitBits.
317   const int threshold = (K + 1 >= kFftInnerThreshold * kDigitBits)
318                             ? 3 + kLog2DigitBits
319                             : kLog2DigitBits;
320   int K_tz = CountTrailingZeros(K);
321   while (K_tz < threshold) {
322     K += (1 << K_tz);
323     r = K >> (m - 1);
324     K_tz = CountTrailingZeros(K);
325   }
326 
327   DCHECK(K % kDigitBits == 0);  // NOLINT(readability/check)
328   DCHECK(s % kDigitBits == 0);  // NOLINT(readability/check)
329   params->K = K / kDigitBits;
330   params->s = s / kDigitBits;
331   params->n = n;
332   params->r = r;
333 }
334 
335 // Computes parameters for recursive invocations ("inner layer").
ComputeParameters_Inner(int N,Parameters * params)336 void ComputeParameters_Inner(int N, Parameters* params) {
337   int max_m = CountTrailingZeros(N);
338   int N_bits = BitLength(N);
339   int m = N_bits - 4;  // Don't let s get too small.
340   m = std::min(max_m, m);
341   N *= kDigitBits;
342   int n = 1 << m;  // 2^m
343   // We can't round up s in the inner layer, because N = n*s is fixed.
344   int s = N >> m;
345   DCHECK(N == s * n);
346   int K = m + 2 * s + 1;  // K must be at least this big...
347   K = RoundUp(K, n);      // ...and a multiple of n and kDigitBits.
348   K = RoundUp(K, kDigitBits);
349   params->r = K >> m;           // Which multiple?
350   DCHECK(K % kDigitBits == 0);  // NOLINT(readability/check)
351   DCHECK(s % kDigitBits == 0);  // NOLINT(readability/check)
352   params->K = K / kDigitBits;
353   params->s = s / kDigitBits;
354   params->n = n;
355   params->m = m;
356 }
357 
PredictInnerK(int N)358 int PredictInnerK(int N) {
359   Parameters params;
360   ComputeParameters_Inner(N, &params);
361   return params.K;
362 }
363 
364 // Applies heuristics to decide whether {m} should be decremented, by looking
365 // at what would happen to {K} and {s} if {m} was decremented.
ShouldDecrementM(const Parameters & current,const Parameters & next,const Parameters & after_next)366 bool ShouldDecrementM(const Parameters& current, const Parameters& next,
367                       const Parameters& after_next) {
368   // K == 64 seems to work particularly well.
369   if (current.K == 64 && next.K >= 112) return false;
370   // Small values for s are never efficient.
371   if (current.s < 6) return true;
372   // The time is roughly determined by K * n. When we decrement m, then
373   // n always halves, and K usually gets bigger, by up to 2x.
374   // For not-quite-so-small s, look at how much bigger K would get: if
375   // the K increase is small enough, making n smaller is worth it.
376   // Empirically, it's most meaningful to look at the K *after* next.
377   // The specific threshold values have been chosen by running many
378   // benchmarks on inputs of many sizes, and manually selecting thresholds
379   // that seemed to produce good results.
380   double factor = static_cast<double>(after_next.K) / current.K;
381   if ((current.s == 6 && factor < 3.85) ||  // --
382       (current.s == 7 && factor < 3.73) ||  // --
383       (current.s == 8 && factor < 3.55) ||  // --
384       (current.s == 9 && factor < 3.50) ||  // --
385       factor < 3.4) {
386     return true;
387   }
388   // If K is just below the recursion threshold, make sure we do recurse,
389   // unless doing so would be particularly inefficient (large inner_K).
390   // If K is just above the recursion threshold, doubling it often makes
391   // the inner call more efficient.
392   if (current.K >= 160 && current.K < 250 && PredictInnerK(next.K) < 28) {
393     return true;
394   }
395   // If we found no reason to decrement, keep m as large as possible.
396   return false;
397 }
398 
399 // Decides what parameters to use for a given input bit length {N}.
400 // Returns the chosen m.
GetParameters(int N,Parameters * params)401 int GetParameters(int N, Parameters* params) {
402   int N_bits = BitLength(N);
403   int max_m = N_bits - 3;                   // Larger m make s too small.
404   max_m = std::max(kLog2DigitBits, max_m);  // Smaller m break the logic below.
405   int m = max_m;
406   Parameters current;
407   ComputeParameters(N, m, &current);
408   Parameters next;
409   ComputeParameters(N, m - 1, &next);
410   while (m > 2) {
411     Parameters after_next;
412     ComputeParameters(N, m - 2, &after_next);
413     if (ShouldDecrementM(current, next, after_next)) {
414       m--;
415       current = next;
416       next = after_next;
417     } else {
418       break;
419     }
420   }
421   *params = current;
422   return m;
423 }
424 
425 ////////////////////////////////////////////////////////////////////////////////
426 // Part 3: Fast Fourier Transformation.
427 
428 class FFTContainer {
429  public:
430   // {n} is the number of chunks, whose length is {K}+1.
431   // {K} determines F_n = 2^(K * kDigitBits) + 1.
FFTContainer(int n,int K,ProcessorImpl * processor)432   FFTContainer(int n, int K, ProcessorImpl* processor)
433       : n_(n), K_(K), length_(K + 1), processor_(processor) {
434     storage_ = new digit_t[length_ * n_];
435     part_ = new digit_t*[n_];
436     digit_t* ptr = storage_;
437     for (int i = 0; i < n; i++, ptr += length_) {
438       part_[i] = ptr;
439     }
440     temp_ = new digit_t[length_ * 2];
441   }
442   FFTContainer() = delete;
443   FFTContainer(const FFTContainer&) = delete;
444   FFTContainer& operator=(const FFTContainer&) = delete;
445 
~FFTContainer()446   ~FFTContainer() {
447     delete[] storage_;
448     delete[] part_;
449     delete[] temp_;
450   }
451 
452   void Start_Default(Digits X, int chunk_size, int theta, int omega);
453   void Start(Digits X, int chunk_size, int theta, int omega);
454 
455   void NormalizeAndRecombine(int omega, int m, RWDigits Z, int chunk_size);
456   void CounterWeightAndRecombine(int theta, int m, RWDigits Z, int chunk_size);
457 
458   void FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
459                                     digit_t* temp);
460   void FFT_Recurse(int start, int half, int omega, digit_t* temp);
461 
462   void BackwardFFT(int start, int len, int omega);
463   void BackwardFFT_Threadsafe(int start, int len, int omega, digit_t* temp);
464 
465   void PointwiseMultiply(const FFTContainer& other);
466   void DoPointwiseMultiplication(const FFTContainer& other, int start, int end,
467                                  digit_t* temp);
468 
length() const469   int length() const { return length_; }
470 
471  private:
472   const int n_;       // Number of parts.
473   const int K_;       // Always length_ - 1.
474   const int length_;  // Length of each part, in digits.
475   ProcessorImpl* processor_;
476   digit_t* storage_;  // Combined storage of all parts.
477   digit_t** part_;    // Pointers to each part.
478   digit_t* temp_;     // Temporary storage with size 2 * length_.
479 };
480 
CopyAndZeroExtend(digit_t * dst,const digit_t * src,int digits_to_copy,size_t total_bytes)481 inline void CopyAndZeroExtend(digit_t* dst, const digit_t* src,
482                               int digits_to_copy, size_t total_bytes) {
483   size_t bytes_to_copy = digits_to_copy * sizeof(digit_t);
484   memcpy(dst, src, bytes_to_copy);
485   memset(dst + digits_to_copy, 0, total_bytes - bytes_to_copy);
486 }
487 
488 // Reads {X} into the FFTContainer's internal storage, dividing it into chunks
489 // while doing so; then performs the forward FFT.
Start_Default(Digits X,int chunk_size,int theta,int omega)490 void FFTContainer::Start_Default(Digits X, int chunk_size, int theta,
491                                  int omega) {
492   int len = X.len();
493   const digit_t* pointer = X.digits();
494   const size_t part_length_in_bytes = length_ * sizeof(digit_t);
495   int current_theta = 0;
496   int i = 0;
497   for (; i < n_ && len > 0; i++, current_theta += theta) {
498     chunk_size = std::min(chunk_size, len);
499     // For invocations via MultiplyFFT_Inner, X.len() == n_ * chunk_size + 1,
500     // because the outer layer's "K" is passed as the inner layer's "N".
501     // Since X is (mod Fn)-normalized on the outer layer, there is the rare
502     // corner case where X[n_ * chunk_size] == 1. Detect that case, and handle
503     // the extra bit as part of the last chunk; we always have the space.
504     if (i == n_ - 1 && len == chunk_size + 1) {
505       DCHECK(X[n_ * chunk_size] <= 1);  // NOLINT(readability/check)
506       DCHECK(length_ >= chunk_size + 1);
507       chunk_size++;
508     }
509     if (current_theta != 0) {
510       // Multiply with theta^i, and reduce modulo 2^K + 1.
511       // We pass theta as a shift amount; it really means 2^theta.
512       CopyAndZeroExtend(temp_, pointer, chunk_size, part_length_in_bytes);
513       ShiftModFn(part_[i], temp_, current_theta, K_, chunk_size);
514     } else {
515       CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
516     }
517     pointer += chunk_size;
518     len -= chunk_size;
519   }
520   DCHECK(len == 0);  // NOLINT(readability/check)
521   for (; i < n_; i++) {
522     memset(part_[i], 0, part_length_in_bytes);
523   }
524   FFT_ReturnShuffledThreadsafe(0, n_, omega, temp_);
525 }
526 
527 // This version of Start is optimized for the case where ~half of the
528 // container will be filled with padding zeros.
Start(Digits X,int chunk_size,int theta,int omega)529 void FFTContainer::Start(Digits X, int chunk_size, int theta, int omega) {
530   int len = X.len();
531   if (len > n_ * chunk_size / 2) {
532     return Start_Default(X, chunk_size, theta, omega);
533   }
534   DCHECK(theta == 0);  // NOLINT(readability/check)
535   const digit_t* pointer = X.digits();
536   const size_t part_length_in_bytes = length_ * sizeof(digit_t);
537   int nhalf = n_ / 2;
538   // Unrolled first iteration.
539   CopyAndZeroExtend(part_[0], pointer, chunk_size, part_length_in_bytes);
540   CopyAndZeroExtend(part_[nhalf], pointer, chunk_size, part_length_in_bytes);
541   pointer += chunk_size;
542   len -= chunk_size;
543   int i = 1;
544   for (; i < nhalf && len > 0; i++) {
545     chunk_size = std::min(chunk_size, len);
546     CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
547     int w = omega * i;
548     ShiftModFn(part_[i + nhalf], part_[i], w, K_, chunk_size);
549     pointer += chunk_size;
550     len -= chunk_size;
551   }
552   for (; i < nhalf; i++) {
553     memset(part_[i], 0, part_length_in_bytes);
554     memset(part_[i + nhalf], 0, part_length_in_bytes);
555   }
556   FFT_Recurse(0, nhalf, omega, temp_);
557 }
558 
559 // Forward transformation.
560 // We use the "DIF" aka "decimation in frequency" transform, because it
561 // leaves the result in "bit reversed" order, which is precisely what we
562 // need as input for the "DIT" aka "decimation in time" backwards transform.
FFT_ReturnShuffledThreadsafe(int start,int len,int omega,digit_t * temp)563 void FFTContainer::FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
564                                                 digit_t* temp) {
565   DCHECK((len & 1) == 0);  // {len} must be even. NOLINT(readability/check)
566   int half = len / 2;
567   SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
568           length_);
569   for (int k = 1; k < half; k++) {
570     SumDiff(part_[start + k], temp, part_[start + k], part_[start + half + k],
571             length_);
572     int w = omega * k;
573     ShiftModFn(part_[start + half + k], temp, w, K_);
574   }
575   FFT_Recurse(start, half, omega, temp);
576 }
577 
578 // Recursive step of the above, factored out for additional callers.
FFT_Recurse(int start,int half,int omega,digit_t * temp)579 void FFTContainer::FFT_Recurse(int start, int half, int omega, digit_t* temp) {
580   if (half > 1) {
581     FFT_ReturnShuffledThreadsafe(start, half, 2 * omega, temp);
582     FFT_ReturnShuffledThreadsafe(start + half, half, 2 * omega, temp);
583   }
584 }
585 
586 // Backward transformation.
587 // We use the "DIT" aka "decimation in time" transform here, because it
588 // turns bit-reversed input into normally sorted output.
BackwardFFT(int start,int len,int omega)589 void FFTContainer::BackwardFFT(int start, int len, int omega) {
590   BackwardFFT_Threadsafe(start, len, omega, temp_);
591 }
592 
BackwardFFT_Threadsafe(int start,int len,int omega,digit_t * temp)593 void FFTContainer::BackwardFFT_Threadsafe(int start, int len, int omega,
594                                           digit_t* temp) {
595   DCHECK((len & 1) == 0);  // {len} must be even. NOLINT(readability/check)
596   int half = len / 2;
597   // Don't recurse for half == 2, as PointwiseMultiply already performed
598   // the first level of the backwards FFT.
599   if (half > 2) {
600     BackwardFFT_Threadsafe(start, half, 2 * omega, temp);
601     BackwardFFT_Threadsafe(start + half, half, 2 * omega, temp);
602   }
603   SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
604           length_);
605   for (int k = 1; k < half; k++) {
606     int w = omega * (len - k);
607     ShiftModFn(temp, part_[start + half + k], w, K_);
608     SumDiff(part_[start + k], part_[start + half + k], part_[start + k], temp,
609             length_);
610   }
611 }
612 
613 // Recombines the result's parts into {Z}, after backwards FFT.
NormalizeAndRecombine(int omega,int m,RWDigits Z,int chunk_size)614 void FFTContainer::NormalizeAndRecombine(int omega, int m, RWDigits Z,
615                                          int chunk_size) {
616   Z.Clear();
617   int z_index = 0;
618   const int shift = n_ * omega - m;
619   for (int i = 0; i < n_; i++, z_index += chunk_size) {
620     digit_t* part = part_[i];
621     ShiftModFn(temp_, part, shift, K_);
622     digit_t carry = 0;
623     int zi = z_index;
624     int j = 0;
625     for (; j < length_ && zi < Z.len(); j++, zi++) {
626       Z[zi] = digit_add3(Z[zi], temp_[j], carry, &carry);
627     }
628     for (; j < length_; j++) {
629       DCHECK(temp_[j] == 0);  // NOLINT(readability/check)
630     }
631     if (carry != 0) {
632       DCHECK(zi < Z.len());
633       Z[zi] = carry;
634     }
635   }
636 }
637 
638 // Helper function for {CounterWeightAndRecombine} below.
ShouldBeNegative(const digit_t * x,int xlen,digit_t threshold,int s)639 bool ShouldBeNegative(const digit_t* x, int xlen, digit_t threshold, int s) {
640   if (x[2 * s] >= threshold) return true;
641   for (int i = 2 * s + 1; i < xlen; i++) {
642     if (x[i] > 0) return true;
643   }
644   return false;
645 }
646 
647 // Same as {NormalizeAndRecombine} above, but for the needs of the recursive
648 // invocation ("inner layer") of FFT multiplication, where an additional
649 // counter-weighting step is required.
CounterWeightAndRecombine(int theta,int m,RWDigits Z,int s)650 void FFTContainer::CounterWeightAndRecombine(int theta, int m, RWDigits Z,
651                                              int s) {
652   Z.Clear();
653   int z_index = 0;
654   for (int k = 0; k < n_; k++, z_index += s) {
655     int shift = -theta * k - m;
656     if (shift < 0) shift += 2 * n_ * theta;
657     DCHECK(shift >= 0);  // NOLINT(readability/check)
658     digit_t* input = part_[k];
659     ShiftModFn(temp_, input, shift, K_);
660     int remaining_z = Z.len() - z_index;
661     if (ShouldBeNegative(temp_, length_, k + 1, s)) {
662       // Subtract F_n from input before adding to result. We use the following
663       // transformation (knowing that X < F_n):
664       // Z + (X - F_n) == Z - (F_n - X)
665       digit_t borrow_z = 0;
666       digit_t borrow_Fn = 0;
667       {
668         // i == 0:
669         digit_t d = digit_sub(1, temp_[0], &borrow_Fn);
670         Z[z_index] = digit_sub(Z[z_index], d, &borrow_z);
671       }
672       int i = 1;
673       for (; i < K_ && i < remaining_z; i++) {
674         digit_t d = digit_sub2(0, temp_[i], borrow_Fn, &borrow_Fn);
675         Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
676       }
677       DCHECK(i == K_ && K_ == length_ - 1);
678       for (; i < length_ && i < remaining_z; i++) {
679         digit_t d = digit_sub2(1, temp_[i], borrow_Fn, &borrow_Fn);
680         Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
681       }
682       DCHECK(borrow_Fn == 0);  // NOLINT(readability/check)
683       for (; borrow_z > 0 && i < remaining_z; i++) {
684         Z[z_index + i] = digit_sub(Z[z_index + i], borrow_z, &borrow_z);
685       }
686     } else {
687       digit_t carry = 0;
688       int i = 0;
689       for (; i < length_ && i < remaining_z; i++) {
690         Z[z_index + i] = digit_add3(Z[z_index + i], temp_[i], carry, &carry);
691       }
692       for (; i < length_; i++) {
693         DCHECK(temp_[i] == 0);  // NOLINT(readability/check)
694       }
695       for (; carry > 0 && i < remaining_z; i++) {
696         Z[z_index + i] = digit_add2(Z[z_index + i], carry, &carry);
697       }
698       // {carry} might be != 0 here if Z was negative before. That's fine.
699     }
700   }
701 }
702 
703 // Main FFT function for recursive invocations ("inner layer").
MultiplyFFT_Inner(RWDigits Z,Digits X,Digits Y,const Parameters & params,ProcessorImpl * processor)704 void MultiplyFFT_Inner(RWDigits Z, Digits X, Digits Y, const Parameters& params,
705                        ProcessorImpl* processor) {
706   int omega = 2 * params.r;  // really: 2^(2r)
707   int theta = params.r;      // really: 2^r
708 
709   FFTContainer a(params.n, params.K, processor);
710   a.Start_Default(X, params.s, theta, omega);
711   FFTContainer b(params.n, params.K, processor);
712   b.Start_Default(Y, params.s, theta, omega);
713 
714   a.PointwiseMultiply(b);
715   if (processor->should_terminate()) return;
716 
717   FFTContainer& c = a;
718   c.BackwardFFT(0, params.n, omega);
719 
720   c.CounterWeightAndRecombine(theta, params.m, Z, params.s);
721 }
722 
723 // Actual implementation of pointwise multiplications.
DoPointwiseMultiplication(const FFTContainer & other,int start,int end,digit_t * temp)724 void FFTContainer::DoPointwiseMultiplication(const FFTContainer& other,
725                                              int start, int end,
726                                              digit_t* temp) {
727   // The (K_ & 3) != 0 condition makes sure that the inner FFT gets
728   // to split the work into at least 4 chunks.
729   bool use_fft = length_ >= kFftInnerThreshold && (K_ & 3) == 0;
730   Parameters params;
731   if (use_fft) ComputeParameters_Inner(K_, &params);
732   RWDigits result(temp, 2 * length_);
733   for (int i = start; i < end; i++) {
734     Digits A(part_[i], length_);
735     Digits B(other.part_[i], length_);
736     if (use_fft) {
737       MultiplyFFT_Inner(result, A, B, params, processor_);
738     } else {
739       processor_->Multiply(result, A, B);
740     }
741     if (processor_->should_terminate()) return;
742     ModFnDoubleWidth(part_[i], result.digits(), length_);
743     // To improve cache friendliness, we perform the first level of the
744     // backwards FFT here.
745     if ((i & 1) == 1) {
746       SumDiff(part_[i - 1], part_[i], part_[i - 1], part_[i], length_);
747     }
748   }
749 }
750 
751 // Convenient entry point for pointwise multiplications.
PointwiseMultiply(const FFTContainer & other)752 void FFTContainer::PointwiseMultiply(const FFTContainer& other) {
753   DCHECK(n_ == other.n_);
754   DoPointwiseMultiplication(other, 0, n_, temp_);
755 }
756 
757 }  // namespace
758 
759 ////////////////////////////////////////////////////////////////////////////////
760 // Part 4: Tying everything together into a multiplication algorithm.
761 
762 // TODO(jkummerow): Consider doing a "Mersenne transform" and CRT reconstruction
763 // of the final result. Might yield a few percent of perf improvement.
764 
765 // TODO(jkummerow): Consider implementing the "sqrt(2) trick".
766 // Gaudry/Kruppa/Zimmerman report that it saved them around 10%.
767 
MultiplyFFT(RWDigits Z,Digits X,Digits Y)768 void ProcessorImpl::MultiplyFFT(RWDigits Z, Digits X, Digits Y) {
769   Parameters params;
770   int m = GetParameters(X.len() + Y.len(), &params);
771   int omega = params.r;  // really: 2^r
772 
773   FFTContainer a(params.n, params.K, this);
774   a.Start(X, params.s, 0, omega);
775   if (X == Y) {
776     // Squaring.
777     a.PointwiseMultiply(a);
778   } else {
779     FFTContainer b(params.n, params.K, this);
780     b.Start(Y, params.s, 0, omega);
781     a.PointwiseMultiply(b);
782   }
783   if (should_terminate()) return;
784 
785   a.BackwardFFT(0, params.n, omega);
786   a.NormalizeAndRecombine(omega, m, Z, params.s);
787 }
788 
789 }  // namespace bigint
790 }  // namespace v8
791