1 use super::monty::monty_modpow;
2 use super::BigUint;
3 
4 use crate::big_digit::{self, BigDigit};
5 
6 use num_integer::Integer;
7 use num_traits::{One, Pow, ToPrimitive, Zero};
8 
9 impl<'b> Pow<&'b BigUint> for BigUint {
10     type Output = BigUint;
11 
12     #[inline]
pow(self, exp: &BigUint) -> BigUint13     fn pow(self, exp: &BigUint) -> BigUint {
14         if self.is_one() || exp.is_zero() {
15             BigUint::one()
16         } else if self.is_zero() {
17             BigUint::zero()
18         } else if let Some(exp) = exp.to_u64() {
19             self.pow(exp)
20         } else if let Some(exp) = exp.to_u128() {
21             self.pow(exp)
22         } else {
23             // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given
24             // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address!
25             panic!("memory overflow")
26         }
27     }
28 }
29 
30 impl Pow<BigUint> for BigUint {
31     type Output = BigUint;
32 
33     #[inline]
pow(self, exp: BigUint) -> BigUint34     fn pow(self, exp: BigUint) -> BigUint {
35         Pow::pow(self, &exp)
36     }
37 }
38 
39 impl<'a, 'b> Pow<&'b BigUint> for &'a BigUint {
40     type Output = BigUint;
41 
42     #[inline]
pow(self, exp: &BigUint) -> BigUint43     fn pow(self, exp: &BigUint) -> BigUint {
44         if self.is_one() || exp.is_zero() {
45             BigUint::one()
46         } else if self.is_zero() {
47             BigUint::zero()
48         } else {
49             self.clone().pow(exp)
50         }
51     }
52 }
53 
54 impl<'a> Pow<BigUint> for &'a BigUint {
55     type Output = BigUint;
56 
57     #[inline]
pow(self, exp: BigUint) -> BigUint58     fn pow(self, exp: BigUint) -> BigUint {
59         Pow::pow(self, &exp)
60     }
61 }
62 
63 macro_rules! pow_impl {
64     ($T:ty) => {
65         impl Pow<$T> for BigUint {
66             type Output = BigUint;
67 
68             fn pow(self, mut exp: $T) -> BigUint {
69                 if exp == 0 {
70                     return BigUint::one();
71                 }
72                 let mut base = self;
73 
74                 while exp & 1 == 0 {
75                     base = &base * &base;
76                     exp >>= 1;
77                 }
78 
79                 if exp == 1 {
80                     return base;
81                 }
82 
83                 let mut acc = base.clone();
84                 while exp > 1 {
85                     exp >>= 1;
86                     base = &base * &base;
87                     if exp & 1 == 1 {
88                         acc *= &base;
89                     }
90                 }
91                 acc
92             }
93         }
94 
95         impl<'b> Pow<&'b $T> for BigUint {
96             type Output = BigUint;
97 
98             #[inline]
99             fn pow(self, exp: &$T) -> BigUint {
100                 Pow::pow(self, *exp)
101             }
102         }
103 
104         impl<'a> Pow<$T> for &'a BigUint {
105             type Output = BigUint;
106 
107             #[inline]
108             fn pow(self, exp: $T) -> BigUint {
109                 if exp == 0 {
110                     return BigUint::one();
111                 }
112                 Pow::pow(self.clone(), exp)
113             }
114         }
115 
116         impl<'a, 'b> Pow<&'b $T> for &'a BigUint {
117             type Output = BigUint;
118 
119             #[inline]
120             fn pow(self, exp: &$T) -> BigUint {
121                 Pow::pow(self, *exp)
122             }
123         }
124     };
125 }
126 
127 pow_impl!(u8);
128 pow_impl!(u16);
129 pow_impl!(u32);
130 pow_impl!(u64);
131 pow_impl!(usize);
132 pow_impl!(u128);
133 
modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint134 pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
135     assert!(
136         !modulus.is_zero(),
137         "attempt to calculate with zero modulus!"
138     );
139 
140     if modulus.is_odd() {
141         // For an odd modulus, we can use Montgomery multiplication in base 2^32.
142         monty_modpow(x, exponent, modulus)
143     } else {
144         // Otherwise do basically the same as `num::pow`, but with a modulus.
145         plain_modpow(x, &exponent.data, modulus)
146     }
147 }
148 
plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint149 fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {
150     assert!(
151         !modulus.is_zero(),
152         "attempt to calculate with zero modulus!"
153     );
154 
155     let i = match exp_data.iter().position(|&r| r != 0) {
156         None => return BigUint::one(),
157         Some(i) => i,
158     };
159 
160     let mut base = base % modulus;
161     for _ in 0..i {
162         for _ in 0..big_digit::BITS {
163             base = &base * &base % modulus;
164         }
165     }
166 
167     let mut r = exp_data[i];
168     let mut b = 0u8;
169     while r.is_even() {
170         base = &base * &base % modulus;
171         r >>= 1;
172         b += 1;
173     }
174 
175     let mut exp_iter = exp_data[i + 1..].iter();
176     if exp_iter.len() == 0 && r.is_one() {
177         return base;
178     }
179 
180     let mut acc = base.clone();
181     r >>= 1;
182     b += 1;
183 
184     {
185         let mut unit = |exp_is_odd| {
186             base = &base * &base % modulus;
187             if exp_is_odd {
188                 acc *= &base;
189                 acc %= modulus;
190             }
191         };
192 
193         if let Some(&last) = exp_iter.next_back() {
194             // consume exp_data[i]
195             for _ in b..big_digit::BITS {
196                 unit(r.is_odd());
197                 r >>= 1;
198             }
199 
200             // consume all other digits before the last
201             for &r in exp_iter {
202                 let mut r = r;
203                 for _ in 0..big_digit::BITS {
204                     unit(r.is_odd());
205                     r >>= 1;
206                 }
207             }
208             r = last;
209         }
210 
211         debug_assert_ne!(r, 0);
212         while !r.is_zero() {
213             unit(r.is_odd());
214             r >>= 1;
215         }
216     }
217     acc
218 }
219 
220 #[test]
test_plain_modpow()221 fn test_plain_modpow() {
222     let two = &BigUint::from(2u32);
223     let modulus = BigUint::from(0x1100u32);
224 
225     let exp = vec![0, 0b1];
226     assert_eq!(
227         two.pow(0b1_00000000_u32) % &modulus,
228         plain_modpow(&two, &exp, &modulus)
229     );
230     let exp = vec![0, 0b10];
231     assert_eq!(
232         two.pow(0b10_00000000_u32) % &modulus,
233         plain_modpow(&two, &exp, &modulus)
234     );
235     let exp = vec![0, 0b110010];
236     assert_eq!(
237         two.pow(0b110010_00000000_u32) % &modulus,
238         plain_modpow(&two, &exp, &modulus)
239     );
240     let exp = vec![0b1, 0b1];
241     assert_eq!(
242         two.pow(0b1_00000001_u32) % &modulus,
243         plain_modpow(&two, &exp, &modulus)
244     );
245     let exp = vec![0b1100, 0, 0b1];
246     assert_eq!(
247         two.pow(0b1_00000000_00001100_u32) % &modulus,
248         plain_modpow(&two, &exp, &modulus)
249     );
250 }
251 
252 #[test]
test_pow_biguint()253 fn test_pow_biguint() {
254     let base = BigUint::from(5u8);
255     let exponent = BigUint::from(3u8);
256 
257     assert_eq!(BigUint::from(125u8), base.pow(exponent));
258 }
259