1 mod biguint {
2     use num_bigint::BigUint;
3     use num_traits::{One, Zero};
4     use std::{i32, u32};
5 
check<T: Into<BigUint>>(x: T, n: u32)6     fn check<T: Into<BigUint>>(x: T, n: u32) {
7         let x: BigUint = x.into();
8         let root = x.nth_root(n);
9         println!("check {}.nth_root({}) = {}", x, n, root);
10 
11         if n == 2 {
12             assert_eq!(root, x.sqrt())
13         } else if n == 3 {
14             assert_eq!(root, x.cbrt())
15         }
16 
17         let lo = root.pow(n);
18         assert!(lo <= x);
19         assert_eq!(lo.nth_root(n), root);
20         if !lo.is_zero() {
21             assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
22         }
23 
24         let hi = (&root + 1u32).pow(n);
25         assert!(hi > x);
26         assert_eq!(hi.nth_root(n), &root + 1u32);
27         assert_eq!((&hi - 1u32).nth_root(n), root);
28     }
29 
30     #[test]
test_sqrt()31     fn test_sqrt() {
32         check(99u32, 2);
33         check(100u32, 2);
34         check(120u32, 2);
35     }
36 
37     #[test]
test_cbrt()38     fn test_cbrt() {
39         check(8u32, 3);
40         check(26u32, 3);
41     }
42 
43     #[test]
test_nth_root()44     fn test_nth_root() {
45         check(0u32, 1);
46         check(10u32, 1);
47         check(100u32, 4);
48     }
49 
50     #[test]
51     #[should_panic]
test_nth_root_n_is_zero()52     fn test_nth_root_n_is_zero() {
53         check(4u32, 0);
54     }
55 
56     #[test]
test_nth_root_big()57     fn test_nth_root_big() {
58         let x = BigUint::from(123_456_789_u32);
59         let expected = BigUint::from(6u32);
60 
61         assert_eq!(x.nth_root(10), expected);
62         check(x, 10);
63     }
64 
65     #[test]
test_nth_root_googol()66     fn test_nth_root_googol() {
67         let googol = BigUint::from(10u32).pow(100u32);
68 
69         // perfect divisors of 100
70         for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
71             let expected = BigUint::from(10u32).pow(100u32 / n);
72             assert_eq!(googol.nth_root(n), expected);
73             check(googol.clone(), n);
74         }
75     }
76 
77     #[test]
test_nth_root_twos()78     fn test_nth_root_twos() {
79         const EXP: u32 = 12;
80         const LOG2: usize = 1 << EXP;
81         let x = BigUint::one() << LOG2;
82 
83         // the perfect divisors are just powers of two
84         for exp in 1..=EXP {
85             let n = 2u32.pow(exp);
86             let expected = BigUint::one() << (LOG2 / n as usize);
87             assert_eq!(x.nth_root(n), expected);
88             check(x.clone(), n);
89         }
90 
91         // degenerate cases should return quickly
92         assert!(x.nth_root(x.bits() as u32).is_one());
93         assert!(x.nth_root(i32::MAX as u32).is_one());
94         assert!(x.nth_root(u32::MAX).is_one());
95     }
96 
97     #[test]
test_roots_rand1()98     fn test_roots_rand1() {
99         // A random input that found regressions
100         let s = "575981506858479247661989091587544744717244516135539456183849\
101                  986593934723426343633698413178771587697273822147578889823552\
102                  182702908597782734558103025298880194023243541613924361007059\
103                  353344183590348785832467726433749431093350684849462759540710\
104                  026019022227591412417064179299354183441181373862905039254106\
105                  4781867";
106         let x: BigUint = s.parse().unwrap();
107 
108         check(x.clone(), 2);
109         check(x.clone(), 3);
110         check(x.clone(), 10);
111         check(x, 100);
112     }
113 }
114 
115 mod bigint {
116     use num_bigint::BigInt;
117     use num_traits::Signed;
118 
check(x: i64, n: u32)119     fn check(x: i64, n: u32) {
120         let big_x = BigInt::from(x);
121         let res = big_x.nth_root(n);
122 
123         if n == 2 {
124             assert_eq!(&res, &big_x.sqrt())
125         } else if n == 3 {
126             assert_eq!(&res, &big_x.cbrt())
127         }
128 
129         if big_x.is_negative() {
130             assert!(res.pow(n) >= big_x);
131             assert!((res - 1u32).pow(n) < big_x);
132         } else {
133             assert!(res.pow(n) <= big_x);
134             assert!((res + 1u32).pow(n) > big_x);
135         }
136     }
137 
138     #[test]
test_nth_root()139     fn test_nth_root() {
140         check(-100, 3);
141     }
142 
143     #[test]
144     #[should_panic]
test_nth_root_x_neg_n_even()145     fn test_nth_root_x_neg_n_even() {
146         check(-100, 4);
147     }
148 
149     #[test]
150     #[should_panic]
test_sqrt_x_neg()151     fn test_sqrt_x_neg() {
152         check(-4, 2);
153     }
154 
155     #[test]
test_cbrt()156     fn test_cbrt() {
157         check(8, 3);
158         check(-8, 3);
159     }
160 }
161