1 use crate::std_alloc::{Cow, Vec};
2 use core::cmp;
3 use core::cmp::Ordering::{self, Equal, Greater, Less};
4 use core::iter::repeat;
5 use core::mem;
6 use num_traits::{One, PrimInt, Zero};
7 
8 use crate::biguint::biguint_from_vec;
9 use crate::biguint::BigUint;
10 
11 use crate::bigint::BigInt;
12 use crate::bigint::Sign;
13 use crate::bigint::Sign::{Minus, NoSign, Plus};
14 
15 use crate::big_digit::{self, BigDigit, DoubleBigDigit, SignedDoubleBigDigit};
16 
17 // Generic functions for add/subtract/multiply with carry/borrow:
18 
19 // Add with carry:
20 #[inline]
adc(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit21 fn adc(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
22     *acc += DoubleBigDigit::from(a);
23     *acc += DoubleBigDigit::from(b);
24     let lo = *acc as BigDigit;
25     *acc >>= big_digit::BITS;
26     lo
27 }
28 
29 // Subtract with borrow:
30 #[inline]
sbb(a: BigDigit, b: BigDigit, acc: &mut SignedDoubleBigDigit) -> BigDigit31 fn sbb(a: BigDigit, b: BigDigit, acc: &mut SignedDoubleBigDigit) -> BigDigit {
32     *acc += SignedDoubleBigDigit::from(a);
33     *acc -= SignedDoubleBigDigit::from(b);
34     let lo = *acc as BigDigit;
35     *acc >>= big_digit::BITS;
36     lo
37 }
38 
39 #[inline]
mac_with_carry( a: BigDigit, b: BigDigit, c: BigDigit, acc: &mut DoubleBigDigit, ) -> BigDigit40 pub(crate) fn mac_with_carry(
41     a: BigDigit,
42     b: BigDigit,
43     c: BigDigit,
44     acc: &mut DoubleBigDigit,
45 ) -> BigDigit {
46     *acc += DoubleBigDigit::from(a);
47     *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
48     let lo = *acc as BigDigit;
49     *acc >>= big_digit::BITS;
50     lo
51 }
52 
53 #[inline]
mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit54 pub(crate) fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
55     *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
56     let lo = *acc as BigDigit;
57     *acc >>= big_digit::BITS;
58     lo
59 }
60 
61 /// Divide a two digit numerator by a one digit divisor, returns quotient and remainder:
62 ///
63 /// Note: the caller must ensure that both the quotient and remainder will fit into a single digit.
64 /// This is _not_ true for an arbitrary numerator/denominator.
65 ///
66 /// (This function also matches what the x86 divide instruction does).
67 #[inline]
div_wide(hi: BigDigit, lo: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit)68 fn div_wide(hi: BigDigit, lo: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
69     debug_assert!(hi < divisor);
70 
71     let lhs = big_digit::to_doublebigdigit(hi, lo);
72     let rhs = DoubleBigDigit::from(divisor);
73     ((lhs / rhs) as BigDigit, (lhs % rhs) as BigDigit)
74 }
75 
76 /// For small divisors, we can divide without promoting to `DoubleBigDigit` by
77 /// using half-size pieces of digit, like long-division.
78 #[inline]
div_half(rem: BigDigit, digit: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit)79 fn div_half(rem: BigDigit, digit: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
80     use crate::big_digit::{HALF, HALF_BITS};
81     use num_integer::Integer;
82 
83     debug_assert!(rem < divisor && divisor <= HALF);
84     let (hi, rem) = ((rem << HALF_BITS) | (digit >> HALF_BITS)).div_rem(&divisor);
85     let (lo, rem) = ((rem << HALF_BITS) | (digit & HALF)).div_rem(&divisor);
86     ((hi << HALF_BITS) | lo, rem)
87 }
88 
89 #[inline]
div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit)90 pub(crate) fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) {
91     let mut rem = 0;
92 
93     if b <= big_digit::HALF {
94         for d in a.data.iter_mut().rev() {
95             let (q, r) = div_half(rem, *d, b);
96             *d = q;
97             rem = r;
98         }
99     } else {
100         for d in a.data.iter_mut().rev() {
101             let (q, r) = div_wide(rem, *d, b);
102             *d = q;
103             rem = r;
104         }
105     }
106 
107     (a.normalized(), rem)
108 }
109 
110 #[inline]
rem_digit(a: &BigUint, b: BigDigit) -> BigDigit111 pub(crate) fn rem_digit(a: &BigUint, b: BigDigit) -> BigDigit {
112     let mut rem = 0;
113 
114     if b <= big_digit::HALF {
115         for &digit in a.data.iter().rev() {
116             let (_, r) = div_half(rem, digit, b);
117             rem = r;
118         }
119     } else {
120         for &digit in a.data.iter().rev() {
121             let (_, r) = div_wide(rem, digit, b);
122             rem = r;
123         }
124     }
125 
126     rem
127 }
128 
129 /// Two argument addition of raw slices, `a += b`, returning the carry.
130 ///
131 /// This is used when the data `Vec` might need to resize to push a non-zero carry, so we perform
132 /// the addition first hoping that it will fit.
133 ///
134 /// The caller _must_ ensure that `a` is at least as long as `b`.
135 #[inline]
__add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit136 pub(crate) fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
137     debug_assert!(a.len() >= b.len());
138 
139     let mut carry = 0;
140     let (a_lo, a_hi) = a.split_at_mut(b.len());
141 
142     for (a, b) in a_lo.iter_mut().zip(b) {
143         *a = adc(*a, *b, &mut carry);
144     }
145 
146     if carry != 0 {
147         for a in a_hi {
148             *a = adc(*a, 0, &mut carry);
149             if carry == 0 {
150                 break;
151             }
152         }
153     }
154 
155     carry as BigDigit
156 }
157 
158 /// Two argument addition of raw slices:
159 /// a += b
160 ///
161 /// The caller _must_ ensure that a is big enough to store the result - typically this means
162 /// resizing a to max(a.len(), b.len()) + 1, to fit a possible carry.
add2(a: &mut [BigDigit], b: &[BigDigit])163 pub(crate) fn add2(a: &mut [BigDigit], b: &[BigDigit]) {
164     let carry = __add2(a, b);
165 
166     debug_assert!(carry == 0);
167 }
168 
sub2(a: &mut [BigDigit], b: &[BigDigit])169 pub(crate) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
170     let mut borrow = 0;
171 
172     let len = cmp::min(a.len(), b.len());
173     let (a_lo, a_hi) = a.split_at_mut(len);
174     let (b_lo, b_hi) = b.split_at(len);
175 
176     for (a, b) in a_lo.iter_mut().zip(b_lo) {
177         *a = sbb(*a, *b, &mut borrow);
178     }
179 
180     if borrow != 0 {
181         for a in a_hi {
182             *a = sbb(*a, 0, &mut borrow);
183             if borrow == 0 {
184                 break;
185             }
186         }
187     }
188 
189     // note: we're _required_ to fail on underflow
190     assert!(
191         borrow == 0 && b_hi.iter().all(|x| *x == 0),
192         "Cannot subtract b from a because b is larger than a."
193     );
194 }
195 
196 // Only for the Sub impl. `a` and `b` must have same length.
197 #[inline]
__sub2rev(a: &[BigDigit], b: &mut [BigDigit]) -> BigDigit198 pub(crate) fn __sub2rev(a: &[BigDigit], b: &mut [BigDigit]) -> BigDigit {
199     debug_assert!(b.len() == a.len());
200 
201     let mut borrow = 0;
202 
203     for (ai, bi) in a.iter().zip(b) {
204         *bi = sbb(*ai, *bi, &mut borrow);
205     }
206 
207     borrow as BigDigit
208 }
209 
sub2rev(a: &[BigDigit], b: &mut [BigDigit])210 pub(crate) fn sub2rev(a: &[BigDigit], b: &mut [BigDigit]) {
211     debug_assert!(b.len() >= a.len());
212 
213     let len = cmp::min(a.len(), b.len());
214     let (a_lo, a_hi) = a.split_at(len);
215     let (b_lo, b_hi) = b.split_at_mut(len);
216 
217     let borrow = __sub2rev(a_lo, b_lo);
218 
219     assert!(a_hi.is_empty());
220 
221     // note: we're _required_ to fail on underflow
222     assert!(
223         borrow == 0 && b_hi.iter().all(|x| *x == 0),
224         "Cannot subtract b from a because b is larger than a."
225     );
226 }
227 
sub_sign(a: &[BigDigit], b: &[BigDigit]) -> (Sign, BigUint)228 pub(crate) fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> (Sign, BigUint) {
229     // Normalize:
230     let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
231     let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
232 
233     match cmp_slice(a, b) {
234         Greater => {
235             let mut a = a.to_vec();
236             sub2(&mut a, b);
237             (Plus, biguint_from_vec(a))
238         }
239         Less => {
240             let mut b = b.to_vec();
241             sub2(&mut b, a);
242             (Minus, biguint_from_vec(b))
243         }
244         _ => (NoSign, Zero::zero()),
245     }
246 }
247 
248 /// Three argument multiply accumulate:
249 /// acc += b * c
mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit)250 pub(crate) fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
251     if c == 0 {
252         return;
253     }
254 
255     let mut carry = 0;
256     let (a_lo, a_hi) = acc.split_at_mut(b.len());
257 
258     for (a, &b) in a_lo.iter_mut().zip(b) {
259         *a = mac_with_carry(*a, b, c, &mut carry);
260     }
261 
262     let mut a = a_hi.iter_mut();
263     while carry != 0 {
264         let a = a.next().expect("carry overflow during multiplication!");
265         *a = adc(*a, 0, &mut carry);
266     }
267 }
268 
bigint_from_slice(slice: &[BigDigit]) -> BigInt269 fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
270     BigInt::from(biguint_from_vec(slice.to_vec()))
271 }
272 
273 /// Three argument multiply accumulate:
274 /// acc += b * c
mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit])275 fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
276     let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
277 
278     // We use three algorithms for different input sizes.
279     //
280     // - For small inputs, long multiplication is fastest.
281     // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
282     //   to avoid unnecessary allocations for intermediate values.
283     // - For the largest inputs we use Toom-3, which better optimizes the
284     //   number of operations, but uses more temporary allocations.
285     //
286     // The thresholds are somewhat arbitrary, chosen by evaluating the results
287     // of `cargo bench --bench bigint multiply`.
288 
289     if x.len() <= 32 {
290         // Long multiplication:
291         for (i, xi) in x.iter().enumerate() {
292             mac_digit(&mut acc[i..], y, *xi);
293         }
294     } else if x.len() <= 256 {
295         /*
296          * Karatsuba multiplication:
297          *
298          * The idea is that we break x and y up into two smaller numbers that each have about half
299          * as many digits, like so (note that multiplying by b is just a shift):
300          *
301          * x = x0 + x1 * b
302          * y = y0 + y1 * b
303          *
304          * With some algebra, we can compute x * y with three smaller products, where the inputs to
305          * each of the smaller products have only about half as many digits as x and y:
306          *
307          * x * y = (x0 + x1 * b) * (y0 + y1 * b)
308          *
309          * x * y = x0 * y0
310          *       + x0 * y1 * b
311          *       + x1 * y0 * b
312          *       + x1 * y1 * b^2
313          *
314          * Let p0 = x0 * y0 and p2 = x1 * y1:
315          *
316          * x * y = p0
317          *       + (x0 * y1 + x1 * y0) * b
318          *       + p2 * b^2
319          *
320          * The real trick is that middle term:
321          *
322          *         x0 * y1 + x1 * y0
323          *
324          *       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
325          *
326          *       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
327          *
328          * Now we complete the square:
329          *
330          *       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
331          *
332          *       = -((x1 - x0) * (y1 - y0)) + p0 + p2
333          *
334          * Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
335          *
336          * x * y = p0
337          *       + (p0 + p2 - p1) * b
338          *       + p2 * b^2
339          *
340          * Where the three intermediate products are:
341          *
342          * p0 = x0 * y0
343          * p1 = (x1 - x0) * (y1 - y0)
344          * p2 = x1 * y1
345          *
346          * In doing the computation, we take great care to avoid unnecessary temporary variables
347          * (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
348          * bit so we can use the same temporary variable for all the intermediate products:
349          *
350          * x * y = p2 * b^2 + p2 * b
351          *       + p0 * b + p0
352          *       - p1 * b
353          *
354          * The other trick we use is instead of doing explicit shifts, we slice acc at the
355          * appropriate offset when doing the add.
356          */
357 
358         /*
359          * When x is smaller than y, it's significantly faster to pick b such that x is split in
360          * half, not y:
361          */
362         let b = x.len() / 2;
363         let (x0, x1) = x.split_at(b);
364         let (y0, y1) = y.split_at(b);
365 
366         /*
367          * We reuse the same BigUint for all the intermediate multiplies and have to size p
368          * appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
369          */
370         let len = x1.len() + y1.len() + 1;
371         let mut p = BigUint { data: vec![0; len] };
372 
373         // p2 = x1 * y1
374         mac3(&mut p.data[..], x1, y1);
375 
376         // Not required, but the adds go faster if we drop any unneeded 0s from the end:
377         p.normalize();
378 
379         add2(&mut acc[b..], &p.data[..]);
380         add2(&mut acc[b * 2..], &p.data[..]);
381 
382         // Zero out p before the next multiply:
383         p.data.truncate(0);
384         p.data.extend(repeat(0).take(len));
385 
386         // p0 = x0 * y0
387         mac3(&mut p.data[..], x0, y0);
388         p.normalize();
389 
390         add2(&mut acc[..], &p.data[..]);
391         add2(&mut acc[b..], &p.data[..]);
392 
393         // p1 = (x1 - x0) * (y1 - y0)
394         // We do this one last, since it may be negative and acc can't ever be negative:
395         let (j0_sign, j0) = sub_sign(x1, x0);
396         let (j1_sign, j1) = sub_sign(y1, y0);
397 
398         match j0_sign * j1_sign {
399             Plus => {
400                 p.data.truncate(0);
401                 p.data.extend(repeat(0).take(len));
402 
403                 mac3(&mut p.data[..], &j0.data[..], &j1.data[..]);
404                 p.normalize();
405 
406                 sub2(&mut acc[b..], &p.data[..]);
407             }
408             Minus => {
409                 mac3(&mut acc[b..], &j0.data[..], &j1.data[..]);
410             }
411             NoSign => (),
412         }
413     } else {
414         // Toom-3 multiplication:
415         //
416         // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
417         // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
418         //
419         // The general idea is to treat the large integers digits as
420         // polynomials of a certain degree and determine the coefficients/digits
421         // of the product of the two via interpolation of the polynomial product.
422         let i = y.len() / 3 + 1;
423 
424         let x0_len = cmp::min(x.len(), i);
425         let x1_len = cmp::min(x.len() - x0_len, i);
426 
427         let y0_len = i;
428         let y1_len = cmp::min(y.len() - y0_len, i);
429 
430         // Break x and y into three parts, representating an order two polynomial.
431         // t is chosen to be the size of a digit so we can use faster shifts
432         // in place of multiplications.
433         //
434         // x(t) = x2*t^2 + x1*t + x0
435         let x0 = bigint_from_slice(&x[..x0_len]);
436         let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
437         let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
438 
439         // y(t) = y2*t^2 + y1*t + y0
440         let y0 = bigint_from_slice(&y[..y0_len]);
441         let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
442         let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
443 
444         // Let w(t) = x(t) * y(t)
445         //
446         // This gives us the following order-4 polynomial.
447         //
448         // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
449         //
450         // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
451         // of simply multiplying the x and y in total, we can evaluate w
452         // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
453         // points.
454         //
455         // It is arbitrary as to what points we evaluate w at but we use the
456         // following.
457         //
458         // w(t) at t = 0, 1, -1, -2 and inf
459         //
460         // The values for w(t) in terms of x(t)*y(t) at these points are:
461         //
462         // let a = w(0)   = x0 * y0
463         // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
464         // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
465         // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
466         // let e = w(inf) = x2 * y2 as t -> inf
467 
468         // x0 + x2, avoiding temporaries
469         let p = &x0 + &x2;
470 
471         // y0 + y2, avoiding temporaries
472         let q = &y0 + &y2;
473 
474         // x2 - x1 + x0, avoiding temporaries
475         let p2 = &p - &x1;
476 
477         // y2 - y1 + y0, avoiding temporaries
478         let q2 = &q - &y1;
479 
480         // w(0)
481         let r0 = &x0 * &y0;
482 
483         // w(inf)
484         let r4 = &x2 * &y2;
485 
486         // w(1)
487         let r1 = (p + x1) * (q + y1);
488 
489         // w(-1)
490         let r2 = &p2 * &q2;
491 
492         // w(-2)
493         let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
494 
495         // Evaluating these points gives us the following system of linear equations.
496         //
497         //  0  0  0  0  1 | a
498         //  1  1  1  1  1 | b
499         //  1 -1  1 -1  1 | c
500         // 16 -8  4 -2  1 | d
501         //  1  0  0  0  0 | e
502         //
503         // The solved equation (after gaussian elimination or similar)
504         // in terms of its coefficients:
505         //
506         // w0 = w(0)
507         // w1 = w(0)/2 + w(1)/3 - w(-1) + w(2)/6 - 2*w(inf)
508         // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
509         // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(1)/6
510         // w4 = w(inf)
511         //
512         // This particular sequence is given by Bodrato and is an interpolation
513         // of the above equations.
514         let mut comp3: BigInt = (r3 - &r1) / 3;
515         let mut comp1: BigInt = (r1 - &r2) / 2;
516         let mut comp2: BigInt = r2 - &r0;
517         comp3 = (&comp2 - comp3) / 2 + &r4 * 2;
518         comp2 += &comp1 - &r4;
519         comp1 -= &comp3;
520 
521         // Recomposition. The coefficients of the polynomial are now known.
522         //
523         // Evaluate at w(t) where t is our given base to get the result.
524         let bits = u64::from(big_digit::BITS) * i as u64;
525         let result = r0
526             + (comp1 << bits)
527             + (comp2 << (2 * bits))
528             + (comp3 << (3 * bits))
529             + (r4 << (4 * bits));
530         let result_pos = result.to_biguint().unwrap();
531         add2(&mut acc[..], &result_pos.data);
532     }
533 }
534 
mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint535 pub(crate) fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
536     let len = x.len() + y.len() + 1;
537     let mut prod = BigUint { data: vec![0; len] };
538 
539     mac3(&mut prod.data[..], x, y);
540     prod.normalized()
541 }
542 
scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit543 pub(crate) fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit {
544     let mut carry = 0;
545     for a in a.iter_mut() {
546         *a = mul_with_carry(*a, b, &mut carry);
547     }
548     carry as BigDigit
549 }
550 
div_rem(mut u: BigUint, mut d: BigUint) -> (BigUint, BigUint)551 pub(crate) fn div_rem(mut u: BigUint, mut d: BigUint) -> (BigUint, BigUint) {
552     if d.is_zero() {
553         panic!("attempt to divide by zero")
554     }
555     if u.is_zero() {
556         return (Zero::zero(), Zero::zero());
557     }
558 
559     if d.data.len() == 1 {
560         if d.data == [1] {
561             return (u, Zero::zero());
562         }
563         let (div, rem) = div_rem_digit(u, d.data[0]);
564         // reuse d
565         d.data.clear();
566         d += rem;
567         return (div, d);
568     }
569 
570     // Required or the q_len calculation below can underflow:
571     match u.cmp(&d) {
572         Less => return (Zero::zero(), u),
573         Equal => {
574             u.set_one();
575             return (u, Zero::zero());
576         }
577         Greater => {} // Do nothing
578     }
579 
580     // This algorithm is from Knuth, TAOCP vol 2 section 4.3, algorithm D:
581     //
582     // First, normalize the arguments so the highest bit in the highest digit of the divisor is
583     // set: the main loop uses the highest digit of the divisor for generating guesses, so we
584     // want it to be the largest number we can efficiently divide by.
585     //
586     let shift = d.data.last().unwrap().leading_zeros() as usize;
587 
588     let (q, r) = if shift == 0 {
589         // no need to clone d
590         div_rem_core(u, &d)
591     } else {
592         div_rem_core(u << shift, &(d << shift))
593     };
594     // renormalize the remainder
595     (q, r >> shift)
596 }
597 
div_rem_ref(u: &BigUint, d: &BigUint) -> (BigUint, BigUint)598 pub(crate) fn div_rem_ref(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) {
599     if d.is_zero() {
600         panic!("attempt to divide by zero")
601     }
602     if u.is_zero() {
603         return (Zero::zero(), Zero::zero());
604     }
605 
606     if d.data.len() == 1 {
607         if d.data == [1] {
608             return (u.clone(), Zero::zero());
609         }
610 
611         let (div, rem) = div_rem_digit(u.clone(), d.data[0]);
612         return (div, rem.into());
613     }
614 
615     // Required or the q_len calculation below can underflow:
616     match u.cmp(d) {
617         Less => return (Zero::zero(), u.clone()),
618         Equal => return (One::one(), Zero::zero()),
619         Greater => {} // Do nothing
620     }
621 
622     // This algorithm is from Knuth, TAOCP vol 2 section 4.3, algorithm D:
623     //
624     // First, normalize the arguments so the highest bit in the highest digit of the divisor is
625     // set: the main loop uses the highest digit of the divisor for generating guesses, so we
626     // want it to be the largest number we can efficiently divide by.
627     //
628     let shift = d.data.last().unwrap().leading_zeros() as usize;
629 
630     let (q, r) = if shift == 0 {
631         // no need to clone d
632         div_rem_core(u.clone(), d)
633     } else {
634         div_rem_core(u << shift, &(d << shift))
635     };
636     // renormalize the remainder
637     (q, r >> shift)
638 }
639 
640 /// an implementation of Knuth, TAOCP vol 2 section 4.3, algorithm D
641 ///
642 /// # Correctness
643 ///
644 /// This function requires the following conditions to run correctly and/or effectively
645 ///
646 /// - `a > b`
647 /// - `d.data.len() > 1`
648 /// - `d.data.last().unwrap().leading_zeros() == 0`
div_rem_core(mut a: BigUint, b: &BigUint) -> (BigUint, BigUint)649 fn div_rem_core(mut a: BigUint, b: &BigUint) -> (BigUint, BigUint) {
650     // The algorithm works by incrementally calculating "guesses", q0, for part of the
651     // remainder. Once we have any number q0 such that q0 * b <= a, we can set
652     //
653     //     q += q0
654     //     a -= q0 * b
655     //
656     // and then iterate until a < b. Then, (q, a) will be our desired quotient and remainder.
657     //
658     // q0, our guess, is calculated by dividing the last few digits of a by the last digit of b
659     // - this should give us a guess that is "close" to the actual quotient, but is possibly
660     // greater than the actual quotient. If q0 * b > a, we simply use iterated subtraction
661     // until we have a guess such that q0 * b <= a.
662     //
663 
664     let bn = *b.data.last().unwrap();
665     let q_len = a.data.len() - b.data.len() + 1;
666     let mut q = BigUint {
667         data: vec![0; q_len],
668     };
669 
670     // We reuse the same temporary to avoid hitting the allocator in our inner loop - this is
671     // sized to hold a0 (in the common case; if a particular digit of the quotient is zero a0
672     // can be bigger).
673     //
674     let mut tmp = BigUint {
675         data: Vec::with_capacity(2),
676     };
677 
678     for j in (0..q_len).rev() {
679         /*
680          * When calculating our next guess q0, we don't need to consider the digits below j
681          * + b.data.len() - 1: we're guessing digit j of the quotient (i.e. q0 << j) from
682          * digit bn of the divisor (i.e. bn << (b.data.len() - 1) - so the product of those
683          * two numbers will be zero in all digits up to (j + b.data.len() - 1).
684          */
685         let offset = j + b.data.len() - 1;
686         if offset >= a.data.len() {
687             continue;
688         }
689 
690         /* just avoiding a heap allocation: */
691         let mut a0 = tmp;
692         a0.data.truncate(0);
693         a0.data.extend(a.data[offset..].iter().cloned());
694 
695         /*
696          * q0 << j * big_digit::BITS is our actual quotient estimate - we do the shifts
697          * implicitly at the end, when adding and subtracting to a and q. Not only do we
698          * save the cost of the shifts, the rest of the arithmetic gets to work with
699          * smaller numbers.
700          */
701         let (mut q0, _) = div_rem_digit(a0, bn);
702         let mut prod = b * &q0;
703 
704         while cmp_slice(&prod.data[..], &a.data[j..]) == Greater {
705             q0 -= 1u32;
706             prod -= b;
707         }
708 
709         add2(&mut q.data[j..], &q0.data[..]);
710         sub2(&mut a.data[j..], &prod.data[..]);
711         a.normalize();
712 
713         tmp = q0;
714     }
715 
716     debug_assert!(a < *b);
717 
718     (q.normalized(), a)
719 }
720 
721 /// Find last set bit
722 /// fls(0) == 0, fls(u32::MAX) == 32
fls<T: PrimInt>(v: T) -> u8723 pub(crate) fn fls<T: PrimInt>(v: T) -> u8 {
724     mem::size_of::<T>() as u8 * 8 - v.leading_zeros() as u8
725 }
726 
ilog2<T: PrimInt>(v: T) -> u8727 pub(crate) fn ilog2<T: PrimInt>(v: T) -> u8 {
728     fls(v) - 1
729 }
730 
731 #[inline]
biguint_shl<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint732 pub(crate) fn biguint_shl<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
733     if shift < T::zero() {
734         panic!("attempt to shift left with negative");
735     }
736     if n.is_zero() {
737         return n.into_owned();
738     }
739     let bits = T::from(big_digit::BITS).unwrap();
740     let digits = (shift / bits).to_usize().expect("capacity overflow");
741     let shift = (shift % bits).to_u8().unwrap();
742     biguint_shl2(n, digits, shift)
743 }
744 
biguint_shl2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint745 fn biguint_shl2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
746     let mut data = match digits {
747         0 => n.into_owned().data,
748         _ => {
749             let len = digits.saturating_add(n.data.len() + 1);
750             let mut data = Vec::with_capacity(len);
751             data.extend(repeat(0).take(digits));
752             data.extend(n.data.iter());
753             data
754         }
755     };
756 
757     if shift > 0 {
758         let mut carry = 0;
759         let carry_shift = big_digit::BITS as u8 - shift;
760         for elem in data[digits..].iter_mut() {
761             let new_carry = *elem >> carry_shift;
762             *elem = (*elem << shift) | carry;
763             carry = new_carry;
764         }
765         if carry != 0 {
766             data.push(carry);
767         }
768     }
769 
770     biguint_from_vec(data)
771 }
772 
773 #[inline]
biguint_shr<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint774 pub(crate) fn biguint_shr<T: PrimInt>(n: Cow<'_, BigUint>, shift: T) -> BigUint {
775     if shift < T::zero() {
776         panic!("attempt to shift right with negative");
777     }
778     if n.is_zero() {
779         return n.into_owned();
780     }
781     let bits = T::from(big_digit::BITS).unwrap();
782     let digits = (shift / bits).to_usize().unwrap_or(core::usize::MAX);
783     let shift = (shift % bits).to_u8().unwrap();
784     biguint_shr2(n, digits, shift)
785 }
786 
biguint_shr2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint787 fn biguint_shr2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint {
788     if digits >= n.data.len() {
789         let mut n = n.into_owned();
790         n.set_zero();
791         return n;
792     }
793     let mut data = match n {
794         Cow::Borrowed(n) => n.data[digits..].to_vec(),
795         Cow::Owned(mut n) => {
796             n.data.drain(..digits);
797             n.data
798         }
799     };
800 
801     if shift > 0 {
802         let mut borrow = 0;
803         let borrow_shift = big_digit::BITS as u8 - shift;
804         for elem in data.iter_mut().rev() {
805             let new_borrow = *elem << borrow_shift;
806             *elem = (*elem >> shift) | borrow;
807             borrow = new_borrow;
808         }
809     }
810 
811     biguint_from_vec(data)
812 }
813 
cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering814 pub(crate) fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
815     debug_assert!(a.last() != Some(&0));
816     debug_assert!(b.last() != Some(&0));
817 
818     match Ord::cmp(&a.len(), &b.len()) {
819         Equal => Iterator::cmp(a.iter().rev(), b.iter().rev()),
820         other => other,
821     }
822 }
823 
824 #[cfg(test)]
825 mod algorithm_tests {
826     use crate::big_digit::BigDigit;
827     use crate::{BigInt, BigUint};
828     use num_traits::Num;
829 
830     #[test]
test_sub_sign()831     fn test_sub_sign() {
832         use super::sub_sign;
833 
834         fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
835             let (sign, val) = sub_sign(a, b);
836             BigInt::from_biguint(sign, val)
837         }
838 
839         let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
840         let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
841         let a_i = BigInt::from(a.clone());
842         let b_i = BigInt::from(b.clone());
843 
844         assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i);
845         assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i);
846     }
847 }
848