1 extern crate num_integer;
2 extern crate num_traits;
3 
4 use num_integer::Roots;
5 use num_traits::checked_pow;
6 use num_traits::{AsPrimitive, PrimInt, Signed};
7 use std::f64::MANTISSA_DIGITS;
8 use std::fmt::Debug;
9 use std::mem;
10 
11 trait TestInteger: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
12 
13 impl<T> TestInteger for T where T: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
14 
15 /// Check that each root is correct
16 ///
17 /// If `x` is positive, check `rⁿ ≤ x < (r+1)ⁿ`.
18 /// If `x` is negative, check `(r-1)ⁿ < x ≤ rⁿ`.
check<T>(v: &[T], n: u32) where T: TestInteger,19 fn check<T>(v: &[T], n: u32)
20 where
21     T: TestInteger,
22 {
23     for i in v {
24         let rt = i.nth_root(n);
25         // println!("nth_root({:?}, {}) = {:?}", i, n, rt);
26         if n == 2 {
27             assert_eq!(rt, i.sqrt());
28         } else if n == 3 {
29             assert_eq!(rt, i.cbrt());
30         }
31         if *i >= T::zero() {
32             let rt1 = rt + T::one();
33             assert!(rt.pow(n) <= *i);
34             if let Some(x) = checked_pow(rt1, n as usize) {
35                 assert!(*i < x);
36             }
37         } else {
38             let rt1 = rt - T::one();
39             assert!(rt < T::zero());
40             assert!(*i <= rt.pow(n));
41             if let Some(x) = checked_pow(rt1, n as usize) {
42                 assert!(x < *i);
43             }
44         };
45     }
46 }
47 
48 /// Get the maximum value that will round down as `f64` (if any),
49 /// and its successor that will round up.
50 ///
51 /// Important because the `std` implementations cast to `f64` to
52 /// get a close approximation of the roots.
mantissa_max<T>() -> Option<(T, T)> where T: TestInteger,53 fn mantissa_max<T>() -> Option<(T, T)>
54 where
55     T: TestInteger,
56 {
57     let bits = if T::min_value().is_zero() {
58         8 * mem::size_of::<T>()
59     } else {
60         8 * mem::size_of::<T>() - 1
61     };
62     if bits > MANTISSA_DIGITS as usize {
63         let rounding_bit = T::one() << (bits - MANTISSA_DIGITS as usize - 1);
64         let x = T::max_value() - rounding_bit;
65 
66         let x1 = x + T::one();
67         let x2 = x1 + T::one();
68         assert!(x.as_() < x1.as_());
69         assert_eq!(x1.as_(), x2.as_());
70 
71         Some((x, x1))
72     } else {
73         None
74     }
75 }
76 
extend<T>(v: &mut Vec<T>, start: T, end: T) where T: TestInteger,77 fn extend<T>(v: &mut Vec<T>, start: T, end: T)
78 where
79     T: TestInteger,
80 {
81     let mut i = start;
82     while i < end {
83         v.push(i);
84         i = i + T::one();
85     }
86     v.push(i);
87 }
88 
extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T) where T: TestInteger,89 fn extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T)
90 where
91     T: TestInteger,
92 {
93     let mut i = start;
94     while i != end {
95         v.push(i);
96         i = (i << 1) & mask;
97     }
98 }
99 
extend_shr<T>(v: &mut Vec<T>, start: T, end: T) where T: TestInteger,100 fn extend_shr<T>(v: &mut Vec<T>, start: T, end: T)
101 where
102     T: TestInteger,
103 {
104     let mut i = start;
105     while i != end {
106         v.push(i);
107         i = i >> 1;
108     }
109 }
110 
pos<T>() -> Vec<T> where T: TestInteger, i8: AsPrimitive<T>,111 fn pos<T>() -> Vec<T>
112 where
113     T: TestInteger,
114     i8: AsPrimitive<T>,
115 {
116     let mut v: Vec<T> = vec![];
117     if mem::size_of::<T>() == 1 {
118         extend(&mut v, T::zero(), T::max_value());
119     } else {
120         extend(&mut v, T::zero(), i8::max_value().as_());
121         extend(
122             &mut v,
123             T::max_value() - i8::max_value().as_(),
124             T::max_value(),
125         );
126         if let Some((i, j)) = mantissa_max::<T>() {
127             v.push(i);
128             v.push(j);
129         }
130         extend_shl(&mut v, T::max_value(), T::zero(), !T::min_value());
131         extend_shr(&mut v, T::max_value(), T::zero());
132     }
133     v
134 }
135 
neg<T>() -> Vec<T> where T: TestInteger + Signed, i8: AsPrimitive<T>,136 fn neg<T>() -> Vec<T>
137 where
138     T: TestInteger + Signed,
139     i8: AsPrimitive<T>,
140 {
141     let mut v: Vec<T> = vec![];
142     if mem::size_of::<T>() <= 1 {
143         extend(&mut v, T::min_value(), T::zero());
144     } else {
145         extend(&mut v, i8::min_value().as_(), T::zero());
146         extend(
147             &mut v,
148             T::min_value(),
149             T::min_value() - i8::min_value().as_(),
150         );
151         if let Some((i, j)) = mantissa_max::<T>() {
152             v.push(-i);
153             v.push(-j);
154         }
155         extend_shl(&mut v, -T::one(), T::min_value(), !T::zero());
156         extend_shr(&mut v, T::min_value(), -T::one());
157     }
158     v
159 }
160 
161 macro_rules! test_roots {
162     ($I:ident, $U:ident) => {
163         mod $I {
164             use check;
165             use neg;
166             use num_integer::Roots;
167             use pos;
168             use std::mem;
169 
170             #[test]
171             #[should_panic]
172             fn zeroth_root() {
173                 (123 as $I).nth_root(0);
174             }
175 
176             #[test]
177             fn sqrt() {
178                 check(&pos::<$I>(), 2);
179             }
180 
181             #[test]
182             #[should_panic]
183             fn sqrt_neg() {
184                 (-123 as $I).sqrt();
185             }
186 
187             #[test]
188             fn cbrt() {
189                 check(&pos::<$I>(), 3);
190             }
191 
192             #[test]
193             fn cbrt_neg() {
194                 check(&neg::<$I>(), 3);
195             }
196 
197             #[test]
198             fn nth_root() {
199                 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
200                 let pos = pos::<$I>();
201                 for n in 4..bits {
202                     check(&pos, n);
203                 }
204             }
205 
206             #[test]
207             fn nth_root_neg() {
208                 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
209                 let neg = neg::<$I>();
210                 for n in 2..bits / 2 {
211                     check(&neg, 2 * n + 1);
212                 }
213             }
214 
215             #[test]
216             fn bit_size() {
217                 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
218                 assert_eq!($I::max_value().nth_root(bits - 1), 2);
219                 assert_eq!($I::max_value().nth_root(bits), 1);
220                 assert_eq!($I::min_value().nth_root(bits), -2);
221                 assert_eq!(($I::min_value() + 1).nth_root(bits), -1);
222             }
223         }
224 
225         mod $U {
226             use check;
227             use num_integer::Roots;
228             use pos;
229             use std::mem;
230 
231             #[test]
232             #[should_panic]
233             fn zeroth_root() {
234                 (123 as $U).nth_root(0);
235             }
236 
237             #[test]
238             fn sqrt() {
239                 check(&pos::<$U>(), 2);
240             }
241 
242             #[test]
243             fn cbrt() {
244                 check(&pos::<$U>(), 3);
245             }
246 
247             #[test]
248             fn nth_root() {
249                 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
250                 let pos = pos::<$I>();
251                 for n in 4..bits {
252                     check(&pos, n);
253                 }
254             }
255 
256             #[test]
257             fn bit_size() {
258                 let bits = 8 * mem::size_of::<$U>() as u32;
259                 assert_eq!($U::max_value().nth_root(bits - 1), 2);
260                 assert_eq!($U::max_value().nth_root(bits), 1);
261             }
262         }
263     };
264 }
265 
266 test_roots!(i8, u8);
267 test_roots!(i16, u16);
268 test_roots!(i32, u32);
269 test_roots!(i64, u64);
270 #[cfg(has_i128)]
271 test_roots!(i128, u128);
272 test_roots!(isize, usize);
273