1 // Copyright 2015-2016 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //!     https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38 
39 use crate::{
40     arithmetic::montgomery::*,
41     bits, bssl, c, error,
42     limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43 };
44 use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45 use core::{
46     marker::PhantomData,
47     ops::{Deref, DerefMut},
48 };
49 use untrusted;
50 
51 pub unsafe trait Prime {}
52 
53 struct Width<M> {
54     num_limbs: usize,
55 
56     /// The modulus *m* that the width originated from.
57     m: PhantomData<M>,
58 }
59 
60 /// All `BoxedLimbs<M>` are stored in the same number of limbs.
61 struct BoxedLimbs<M> {
62     limbs: Box<[Limb]>,
63 
64     /// The modulus *m* that determines the size of `limbx`.
65     m: PhantomData<M>,
66 }
67 
68 impl<M> Deref for BoxedLimbs<M> {
69     type Target = [Limb];
70     #[inline]
deref(&self) -> &Self::Target71     fn deref(&self) -> &Self::Target {
72         &self.limbs
73     }
74 }
75 
76 impl<M> DerefMut for BoxedLimbs<M> {
77     #[inline]
deref_mut(&mut self) -> &mut Self::Target78     fn deref_mut(&mut self) -> &mut Self::Target {
79         &mut self.limbs
80     }
81 }
82 
83 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
84 // is resolved or restrict `M: Clone`.
85 impl<M> Clone for BoxedLimbs<M> {
clone(&self) -> Self86     fn clone(&self) -> Self {
87         Self {
88             limbs: self.limbs.clone(),
89             m: self.m.clone(),
90         }
91     }
92 }
93 
94 impl<M> BoxedLimbs<M> {
positive_minimal_width_from_be_bytes( input: untrusted::Input, ) -> Result<Self, error::KeyRejected>95     fn positive_minimal_width_from_be_bytes(
96         input: untrusted::Input,
97     ) -> Result<Self, error::KeyRejected> {
98         // Reject leading zeros. Also reject the value zero ([0]) because zero
99         // isn't positive.
100         if untrusted::Reader::new(input).peek(0) {
101             return Err(error::KeyRejected::invalid_encoding());
102         }
103         let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
104         let mut r = Self::zero(Width {
105             num_limbs,
106             m: PhantomData,
107         });
108         limb::parse_big_endian_and_pad_consttime(input, &mut r)
109             .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
110         Ok(r)
111     }
112 
minimal_width_from_unpadded(limbs: &[Limb]) -> Self113     fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
114         debug_assert_ne!(limbs.last(), Some(&0));
115         Self {
116             limbs: limbs.to_owned().into_boxed_slice(),
117             m: PhantomData,
118         }
119     }
120 
from_be_bytes_padded_less_than( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>121     fn from_be_bytes_padded_less_than(
122         input: untrusted::Input,
123         m: &Modulus<M>,
124     ) -> Result<Self, error::Unspecified> {
125         let mut r = Self::zero(m.width());
126         limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
127         if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
128             return Err(error::Unspecified);
129         }
130         Ok(r)
131     }
132 
133     #[inline]
is_zero(&self) -> bool134     fn is_zero(&self) -> bool {
135         limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
136     }
137 
zero(width: Width<M>) -> Self138     fn zero(width: Width<M>) -> Self {
139         Self {
140             limbs: vec![0; width.num_limbs].to_owned().into_boxed_slice(),
141             m: PhantomData,
142         }
143     }
144 
width(&self) -> Width<M>145     fn width(&self) -> Width<M> {
146         Width {
147             num_limbs: self.limbs.len(),
148             m: PhantomData,
149         }
150     }
151 }
152 
153 /// A modulus *s* that is smaller than another modulus *l* so every element of
154 /// ℤ/sℤ is also an element of ℤ/lℤ.
155 pub unsafe trait SmallerModulus<L> {}
156 
157 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
158 /// the precondition for reduction by conditional subtraction,
159 /// `elem_reduce_once()`.
160 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
161 
162 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
163 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
164 /// ℤ/sℤ.
165 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
166 
167 pub unsafe trait PublicModulus {}
168 
169 /// The x86 implementation of `GFp_bn_mul_mont`, at least, requires at least 4
170 /// limbs. For a long time we have required 4 limbs for all targets, though
171 /// this may be unnecessary. TODO: Replace this with
172 /// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
173 /// same.
174 pub const MODULUS_MIN_LIMBS: usize = 4;
175 
176 pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
177 
178 /// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
179 /// for efficient Montgomery multiplication modulo *m*. The value must be odd
180 /// and larger than 2. The larger-than-1 requirement is imposed, at least, by
181 /// the modular inversion code.
182 pub struct Modulus<M> {
183     limbs: BoxedLimbs<M>, // Also `value >= 3`.
184 
185     // n0 * N == -1 (mod r).
186     //
187     // r == 2**(N0_LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
188     // ensures that we can do integer division by |r| by simply ignoring
189     // `N0_LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
190     // just looking at the lowest `N0_LIMBS_USED` limbs. This is what makes
191     // Montgomery multiplication efficient.
192     //
193     // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
194     // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
195     // multi-limb Montgomery multiplication of a * b (mod n), given the
196     // unreduced product t == a * b, we repeatedly calculate:
197     //
198     //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
199     //    t2 := t1*n0*n
200     //    t3 := t + t2
201     //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
202     //
203     // In the last step, it would only make sense to ignore the lowest limb of
204     // |t3| if it were zero. The middle steps ensure that this is the case:
205     //
206     //                            t3 ==  0 (mod r)
207     //                        t + t2 ==  0 (mod r)
208     //                   t + t1*n0*n ==  0 (mod r)
209     //                       t1*n0*n == -t (mod r)
210     //                        t*n0*n == -t (mod r)
211     //                          n0*n == -1 (mod r)
212     //                            n0 == -1/n (mod r)
213     //
214     // Thus, in each iteration of the loop, we multiply by the constant factor
215     // n0, the negative inverse of n (mod r).
216     //
217     // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
218     // ones that don't, we could use a shorter `R` value and use faster `Limb`
219     // calculations instead of double-precision `u64` calculations.
220     n0: N0,
221 
222     oneRR: One<M, RR>,
223 }
224 
225 impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error>226     fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
227         fmt.debug_struct("Modulus")
228             // TODO: Print modulus value.
229             .finish()
230     }
231 }
232 
233 impl<M> Modulus<M> {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::KeyRejected>234     pub fn from_be_bytes_with_bit_length(
235         input: untrusted::Input,
236     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
237         let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
238         Self::from_boxed_limbs(limbs)
239     }
240 
from_nonnegative_with_bit_length( n: Nonnegative, ) -> Result<(Self, bits::BitLength), error::KeyRejected>241     pub fn from_nonnegative_with_bit_length(
242         n: Nonnegative,
243     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
244         let limbs = BoxedLimbs {
245             limbs: n.limbs.into_boxed_slice(),
246             m: PhantomData,
247         };
248         Self::from_boxed_limbs(limbs)
249     }
250 
from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected>251     fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
252         if n.len() > MODULUS_MAX_LIMBS {
253             return Err(error::KeyRejected::too_large());
254         }
255         if n.len() < MODULUS_MIN_LIMBS {
256             return Err(error::KeyRejected::unexpected_error());
257         }
258         if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
259             return Err(error::KeyRejected::invalid_component());
260         }
261         if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
262             return Err(error::KeyRejected::unexpected_error());
263         }
264 
265         // n_mod_r = n % r. As explained in the documentation for `n0`, this is
266         // done by taking the lowest `N0_LIMBS_USED` limbs of `n`.
267         let n0 = {
268             extern "C" {
269                 fn GFp_bn_neg_inv_mod_r_u64(n: u64) -> u64;
270             }
271 
272             // XXX: u64::from isn't guaranteed to be constant time.
273             let mut n_mod_r: u64 = u64::from(n[0]);
274 
275             if N0_LIMBS_USED == 2 {
276                 // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
277                 // fail to compile because of `deny(exceeding_bitshifts)`.
278                 debug_assert_eq!(LIMB_BITS, 32);
279                 n_mod_r |= u64::from(n[1]) << 32;
280             }
281             N0::from(unsafe { GFp_bn_neg_inv_mod_r_u64(n_mod_r) })
282         };
283 
284         let bits = limb::limbs_minimal_bits(&n.limbs);
285         let oneRR = {
286             let partial = PartialModulus {
287                 limbs: &n.limbs,
288                 n0: n0.clone(),
289                 m: PhantomData,
290             };
291 
292             One::newRR(&partial, bits)
293         };
294 
295         Ok((
296             Self {
297                 limbs: n,
298                 n0,
299                 oneRR,
300             },
301             bits,
302         ))
303     }
304 
305     #[inline]
width(&self) -> Width<M>306     fn width(&self) -> Width<M> {
307         self.limbs.width()
308     }
309 
zero<E>(&self) -> Elem<M, E>310     fn zero<E>(&self) -> Elem<M, E> {
311         Elem {
312             limbs: BoxedLimbs::zero(self.width()),
313             encoding: PhantomData,
314         }
315     }
316 
317     // TODO: Get rid of this
one(&self) -> Elem<M, Unencoded>318     fn one(&self) -> Elem<M, Unencoded> {
319         let mut r = self.zero();
320         r.limbs[0] = 1;
321         r
322     }
323 
oneRR(&self) -> &One<M, RR>324     pub fn oneRR(&self) -> &One<M, RR> {
325         &self.oneRR
326     }
327 
to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded> where M: SmallerModulus<L>,328     pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
329     where
330         M: SmallerModulus<L>,
331     {
332         // TODO: Encode this assertion into the `where` above.
333         assert_eq!(self.width().num_limbs, l.width().num_limbs);
334         let limbs = self.limbs.clone();
335         Elem {
336             limbs: BoxedLimbs {
337                 limbs: limbs.limbs,
338                 m: PhantomData,
339             },
340             encoding: PhantomData,
341         }
342     }
343 
as_partial(&self) -> PartialModulus<M>344     fn as_partial(&self) -> PartialModulus<M> {
345         PartialModulus {
346             limbs: &self.limbs,
347             n0: self.n0.clone(),
348             m: PhantomData,
349         }
350     }
351 }
352 
353 struct PartialModulus<'a, M> {
354     limbs: &'a [Limb],
355     n0: N0,
356     m: PhantomData<M>,
357 }
358 
359 impl<M> PartialModulus<'_, M> {
360     // TODO: XXX Avoid duplication with `Modulus`.
zero(&self) -> Elem<M, R>361     fn zero(&self) -> Elem<M, R> {
362         let width = Width {
363             num_limbs: self.limbs.len(),
364             m: PhantomData,
365         };
366         Elem {
367             limbs: BoxedLimbs::zero(width),
368             encoding: PhantomData,
369         }
370     }
371 }
372 
373 /// Elements of ℤ/mℤ for some modulus *m*.
374 //
375 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
376 // submodule. However, for maximum clarity, we always explicitly use
377 // `Unencoded` within the `bigint` submodule.
378 pub struct Elem<M, E = Unencoded> {
379     limbs: BoxedLimbs<M>,
380 
381     /// The number of Montgomery factors that need to be canceled out from
382     /// `value` to get the actual value.
383     encoding: PhantomData<E>,
384 }
385 
386 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
387 // is resolved or restrict `M: Clone` and `E: Clone`.
388 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self389     fn clone(&self) -> Self {
390         Self {
391             limbs: self.limbs.clone(),
392             encoding: self.encoding.clone(),
393         }
394     }
395 }
396 
397 impl<M, E> Elem<M, E> {
398     #[inline]
is_zero(&self) -> bool399     pub fn is_zero(&self) -> bool {
400         self.limbs.is_zero()
401     }
402 }
403 
404 impl<M, E: ReductionEncoding> Elem<M, E> {
decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output>405     fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
406         // A multiplication isn't required since we're multiplying by the
407         // unencoded value one (1); only a Montgomery reduction is needed.
408         // However the only non-multiplication Montgomery reduction function we
409         // have requires the input to be large, so we avoid using it here.
410         let mut limbs = self.limbs;
411         let num_limbs = m.width().num_limbs;
412         let mut one = [0; MODULUS_MAX_LIMBS];
413         one[0] = 1;
414         let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
415         limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
416         Elem {
417             limbs,
418             encoding: PhantomData,
419         }
420     }
421 }
422 
423 impl<M> Elem<M, R> {
424     #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>425     pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
426         self.decode_once(m)
427     }
428 }
429 
430 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>431     pub fn from_be_bytes_padded(
432         input: untrusted::Input,
433         m: &Modulus<M>,
434     ) -> Result<Self, error::Unspecified> {
435         Ok(Elem {
436             limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
437             encoding: PhantomData,
438         })
439     }
440 
441     #[inline]
fill_be_bytes(&self, out: &mut [u8])442     pub fn fill_be_bytes(&self, out: &mut [u8]) {
443         // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
444         limb::big_endian_from_limbs(&self.limbs, out)
445     }
446 
into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected>447     pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
448         let (m, _bits) =
449             Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
450         Ok(m)
451     }
452 
is_one(&self) -> bool453     fn is_one(&self) -> bool {
454         limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
455     }
456 }
457 
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,458 pub fn elem_mul<M, AF, BF>(
459     a: &Elem<M, AF>,
460     b: Elem<M, BF>,
461     m: &Modulus<M>,
462 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
463 where
464     (AF, BF): ProductEncoding,
465 {
466     elem_mul_(a, b, &m.as_partial())
467 }
468 
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,469 fn elem_mul_<M, AF, BF>(
470     a: &Elem<M, AF>,
471     mut b: Elem<M, BF>,
472     m: &PartialModulus<M>,
473 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
474 where
475     (AF, BF): ProductEncoding,
476 {
477     limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
478     Elem {
479         limbs: b.limbs,
480         encoding: PhantomData,
481     }
482 }
483 
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)484 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
485     extern "C" {
486         fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
487     }
488     unsafe {
489         LIMBS_shl_mod(
490             a.limbs.as_mut_ptr(),
491             a.limbs.as_ptr(),
492             m.limbs.as_ptr(),
493             m.limbs.len(),
494         );
495     }
496 }
497 
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>498 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
499     a: &Elem<Larger, Unencoded>,
500     m: &Modulus<Smaller>,
501 ) -> Elem<Smaller, Unencoded> {
502     let mut r = a.limbs.clone();
503     assert!(r.len() <= m.limbs.len());
504     limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
505     Elem {
506         limbs: BoxedLimbs {
507             limbs: r.limbs,
508             m: PhantomData,
509         },
510         encoding: PhantomData,
511     }
512 }
513 
514 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Result<Elem<Smaller, RInverse>, error::Unspecified>515 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
516     a: &Elem<Larger, Unencoded>,
517     m: &Modulus<Smaller>,
518 ) -> Result<Elem<Smaller, RInverse>, error::Unspecified> {
519     extern "C" {
520         fn GFp_bn_from_montgomery_in_place(
521             r: *mut Limb,
522             num_r: c::size_t,
523             a: *mut Limb,
524             num_a: c::size_t,
525             n: *const Limb,
526             num_n: c::size_t,
527             n0: &N0,
528         ) -> bssl::Result;
529     }
530 
531     let mut tmp = [0; MODULUS_MAX_LIMBS];
532     let tmp = &mut tmp[..a.limbs.len()];
533     tmp.copy_from_slice(&a.limbs);
534 
535     let mut r = m.zero();
536     Result::from(unsafe {
537         GFp_bn_from_montgomery_in_place(
538             r.limbs.as_mut_ptr(),
539             r.limbs.len(),
540             tmp.as_mut_ptr(),
541             tmp.len(),
542             m.limbs.as_ptr(),
543             m.limbs.len(),
544             &m.n0,
545         )
546     })?;
547     Ok(r)
548 }
549 
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,550 fn elem_squared<M, E>(
551     mut a: Elem<M, E>,
552     m: &PartialModulus<M>,
553 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
554 where
555     (E, E): ProductEncoding,
556 {
557     limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
558     Elem {
559         limbs: a.limbs,
560         encoding: PhantomData,
561     }
562 }
563 
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>564 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
565     a: Elem<Smaller, Unencoded>,
566     m: &Modulus<Larger>,
567 ) -> Elem<Larger, Unencoded> {
568     let mut r = m.zero();
569     r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
570     r
571 }
572 
573 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>574 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
575     extern "C" {
576         // `r` and `a` may alias.
577         fn LIMBS_add_mod(
578             r: *mut Limb,
579             a: *const Limb,
580             b: *const Limb,
581             m: *const Limb,
582             num_limbs: c::size_t,
583         );
584     }
585     unsafe {
586         LIMBS_add_mod(
587             a.limbs.as_mut_ptr(),
588             a.limbs.as_ptr(),
589             b.limbs.as_ptr(),
590             m.limbs.as_ptr(),
591             m.limbs.len(),
592         )
593     }
594     a
595 }
596 
597 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>598 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
599     extern "C" {
600         // `r` and `a` may alias.
601         fn LIMBS_sub_mod(
602             r: *mut Limb,
603             a: *const Limb,
604             b: *const Limb,
605             m: *const Limb,
606             num_limbs: c::size_t,
607         );
608     }
609     unsafe {
610         LIMBS_sub_mod(
611             a.limbs.as_mut_ptr(),
612             a.limbs.as_ptr(),
613             b.limbs.as_ptr(),
614             m.limbs.as_ptr(),
615             m.limbs.len(),
616         );
617     }
618     a
619 }
620 
621 // The value 1, Montgomery-encoded some number of times.
622 pub struct One<M, E>(Elem<M, E>);
623 
624 impl<M> One<M, RR> {
625     // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
626     // 2**LIMB_BITS such that R > m.
627     //
628     // Even though the assembly on some 32-bit platforms works with 64-bit
629     // values, using `LIMB_BITS` here, rather than `N0_LIMBS_USED * LIMB_BITS`,
630     // is correct because R**2 will still be a multiple of the latter as
631     // `N0_LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self632     fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
633         let m_bits = m_bits.as_usize_bits();
634         let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
635 
636         // base = 2**(lg m - 1).
637         let bit = m_bits - 1;
638         let mut base = m.zero();
639         base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
640 
641         // Double `base` so that base == R == 2**r (mod m). For normal moduli
642         // that have the high bit of the highest limb set, this requires one
643         // doubling. Unusual moduli require more doublings but we are less
644         // concerned about the performance of those.
645         //
646         // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
647         // Montgomery form (`elem_exp_vartime_()` requires the base to be in
648         // Montgomery form). Then compute
649         // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
650         //
651         // Take advantage of the fact that `elem_mul_by_2` is faster than
652         // `elem_squared` by replacing some of the early squarings with shifts.
653         // TODO: Benchmark shift vs. squaring performance to determine the
654         // optimal value of `lg_base`.
655         let lg_base = 2usize; // Shifts vs. squaring trade-off.
656         debug_assert_eq!(lg_base.count_ones(), 1); // Must 2**n for n >= 0.
657         let shifts = r - bit + lg_base;
658         let exponent = (r / lg_base) as u64;
659         for _ in 0..shifts {
660             elem_mul_by_2(&mut base, m)
661         }
662         let RR = elem_exp_vartime_(base, exponent, m);
663 
664         Self(Elem {
665             limbs: RR.limbs,
666             encoding: PhantomData, // PhantomData<RR>
667         })
668     }
669 }
670 
671 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>672     fn as_ref(&self) -> &Elem<M, E> {
673         &self.0
674     }
675 }
676 
677 /// A non-secret odd positive value in the range
678 /// [3, PUBLIC_EXPONENT_MAX_VALUE].
679 #[derive(Clone, Copy, Debug)]
680 pub struct PublicExponent(u64);
681 
682 impl PublicExponent {
from_be_bytes( input: untrusted::Input, min_value: u64, ) -> Result<Self, error::KeyRejected>683     pub fn from_be_bytes(
684         input: untrusted::Input,
685         min_value: u64,
686     ) -> Result<Self, error::KeyRejected> {
687         if input.len() > 5 {
688             return Err(error::KeyRejected::too_large());
689         }
690         let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
691             // The exponent can't be zero and it can't be prefixed with
692             // zero-valued bytes.
693             if input.peek(0) {
694                 return Err(error::KeyRejected::invalid_encoding());
695             }
696             let mut value = 0u64;
697             loop {
698                 let byte = input
699                     .read_byte()
700                     .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
701                 value = (value << 8) | u64::from(byte);
702                 if input.at_end() {
703                     return Ok(value);
704                 }
705             }
706         })?;
707 
708         // Step 2 / Step b. NIST SP800-89 defers to FIPS 186-3, which requires
709         // `e >= 65537`. We enforce this when signing, but are more flexible in
710         // verification, for compatibility. Only small public exponents are
711         // supported.
712         if value & 1 != 1 {
713             return Err(error::KeyRejected::invalid_component());
714         }
715         debug_assert!(min_value & 1 == 1);
716         debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
717         if min_value < 3 {
718             return Err(error::KeyRejected::invalid_component());
719         }
720         if value < min_value {
721             return Err(error::KeyRejected::too_small());
722         }
723         if value > PUBLIC_EXPONENT_MAX_VALUE {
724             return Err(error::KeyRejected::too_large());
725         }
726 
727         Ok(Self(value))
728     }
729 }
730 
731 // This limit was chosen to bound the performance of the simple
732 // exponentiation-by-squaring implementation in `elem_exp_vartime`. In
733 // particular, it helps mitigate theoretical resource exhaustion attacks. 33
734 // bits was chosen as the limit based on the recommendations in [1] and
735 // [2]. Windows CryptoAPI (at least older versions) doesn't support values
736 // larger than 32 bits [3], so it is unlikely that exponents larger than 32
737 // bits are being used for anything Windows commonly does.
738 //
739 // [1] https://www.imperialviolet.org/2012/03/16/rsae.html
740 // [2] https://www.imperialviolet.org/2012/03/17/rsados.html
741 // [3] https://msdn.microsoft.com/en-us/library/aa387685(VS.85).aspx
742 const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
743 
744 /// Calculates base**exponent (mod m).
745 // TODO: The test coverage needs to be expanded, e.g. test with the largest
746 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, Unencoded>, PublicExponent(exponent): PublicExponent, m: &Modulus<M>, ) -> Elem<M, R>747 pub fn elem_exp_vartime<M>(
748     base: Elem<M, Unencoded>,
749     PublicExponent(exponent): PublicExponent,
750     m: &Modulus<M>,
751 ) -> Elem<M, R> {
752     let base = elem_mul(m.oneRR().as_ref(), base, &m);
753     elem_exp_vartime_(base, exponent, &m.as_partial())
754 }
755 
756 /// Calculates base**exponent (mod m).
elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R>757 fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
758     // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
759     // square-and-multiply that scans the exponent from the most significant
760     // bit to the least significant bit (left-to-right). Left-to-right requires
761     // less storage compared to right-to-left scanning, at the cost of needing
762     // to compute `exponent.leading_zeros()`, which we assume to be cheap.
763     //
764     // During RSA public key operations the exponent is almost always either 65537
765     // (0b10000000000000001) or 3 (0b11), both of which have a Hamming weight
766     // of 2. During Montgomery setup the exponent is almost always a power of two,
767     // with Hamming weight 1. As explained in [Knuth], exponentiation by squaring
768     // is the most efficient algorithm when the Hamming weight is 2 or less. It
769     // isn't the most efficient for all other, uncommon, exponent values but any
770     // suboptimality is bounded by `PUBLIC_EXPONENT_MAX_VALUE`.
771     //
772     // This implementation is slightly simplified by taking advantage of the
773     // fact that we require the exponent to be a positive integer.
774     //
775     // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
776     //          Algorithms (3rd Edition), Section 4.6.3.
777     assert!(exponent >= 1);
778     assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
779     let mut acc = base.clone();
780     let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
781     debug_assert!((exponent & bit) != 0);
782     while bit > 1 {
783         bit >>= 1;
784         acc = elem_squared(acc, m);
785         if (exponent & bit) != 0 {
786             acc = elem_mul_(&base, acc, m);
787         }
788     }
789     acc
790 }
791 
792 // `M` represents the prime modulus for which the exponent is in the interval
793 // [1, `m` - 1).
794 pub struct PrivateExponent<M> {
795     limbs: BoxedLimbs<M>,
796 }
797 
798 impl<M> PrivateExponent<M> {
from_be_bytes_padded( input: untrusted::Input, p: &Modulus<M>, ) -> Result<Self, error::Unspecified>799     pub fn from_be_bytes_padded(
800         input: untrusted::Input,
801         p: &Modulus<M>,
802     ) -> Result<Self, error::Unspecified> {
803         let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
804 
805         // Proof that `dP < p - 1`:
806         //
807         // If `dP < p` then either `dP == p - 1` or `dP < p - 1`. Since `p` is
808         // odd, `p - 1` is even. `d` is odd, and an odd number modulo an even
809         // number is odd. Therefore `dP` must be odd. But then it cannot be
810         // `p - 1` and so we know `dP < p - 1`.
811         //
812         // Further we know `dP != 0` because `dP` is not even.
813         if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
814             return Err(error::Unspecified);
815         }
816 
817         Ok(Self { limbs: dP })
818     }
819 }
820 
821 impl<M: Prime> PrivateExponent<M> {
822     // Returns `p - 2`.
for_flt(p: &Modulus<M>) -> Self823     fn for_flt(p: &Modulus<M>) -> Self {
824         let two = elem_add(p.one(), p.one(), p);
825         let p_minus_2 = elem_sub(p.zero(), &two, p);
826         Self {
827             limbs: p_minus_2.limbs,
828         }
829     }
830 }
831 
832 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>833 pub fn elem_exp_consttime<M>(
834     base: Elem<M, R>,
835     exponent: &PrivateExponent<M>,
836     m: &Modulus<M>,
837 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
838     use crate::limb::Window;
839 
840     const WINDOW_BITS: usize = 5;
841     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
842 
843     let num_limbs = m.limbs.len();
844 
845     let mut table = vec![0; TABLE_ENTRIES * num_limbs];
846 
847     fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
848         extern "C" {
849             fn LIMBS_select_512_32(
850                 r: *mut Limb,
851                 table: *const Limb,
852                 num_limbs: c::size_t,
853                 i: Window,
854             ) -> bssl::Result;
855         }
856         Result::from(unsafe {
857             LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
858         })
859         .unwrap();
860     }
861 
862     fn power<M>(
863         table: &[Limb],
864         i: Window,
865         mut acc: Elem<M, R>,
866         mut tmp: Elem<M, R>,
867         m: &Modulus<M>,
868     ) -> (Elem<M, R>, Elem<M, R>) {
869         for _ in 0..WINDOW_BITS {
870             acc = elem_squared(acc, &m.as_partial());
871         }
872         gather(table, i, &mut tmp);
873         let acc = elem_mul(&tmp, acc, m);
874         (acc, tmp)
875     }
876 
877     let tmp = m.one();
878     let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
879 
880     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
881         &table[(i * num_limbs)..][..num_limbs]
882     }
883     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
884         &mut table[(i * num_limbs)..][..num_limbs]
885     }
886     let num_limbs = m.limbs.len();
887     entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
888     entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
889     for i in 2..TABLE_ENTRIES {
890         let (src1, src2) = if i % 2 == 0 {
891             (i / 2, i / 2)
892         } else {
893             (i - 1, 1)
894         };
895         let (previous, rest) = table.split_at_mut(num_limbs * i);
896         let src1 = entry(previous, src1, num_limbs);
897         let src2 = entry(previous, src2, num_limbs);
898         let dst = entry_mut(rest, 0, num_limbs);
899         limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
900     }
901 
902     let (r, _) = limb::fold_5_bit_windows(
903         &exponent.limbs,
904         |initial_window| {
905             let mut r = Elem {
906                 limbs: base.limbs,
907                 encoding: PhantomData,
908             };
909             gather(&table, initial_window, &mut r);
910             (r, tmp)
911         },
912         |(acc, tmp), window| power(&table, window, acc, tmp, m),
913     );
914 
915     let r = r.into_unencoded(m);
916 
917     Ok(r)
918 }
919 
920 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>921 pub fn elem_inverse_consttime<M: Prime>(
922     a: Elem<M, R>,
923     m: &Modulus<M>,
924 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
925     elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
926 }
927 
928 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>929 pub fn elem_exp_consttime<M>(
930     base: Elem<M, R>,
931     exponent: &PrivateExponent<M>,
932     m: &Modulus<M>,
933 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
934     // The x86_64 assembly was written under the assumption that the input data
935     // is aligned to `MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH` bytes, which was/is
936     // 64 in OpenSSL. Similarly, OpenSSL uses the x86_64 assembly functions by
937     // giving it only inputs `tmp`, `am`, and `np` that immediately follow the
938     // table. The code seems to "work" even when the inputs aren't exactly
939     // like that but the side channel defenses might not be as effective. All
940     // the awkwardness here stems from trying to use the assembly code like
941     // OpenSSL does.
942 
943     use crate::limb::Window;
944 
945     const WINDOW_BITS: usize = 5;
946     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
947 
948     let num_limbs = m.limbs.len();
949 
950     const ALIGNMENT: usize = 64;
951     assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
952     let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
953     let (table, state) = {
954         let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
955         let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
956         assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
957         table.split_at_mut(TABLE_ENTRIES * num_limbs)
958     };
959 
960     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
961         &table[(i * num_limbs)..][..num_limbs]
962     }
963     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
964         &mut table[(i * num_limbs)..][..num_limbs]
965     }
966 
967     const ACC: usize = 0; // `tmp` in OpenSSL
968     const BASE: usize = ACC + 1; // `am` in OpenSSL
969     const M: usize = BASE + 1; // `np` in OpenSSL
970 
971     entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
972     entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
973 
974     fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
975         extern "C" {
976             fn GFp_bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
977         }
978         unsafe {
979             GFp_bn_scatter5(
980                 entry(state, ACC, num_limbs).as_ptr(),
981                 num_limbs,
982                 table.as_mut_ptr(),
983                 i,
984             )
985         }
986     }
987 
988     fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
989         extern "C" {
990             fn GFp_bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
991         }
992         unsafe {
993             GFp_bn_gather5(
994                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
995                 num_limbs,
996                 table.as_ptr(),
997                 i,
998             )
999         }
1000     }
1001 
1002     fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1003         gather(table, state, i, num_limbs);
1004         assert_eq!(ACC, 0);
1005         let (acc, rest) = state.split_at_mut(num_limbs);
1006         let m = entry(rest, M - 1, num_limbs);
1007         limbs_mont_square(acc, m, n0);
1008     }
1009 
1010     fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1011         extern "C" {
1012             fn GFp_bn_mul_mont_gather5(
1013                 rp: *mut Limb,
1014                 ap: *const Limb,
1015                 table: *const Limb,
1016                 np: *const Limb,
1017                 n0: &N0,
1018                 num: c::size_t,
1019                 power: Window,
1020             );
1021         }
1022         unsafe {
1023             GFp_bn_mul_mont_gather5(
1024                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1025                 entry(state, BASE, num_limbs).as_ptr(),
1026                 table.as_ptr(),
1027                 entry(state, M, num_limbs).as_ptr(),
1028                 n0,
1029                 num_limbs,
1030                 i,
1031             );
1032         }
1033     }
1034 
1035     fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1036         extern "C" {
1037             fn GFp_bn_power5(
1038                 r: *mut Limb,
1039                 a: *const Limb,
1040                 table: *const Limb,
1041                 n: *const Limb,
1042                 n0: &N0,
1043                 num: c::size_t,
1044                 i: Window,
1045             );
1046         }
1047         unsafe {
1048             GFp_bn_power5(
1049                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1050                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1051                 table.as_ptr(),
1052                 entry(state, M, num_limbs).as_ptr(),
1053                 n0,
1054                 num_limbs,
1055                 i,
1056             );
1057         }
1058     }
1059 
1060     // table[0] = base**0.
1061     {
1062         let acc = entry_mut(state, ACC, num_limbs);
1063         acc[0] = 1;
1064         limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1065     }
1066     scatter(table, state, 0, num_limbs);
1067 
1068     // table[1] = base**1.
1069     entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1070     scatter(table, state, 1, num_limbs);
1071 
1072     for i in 2..(TABLE_ENTRIES as Window) {
1073         if i % 2 == 0 {
1074             // TODO: Optimize this to avoid gathering
1075             gather_square(table, state, &m.n0, i / 2, num_limbs);
1076         } else {
1077             gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1078         };
1079         scatter(table, state, i, num_limbs);
1080     }
1081 
1082     let state = limb::fold_5_bit_windows(
1083         &exponent.limbs,
1084         |initial_window| {
1085             gather(table, state, initial_window, num_limbs);
1086             state
1087         },
1088         |state, window| {
1089             power(table, state, &m.n0, window, num_limbs);
1090             state
1091         },
1092     );
1093 
1094     extern "C" {
1095         fn GFp_bn_from_montgomery(
1096             r: *mut Limb,
1097             a: *const Limb,
1098             not_used: *const Limb,
1099             n: *const Limb,
1100             n0: &N0,
1101             num: c::size_t,
1102         ) -> bssl::Result;
1103     }
1104     Result::from(unsafe {
1105         GFp_bn_from_montgomery(
1106             entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1107             entry(state, ACC, num_limbs).as_ptr(),
1108             core::ptr::null(),
1109             entry(state, M, num_limbs).as_ptr(),
1110             &m.n0,
1111             num_limbs,
1112         )
1113     })?;
1114     let mut r = Elem {
1115         limbs: base.limbs,
1116         encoding: PhantomData,
1117     };
1118     r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1119     Ok(r)
1120 }
1121 
1122 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>1123 pub fn verify_inverses_consttime<M>(
1124     a: &Elem<M, R>,
1125     b: Elem<M, Unencoded>,
1126     m: &Modulus<M>,
1127 ) -> Result<(), error::Unspecified> {
1128     if elem_mul(a, b, m).is_one() {
1129         Ok(())
1130     } else {
1131         Err(error::Unspecified)
1132     }
1133 }
1134 
1135 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>1136 pub fn elem_verify_equal_consttime<M, E>(
1137     a: &Elem<M, E>,
1138     b: &Elem<M, E>,
1139 ) -> Result<(), error::Unspecified> {
1140     if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1141         Ok(())
1142     } else {
1143         Err(error::Unspecified)
1144     }
1145 }
1146 
1147 /// Nonnegative integers.
1148 pub struct Nonnegative {
1149     limbs: Vec<Limb>,
1150 }
1151 
1152 impl Nonnegative {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::Unspecified>1153     pub fn from_be_bytes_with_bit_length(
1154         input: untrusted::Input,
1155     ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1156         let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1157         // Rejects empty inputs.
1158         limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1159         while limbs.last() == Some(&0) {
1160             let _ = limbs.pop();
1161         }
1162         let r_bits = limb::limbs_minimal_bits(&limbs);
1163         Ok((Self { limbs }, r_bits))
1164     }
1165 
1166     #[inline]
is_odd(&self) -> bool1167     pub fn is_odd(&self) -> bool {
1168         limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1169     }
1170 
verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified>1171     pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1172         if !greater_than(other, self) {
1173             return Err(error::Unspecified);
1174         }
1175         Ok(())
1176     }
1177 
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>1178     pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1179         self.verify_less_than_modulus(&m)?;
1180         let mut r = m.zero();
1181         r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1182         Ok(r)
1183     }
1184 
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>1185     pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1186         if self.limbs.len() > m.limbs.len() {
1187             return Err(error::Unspecified);
1188         }
1189         if self.limbs.len() == m.limbs.len() {
1190             if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1191                 return Err(error::Unspecified);
1192             }
1193         }
1194         return Ok(());
1195     }
1196 }
1197 
1198 // Returns a > b.
greater_than(a: &Nonnegative, b: &Nonnegative) -> bool1199 fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1200     if a.limbs.len() == b.limbs.len() {
1201         limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1202     } else {
1203         a.limbs.len() > b.limbs.len()
1204     }
1205 }
1206 
1207 #[derive(Clone)]
1208 #[repr(transparent)]
1209 struct N0([Limb; 2]);
1210 
1211 const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1212 
1213 impl From<u64> for N0 {
1214     #[inline]
from(n0: u64) -> Self1215     fn from(n0: u64) -> Self {
1216         #[cfg(target_pointer_width = "64")]
1217         {
1218             Self([n0, 0])
1219         }
1220 
1221         #[cfg(target_pointer_width = "32")]
1222         {
1223             Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1224         }
1225     }
1226 }
1227 
1228 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0)1229 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1230     debug_assert_eq!(r.len(), m.len());
1231     debug_assert_eq!(a.len(), m.len());
1232     unsafe {
1233         GFp_bn_mul_mont(
1234             r.as_mut_ptr(),
1235             r.as_ptr(),
1236             a.as_ptr(),
1237             m.as_ptr(),
1238             n0,
1239             r.len(),
1240         )
1241     }
1242 }
1243 
1244 /// r = a * b
1245 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0)1246 fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1247     debug_assert_eq!(r.len(), m.len());
1248     debug_assert_eq!(a.len(), m.len());
1249     debug_assert_eq!(b.len(), m.len());
1250     unsafe {
1251         GFp_bn_mul_mont(
1252             r.as_mut_ptr(),
1253             a.as_ptr(),
1254             b.as_ptr(),
1255             m.as_ptr(),
1256             n0,
1257             r.len(),
1258         )
1259     }
1260 }
1261 
1262 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0)1263 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1264     debug_assert_eq!(r.len(), m.len());
1265     unsafe {
1266         GFp_bn_mul_mont(
1267             r.as_mut_ptr(),
1268             r.as_ptr(),
1269             r.as_ptr(),
1270             m.as_ptr(),
1271             n0,
1272             r.len(),
1273         )
1274     }
1275 }
1276 
1277 extern "C" {
1278     // `r` and/or 'a' and/or 'b' may alias.
GFp_bn_mul_mont( r: *mut Limb, a: *const Limb, b: *const Limb, n: *const Limb, n0: &N0, num_limbs: c::size_t, )1279     fn GFp_bn_mul_mont(
1280         r: *mut Limb,
1281         a: *const Limb,
1282         b: *const Limb,
1283         n: *const Limb,
1284         n0: &N0,
1285         num_limbs: c::size_t,
1286     );
1287 }
1288 
1289 #[cfg(test)]
1290 mod tests {
1291     use super::*;
1292     use crate::test;
1293     use alloc::format;
1294     use untrusted;
1295 
1296     // Type-level representation of an arbitrary modulus.
1297     struct M {}
1298 
1299     unsafe impl PublicModulus for M {}
1300 
1301     #[test]
test_elem_exp_consttime()1302     fn test_elem_exp_consttime() {
1303         test::run(
1304             test_file!("bigint_elem_exp_consttime_tests.txt"),
1305             |section, test_case| {
1306                 assert_eq!(section, "");
1307 
1308                 let m = consume_modulus::<M>(test_case, "M");
1309                 let expected_result = consume_elem(test_case, "ModExp", &m);
1310                 let base = consume_elem(test_case, "A", &m);
1311                 let e = {
1312                     let bytes = test_case.consume_bytes("E");
1313                     PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1314                         .expect("valid exponent")
1315                 };
1316                 let base = into_encoded(base, &m);
1317                 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1318                 assert_elem_eq(&actual_result, &expected_result);
1319 
1320                 Ok(())
1321             },
1322         )
1323     }
1324 
1325     #[test]
1326     // TODO: fn test_elem_exp_vartime() using
1327     // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
1328     // In the meantime, the function is tested indirectly via the RSA
1329     // verification and signing tests.
1330     #[test]
test_elem_mul()1331     fn test_elem_mul() {
1332         test::run(
1333             test_file!("bigint_elem_mul_tests.txt"),
1334             |section, test_case| {
1335                 assert_eq!(section, "");
1336 
1337                 let m = consume_modulus::<M>(test_case, "M");
1338                 let expected_result = consume_elem(test_case, "ModMul", &m);
1339                 let a = consume_elem(test_case, "A", &m);
1340                 let b = consume_elem(test_case, "B", &m);
1341 
1342                 let b = into_encoded(b, &m);
1343                 let a = into_encoded(a, &m);
1344                 let actual_result = elem_mul(&a, b, &m);
1345                 let actual_result = actual_result.into_unencoded(&m);
1346                 assert_elem_eq(&actual_result, &expected_result);
1347 
1348                 Ok(())
1349             },
1350         )
1351     }
1352 
1353     #[test]
test_elem_squared()1354     fn test_elem_squared() {
1355         test::run(
1356             test_file!("bigint_elem_squared_tests.txt"),
1357             |section, test_case| {
1358                 assert_eq!(section, "");
1359 
1360                 let m = consume_modulus::<M>(test_case, "M");
1361                 let expected_result = consume_elem(test_case, "ModSquare", &m);
1362                 let a = consume_elem(test_case, "A", &m);
1363 
1364                 let a = into_encoded(a, &m);
1365                 let actual_result = elem_squared(a, &m.as_partial());
1366                 let actual_result = actual_result.into_unencoded(&m);
1367                 assert_elem_eq(&actual_result, &expected_result);
1368 
1369                 Ok(())
1370             },
1371         )
1372     }
1373 
1374     #[test]
test_elem_reduced()1375     fn test_elem_reduced() {
1376         test::run(
1377             test_file!("bigint_elem_reduced_tests.txt"),
1378             |section, test_case| {
1379                 assert_eq!(section, "");
1380 
1381                 struct MM {}
1382                 unsafe impl SmallerModulus<MM> for M {}
1383                 unsafe impl NotMuchSmallerModulus<MM> for M {}
1384 
1385                 let m = consume_modulus::<M>(test_case, "M");
1386                 let expected_result = consume_elem(test_case, "R", &m);
1387                 let a =
1388                     consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1389 
1390                 let actual_result = elem_reduced(&a, &m).unwrap();
1391                 let oneRR = m.oneRR();
1392                 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1393                 assert_elem_eq(&actual_result, &expected_result);
1394 
1395                 Ok(())
1396             },
1397         )
1398     }
1399 
1400     #[test]
test_elem_reduced_once()1401     fn test_elem_reduced_once() {
1402         test::run(
1403             test_file!("bigint_elem_reduced_once_tests.txt"),
1404             |section, test_case| {
1405                 assert_eq!(section, "");
1406 
1407                 struct N {}
1408                 struct QQ {}
1409                 unsafe impl SmallerModulus<N> for QQ {}
1410                 unsafe impl SlightlySmallerModulus<N> for QQ {}
1411 
1412                 let qq = consume_modulus::<QQ>(test_case, "QQ");
1413                 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1414                 let n = consume_modulus::<N>(test_case, "N");
1415                 let a = consume_elem::<N>(test_case, "A", &n);
1416 
1417                 let actual_result = elem_reduced_once(&a, &qq);
1418                 assert_elem_eq(&actual_result, &expected_result);
1419 
1420                 Ok(())
1421             },
1422         )
1423     }
1424 
1425     #[test]
test_modulus_debug()1426     fn test_modulus_debug() {
1427         let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(
1428             &vec![0xff; LIMB_BYTES * MODULUS_MIN_LIMBS],
1429         ))
1430         .unwrap();
1431         assert_eq!("Modulus", format!("{:?}", modulus));
1432     }
1433 
1434     #[test]
test_public_exponent_debug()1435     fn test_public_exponent_debug() {
1436         let exponent =
1437             PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1438                 .unwrap();
1439         assert_eq!("PublicExponent(65537)", format!("{:?}", exponent));
1440     }
1441 
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>1442     fn consume_elem<M>(
1443         test_case: &mut test::TestCase,
1444         name: &str,
1445         m: &Modulus<M>,
1446     ) -> Elem<M, Unencoded> {
1447         let value = test_case.consume_bytes(name);
1448         Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1449     }
1450 
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>1451     fn consume_elem_unchecked<M>(
1452         test_case: &mut test::TestCase,
1453         name: &str,
1454         num_limbs: usize,
1455     ) -> Elem<M, Unencoded> {
1456         let value = consume_nonnegative(test_case, name);
1457         let mut limbs = BoxedLimbs::zero(Width {
1458             num_limbs,
1459             m: PhantomData,
1460         });
1461         limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1462         Elem {
1463             limbs,
1464             encoding: PhantomData,
1465         }
1466     }
1467 
consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M>1468     fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1469         let value = test_case.consume_bytes(name);
1470         let (value, _) =
1471             Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1472         value
1473     }
1474 
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1475     fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1476         let bytes = test_case.consume_bytes(name);
1477         let (r, _r_bits) =
1478             Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1479         r
1480     }
1481 
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1482     fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1483         elem_verify_equal_consttime(&a, b).unwrap()
1484     }
1485 
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1486     fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1487         elem_mul(m.oneRR().as_ref(), a, m)
1488     }
1489 }
1490