1
2 use num_traits::{Zero, One, FromPrimitive, PrimInt, Signed};
3 use std::mem::swap;
4
primitive_root(prime: u64) -> Option<u64>5 pub fn primitive_root(prime: u64) -> Option<u64> {
6 let test_exponents: Vec<u64> = distinct_prime_factors(prime - 1)
7 .iter()
8 .map(|factor| (prime - 1) / factor)
9 .collect();
10 'next: for potential_root in 2..prime {
11 // for each distinct factor, if potential_root^(p-1)/factor mod p is 1, reject it
12 for exp in &test_exponents {
13 if modular_exponent(potential_root, *exp, prime) == 1 {
14 continue 'next;
15 }
16 }
17
18 // if we reach this point, it means this root was not rejected, so return it
19 return Some(potential_root);
20 }
21 None
22 }
23
24 /// computes base^exponent % modulo using the standard exponentiation by squaring algorithm
modular_exponent<T: PrimInt>(mut base: T, mut exponent: T, modulo: T) -> T25 pub fn modular_exponent<T: PrimInt>(mut base: T, mut exponent: T, modulo: T) -> T {
26 let one = T::one();
27
28 let mut result = one;
29
30 while exponent > Zero::zero() {
31 if exponent & one == one {
32 result = result * base % modulo;
33 }
34 exponent = exponent >> One::one();
35 base = (base * base) % modulo;
36 }
37
38 result
39 }
40
multiplicative_inverse<T: PrimInt + FromPrimitive>(a: T, n: T) -> T41 pub fn multiplicative_inverse<T: PrimInt + FromPrimitive>(a: T, n: T) -> T {
42 // we're going to use a modified version extended euclidean algorithm
43 // we only need half the output
44
45 let mut t = Zero::zero();
46 let mut t_new = One::one();
47
48 let mut r = n;
49 let mut r_new = a;
50
51 while r_new > Zero::zero() {
52 let quotient = r / r_new;
53
54 r = r - quotient * r_new;
55 swap(&mut r, &mut r_new);
56
57 // t might go negative here, so we have to do a checked subtract
58 // if it underflows, wrap it around to the other end of the modulo
59 // IE, 3 - 4 mod 5 = -1 mod 5 = 4
60 let t_subtract = quotient * t_new;
61 t = if t_subtract < t {
62 t - t_subtract
63 } else {
64 n - (t_subtract - t) % n
65 };
66 swap(&mut t, &mut t_new);
67 }
68
69 t
70 }
71
extended_euclidean_algorithm<T: PrimInt + Signed + FromPrimitive>(a: T, b: T) -> (T, T, T)72 pub fn extended_euclidean_algorithm<T: PrimInt + Signed + FromPrimitive>(a: T,
73 b: T)
74 -> (T, T, T) {
75 let mut s = Zero::zero();
76 let mut s_old = One::one();
77
78 let mut t = One::one();
79 let mut t_old = Zero::zero();
80
81 let mut r = b;
82 let mut r_old = a;
83
84 while r > Zero::zero() {
85 let quotient = r_old / r;
86
87 r_old = r_old - quotient * r;
88 swap(&mut r_old, &mut r);
89
90 s_old = s_old - quotient * s;
91 swap(&mut s_old, &mut s);
92
93 t_old = t_old - quotient * t;
94 swap(&mut t_old, &mut t);
95 }
96
97 (r_old, s_old, t_old)
98 }
99
100 /// return all of the prime factors of n, but omit duplicate prime factors
distinct_prime_factors(mut n: u64) -> Vec<u64>101 pub fn distinct_prime_factors(mut n: u64) -> Vec<u64> {
102 let mut result = Vec::new();
103
104 // handle 2 separately so we dont have to worry about adding 2 vs 1
105 if n % 2 == 0 {
106 while n % 2 == 0 {
107 n /= 2;
108 }
109 result.push(2);
110 }
111 if n > 1 {
112 let mut divisor = 3;
113 let mut limit = (n as f32).sqrt() as u64 + 1;
114 while divisor < limit {
115 if n % divisor == 0 {
116
117 // remove as many factors as possible from n
118 while n % divisor == 0 {
119 n /= divisor;
120 }
121 result.push(divisor);
122
123 // recalculate the limit to reduce the amount of work we need to do
124 limit = (n as f32).sqrt() as u64 + 1;
125 }
126
127 divisor += 2;
128 }
129
130 if n > 1 {
131 result.push(n);
132 }
133 }
134
135 result
136 }
137
138 /// Factors an integer into its prime factors.
prime_factors(mut n: usize) -> Vec<usize>139 pub fn prime_factors(mut n: usize) -> Vec<usize> {
140 let mut result = Vec::new();
141
142 while n % 2 == 0 {
143 n /= 2;
144 result.push(2);
145 }
146 if n > 1 {
147 let mut divisor = 3;
148 let mut limit = (n as f32).sqrt() as usize + 1;
149 while divisor < limit {
150 while n % divisor == 0 {
151 n /= divisor;
152 result.push(divisor);
153 }
154
155 // recalculate the limit to reduce the amount of other factors we need to check
156 limit = (n as f32).sqrt() as usize + 1;
157 divisor += 2;
158 }
159
160 if n > 1 {
161 result.push(n);
162 }
163 }
164
165 result
166 }
167
168 #[cfg(test)]
169 mod unit_tests {
170 use super::*;
171
172 #[test]
test_modular_exponent()173 fn test_modular_exponent() {
174 // make sure to test something that would overflow under ordinary circumstances
175 // ie 3 ^ 416788 mod 47
176 let test_list = vec![
177 ((2,8,300), 256),
178 ((2,9,300), 212),
179 ((1,9,300), 1),
180 ((3,416788,47), 8),
181 ];
182
183 for (input, expected) in test_list {
184 let (base, exponent, modulo) = input;
185
186 let result = modular_exponent(base, exponent, modulo);
187
188 assert_eq!(result, expected);
189 }
190 }
191
192 #[test]
test_multiplicative_inverse()193 fn test_multiplicative_inverse() {
194 let prime_list = vec![3, 5, 7, 11, 13, 17, 19, 23, 29];
195
196 for modulo in prime_list {
197 for i in 2..modulo {
198 let inverse = multiplicative_inverse(i, modulo);
199
200 assert_eq!(i * inverse % modulo, 1);
201 }
202 }
203 }
204
205 #[test]
test_extended_euclidean()206 fn test_extended_euclidean() {
207 let test_list = vec![
208 ((3,5), (1, 2, -1)),
209 ((15,12), (3, 1, -1)),
210 ((16,21), (1, 4, -3)),
211 ];
212
213 for (input, expected) in test_list {
214 let (a, b) = input;
215
216 let result = extended_euclidean_algorithm(a, b);
217 assert_eq!(expected, result);
218
219 let (gcd, mut a_inverse, mut b_inverse) = result;
220
221 // sanity check: if gcd=1, then a*a_inverse mod b should equal 1 and vice versa
222 if gcd == 1 {
223 if a_inverse < 0 {
224 a_inverse += b;
225 }
226 if b_inverse < 0 {
227 b_inverse += a;
228 }
229
230 assert_eq!(1, a * a_inverse % b);
231 assert_eq!(1, b * b_inverse % a);
232 }
233 }
234 }
235
236 #[test]
test_primitive_root()237 fn test_primitive_root() {
238 let test_list = vec![(3, 2), (7, 3), (11, 2), (13, 2), (47, 5), (7919, 7)];
239
240 for (input, expected) in test_list {
241 let root = primitive_root(input).unwrap();
242
243 assert_eq!(root, expected);
244 }
245 }
246
247 #[test]
test_prime_factors()248 fn test_prime_factors() {
249 let test_list = vec![
250 (46, vec![2,23]),
251 (2, vec![2]),
252 (3, vec![3]),
253 (162, vec![2, 3]),
254 ];
255
256 for (input, expected) in test_list {
257 let factors = distinct_prime_factors(input);
258
259 assert_eq!(factors, expected);
260 }
261 }
262 }
263