1 // -*- mode: rust; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2016-2021 isis lovecruft
5 // Copyright (c) 2016-2020 Henry de Valence
6 // See LICENSE for licensing information.
7 //
8 // Authors:
9 // - isis agora lovecruft <isis@patternsinthevoid.net>
10 // - Henry de Valence <hdevalence@hdevalence.ca>
11 
12 // We allow non snake_case names because coordinates in projective space are
13 // traditionally denoted by the capitalisation of their respective
14 // counterparts in affine space.  Yeah, you heard me, rustc, I'm gonna have my
15 // affine and projective cakes and eat both of them too.
16 #![allow(non_snake_case)]
17 
18 //! An implementation of [Ristretto][ristretto_main], which provides a
19 //! prime-order group.
20 //!
21 //! # The Ristretto Group
22 //!
23 //! Ristretto is a modification of Mike Hamburg's Decaf scheme to work
24 //! with cofactor-\\(8\\) curves, such as Curve25519.
25 //!
26 //! The introduction of the Decaf paper, [_Decaf:
27 //! Eliminating cofactors through point
28 //! compression_](https://eprint.iacr.org/2015/673.pdf), notes that while
29 //! most cryptographic systems require a group of prime order, most
30 //! concrete implementations using elliptic curve groups fall short –
31 //! they either provide a group of prime order, but with incomplete or
32 //! variable-time addition formulae (for instance, most Weierstrass
33 //! models), or else they provide a fast and safe implementation of a
34 //! group whose order is not quite a prime \\(q\\), but \\(hq\\) for a
35 //! small cofactor \\(h\\) (for instance, Edwards curves, which have
36 //! cofactor at least \\(4\\)).
37 //!
38 //! This abstraction mismatch is commonly “handled” by pushing the
39 //! complexity upwards, adding ad-hoc protocol modifications.  But
40 //! these modifications require careful analysis and are a recurring
41 //! source of [vulnerabilities][cryptonote] and [design
42 //! complications][ed25519_hkd].
43 //!
44 //! Instead, Decaf (and Ristretto) use a quotient group to implement a
45 //! prime-order group using a non-prime-order curve.  This provides
46 //! the correct abstraction for cryptographic systems, while retaining
47 //! the speed and safety benefits of an Edwards curve.
48 //!
49 //! Decaf is named “after the procedure which divides the effect of
50 //! coffee by \\(4\\)”.  However, Curve25519 has a cofactor of
51 //! \\(8\\).  To eliminate its cofactor, Ristretto restricts further;
52 //! this [additional restriction][ristretto_coffee] gives the
53 //! _Ristretto_ encoding.
54 //!
55 //! More details on why Ristretto is necessary can be found in the
56 //! [Why Ristretto?][why_ristretto] section of the Ristretto website.
57 //!
58 //! Ristretto
59 //! points are provided in `curve25519-dalek` by the `RistrettoPoint`
60 //! struct.
61 //!
62 //! ## Encoding and Decoding
63 //!
64 //! Encoding is done by converting to and from a `CompressedRistretto`
65 //! struct, which is a typed wrapper around `[u8; 32]`.
66 //!
67 //! The encoding is not batchable, but it is possible to
68 //! double-and-encode in a batch using
69 //! `RistrettoPoint::double_and_compress_batch`.
70 //!
71 //! ## Equality Testing
72 //!
73 //! Testing equality of points on an Edwards curve in projective
74 //! coordinates requires an expensive inversion.  By contrast, equality
75 //! checking in the Ristretto group can be done in projective
76 //! coordinates without requiring an inversion, so it is much faster.
77 //!
78 //! The `RistrettoPoint` struct implements the
79 //! `subtle::ConstantTimeEq` trait for constant-time equality
80 //! checking, and the Rust `Eq` trait for variable-time equality
81 //! checking.
82 //!
83 //! ## Scalars
84 //!
85 //! Scalars are represented by the `Scalar` struct.  Each scalar has a
86 //! canonical representative mod the group order.  To attempt to load
87 //! a supposedly-canonical scalar, use
88 //! `Scalar::from_canonical_bytes()`. To check whether a
89 //! representative is canonical, use `Scalar::is_canonical()`.
90 //!
91 //! ## Scalar Multiplication
92 //!
93 //! Scalar multiplication on Ristretto points is provided by:
94 //!
95 //! * the `*` operator between a `Scalar` and a `RistrettoPoint`, which
96 //! performs constant-time variable-base scalar multiplication;
97 //!
98 //! * the `*` operator between a `Scalar` and a
99 //! `RistrettoBasepointTable`, which performs constant-time fixed-base
100 //! scalar multiplication;
101 //!
102 //! * an implementation of the
103 //! [`MultiscalarMul`](../traits/trait.MultiscalarMul.html) trait for
104 //! constant-time variable-base multiscalar multiplication;
105 //!
106 //! * an implementation of the
107 //! [`VartimeMultiscalarMul`](../traits/trait.VartimeMultiscalarMul.html)
108 //! trait for variable-time variable-base multiscalar multiplication;
109 //!
110 //! ## Random Points and Hashing to Ristretto
111 //!
112 //! The Ristretto group comes equipped with an Elligator map.  This is
113 //! used to implement
114 //!
115 //! * `RistrettoPoint::random()`, which generates random points from an
116 //! RNG;
117 //!
118 //! * `RistrettoPoint::from_hash()` and
119 //! `RistrettoPoint::hash_from_bytes()`, which perform hashing to the
120 //! group.
121 //!
122 //! The Elligator map itself is not currently exposed.
123 //!
124 //! ## Implementation
125 //!
126 //! The Decaf suggestion is to use a quotient group, such as \\(\mathcal
127 //! E / \mathcal E[4]\\) or \\(2 \mathcal E / \mathcal E[2] \\), to
128 //! implement a prime-order group using a non-prime-order curve.
129 //!
130 //! This requires only changing
131 //!
132 //! 1. the function for equality checking (so that two representatives
133 //!    of the same coset are considered equal);
134 //! 2. the function for encoding (so that two representatives of the
135 //!    same coset are encoded as identical bitstrings);
136 //! 3. the function for decoding (so that only the canonical encoding of
137 //!    a coset is accepted).
138 //!
139 //! Internally, each coset is represented by a curve point; two points
140 //! \\( P, Q \\) may represent the same coset in the same way that two
141 //! points with different \\(X,Y,Z\\) coordinates may represent the
142 //! same point.  The group operations are carried out with no overhead
143 //! using Edwards formulas.
144 //!
145 //! Notes on the details of the encoding can be found in the
146 //! [Details][ristretto_notes] section of the Ristretto website.
147 //!
148 //! [cryptonote]:
149 //! https://moderncrypto.org/mail-archive/curves/2017/000898.html
150 //! [ed25519_hkd]:
151 //! https://moderncrypto.org/mail-archive/curves/2017/000858.html
152 //! [ristretto_coffee]:
153 //! https://en.wikipedia.org/wiki/Ristretto
154 //! [ristretto_notes]:
155 //! https://ristretto.group/details/index.html
156 //! [why_ristretto]:
157 //! https://ristretto.group/why_ristretto.html
158 //! [ristretto_main]:
159 //! https://ristretto.group/
160 
161 use core::borrow::Borrow;
162 use core::fmt::Debug;
163 use core::iter::Sum;
164 use core::ops::{Add, Neg, Sub};
165 use core::ops::{AddAssign, SubAssign};
166 use core::ops::{Mul, MulAssign};
167 
168 use rand_core::{CryptoRng, RngCore};
169 
170 use digest::generic_array::typenum::U64;
171 use digest::Digest;
172 
173 use constants;
174 use field::FieldElement;
175 
176 use subtle::Choice;
177 use subtle::ConditionallySelectable;
178 use subtle::ConditionallyNegatable;
179 use subtle::ConstantTimeEq;
180 
181 use zeroize::Zeroize;
182 
183 use edwards::EdwardsBasepointTable;
184 use edwards::EdwardsPoint;
185 
186 #[allow(unused_imports)]
187 use prelude::*;
188 
189 use scalar::Scalar;
190 
191 use traits::Identity;
192 #[cfg(any(feature = "alloc", feature = "std"))]
193 use traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul};
194 
195 #[cfg(not(all(
196     feature = "simd_backend",
197     any(target_feature = "avx2", target_feature = "avx512ifma")
198 )))]
199 use backend::serial::scalar_mul;
200 #[cfg(all(
201     feature = "simd_backend",
202     any(target_feature = "avx2", target_feature = "avx512ifma")
203 ))]
204 use backend::vector::scalar_mul;
205 
206 // ------------------------------------------------------------------------
207 // Compressed points
208 // ------------------------------------------------------------------------
209 
210 /// A Ristretto point, in compressed wire format.
211 ///
212 /// The Ristretto encoding is canonical, so two points are equal if and
213 /// only if their encodings are equal.
214 #[derive(Copy, Clone, Eq, PartialEq, Hash)]
215 pub struct CompressedRistretto(pub [u8; 32]);
216 
217 impl ConstantTimeEq for CompressedRistretto {
ct_eq(&self, other: &CompressedRistretto) -> Choice218     fn ct_eq(&self, other: &CompressedRistretto) -> Choice {
219         self.as_bytes().ct_eq(other.as_bytes())
220     }
221 }
222 
223 impl CompressedRistretto {
224     /// Copy the bytes of this `CompressedRistretto`.
to_bytes(&self) -> [u8; 32]225     pub fn to_bytes(&self) -> [u8; 32] {
226         self.0
227     }
228 
229     /// View this `CompressedRistretto` as an array of bytes.
as_bytes(&self) -> &[u8; 32]230     pub fn as_bytes(&self) -> &[u8; 32] {
231         &self.0
232     }
233 
234     /// Construct a `CompressedRistretto` from a slice of bytes.
235     ///
236     /// # Panics
237     ///
238     /// If the input `bytes` slice does not have a length of 32.
from_slice(bytes: &[u8]) -> CompressedRistretto239     pub fn from_slice(bytes: &[u8]) -> CompressedRistretto {
240         let mut tmp = [0u8; 32];
241 
242         tmp.copy_from_slice(bytes);
243 
244         CompressedRistretto(tmp)
245     }
246 
247     /// Attempt to decompress to an `RistrettoPoint`.
248     ///
249     /// # Return
250     ///
251     /// - `Some(RistrettoPoint)` if `self` was the canonical encoding of a point;
252     ///
253     /// - `None` if `self` was not the canonical encoding of a point.
decompress(&self) -> Option<RistrettoPoint>254     pub fn decompress(&self) -> Option<RistrettoPoint> {
255         // Step 1. Check s for validity:
256         // 1.a) s must be 32 bytes (we get this from the type system)
257         // 1.b) s < p
258         // 1.c) s is nonnegative
259         //
260         // Our decoding routine ignores the high bit, so the only
261         // possible failure for 1.b) is if someone encodes s in 0..18
262         // as s+p in 2^255-19..2^255-1.  We can check this by
263         // converting back to bytes, and checking that we get the
264         // original input, since our encoding routine is canonical.
265 
266         let s = FieldElement::from_bytes(self.as_bytes());
267         let s_bytes_check = s.to_bytes();
268         let s_encoding_is_canonical =
269             &s_bytes_check[..].ct_eq(self.as_bytes());
270         let s_is_negative = s.is_negative();
271 
272         if s_encoding_is_canonical.unwrap_u8() == 0u8 || s_is_negative.unwrap_u8() == 1u8 {
273             return None;
274         }
275 
276         // Step 2.  Compute (X:Y:Z:T).
277         let one = FieldElement::one();
278         let ss = s.square();
279         let u1 = &one - &ss;      //  1 + as²
280         let u2 = &one + &ss;      //  1 - as²    where a=-1
281         let u2_sqr = u2.square(); // (1 - as²)²
282 
283         // v == ad(1+as²)² - (1-as²)²            where d=-121665/121666
284         let v = &(&(-&constants::EDWARDS_D) * &u1.square()) - &u2_sqr;
285 
286         let (ok, I) = (&v * &u2_sqr).invsqrt(); // 1/sqrt(v*u_2²)
287 
288         let Dx = &I * &u2;         // 1/sqrt(v)
289         let Dy = &I * &(&Dx * &v); // 1/u2
290 
291         // x == | 2s/sqrt(v) | == + sqrt(4s²/(ad(1+as²)² - (1-as²)²))
292         let mut x = &(&s + &s) * &Dx;
293         let x_neg = x.is_negative();
294         x.conditional_negate(x_neg);
295 
296         // y == (1-as²)/(1+as²)
297         let y = &u1 * &Dy;
298 
299         // t == ((1+as²) sqrt(4s²/(ad(1+as²)² - (1-as²)²)))/(1-as²)
300         let t = &x * &y;
301 
302         if ok.unwrap_u8() == 0u8 || t.is_negative().unwrap_u8() == 1u8 || y.is_zero().unwrap_u8() == 1u8 {
303             None
304         } else {
305             Some(RistrettoPoint(EdwardsPoint{X: x, Y: y, Z: one, T: t}))
306         }
307     }
308 }
309 
310 impl Identity for CompressedRistretto {
identity() -> CompressedRistretto311     fn identity() -> CompressedRistretto {
312         CompressedRistretto([0u8; 32])
313     }
314 }
315 
316 impl Default for CompressedRistretto {
default() -> CompressedRistretto317     fn default() -> CompressedRistretto {
318         CompressedRistretto::identity()
319     }
320 }
321 
322 // ------------------------------------------------------------------------
323 // Serde support
324 // ------------------------------------------------------------------------
325 // Serializes to and from `RistrettoPoint` directly, doing compression
326 // and decompression internally.  This means that users can create
327 // structs containing `RistrettoPoint`s and use Serde's derived
328 // serializers to serialize those structures.
329 
330 #[cfg(feature = "serde")]
331 use serde::{self, Serialize, Deserialize, Serializer, Deserializer};
332 #[cfg(feature = "serde")]
333 use serde::de::Visitor;
334 
335 #[cfg(feature = "serde")]
336 impl Serialize for RistrettoPoint {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer337     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
338         where S: Serializer
339     {
340         use serde::ser::SerializeTuple;
341         let mut tup = serializer.serialize_tuple(32)?;
342         for byte in self.compress().as_bytes().iter() {
343             tup.serialize_element(byte)?;
344         }
345         tup.end()
346     }
347 }
348 
349 #[cfg(feature = "serde")]
350 impl Serialize for CompressedRistretto {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer351     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
352         where S: Serializer
353     {
354         use serde::ser::SerializeTuple;
355         let mut tup = serializer.serialize_tuple(32)?;
356         for byte in self.as_bytes().iter() {
357             tup.serialize_element(byte)?;
358         }
359         tup.end()
360     }
361 }
362 
363 #[cfg(feature = "serde")]
364 impl<'de> Deserialize<'de> for RistrettoPoint {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>365     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
366         where D: Deserializer<'de>
367     {
368         struct RistrettoPointVisitor;
369 
370         impl<'de> Visitor<'de> for RistrettoPointVisitor {
371             type Value = RistrettoPoint;
372 
373             fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
374                 formatter.write_str("a valid point in Ristretto format")
375             }
376 
377             fn visit_seq<A>(self, mut seq: A) -> Result<RistrettoPoint, A::Error>
378                 where A: serde::de::SeqAccess<'de>
379             {
380                 let mut bytes = [0u8; 32];
381                 for i in 0..32 {
382                     bytes[i] = seq.next_element()?
383                         .ok_or(serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
384                 }
385                 CompressedRistretto(bytes)
386                     .decompress()
387                     .ok_or(serde::de::Error::custom("decompression failed"))
388             }
389         }
390 
391         deserializer.deserialize_tuple(32, RistrettoPointVisitor)
392     }
393 }
394 
395 #[cfg(feature = "serde")]
396 impl<'de> Deserialize<'de> for CompressedRistretto {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>397     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
398         where D: Deserializer<'de>
399     {
400         struct CompressedRistrettoVisitor;
401 
402         impl<'de> Visitor<'de> for CompressedRistrettoVisitor {
403             type Value = CompressedRistretto;
404 
405             fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
406                 formatter.write_str("32 bytes of data")
407             }
408 
409             fn visit_seq<A>(self, mut seq: A) -> Result<CompressedRistretto, A::Error>
410                 where A: serde::de::SeqAccess<'de>
411             {
412                 let mut bytes = [0u8; 32];
413                 for i in 0..32 {
414                     bytes[i] = seq.next_element()?
415                         .ok_or(serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
416                 }
417                 Ok(CompressedRistretto(bytes))
418             }
419         }
420 
421         deserializer.deserialize_tuple(32, CompressedRistrettoVisitor)
422     }
423 }
424 
425 // ------------------------------------------------------------------------
426 // Internal point representations
427 // ------------------------------------------------------------------------
428 
429 /// A `RistrettoPoint` represents a point in the Ristretto group for
430 /// Curve25519.  Ristretto, a variant of Decaf, constructs a
431 /// prime-order group as a quotient group of a subgroup of (the
432 /// Edwards form of) Curve25519.
433 ///
434 /// Internally, a `RistrettoPoint` is implemented as a wrapper type
435 /// around `EdwardsPoint`, with custom equality, compression, and
436 /// decompression routines to account for the quotient.  This means that
437 /// operations on `RistrettoPoint`s are exactly as fast as operations on
438 /// `EdwardsPoint`s.
439 ///
440 #[derive(Copy, Clone)]
441 pub struct RistrettoPoint(pub(crate) EdwardsPoint);
442 
443 impl RistrettoPoint {
444     /// Compress this point using the Ristretto encoding.
compress(&self) -> CompressedRistretto445     pub fn compress(&self) -> CompressedRistretto {
446         let mut X = self.0.X;
447         let mut Y = self.0.Y;
448         let Z = &self.0.Z;
449         let T = &self.0.T;
450 
451         let u1 = &(Z + &Y) * &(Z - &Y);
452         let u2 = &X * &Y;
453         // Ignore return value since this is always square
454         let (_, invsqrt) = (&u1 * &u2.square()).invsqrt();
455         let i1 = &invsqrt * &u1;
456         let i2 = &invsqrt * &u2;
457         let z_inv = &i1 * &(&i2 * T);
458         let mut den_inv = i2;
459 
460         let iX = &X * &constants::SQRT_M1;
461         let iY = &Y * &constants::SQRT_M1;
462         let ristretto_magic = &constants::INVSQRT_A_MINUS_D;
463         let enchanted_denominator = &i1 * ristretto_magic;
464 
465         let rotate = (T * &z_inv).is_negative();
466 
467         X.conditional_assign(&iY, rotate);
468         Y.conditional_assign(&iX, rotate);
469         den_inv.conditional_assign(&enchanted_denominator, rotate);
470 
471         Y.conditional_negate((&X * &z_inv).is_negative());
472 
473         let mut s = &den_inv * &(Z - &Y);
474         let s_is_negative = s.is_negative();
475         s.conditional_negate(s_is_negative);
476 
477         CompressedRistretto(s.to_bytes())
478     }
479 
480     /// Double-and-compress a batch of points.  The Ristretto encoding
481     /// is not batchable, since it requires an inverse square root.
482     ///
483     /// However, given input points \\( P\_1, \ldots, P\_n, \\)
484     /// it is possible to compute the encodings of their doubles \\(
485     /// \mathrm{enc}( [2]P\_1), \ldots, \mathrm{enc}( [2]P\_n ) \\)
486     /// in a batch.
487     ///
488     /// ```
489     /// # extern crate curve25519_dalek;
490     /// # use curve25519_dalek::ristretto::RistrettoPoint;
491     /// extern crate rand_core;
492     /// use rand_core::OsRng;
493     ///
494     /// # // Need fn main() here in comment so the doctest compiles
495     /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
496     /// # fn main() {
497     /// let mut rng = OsRng;
498     /// let points: Vec<RistrettoPoint> =
499     ///     (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
500     ///
501     /// let compressed = RistrettoPoint::double_and_compress_batch(&points);
502     ///
503     /// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
504     ///     assert_eq!(*P2_compressed, (P + P).compress());
505     /// }
506     /// # }
507     /// ```
508     #[cfg(feature = "alloc")]
double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto> where I: IntoIterator<Item = &'a RistrettoPoint>509     pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
510         where I: IntoIterator<Item = &'a RistrettoPoint>
511     {
512         #[derive(Copy, Clone, Debug)]
513         struct BatchCompressState {
514             e: FieldElement,
515             f: FieldElement,
516             g: FieldElement,
517             h: FieldElement,
518             eg: FieldElement,
519             fh: FieldElement,
520         }
521 
522         impl BatchCompressState {
523             fn efgh(&self) -> FieldElement {
524                 &self.eg * &self.fh
525             }
526         }
527 
528         impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
529             fn from(P: &'a RistrettoPoint) -> BatchCompressState {
530                 let XX = P.0.X.square();
531                 let YY = P.0.Y.square();
532                 let ZZ = P.0.Z.square();
533                 let dTT = &P.0.T.square() * &constants::EDWARDS_D;
534 
535                 let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
536                 let f = &ZZ + &dTT;                  // = Z^2 + d*T^2
537                 let g = &YY + &XX;                   // = Y^2 - a*X^2
538                 let h = &ZZ - &dTT;                  // = Z^2 - d*T^2
539 
540                 let eg = &e * &g;
541                 let fh = &f * &h;
542 
543                 BatchCompressState{ e, f, g, h, eg, fh }
544             }
545         }
546 
547         let states: Vec<BatchCompressState> = points.into_iter().map(BatchCompressState::from).collect();
548 
549         let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
550 
551         FieldElement::batch_invert(&mut invs[..]);
552 
553         states.iter().zip(invs.iter()).map(|(state, inv): (&BatchCompressState, &FieldElement)| {
554             let Zinv = &state.eg * &inv;
555             let Tinv = &state.fh * &inv;
556 
557             let mut magic = constants::INVSQRT_A_MINUS_D;
558 
559             let negcheck1 = (&state.eg * &Zinv).is_negative();
560 
561             let mut e = state.e;
562             let mut g = state.g;
563             let mut h = state.h;
564 
565             let minus_e = -&e;
566             let f_times_sqrta = &state.f * &constants::SQRT_M1;
567 
568             e.conditional_assign(&state.g,       negcheck1);
569             g.conditional_assign(&minus_e,       negcheck1);
570             h.conditional_assign(&f_times_sqrta, negcheck1);
571 
572             magic.conditional_assign(&constants::SQRT_M1, negcheck1);
573 
574             let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
575 
576             g.conditional_negate(negcheck2);
577 
578             let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
579 
580             let s_is_negative = s.is_negative();
581             s.conditional_negate(s_is_negative);
582 
583             CompressedRistretto(s.to_bytes())
584         }).collect()
585     }
586 
587 
588     /// Return the coset self + E[4], for debugging.
coset4(&self) -> [EdwardsPoint; 4]589     fn coset4(&self) -> [EdwardsPoint; 4] {
590         [  self.0
591         , &self.0 + &constants::EIGHT_TORSION[2]
592         , &self.0 + &constants::EIGHT_TORSION[4]
593         , &self.0 + &constants::EIGHT_TORSION[6]
594         ]
595     }
596 
597     /// Computes the Ristretto Elligator map.
598     ///
599     /// # Note
600     ///
601     /// This method is not public because it's just used for hashing
602     /// to a point -- proper elligator support is deferred for now.
elligator_ristretto_flavor(r_0: &FieldElement) -> RistrettoPoint603     pub(crate) fn elligator_ristretto_flavor(r_0: &FieldElement) -> RistrettoPoint {
604         let i = &constants::SQRT_M1;
605         let d = &constants::EDWARDS_D;
606         let one_minus_d_sq = &constants::ONE_MINUS_EDWARDS_D_SQUARED;
607         let d_minus_one_sq = &constants::EDWARDS_D_MINUS_ONE_SQUARED;
608         let mut c = constants::MINUS_ONE;
609 
610         let one = FieldElement::one();
611 
612         let r = i * &r_0.square();
613         let N_s = &(&r + &one) * &one_minus_d_sq;
614         let D = &(&c - &(d * &r)) * &(&r + d);
615 
616         let (Ns_D_is_sq, mut s) = FieldElement::sqrt_ratio_i(&N_s, &D);
617         let mut s_prime = &s * r_0;
618         let s_prime_is_pos = !s_prime.is_negative();
619         s_prime.conditional_negate(s_prime_is_pos);
620 
621         s.conditional_assign(&s_prime, !Ns_D_is_sq);
622         c.conditional_assign(&r, !Ns_D_is_sq);
623 
624         let N_t = &(&(&c * &(&r - &one)) * &d_minus_one_sq) - &D;
625         let s_sq = s.square();
626 
627         use backend::serial::curve_models::CompletedPoint;
628 
629         // The conversion from W_i is exactly the conversion from P1xP1.
630         RistrettoPoint(CompletedPoint{
631             X: &(&s + &s) * &D,
632             Z: &N_t * &constants::SQRT_AD_MINUS_ONE,
633             Y: &FieldElement::one() - &s_sq,
634             T: &FieldElement::one() + &s_sq,
635         }.to_extended())
636     }
637 
638     /// Return a `RistrettoPoint` chosen uniformly at random using a user-provided RNG.
639     ///
640     /// # Inputs
641     ///
642     /// * `rng`: any RNG which implements the `RngCore + CryptoRng` interface.
643     ///
644     /// # Returns
645     ///
646     /// A random element of the Ristretto group.
647     ///
648     /// # Implementation
649     ///
650     /// Uses the Ristretto-flavoured Elligator 2 map, so that the
651     /// discrete log of the output point with respect to any other
652     /// point should be unknown.  The map is applied twice and the
653     /// results are added, to ensure a uniform distribution.
random<R: RngCore + CryptoRng>(rng: &mut R) -> Self654     pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
655         let mut uniform_bytes = [0u8; 64];
656         rng.fill_bytes(&mut uniform_bytes);
657 
658         RistrettoPoint::from_uniform_bytes(&uniform_bytes)
659     }
660 
661     /// Hash a slice of bytes into a `RistrettoPoint`.
662     ///
663     /// Takes a type parameter `D`, which is any `Digest` producing 64
664     /// bytes of output.
665     ///
666     /// Convenience wrapper around `from_hash`.
667     ///
668     /// # Implementation
669     ///
670     /// Uses the Ristretto-flavoured Elligator 2 map, so that the
671     /// discrete log of the output point with respect to any other
672     /// point should be unknown.  The map is applied twice and the
673     /// results are added, to ensure a uniform distribution.
674     ///
675     /// # Example
676     ///
677     /// ```
678     /// # extern crate curve25519_dalek;
679     /// # use curve25519_dalek::ristretto::RistrettoPoint;
680     /// extern crate sha2;
681     /// use sha2::Sha512;
682     ///
683     /// # // Need fn main() here in comment so the doctest compiles
684     /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
685     /// # fn main() {
686     /// let msg = "To really appreciate architecture, you may even need to commit a murder";
687     /// let P = RistrettoPoint::hash_from_bytes::<Sha512>(msg.as_bytes());
688     /// # }
689     /// ```
690     ///
hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint where D: Digest<OutputSize = U64> + Default691     pub fn hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint
692         where D: Digest<OutputSize = U64> + Default
693     {
694         let mut hash = D::default();
695         hash.update(input);
696         RistrettoPoint::from_hash(hash)
697     }
698 
699     /// Construct a `RistrettoPoint` from an existing `Digest` instance.
700     ///
701     /// Use this instead of `hash_from_bytes` if it is more convenient
702     /// to stream data into the `Digest` than to pass a single byte
703     /// slice.
from_hash<D>(hash: D) -> RistrettoPoint where D: Digest<OutputSize = U64> + Default704     pub fn from_hash<D>(hash: D) -> RistrettoPoint
705         where D: Digest<OutputSize = U64> + Default
706     {
707         // dealing with generic arrays is clumsy, until const generics land
708         let output = hash.finalize();
709         let mut output_bytes = [0u8; 64];
710         output_bytes.copy_from_slice(&output.as_slice());
711 
712         RistrettoPoint::from_uniform_bytes(&output_bytes)
713     }
714 
715     /// Construct a `RistrettoPoint` from 64 bytes of data.
716     ///
717     /// If the input bytes are uniformly distributed, the resulting
718     /// point will be uniformly distributed over the group, and its
719     /// discrete log with respect to other points should be unknown.
720     ///
721     /// # Implementation
722     ///
723     /// This function splits the input array into two 32-byte halves,
724     /// takes the low 255 bits of each half mod p, applies the
725     /// Ristretto-flavored Elligator map to each, and adds the results.
from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint726     pub fn from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint {
727         let mut r_1_bytes = [0u8; 32];
728         r_1_bytes.copy_from_slice(&bytes[0..32]);
729         let r_1 = FieldElement::from_bytes(&r_1_bytes);
730         let R_1 = RistrettoPoint::elligator_ristretto_flavor(&r_1);
731 
732         let mut r_2_bytes = [0u8; 32];
733         r_2_bytes.copy_from_slice(&bytes[32..64]);
734         let r_2 = FieldElement::from_bytes(&r_2_bytes);
735         let R_2 = RistrettoPoint::elligator_ristretto_flavor(&r_2);
736 
737         // Applying Elligator twice and adding the results ensures a
738         // uniform distribution.
739         &R_1 + &R_2
740     }
741 }
742 
743 impl Identity for RistrettoPoint {
identity() -> RistrettoPoint744     fn identity() -> RistrettoPoint {
745         RistrettoPoint(EdwardsPoint::identity())
746     }
747 }
748 
749 impl Default for RistrettoPoint {
default() -> RistrettoPoint750     fn default() -> RistrettoPoint {
751         RistrettoPoint::identity()
752     }
753 }
754 
755 // ------------------------------------------------------------------------
756 // Equality
757 // ------------------------------------------------------------------------
758 
759 impl PartialEq for RistrettoPoint {
eq(&self, other: &RistrettoPoint) -> bool760     fn eq(&self, other: &RistrettoPoint) -> bool {
761         self.ct_eq(other).unwrap_u8() == 1u8
762     }
763 }
764 
765 impl ConstantTimeEq for RistrettoPoint {
766     /// Test equality between two `RistrettoPoint`s.
767     ///
768     /// # Returns
769     ///
770     /// * `Choice(1)` if the two `RistrettoPoint`s are equal;
771     /// * `Choice(0)` otherwise.
ct_eq(&self, other: &RistrettoPoint) -> Choice772     fn ct_eq(&self, other: &RistrettoPoint) -> Choice {
773         let X1Y2 = &self.0.X * &other.0.Y;
774         let Y1X2 = &self.0.Y * &other.0.X;
775         let X1X2 = &self.0.X * &other.0.X;
776         let Y1Y2 = &self.0.Y * &other.0.Y;
777 
778         X1Y2.ct_eq(&Y1X2) | X1X2.ct_eq(&Y1Y2)
779     }
780 }
781 
782 impl Eq for RistrettoPoint {}
783 
784 // ------------------------------------------------------------------------
785 // Arithmetic
786 // ------------------------------------------------------------------------
787 
788 impl<'a, 'b> Add<&'b RistrettoPoint> for &'a RistrettoPoint {
789     type Output = RistrettoPoint;
790 
add(self, other: &'b RistrettoPoint) -> RistrettoPoint791     fn add(self, other: &'b RistrettoPoint) -> RistrettoPoint {
792         RistrettoPoint(&self.0 + &other.0)
793     }
794 }
795 
796 define_add_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint, Output = RistrettoPoint);
797 
798 impl<'b> AddAssign<&'b RistrettoPoint> for RistrettoPoint {
add_assign(&mut self, _rhs: &RistrettoPoint)799     fn add_assign(&mut self, _rhs: &RistrettoPoint) {
800         *self = (self as &RistrettoPoint) + _rhs;
801     }
802 }
803 
804 define_add_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
805 
806 impl<'a, 'b> Sub<&'b RistrettoPoint> for &'a RistrettoPoint {
807     type Output = RistrettoPoint;
808 
sub(self, other: &'b RistrettoPoint) -> RistrettoPoint809     fn sub(self, other: &'b RistrettoPoint) -> RistrettoPoint {
810         RistrettoPoint(&self.0 - &other.0)
811     }
812 }
813 
814 define_sub_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint, Output = RistrettoPoint);
815 
816 impl<'b> SubAssign<&'b RistrettoPoint> for RistrettoPoint {
sub_assign(&mut self, _rhs: &RistrettoPoint)817     fn sub_assign(&mut self, _rhs: &RistrettoPoint) {
818         *self = (self as &RistrettoPoint) - _rhs;
819     }
820 }
821 
822 define_sub_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
823 
824 impl<T> Sum<T> for RistrettoPoint
825 where
826     T: Borrow<RistrettoPoint>
827 {
sum<I>(iter: I) -> Self where I: Iterator<Item = T>828     fn sum<I>(iter: I) -> Self
829     where
830         I: Iterator<Item = T>
831     {
832         iter.fold(RistrettoPoint::identity(), |acc, item| acc + item.borrow())
833     }
834 }
835 
836 impl<'a> Neg for &'a RistrettoPoint {
837     type Output = RistrettoPoint;
838 
neg(self) -> RistrettoPoint839     fn neg(self) -> RistrettoPoint {
840         RistrettoPoint(-&self.0)
841     }
842 }
843 
844 impl Neg for RistrettoPoint {
845     type Output = RistrettoPoint;
846 
neg(self) -> RistrettoPoint847     fn neg(self) -> RistrettoPoint {
848         -&self
849     }
850 }
851 
852 impl<'b> MulAssign<&'b Scalar> for RistrettoPoint {
mul_assign(&mut self, scalar: &'b Scalar)853     fn mul_assign(&mut self, scalar: &'b Scalar) {
854         let result = (self as &RistrettoPoint) * scalar;
855         *self = result;
856     }
857 }
858 
859 impl<'a, 'b> Mul<&'b Scalar> for &'a RistrettoPoint {
860     type Output = RistrettoPoint;
861     /// Scalar multiplication: compute `scalar * self`.
mul(self, scalar: &'b Scalar) -> RistrettoPoint862     fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
863         RistrettoPoint(self.0 * scalar)
864     }
865 }
866 
867 impl<'a, 'b> Mul<&'b RistrettoPoint> for &'a Scalar {
868     type Output = RistrettoPoint;
869 
870     /// Scalar multiplication: compute `self * scalar`.
mul(self, point: &'b RistrettoPoint) -> RistrettoPoint871     fn mul(self, point: &'b RistrettoPoint) -> RistrettoPoint {
872         RistrettoPoint(self * point.0)
873     }
874 }
875 
876 define_mul_assign_variants!(LHS = RistrettoPoint, RHS = Scalar);
877 
878 define_mul_variants!(LHS = RistrettoPoint, RHS = Scalar, Output = RistrettoPoint);
879 define_mul_variants!(LHS = Scalar, RHS = RistrettoPoint, Output = RistrettoPoint);
880 
881 // ------------------------------------------------------------------------
882 // Multiscalar Multiplication impls
883 // ------------------------------------------------------------------------
884 
885 // These use iterator combinators to unwrap the underlying points and
886 // forward to the EdwardsPoint implementations.
887 
888 #[cfg(feature = "alloc")]
889 impl MultiscalarMul for RistrettoPoint {
890     type Point = RistrettoPoint;
891 
multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator, J::Item: Borrow<RistrettoPoint>,892     fn multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint
893     where
894         I: IntoIterator,
895         I::Item: Borrow<Scalar>,
896         J: IntoIterator,
897         J::Item: Borrow<RistrettoPoint>,
898     {
899         let extended_points = points.into_iter().map(|P| P.borrow().0);
900         RistrettoPoint(
901             EdwardsPoint::multiscalar_mul(scalars, extended_points)
902         )
903     }
904 }
905 
906 #[cfg(feature = "alloc")]
907 impl VartimeMultiscalarMul for RistrettoPoint {
908     type Point = RistrettoPoint;
909 
optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint> where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator<Item = Option<RistrettoPoint>>,910     fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint>
911     where
912         I: IntoIterator,
913         I::Item: Borrow<Scalar>,
914         J: IntoIterator<Item = Option<RistrettoPoint>>,
915     {
916         let extended_points = points.into_iter().map(|opt_P| opt_P.map(|P| P.borrow().0));
917 
918         EdwardsPoint::optional_multiscalar_mul(scalars, extended_points).map(RistrettoPoint)
919     }
920 }
921 
922 /// Precomputation for variable-time multiscalar multiplication with `RistrettoPoint`s.
923 // This wraps the inner implementation in a facade type so that we can
924 // decouple stability of the inner type from the stability of the
925 // outer type.
926 #[cfg(feature = "alloc")]
927 pub struct VartimeRistrettoPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus);
928 
929 #[cfg(feature = "alloc")]
930 impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation {
931     type Point = RistrettoPoint;
932 
new<I>(static_points: I) -> Self where I: IntoIterator, I::Item: Borrow<Self::Point>,933     fn new<I>(static_points: I) -> Self
934     where
935         I: IntoIterator,
936         I::Item: Borrow<Self::Point>,
937     {
938         Self(
939             scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(
940                 static_points.into_iter().map(|P| P.borrow().0),
941             ),
942         )
943     }
944 
optional_mixed_multiscalar_mul<I, J, K>( &self, static_scalars: I, dynamic_scalars: J, dynamic_points: K, ) -> Option<Self::Point> where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator, J::Item: Borrow<Scalar>, K: IntoIterator<Item = Option<Self::Point>>,945     fn optional_mixed_multiscalar_mul<I, J, K>(
946         &self,
947         static_scalars: I,
948         dynamic_scalars: J,
949         dynamic_points: K,
950     ) -> Option<Self::Point>
951     where
952         I: IntoIterator,
953         I::Item: Borrow<Scalar>,
954         J: IntoIterator,
955         J::Item: Borrow<Scalar>,
956         K: IntoIterator<Item = Option<Self::Point>>,
957     {
958         self.0
959             .optional_mixed_multiscalar_mul(
960                 static_scalars,
961                 dynamic_scalars,
962                 dynamic_points.into_iter().map(|P_opt| P_opt.map(|P| P.0)),
963             )
964             .map(RistrettoPoint)
965     }
966 }
967 
968 impl RistrettoPoint {
969     /// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the
970     /// Ristretto basepoint.
vartime_double_scalar_mul_basepoint( a: &Scalar, A: &RistrettoPoint, b: &Scalar, ) -> RistrettoPoint971     pub fn vartime_double_scalar_mul_basepoint(
972         a: &Scalar,
973         A: &RistrettoPoint,
974         b: &Scalar,
975     ) -> RistrettoPoint {
976         RistrettoPoint(
977             EdwardsPoint::vartime_double_scalar_mul_basepoint(a, &A.0, b)
978         )
979     }
980 }
981 
982 /// A precomputed table of multiples of a basepoint, used to accelerate
983 /// scalar multiplication.
984 ///
985 /// A precomputed table of multiples of the Ristretto basepoint is
986 /// available in the `constants` module:
987 /// ```
988 /// use curve25519_dalek::constants;
989 /// use curve25519_dalek::scalar::Scalar;
990 ///
991 /// let a = Scalar::from(87329482u64);
992 /// let P = &a * &constants::RISTRETTO_BASEPOINT_TABLE;
993 /// ```
994 #[derive(Clone)]
995 pub struct RistrettoBasepointTable(pub(crate) EdwardsBasepointTable);
996 
997 impl<'a, 'b> Mul<&'b Scalar> for &'a RistrettoBasepointTable {
998     type Output = RistrettoPoint;
999 
mul(self, scalar: &'b Scalar) -> RistrettoPoint1000     fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
1001         RistrettoPoint(&self.0 * scalar)
1002     }
1003 }
1004 
1005 impl<'a, 'b> Mul<&'a RistrettoBasepointTable> for &'b Scalar {
1006     type Output = RistrettoPoint;
1007 
mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint1008     fn mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint {
1009         RistrettoPoint(self * &basepoint_table.0)
1010     }
1011 }
1012 
1013 impl RistrettoBasepointTable {
1014     /// Create a precomputed table of multiples of the given `basepoint`.
create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable1015     pub fn create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable {
1016         RistrettoBasepointTable(EdwardsBasepointTable::create(&basepoint.0))
1017     }
1018 
1019     /// Get the basepoint for this table as a `RistrettoPoint`.
basepoint(&self) -> RistrettoPoint1020     pub fn basepoint(&self) -> RistrettoPoint {
1021         RistrettoPoint(self.0.basepoint())
1022     }
1023 }
1024 
1025 // ------------------------------------------------------------------------
1026 // Constant-time conditional selection
1027 // ------------------------------------------------------------------------
1028 
1029 impl ConditionallySelectable for RistrettoPoint {
1030     /// Conditionally select between `self` and `other`.
1031     ///
1032     /// # Example
1033     ///
1034     /// ```
1035     /// # extern crate subtle;
1036     /// # extern crate curve25519_dalek;
1037     /// #
1038     /// use subtle::ConditionallySelectable;
1039     /// use subtle::Choice;
1040     /// #
1041     /// # use curve25519_dalek::traits::Identity;
1042     /// # use curve25519_dalek::ristretto::RistrettoPoint;
1043     /// # use curve25519_dalek::constants;
1044     /// # fn main() {
1045     ///
1046     /// let A = RistrettoPoint::identity();
1047     /// let B = constants::RISTRETTO_BASEPOINT_POINT;
1048     ///
1049     /// let mut P = A;
1050     ///
1051     /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(0));
1052     /// assert_eq!(P, A);
1053     /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(1));
1054     /// assert_eq!(P, B);
1055     /// # }
1056     /// ```
conditional_select( a: &RistrettoPoint, b: &RistrettoPoint, choice: Choice, ) -> RistrettoPoint1057     fn conditional_select(
1058         a: &RistrettoPoint,
1059         b: &RistrettoPoint,
1060         choice: Choice,
1061     ) -> RistrettoPoint {
1062         RistrettoPoint(EdwardsPoint::conditional_select(&a.0, &b.0, choice))
1063     }
1064 }
1065 
1066 // ------------------------------------------------------------------------
1067 // Debug traits
1068 // ------------------------------------------------------------------------
1069 
1070 impl Debug for CompressedRistretto {
fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result1071     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
1072         write!(f, "CompressedRistretto: {:?}", self.as_bytes())
1073     }
1074 }
1075 
1076 impl Debug for RistrettoPoint {
fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result1077     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
1078         let coset = self.coset4();
1079         write!(f, "RistrettoPoint: coset \n{:?}\n{:?}\n{:?}\n{:?}",
1080                coset[0], coset[1], coset[2], coset[3])
1081     }
1082 }
1083 
1084 // ------------------------------------------------------------------------
1085 // Zeroize traits
1086 // ------------------------------------------------------------------------
1087 
1088 impl Zeroize for CompressedRistretto {
zeroize(&mut self)1089     fn zeroize(&mut self) {
1090         self.0.zeroize();
1091     }
1092 }
1093 
1094 impl Zeroize for RistrettoPoint {
zeroize(&mut self)1095     fn zeroize(&mut self) {
1096         self.0.zeroize();
1097     }
1098 }
1099 
1100 // ------------------------------------------------------------------------
1101 // Tests
1102 // ------------------------------------------------------------------------
1103 
1104 #[cfg(test)]
1105 mod test {
1106     use rand_core::OsRng;
1107 
1108     use scalar::Scalar;
1109     use constants;
1110     use edwards::CompressedEdwardsY;
1111     use traits::{Identity};
1112     use super::*;
1113 
1114     #[test]
1115     #[cfg(feature = "serde")]
serde_bincode_basepoint_roundtrip()1116     fn serde_bincode_basepoint_roundtrip() {
1117         use bincode;
1118 
1119         let encoded = bincode::serialize(&constants::RISTRETTO_BASEPOINT_POINT).unwrap();
1120         let enc_compressed = bincode::serialize(&constants::RISTRETTO_BASEPOINT_COMPRESSED).unwrap();
1121         assert_eq!(encoded, enc_compressed);
1122 
1123         // Check that the encoding is 32 bytes exactly
1124         assert_eq!(encoded.len(), 32);
1125 
1126         let dec_uncompressed: RistrettoPoint = bincode::deserialize(&encoded).unwrap();
1127         let dec_compressed: CompressedRistretto = bincode::deserialize(&encoded).unwrap();
1128 
1129         assert_eq!(dec_uncompressed, constants::RISTRETTO_BASEPOINT_POINT);
1130         assert_eq!(dec_compressed, constants::RISTRETTO_BASEPOINT_COMPRESSED);
1131 
1132         // Check that the encoding itself matches the usual one
1133         let raw_bytes = constants::RISTRETTO_BASEPOINT_COMPRESSED.as_bytes();
1134         let bp: RistrettoPoint = bincode::deserialize(raw_bytes).unwrap();
1135         assert_eq!(bp, constants::RISTRETTO_BASEPOINT_POINT);
1136     }
1137 
1138     #[test]
scalarmult_ristrettopoint_works_both_ways()1139     fn scalarmult_ristrettopoint_works_both_ways() {
1140         let P = constants::RISTRETTO_BASEPOINT_POINT;
1141         let s = Scalar::from(999u64);
1142 
1143         let P1 = &P * &s;
1144         let P2 = &s * &P;
1145 
1146         assert!(P1.compress().as_bytes() == P2.compress().as_bytes());
1147     }
1148 
1149     #[test]
impl_sum()1150     fn impl_sum() {
1151 
1152         // Test that sum works for non-empty iterators
1153         let BASE = constants::RISTRETTO_BASEPOINT_POINT;
1154 
1155         let s1 = Scalar::from(999u64);
1156         let P1 = &BASE * &s1;
1157 
1158         let s2 = Scalar::from(333u64);
1159         let P2 = &BASE * &s2;
1160 
1161         let vec = vec![P1.clone(), P2.clone()];
1162         let sum: RistrettoPoint = vec.iter().sum();
1163 
1164         assert_eq!(sum, P1 + P2);
1165 
1166         // Test that sum works for the empty iterator
1167         let empty_vector: Vec<RistrettoPoint> = vec![];
1168         let sum: RistrettoPoint = empty_vector.iter().sum();
1169 
1170         assert_eq!(sum, RistrettoPoint::identity());
1171 
1172         // Test that sum works on owning iterators
1173         let s = Scalar::from(2u64);
1174         let mapped = vec.iter().map(|x| x * s);
1175         let sum: RistrettoPoint = mapped.sum();
1176 
1177         assert_eq!(sum, &P1 * &s + &P2 * &s);
1178     }
1179 
1180     #[test]
decompress_negative_s_fails()1181     fn decompress_negative_s_fails() {
1182         // constants::d is neg, so decompression should fail as |d| != d.
1183         let bad_compressed = CompressedRistretto(constants::EDWARDS_D.to_bytes());
1184         assert!(bad_compressed.decompress().is_none());
1185     }
1186 
1187     #[test]
decompress_id()1188     fn decompress_id() {
1189         let compressed_id = CompressedRistretto::identity();
1190         let id = compressed_id.decompress().unwrap();
1191         let mut identity_in_coset = false;
1192         for P in &id.coset4() {
1193             if P.compress() == CompressedEdwardsY::identity() {
1194                 identity_in_coset = true;
1195             }
1196         }
1197         assert!(identity_in_coset);
1198     }
1199 
1200     #[test]
compress_id()1201     fn compress_id() {
1202         let id = RistrettoPoint::identity();
1203         assert_eq!(id.compress(), CompressedRistretto::identity());
1204     }
1205 
1206     #[test]
basepoint_roundtrip()1207     fn basepoint_roundtrip() {
1208         let bp_compressed_ristretto = constants::RISTRETTO_BASEPOINT_POINT.compress();
1209         let bp_recaf = bp_compressed_ristretto.decompress().unwrap().0;
1210         // Check that bp_recaf differs from bp by a point of order 4
1211         let diff = &constants::RISTRETTO_BASEPOINT_POINT.0 - &bp_recaf;
1212         let diff4 = diff.mul_by_pow_2(2);
1213         assert_eq!(diff4.compress(), CompressedEdwardsY::identity());
1214     }
1215 
1216     #[test]
encodings_of_small_multiples_of_basepoint()1217     fn encodings_of_small_multiples_of_basepoint() {
1218         // Table of encodings of i*basepoint
1219         // Generated using ristretto.sage
1220         let compressed = [
1221             CompressedRistretto([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
1222             CompressedRistretto([226, 242, 174, 10, 106, 188, 78, 113, 168, 132, 169, 97, 197, 0, 81, 95, 88, 227, 11, 106, 165, 130, 221, 141, 182, 166, 89, 69, 224, 141, 45, 118]),
1223             CompressedRistretto([106, 73, 50, 16, 247, 73, 156, 209, 127, 236, 181, 16, 174, 12, 234, 35, 161, 16, 232, 213, 185, 1, 248, 172, 173, 211, 9, 92, 115, 163, 185, 25]),
1224             CompressedRistretto([148, 116, 31, 93, 93, 82, 117, 94, 206, 79, 35, 240, 68, 238, 39, 213, 209, 234, 30, 43, 209, 150, 180, 98, 22, 107, 22, 21, 42, 157, 2, 89]),
1225             CompressedRistretto([218, 128, 134, 39, 115, 53, 139, 70, 111, 250, 223, 224, 179, 41, 58, 179, 217, 253, 83, 197, 234, 108, 149, 83, 88, 245, 104, 50, 45, 175, 106, 87]),
1226             CompressedRistretto([232, 130, 177, 49, 1, 107, 82, 193, 211, 51, 112, 128, 24, 124, 247, 104, 66, 62, 252, 203, 181, 23, 187, 73, 90, 184, 18, 196, 22, 15, 244, 78]),
1227             CompressedRistretto([246, 71, 70, 211, 201, 43, 19, 5, 14, 216, 216, 2, 54, 167, 240, 0, 124, 59, 63, 150, 47, 91, 167, 147, 209, 154, 96, 30, 187, 29, 244, 3]),
1228             CompressedRistretto([68, 245, 53, 32, 146, 110, 200, 31, 189, 90, 56, 120, 69, 190, 183, 223, 133, 169, 106, 36, 236, 225, 135, 56, 189, 207, 166, 167, 130, 42, 23, 109]),
1229             CompressedRistretto([144, 50, 147, 216, 242, 40, 126, 190, 16, 226, 55, 77, 193, 165, 62, 11, 200, 135, 229, 146, 105, 159, 2, 208, 119, 213, 38, 60, 221, 85, 96, 28]),
1230             CompressedRistretto([2, 98, 42, 206, 143, 115, 3, 163, 28, 175, 198, 63, 143, 196, 143, 220, 22, 225, 200, 200, 210, 52, 178, 240, 214, 104, 82, 130, 169, 7, 96, 49]),
1231             CompressedRistretto([32, 112, 111, 215, 136, 178, 114, 10, 30, 210, 165, 218, 212, 149, 43, 1, 244, 19, 188, 240, 231, 86, 77, 232, 205, 200, 22, 104, 158, 45, 185, 95]),
1232             CompressedRistretto([188, 232, 63, 139, 165, 221, 47, 165, 114, 134, 76, 36, 186, 24, 16, 249, 82, 43, 198, 0, 74, 254, 149, 135, 122, 199, 50, 65, 202, 253, 171, 66]),
1233             CompressedRistretto([228, 84, 158, 225, 107, 154, 160, 48, 153, 202, 32, 140, 103, 173, 175, 202, 250, 76, 63, 62, 78, 83, 3, 222, 96, 38, 227, 202, 143, 248, 68, 96]),
1234             CompressedRistretto([170, 82, 224, 0, 223, 46, 22, 245, 95, 177, 3, 47, 195, 59, 196, 39, 66, 218, 214, 189, 90, 143, 192, 190, 1, 103, 67, 108, 89, 72, 80, 31]),
1235             CompressedRistretto([70, 55, 107, 128, 244, 9, 178, 157, 194, 181, 246, 240, 197, 37, 145, 153, 8, 150, 229, 113, 111, 65, 71, 124, 211, 0, 133, 171, 127, 16, 48, 30]),
1236             CompressedRistretto([224, 196, 24, 247, 200, 217, 196, 205, 215, 57, 91, 147, 234, 18, 79, 58, 217, 144, 33, 187, 104, 29, 252, 51, 2, 169, 217, 154, 46, 83, 230, 78]),
1237         ];
1238         let mut bp = RistrettoPoint::identity();
1239         for i in 0..16 {
1240             assert_eq!(bp.compress(), compressed[i]);
1241             bp = &bp + &constants::RISTRETTO_BASEPOINT_POINT;
1242         }
1243     }
1244 
1245     #[test]
four_torsion_basepoint()1246     fn four_torsion_basepoint() {
1247         let bp = constants::RISTRETTO_BASEPOINT_POINT;
1248         let bp_coset = bp.coset4();
1249         for i in 0..4 {
1250             assert_eq!(bp, RistrettoPoint(bp_coset[i]));
1251         }
1252     }
1253 
1254     #[test]
four_torsion_random()1255     fn four_torsion_random() {
1256         let mut rng = OsRng;
1257         let B = &constants::RISTRETTO_BASEPOINT_TABLE;
1258         let P = B * &Scalar::random(&mut rng);
1259         let P_coset = P.coset4();
1260         for i in 0..4 {
1261             assert_eq!(P, RistrettoPoint(P_coset[i]));
1262         }
1263     }
1264 
1265     #[test]
elligator_vs_ristretto_sage()1266     fn elligator_vs_ristretto_sage() {
1267         // Test vectors extracted from ristretto.sage.
1268         //
1269         // Notice that all of the byte sequences have bit 255 set to 0; this is because
1270         // ristretto.sage does not mask the high bit of a field element.  When the high bit is set,
1271         // the ristretto.sage elligator implementation gives different results, since it takes a
1272         // different field element as input.
1273         let bytes: [[u8;32]; 16] = [
1274             [184, 249, 135, 49, 253, 123, 89, 113, 67, 160, 6, 239, 7, 105, 211, 41, 192, 249, 185, 57, 9, 102, 70, 198, 15, 127, 7, 26, 160, 102, 134, 71],
1275             [229, 14, 241, 227, 75, 9, 118, 60, 128, 153, 226, 21, 183, 217, 91, 136, 98, 0, 231, 156, 124, 77, 82, 139, 142, 134, 164, 169, 169, 62, 250, 52],
1276             [115, 109, 36, 220, 180, 223, 99, 6, 204, 169, 19, 29, 169, 68, 84, 23, 21, 109, 189, 149, 127, 205, 91, 102, 172, 35, 112, 35, 134, 69, 186, 34],
1277             [16, 49, 96, 107, 171, 199, 164, 9, 129, 16, 64, 62, 241, 63, 132, 173, 209, 160, 112, 215, 105, 50, 157, 81, 253, 105, 1, 154, 229, 25, 120, 83],
1278             [156, 131, 161, 162, 236, 251, 5, 187, 167, 171, 17, 178, 148, 210, 90, 207, 86, 21, 79, 161, 167, 215, 234, 1, 136, 242, 182, 248, 38, 85, 79, 86],
1279             [251, 177, 124, 54, 18, 101, 75, 235, 245, 186, 19, 46, 133, 157, 229, 64, 10, 136, 181, 185, 78, 144, 254, 167, 137, 49, 107, 10, 61, 10, 21, 25],
1280             [232, 193, 20, 68, 240, 77, 186, 77, 183, 40, 44, 86, 150, 31, 198, 212, 76, 81, 3, 217, 197, 8, 126, 128, 126, 152, 164, 208, 153, 44, 189, 77],
1281             [173, 229, 149, 177, 37, 230, 30, 69, 61, 56, 172, 190, 219, 115, 167, 194, 71, 134, 59, 75, 28, 244, 118, 26, 162, 97, 64, 16, 15, 189, 30, 64],
1282             [106, 71, 61, 107, 250, 117, 42, 151, 91, 202, 212, 100, 52, 188, 190, 21, 125, 218, 31, 18, 253, 241, 160, 133, 57, 242, 3, 164, 189, 68, 111, 75],
1283             [112, 204, 182, 90, 220, 198, 120, 73, 173, 107, 193, 17, 227, 40, 162, 36, 150, 141, 235, 55, 172, 183, 12, 39, 194, 136, 43, 153, 244, 118, 91, 89],
1284             [111, 24, 203, 123, 254, 189, 11, 162, 51, 196, 163, 136, 204, 143, 10, 222, 33, 112, 81, 205, 34, 35, 8, 66, 90, 6, 164, 58, 170, 177, 34, 25],
1285             [225, 183, 30, 52, 236, 82, 6, 183, 109, 25, 227, 181, 25, 82, 41, 193, 80, 77, 161, 80, 242, 203, 79, 204, 136, 245, 131, 110, 237, 106, 3, 58],
1286             [207, 246, 38, 56, 30, 86, 176, 90, 27, 200, 61, 42, 221, 27, 56, 210, 79, 178, 189, 120, 68, 193, 120, 167, 77, 185, 53, 197, 124, 128, 191, 126],
1287             [1, 136, 215, 80, 240, 46, 63, 147, 16, 244, 230, 207, 82, 189, 74, 50, 106, 169, 138, 86, 30, 131, 214, 202, 166, 125, 251, 228, 98, 24, 36, 21],
1288             [210, 207, 228, 56, 155, 116, 207, 54, 84, 195, 251, 215, 249, 199, 116, 75, 109, 239, 196, 251, 194, 246, 252, 228, 70, 146, 156, 35, 25, 39, 241, 4],
1289             [34, 116, 123, 9, 8, 40, 93, 189, 9, 103, 57, 103, 66, 227, 3, 2, 157, 107, 134, 219, 202, 74, 230, 154, 78, 107, 219, 195, 214, 14, 84, 80],
1290         ];
1291         let encoded_images: [CompressedRistretto; 16] = [
1292             CompressedRistretto([176, 157, 237, 97, 66, 29, 140, 166, 168, 94, 26, 157, 212, 216, 229, 160, 195, 246, 232, 239, 169, 112, 63, 193, 64, 32, 152, 69, 11, 190, 246, 86]),
1293             CompressedRistretto([234, 141, 77, 203, 181, 225, 250, 74, 171, 62, 15, 118, 78, 212, 150, 19, 131, 14, 188, 238, 194, 244, 141, 138, 166, 162, 83, 122, 228, 201, 19, 26]),
1294             CompressedRistretto([232, 231, 51, 92, 5, 168, 80, 36, 173, 179, 104, 68, 186, 149, 68, 40, 140, 170, 27, 103, 99, 140, 21, 242, 43, 62, 250, 134, 208, 255, 61, 89]),
1295             CompressedRistretto([208, 120, 140, 129, 177, 179, 237, 159, 252, 160, 28, 13, 206, 5, 211, 241, 192, 218, 1, 97, 130, 241, 20, 169, 119, 46, 246, 29, 79, 80, 77, 84]),
1296             CompressedRistretto([202, 11, 236, 145, 58, 12, 181, 157, 209, 6, 213, 88, 75, 147, 11, 119, 191, 139, 47, 142, 33, 36, 153, 193, 223, 183, 178, 8, 205, 120, 248, 110]),
1297             CompressedRistretto([26, 66, 231, 67, 203, 175, 116, 130, 32, 136, 62, 253, 215, 46, 5, 214, 166, 248, 108, 237, 216, 71, 244, 173, 72, 133, 82, 6, 143, 240, 104, 41]),
1298             CompressedRistretto([40, 157, 102, 96, 201, 223, 200, 197, 150, 181, 106, 83, 103, 126, 143, 33, 145, 230, 78, 6, 171, 146, 210, 143, 112, 5, 245, 23, 183, 138, 18, 120]),
1299             CompressedRistretto([220, 37, 27, 203, 239, 196, 176, 131, 37, 66, 188, 243, 185, 250, 113, 23, 167, 211, 154, 243, 168, 215, 54, 171, 159, 36, 195, 81, 13, 150, 43, 43]),
1300             CompressedRistretto([232, 121, 176, 222, 183, 196, 159, 90, 238, 193, 105, 52, 101, 167, 244, 170, 121, 114, 196, 6, 67, 152, 80, 185, 221, 7, 83, 105, 176, 208, 224, 121]),
1301             CompressedRistretto([226, 181, 183, 52, 241, 163, 61, 179, 221, 207, 220, 73, 245, 242, 25, 236, 67, 84, 179, 222, 167, 62, 167, 182, 32, 9, 92, 30, 165, 127, 204, 68]),
1302             CompressedRistretto([226, 119, 16, 242, 200, 139, 240, 87, 11, 222, 92, 146, 156, 243, 46, 119, 65, 59, 1, 248, 92, 183, 50, 175, 87, 40, 206, 53, 208, 220, 148, 13]),
1303             CompressedRistretto([70, 240, 79, 112, 54, 157, 228, 146, 74, 122, 216, 88, 232, 62, 158, 13, 14, 146, 115, 117, 176, 222, 90, 225, 244, 23, 94, 190, 150, 7, 136, 96]),
1304             CompressedRistretto([22, 71, 241, 103, 45, 193, 195, 144, 183, 101, 154, 50, 39, 68, 49, 110, 51, 44, 62, 0, 229, 113, 72, 81, 168, 29, 73, 106, 102, 40, 132, 24]),
1305             CompressedRistretto([196, 133, 107, 11, 130, 105, 74, 33, 204, 171, 133, 221, 174, 193, 241, 36, 38, 179, 196, 107, 219, 185, 181, 253, 228, 47, 155, 42, 231, 73, 41, 78]),
1306             CompressedRistretto([58, 255, 225, 197, 115, 208, 160, 143, 39, 197, 82, 69, 143, 235, 92, 170, 74, 40, 57, 11, 171, 227, 26, 185, 217, 207, 90, 185, 197, 190, 35, 60]),
1307             CompressedRistretto([88, 43, 92, 118, 223, 136, 105, 145, 238, 186, 115, 8, 214, 112, 153, 253, 38, 108, 205, 230, 157, 130, 11, 66, 101, 85, 253, 110, 110, 14, 148, 112]),
1308         ];
1309         for i in 0..16 {
1310             let r_0 = FieldElement::from_bytes(&bytes[i]);
1311             let Q = RistrettoPoint::elligator_ristretto_flavor(&r_0);
1312             assert_eq!(Q.compress(), encoded_images[i]);
1313         }
1314     }
1315 
1316     #[test]
random_roundtrip()1317     fn random_roundtrip() {
1318         let mut rng = OsRng;
1319         let B = &constants::RISTRETTO_BASEPOINT_TABLE;
1320         for _ in 0..100 {
1321             let P = B * &Scalar::random(&mut rng);
1322             let compressed_P = P.compress();
1323             let Q = compressed_P.decompress().unwrap();
1324             assert_eq!(P, Q);
1325         }
1326     }
1327 
1328     #[test]
double_and_compress_1024_random_points()1329     fn double_and_compress_1024_random_points() {
1330         let mut rng = OsRng;
1331 
1332         let points: Vec<RistrettoPoint> =
1333             (0..1024).map(|_| RistrettoPoint::random(&mut rng)).collect();
1334 
1335         let compressed = RistrettoPoint::double_and_compress_batch(&points);
1336 
1337         for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
1338             assert_eq!(*P2_compressed, (P + P).compress());
1339         }
1340     }
1341 
1342     #[test]
vartime_precomputed_vs_nonprecomputed_multiscalar()1343     fn vartime_precomputed_vs_nonprecomputed_multiscalar() {
1344         let mut rng = rand::thread_rng();
1345 
1346         let B = &::constants::RISTRETTO_BASEPOINT_TABLE;
1347 
1348         let static_scalars = (0..128)
1349             .map(|_| Scalar::random(&mut rng))
1350             .collect::<Vec<_>>();
1351 
1352         let dynamic_scalars = (0..128)
1353             .map(|_| Scalar::random(&mut rng))
1354             .collect::<Vec<_>>();
1355 
1356         let check_scalar: Scalar = static_scalars
1357             .iter()
1358             .chain(dynamic_scalars.iter())
1359             .map(|s| s * s)
1360             .sum();
1361 
1362         let static_points = static_scalars.iter().map(|s| s * B).collect::<Vec<_>>();
1363         let dynamic_points = dynamic_scalars.iter().map(|s| s * B).collect::<Vec<_>>();
1364 
1365         let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1366 
1367         let P = precomputation.vartime_mixed_multiscalar_mul(
1368             &static_scalars,
1369             &dynamic_scalars,
1370             &dynamic_points,
1371         );
1372 
1373         use traits::VartimeMultiscalarMul;
1374         let Q = RistrettoPoint::vartime_multiscalar_mul(
1375             static_scalars.iter().chain(dynamic_scalars.iter()),
1376             static_points.iter().chain(dynamic_points.iter()),
1377         );
1378 
1379         let R = &check_scalar * B;
1380 
1381         assert_eq!(P.compress(), R.compress());
1382         assert_eq!(Q.compress(), R.compress());
1383     }
1384 }
1385