1 extern crate num_integer;
2 extern crate num_traits;
3
4 use num_integer::Roots;
5 use num_traits::checked_pow;
6 use num_traits::{AsPrimitive, PrimInt, Signed};
7 use std::f64::MANTISSA_DIGITS;
8 use std::fmt::Debug;
9 use std::mem;
10
11 trait TestInteger: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
12
13 impl<T> TestInteger for T where T: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
14
15 /// Check that each root is correct
16 ///
17 /// If `x` is positive, check `rⁿ ≤ x < (r+1)ⁿ`.
18 /// If `x` is negative, check `(r-1)ⁿ < x ≤ rⁿ`.
check<T>(v: &[T], n: u32) where T: TestInteger,19 fn check<T>(v: &[T], n: u32)
20 where
21 T: TestInteger,
22 {
23 for i in v {
24 let rt = i.nth_root(n);
25 // println!("nth_root({:?}, {}) = {:?}", i, n, rt);
26 if n == 2 {
27 assert_eq!(rt, i.sqrt());
28 } else if n == 3 {
29 assert_eq!(rt, i.cbrt());
30 }
31 if *i >= T::zero() {
32 let rt1 = rt + T::one();
33 assert!(rt.pow(n) <= *i);
34 if let Some(x) = checked_pow(rt1, n as usize) {
35 assert!(*i < x);
36 }
37 } else {
38 let rt1 = rt - T::one();
39 assert!(rt < T::zero());
40 assert!(*i <= rt.pow(n));
41 if let Some(x) = checked_pow(rt1, n as usize) {
42 assert!(x < *i);
43 }
44 };
45 }
46 }
47
48 /// Get the maximum value that will round down as `f64` (if any),
49 /// and its successor that will round up.
50 ///
51 /// Important because the `std` implementations cast to `f64` to
52 /// get a close approximation of the roots.
mantissa_max<T>() -> Option<(T, T)> where T: TestInteger,53 fn mantissa_max<T>() -> Option<(T, T)>
54 where
55 T: TestInteger,
56 {
57 let bits = if T::min_value().is_zero() {
58 8 * mem::size_of::<T>()
59 } else {
60 8 * mem::size_of::<T>() - 1
61 };
62 if bits > MANTISSA_DIGITS as usize {
63 let rounding_bit = T::one() << (bits - MANTISSA_DIGITS as usize - 1);
64 let x = T::max_value() - rounding_bit;
65
66 let x1 = x + T::one();
67 let x2 = x1 + T::one();
68 assert!(x.as_() < x1.as_());
69 assert_eq!(x1.as_(), x2.as_());
70
71 Some((x, x1))
72 } else {
73 None
74 }
75 }
76
extend<T>(v: &mut Vec<T>, start: T, end: T) where T: TestInteger,77 fn extend<T>(v: &mut Vec<T>, start: T, end: T)
78 where
79 T: TestInteger,
80 {
81 let mut i = start;
82 while i < end {
83 v.push(i);
84 i = i + T::one();
85 }
86 v.push(i);
87 }
88
extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T) where T: TestInteger,89 fn extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T)
90 where
91 T: TestInteger,
92 {
93 let mut i = start;
94 while i != end {
95 v.push(i);
96 i = (i << 1) & mask;
97 }
98 }
99
extend_shr<T>(v: &mut Vec<T>, start: T, end: T) where T: TestInteger,100 fn extend_shr<T>(v: &mut Vec<T>, start: T, end: T)
101 where
102 T: TestInteger,
103 {
104 let mut i = start;
105 while i != end {
106 v.push(i);
107 i = i >> 1;
108 }
109 }
110
pos<T>() -> Vec<T> where T: TestInteger, i8: AsPrimitive<T>,111 fn pos<T>() -> Vec<T>
112 where
113 T: TestInteger,
114 i8: AsPrimitive<T>,
115 {
116 let mut v: Vec<T> = vec![];
117 if mem::size_of::<T>() == 1 {
118 extend(&mut v, T::zero(), T::max_value());
119 } else {
120 extend(&mut v, T::zero(), i8::max_value().as_());
121 extend(
122 &mut v,
123 T::max_value() - i8::max_value().as_(),
124 T::max_value(),
125 );
126 if let Some((i, j)) = mantissa_max::<T>() {
127 v.push(i);
128 v.push(j);
129 }
130 extend_shl(&mut v, T::max_value(), T::zero(), !T::min_value());
131 extend_shr(&mut v, T::max_value(), T::zero());
132 }
133 v
134 }
135
neg<T>() -> Vec<T> where T: TestInteger + Signed, i8: AsPrimitive<T>,136 fn neg<T>() -> Vec<T>
137 where
138 T: TestInteger + Signed,
139 i8: AsPrimitive<T>,
140 {
141 let mut v: Vec<T> = vec![];
142 if mem::size_of::<T>() <= 1 {
143 extend(&mut v, T::min_value(), T::zero());
144 } else {
145 extend(&mut v, i8::min_value().as_(), T::zero());
146 extend(
147 &mut v,
148 T::min_value(),
149 T::min_value() - i8::min_value().as_(),
150 );
151 if let Some((i, j)) = mantissa_max::<T>() {
152 v.push(-i);
153 v.push(-j);
154 }
155 extend_shl(&mut v, -T::one(), T::min_value(), !T::zero());
156 extend_shr(&mut v, T::min_value(), -T::one());
157 }
158 v
159 }
160
161 macro_rules! test_roots {
162 ($I:ident, $U:ident) => {
163 mod $I {
164 use check;
165 use neg;
166 use num_integer::Roots;
167 use pos;
168 use std::mem;
169
170 #[test]
171 #[should_panic]
172 fn zeroth_root() {
173 (123 as $I).nth_root(0);
174 }
175
176 #[test]
177 fn sqrt() {
178 check(&pos::<$I>(), 2);
179 }
180
181 #[test]
182 #[should_panic]
183 fn sqrt_neg() {
184 (-123 as $I).sqrt();
185 }
186
187 #[test]
188 fn cbrt() {
189 check(&pos::<$I>(), 3);
190 }
191
192 #[test]
193 fn cbrt_neg() {
194 check(&neg::<$I>(), 3);
195 }
196
197 #[test]
198 fn nth_root() {
199 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
200 let pos = pos::<$I>();
201 for n in 4..bits {
202 check(&pos, n);
203 }
204 }
205
206 #[test]
207 fn nth_root_neg() {
208 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
209 let neg = neg::<$I>();
210 for n in 2..bits / 2 {
211 check(&neg, 2 * n + 1);
212 }
213 }
214
215 #[test]
216 fn bit_size() {
217 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
218 assert_eq!($I::max_value().nth_root(bits - 1), 2);
219 assert_eq!($I::max_value().nth_root(bits), 1);
220 assert_eq!($I::min_value().nth_root(bits), -2);
221 assert_eq!(($I::min_value() + 1).nth_root(bits), -1);
222 }
223 }
224
225 mod $U {
226 use check;
227 use num_integer::Roots;
228 use pos;
229 use std::mem;
230
231 #[test]
232 #[should_panic]
233 fn zeroth_root() {
234 (123 as $U).nth_root(0);
235 }
236
237 #[test]
238 fn sqrt() {
239 check(&pos::<$U>(), 2);
240 }
241
242 #[test]
243 fn cbrt() {
244 check(&pos::<$U>(), 3);
245 }
246
247 #[test]
248 fn nth_root() {
249 let bits = 8 * mem::size_of::<$I>() as u32 - 1;
250 let pos = pos::<$I>();
251 for n in 4..bits {
252 check(&pos, n);
253 }
254 }
255
256 #[test]
257 fn bit_size() {
258 let bits = 8 * mem::size_of::<$U>() as u32;
259 assert_eq!($U::max_value().nth_root(bits - 1), 2);
260 assert_eq!($U::max_value().nth_root(bits), 1);
261 }
262 }
263 };
264 }
265
266 test_roots!(i8, u8);
267 test_roots!(i16, u16);
268 test_roots!(i32, u32);
269 test_roots!(i64, u64);
270 #[cfg(has_i128)]
271 test_roots!(i128, u128);
272 test_roots!(isize, usize);
273