1 
2 use num_traits::{Zero, One, FromPrimitive, PrimInt, Signed};
3 use std::mem::swap;
4 
primitive_root(prime: u64) -> Option<u64>5 pub fn primitive_root(prime: u64) -> Option<u64> {
6     let test_exponents: Vec<u64> = distinct_prime_factors(prime - 1)
7         .iter()
8         .map(|factor| (prime - 1) / factor)
9         .collect();
10     'next: for potential_root in 2..prime {
11         // for each distinct factor, if potential_root^(p-1)/factor mod p is 1, reject it
12         for exp in &test_exponents {
13             if modular_exponent(potential_root, *exp, prime) == 1 {
14                 continue 'next;
15             }
16         }
17 
18         // if we reach this point, it means this root was not rejected, so return it
19         return Some(potential_root);
20     }
21     None
22 }
23 
24 /// computes base^exponent % modulo using the standard exponentiation by squaring algorithm
modular_exponent<T: PrimInt>(mut base: T, mut exponent: T, modulo: T) -> T25 pub fn modular_exponent<T: PrimInt>(mut base: T, mut exponent: T, modulo: T) -> T {
26     let one = T::one();
27 
28     let mut result = one;
29 
30     while exponent > Zero::zero() {
31         if exponent & one == one {
32             result = result * base % modulo;
33         }
34         exponent = exponent >> One::one();
35         base = (base * base) % modulo;
36     }
37 
38     result
39 }
40 
multiplicative_inverse<T: PrimInt + FromPrimitive>(a: T, n: T) -> T41 pub fn multiplicative_inverse<T: PrimInt + FromPrimitive>(a: T, n: T) -> T {
42     // we're going to use a modified version extended euclidean algorithm
43     // we only need half the output
44 
45     let mut t = Zero::zero();
46     let mut t_new = One::one();
47 
48     let mut r = n;
49     let mut r_new = a;
50 
51     while r_new > Zero::zero() {
52         let quotient = r / r_new;
53 
54         r = r - quotient * r_new;
55         swap(&mut r, &mut r_new);
56 
57         // t might go negative here, so we have to do a checked subtract
58         // if it underflows, wrap it around to the other end of the modulo
59         // IE, 3 - 4 mod 5  =  -1 mod 5  =  4
60         let t_subtract = quotient * t_new;
61         t = if t_subtract < t {
62             t - t_subtract
63         } else {
64             n - (t_subtract - t) % n
65         };
66         swap(&mut t, &mut t_new);
67     }
68 
69     t
70 }
71 
extended_euclidean_algorithm<T: PrimInt + Signed + FromPrimitive>(a: T, b: T) -> (T, T, T)72 pub fn extended_euclidean_algorithm<T: PrimInt + Signed + FromPrimitive>(a: T,
73                                                                                    b: T)
74                                                                                    -> (T, T, T) {
75     let mut s = Zero::zero();
76     let mut s_old = One::one();
77 
78     let mut t = One::one();
79     let mut t_old = Zero::zero();
80 
81     let mut r = b;
82     let mut r_old = a;
83 
84     while r > Zero::zero() {
85         let quotient = r_old / r;
86 
87         r_old = r_old - quotient * r;
88         swap(&mut r_old, &mut r);
89 
90         s_old = s_old - quotient * s;
91         swap(&mut s_old, &mut s);
92 
93         t_old = t_old - quotient * t;
94         swap(&mut t_old, &mut t);
95     }
96 
97     (r_old, s_old, t_old)
98 }
99 
100 /// return all of the prime factors of n, but omit duplicate prime factors
distinct_prime_factors(mut n: u64) -> Vec<u64>101 pub fn distinct_prime_factors(mut n: u64) -> Vec<u64> {
102     let mut result = Vec::new();
103 
104     // handle 2 separately so we dont have to worry about adding 2 vs 1
105     if n % 2 == 0 {
106         while n % 2 == 0 {
107             n /= 2;
108         }
109         result.push(2);
110     }
111     if n > 1 {
112         let mut divisor = 3;
113         let mut limit = (n as f32).sqrt() as u64 + 1;
114         while divisor < limit {
115             if n % divisor == 0 {
116 
117                 // remove as many factors as possible from n
118                 while n % divisor == 0 {
119                     n /= divisor;
120                 }
121                 result.push(divisor);
122 
123                 // recalculate the limit to reduce the amount of work we need to do
124                 limit = (n as f32).sqrt() as u64 + 1;
125             }
126 
127             divisor += 2;
128         }
129 
130         if n > 1 {
131             result.push(n);
132         }
133     }
134 
135     result
136 }
137 
138 /// Factors an integer into its prime factors.
prime_factors(mut n: usize) -> Vec<usize>139 pub fn prime_factors(mut n: usize) -> Vec<usize> {
140     let mut result = Vec::new();
141 
142     while n % 2 == 0 {
143         n /= 2;
144         result.push(2);
145     }
146     if n > 1 {
147         let mut divisor = 3;
148         let mut limit = (n as f32).sqrt() as usize + 1;
149         while divisor < limit {
150             while n % divisor == 0 {
151                 n /= divisor;
152                 result.push(divisor);
153             }
154 
155             // recalculate the limit to reduce the amount of other factors we need to check
156             limit = (n as f32).sqrt() as usize + 1;
157             divisor += 2;
158         }
159 
160         if n > 1 {
161             result.push(n);
162         }
163     }
164 
165     result
166 }
167 
168 #[cfg(test)]
169 mod unit_tests {
170     use super::*;
171 
172     #[test]
test_modular_exponent()173     fn test_modular_exponent() {
174         // make sure to test something that would overflow under ordinary circumstances
175         // ie 3 ^ 416788 mod 47
176         let test_list = vec![
177 			((2,8,300), 256),
178 			((2,9,300), 212),
179 			((1,9,300), 1),
180 			((3,416788,47), 8),
181 		];
182 
183         for (input, expected) in test_list {
184             let (base, exponent, modulo) = input;
185 
186             let result = modular_exponent(base, exponent, modulo);
187 
188             assert_eq!(result, expected);
189         }
190     }
191 
192     #[test]
test_multiplicative_inverse()193     fn test_multiplicative_inverse() {
194         let prime_list = vec![3, 5, 7, 11, 13, 17, 19, 23, 29];
195 
196         for modulo in prime_list {
197             for i in 2..modulo {
198                 let inverse = multiplicative_inverse(i, modulo);
199 
200                 assert_eq!(i * inverse % modulo, 1);
201             }
202         }
203     }
204 
205     #[test]
test_extended_euclidean()206     fn test_extended_euclidean() {
207         let test_list = vec![
208             ((3,5), (1, 2, -1)),
209             ((15,12), (3, 1, -1)),
210             ((16,21), (1, 4, -3)),
211         ];
212 
213         for (input, expected) in test_list {
214             let (a, b) = input;
215 
216             let result = extended_euclidean_algorithm(a, b);
217             assert_eq!(expected, result);
218 
219             let (gcd, mut a_inverse, mut b_inverse) = result;
220 
221             // sanity check: if gcd=1, then a*a_inverse mod b should equal 1 and vice versa
222             if gcd == 1 {
223                 if a_inverse < 0 {
224                     a_inverse += b;
225                 }
226                 if b_inverse < 0 {
227                     b_inverse += a;
228                 }
229 
230                 assert_eq!(1, a * a_inverse % b);
231                 assert_eq!(1, b * b_inverse % a);
232             }
233         }
234     }
235 
236     #[test]
test_primitive_root()237     fn test_primitive_root() {
238         let test_list = vec![(3, 2), (7, 3), (11, 2), (13, 2), (47, 5), (7919, 7)];
239 
240         for (input, expected) in test_list {
241             let root = primitive_root(input).unwrap();
242 
243             assert_eq!(root, expected);
244         }
245     }
246 
247     #[test]
test_prime_factors()248     fn test_prime_factors() {
249         let test_list = vec![
250 			(46, vec![2,23]),
251 			(2, vec![2]),
252 			(3, vec![3]),
253 			(162, vec![2, 3]),
254 			];
255 
256         for (input, expected) in test_list {
257             let factors = distinct_prime_factors(input);
258 
259             assert_eq!(factors, expected);
260         }
261     }
262 }
263